Skip to content

Commit f791ffc

Browse files
committed
Update chunking
1 parent b952d49 commit f791ffc

File tree

1 file changed

+42
-20
lines changed

1 file changed

+42
-20
lines changed

comfy/ldm/kandinsky5/model.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -110,23 +110,43 @@ def __init__(self, num_channels, head_dim, operation_settings=None):
110110
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
111111

112112
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
113+
self.num_chunks = 2
113114

114-
def forward(self, x, freqs, transformer_options={}):
115-
def compute_q(x):
116-
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
117-
return apply_rope1(self.query_norm(q), freqs)
118-
119-
def compute_k(x):
120-
k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1)
121-
return apply_rope1(self.key_norm(k), freqs)
115+
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
116+
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
117+
return apply_rope1(norm_fn(result), freqs)
122118

123-
q = compute_q(x)
124-
k = compute_k(x)
119+
def _forward(self, x, freqs, transformer_options={}):
120+
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
121+
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
122+
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
123+
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
124+
return self.out_layer(out)
125125

126+
def _forward_chunked(self, x, freqs, transformer_options={}):
127+
def process_chunks(proj_fn, norm_fn):
128+
B, L, _ = x.shape
129+
chunk_size = (L + self.num_chunks - 1) // self.num_chunks
130+
chunks = []
131+
for i in range(0, L, chunk_size):
132+
end_idx = min(i + chunk_size, L)
133+
x_chunk = x[:, i:end_idx]
134+
freqs_chunk = freqs[:, i:end_idx]
135+
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
136+
return torch.cat(chunks, dim=1)
137+
138+
q = process_chunks(self.to_query, self.query_norm)
139+
k = process_chunks(self.to_key, self.key_norm)
126140
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
127141
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
128142
return self.out_layer(out)
129143

144+
def forward(self, x, freqs, transformer_options={}):
145+
if x.shape[1] > 8192:
146+
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
147+
else:
148+
return self._forward(x, freqs, transformer_options=transformer_options)
149+
130150

131151
class CrossAttention(SelfAttention):
132152
def get_qkv(self, x, context):
@@ -150,22 +170,24 @@ def __init__(self, dim, ff_dim, operation_settings=None):
150170
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
151171
self.num_chunks = 4
152172

153-
def forward(self, x):
154-
#return self.out_layer(self.activation(self.in_layer(x)))
155-
# ffn is the peak memory consumer, chunking here helps
156-
B, L, C = x.shape
173+
def _forward(self, x):
174+
return self.out_layer(self.activation(self.in_layer(x)))
175+
176+
def _forward_chunked(self, x):
177+
B, L, _ = x.shape
157178
chunk_size = (L + self.num_chunks - 1) // self.num_chunks
158179
output = torch.empty(B, L, self.out_layer.out_features, dtype=x.dtype, device=x.device)
159-
160180
for i in range(0, L, chunk_size):
161181
end_idx = min(i + chunk_size, L)
162-
def compute_chunk(x_chunk):
163-
activated = self.activation(self.in_layer(x_chunk))
164-
return self.out_layer(activated)
165-
output[:, i:end_idx] = compute_chunk(x[:, i:end_idx])
166-
182+
output[:, i:end_idx] = self._forward(x[:, i:end_idx])
167183
return output
168184

185+
def forward(self, x):
186+
if x.shape[1] > 8192:
187+
return self._forward_chunked(x)
188+
else:
189+
return self._forward(x)
190+
169191

170192
class OutLayer(nn.Module):
171193
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):

0 commit comments

Comments
 (0)