@@ -177,18 +177,18 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
177177 dots_image = einsum ('b i d, b i j d -> b i j' , q_img , k_img )
178178 dots_image_to_text = einsum ('b i d, b j d -> b i j' , q_img , k_text )
179179
180- # calculate causal attention for local convolution
180+ # use padding of 0 on tensor of 1s and unfold for padding mask
181181
182182 i , j = dots_image .shape [- 2 :]
183- img_seq = torch .arange ( img_seq_len , device = device )
184- k_img_indices = rearrange (img_seq . float () , '(h w) -> () () h w' , h = img_size )
185- k_img_indices = F .pad (k_img_indices , causal_padding , value = img_seq_len ) # padding set to be max, so it is never attended to
186- k_img_indices = F .unfold (k_img_indices , kernel_size , dilation = dilation )
187- k_img_indices = rearrange (k_img_indices , 'b j i -> b i j' )
183+ ones = torch .ones (( img_seq_len ,) , device = device )
184+ ones = rearrange (ones , '(h w) -> () () h w' , h = img_size )
185+ ones = F .pad (ones , causal_padding , value = 0. )
186+ ones = F .unfold (ones , kernel_size , dilation = dilation )
187+ ones = rearrange (ones , 'b j i -> b i j' )
188188
189189 # mask image attention
190190
191- padding_mask = k_img_indices == img_seq_len
191+ padding_mask = ones == 0.
192192
193193 # concat text mask with image causal mask
194194
0 commit comments