Skip to content

Commit 6319558

Browse files
committed
bug fix
1 parent 045703b commit 6319558

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

vllm/compilation/mirage_backend.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,18 @@ def build_model_config(
5656
state_dict: dict[str, torch.Tensor],
5757
k_cache_tensors: list[torch.Tensor],
5858
v_cache_tensors: list[torch.Tensor],
59-
position_embeddings: torch.Tensor,
59+
position_embeddings_: torch.Tensor,
6060
parallel_config: ParallelConfig,
6161
) -> MirageModelConfig:
62+
whole_dim = position_embeddings_.shape[-1]
63+
cos_tensor_ = position_embeddings_[:, 0:whole_dim//2].unsqueeze(0)
64+
sin_tensor_ = position_embeddings_[:, whole_dim//2:].unsqueeze(0)
65+
66+
cos_tensor = torch.cat([cos_tensor_, cos_tensor_], dim=-1)
67+
sin_tensor = torch.cat([sin_tensor_, sin_tensor_], dim=-1)
68+
69+
position_embeddings = (cos_tensor, sin_tensor)
70+
logger.info(f"[Mirage] position_embeddings: {position_embeddings[0].shape}, {position_embeddings[1].shape}")
6271
mirage_model_config = MirageModelConfig(
6372
# model architecture
6473
hidden_size=model_config.get_hidden_size(),
@@ -75,6 +84,7 @@ def build_model_config(
7584
position_embeddings=position_embeddings,
7685
# model weights
7786
state_dict=state_dict,
87+
with_lm_head=False,
7888
)
7989
return mirage_model_config
8090

@@ -88,9 +98,9 @@ def build_mpk_metadata(
8898
scheduler_config = vllm_config.scheduler_config
8999
cache_config = vllm_config.cache_config
90100
parallel_config = vllm_config.parallel_config
91-
attn_metadata = forward_context.attn_metadata
92-
logger.info(f"[Mirage] Forward context: {forward_context}, attn_metadata: {attn_metadata}")
93-
101+
# For now we assume only one attention group
102+
attn_metadata = list(forward_context.attn_metadata.values())[0]
103+
94104
static_forward_context = forward_context.no_compile_layers # layer names to layers
95105
k_cache_tensors = []
96106
v_cache_tensors = []
@@ -275,7 +285,7 @@ def compile_or_call(*args):
275285
logger.info(f"[Mirage] MPK metadata: {mpk_metadata.info_as_string()}")
276286
self.mpk = MPK(mpk_metadata)
277287
self.mpk.build()
278-
self.mpk.compile()
288+
self.mpk.compile(output_dir=os.path.join(os.path.dirname(__file__), "mirage_backend_output"))
279289

280290
self.compiled = True
281291

vllm/v1/attention/backends/mirage.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,6 @@ def get_kv_cache_stride_order() -> tuple[int, ...]:
108108
class MirageAttentionMetadata:
109109
num_actual_tokens: int # Number of tokens excluding padding.
110110

111-
# The data type of the query
112-
q_data_type: torch.dtype
113-
114111
# For handling prefill decode split
115112
num_decodes: int
116113
num_decode_tokens: int
@@ -250,17 +247,14 @@ def build(
250247
)
251248

252249
# uses_spec_reorder = self.reorder_batch_threshold > 1
253-
254-
assert self.q_data_type == torch.bfloat16, "MirageAttentionBackend currently only supports bfloat16"
255250

256251
attn_metadata = MirageAttentionMetadata(
257252
num_actual_tokens=num_actual_tokens,
258-
q_data_type=self.q_data_type,
259253
num_decodes=num_decodes,
260254
num_decode_tokens=num_decode_tokens,
261255
num_prefills=num_prefills,
262256
num_prefill_tokens=num_prefill_tokens,
263-
qo_indptr_gpu=common_attn_metadata.query_start_loc_gpu,
257+
qo_indptr_gpu=common_attn_metadata.query_start_loc,
264258
paged_kv_indptr_gpu=self.paged_kv_indptr,
265259
paged_kv_indices_gpu=self.paged_kv_indices,
266260
paged_kv_last_page_len_gpu=self.paged_kv_last_page_len,

0 commit comments

Comments
 (0)