@@ -306,6 +306,21 @@ def forward(
306306
307307# main DALL-E class
308308
309+ class SharedEmbedding (nn .Embedding ):
310+ def __init__ (self , linear , start_index , end_index , ** kwargs ):
311+ super ().__init__ (end_index - start_index , linear .weight .shape [1 ], ** kwargs )
312+ del self .weight
313+
314+ self .linear = linear
315+ self .start_index = start_index
316+ self .end_index = end_index
317+
318+ def forward (self , input ):
319+ return F .embedding (
320+ input , self .linear .weight [self .start_index :self .end_index ], self .padding_idx , self .max_norm ,
321+ self .norm_type , self .scale_grad_by_freq , self .sparse )
322+
323+
309324class DALLE (nn .Module ):
310325 def __init__ (
311326 self ,
@@ -329,6 +344,7 @@ def __init__(
329344 rotary_emb = True ,
330345 shared_attn_ids = None ,
331346 shared_ff_ids = None ,
347+ share_input_output_emb = False ,
332348 ):
333349 super ().__init__ ()
334350 assert isinstance (vae , (DiscreteVAE , OpenAIDiscreteVAE , VQGanVAE )), 'vae must be an instance of DiscreteVAE'
@@ -340,9 +356,6 @@ def __init__(
340356
341357 num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len)
342358
343- self .text_emb = nn .Embedding (num_text_tokens , dim )
344- self .image_emb = nn .Embedding (num_image_tokens , dim )
345-
346359 self .text_pos_emb = nn .Embedding (text_seq_len + 1 , dim ) if not rotary_emb else always (0 ) # +1 for <bos>
347360 self .image_pos_emb = AxialPositionalEmbedding (dim , axial_shape = (image_fmap_size , image_fmap_size )) if not rotary_emb else always (0 )
348361
@@ -391,6 +404,13 @@ def __init__(
391404 nn .Linear (dim , self .total_tokens ),
392405 )
393406
407+ if share_input_output_emb :
408+ self .text_emb = SharedEmbedding (self .to_logits [1 ], 0 , num_text_tokens )
409+ self .image_emb = SharedEmbedding (self .to_logits [1 ], num_text_tokens , total_tokens )
410+ else :
411+ self .text_emb = nn .Embedding (num_text_tokens , dim )
412+ self .image_emb = nn .Embedding (num_image_tokens , dim )
413+
394414 seq_range = torch .arange (seq_len )
395415 logits_range = torch .arange (total_tokens )
396416
0 commit comments