Skip to content

Commit 329480d

Browse files
Fix qwen scaled fp8 not working with kandinsky. Make basic t2i wf work. (#11162)
1 parent 4086acf commit 329480d

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

comfy/ldm/kandinsky5/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,9 @@ def block_wrap(args):
387387
return self.out_layer(visual_embed, time_embed)
388388

389389
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
390+
original_dims = x.ndim
391+
if original_dims == 4:
392+
x = x.unsqueeze(2)
390393
bs, c, t_len, h, w = x.shape
391394
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
392395

@@ -397,7 +400,10 @@ def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_o
397400
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
398401
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
399402

400-
return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
403+
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
404+
if original_dims == 4:
405+
out = out.squeeze(2)
406+
return out
401407

402408
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
403409
return comfy.patcher_extension.WrapperExecutor.new_class_executor(

comfy/text_encoders/kandinsky5.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
2424

2525
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
2626
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
27-
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
28-
if llama_scaled_fp8 is not None:
27+
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
28+
if llama_quantization_metadata is not None:
2929
model_options = model_options.copy()
30-
model_options["scaled_fp8"] = llama_scaled_fp8
30+
model_options["quantization_metadata"] = llama_quantization_metadata
3131
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
3232

3333

@@ -56,12 +56,12 @@ def load_sd(self, sd):
5656
else:
5757
return super().load_sd(sd)
5858

59-
def te(dtype_llama=None, llama_scaled_fp8=None):
59+
def te(dtype_llama=None, llama_quantization_metadata=None):
6060
class Kandinsky5TEModel_(Kandinsky5TEModel):
6161
def __init__(self, device="cpu", dtype=None, model_options={}):
62-
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
62+
if llama_quantization_metadata is not None:
6363
model_options = model_options.copy()
64-
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
64+
model_options["llama_quantization_metadata"] = llama_quantization_metadata
6565
if dtype_llama is not None:
6666
dtype = dtype_llama
6767
super().__init__(device=device, dtype=dtype, model_options=model_options)

0 commit comments

Comments
 (0)