@@ -111,6 +111,7 @@ def __init__(
111111 smooth_l1_loss = False ,
112112 temperature = 0.9 ,
113113 straight_through = False ,
114+ reinmax = False ,
114115 kl_div_loss_weight = 0. ,
115116 normalization = ((* ((0.5 ,) * 3 ), 0 ), (* ((0.5 ,) * 3 ), 1 ))
116117 ):
@@ -125,6 +126,8 @@ def __init__(
125126 self .num_layers = num_layers
126127 self .temperature = temperature
127128 self .straight_through = straight_through
129+ self .reinmax = reinmax
130+
128131 self .codebook = nn .Embedding (num_tokens , codebook_dim )
129132
130133 hdim = hidden_dim
@@ -227,8 +230,20 @@ def forward(
227230 return logits # return logits for getting hard image indices for DALL-E training
228231
229232 temp = default (temp , self .temperature )
230- soft_one_hot = F .gumbel_softmax (logits , tau = temp , dim = 1 , hard = self .straight_through )
231- sampled = einsum ('b n h w, n d -> b d h w' , soft_one_hot , self .codebook .weight )
233+
234+ one_hot = F .gumbel_softmax (logits , tau = temp , dim = 1 , hard = self .straight_through )
235+
236+ if self .straight_through and self .reinmax :
237+ # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
238+ # algorithm 2
239+ one_hot = one_hot .detach ()
240+ π0 = logits .softmax (dim = 1 )
241+ π1 = (one_hot + (logits / temp ).softmax (dim = 1 )) / 2
242+ π1 = ((π1 .log () - logits ).detach () + logits ).softmax (dim = 1 )
243+ π2 = 2 * π1 - 0.5 * π0
244+ one_hot = π2 - π2 .detach () + one_hot
245+
246+ sampled = einsum ('b n h w, n d -> b d h w' , one_hot , self .codebook .weight )
232247 out = self .decoder (sampled )
233248
234249 if not return_loss :
0 commit comments