Skip to content

Commit 5151cff

Browse files
Add some missing z image lora layers. (#10980)
1 parent af96d98 commit 5151cff

File tree

2 files changed

+42
-20
lines changed

2 files changed

+42
-20
lines changed

comfy/lora.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,11 @@ def model_lora_keys_unet(model, key_map={}):
316316
if isinstance(model, comfy.model_base.Lumina2):
317317
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
318318
for k in diffusers_keys:
319-
to = diffusers_keys[k]
320-
key_lora = k[:-len(".weight")]
321-
key_map["diffusion_model.{}".format(key_lora)] = to
322-
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
319+
if k.endswith(".weight"):
320+
to = diffusers_keys[k]
321+
key_lora = k[:-len(".weight")]
322+
key_map["diffusion_model.{}".format(key_lora)] = to
323+
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
323324

324325
return key_map
325326

comfy/utils.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -678,17 +678,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
678678
def z_image_to_diffusers(mmdit_config, output_prefix=""):
679679
n_layers = mmdit_config.get("n_layers", 0)
680680
hidden_size = mmdit_config.get("dim", 0)
681-
681+
n_context_refiner = mmdit_config.get("n_refiner_layers", 2)
682+
n_noise_refiner = mmdit_config.get("n_refiner_layers", 2)
682683
key_map = {}
683684

684-
for index in range(n_layers):
685-
prefix_from = "layers.{}".format(index)
686-
prefix_to = "{}layers.{}".format(output_prefix, index)
687-
685+
def add_block_keys(prefix_from, prefix_to, has_adaln=True):
688686
for end in ("weight", "bias"):
689687
k = "{}.attention.".format(prefix_from)
690688
qkv = "{}.attention.qkv.{}".format(prefix_to, end)
691-
692689
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
693690
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
694691
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
@@ -698,28 +695,52 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""):
698695
"attention.norm_k.weight": "attention.k_norm.weight",
699696
"attention.to_out.0.weight": "attention.out.weight",
700697
"attention.to_out.0.bias": "attention.out.bias",
698+
"attention_norm1.weight": "attention_norm1.weight",
699+
"attention_norm2.weight": "attention_norm2.weight",
700+
"feed_forward.w1.weight": "feed_forward.w1.weight",
701+
"feed_forward.w2.weight": "feed_forward.w2.weight",
702+
"feed_forward.w3.weight": "feed_forward.w3.weight",
703+
"ffn_norm1.weight": "ffn_norm1.weight",
704+
"ffn_norm2.weight": "ffn_norm2.weight",
701705
}
706+
if has_adaln:
707+
block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight"
708+
block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias"
709+
for k, v in block_map.items():
710+
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v)
702711

703-
for k in block_map:
704-
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
712+
for i in range(n_layers):
713+
add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i))
705714

706-
MAP_BASIC = {
707-
# Final layer
715+
for i in range(n_context_refiner):
716+
add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i))
717+
718+
for i in range(n_noise_refiner):
719+
add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i))
720+
721+
MAP_BASIC = [
708722
("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"),
709723
("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"),
710724
("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"),
711725
("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"),
712-
# X embedder
713726
("x_embedder.weight", "all_x_embedder.2-1.weight"),
714727
("x_embedder.bias", "all_x_embedder.2-1.bias"),
715-
}
716-
717-
for k in MAP_BASIC:
718-
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
728+
("x_pad_token", "x_pad_token"),
729+
("cap_embedder.0.weight", "cap_embedder.0.weight"),
730+
("cap_embedder.1.weight", "cap_embedder.1.weight"),
731+
("cap_embedder.1.bias", "cap_embedder.1.bias"),
732+
("cap_pad_token", "cap_pad_token"),
733+
("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"),
734+
("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"),
735+
("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"),
736+
("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"),
737+
]
738+
739+
for c, diffusers in MAP_BASIC:
740+
key_map[diffusers] = "{}{}".format(output_prefix, c)
719741

720742
return key_map
721743

722-
723744
def repeat_to_batch_size(tensor, batch_size, dim=0):
724745
if tensor.shape[dim] > batch_size:
725746
return tensor.narrow(dim, 0, batch_size)

0 commit comments

Comments
 (0)