1111
1212logger = init_logger (__name__ )
1313
14+ # TODO(Jianan Ji): Is this name mapping common for all models?
1415def transfer_tensor_names (placeholders : list [torch .fx .node .Node ]) -> list [str ]:
15- """Transfer FX placeholder debug names to model-like dotted names.
16+ """Transfer FX placeholder debug names to model-like dotted names. Return a list of transferred names and input id.
1617
1718 Example:
1819 l_self_modules_layers_modules_17_modules_mlp_modules_gate_up_proj_parameters_weight_
@@ -24,13 +25,15 @@ def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]:
2425 Instead, we annotate via node.meta['logical_name'] and return the list.
2526 """
2627 converted_names = []
27- s_pattern = re .compile (r"^s\d+$" )
28+ s_pattern = re .compile (r"^s\d+$" ) # s72 / s80
29+ input_id = 0
2830
29- for node in placeholders :
31+ for i , node in enumerate ( placeholders ) :
3032 name = node .name
3133 if name == 'l_input_ids_' :
3234 final_name = 'input_ids'
3335 converted_names .append (final_name )
36+ input_id = i
3437 elif name == 'l_positions_' :
3538 final_name = 'positions'
3639 converted_names .append (final_name )
@@ -49,7 +52,7 @@ def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]:
4952
5053 converted_names .append (final_name )
5154
52- return converted_names
55+ return converted_names , input_id
5356
5457def build_model_config (
5558 model_config : ModelConfig ,
@@ -67,7 +70,6 @@ def build_model_config(
6770 sin_tensor = torch .cat ([sin_tensor_ , sin_tensor_ ], dim = - 1 )
6871
6972 position_embeddings = (cos_tensor , sin_tensor )
70- logger .info (f"[Mirage] position_embeddings: { position_embeddings [0 ].shape } , { position_embeddings [1 ].shape } " )
7173 mirage_model_config = MirageModelConfig (
7274 # model architecture
7375 hidden_size = model_config .get_hidden_size (),
@@ -129,6 +131,43 @@ def build_mpk_metadata(
129131 positions_tensor = arg
130132 elif "cos_sin_cache" in name :
131133 position_embeddings = arg
134+ elif "qkv" in name :
135+ # Split qkv since we need to shuffle them on mirage side later
136+ # (6144, 4096) -> (4096, 4096), (1024, 4096), (1024, 4096)
137+ qkv_tensor = arg
138+
139+ total_dim = qkv_tensor .shape [0 ]
140+ n_q_heads = model_config .get_num_attention_heads (parallel_config ) # 32
141+ n_kv_heads = model_config .get_num_kv_heads (parallel_config ) # 8
142+ n_heads = n_q_heads + n_kv_heads * 2
143+
144+ q_range = (total_dim * n_q_heads ) // n_heads # 6144 * 32 / 48 = 4096
145+ k_range = (total_dim * (n_q_heads + n_kv_heads )) // n_heads # 6144 * 40 / 48 = 5120
146+
147+ q_tensor = qkv_tensor [:q_range , :]
148+ k_tensor = qkv_tensor [q_range :k_range , :]
149+ v_tensor = qkv_tensor [k_range :, :]
150+
151+ # substitute qkv to q/k/v views
152+ state_dict [name .replace ("qkv" , "q" )] = q_tensor
153+ state_dict [name .replace ("qkv" , "k" )] = k_tensor
154+ state_dict [name .replace ("qkv" , "v" )] = v_tensor
155+
156+ state_dict [name ] = qkv_tensor
157+ elif "gate_up" in name :
158+ # Split gate_up to gate and up
159+ gate_up_tensor = arg
160+ total_dim = gate_up_tensor .shape [0 ]
161+ single_dim = total_dim // 2
162+
163+ gate_tensor = gate_up_tensor [:single_dim , :]
164+ up_tensor = gate_up_tensor [single_dim :, :]
165+
166+ # substitude gate_up to gate and up
167+ state_dict [name .replace ("gate_up" , "gate" )] = gate_tensor
168+ state_dict [name .replace ("gate_up" , "up" )] = up_tensor
169+
170+ state_dict [name ] = gate_up_tensor
132171 else :
133172 state_dict [name ] = arg
134173
@@ -141,7 +180,7 @@ def build_mpk_metadata(
141180 parallel_config ,
142181 )
143182 mpk_metadata = MPKMetadata (
144- mode = "online " ,
183+ mode = "online_notoken " ,
145184 # total_num_requests
146185 # num_remote_schedulers: int = 0
147186 max_seq_length = model_config .max_model_len ,
@@ -257,7 +296,9 @@ def __call__(
257296 placeholders = [node for node in graph .graph .nodes if node .op == 'placeholder' ]
258297 assert len (placeholders ) == len (example_inputs )
259298
260- transfered_tensor_names = transfer_tensor_names (placeholders )
299+ transfered_tensor_names , input_id = transfer_tensor_names (placeholders )
300+
301+ max_input_tokens = example_inputs [input_id ].shape [0 ]
261302
262303
263304 self ._called = True
@@ -269,7 +310,8 @@ def compile_or_call(*args):
269310 model_config = self .vllm_config .model_config
270311 dtype = model_config .dtype
271312 hidden_size = model_config .get_hidden_size ()
272- output_tensor = torch .zeros (1 , hidden_size , device = 'cuda' , dtype = dtype )
313+ # TODO(Jianan Ji): We'll want to run in eager instead of doing nothing
314+ output_tensor = torch .zeros (max_input_tokens , hidden_size , device = 'cuda' , dtype = dtype )
273315 logger .info (f"[Mirage] Calling dumb_run_called, returning dummy output tensor with shape [{ output_tensor .shape } ]......!" )
274316
275317 return (output_tensor ,)
@@ -290,7 +332,9 @@ def compile_or_call(*args):
290332 self .compiled = True
291333
292334 logger .info (f"[Mirage] Calling the compiled result..." )
293- return self .mpk ()
335+ result_hidden_states = self .mpk ()
336+
337+ return (result_hidden_states ,)
294338
295339 # return VllmSerializableFunction(
296340 # graph, example_inputs, self.prefix, compile_or_call
0 commit comments