@@ -56,9 +56,18 @@ def build_model_config(
5656 state_dict : dict [str , torch .Tensor ],
5757 k_cache_tensors : list [torch .Tensor ],
5858 v_cache_tensors : list [torch .Tensor ],
59- position_embeddings : torch .Tensor ,
59+ position_embeddings_ : torch .Tensor ,
6060 parallel_config : ParallelConfig ,
6161) -> MirageModelConfig :
62+ whole_dim = position_embeddings_ .shape [- 1 ]
63+ cos_tensor_ = position_embeddings_ [:, 0 :whole_dim // 2 ].unsqueeze (0 )
64+ sin_tensor_ = position_embeddings_ [:, whole_dim // 2 :].unsqueeze (0 )
65+
66+ cos_tensor = torch .cat ([cos_tensor_ , cos_tensor_ ], dim = - 1 )
67+ sin_tensor = torch .cat ([sin_tensor_ , sin_tensor_ ], dim = - 1 )
68+
69+ position_embeddings = (cos_tensor , sin_tensor )
70+ logger .info (f"[Mirage] position_embeddings: { position_embeddings [0 ].shape } , { position_embeddings [1 ].shape } " )
6271 mirage_model_config = MirageModelConfig (
6372 # model architecture
6473 hidden_size = model_config .get_hidden_size (),
@@ -75,6 +84,7 @@ def build_model_config(
7584 position_embeddings = position_embeddings ,
7685 # model weights
7786 state_dict = state_dict ,
87+ with_lm_head = False ,
7888 )
7989 return mirage_model_config
8090
@@ -88,9 +98,9 @@ def build_mpk_metadata(
8898 scheduler_config = vllm_config .scheduler_config
8999 cache_config = vllm_config .cache_config
90100 parallel_config = vllm_config .parallel_config
91- attn_metadata = forward_context . attn_metadata
92- logger . info ( f"[Mirage] Forward context: { forward_context } , attn_metadata: { attn_metadata } " )
93-
101+ # For now we assume only one attention group
102+ attn_metadata = list ( forward_context . attn_metadata . values ())[ 0 ]
103+
94104 static_forward_context = forward_context .no_compile_layers # layer names to layers
95105 k_cache_tensors = []
96106 v_cache_tensors = []
@@ -275,7 +285,7 @@ def compile_or_call(*args):
275285 logger .info (f"[Mirage] MPK metadata: { mpk_metadata .info_as_string ()} " )
276286 self .mpk = MPK (mpk_metadata )
277287 self .mpk .build ()
278- self .mpk .compile ()
288+ self .mpk .compile (output_dir = os . path . join ( os . path . dirname ( __file__ ), "mirage_backend_output" ) )
279289
280290 self .compiled = True
281291
0 commit comments