Skip to content
1 change: 1 addition & 0 deletions vllm/attention/backends/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"RocmAiterUnifiedAttentionBackend"
)
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
MIRAGE = "vllm.v1.attention.backends.mirage.MirageAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
CUSTOM = ""

Expand Down
318 changes: 318 additions & 0 deletions vllm/compilation/mirage_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
import os
from collections import defaultdict
import time
from mirage import MPK, MPKMetadata, MirageModelConfig
import re
from typing import Any

import torch
import torch.fx as fx

from vllm.config import CompilationConfig, ModelConfig, VllmConfig, get_current_vllm_config
from vllm.config.parallel import ParallelConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.models.utils import extract_layer_index
from vllm.logger import init_logger

from .counter import compilation_counter

logger = init_logger(__name__)

# TODO(Jianan Ji): Is this name mapping common for all models?
def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]:
"""Transfer FX placeholder debug names to model-like dotted names. Return a list of transferred names and input id.

Example:
l_self_modules_layers_modules_17_modules_mlp_\
modules_gate_up_proj_parameters_weight_

Check failure on line 27 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/mirage_backend.py:27:89: E501 Line too long (119 > 88)
-> model.layers.17.mlp.gate_up_proj.weight

Notes:
- Tailored for Qwen3-style module names seen in exported FX graphs.
- We do NOT rename the FX node identifiers (dots are not valid in FX names).
Instead, we annotate via node.meta['logical_name'] and return the list.
"""
converted_names = []
s_pattern = re.compile(r"^s\d+$") # s72 / s80

for node in placeholders:
name = node.name
if name == 'l_input_ids_':
final_name = 'input_ids'
converted_names.append(final_name)
elif name == 'l_positions_':
final_name = 'positions'
converted_names.append(final_name)
elif s_pattern.match(name): # s72 / s80
converted_names.append(name)
else:
if name.startswith('l_self_modules_'):
name = name.replace('l_self_modules_', '', 1)
if name.endswith('_'):
name = name[:-1]

name = name.replace('_modules_', '.')
name = name.replace('_parameters_', '.')

final_name = 'model.' + name

converted_names.append(final_name)

return converted_names

