Skip to content

Commit 56fa7db

Browse files
Properly load the newbie diffusion model. (#11172)
There is still one of the text encoders missing and I didn't actually test it.
1 parent 329480d commit 56fa7db

File tree

3 files changed

+42
-0
lines changed

3 files changed

+42
-0
lines changed

comfy/ldm/lumina/model.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def __init__(
377377
z_image_modulation=False,
378378
time_scale=1.0,
379379
pad_tokens_multiple=None,
380+
clip_text_dim=None,
380381
image_model=None,
381382
device=None,
382383
dtype=None,
@@ -447,6 +448,31 @@ def __init__(
447448
),
448449
)
449450

451+
self.clip_text_pooled_proj = None
452+
453+
if clip_text_dim is not None:
454+
self.clip_text_dim = clip_text_dim
455+
self.clip_text_pooled_proj = nn.Sequential(
456+
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
457+
operation_settings.get("operations").Linear(
458+
clip_text_dim,
459+
clip_text_dim,
460+
bias=True,
461+
device=operation_settings.get("device"),
462+
dtype=operation_settings.get("dtype"),
463+
),
464+
)
465+
self.time_text_embed = nn.Sequential(
466+
nn.SiLU(),
467+
operation_settings.get("operations").Linear(
468+
min(dim, 1024) + clip_text_dim,
469+
min(dim, 1024),
470+
bias=True,
471+
device=operation_settings.get("device"),
472+
dtype=operation_settings.get("dtype"),
473+
),
474+
)
475+
450476
self.layers = nn.ModuleList(
451477
[
452478
JointTransformerBlock(
@@ -585,6 +611,15 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, trans
585611

586612
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
587613

614+
if self.clip_text_pooled_proj is not None:
615+
pooled = kwargs.get("clip_text_pooled", None)
616+
if pooled is not None:
617+
pooled = self.clip_text_pooled_proj(pooled)
618+
else:
619+
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
620+
621+
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
622+
588623
patches = transformer_options.get("patches", {})
589624
x_is_tensor = isinstance(x, torch.Tensor)
590625
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)

comfy/model_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,10 @@ def extra_conds(self, **kwargs):
11101110
if 'num_tokens' not in out:
11111111
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
11121112

1113+
clip_text_pooled = kwargs["pooled_output"] # Newbie
1114+
if clip_text_pooled is not None:
1115+
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
1116+
11131117
return out
11141118

11151119
class WAN21(BaseModel):

comfy/model_detection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
423423
dit_config["axes_lens"] = [300, 512, 512]
424424
dit_config["rope_theta"] = 10000.0
425425
dit_config["ffn_dim_multiplier"] = 4.0
426+
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
427+
if ctd_weight is not None:
428+
dit_config["clip_text_dim"] = ctd_weight.shape[0]
426429
elif dit_config["dim"] == 3840: # Z image
427430
dit_config["n_heads"] = 30
428431
dit_config["n_kv_heads"] = 30

0 commit comments

Comments
 (0)