Skip to content

Commit 73f5649

Browse files
authored
Implement temporal rolling VAE (Major VRAM reductions in Hunyuan and Kandinsky) (#10995)
* hunyuan upsampler: rework imports Remove the transitive import of VideoConv3d and Resnet and takes these from actual implementation source. * model: remove unused give_pre_end According to git grep, this is not used now, and was not used in the initial commit that introduced it (see below). This semantic is difficult to implement temporal roll VAE for (and would defeat the purpose). Rather than implement the complex if, just delete the unused feature. (venv) rattus@rattus-box2:~/ComfyUI$ git log --oneline 220afe3 (HEAD) Initial commit. (venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end: (venv) rattus@rattus-box2:~/ComfyUI$ git co origin/master Previous HEAD position was 220afe3 Initial commit. HEAD is now at 9d8a817 Enable async offloading by default on Nvidia. (#10953) (venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end: * move refiner VAE temporal roller to core Move the carrying conv op to the common VAE code and give it a better name. Roll the carry implementation logic for Resnet into the base class and scrap the Hunyuan specific subclass. * model: Add temporal roll to main VAE decoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolloing VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings). * model: Add temporal roll to main VAE encoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolling VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings).
1 parent 3f512f5 commit 73f5649

File tree

3 files changed

+175
-131
lines changed

3 files changed

+175
-131
lines changed

comfy/ldm/hunyuan_video/upsampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
4+
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
5+
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
56
import model_management, model_patcher
67

78
class SRResidualCausalBlock3D(nn.Module):

comfy/ldm/hunyuan_video/vae_refiner.py

Lines changed: 22 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,12 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
4+
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
55
import comfy.ops
66
import comfy.ldm.models.autoencoder
77
import comfy.model_management
88
ops = comfy.ops.disable_weight_init
99

10-
class NoPadConv3d(nn.Module):
11-
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
12-
super().__init__()
13-
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
14-
15-
def forward(self, x):
16-
return self.conv(x)
17-
18-
19-
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
20-
21-
x = xl[0]
22-
xl.clear()
23-
24-
if conv_carry_out is not None:
25-
to_push = x[:, :, -2:, :, :].clone()
26-
conv_carry_out.append(to_push)
27-
28-
if isinstance(op, NoPadConv3d):
29-
if conv_carry_in is None:
30-
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
31-
else:
32-
carry_len = conv_carry_in[0].shape[2]
33-
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
34-
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
35-
36-
out = op(x)
37-
38-
return out
39-
4010

4111
class RMS_norm(nn.Module):
4212
def __init__(self, dim):
@@ -49,7 +19,7 @@ def forward(self, x):
4919
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
5020

5121
class DnSmpl(nn.Module):
52-
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
22+
def __init__(self, ic, oc, tds, refiner_vae, op):
5323
super().__init__()
5424
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
5525
assert oc % fct == 0
@@ -109,7 +79,7 @@ def forward(self, x, conv_carry_in=None, conv_carry_out=None):
10979

11080

11181
class UpSmpl(nn.Module):
112-
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
82+
def __init__(self, ic, oc, tus, refiner_vae, op):
11383
super().__init__()
11484
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
11585
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
@@ -163,23 +133,6 @@ def forward(self, x, conv_carry_in=None, conv_carry_out=None):
163133

164134
return h + x
165135

166-
class HunyuanRefinerResnetBlock(ResnetBlock):
167-
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
168-
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
169-
170-
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
171-
h = x
172-
h = [ self.swish(self.norm1(x)) ]
173-
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
174-
175-
h = [ self.dropout(self.swish(self.norm2(h))) ]
176-
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
177-
178-
if self.in_channels != self.out_channels:
179-
x = self.nin_shortcut(x)
180-
181-
return x+h
182-
183136
class Encoder(nn.Module):
184137
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
185138
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
@@ -191,7 +144,7 @@ def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
191144

192145
self.refiner_vae = refiner_vae
193146
if self.refiner_vae:
194-
conv_op = NoPadConv3d
147+
conv_op = CarriedConv3d
195148
norm_op = RMS_norm
196149
else:
197150
conv_op = ops.Conv3d
@@ -206,9 +159,10 @@ def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
206159

207160
for i, tgt in enumerate(block_out_channels):
208161
stage = nn.Module()
209-
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
210-
out_channels=tgt,
211-
conv_op=conv_op, norm_op=norm_op)
162+
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
163+
out_channels=tgt,
164+
temb_channels=0,
165+
conv_op=conv_op, norm_op=norm_op)
212166
for j in range(num_res_blocks)])
213167
ch = tgt
214168
if i < depth:
@@ -218,9 +172,9 @@ def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
218172
self.down.append(stage)
219173

220174
self.mid = nn.Module()
221-
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
175+
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
222176
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
223-
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
177+
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
224178

225179
self.norm_out = norm_op(ch)
226180
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@@ -246,22 +200,20 @@ def forward(self, x):
246200
conv_carry_out = []
247201
if i == len(x) - 1:
248202
conv_carry_out = None
203+
249204
x1 = [ x1 ]
250205
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
251206

252207
for stage in self.down:
253208
for blk in stage.block:
254-
x1 = blk(x1, conv_carry_in, conv_carry_out)
209+
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
255210
if hasattr(stage, 'downsample'):
256211
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
257212

258213
out.append(x1)
259214
conv_carry_in = conv_carry_out
260215

261-
if len(out) > 1:
262-
out = torch.cat(out, dim=2)
263-
else:
264-
out = out[0]
216+
out = torch_cat_if_needed(out, dim=2)
265217

266218
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
267219
del out
@@ -288,7 +240,7 @@ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
288240

289241
self.refiner_vae = refiner_vae
290242
if self.refiner_vae:
291-
conv_op = NoPadConv3d
243+
conv_op = CarriedConv3d
292244
norm_op = RMS_norm
293245
else:
294246
conv_op = ops.Conv3d
@@ -298,19 +250,20 @@ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
298250
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
299251

300252
self.mid = nn.Module()
301-
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
253+
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
302254
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
303-
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
255+
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
304256

305257
self.up = nn.ModuleList()
306258
depth = (ffactor_spatial >> 1).bit_length()
307259
depth_temporal = (ffactor_temporal >> 1).bit_length()
308260

309261
for i, tgt in enumerate(block_out_channels):
310262
stage = nn.Module()
311-
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
312-
out_channels=tgt,
313-
conv_op=conv_op, norm_op=norm_op)
263+
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
264+
out_channels=tgt,
265+
temb_channels=0,
266+
conv_op=conv_op, norm_op=norm_op)
314267
for j in range(num_res_blocks + 1)])
315268
ch = tgt
316269
if i < depth:
@@ -340,7 +293,7 @@ def forward(self, z):
340293
conv_carry_out = None
341294
for stage in self.up:
342295
for blk in stage.block:
343-
x1 = blk(x1, conv_carry_in, conv_carry_out)
296+
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
344297
if hasattr(stage, 'upsample'):
345298
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
346299

@@ -350,10 +303,7 @@ def forward(self, z):
350303
conv_carry_in = conv_carry_out
351304
del x
352305

353-
if len(out) > 1:
354-
out = torch.cat(out, dim=2)
355-
else:
356-
out = out[0]
306+
out = torch_cat_if_needed(out, dim=2)
357307

358308
if not self.refiner_vae:
359309
if z.shape[-3] == 1:

0 commit comments

Comments
 (0)