def build_model_config(
model_config: ModelConfig,
state_dict: dict[str, torch.Tensor],
k_cache_tensors: list[torch.Tensor],
v_cache_tensors: list[torch.Tensor],
position_embeddings_: torch.Tensor,
parallel_config: ParallelConfig,
) -> MirageModelConfig:
whole_dim = position_embeddings_.shape[-1]
cos_tensor_ = position_embeddings_[:, 0:whole_dim//2].unsqueeze(0)
sin_tensor_ = position_embeddings_[:, whole_dim//2:].unsqueeze(0)

cos_tensor = torch.cat([cos_tensor_, cos_tensor_], dim=-1)
sin_tensor = torch.cat([sin_tensor_, sin_tensor_], dim=-1)

position_embeddings = (cos_tensor, sin_tensor)
mirage_model_config = MirageModelConfig(
# model architecture
hidden_size=model_config.get_hidden_size(),
intermediate_size=getattr(model_config.hf_text_config, "intermediate_size", 0),
vocab_size=model_config.get_vocab_size(),
local_num_q_heads=model_config.get_num_attention_heads(parallel_config),
local_num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_dim=model_config.get_head_size(),
num_layers=getattr(model_config.hf_text_config, "num_hidden_layers", 0),
# kv cache
k_cache=k_cache_tensors,
v_cache=v_cache_tensors,
# position embeddings
position_embeddings=position_embeddings,
# model weights
state_dict=state_dict,
with_lm_head=False,
)
return mirage_model_config

def build_mpk_metadata(
vllm_config: VllmConfig,
args: list[Any],
transfered_tensor_names: list[str],
) -> MPKMetadata:
forward_context = get_forward_context()
model_config = vllm_config.model_config
scheduler_config = vllm_config.scheduler_config
cache_config = vllm_config.cache_config
parallel_config = vllm_config.parallel_config
# For now we assume only one attention group
attn_metadata = list(forward_context.attn_metadata.values())[0]

static_forward_context = forward_context.no_compile_layers # layer names to layers
k_cache_tensors = []
v_cache_tensors = []
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in static_forward_context.keys():
index2name[extract_layer_index(layer_name, 1)].append(layer_name)

for layer_index in sorted(index2name):
layer_names = index2name[layer_index]

Check failure on line 121 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM118)

vllm/compilation/mirage_backend.py:121:9: SIM118 Use `key in dict` instead of `key in dict.keys()`
assert len(layer_names) == 1, "Multiple layers with the same layer index are not supported"
layer_name = layer_names[0]
k_cache_tensors.append(static_forward_context[layer_name].kv_cache[0][0])
v_cache_tensors.append(static_forward_context[layer_name].kv_cache[0][1])

Check failure on line 126 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/mirage_backend.py:126:89: E501 Line too long (99 > 88)
state_dict = {}
input_token_ids = None
positions_tensor = None
position_embeddings = None
for arg, name in zip(args, transfered_tensor_names):
if name == 'input_ids':
input_token_ids = arg
elif name == 'positions':
positions_tensor = arg
elif "cos_sin_cache" in name:
position_embeddings = arg
elif "qkv" in name:
# Split qkv since we need to shuffle them on mirage side later
# (6144, 4096) -> (4096, 4096), (1024, 4096), (1024, 4096)
qkv_tensor = arg

total_dim = qkv_tensor.shape[0]
n_q_heads = model_config.get_num_attention_heads(parallel_config) # 32
n_kv_heads = model_config.get_num_kv_heads(parallel_config) # 8
n_heads = n_q_heads + n_kv_heads * 2

q_range = (total_dim * n_q_heads) // n_heads # 6144 * 32 / 48 = 4096
k_range = (total_dim * (n_q_heads + n_kv_heads)) // n_heads # 6144 * 40 / 48 = 5120

q_tensor = qkv_tensor[:q_range, :]
k_tensor = qkv_tensor[q_range:k_range, :]
v_tensor = qkv_tensor[k_range:, :]

Check failure on line 153 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/mirage_backend.py:153:89: E501 Line too long (95 > 88)

# substitute qkv to q/k/v views
state_dict[name.replace("qkv", "q")] = q_tensor
state_dict[name.replace("qkv", "k")] = k_tensor
state_dict[name.replace("qkv", "v")] = v_tensor

state_dict[name] = qkv_tensor
elif "gate_up" in name:
# Split gate_up to gate and up
gate_up_tensor = arg
total_dim = gate_up_tensor.shape[0]
single_dim = total_dim // 2

gate_tensor = gate_up_tensor[:single_dim, :]
up_tensor = gate_up_tensor[single_dim:, :]

# substitude gate_up to gate and up
state_dict[name.replace("gate_up", "gate")] = gate_tensor
state_dict[name.replace("gate_up", "up")] = up_tensor

state_dict[name] = gate_up_tensor
else:
state_dict[name] = arg

mirage_model_config = build_model_config(
model_config,
state_dict,
k_cache_tensors,
v_cache_tensors,
position_embeddings,
parallel_config,
)
mpk_metadata = MPKMetadata(
mode = "online_notoken",
# total_num_requests
# num_remote_schedulers: int = 0
max_seq_length = model_config.max_model_len,
max_num_batched_requests = scheduler_config.max_num_seqs,
max_num_batched_tokens = scheduler_config.max_num_batched_tokens,
max_num_pages = cache_config.num_gpu_blocks,
page_size = cache_config.block_size,
# max_sm_num: int = 108
device = "cuda",
# # model
weight_from_model = False,
model_name = model_config.model,
# model_path: Optional[str] = None
# multi device support
world_size = parallel_config.world_size,
rank = parallel_config.rank,
# # Meta tensors
step = positions_tensor,
# tokens: Optional[torch.Tensor] = None
input_tokens = input_token_ids,
# output_tokens: Optional[torch.Tensor] = None
# num_new_tokens: Optional[torch.Tensor] = None
# prompt_lengths: Optional[torch.Tensor] = None
qo_indptr_buffer = attn_metadata.qo_indptr_gpu,
paged_kv_indptr_buffer = attn_metadata.paged_kv_indptr_gpu,
paged_kv_indices_buffer = attn_metadata.paged_kv_indices_gpu,
paged_kv_last_page_len_buffer = attn_metadata.paged_kv_last_page_len_gpu,
# kv cache tensors, weights and model config
model_config=mirage_model_config,
# # profiling
# profiler_tensor: Optional[torch.Tensor] = None
# trace_name: Optional[str] = None
# # spec decode config
# spec_decode_config: Optional[object] = None
)
return mpk_metadata

class MirageBackend:
"""The compilation backend for Mirage Persistent Kernel."""

vllm_config: VllmConfig
compilation_config: CompilationConfig
_called: bool = False
# the graph we compiled
graph: fx.GraphModule

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
logger.debug("[Mirage] Calling MirageBackend init!")

self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.model_config = vllm_config.model_config
self.model_name = vllm_config.model_config.model

def __call__(
self, graph: fx.GraphModule, example_inputs
) -> Any:

# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
compilation_counter.num_graphs_seen += 1
from .monitor import torch_compile_start_time

# TODO: remove this after debugging
# try:
# src = graph.print_readable(print_output=False)
# except Exception:
# src = str(graph)
# try:
# with open('mirage_backends_graph.txt', 'w') as f:
# logger.info('Writing readable FX graph to mirage_backends_graph.txt')
# f.write(src)
# logger.info('Readable FX graph written to mirage_backends_graph.txt')
# except Exception:
# logger.exception('Failed to write mirage_backends_graph.txt')

dynamo_time = time.time() - torch_compile_start_time
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
self.compilation_config.compilation_time += dynamo_time

# we control the compilation process, each instance can only be
# called once
assert not self._called, "MirageBackend can only be called once"

placeholders = [node for node in graph.graph.nodes if node.op == 'placeholder']
assert len(placeholders) == len(example_inputs)

transfered_tensor_names = transfer_tensor_names(placeholders)

max_input_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens

# TODO(Jianan Ji): remove this after debugging
# with open('mirage_backends_graph.txt', 'w') as f:
# f.write(graph.print_readable(print_output=False))

Check failure on line 285 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/compilation/mirage_backend.py:285:9: F841 Local variable `max_input_tokens` is assigned to but never used
# with open("graph_structure.txt", "w", encoding="utf-8") as f:
# f.write(str(graph.graph))


self._called = True
self.compiled = False

def compile_or_call(*args):
dumb_run_called = (get_forward_context().attn_metadata is None)
if dumb_run_called:
return graph(*args)

if not self.compiled:
# Compile only at the first call -- when we get real tensors
logger.info("[Mirage] Calling compile_or_call for the first time, compiling......!")
mpk_metadata = build_mpk_metadata(
self.vllm_config,
args,
transfered_tensor_names,

Check failure on line 304 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/mirage_backend.py:304:89: E501 Line too long (100 > 88)
)
logger.info(f"[Mirage] MPK metadata: {mpk_metadata.info_as_string()}")
self.mpk = MPK(mpk_metadata)
self.mpk.build()
self.mpk.compile(output_dir=os.path.join(os.path.dirname(__file__), "mirage_backend_output"))

Check failure on line 310 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/compilation/mirage_backend.py:310:29: G004 Logging statement uses f-string
self.compiled = True

default_stream = torch.cuda.current_stream()

Check failure on line 313 in vllm/compilation/mirage_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/mirage_backend.py:313:89: E501 Line too long (109 > 88)
result_hidden_states = self.mpk(default_stream = default_stream)

return (result_hidden_states,)

return compile_or_call
3 changes: 3 additions & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,9 @@ def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
]:
if self.backend in torch_backends:
return self.backend
if self.backend == "mirage_byname":
from vllm.compilation.mirage_backend import MirageBackend
return MirageBackend(vllm_config)
return resolve_obj_by_qualname(self.backend)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return resolve_obj_by_qualname(self.backend)
if self.backend == "mirage":
from vllm.compilation.mirage_backend import MirageBackend
return MirageBackend(vllm_config)
return resolve_obj_by_qualname(self.backend)


