@@ -675,6 +675,51 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
675675
676676 return key_map
677677
678+ def z_image_to_diffusers (mmdit_config , output_prefix = "" ):
679+ n_layers = mmdit_config .get ("n_layers" , 0 )
680+ hidden_size = mmdit_config .get ("dim" , 0 )
681+
682+ key_map = {}
683+
684+ for index in range (n_layers ):
685+ prefix_from = "layers.{}" .format (index )
686+ prefix_to = "{}layers.{}" .format (output_prefix , index )
687+
688+ for end in ("weight" , "bias" ):
689+ k = "{}.attention." .format (prefix_from )
690+ qkv = "{}.attention.qkv.{}" .format (prefix_to , end )
691+
692+ key_map ["{}to_q.{}" .format (k , end )] = (qkv , (0 , 0 , hidden_size ))
693+ key_map ["{}to_k.{}" .format (k , end )] = (qkv , (0 , hidden_size , hidden_size ))
694+ key_map ["{}to_v.{}" .format (k , end )] = (qkv , (0 , hidden_size * 2 , hidden_size ))
695+
696+ block_map = {
697+ "attention.norm_q.weight" : "attention.q_norm.weight" ,
698+ "attention.norm_k.weight" : "attention.k_norm.weight" ,
699+ "attention.to_out.0.weight" : "attention.out.weight" ,
700+ "attention.to_out.0.bias" : "attention.out.bias" ,
701+ }
702+
703+ for k in block_map :
704+ key_map ["{}.{}" .format (prefix_from , k )] = "{}.{}" .format (prefix_to , block_map [k ])
705+
706+ MAP_BASIC = {
707+ # Final layer
708+ ("final_layer.linear.weight" , "all_final_layer.2-1.linear.weight" ),
709+ ("final_layer.linear.bias" , "all_final_layer.2-1.linear.bias" ),
710+ ("final_layer.adaLN_modulation.1.weight" , "all_final_layer.2-1.adaLN_modulation.1.weight" ),
711+ ("final_layer.adaLN_modulation.1.bias" , "all_final_layer.2-1.adaLN_modulation.1.bias" ),
712+ # X embedder
713+ ("x_embedder.weight" , "all_x_embedder.2-1.weight" ),
714+ ("x_embedder.bias" , "all_x_embedder.2-1.bias" ),
715+ }
716+
717+ for k in MAP_BASIC :
718+ key_map [k [1 ]] = "{}{}" .format (output_prefix , k [0 ])
719+
720+ return key_map
721+
722+
678723def repeat_to_batch_size (tensor , batch_size , dim = 0 ):
679724 if tensor .shape [dim ] > batch_size :
680725 return tensor .narrow (dim , 0 , batch_size )
0 commit comments