@@ -517,11 +517,23 @@ def patchify_and_embed(
517517 B , C , H , W = x .shape
518518 x = self .x_embedder (x .view (B , C , H // pH , pH , W // pW , pW ).permute (0 , 2 , 4 , 3 , 5 , 1 ).flatten (3 ).flatten (1 , 2 ))
519519
520+ rope_options = transformer_options .get ("rope_options" , None )
521+ h_scale = 1.0
522+ w_scale = 1.0
523+ h_start = 0
524+ w_start = 0
525+ if rope_options is not None :
526+ h_scale = rope_options .get ("scale_y" , 1.0 )
527+ w_scale = rope_options .get ("scale_x" , 1.0 )
528+
529+ h_start = rope_options .get ("shift_y" , 0.0 )
530+ w_start = rope_options .get ("shift_x" , 0.0 )
531+
520532 H_tokens , W_tokens = H // pH , W // pW
521533 x_pos_ids = torch .zeros ((bsz , x .shape [1 ], 3 ), dtype = torch .float32 , device = device )
522534 x_pos_ids [:, :, 0 ] = cap_feats .shape [1 ] + 1
523- x_pos_ids [:, :, 1 ] = torch .arange (H_tokens , dtype = torch .float32 , device = device ).view (- 1 , 1 ).repeat (1 , W_tokens ).flatten ()
524- x_pos_ids [:, :, 2 ] = torch .arange (W_tokens , dtype = torch .float32 , device = device ).view (1 , - 1 ).repeat (H_tokens , 1 ).flatten ()
535+ x_pos_ids [:, :, 1 ] = ( torch .arange (H_tokens , dtype = torch .float32 , device = device ) * h_scale + h_start ).view (- 1 , 1 ).repeat (1 , W_tokens ).flatten ()
536+ x_pos_ids [:, :, 2 ] = ( torch .arange (W_tokens , dtype = torch .float32 , device = device ) * w_scale + w_start ).view (1 , - 1 ).repeat (H_tokens , 1 ).flatten ()
525537
526538 if self .pad_tokens_multiple is not None :
527539 pad_extra = (- x .shape [1 ]) % self .pad_tokens_multiple
0 commit comments