Skip to content

Commit 52a32e2

Browse files
Support some z image lora formats. (#10978)
1 parent b907085 commit 52a32e2

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

comfy/lora.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,14 @@ def model_lora_keys_unet(model, key_map={}):
313313
key_map["transformer.{}".format(key_lora)] = k
314314
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
315315

316+
if isinstance(model, comfy.model_base.Lumina2):
317+
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
318+
for k in diffusers_keys:
319+
to = diffusers_keys[k]
320+
key_lora = k[:-len(".weight")]
321+
key_map["diffusion_model.{}".format(key_lora)] = to
322+
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
323+
316324
return key_map
317325

318326

comfy/utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,51 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
675675

676676
return key_map
677677

678+
def z_image_to_diffusers(mmdit_config, output_prefix=""):
679+
n_layers = mmdit_config.get("n_layers", 0)
680+
hidden_size = mmdit_config.get("dim", 0)
681+
682+
key_map = {}
683+
684+
for index in range(n_layers):
685+
prefix_from = "layers.{}".format(index)
686+
prefix_to = "{}layers.{}".format(output_prefix, index)
687+
688+
for end in ("weight", "bias"):
689+
k = "{}.attention.".format(prefix_from)
690+
qkv = "{}.attention.qkv.{}".format(prefix_to, end)
691+
692+
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
693+
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
694+
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
695+
696+
block_map = {
697+
"attention.norm_q.weight": "attention.q_norm.weight",
698+
"attention.norm_k.weight": "attention.k_norm.weight",
699+
"attention.to_out.0.weight": "attention.out.weight",
700+
"attention.to_out.0.bias": "attention.out.bias",
701+
}
702+
703+
for k in block_map:
704+
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
705+
706+
MAP_BASIC = {
707+
# Final layer
708+
("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"),
709+
("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"),
710+
("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"),
711+
("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"),
712+
# X embedder
713+
("x_embedder.weight", "all_x_embedder.2-1.weight"),
714+
("x_embedder.bias", "all_x_embedder.2-1.bias"),
715+
}
716+
717+
for k in MAP_BASIC:
718+
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
719+
720+
return key_map
721+
722+
678723
def repeat_to_batch_size(tensor, batch_size, dim=0):
679724
if tensor.shape[dim] > batch_size:
680725
return tensor.narrow(dim, 0, batch_size)

0 commit comments

Comments
 (0)