assert self.mode == CompilationMode.VLLM_COMPILE
Expand Down
12 changes: 11 additions & 1 deletion vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,10 @@
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
disable_cuda_graph = (
(self.model_config is not None and self.model_config.enforce_eager)
)
if disable_cuda_graph:
logger.info("Cudagraph is disabled under eager mode")
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# override related settings when enforce eager
Expand All @@ -703,6 +706,13 @@
self._set_cudagraph_sizes()
else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

if self.compilation_config.backend == "mirage_byname":
if envs.VLLM_ATTENTION_BACKEND is None:
envs.VLLM_ATTENTION_BACKEND = "MIRAGE"
elif envs.VLLM_ATTENTION_BACKEND != "MIRAGE":
raise ValueError(f"Have to use MIRAGE attention backend when using mirage backend. Now it is {envs.VLLM_ATTENTION_BACKEND}")

Check failure on line 714 in vllm/config/vllm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/vllm.py:714:89: E501 Line too long (140 > 88)
assert self.cache_config.block_size % 64 == 0, "Block size must be a multiple of 64 for mirage backend."

Check failure on line 715 in vllm/config/vllm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/vllm.py:715:89: E501 Line too long (116 > 88)

if self.cache_config.kv_sharing_fast_prefill:
if (
Expand Down
Loading