@@ -377,6 +377,7 @@ def __init__(
377377 z_image_modulation = False ,
378378 time_scale = 1.0 ,
379379 pad_tokens_multiple = None ,
380+ clip_text_dim = None ,
380381 image_model = None ,
381382 device = None ,
382383 dtype = None ,
@@ -447,6 +448,31 @@ def __init__(
447448 ),
448449 )
449450
451+ self .clip_text_pooled_proj = None
452+
453+ if clip_text_dim is not None :
454+ self .clip_text_dim = clip_text_dim
455+ self .clip_text_pooled_proj = nn .Sequential (
456+ operation_settings .get ("operations" ).RMSNorm (clip_text_dim , eps = norm_eps , elementwise_affine = True , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )),
457+ operation_settings .get ("operations" ).Linear (
458+ clip_text_dim ,
459+ clip_text_dim ,
460+ bias = True ,
461+ device = operation_settings .get ("device" ),
462+ dtype = operation_settings .get ("dtype" ),
463+ ),
464+ )
465+ self .time_text_embed = nn .Sequential (
466+ nn .SiLU (),
467+ operation_settings .get ("operations" ).Linear (
468+ min (dim , 1024 ) + clip_text_dim ,
469+ min (dim , 1024 ),
470+ bias = True ,
471+ device = operation_settings .get ("device" ),
472+ dtype = operation_settings .get ("dtype" ),
473+ ),
474+ )
475+
450476 self .layers = nn .ModuleList (
451477 [
452478 JointTransformerBlock (
@@ -585,6 +611,15 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, trans
585611
586612 cap_feats = self .cap_embedder (cap_feats ) # (N, L, D) # todo check if able to batchify w.o. redundant compute
587613
614+ if self .clip_text_pooled_proj is not None :
615+ pooled = kwargs .get ("clip_text_pooled" , None )
616+ if pooled is not None :
617+ pooled = self .clip_text_pooled_proj (pooled )
618+ else :
619+ pooled = torch .zeros ((1 , self .clip_text_dim ), device = x .device , dtype = x .dtype )
620+
621+ adaln_input = self .time_text_embed (torch .cat ((t , pooled ), dim = - 1 ))
622+
588623 patches = transformer_options .get ("patches" , {})
589624 x_is_tensor = isinstance (x , torch .Tensor )
590625 img , mask , img_size , cap_size , freqs_cis = self .patchify_and_embed (x , cap_feats , cap_mask , t , num_tokens , transformer_options = transformer_options )
0 commit comments