55#
66# -----------------------------------------------------------------------------
77
8+ from typing import Optional
9+
810import torch
911import torch .nn as nn
1012import torch .nn .functional as F
@@ -34,7 +36,15 @@ def __init__(self, model):
3436 self .config = self .model .language_model .config
3537 self .language_model = self .model .language_model
3638
37- def forward (self , input_ids , vision_embeds , position_ids , image_idx , past_key_values ):
39+ def forward (
40+ self ,
41+ input_ids ,
42+ vision_embeds ,
43+ position_ids ,
44+ image_idx ,
45+ past_key_values ,
46+ batch_index : Optional [torch .LongTensor ] = None ,
47+ ):
3848 input_embeds = self .model .language_model .get_input_embeddings ()(input_ids )
3949 B , N , C = input_embeds .shape
4050 image_input_embeds = input_embeds .reshape (B * N , C )
@@ -55,7 +65,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
5565 inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), input_embeds , image_input_embeds )
5666 inputs_embeds = inputs_embeds .reshape (B , N , C )
5767 outputs = self .model .language_model (
58- inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True
68+ inputs_embeds = inputs_embeds ,
69+ position_ids = position_ids ,
70+ past_key_values = past_key_values ,
71+ batch_index = batch_index ,
72+ use_cache = True ,
5973 )
6074 image_idx = (indices1 .max () + 1 ).unsqueeze (0 ).unsqueeze (0 )
6175 return outputs .logits , vision_embeds , image_idx , outputs .past_key_values
@@ -75,6 +89,9 @@ def get_specializations(
7589 ctx_len : int ,
7690 img_size : int ,
7791 kv_offload : bool = False ,
92+ continuous_batching : bool = False ,
93+ kv_cache_batch_size : Optional [int ] = None ,
94+ full_batch_size : Optional [int ] = None ,
7895 ** compiler_options ,
7996 ):
8097 num_patches = compiler_options .pop ("num_patches" , None )
@@ -104,24 +121,38 @@ def get_specializations(
104121 "batched_num_patches" : batch_size * num_patches ,
105122 }
106123 ]
107- lang = [
108- {
109- "batch_size" : batch_size ,
110- "seq_len" : prefill_seq_len ,
111- "ctx_len" : ctx_len ,
112- "num_patches" : num_patches ,
113- "img_size" : img_size ,
114- "vision_size" : vision_size ,
115- },
116- {
117- "batch_size" : batch_size ,
118- "seq_len" : "1" ,
119- "ctx_len" : ctx_len ,
120- "num_patches" : num_patches ,
121- "img_size" : img_size ,
122- "vision_size" : vision_size ,
123- },
124- ]
124+ lang_prefill = {
125+ "batch_size" : 1 if continuous_batching else batch_size ,
126+ "seq_len" : prefill_seq_len ,
127+ "ctx_len" : ctx_len ,
128+ "num_patches" : num_patches ,
129+ "img_size" : img_size ,
130+ "vision_size" : vision_size ,
131+ }
132+ if continuous_batching :
133+ lang_prefill ["full_batch_size" ] = kv_cache_batch_size
134+ else :
135+ lang_prefill ["batch_size" ] = kv_cache_batch_size
136+ if full_batch_size :
137+ lang_prefill ["full_batch_exec_size" ] = full_batch_size
138+
139+ lang_decode = {
140+ "batch_size" : full_batch_size if continuous_batching else batch_size ,
141+ "seq_len" : "1" ,
142+ "ctx_len" : ctx_len ,
143+ "num_patches" : num_patches ,
144+ "img_size" : img_size ,
145+ "vision_size" : vision_size ,
146+ }
147+
148+ if continuous_batching :
149+ lang_decode ["full_batch_size" ] = kv_cache_batch_size
150+ else :
151+ lang_decode ["batch_size" ] = kv_cache_batch_size
152+
153+ lang = []
154+ lang .append (lang_prefill )
155+ lang .append (lang_decode )
125156
126157 specializations = {}
127158
@@ -130,18 +161,22 @@ def get_specializations(
130161 specializations ["lang" ] = lang
131162 return specializations , compiler_options
132163 else :
164+ lang [0 ].pop ("vision_size" )
165+ lang [1 ].pop ("vision_size" )
133166 return lang , compiler_options
134167
135- def get_onnx_dynamic_axes (self , kv_offload : bool = False ):
168+ def get_onnx_dynamic_axes (self , kv_offload : bool = False , continuous_batching : bool = False ):
136169 # Define dynamic axes
137170 vision_dynamic_axes = {}
138171 lang_dynamic_axes = {}
139172 lang_dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
140173 lang_dynamic_axes ["position_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
141174 lang_dynamic_axes ["vision_embeds" ] = {1 : "vision_size" }
175+ if continuous_batching :
176+ lang_dynamic_axes ["batch_index" ] = {0 : "batch_size" }
142177 vision_dynamic_axes ["pixel_values" ] = {0 : "batched_num_patches" , 2 : "img_size" , 3 : "img_size" }
143178
144- pkv_dynamic_axes = {0 : "batch_size" , 2 : "ctx_len" }
179+ pkv_dynamic_axes = {0 : "full_batch_size" if continuous_batching else " batch_size" , 2 : "ctx_len" }
145180 for i in range (self .language_model .config .num_hidden_layers ):
146181 for kv in ["key" , "value" ]:
147182 lang_dynamic_axes [f"past_{ kv } .{ i } " ] = pkv_dynamic_axes
@@ -173,7 +208,7 @@ def get_output_names(self, kv_offload: bool = False):
173208 return lang_output_names
174209 return output_names
175210
176- def get_dummy_inputs (self , kv_offload : bool = False ):
211+ def get_dummy_inputs (self , kv_offload : bool = False , continuous_batching : bool = False ):
177212 if vis_cfg := getattr (self .config , "vision_config" , None ):
178213 img_size = getattr (vis_cfg , "image_size" , constants .INTERN_IMG_SIZE )
179214 else :
@@ -222,10 +257,13 @@ def get_dummy_inputs(self, kv_offload: bool = False):
222257 )
223258 lang_inputs ["image_idx" ] = torch .zeros ((1 , 1 ), dtype = torch .int64 )
224259
260+ bs : int = constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE
261+ fbs : int = constants .ONNX_EXPORT_EXAMPLE_FBS
262+
225263 # Add data for KV
226264 kv_cache_shape = get_padding_shape_from_config (
227265 config = self .language_model .config ,
228- batch_size = constants . ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
266+ batch_size = fbs if continuous_batching else bs ,
229267 seq_len = constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
230268 )
231269
@@ -234,6 +272,9 @@ def get_dummy_inputs(self, kv_offload: bool = False):
234272 for kv in ["key" , "value" ]:
235273 lang_inputs ["past_key_values" ][i ].append (torch .zeros (kv_cache_shape , dtype = torch .float32 ))
236274
275+ if continuous_batching :
276+ lang_inputs ["batch_index" ] = torch .arange (bs ).view (bs , 1 )
277+
237278 inputs = {}
238279 if kv_offload :
239280 inputs ["vision" ] = vision_inputs
0 commit comments