Skip to content

Commit 1f54797

Browse files
committed
compatible with mpk
1 parent e08209b commit 1f54797

File tree

2 files changed

+56
-12
lines changed

2 files changed

+56
-12
lines changed

vllm/compilation/mirage_backend.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
logger = init_logger(__name__)
1313

14+
# TODO(Jianan Ji): Is this name mapping common for all models?
1415
def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]:
15-
"""Transfer FX placeholder debug names to model-like dotted names.
16+
"""Transfer FX placeholder debug names to model-like dotted names. Return a list of transferred names and input id.
1617
1718
Example:
1819
l_self_modules_layers_modules_17_modules_mlp_modules_gate_up_proj_parameters_weight_
@@ -24,13 +25,15 @@ def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]:
2425
Instead, we annotate via node.meta['logical_name'] and return the list.
2526
"""
2627
converted_names = []
27-
s_pattern = re.compile(r"^s\d+$")
28+
s_pattern = re.compile(r"^s\d+$") # s72 / s80
29+
input_id = 0
2830

29-
for node in placeholders:
31+
for i, node in enumerate(placeholders):
3032
name = node.name
3133
if name == 'l_input_ids_':
3234
final_name = 'input_ids'
3335
converted_names.append(final_name)
36+
input_id = i
3437
elif name == 'l_positions_':
3538
final_name = 'positions'
3639
converted_names.append(final_name)
@@ -49,7 +52,7 @@ def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]:
4952

5053
converted_names.append(final_name)
5154

52-
return converted_names
55+
return converted_names, input_id
5356

5457
def build_model_config(
5558
model_config: ModelConfig,
@@ -67,7 +70,6 @@ def build_model_config(
6770
sin_tensor = torch.cat([sin_tensor_, sin_tensor_], dim=-1)
6871

6972
position_embeddings = (cos_tensor, sin_tensor)
70-
logger.info(f"[Mirage] position_embeddings: {position_embeddings[0].shape}, {position_embeddings[1].shape}")
7173
mirage_model_config = MirageModelConfig(
7274
# model architecture
7375
hidden_size=model_config.get_hidden_size(),
@@ -129,6 +131,43 @@ def build_mpk_metadata(
129131
positions_tensor = arg
130132
elif "cos_sin_cache" in name:
131133
position_embeddings = arg
134+
elif "qkv" in name:
135+
# Split qkv since we need to shuffle them on mirage side later
136+
# (6144, 4096) -> (4096, 4096), (1024, 4096), (1024, 4096)
137+
qkv_tensor = arg
138+
139+
total_dim = qkv_tensor.shape[0]
140+
n_q_heads = model_config.get_num_attention_heads(parallel_config) # 32
141+
n_kv_heads = model_config.get_num_kv_heads(parallel_config) # 8
142+
n_heads = n_q_heads + n_kv_heads * 2
143+
144+
q_range = (total_dim * n_q_heads) // n_heads # 6144 * 32 / 48 = 4096
145+
k_range = (total_dim * (n_q_heads + n_kv_heads)) // n_heads # 6144 * 40 / 48 = 5120
146+
147+
q_tensor = qkv_tensor[:q_range, :]
148+
k_tensor = qkv_tensor[q_range:k_range, :]
149+
v_tensor = qkv_tensor[k_range:, :]
150+
151+
# substitute qkv to q/k/v views
152+
state_dict[name.replace("qkv", "q")] = q_tensor
153+
state_dict[name.replace("qkv", "k")] = k_tensor
154+
state_dict[name.replace("qkv", "v")] = v_tensor
155+
156+
state_dict[name] = qkv_tensor
157+
elif "gate_up" in name:
158+
# Split gate_up to gate and up
159+
gate_up_tensor = arg
160+
total_dim = gate_up_tensor.shape[0]
161+
single_dim = total_dim // 2
162+
163+
gate_tensor = gate_up_tensor[:single_dim, :]
164+
up_tensor = gate_up_tensor[single_dim:, :]
165+
166+
# substitude gate_up to gate and up
167+
state_dict[name.replace("gate_up", "gate")] = gate_tensor
168+
state_dict[name.replace("gate_up", "up")] = up_tensor
169+
170+
state_dict[name] = gate_up_tensor
132171
else:
133172
state_dict[name] = arg
134173

@@ -141,7 +180,7 @@ def build_mpk_metadata(
141180
parallel_config,
142181
)
143182
mpk_metadata = MPKMetadata(
144-
mode = "online",
183+
mode = "online_notoken",
145184
# total_num_requests
146185
# num_remote_schedulers: int = 0
147186
max_seq_length = model_config.max_model_len,
@@ -257,7 +296,9 @@ def __call__(
257296
placeholders = [node for node in graph.graph.nodes if node.op == 'placeholder']
258297
assert len(placeholders) == len(example_inputs)
259298

260-
transfered_tensor_names = transfer_tensor_names(placeholders)
299+
transfered_tensor_names, input_id = transfer_tensor_names(placeholders)
300+
301+
max_input_tokens = example_inputs[input_id].shape[0]
261302

262303

263304
self._called = True
@@ -269,7 +310,8 @@ def compile_or_call(*args):
269310
model_config = self.vllm_config.model_config
270311
dtype = model_config.dtype
271312
hidden_size = model_config.get_hidden_size()
272-
output_tensor = torch.zeros(1, hidden_size, device='cuda', dtype=dtype)
313+
# TODO(Jianan Ji): We'll want to run in eager instead of doing nothing
314+
output_tensor = torch.zeros(max_input_tokens, hidden_size, device='cuda', dtype=dtype)
273315
logger.info(f"[Mirage] Calling dumb_run_called, returning dummy output tensor with shape [{output_tensor.shape}]......!")
274316

275317
return (output_tensor,)
@@ -290,7 +332,9 @@ def compile_or_call(*args):
290332
self.compiled = True
291333

292334
logger.info(f"[Mirage] Calling the compiled result...")
293-
return self.mpk()
335+
result_hidden_states = self.mpk()
336+
337+
return (result_hidden_states,)
294338

295339
# return VllmSerializableFunction(
296340
# graph, example_inputs, self.prefix, compile_or_call

vllm/v1/attention/backends/mirage.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,9 @@ def build(
255255
num_prefills=num_prefills,
256256
num_prefill_tokens=num_prefill_tokens,
257257
qo_indptr_gpu=common_attn_metadata.query_start_loc,
258-
paged_kv_indptr_gpu=self.paged_kv_indptr,
259-
paged_kv_indices_gpu=self.paged_kv_indices,
260-
paged_kv_last_page_len_gpu=self.paged_kv_last_page_len,
258+
paged_kv_indptr_gpu=paged_kv_indptr,
259+
paged_kv_indices_gpu=paged_kv_indices,
260+
paged_kv_last_page_len_gpu=paged_kv_last_page_len,
261261
)
262262

263263
return attn_metadata

0 commit comments

Comments
 (0)