@@ -678,17 +678,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
678678def z_image_to_diffusers (mmdit_config , output_prefix = "" ):
679679 n_layers = mmdit_config .get ("n_layers" , 0 )
680680 hidden_size = mmdit_config .get ("dim" , 0 )
681-
681+ n_context_refiner = mmdit_config .get ("n_refiner_layers" , 2 )
682+ n_noise_refiner = mmdit_config .get ("n_refiner_layers" , 2 )
682683 key_map = {}
683684
684- for index in range (n_layers ):
685- prefix_from = "layers.{}" .format (index )
686- prefix_to = "{}layers.{}" .format (output_prefix , index )
687-
685+ def add_block_keys (prefix_from , prefix_to , has_adaln = True ):
688686 for end in ("weight" , "bias" ):
689687 k = "{}.attention." .format (prefix_from )
690688 qkv = "{}.attention.qkv.{}" .format (prefix_to , end )
691-
692689 key_map ["{}to_q.{}" .format (k , end )] = (qkv , (0 , 0 , hidden_size ))
693690 key_map ["{}to_k.{}" .format (k , end )] = (qkv , (0 , hidden_size , hidden_size ))
694691 key_map ["{}to_v.{}" .format (k , end )] = (qkv , (0 , hidden_size * 2 , hidden_size ))
@@ -698,28 +695,52 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""):
698695 "attention.norm_k.weight" : "attention.k_norm.weight" ,
699696 "attention.to_out.0.weight" : "attention.out.weight" ,
700697 "attention.to_out.0.bias" : "attention.out.bias" ,
698+ "attention_norm1.weight" : "attention_norm1.weight" ,
699+ "attention_norm2.weight" : "attention_norm2.weight" ,
700+ "feed_forward.w1.weight" : "feed_forward.w1.weight" ,
701+ "feed_forward.w2.weight" : "feed_forward.w2.weight" ,
702+ "feed_forward.w3.weight" : "feed_forward.w3.weight" ,
703+ "ffn_norm1.weight" : "ffn_norm1.weight" ,
704+ "ffn_norm2.weight" : "ffn_norm2.weight" ,
701705 }
706+ if has_adaln :
707+ block_map ["adaLN_modulation.0.weight" ] = "adaLN_modulation.0.weight"
708+ block_map ["adaLN_modulation.0.bias" ] = "adaLN_modulation.0.bias"
709+ for k , v in block_map .items ():
710+ key_map ["{}.{}" .format (prefix_from , k )] = "{}.{}" .format (prefix_to , v )
702711
703- for k in block_map :
704- key_map [ "{} .{}" .format (prefix_from , k )] = "{}.{}" .format (prefix_to , block_map [ k ] )
712+ for i in range ( n_layers ) :
713+ add_block_keys ( "layers .{}" .format (i ), "{}layers .{}" .format (output_prefix , i ) )
705714
706- MAP_BASIC = {
707- # Final layer
715+ for i in range (n_context_refiner ):
716+ add_block_keys ("context_refiner.{}" .format (i ), "{}context_refiner.{}" .format (output_prefix , i ))
717+
718+ for i in range (n_noise_refiner ):
719+ add_block_keys ("noise_refiner.{}" .format (i ), "{}noise_refiner.{}" .format (output_prefix , i ))
720+
721+ MAP_BASIC = [
708722 ("final_layer.linear.weight" , "all_final_layer.2-1.linear.weight" ),
709723 ("final_layer.linear.bias" , "all_final_layer.2-1.linear.bias" ),
710724 ("final_layer.adaLN_modulation.1.weight" , "all_final_layer.2-1.adaLN_modulation.1.weight" ),
711725 ("final_layer.adaLN_modulation.1.bias" , "all_final_layer.2-1.adaLN_modulation.1.bias" ),
712- # X embedder
713726 ("x_embedder.weight" , "all_x_embedder.2-1.weight" ),
714727 ("x_embedder.bias" , "all_x_embedder.2-1.bias" ),
715- }
716-
717- for k in MAP_BASIC :
718- key_map [k [1 ]] = "{}{}" .format (output_prefix , k [0 ])
728+ ("x_pad_token" , "x_pad_token" ),
729+ ("cap_embedder.0.weight" , "cap_embedder.0.weight" ),
730+ ("cap_embedder.1.weight" , "cap_embedder.1.weight" ),
731+ ("cap_embedder.1.bias" , "cap_embedder.1.bias" ),
732+ ("cap_pad_token" , "cap_pad_token" ),
733+ ("t_embedder.mlp.0.weight" , "t_embedder.mlp.0.weight" ),
734+ ("t_embedder.mlp.0.bias" , "t_embedder.mlp.0.bias" ),
735+ ("t_embedder.mlp.2.weight" , "t_embedder.mlp.2.weight" ),
736+ ("t_embedder.mlp.2.bias" , "t_embedder.mlp.2.bias" ),
737+ ]
738+
739+ for c , diffusers in MAP_BASIC :
740+ key_map [diffusers ] = "{}{}" .format (output_prefix , c )
719741
720742 return key_map
721743
722-
723744def repeat_to_batch_size (tensor , batch_size , dim = 0 ):
724745 if tensor .shape [dim ] > batch_size :
725746 return tensor .narrow (dim , 0 , batch_size )
0 commit comments