From c4999066fa7516998096f4a1ed35808af4c46d19 Mon Sep 17 00:00:00 2001 From: Jianan Ji Date: Sun, 19 Oct 2025 08:34:32 +0000 Subject: [PATCH 01/12] init mirage integration --- vllm/attention/backends/registry.py | 1 + vllm/compilation/mirage_backend.py | 270 ++++++++ vllm/v1/attention/backends/mirage.py | 967 +++++++++++++++++++++++++++ 3 files changed, 1238 insertions(+) create mode 100644 vllm/compilation/mirage_backend.py create mode 100755 vllm/v1/attention/backends/mirage.py diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 125e4e382774..87212c30d793 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -76,6 +76,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): "RocmAiterUnifiedAttentionBackend" ) CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend" + MIRAGE = "vllm.v1.attention.backends.mirage.MirageBackend" # Placeholder for third-party/custom backends - must be registered before use CUSTOM = "" diff --git a/vllm/compilation/mirage_backend.py b/vllm/compilation/mirage_backend.py new file mode 100644 index 000000000000..9c5ab88bc075 --- /dev/null +++ b/vllm/compilation/mirage_backend.py @@ -0,0 +1,270 @@ +from collections import defaultdict +from .backends import * +from mirage.mpk import MPK, MPKMetadata +import re +from vllm.config import get_current_vllm_config +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.models.utils import extract_layer_index +import torch +def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]: + """Transfer FX placeholder debug names to model-like dotted names. + + Example: + l_self_modules_layers_modules_17_modules_mlp_modules_gate_up_proj_parameters_weight_ + -> model.layers.17.mlp.gate_up_proj.weight + + Notes: + - Tailored for Qwen3-style module names seen in exported FX graphs. + - We do NOT rename the FX node identifiers (dots are not valid in FX names). + Instead, we annotate via node.meta['logical_name'] and return the list. + """ + converted_names = [] + s_pattern = re.compile(r"^s\d+$") + + for node in placeholders: + name = node.name + if name == 'l_input_ids_': + final_name = 'input_ids' + converted_names.append(final_name) + elif name == 'l_positions_': + final_name = 'positions' + converted_names.append(final_name) + elif s_pattern.match(name): # s72 / s80 + converted_names.append(name) + else: + if name.startswith('l_self_modules_'): + name = name.replace('l_self_modules_', '', 1) + if name.endswith('_'): + name = name[:-1] + + name = name.replace('_modules_', '.') + name = name.replace('_parameters_', '.') + + final_name = 'model.' + name + + converted_names.append(final_name) + + return converted_names + + +# @dataclass +# class MPKMetadata: +# # ---------- MPK class external state bundled here ---------- +# # args +# mode: str = "offline" +# total_num_requests: int = 1 +# num_remote_schedulers: int = 0 +# max_seq_length: int = 0 +# max_num_batched_requests: int = 0 +# max_num_batched_tokens: int = 0 +# max_num_pages: int = 0 +# page_size: int = 0 +# max_sm_num: int = 108 +# device: str = "cuda" +# # model +# weight_from_model: bool +# model_name: Optional[str] # For now, model_name must be provided +# model_path: Optional[str] = None +# # fx graph +# state_dict: Optional[dict] = None +# # Meta tensors +# step: Optional[torch.Tensor] = None +# tokens: Optional[torch.Tensor] = None +# input_tokens: Optional[torch.Tensor] = None +# output_tokens: Optional[torch.Tensor] = None +# num_new_tokens: Optional[torch.Tensor] = None +# prompt_lengths: Optional[torch.Tensor] = None +# qo_indptr_buffer: Optional[torch.Tensor] = None +# paged_kv_indptr_buffer: Optional[torch.Tensor] = None +# paged_kv_indices_buffer: Optional[torch.Tensor] = None +# paged_kv_last_page_len_buffer: Optional[torch.Tensor] = None +# # profiling +# profiler_tensor: Optional[torch.Tensor] = None +# trace_name: Optional[str] = None +# # spec decode config +# spec_decode_config: Optional[object] = None + + +def build_mpk_metadata( + vllm_config: VllmConfig, + forward_context: ForwardContext, + state_dict: dict[str, torch.Tensor], + k_cache_tensors: list[torch.Tensor], + v_cache_tensors: list[torch.Tensor] + ) -> MPKMetadata: + model_config = vllm_config.model_config + scheduler_config = vllm_config.scheduler_config + cache_config = vllm_config.cache_config + parallel_config = vllm_config.parallel_config + attn_metadata = forward_context.attn_metadata + mpk_metadata = MPKMetadata( + mode = "online" + # total_num_requests + # num_remote_schedulers: int = 0 + max_seq_length = model_config.max_model_len, + max_num_batched_requests = scheduler_config.max_num_seqs, + max_num_batched_tokens = scheduler_config.max_num_batched_tokens, + max_num_pages = cache_config.num_gpu_blocks + page_size = cache_config.block_size + # max_sm_num: int = 108 + device: str = "cuda" + # # model + # weight_from_model: bool + model_name = model_config.model_name, + # model_path: Optional[str] = None + # multi device support + world_size = parallel_config.world_size + rank = parallel_config.rank + # # fx graph + state_dict = state_dict + # # Meta tensors + # step: Optional[torch.Tensor] = None + # tokens: Optional[torch.Tensor] = None + # input_tokens: Optional[torch.Tensor] = None + # output_tokens: Optional[torch.Tensor] = None + # num_new_tokens: Optional[torch.Tensor] = None + # prompt_lengths: Optional[torch.Tensor] = None + qo_indptr_buffer = attn_metadata.qo_indptr_gpu + paged_kv_indptr_buffer = attn_metadata.paged_kv_indptr_gpu + paged_kv_indices_buffer = attn_metadata.paged_kv_indices_gpu + paged_kv_last_page_len_buffer = attn_metadata.paged_kv_last_page_len_gpu + # kv cache tensors + k_cache_tensors = k_cache_tensors + v_cache_tensors = v_cache_tensors + # # profiling + # profiler_tensor: Optional[torch.Tensor] = None + # trace_name: Optional[str] = None + # # spec decode config + # spec_decode_config: Optional[object] = None + ) + return mpk_metadata + +class MirageBackend: + """The compilation backend for `torch.compile` with vLLM. + It is used for compilation level of `CompilationLevel.PIECEWISE`, + where we customize the compilation. + + The major work of this backend is to split the graph into + piecewise graphs, and pass them to the piecewise backend. + + This backend also adds the PostGradPassManager to Inductor config, + which handles the post-grad passes. + """ + + vllm_config: VllmConfig + compilation_config: CompilationConfig + _called: bool = False + # the graph we compiled + graph: fx.GraphModule + # the stiching graph module for all the piecewise graphs + split_gm: fx.GraphModule + piecewise_graphs: list[SplitItem] + returned_callable: Callable + # Inductor passes to run on the graph pre-defunctionalization + post_grad_passes: Sequence[Callable] + sym_tensor_indices: list[int] + input_buffers: list[torch.Tensor] + compiler_manager: CompilerManager + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + # if the model is initialized with a non-empty prefix, + # then usually it's enough to use that prefix, + # e.g. language_model, vision_model, etc. + # when multiple parts are initialized as independent + # models, we need to use the model_tag to distinguish + # them, e.g. backbone (default), eagle_head, etc. + self.prefix = prefix or model_tag + + # Passes to run on the graph post-grad. + self.post_grad_pass_manager = PostGradPassManager() + + self.sym_tensor_indices = [] + self.input_buffers = [] + + self.vllm_config = vllm_config + self.model_name = vllm_config.model_config.model_name + self.compilation_config = vllm_config.compilation_config + + self.compiler_manager: CompilerManager = CompilerManager( + self.compilation_config + ) + + def __call__( + self, graph: fx.GraphModule, example_inputs + ) -> VllmSerializableFunction: + + # when dynamo calls the backend, it means the bytecode + # transform and analysis are done + compilation_counter.num_graphs_seen += 1 + from .monitor import torch_compile_start_time + + dynamo_time = time.time() - torch_compile_start_time + logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) + self.compilation_config.compilation_time += dynamo_time + + # we control the compilation process, each instance can only be + # called once + assert not self._called, "MirageBackend can only be called once" + + placeholders = [node for node in graph.graph.nodes if node.op == 'placeholder'] + assert len(placeholders) == len(example_inputs) + + transfered_tensor_names = transfer_tensor_names(placeholders) + + forward_context = get_forward_context() + static_forward_context = forward_context.no_compile_layers # layer names to layers + + k_cache_tensors = [] + v_cache_tensors = [] + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in static_forward_context.keys(): + index2name[extract_layer_index(layer_name, 1)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + assert len(layer_names) == 1, "Multiple layers with the same layer index are not supported" + layer_name = layer_names[0] + k_cache_tensors.append(static_forward_context[layer_name].kv_cache[0]) + v_cache_tensors.append(static_forward_context[layer_name].kv_cache[1]) + # kv_cache_tensors shape: num_layers * (2, num_blocks, block_size, num_kv_heads, head_size) + + self._called = True + self.compiled = False + + def compile_or_call(*args): + if not self.compiled: + # Compile only at the first call -- when we get real tensors + state_dict = {} + for arg, name in zip(args, transfered_tensor_names): + if name == 'input_ids': + input_tensor = arg + elif name == 'positions': + positions_tensor = arg + else: + state_dict[name] = arg + vllm_config = get_current_vllm_config() + + model_config = vllm_config.model_config + mpk_metadata = build_mpk_metadata( + vllm_config, + forward_context, + state_dict, + k_cache_tensors, + v_cache_tensors + ) + self.mpk = MPK(mpk_metadata) + self.mpk.build() + self.mpk.compile() + + self.compiled = True + + return self.mpk() + + return VllmSerializableFunction( + graph, example_inputs, self.prefix, compile_or_call + ) \ No newline at end of file diff --git a/vllm/v1/attention/backends/mirage.py b/vllm/v1/attention/backends/mirage.py new file mode 100755 index 000000000000..d7df534bfeec --- /dev/null +++ b/vllm/v1/attention/backends/mirage.py @@ -0,0 +1,967 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlashInfer.""" + +from dataclasses import dataclass +from typing import ClassVar + +import numpy as np +import torch +from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + MultiLevelCascadeAttentionWrapper, +) +from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache +from flashinfer.prefill import trtllm_batch_context_with_kv_cache +from flashinfer.utils import FP4Tensor + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, + MultipleOf, +) +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8StaticTensorSym, + kNvfp4Quant, +) +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils import cdiv, is_pin_memory_available +from vllm.utils.flashinfer import ( + can_use_trtllm_attention, + flashinfer_disable_q_quantization, + use_trtllm_attention, +) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_kv_cache_layout, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills, +) +from vllm.v1.kv_cache_interface import AttentionSpec + +FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 + +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + +logger = init_logger(__name__) + +class MirageAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + # TODO: (Jianan) Make sure these are correct. + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [128] + + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [16, 32, 64, 4096] + + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes." + ) + + @staticmethod + def get_name() -> str: + return "MPK_ATTENTION" + + @staticmethod + def get_impl_cls() -> type["MirageAttentionImpl"]: + return MirageAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["MirageAttentionMetadata"]: + return MirageAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["MirageAttentionMetadataBuilder"]: + return MirageAttentionMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order() -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets us from + # `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_cache_layout() + if cache_layout == "NHD": + stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND": + stride_order = (0, 1, 3, 2, 4) + else: + raise ValueError(f"Unknown cache layout format {cache_layout}.") + return stride_order + +@dataclass +class MirageAttentionMetadata: + num_actual_tokens: int # Number of tokens excluding padding. + + # The data type of the query + q_data_type: torch.dtype + + slot_mapping: torch.Tensor + + # For flashinfer trtllm batch decode + max_q_len: int + max_q_len_prefill: int + max_seq_len: int + seq_lens: torch.Tensor + block_table_tensor: torch.Tensor + prefill_use_trtllm: bool + decode_use_trtllm: bool + + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + + # Meta tensors + qo_indptr_gpu: torch.Tensor | None = None + paged_kv_indptr_gpu: torch.Tensor | None = None + paged_kv_indices_gpu: torch.Tensor | None = None + paged_kv_last_page_len_gpu: torch.Tensor | None = None + + +class MirageAttentionMetadataBuilder(AttentionMetadataBuilder[MirageAttentionMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) + + reorder_batch_threshold: int = 1 + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.cache_config = vllm_config.cache_config + self.model_config = vllm_config.model_config + self._workspace_buffer = None + self._prefill_wrapper = None # Wrapper for prefill/append + self._decode_wrapper = None # Wrapper for decode (general shape) + + if vllm_is_batch_invariant(): + self.decode_fixed_split_size = 2048 + self.prefill_fixed_split_size = 4096 + self.disable_split_kv = True + else: + self.decode_fixed_split_size = -1 + self.prefill_fixed_split_size = -1 + self.disable_split_kv = False + + self.compilation_config = vllm_config.compilation_config + max_num_pages_per_req = cdiv( + self.model_config.max_model_len, self.kv_cache_spec.block_size + ) + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + max_num_pages = max_num_reqs * max_num_pages_per_req + + self.num_qo_heads = self.model_config.get_num_attention_heads( + self.vllm_config.parallel_config + ) + self.num_kv_heads = self.kv_cache_spec.num_kv_heads + self.head_dim = self.kv_cache_spec.head_size + self.page_size = self.kv_cache_spec.block_size + + + # Preparing persistent buffers (device-side) + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=self.device + ) + self.paged_kv_indices = torch.zeros( + max_num_pages, # max num pages possible + dtype=torch.int32, + device=self.device, + ) + self.paged_kv_last_page_len = torch.zeros( + max_num_reqs, dtype=torch.int32, device=self.device + ) + # host-side buffer + pin_memory = is_pin_memory_available() + self.paged_kv_indptr_cpu = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() + self.paged_kv_indptr_buffer = torch.zeros_like( + self.paged_kv_indptr_cpu, pin_memory=pin_memory + ) + self.paged_kv_indices_cpu = torch.zeros( + max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_cpu = torch.zeros( + max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy() + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> MirageAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=True, + ) + ) + + page_size = self.page_size + max_q_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens_np = seq_lens_cpu.numpy() + block_table_tensor = common_attn_metadata.block_table_tensor + + num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size + + # write self.paged_kv_indptr_cpu inplace (0-index is always 0) + np.cumsum( + num_blocks_np, + dtype=np.int32, + out=self.paged_kv_indptr_np[1 : num_reqs + 1], + ) + # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified + # after this line (e.g., for cuda graphs), we need to copy the data to + # self.paged_kv_indptr_buffer to avoid race condition. + self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[ + : num_reqs + 1 + ] + paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1] + paged_kv_indptr.copy_( + self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True + ) + + # write self.paged_kv_indices inplace + num_actual_pages = self.paged_kv_indptr_np[num_reqs] + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + _copy_page_indices_kernel[(num_reqs,)]( + paged_kv_indices, + block_table_tensor, + block_table_tensor.stride(0), + paged_kv_indptr, + BLOCK_SIZE=1024, + ) + + # write self.paged_kv_last_page_len_cpu inplace + paged_kv_last_page_len_np = seq_lens_np % page_size + self.paged_kv_last_page_len_np[:num_reqs] = np.where( + paged_kv_last_page_len_np == 0, + page_size, + paged_kv_last_page_len_np, + ) + + uses_spec_reorder = self.reorder_batch_threshold > 1 + + assert self.q_data_type == torch.float16, "MirageAttentionBackend currently only supports float16" + + if not (prefill_use_trtllm and decode_use_trtllm): + if self.has_sinks: + raise NotImplementedError( + "FlashInfer backend currently does not support attention " + "sinks, please use trtllm on blackwell or flash attention " + "on earlier GPUs." + ) + + if not self.global_hyperparameters.has_same_window_lefts: + raise ValueError( + "Window left is not the same for all layers. " + "One potential fix is to set disable_sliding_window=True" + ) + + assert self.global_hyperparameters.has_same_all_params, ( + "FlashInfer backend currently only supports models in which " + "all layers share the same values for the following " + "hyperparameters: `window_left`, `logits_soft_cap`, " + "`sm_scale`." + ) + + # The q quantization is not supported for non-trtllm attention, + # fall back to model dtype. + self.q_data_type = self.model_config.dtype + + attn_metadata = MirageAttentionMetadata( + num_actual_tokens=num_actual_tokens, + q_data_type=self.q_data_type, + slot_mapping=common_attn_metadata.slot_mapping, + max_q_len=max_q_len, + max_q_len_prefill=max_q_len, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table_tensor=block_table_tensor, + prefill_use_trtllm=prefill_use_trtllm, + decode_use_trtllm=decode_use_trtllm, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + ) + + qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu + paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] + paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] + + # Regular attention (common case). + # Decodes are at the front and prefills are at the back. + num_prefills = attn_metadata.num_prefills + num_decodes = attn_metadata.num_decodes + if num_prefills > 0: + # Decodes are first so prefills start after the last decode + prefill_start = num_decodes + attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 + assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 + assert ( + paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills + ) + # Since prefill_wrapper.run() will be called with + # query[num_decode_tokens:] we need to adjust the qo_indptr + # to be relative to the start of the prefill queries. + qo_indptr_cpu = ( + qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start] + ) + paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] + + # Recompute max_q_len for the slice of requests we are using + # for prefills. This can be different from max_q_len when + # we have a non-uniform batch with some short decodes offloaded + # to the prefill pathway + query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] + attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) + + attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( + self.device, non_blocking=True + ) + attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( + self.device, non_blocking=True + ) + + if num_decodes > 0: + pure_decode = num_prefills == 0 + # possible required padding for cudagraph replay + use_cudagraph = ( + self.enable_cuda_graph + and pure_decode + and num_decode_tokens <= self._decode_cudagraph_max_bs + ) + if use_cudagraph: + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_decode_tokens + ) + # Carefully fulfill the padding region with reasonable value + # on cpu. + # Make sure paged_kv_indptr_cpu is not decreasing + self.paged_kv_indptr_cpu[ + 1 + num_decodes : 1 + num_input_tokens + ].fill_(paged_kv_indptr_cpu[-1]) + # Fill the remaining paged_kv_last_page_len_cpu with 1. + # This is because flashinfer treats 0 as a full page + # instead of empty. + self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_( + 1 + ) + + else: + num_input_tokens = num_decode_tokens + + attn_metadata.decode_wrapper = self._get_decode_wrapper( + num_input_tokens, use_cudagraph + ) + if not attn_metadata.decode_use_trtllm: + # Use the persistent buffer with padding length, + # instead of the same address but chunked version + # in atten_metadata when using cudagraph. + fast_plan_decode( + attn_metadata.decode_wrapper, + self.paged_kv_indptr_cpu[: num_input_tokens + 1], + paged_kv_indices, + self.paged_kv_last_page_len_cpu[:num_input_tokens], + seq_lens_cpu[:num_input_tokens], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + fixed_split_size=self.decode_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) + return attn_metadata + + +class MirageAttentionImpl(AttentionImpl): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: int | None = None, + sinks: torch.Tensor | None = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.window_left = ( + self.sliding_window[0] if self.sliding_window is not None else -1 + ) + self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if attn_type != AttentionType.DECODER: + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl" + ) + + self.sinks: torch.Tensor | None = None + if sinks is not None: + if sinks.shape[0] != num_heads: + raise ValueError( + "Sinks must have the same number of heads as the number of " + f"heads in the layer. Expected {num_heads}, but got " + f"{sinks.shape[0]}." + ) + self.sinks = sinks + + self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) + self.bmm1_scale: float | None = None + self.bmm2_scale: float | None = None + self.o_sf_scale: float | None = None + + def fused_output_quant_supported(self, quant_key: QuantKey): + return ( + self.support_trtllm_attn + and self.kv_cache_dtype.startswith("fp8") + and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) + ) + + def supports_quant_query_input(self) -> bool: + if flashinfer_disable_q_quantization(): + return False + + return self.support_trtllm_attn + + # FlashInfer requires attention sinks to be float32 + def process_weights_after_loading(self, act_dtype: torch.dtype): + if self.sinks is not None and self.sinks.dtype != torch.float32: + self.sinks = self.sinks.to(torch.float32) + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with FlashInfer. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: KV cache tensor with different possible shapes: + - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + + # Ensure query dtype matches the expected dtype from attention metadata + assert attn_metadata.q_data_type == query.dtype, ( + f"Query dtype mismatch: expected {attn_metadata.q_data_type}, " + f"got {query.dtype}" + ) + + if self.bmm1_scale is None: + self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale + + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float + + # The attn+quant fusion happens when output_scale is provided. + if output_scale is None: + assert output_block_scale is None, ( + "output_block_scale is not supported when fusion has not happened" + ) + else: + assert attn_metadata.q_data_type == FP8_DTYPE, ( + "Query must be FP8 when attn+quant fusion happened." + ) + assert ( + attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm + ), "Must use TRT-LLM attn" + + if output.dtype == FP8_DTYPE: + assert output_block_scale is None, ( + "output_block_scale should not be provided for fp8 output" + ) + elif output.dtype == FP4_DTYPE: + assert output_block_scale is not None, ( + "output_block_scale is required for nvfp4 output" + ) + else: + raise ValueError(f"Unsupported output dtype: {output.dtype}") + + # TRTLLM attn kernel requires to scale to pass as a host scalar, + # store the o scale as a host scalar in warmup run with cuda graph + # not enabled + if layer._o_scale_float is None: + layer._o_scale_float = output_scale.cpu().item() + if output.dtype == FP8_DTYPE: + self.bmm2_scale = self.bmm2_scale / layer._o_scale_float + elif output.dtype == FP4_DTYPE: + self.o_sf_scale = layer._o_scale_float + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if self.kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.kv_cache_dtype + ) + kv_cache = kv_cache.view(torch_dtype) + + # Inputs and outputs may be padded for CUDA graphs + query = query[:num_actual_tokens] + output_padded = output + output = output[:num_actual_tokens] + + if attn_metadata.use_cascade: + # Cascade attention (rare case). + assert attn_metadata.cascade_wrapper is not None + output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) + return output + + # When using spec decoding, num_decodes can be < num_decode_tokens + # because some decode requests may have more than one query token. + num_decodes = attn_metadata.num_decodes + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens + + stride_order = FlashInferBackend.get_kv_cache_stride_order() + kv_cache_permute = kv_cache.permute(*stride_order) + # Regular attention (common case). + # Decodes are at the front and prefills are at the back. + if num_prefill_tokens > 0: + prefill_wrapper = attn_metadata.prefill_wrapper + prefill_query = query[num_decode_tokens:] + assert prefill_query.shape[0] == num_prefill_tokens + assert prefill_wrapper is not None + + if not attn_metadata.prefill_use_trtllm: + assert prefill_wrapper._causal + assert prefill_wrapper._window_left == self.window_left + assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) + assert prefill_wrapper._sm_scale == self.scale + prefill_wrapper.run( + prefill_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[num_decode_tokens:], + ) + else: + # prefill_query may be non-contiguous + prefill_query = prefill_query.contiguous() + workspace_buffer = _get_trtllm_gen_workspace_buffer() + block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:] + seq_lens_prefill = attn_metadata.seq_lens[num_decodes:] + + # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND + assert get_kv_cache_layout() == "HND" + assert prefill_query.is_contiguous() + assert kv_cache_permute.is_contiguous() + assert workspace_buffer.is_contiguous() + assert block_tables_prefill.is_contiguous() + assert seq_lens_prefill.is_contiguous() + + if output.dtype == FP4_DTYPE: + assert self.o_sf_scale is not None + out = FP4Tensor( + data=output[num_decode_tokens:], + scale=output_block_scale, + scale_start_index=num_decode_tokens, + original_shape=prefill_query.shape, + ) + else: + assert self.o_sf_scale is None + out = output[num_decode_tokens:] + + if ( + attn_metadata.q_data_type != FP8_DTYPE + and self.kv_cache_dtype.startswith("fp8") + ): + # TRTLLM prefill attention does not support BF16 Q + # and fp8 kv cache. So to enable prefill attention + # with fp8 kv cache, we can construct a mock block + # and mock kv cache with BF16 KV involved in the prefill + mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant( + kv_cache_permute, + block_tables_prefill, + layer._k_scale, + layer._v_scale, + attn_metadata.q_data_type, + ) + else: + mock_kv_cache = kv_cache_permute + mock_block_table = block_tables_prefill + + trtllm_batch_context_with_kv_cache( + query=prefill_query, + kv_cache=mock_kv_cache, + workspace_buffer=workspace_buffer, + block_tables=mock_block_table, + seq_lens=seq_lens_prefill, + max_q_len=attn_metadata.max_q_len_prefill, + max_kv_len=attn_metadata.max_seq_len, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, + batch_size=attn_metadata.num_prefills, + cum_seq_lens_q=attn_metadata.qo_indptr_gpu, + cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, + window_left=self.window_left, + sinks=self.sinks, + o_sf_scale=self.o_sf_scale, + out=out, + ) + + if num_decode_tokens > 0: + decode_wrapper = attn_metadata.decode_wrapper + decode_query = query[:num_decode_tokens] + assert decode_query.shape[0] == num_decode_tokens + assert decode_wrapper is not None + + if not attn_metadata.decode_use_trtllm: + assert decode_wrapper._window_left == self.window_left + assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) + assert decode_wrapper._sm_scale == self.scale + decode_wrapper.run( + decode_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[:num_decode_tokens], + ) + else: + # decode_query may be non-contiguous + decode_query = decode_query.contiguous() + workspace_buffer = _get_trtllm_gen_workspace_buffer() + block_tables_decode = attn_metadata.block_table_tensor[ + :num_decode_tokens + ] + seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] + + # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND + assert get_kv_cache_layout() == "HND" + assert decode_query.is_contiguous() + assert kv_cache_permute.is_contiguous() + assert workspace_buffer.is_contiguous() + assert block_tables_decode.is_contiguous() + assert seq_lens_decode.is_contiguous() + + if output.dtype == FP4_DTYPE: + assert self.o_sf_scale is not None + out = FP4Tensor( + data=output[:num_decode_tokens], + scale=output_block_scale, + scale_start_index=0, + original_shape=decode_query.shape, + ) + else: + assert self.o_sf_scale is None + out = output[:num_decode_tokens] + + if num_decode_tokens % attn_metadata.num_decodes != 0: + # This gets triggered when the dummy_run forces + # attention to be initialized with q_len = 0 + q_len_per_req = 1 + else: + q_len_per_req = num_decode_tokens // attn_metadata.num_decodes + + trtllm_batch_decode_with_kv_cache( + query=decode_query, + kv_cache=kv_cache_permute, + workspace_buffer=workspace_buffer, + block_tables=block_tables_decode, + seq_lens=seq_lens_decode, + max_seq_len=attn_metadata.max_seq_len, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, + window_left=self.window_left, + sinks=self.sinks, + o_sf_scale=self.o_sf_scale, + out=out, + q_len_per_req=q_len_per_req, + ) + return output_padded + + +def fast_plan_decode( + self, # decode wrapper + indptr_cpu: torch.Tensor, + indices: torch.Tensor, + last_page_len_cpu: torch.Tensor, + seq_lens_cpu: torch.Tensor, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + pos_encoding_mode: str = "NONE", + window_left: int = -1, + logits_soft_cap: float | None = None, + q_data_type: str | torch.dtype | None = "float16", + kv_data_type: str | torch.dtype | None = None, + data_type: str | torch.dtype | None = None, + sm_scale: float | None = None, + rope_scale: float | None = None, + rope_theta: float | None = None, + non_blocking: bool = True, + fixed_split_size: int = -1, + disable_split_kv: bool = False, +) -> None: + """ + A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for + cudagraph capture/replay, while the no cudagraph version turns back + to the original plan. + using original plan after passing host-side buffers: + - only host-to-device copy of indptr and last_page_len buffers + Modifications for cudagraph: + - only host-to-device copy of indptr and last_page_len buffers. + - avoid device-to-device copy of indices buffer. + + Part of the code get inspiration from the original plan from FlashInfer repo + and the implementation of fast_decode_plan for FlashInfer in SGlang repo. + """ + # Warm up with the original plan if it is first call, and always run the + # original plan if we run for dynamic shape. For fixed shape (cudagraph), + # this warm up is to generate the _cached_module for the decode wrapper. + if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True): + self.plan( + indptr_cpu, + indices, + last_page_len_cpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode, + window_left, + logits_soft_cap, + q_data_type, + kv_data_type, + data_type, + sm_scale, + rope_scale, + rope_theta, + non_blocking, + None, # block_tables + None, # seq_lens + fixed_split_size, + disable_split_kv, + ) + self.vllm_first_call = False + return + + assert self.is_cuda_graph_enabled, "Should be cudagraph only here" + + batch_size = len(last_page_len_cpu) + if logits_soft_cap is None: + logits_soft_cap = 0.0 + + # Handle data types consistently + if data_type is not None: + if q_data_type is None: + q_data_type = data_type + if kv_data_type is None: + kv_data_type = data_type + elif q_data_type is None: + q_data_type = "float16" + + if kv_data_type is None: + kv_data_type = q_data_type + q_data_type = ( + getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type + ) + kv_data_type = ( + getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type + ) + + if batch_size != self._fixed_batch_size: + raise ValueError( + "The batch size should be fixed in cudagraph mode, the runtime " + "batch size {} mismatches the batch size set during " + "initialization {}".format(batch_size, self._fixed_batch_size) + ) + if len(indices) > len(self._paged_kv_indices_buf): + raise ValueError( + "The size of indices should be less than or equal to the allocated buffer" + ) + + # host-to-device copy for the indptr buffer + self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True) + # host-to-device copy for the last_page_len buffer + self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) + + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + + try: + # Make sure we pass exactly 18 arguments for tensor core version + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_cpu, + seq_lens_cpu, + batch_size, # total_num_rows + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + head_dim, + head_dim, + False, # causal + window_left, + fixed_split_size, + disable_split_kv, + ) + except Exception as e: + raise RuntimeError(f"Error in tensor core plan: {e}") from e + + self._pos_encoding_mode = pos_encoding_mode + self._window_left = window_left + self._logits_soft_cap = logits_soft_cap + self._sm_scale = sm_scale + self._rope_scale = rope_scale + self._rope_theta = rope_theta + + +@triton.jit +def _copy_page_indices_kernel( + page_indices, + block_table, + block_table_stride, + cu_num_blocks, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = block_table + req_idx * block_table_stride + start_idx = tl.load(cu_num_blocks + req_idx) + end_idx = tl.load(cu_num_blocks + req_idx + 1) + num_blocks = end_idx - start_idx + + offset = tl.arange(0, BLOCK_SIZE) + for i in tl.range(0, num_blocks, BLOCK_SIZE): + block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) + tl.store( + page_indices + start_idx + i + offset, + block_ids, + mask=i + offset < num_blocks, + ) From f41eb7f05da88f7bfd616834157534bd715c5839 Mon Sep 17 00:00:00 2001 From: Jianan Ji Date: Mon, 20 Oct 2025 17:34:26 +0000 Subject: [PATCH 02/12] mirage compilation and its attention backends --- vllm/compilation/mirage_backend.py | 211 +++++---- vllm/v1/attention/backends/mirage.py | 659 +-------------------------- 2 files changed, 120 insertions(+), 750 deletions(-) diff --git a/vllm/compilation/mirage_backend.py b/vllm/compilation/mirage_backend.py index 9c5ab88bc075..033fdc551ad6 100644 --- a/vllm/compilation/mirage_backend.py +++ b/vllm/compilation/mirage_backend.py @@ -1,11 +1,15 @@ from collections import defaultdict from .backends import * -from mirage.mpk import MPK, MPKMetadata +from mirage import MPK, MPKMetadata, MirageModelConfig import re -from vllm.config import get_current_vllm_config +from vllm.config import ModelConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.models.utils import extract_layer_index import torch +from vllm.logger import init_logger + +logger = init_logger(__name__) + def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]: """Transfer FX placeholder debug names to model-like dotted names. @@ -46,91 +50,113 @@ def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]: return converted_names - -# @dataclass -# class MPKMetadata: -# # ---------- MPK class external state bundled here ---------- -# # args -# mode: str = "offline" -# total_num_requests: int = 1 -# num_remote_schedulers: int = 0 -# max_seq_length: int = 0 -# max_num_batched_requests: int = 0 -# max_num_batched_tokens: int = 0 -# max_num_pages: int = 0 -# page_size: int = 0 -# max_sm_num: int = 108 -# device: str = "cuda" -# # model -# weight_from_model: bool -# model_name: Optional[str] # For now, model_name must be provided -# model_path: Optional[str] = None -# # fx graph -# state_dict: Optional[dict] = None -# # Meta tensors -# step: Optional[torch.Tensor] = None -# tokens: Optional[torch.Tensor] = None -# input_tokens: Optional[torch.Tensor] = None -# output_tokens: Optional[torch.Tensor] = None -# num_new_tokens: Optional[torch.Tensor] = None -# prompt_lengths: Optional[torch.Tensor] = None -# qo_indptr_buffer: Optional[torch.Tensor] = None -# paged_kv_indptr_buffer: Optional[torch.Tensor] = None -# paged_kv_indices_buffer: Optional[torch.Tensor] = None -# paged_kv_last_page_len_buffer: Optional[torch.Tensor] = None -# # profiling -# profiler_tensor: Optional[torch.Tensor] = None -# trace_name: Optional[str] = None -# # spec decode config -# spec_decode_config: Optional[object] = None - +def build_model_config( + model_config: ModelConfig, + state_dict: dict[str, torch.Tensor], + k_cache_tensors: list[torch.Tensor], + v_cache_tensors: list[torch.Tensor], + position_embeddings: torch.Tensor, +) -> MirageModelConfig: + mirage_model_config = MirageModelConfig( + # model architecture + hidden_size=model_config.get_hidden_size(), + intermediate_size=getattr(model_config.hf_text_config, "intermediate_size", 0), + vocab_size=model_config.get_vocab_size(), + num_q_heads=model_config.get_num_attention_heads(), + num_kv_heads=model_config.get_num_kv_heads(), + head_dim=model_config.get_head_size(), + num_layers=getattr(model_config.hf_text_config, "num_hidden_layers", 0), + # kv cache + k_cache=k_cache_tensors, + v_cache=v_cache_tensors, + # position embeddings + position_embeddings=position_embeddings, + # model weights + state_dict=state_dict, + ) + return mirage_model_config def build_mpk_metadata( vllm_config: VllmConfig, forward_context: ForwardContext, - state_dict: dict[str, torch.Tensor], - k_cache_tensors: list[torch.Tensor], - v_cache_tensors: list[torch.Tensor] + args: list[Any], + transfered_tensor_names: list[str], ) -> MPKMetadata: model_config = vllm_config.model_config scheduler_config = vllm_config.scheduler_config cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config attn_metadata = forward_context.attn_metadata + + static_forward_context = forward_context.no_compile_layers # layer names to layers + k_cache_tensors = [] + v_cache_tensors = [] + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in static_forward_context.keys(): + index2name[extract_layer_index(layer_name, 1)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + assert len(layer_names) == 1, "Multiple layers with the same layer index are not supported" + layer_name = layer_names[0] + logger.info(f"{layer_index} {layer_name}: attention num: {len(static_forward_context[layer_name].kv_cache)}; kv_cache.shape: {static_forward_context[layer_name].kv_cache[0].shape}") + k_cache_tensors.append(static_forward_context[layer_name].kv_cache[0][0]) + v_cache_tensors.append(static_forward_context[layer_name].kv_cache[0][1]) + # kv_cache_tensors shape: num_layers * (2, num_blocks, block_size, num_kv_heads, head_size) + + state_dict = {} + input_token_ids = None + positions_tensor = None + position_embeddings = None + for arg, name in zip(args, transfered_tensor_names): + if name == 'input_ids': + input_token_ids = arg + elif name == 'positions': + positions_tensor = arg + elif "cos_sin_cache" in name: + position_embeddings = arg + else: + state_dict[name] = arg + + mirage_model_config = build_model_config( + model_config, + state_dict, + k_cache_tensors, + v_cache_tensors, + position_embeddings + ) mpk_metadata = MPKMetadata( - mode = "online" + mode = "online", # total_num_requests # num_remote_schedulers: int = 0 max_seq_length = model_config.max_model_len, max_num_batched_requests = scheduler_config.max_num_seqs, max_num_batched_tokens = scheduler_config.max_num_batched_tokens, - max_num_pages = cache_config.num_gpu_blocks - page_size = cache_config.block_size + max_num_pages = cache_config.num_gpu_blocks, + page_size = cache_config.block_size, # max_sm_num: int = 108 - device: str = "cuda" + device = "cuda", # # model - # weight_from_model: bool + weight_from_model = False, model_name = model_config.model_name, # model_path: Optional[str] = None # multi device support - world_size = parallel_config.world_size - rank = parallel_config.rank - # # fx graph - state_dict = state_dict + world_size = parallel_config.world_size, + rank = parallel_config.rank, # # Meta tensors - # step: Optional[torch.Tensor] = None + step = positions_tensor, # tokens: Optional[torch.Tensor] = None - # input_tokens: Optional[torch.Tensor] = None + input_tokens = input_token_ids, # output_tokens: Optional[torch.Tensor] = None # num_new_tokens: Optional[torch.Tensor] = None # prompt_lengths: Optional[torch.Tensor] = None - qo_indptr_buffer = attn_metadata.qo_indptr_gpu - paged_kv_indptr_buffer = attn_metadata.paged_kv_indptr_gpu - paged_kv_indices_buffer = attn_metadata.paged_kv_indices_gpu - paged_kv_last_page_len_buffer = attn_metadata.paged_kv_last_page_len_gpu - # kv cache tensors - k_cache_tensors = k_cache_tensors - v_cache_tensors = v_cache_tensors + qo_indptr_buffer = attn_metadata.qo_indptr_gpu, + paged_kv_indptr_buffer = attn_metadata.paged_kv_indptr_gpu, + paged_kv_indices_buffer = attn_metadata.paged_kv_indices_gpu, + paged_kv_last_page_len_buffer = attn_metadata.paged_kv_last_page_len_gpu, + # kv cache tensors, weights and model config + model_config=mirage_model_config, # # profiling # profiler_tensor: Optional[torch.Tensor] = None # trace_name: Optional[str] = None @@ -156,15 +182,8 @@ class MirageBackend: _called: bool = False # the graph we compiled graph: fx.GraphModule - # the stiching graph module for all the piecewise graphs - split_gm: fx.GraphModule - piecewise_graphs: list[SplitItem] - returned_callable: Callable - # Inductor passes to run on the graph pre-defunctionalization - post_grad_passes: Sequence[Callable] - sym_tensor_indices: list[int] + input_buffers: list[torch.Tensor] - compiler_manager: CompilerManager def __init__( self, @@ -177,25 +196,23 @@ def __init__( # when multiple parts are initialized as independent # models, we need to use the model_tag to distinguish # them, e.g. backbone (default), eagle_head, etc. + logger.info("[Mirage] Calling MirageBackend init!") self.prefix = prefix or model_tag # Passes to run on the graph post-grad. self.post_grad_pass_manager = PostGradPassManager() - self.sym_tensor_indices = [] self.input_buffers = [] self.vllm_config = vllm_config - self.model_name = vllm_config.model_config.model_name self.compilation_config = vllm_config.compilation_config + self.model_config = vllm_config.model_config + self.model_name = vllm_config.model_config.model - self.compiler_manager: CompilerManager = CompilerManager( - self.compilation_config - ) def __call__( self, graph: fx.GraphModule, example_inputs - ) -> VllmSerializableFunction: + ) -> Any: # when dynamo calls the backend, it means the bytecode # transform and analysis are done @@ -216,55 +233,31 @@ def __call__( transfered_tensor_names = transfer_tensor_names(placeholders) forward_context = get_forward_context() - static_forward_context = forward_context.no_compile_layers # layer names to layers - k_cache_tensors = [] - v_cache_tensors = [] - # Convert kv_caches dict to a list of tensors in the order of layer_index. - index2name = defaultdict(list) - for layer_name in static_forward_context.keys(): - index2name[extract_layer_index(layer_name, 1)].append(layer_name) - - for layer_index in sorted(index2name.keys()): - layer_names = index2name[layer_index] - assert len(layer_names) == 1, "Multiple layers with the same layer index are not supported" - layer_name = layer_names[0] - k_cache_tensors.append(static_forward_context[layer_name].kv_cache[0]) - v_cache_tensors.append(static_forward_context[layer_name].kv_cache[1]) - # kv_cache_tensors shape: num_layers * (2, num_blocks, block_size, num_kv_heads, head_size) - self._called = True self.compiled = False def compile_or_call(*args): if not self.compiled: # Compile only at the first call -- when we get real tensors - state_dict = {} - for arg, name in zip(args, transfered_tensor_names): - if name == 'input_ids': - input_tensor = arg - elif name == 'positions': - positions_tensor = arg - else: - state_dict[name] = arg - vllm_config = get_current_vllm_config() - - model_config = vllm_config.model_config + logger.info("[Mirage] Calling compile_or_call for the first time, compiling......!") mpk_metadata = build_mpk_metadata( - vllm_config, + self.vllm_config, forward_context, - state_dict, - k_cache_tensors, - v_cache_tensors + args, + transfered_tensor_names, ) + logger.info(f"[Mirage] MPK metadata: {mpk_metadata.info_as_string()}") self.mpk = MPK(mpk_metadata) self.mpk.build() self.mpk.compile() self.compiled = True + logger.info(f"[Mirage] Calling the compiled result...") return self.mpk() - return VllmSerializableFunction( - graph, example_inputs, self.prefix, compile_or_call - ) \ No newline at end of file + # return VllmSerializableFunction( + # graph, example_inputs, self.prefix, compile_or_call + # ) + return compile_or_call \ No newline at end of file diff --git a/vllm/v1/attention/backends/mirage.py b/vllm/v1/attention/backends/mirage.py index d7df534bfeec..efaacd15b2b3 100755 --- a/vllm/v1/attention/backends/mirage.py +++ b/vllm/v1/attention/backends/mirage.py @@ -7,14 +7,6 @@ import numpy as np import torch -from flashinfer import ( - BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, - MultiLevelCascadeAttentionWrapper, -) -from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache -from flashinfer.prefill import trtllm_batch_context_with_kv_cache -from flashinfer.utils import FP4Tensor from vllm.attention.backends.abstract import ( AttentionBackend, @@ -22,38 +14,24 @@ AttentionType, MultipleOf, ) -from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, - kFp8StaticTensorSym, - kNvfp4Quant, -) + from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import ( - can_use_trtllm_attention, - flashinfer_disable_q_quantization, - use_trtllm_attention, -) from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, - get_per_layer_parameters, - infer_global_hyperparameters, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec -FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 - FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -133,17 +111,6 @@ class MirageAttentionMetadata: # The data type of the query q_data_type: torch.dtype - slot_mapping: torch.Tensor - - # For flashinfer trtllm batch decode - max_q_len: int - max_q_len_prefill: int - max_seq_len: int - seq_lens: torch.Tensor - block_table_tensor: torch.Tensor - prefill_use_trtllm: bool - decode_use_trtllm: bool - # For handling prefill decode split num_decodes: int num_decode_tokens: int @@ -174,9 +141,6 @@ def __init__( super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.cache_config = vllm_config.cache_config self.model_config = vllm_config.model_config - self._workspace_buffer = None - self._prefill_wrapper = None # Wrapper for prefill/append - self._decode_wrapper = None # Wrapper for decode (general shape) if vllm_is_batch_invariant(): self.decode_fixed_split_size = 2048 @@ -201,7 +165,6 @@ def __init__( self.head_dim = self.kv_cache_spec.head_size self.page_size = self.kv_cache_spec.block_size - # Preparing persistent buffers (device-side) self.paged_kv_indptr = torch.zeros( max_num_reqs + 1, dtype=torch.int32, device=self.device @@ -220,9 +183,6 @@ def __init__( max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory ) self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() - self.paged_kv_indptr_buffer = torch.zeros_like( - self.paged_kv_indptr_cpu, pin_memory=pin_memory - ) self.paged_kv_indices_cpu = torch.zeros( max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory ) @@ -248,9 +208,6 @@ def build( ) page_size = self.page_size - max_q_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.max_seq_len - seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_lens_np = seq_lens_cpu.numpy() block_table_tensor = common_attn_metadata.block_table_tensor @@ -263,15 +220,10 @@ def build( dtype=np.int32, out=self.paged_kv_indptr_np[1 : num_reqs + 1], ) - # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified - # after this line (e.g., for cuda graphs), we need to copy the data to - # self.paged_kv_indptr_buffer to avoid race condition. - self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[ - : num_reqs + 1 - ] + paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1] paged_kv_indptr.copy_( - self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True + self.paged_kv_indptr_cpu[: num_reqs + 1], non_blocking=True ) # write self.paged_kv_indices inplace @@ -292,147 +244,28 @@ def build( page_size, paged_kv_last_page_len_np, ) + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] + paged_kv_last_page_len.copy_( + self.paged_kv_last_page_len_cpu[:num_reqs], non_blocking=True + ) - uses_spec_reorder = self.reorder_batch_threshold > 1 + # uses_spec_reorder = self.reorder_batch_threshold > 1 - assert self.q_data_type == torch.float16, "MirageAttentionBackend currently only supports float16" - - if not (prefill_use_trtllm and decode_use_trtllm): - if self.has_sinks: - raise NotImplementedError( - "FlashInfer backend currently does not support attention " - "sinks, please use trtllm on blackwell or flash attention " - "on earlier GPUs." - ) - - if not self.global_hyperparameters.has_same_window_lefts: - raise ValueError( - "Window left is not the same for all layers. " - "One potential fix is to set disable_sliding_window=True" - ) - - assert self.global_hyperparameters.has_same_all_params, ( - "FlashInfer backend currently only supports models in which " - "all layers share the same values for the following " - "hyperparameters: `window_left`, `logits_soft_cap`, " - "`sm_scale`." - ) - - # The q quantization is not supported for non-trtllm attention, - # fall back to model dtype. - self.q_data_type = self.model_config.dtype + assert self.q_data_type == torch.bfloat16, "MirageAttentionBackend currently only supports bfloat16" attn_metadata = MirageAttentionMetadata( num_actual_tokens=num_actual_tokens, q_data_type=self.q_data_type, - slot_mapping=common_attn_metadata.slot_mapping, - max_q_len=max_q_len, - max_q_len_prefill=max_q_len, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table_tensor=block_table_tensor, - prefill_use_trtllm=prefill_use_trtllm, - decode_use_trtllm=decode_use_trtllm, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, + qo_indptr_gpu=common_attn_metadata.query_start_loc_gpu, + paged_kv_indptr_gpu=self.paged_kv_indptr, + paged_kv_indices_gpu=self.paged_kv_indices, + paged_kv_last_page_len_gpu=self.paged_kv_last_page_len, ) - qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu - paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] - paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] - - # Regular attention (common case). - # Decodes are at the front and prefills are at the back. - num_prefills = attn_metadata.num_prefills - num_decodes = attn_metadata.num_decodes - if num_prefills > 0: - # Decodes are first so prefills start after the last decode - prefill_start = num_decodes - attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 - assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 - assert ( - paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills - ) - # Since prefill_wrapper.run() will be called with - # query[num_decode_tokens:] we need to adjust the qo_indptr - # to be relative to the start of the prefill queries. - qo_indptr_cpu = ( - qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start] - ) - paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] - - # Recompute max_q_len for the slice of requests we are using - # for prefills. This can be different from max_q_len when - # we have a non-uniform batch with some short decodes offloaded - # to the prefill pathway - query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] - attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) - - attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( - self.device, non_blocking=True - ) - attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( - self.device, non_blocking=True - ) - - if num_decodes > 0: - pure_decode = num_prefills == 0 - # possible required padding for cudagraph replay - use_cudagraph = ( - self.enable_cuda_graph - and pure_decode - and num_decode_tokens <= self._decode_cudagraph_max_bs - ) - if use_cudagraph: - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_decode_tokens - ) - # Carefully fulfill the padding region with reasonable value - # on cpu. - # Make sure paged_kv_indptr_cpu is not decreasing - self.paged_kv_indptr_cpu[ - 1 + num_decodes : 1 + num_input_tokens - ].fill_(paged_kv_indptr_cpu[-1]) - # Fill the remaining paged_kv_last_page_len_cpu with 1. - # This is because flashinfer treats 0 as a full page - # instead of empty. - self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_( - 1 - ) - - else: - num_input_tokens = num_decode_tokens - - attn_metadata.decode_wrapper = self._get_decode_wrapper( - num_input_tokens, use_cudagraph - ) - if not attn_metadata.decode_use_trtllm: - # Use the persistent buffer with padding length, - # instead of the same address but chunked version - # in atten_metadata when using cudagraph. - fast_plan_decode( - attn_metadata.decode_wrapper, - self.paged_kv_indptr_cpu[: num_input_tokens + 1], - paged_kv_indices, - self.paged_kv_last_page_len_cpu[:num_input_tokens], - seq_lens_cpu[:num_input_tokens], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - sm_scale=self.sm_scale, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.kv_cache_dtype, - fixed_split_size=self.decode_fixed_split_size, - disable_split_kv=self.disable_split_kv, - ) return attn_metadata @@ -451,66 +284,7 @@ def __init__( kv_sharing_target_layer_name: int | None = None, sinks: torch.Tensor | None = None, ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - if sliding_window is None: - self.sliding_window = (-1, -1) - else: - self.sliding_window = (sliding_window - 1, 0) - self.window_left = ( - self.sliding_window[0] if self.sliding_window is not None else -1 - ) - self.kv_cache_dtype = kv_cache_dtype - self.logits_soft_cap = logits_soft_cap - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - if attn_type != AttentionType.DECODER: - raise NotImplementedError( - "Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferImpl" - ) - - self.sinks: torch.Tensor | None = None - if sinks is not None: - if sinks.shape[0] != num_heads: - raise ValueError( - "Sinks must have the same number of heads as the number of " - f"heads in the layer. Expected {num_heads}, but got " - f"{sinks.shape[0]}." - ) - self.sinks = sinks - - self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) - self.bmm1_scale: float | None = None - self.bmm2_scale: float | None = None - self.o_sf_scale: float | None = None - - def fused_output_quant_supported(self, quant_key: QuantKey): - return ( - self.support_trtllm_attn - and self.kv_cache_dtype.startswith("fp8") - and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) - ) - - def supports_quant_query_input(self) -> bool: - if flashinfer_disable_q_quantization(): - return False - - return self.support_trtllm_attn - - # FlashInfer requires attention sinks to be float32 - def process_weights_after_loading(self, act_dtype: torch.dtype): - if self.sinks is not None and self.sinks.dtype != torch.float32: - self.sinks = self.sinks.to(torch.float32) + pass def forward( self, @@ -519,12 +293,12 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: FlashInferMetadata, + attn_metadata: MirageAttentionMetadata, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: - """Forward pass with FlashInfer. + """Forward pass that do nothing. Args: query: shape = [num_tokens, num_heads, head_size] @@ -543,404 +317,7 @@ def forward( # Profiling run. return output.fill_(0) - # Ensure query dtype matches the expected dtype from attention metadata - assert attn_metadata.q_data_type == query.dtype, ( - f"Query dtype mismatch: expected {attn_metadata.q_data_type}, " - f"got {query.dtype}" - ) - - if self.bmm1_scale is None: - self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale - - if self.bmm2_scale is None: - self.bmm2_scale = layer._v_scale_float - - # The attn+quant fusion happens when output_scale is provided. - if output_scale is None: - assert output_block_scale is None, ( - "output_block_scale is not supported when fusion has not happened" - ) - else: - assert attn_metadata.q_data_type == FP8_DTYPE, ( - "Query must be FP8 when attn+quant fusion happened." - ) - assert ( - attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm - ), "Must use TRT-LLM attn" - - if output.dtype == FP8_DTYPE: - assert output_block_scale is None, ( - "output_block_scale should not be provided for fp8 output" - ) - elif output.dtype == FP4_DTYPE: - assert output_block_scale is not None, ( - "output_block_scale is required for nvfp4 output" - ) - else: - raise ValueError(f"Unsupported output dtype: {output.dtype}") - - # TRTLLM attn kernel requires to scale to pass as a host scalar, - # store the o scale as a host scalar in warmup run with cuda graph - # not enabled - if layer._o_scale_float is None: - layer._o_scale_float = output_scale.cpu().item() - if output.dtype == FP8_DTYPE: - self.bmm2_scale = self.bmm2_scale / layer._o_scale_float - elif output.dtype == FP4_DTYPE: - self.o_sf_scale = layer._o_scale_float - - # IMPORTANT! - # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in - # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead - # in this method. For example, `view` and `slice` (or `[:n]`) operations - # are surprisingly slow even in the case they do not invoke any GPU ops. - # Minimize the PyTorch ops in this method as much as possible. - # Whenever making a change in this method, please benchmark the - # performance to make sure it does not introduce any overhead. - - num_actual_tokens = attn_metadata.num_actual_tokens - - if self.kv_sharing_target_layer_name is None: - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 - # to process the cache when the kv_cache_dtype is fp8 - if self.kv_cache_dtype.startswith("fp8"): - torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.kv_cache_dtype - ) - kv_cache = kv_cache.view(torch_dtype) - - # Inputs and outputs may be padded for CUDA graphs - query = query[:num_actual_tokens] - output_padded = output - output = output[:num_actual_tokens] - - if attn_metadata.use_cascade: - # Cascade attention (rare case). - assert attn_metadata.cascade_wrapper is not None - output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) - return output - - # When using spec decoding, num_decodes can be < num_decode_tokens - # because some decode requests may have more than one query token. - num_decodes = attn_metadata.num_decodes - num_decode_tokens = attn_metadata.num_decode_tokens - num_prefill_tokens = attn_metadata.num_prefill_tokens - - stride_order = FlashInferBackend.get_kv_cache_stride_order() - kv_cache_permute = kv_cache.permute(*stride_order) - # Regular attention (common case). - # Decodes are at the front and prefills are at the back. - if num_prefill_tokens > 0: - prefill_wrapper = attn_metadata.prefill_wrapper - prefill_query = query[num_decode_tokens:] - assert prefill_query.shape[0] == num_prefill_tokens - assert prefill_wrapper is not None - - if not attn_metadata.prefill_use_trtllm: - assert prefill_wrapper._causal - assert prefill_wrapper._window_left == self.window_left - assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) - assert prefill_wrapper._sm_scale == self.scale - prefill_wrapper.run( - prefill_query, - kv_cache_permute, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[num_decode_tokens:], - ) - else: - # prefill_query may be non-contiguous - prefill_query = prefill_query.contiguous() - workspace_buffer = _get_trtllm_gen_workspace_buffer() - block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:] - seq_lens_prefill = attn_metadata.seq_lens[num_decodes:] - - # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND - assert get_kv_cache_layout() == "HND" - assert prefill_query.is_contiguous() - assert kv_cache_permute.is_contiguous() - assert workspace_buffer.is_contiguous() - assert block_tables_prefill.is_contiguous() - assert seq_lens_prefill.is_contiguous() - - if output.dtype == FP4_DTYPE: - assert self.o_sf_scale is not None - out = FP4Tensor( - data=output[num_decode_tokens:], - scale=output_block_scale, - scale_start_index=num_decode_tokens, - original_shape=prefill_query.shape, - ) - else: - assert self.o_sf_scale is None - out = output[num_decode_tokens:] - - if ( - attn_metadata.q_data_type != FP8_DTYPE - and self.kv_cache_dtype.startswith("fp8") - ): - # TRTLLM prefill attention does not support BF16 Q - # and fp8 kv cache. So to enable prefill attention - # with fp8 kv cache, we can construct a mock block - # and mock kv cache with BF16 KV involved in the prefill - mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant( - kv_cache_permute, - block_tables_prefill, - layer._k_scale, - layer._v_scale, - attn_metadata.q_data_type, - ) - else: - mock_kv_cache = kv_cache_permute - mock_block_table = block_tables_prefill - - trtllm_batch_context_with_kv_cache( - query=prefill_query, - kv_cache=mock_kv_cache, - workspace_buffer=workspace_buffer, - block_tables=mock_block_table, - seq_lens=seq_lens_prefill, - max_q_len=attn_metadata.max_q_len_prefill, - max_kv_len=attn_metadata.max_seq_len, - bmm1_scale=self.bmm1_scale, - bmm2_scale=self.bmm2_scale, - batch_size=attn_metadata.num_prefills, - cum_seq_lens_q=attn_metadata.qo_indptr_gpu, - cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, - window_left=self.window_left, - sinks=self.sinks, - o_sf_scale=self.o_sf_scale, - out=out, - ) - - if num_decode_tokens > 0: - decode_wrapper = attn_metadata.decode_wrapper - decode_query = query[:num_decode_tokens] - assert decode_query.shape[0] == num_decode_tokens - assert decode_wrapper is not None - - if not attn_metadata.decode_use_trtllm: - assert decode_wrapper._window_left == self.window_left - assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) - assert decode_wrapper._sm_scale == self.scale - decode_wrapper.run( - decode_query, - kv_cache_permute, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[:num_decode_tokens], - ) - else: - # decode_query may be non-contiguous - decode_query = decode_query.contiguous() - workspace_buffer = _get_trtllm_gen_workspace_buffer() - block_tables_decode = attn_metadata.block_table_tensor[ - :num_decode_tokens - ] - seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] - - # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND - assert get_kv_cache_layout() == "HND" - assert decode_query.is_contiguous() - assert kv_cache_permute.is_contiguous() - assert workspace_buffer.is_contiguous() - assert block_tables_decode.is_contiguous() - assert seq_lens_decode.is_contiguous() - - if output.dtype == FP4_DTYPE: - assert self.o_sf_scale is not None - out = FP4Tensor( - data=output[:num_decode_tokens], - scale=output_block_scale, - scale_start_index=0, - original_shape=decode_query.shape, - ) - else: - assert self.o_sf_scale is None - out = output[:num_decode_tokens] - - if num_decode_tokens % attn_metadata.num_decodes != 0: - # This gets triggered when the dummy_run forces - # attention to be initialized with q_len = 0 - q_len_per_req = 1 - else: - q_len_per_req = num_decode_tokens // attn_metadata.num_decodes - - trtllm_batch_decode_with_kv_cache( - query=decode_query, - kv_cache=kv_cache_permute, - workspace_buffer=workspace_buffer, - block_tables=block_tables_decode, - seq_lens=seq_lens_decode, - max_seq_len=attn_metadata.max_seq_len, - bmm1_scale=self.bmm1_scale, - bmm2_scale=self.bmm2_scale, - window_left=self.window_left, - sinks=self.sinks, - o_sf_scale=self.o_sf_scale, - out=out, - q_len_per_req=q_len_per_req, - ) - return output_padded - - -def fast_plan_decode( - self, # decode wrapper - indptr_cpu: torch.Tensor, - indices: torch.Tensor, - last_page_len_cpu: torch.Tensor, - seq_lens_cpu: torch.Tensor, - num_qo_heads: int, - num_kv_heads: int, - head_dim: int, - page_size: int, - pos_encoding_mode: str = "NONE", - window_left: int = -1, - logits_soft_cap: float | None = None, - q_data_type: str | torch.dtype | None = "float16", - kv_data_type: str | torch.dtype | None = None, - data_type: str | torch.dtype | None = None, - sm_scale: float | None = None, - rope_scale: float | None = None, - rope_theta: float | None = None, - non_blocking: bool = True, - fixed_split_size: int = -1, - disable_split_kv: bool = False, -) -> None: - """ - A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for - cudagraph capture/replay, while the no cudagraph version turns back - to the original plan. - using original plan after passing host-side buffers: - - only host-to-device copy of indptr and last_page_len buffers - Modifications for cudagraph: - - only host-to-device copy of indptr and last_page_len buffers. - - avoid device-to-device copy of indices buffer. - - Part of the code get inspiration from the original plan from FlashInfer repo - and the implementation of fast_decode_plan for FlashInfer in SGlang repo. - """ - # Warm up with the original plan if it is first call, and always run the - # original plan if we run for dynamic shape. For fixed shape (cudagraph), - # this warm up is to generate the _cached_module for the decode wrapper. - if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True): - self.plan( - indptr_cpu, - indices, - last_page_len_cpu, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - pos_encoding_mode, - window_left, - logits_soft_cap, - q_data_type, - kv_data_type, - data_type, - sm_scale, - rope_scale, - rope_theta, - non_blocking, - None, # block_tables - None, # seq_lens - fixed_split_size, - disable_split_kv, - ) - self.vllm_first_call = False - return - - assert self.is_cuda_graph_enabled, "Should be cudagraph only here" - - batch_size = len(last_page_len_cpu) - if logits_soft_cap is None: - logits_soft_cap = 0.0 - - # Handle data types consistently - if data_type is not None: - if q_data_type is None: - q_data_type = data_type - if kv_data_type is None: - kv_data_type = data_type - elif q_data_type is None: - q_data_type = "float16" - - if kv_data_type is None: - kv_data_type = q_data_type - q_data_type = ( - getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type - ) - kv_data_type = ( - getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type - ) - - if batch_size != self._fixed_batch_size: - raise ValueError( - "The batch size should be fixed in cudagraph mode, the runtime " - "batch size {} mismatches the batch size set during " - "initialization {}".format(batch_size, self._fixed_batch_size) - ) - if len(indices) > len(self._paged_kv_indices_buf): - raise ValueError( - "The size of indices should be less than or equal to the allocated buffer" - ) - - # host-to-device copy for the indptr buffer - self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True) - # host-to-device copy for the last_page_len buffer - self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) - - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") - - try: - # Make sure we pass exactly 18 arguments for tensor core version - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_cpu, - seq_lens_cpu, - batch_size, # total_num_rows - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - head_dim, - head_dim, - False, # causal - window_left, - fixed_split_size, - disable_split_kv, - ) - except Exception as e: - raise RuntimeError(f"Error in tensor core plan: {e}") from e - - self._pos_encoding_mode = pos_encoding_mode - self._window_left = window_left - self._logits_soft_cap = logits_soft_cap - self._sm_scale = sm_scale - self._rope_scale = rope_scale - self._rope_theta = rope_theta + return output @triton.jit From dbb0c2d533ce005aa41ec282721b56d5550c8369 Mon Sep 17 00:00:00 2001 From: Jianan Ji Date: Mon, 20 Oct 2025 17:35:04 +0000 Subject: [PATCH 03/12] register --- vllm/config/compilation.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index da2c100dae3d..5ed7f243ff12 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -770,17 +770,21 @@ def __post_init__(self) -> None: "(where 'op' is the registered op name)" ) - # Currently only eager and inductor backend are supported. - # for piecewise compilation. Custom backends are not suppported for - # piecewise compilation. Update when more backends are supported. - if self.mode == CompilationMode.VLLM_COMPILE and self.backend not in [ - "", - "eager", - "inductor", - ]: - raise ValueError( - f"Invalid backend for piecewise compilation: {self.backend}" - ) + # Allow eager/inductor for piecewise compilation by default. + # Additionally allow opting into an experimental/custom backend by + # specifying its fully-qualified class path. For now we explicitly + # permit the MirageBackend. + if self.mode == CompilationMode.VLLM_COMPILE: + allowed_backends = { + "", + "eager", + "inductor", + "vllm.compilation.mirage_backend.MirageBackend", + } + if self.backend not in allowed_backends: + raise ValueError( + f"Invalid backend for piecewise compilation: {self.backend}" + ) if self.backend == "": self.backend = current_platform.get_compile_backend() @@ -817,7 +821,11 @@ def init_backend(self, vllm_config: "VllmConfig") -> str | Callable: if self.backend not in ["eager", "inductor"]: logger.info("Using OOT custom backend for compilation.") - from vllm.compilation.backends import VllmBackend + # Custom/experimental backend specified by fully-qualified class name. + # Currently support MirageBackend. + if self.backend == "vllm.compilation.mirage_backend.MirageBackend": + from vllm.compilation.mirage_backend import MirageBackend + return MirageBackend(vllm_config) # TODO[@lucaskabela]: See if we can forward prefix # https://github.com/vllm-project/vllm/issues/27045 From f909d355b0cf4f6eaf1564783ff7824ec85e0167 Mon Sep 17 00:00:00 2001 From: Jianan Ji Date: Tue, 21 Oct 2025 05:31:42 +0000 Subject: [PATCH 04/12] correct register --- vllm/config/compilation.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 5ed7f243ff12..ce4d8777d999 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -770,21 +770,17 @@ def __post_init__(self) -> None: "(where 'op' is the registered op name)" ) - # Allow eager/inductor for piecewise compilation by default. - # Additionally allow opting into an experimental/custom backend by - # specifying its fully-qualified class path. For now we explicitly - # permit the MirageBackend. - if self.mode == CompilationMode.VLLM_COMPILE: - allowed_backends = { - "", - "eager", - "inductor", - "vllm.compilation.mirage_backend.MirageBackend", - } - if self.backend not in allowed_backends: - raise ValueError( - f"Invalid backend for piecewise compilation: {self.backend}" - ) + # Currently only eager and inductor backend are supported. + # for piecewise compilation. Custom backends are not suppported for + # piecewise compilation. Update when more backends are supported. + if self.mode == CompilationMode.VLLM_COMPILE and self.backend not in [ + "", + "eager", + "inductor", + ]: + raise ValueError( + f"Invalid backend for piecewise compilation: {self.backend}" + ) if self.backend == "": self.backend = current_platform.get_compile_backend() @@ -815,17 +811,16 @@ def init_backend(self, vllm_config: "VllmConfig") -> str | Callable: ]: if self.backend in torch_backends: return self.backend + if self.backend == "mirage": + from vllm.compilation.mirage_backend import MirageBackend + return MirageBackend(vllm_config) return resolve_obj_by_qualname(self.backend) assert self.mode == CompilationMode.VLLM_COMPILE if self.backend not in ["eager", "inductor"]: logger.info("Using OOT custom backend for compilation.") - # Custom/experimental backend specified by fully-qualified class name. - # Currently support MirageBackend. - if self.backend == "vllm.compilation.mirage_backend.MirageBackend": - from vllm.compilation.mirage_backend import MirageBackend - return MirageBackend(vllm_config) + from vllm.compilation.backends import VllmBackend # TODO[@lucaskabela]: See if we can forward prefix # https://github.com/vllm-project/vllm/issues/27045 From d0ae52bce9cb00656571a5b99c5f26324208a60d Mon Sep 17 00:00:00 2001 From: Jianan Ji Date: Tue, 21 Oct 2025 05:33:55 +0000 Subject: [PATCH 05/12] backend avoid dumb run compilation. Now vllm can start (still not compile) --- vllm/compilation/mirage_backend.py | 41 ++++++++++++++++++++++------ vllm/v1/attention/backends/mirage.py | 11 ++++++++ 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/vllm/compilation/mirage_backend.py b/vllm/compilation/mirage_backend.py index 033fdc551ad6..a30264080ba0 100644 --- a/vllm/compilation/mirage_backend.py +++ b/vllm/compilation/mirage_backend.py @@ -3,6 +3,7 @@ from mirage import MPK, MPKMetadata, MirageModelConfig import re from vllm.config import ModelConfig, get_current_vllm_config +from vllm.config.parallel import ParallelConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.models.utils import extract_layer_index import torch @@ -56,14 +57,15 @@ def build_model_config( k_cache_tensors: list[torch.Tensor], v_cache_tensors: list[torch.Tensor], position_embeddings: torch.Tensor, + parallel_config: ParallelConfig, ) -> MirageModelConfig: mirage_model_config = MirageModelConfig( # model architecture hidden_size=model_config.get_hidden_size(), intermediate_size=getattr(model_config.hf_text_config, "intermediate_size", 0), vocab_size=model_config.get_vocab_size(), - num_q_heads=model_config.get_num_attention_heads(), - num_kv_heads=model_config.get_num_kv_heads(), + local_num_q_heads=model_config.get_num_attention_heads(parallel_config), + local_num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_dim=model_config.get_head_size(), num_layers=getattr(model_config.hf_text_config, "num_hidden_layers", 0), # kv cache @@ -78,15 +80,16 @@ def build_model_config( def build_mpk_metadata( vllm_config: VllmConfig, - forward_context: ForwardContext, args: list[Any], transfered_tensor_names: list[str], ) -> MPKMetadata: + forward_context = get_forward_context() model_config = vllm_config.model_config scheduler_config = vllm_config.scheduler_config cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config attn_metadata = forward_context.attn_metadata + logger.info(f"[Mirage] Forward context: {forward_context}, attn_metadata: {attn_metadata}") static_forward_context = forward_context.no_compile_layers # layer names to layers k_cache_tensors = [] @@ -100,7 +103,7 @@ def build_mpk_metadata( layer_names = index2name[layer_index] assert len(layer_names) == 1, "Multiple layers with the same layer index are not supported" layer_name = layer_names[0] - logger.info(f"{layer_index} {layer_name}: attention num: {len(static_forward_context[layer_name].kv_cache)}; kv_cache.shape: {static_forward_context[layer_name].kv_cache[0].shape}") + # logger.info(f"{layer_index} {layer_name}: attention num: {len(static_forward_context[layer_name].kv_cache)}; kv_cache.shape: {static_forward_context[layer_name].kv_cache[0].shape}") k_cache_tensors.append(static_forward_context[layer_name].kv_cache[0][0]) v_cache_tensors.append(static_forward_context[layer_name].kv_cache[0][1]) # kv_cache_tensors shape: num_layers * (2, num_blocks, block_size, num_kv_heads, head_size) @@ -124,7 +127,8 @@ def build_mpk_metadata( state_dict, k_cache_tensors, v_cache_tensors, - position_embeddings + position_embeddings, + parallel_config, ) mpk_metadata = MPKMetadata( mode = "online", @@ -139,7 +143,7 @@ def build_mpk_metadata( device = "cuda", # # model weight_from_model = False, - model_name = model_config.model_name, + model_name = model_config.model, # model_path: Optional[str] = None # multi device support world_size = parallel_config.world_size, @@ -218,6 +222,19 @@ def __call__( # transform and analysis are done compilation_counter.num_graphs_seen += 1 from .monitor import torch_compile_start_time + + # TODO: remove this after debugging + # try: + # src = graph.print_readable(print_output=False) + # except Exception: + # src = str(graph) + # try: + # with open('mirage_backends_graph.txt', 'w') as f: + # logger.info('Writing readable FX graph to mirage_backends_graph.txt') + # f.write(src) + # logger.info('Readable FX graph written to mirage_backends_graph.txt') + # except Exception: + # logger.exception('Failed to write mirage_backends_graph.txt') dynamo_time = time.time() - torch_compile_start_time logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) @@ -232,18 +249,26 @@ def __call__( transfered_tensor_names = transfer_tensor_names(placeholders) - forward_context = get_forward_context() self._called = True self.compiled = False def compile_or_call(*args): + dumb_run_called = (get_forward_context().attn_metadata is None) + if dumb_run_called: + model_config = self.vllm_config.model_config + dtype = model_config.dtype + hidden_size = model_config.get_hidden_size() + output_tensor = torch.zeros(1, hidden_size, device='cuda', dtype=dtype) + logger.info(f"[Mirage] Calling dumb_run_called, returning dummy output tensor with shape [{output_tensor.shape}]......!") + + return (output_tensor,) + if not self.compiled: # Compile only at the first call -- when we get real tensors logger.info("[Mirage] Calling compile_or_call for the first time, compiling......!") mpk_metadata = build_mpk_metadata( self.vllm_config, - forward_context, args, transfered_tensor_names, ) diff --git a/vllm/v1/attention/backends/mirage.py b/vllm/v1/attention/backends/mirage.py index efaacd15b2b3..e3bcfab952dd 100755 --- a/vllm/v1/attention/backends/mirage.py +++ b/vllm/v1/attention/backends/mirage.py @@ -284,6 +284,17 @@ def __init__( kv_sharing_target_layer_name: int | None = None, sinks: torch.Tensor | None = None, ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = scale + self.num_kv_heads = num_kv_heads + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap + self.attn_type = attn_type + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + self.sinks = sinks pass def forward( From 9da8262c661db5014c11a3b95cb0a93249e92836 Mon Sep 17 00:00:00 2001 From: Jianan Ji Date: Wed, 22 Oct 2025 04:53:52 +0000 Subject: [PATCH 06/12] bug fix --- vllm/compilation/mirage_backend.py | 20 +++++++++++++++----- vllm/v1/attention/backends/mirage.py | 8 +------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm/compilation/mirage_backend.py b/vllm/compilation/mirage_backend.py index a30264080ba0..695fee7d9956 100644 --- a/vllm/compilation/mirage_backend.py +++ b/vllm/compilation/mirage_backend.py @@ -56,9 +56,18 @@ def build_model_config( state_dict: dict[str, torch.Tensor], k_cache_tensors: list[torch.Tensor], v_cache_tensors: list[torch.Tensor], - position_embeddings: torch.Tensor, + position_embeddings_: torch.Tensor, parallel_config: ParallelConfig, ) -> MirageModelConfig: + whole_dim = position_embeddings_.shape[-1] + cos_tensor_ = position_embeddings_[:, 0:whole_dim//2].unsqueeze(0) + sin_tensor_ = position_embeddings_[:, whole_dim//2:].unsqueeze(0) + + cos_tensor = torch.cat([cos_tensor_, cos_tensor_], dim=-1) + sin_tensor = torch.cat([sin_tensor_, sin_tensor_], dim=-1) + + position_embeddings = (cos_tensor, sin_tensor) + logger.info(f"[Mirage] position_embeddings: {position_embeddings[0].shape}, {position_embeddings[1].shape}") mirage_model_config = MirageModelConfig( # model architecture hidden_size=model_config.get_hidden_size(), @@ -75,6 +84,7 @@ def build_model_config( position_embeddings=position_embeddings, # model weights state_dict=state_dict, + with_lm_head=False, ) return mirage_model_config @@ -88,9 +98,9 @@ def build_mpk_metadata( scheduler_config = vllm_config.scheduler_config cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config - attn_metadata = forward_context.attn_metadata - logger.info(f"[Mirage] Forward context: {forward_context}, attn_metadata: {attn_metadata}") - + # For now we assume only one attention group + attn_metadata = list(forward_context.attn_metadata.values())[0] + static_forward_context = forward_context.no_compile_layers # layer names to layers k_cache_tensors = [] v_cache_tensors = [] @@ -275,7 +285,7 @@ def compile_or_call(*args): logger.info(f"[Mirage] MPK metadata: {mpk_metadata.info_as_string()}") self.mpk = MPK(mpk_metadata) self.mpk.build() - self.mpk.compile() + self.mpk.compile(output_dir=os.path.join(os.path.dirname(__file__), "mirage_backend_output")) self.compiled = True diff --git a/vllm/v1/attention/backends/mirage.py b/vllm/v1/attention/backends/mirage.py index e3bcfab952dd..ec1f9db3fcbe 100755 --- a/vllm/v1/attention/backends/mirage.py +++ b/vllm/v1/attention/backends/mirage.py @@ -108,9 +108,6 @@ def get_kv_cache_stride_order() -> tuple[int, ...]: class MirageAttentionMetadata: num_actual_tokens: int # Number of tokens excluding padding. - # The data type of the query - q_data_type: torch.dtype - # For handling prefill decode split num_decodes: int num_decode_tokens: int @@ -250,17 +247,14 @@ def build( ) # uses_spec_reorder = self.reorder_batch_threshold > 1 - - assert self.q_data_type == torch.bfloat16, "MirageAttentionBackend currently only supports bfloat16" attn_metadata = MirageAttentionMetadata( num_actual_tokens=num_actual_tokens, - q_data_type=self.q_data_type, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, - qo_indptr_gpu=common_attn_metadata.query_start_loc_gpu, + qo_indptr_gpu=common_attn_metadata.query_start_loc, paged_kv_indptr_gpu=self.paged_kv_indptr, paged_kv_indices_gpu=self.paged_kv_indices, paged_kv_last_page_len_gpu=self.paged_kv_last_page_len, From 84cdc9fe5a8ed6738597c2aad107ab59215f32b0 Mon Sep 17 00:00:00 2001 From: Jianan Ji Date: Mon, 3 Nov 2025 04:17:15 +0000 Subject: [PATCH 07/12] compatible with mpk --- vllm/compilation/mirage_backend.py | 62 ++++++++++++++++++++++++---- vllm/v1/attention/backends/mirage.py | 6 +-- 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/vllm/compilation/mirage_backend.py b/vllm/compilation/mirage_backend.py index 695fee7d9956..feb545ac492a 100644 --- a/vllm/compilation/mirage_backend.py +++ b/vllm/compilation/mirage_backend.py @@ -11,8 +11,9 @@ logger = init_logger(__name__) +# TODO(Jianan Ji): Is this name mapping common for all models? def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]: - """Transfer FX placeholder debug names to model-like dotted names. + """Transfer FX placeholder debug names to model-like dotted names. Return a list of transferred names and input id. Example: l_self_modules_layers_modules_17_modules_mlp_modules_gate_up_proj_parameters_weight_ @@ -24,13 +25,15 @@ def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]: Instead, we annotate via node.meta['logical_name'] and return the list. """ converted_names = [] - s_pattern = re.compile(r"^s\d+$") + s_pattern = re.compile(r"^s\d+$") # s72 / s80 + input_id = 0 - for node in placeholders: + for i, node in enumerate(placeholders): name = node.name if name == 'l_input_ids_': final_name = 'input_ids' converted_names.append(final_name) + input_id = i elif name == 'l_positions_': final_name = 'positions' converted_names.append(final_name) @@ -49,7 +52,7 @@ def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]: converted_names.append(final_name) - return converted_names + return converted_names, input_id def build_model_config( model_config: ModelConfig, @@ -67,7 +70,6 @@ def build_model_config( sin_tensor = torch.cat([sin_tensor_, sin_tensor_], dim=-1) position_embeddings = (cos_tensor, sin_tensor) - logger.info(f"[Mirage] position_embeddings: {position_embeddings[0].shape}, {position_embeddings[1].shape}") mirage_model_config = MirageModelConfig( # model architecture hidden_size=model_config.get_hidden_size(), @@ -129,6 +131,43 @@ def build_mpk_metadata( positions_tensor = arg elif "cos_sin_cache" in name: position_embeddings = arg + elif "qkv" in name: + # Split qkv since we need to shuffle them on mirage side later + # (6144, 4096) -> (4096, 4096), (1024, 4096), (1024, 4096) + qkv_tensor = arg + + total_dim = qkv_tensor.shape[0] + n_q_heads = model_config.get_num_attention_heads(parallel_config) # 32 + n_kv_heads = model_config.get_num_kv_heads(parallel_config) # 8 + n_heads = n_q_heads + n_kv_heads * 2 + + q_range = (total_dim * n_q_heads) // n_heads # 6144 * 32 / 48 = 4096 + k_range = (total_dim * (n_q_heads + n_kv_heads)) // n_heads # 6144 * 40 / 48 = 5120 + + q_tensor = qkv_tensor[:q_range, :] + k_tensor = qkv_tensor[q_range:k_range, :] + v_tensor = qkv_tensor[k_range:, :] + + # substitute qkv to q/k/v views + state_dict[name.replace("qkv", "q")] = q_tensor + state_dict[name.replace("qkv", "k")] = k_tensor + state_dict[name.replace("qkv", "v")] = v_tensor + + state_dict[name] = qkv_tensor + elif "gate_up" in name: + # Split gate_up to gate and up + gate_up_tensor = arg + total_dim = gate_up_tensor.shape[0] + single_dim = total_dim // 2 + + gate_tensor = gate_up_tensor[:single_dim, :] + up_tensor = gate_up_tensor[single_dim:, :] + + # substitude gate_up to gate and up + state_dict[name.replace("gate_up", "gate")] = gate_tensor + state_dict[name.replace("gate_up", "up")] = up_tensor + + state_dict[name] = gate_up_tensor else: state_dict[name] = arg @@ -141,7 +180,7 @@ def build_mpk_metadata( parallel_config, ) mpk_metadata = MPKMetadata( - mode = "online", + mode = "online_notoken", # total_num_requests # num_remote_schedulers: int = 0 max_seq_length = model_config.max_model_len, @@ -257,7 +296,9 @@ def __call__( placeholders = [node for node in graph.graph.nodes if node.op == 'placeholder'] assert len(placeholders) == len(example_inputs) - transfered_tensor_names = transfer_tensor_names(placeholders) + transfered_tensor_names, input_id = transfer_tensor_names(placeholders) + + max_input_tokens = example_inputs[input_id].shape[0] self._called = True @@ -269,7 +310,8 @@ def compile_or_call(*args): model_config = self.vllm_config.model_config dtype = model_config.dtype hidden_size = model_config.get_hidden_size() - output_tensor = torch.zeros(1, hidden_size, device='cuda', dtype=dtype) + # TODO(Jianan Ji): We'll want to run in eager instead of doing nothing + output_tensor = torch.zeros(max_input_tokens, hidden_size, device='cuda', dtype=dtype) logger.info(f"[Mirage] Calling dumb_run_called, returning dummy output tensor with shape [{output_tensor.shape}]......!") return (output_tensor,) @@ -290,7 +332,9 @@ def compile_or_call(*args): self.compiled = True logger.info(f"[Mirage] Calling the compiled result...") - return self.mpk() + result_hidden_states = self.mpk() + + return (result_hidden_states,) # return VllmSerializableFunction( # graph, example_inputs, self.prefix, compile_or_call diff --git a/vllm/v1/attention/backends/mirage.py b/vllm/v1/attention/backends/mirage.py index ec1f9db3fcbe..64e7068808e9 100755 --- a/vllm/v1/attention/backends/mirage.py +++ b/vllm/v1/attention/backends/mirage.py @@ -255,9 +255,9 @@ def build( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, qo_indptr_gpu=common_attn_metadata.query_start_loc, - paged_kv_indptr_gpu=self.paged_kv_indptr, - paged_kv_indices_gpu=self.paged_kv_indices, - paged_kv_last_page_len_gpu=self.paged_kv_last_page_len, + paged_kv_indptr_gpu=paged_kv_indptr, + paged_kv_indices_gpu=paged_kv_indices, + paged_kv_last_page_len_gpu=paged_kv_last_page_len, ) return attn_metadata From 830482cd38e722771a0a7f147a6e983cb3910b74 Mon Sep 17 00:00:00 2001 From: Jianan Ji Date: Tue, 4 Nov 2025 20:19:37 +0000 Subject: [PATCH 08/12] fix compatibility --- vllm/compilation/mirage_backend.py | 1 - vllm/v1/attention/backends/mirage.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/mirage_backend.py b/vllm/compilation/mirage_backend.py index feb545ac492a..a3503b71d0e5 100644 --- a/vllm/compilation/mirage_backend.py +++ b/vllm/compilation/mirage_backend.py @@ -331,7 +331,6 @@ def compile_or_call(*args): self.compiled = True - logger.info(f"[Mirage] Calling the compiled result...") result_hidden_states = self.mpk() return (result_hidden_states,) diff --git a/vllm/v1/attention/backends/mirage.py b/vllm/v1/attention/backends/mirage.py index 64e7068808e9..3d8d321ed50e 100755 --- a/vllm/v1/attention/backends/mirage.py +++ b/vllm/v1/attention/backends/mirage.py @@ -22,7 +22,8 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, is_pin_memory_available +from vllm.utils.math_utils import cdiv +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, From 062647c6ffca6fdeab910878f55f6bcc2c50b4ed Mon Sep 17 00:00:00 2001 From: Jianan Ji Date: Tue, 11 Nov 2025 05:04:55 +0000 Subject: [PATCH 09/12] code clean --- vllm/compilation/mirage_backend.py | 76 +++++++++++----------------- vllm/config/compilation.py | 2 +- vllm/config/vllm.py | 15 +++++- vllm/v1/attention/backends/mirage.py | 24 +-------- 4 files changed, 45 insertions(+), 72 deletions(-) diff --git a/vllm/compilation/mirage_backend.py b/vllm/compilation/mirage_backend.py index a3503b71d0e5..505340c3d249 100644 --- a/vllm/compilation/mirage_backend.py +++ b/vllm/compilation/mirage_backend.py @@ -1,14 +1,21 @@ +import os from collections import defaultdict -from .backends import * +import time from mirage import MPK, MPKMetadata, MirageModelConfig import re -from vllm.config import ModelConfig, get_current_vllm_config +from typing import Any + +import torch +import torch.fx as fx + +from vllm.config import CompilationConfig, ModelConfig, VllmConfig, get_current_vllm_config from vllm.config.parallel import ParallelConfig -from vllm.forward_context import ForwardContext, get_forward_context +from vllm.forward_context import get_forward_context from vllm.model_executor.models.utils import extract_layer_index -import torch from vllm.logger import init_logger +from .counter import compilation_counter + logger = init_logger(__name__) # TODO(Jianan Ji): Is this name mapping common for all models? @@ -16,7 +23,8 @@ def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]: """Transfer FX placeholder debug names to model-like dotted names. Return a list of transferred names and input id. Example: - l_self_modules_layers_modules_17_modules_mlp_modules_gate_up_proj_parameters_weight_ + l_self_modules_layers_modules_17_modules_mlp_\ + modules_gate_up_proj_parameters_weight_ -> model.layers.17.mlp.gate_up_proj.weight Notes: @@ -26,14 +34,12 @@ def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]: """ converted_names = [] s_pattern = re.compile(r"^s\d+$") # s72 / s80 - input_id = 0 - for i, node in enumerate(placeholders): + for node in placeholders: name = node.name if name == 'l_input_ids_': final_name = 'input_ids' converted_names.append(final_name) - input_id = i elif name == 'l_positions_': final_name = 'positions' converted_names.append(final_name) @@ -52,7 +58,7 @@ def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]: converted_names.append(final_name) - return converted_names, input_id + return converted_names def build_model_config( model_config: ModelConfig, @@ -111,14 +117,12 @@ def build_mpk_metadata( for layer_name in static_forward_context.keys(): index2name[extract_layer_index(layer_name, 1)].append(layer_name) - for layer_index in sorted(index2name.keys()): + for layer_index in sorted(index2name): layer_names = index2name[layer_index] assert len(layer_names) == 1, "Multiple layers with the same layer index are not supported" layer_name = layer_names[0] - # logger.info(f"{layer_index} {layer_name}: attention num: {len(static_forward_context[layer_name].kv_cache)}; kv_cache.shape: {static_forward_context[layer_name].kv_cache[0].shape}") k_cache_tensors.append(static_forward_context[layer_name].kv_cache[0][0]) v_cache_tensors.append(static_forward_context[layer_name].kv_cache[0][1]) - # kv_cache_tensors shape: num_layers * (2, num_blocks, block_size, num_kv_heads, head_size) state_dict = {} input_token_ids = None @@ -219,16 +223,7 @@ def build_mpk_metadata( return mpk_metadata class MirageBackend: - """The compilation backend for `torch.compile` with vLLM. - It is used for compilation level of `CompilationLevel.PIECEWISE`, - where we customize the compilation. - - The major work of this backend is to split the graph into - piecewise graphs, and pass them to the piecewise backend. - - This backend also adds the PostGradPassManager to Inductor config, - which handles the post-grad passes. - """ + """The compilation backend for Mirage Persistent Kernel.""" vllm_config: VllmConfig compilation_config: CompilationConfig @@ -236,33 +231,18 @@ class MirageBackend: # the graph we compiled graph: fx.GraphModule - input_buffers: list[torch.Tensor] - def __init__( self, vllm_config: VllmConfig, prefix: str = "", ): - # if the model is initialized with a non-empty prefix, - # then usually it's enough to use that prefix, - # e.g. language_model, vision_model, etc. - # when multiple parts are initialized as independent - # models, we need to use the model_tag to distinguish - # them, e.g. backbone (default), eagle_head, etc. - logger.info("[Mirage] Calling MirageBackend init!") - self.prefix = prefix or model_tag - - # Passes to run on the graph post-grad. - self.post_grad_pass_manager = PostGradPassManager() - - self.input_buffers = [] + logger.debug("[Mirage] Calling MirageBackend init!") self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.model_config = vllm_config.model_config self.model_name = vllm_config.model_config.model - def __call__( self, graph: fx.GraphModule, example_inputs ) -> Any: @@ -296,9 +276,15 @@ def __call__( placeholders = [node for node in graph.graph.nodes if node.op == 'placeholder'] assert len(placeholders) == len(example_inputs) - transfered_tensor_names, input_id = transfer_tensor_names(placeholders) + transfered_tensor_names = transfer_tensor_names(placeholders) - max_input_tokens = example_inputs[input_id].shape[0] + max_input_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens + + # TODO(Jianan Ji): remove this after debugging + # with open('mirage_backends_graph.txt', 'w') as f: + # f.write(graph.print_readable(print_output=False)) + # with open("graph_structure.txt", "w", encoding="utf-8") as f: + # f.write(str(graph.graph)) self._called = True @@ -310,11 +296,12 @@ def compile_or_call(*args): model_config = self.vllm_config.model_config dtype = model_config.dtype hidden_size = model_config.get_hidden_size() - # TODO(Jianan Ji): We'll want to run in eager instead of doing nothing - output_tensor = torch.zeros(max_input_tokens, hidden_size, device='cuda', dtype=dtype) - logger.info(f"[Mirage] Calling dumb_run_called, returning dummy output tensor with shape [{output_tensor.shape}]......!") + # # TODO(Jianan Ji): We'll want to run graph(*args) instead of doing nothing + output_tensor = torch.zeros(2, hidden_size, device='cuda', dtype=dtype) + # logger.info(f"[Mirage] Calling dumb_run_called, returning dummy output tensor with shape [{output_tensor.shape}]......!") return (output_tensor,) + # return graph(*args) if not self.compiled: # Compile only at the first call -- when we get real tensors @@ -335,7 +322,4 @@ def compile_or_call(*args): return (result_hidden_states,) - # return VllmSerializableFunction( - # graph, example_inputs, self.prefix, compile_or_call - # ) return compile_or_call \ No newline at end of file diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ce4d8777d999..1e6e97c15cb9 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -811,7 +811,7 @@ def init_backend(self, vllm_config: "VllmConfig") -> str | Callable: ]: if self.backend in torch_backends: return self.backend - if self.backend == "mirage": + if self.backend == "mirage_byname": from vllm.compilation.mirage_backend import MirageBackend return MirageBackend(vllm_config) return resolve_obj_by_qualname(self.backend) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 34e70e3e134b..18267d584536 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -690,8 +690,12 @@ def has_blocked_weights(): ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - # disable cudagraph when enforce eager execution - if self.model_config is not None and self.model_config.enforce_eager: + # disable cudagraph when enforce eager execution or mirage backend is used + disable_cuda_graph = ( + (self.model_config is not None and self.model_config.enforce_eager) + or (self.compilation_config.backend == "mirage_byname") + ) + if disable_cuda_graph: logger.info("Cudagraph is disabled under eager mode") self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE # override related settings when enforce eager @@ -703,6 +707,13 @@ def has_blocked_weights(): self._set_cudagraph_sizes() else: self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + if self.compilation_config.backend == "mirage_byname": + if envs.VLLM_ATTENTION_BACKEND is None: + envs.VLLM_ATTENTION_BACKEND = "MIRAGE" + elif envs.VLLM_ATTENTION_BACKEND != "MIRAGE": + raise ValueError(f"Have to use MIRAGE attention backend when using mirage backend. Now it is {envs.VLLM_ATTENTION_BACKEND}") + assert self.cache_config.block_size % 64 == 0, "Block size must be a multiple of 64 for mirage backend." if self.cache_config.kv_sharing_fast_prefill: if ( diff --git a/vllm/v1/attention/backends/mirage.py b/vllm/v1/attention/backends/mirage.py index 3d8d321ed50e..5c444bd946d8 100755 --- a/vllm/v1/attention/backends/mirage.py +++ b/vllm/v1/attention/backends/mirage.py @@ -29,7 +29,6 @@ AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, - split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -68,7 +67,7 @@ def validate_head_size(cls, head_size: int) -> None: @staticmethod def get_name() -> str: - return "MPK_ATTENTION" + return "MIRAGE" @staticmethod def get_impl_cls() -> type["MirageAttentionImpl"]: @@ -107,14 +106,6 @@ def get_kv_cache_stride_order() -> tuple[int, ...]: @dataclass class MirageAttentionMetadata: - num_actual_tokens: int # Number of tokens excluding padding. - - # For handling prefill decode split - num_decodes: int - num_decode_tokens: int - num_prefills: int - num_prefill_tokens: int - # Meta tensors qo_indptr_gpu: torch.Tensor | None = None paged_kv_indptr_gpu: torch.Tensor | None = None @@ -196,14 +187,6 @@ def build( fast_build: bool = False, ) -> MirageAttentionMetadata: num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold, - require_uniform=True, - ) - ) page_size = self.page_size seq_lens_cpu = common_attn_metadata.seq_lens_cpu @@ -250,11 +233,6 @@ def build( # uses_spec_reorder = self.reorder_batch_threshold > 1 attn_metadata = MirageAttentionMetadata( - num_actual_tokens=num_actual_tokens, - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, - num_prefill_tokens=num_prefill_tokens, qo_indptr_gpu=common_attn_metadata.query_start_loc, paged_kv_indptr_gpu=paged_kv_indptr, paged_kv_indices_gpu=paged_kv_indices, From ed4621172f79ed52702d238ba701b3b9edf41d5e Mon Sep 17 00:00:00 2001 From: Jianan Ji Date: Sun, 30 Nov 2025 19:28:55 +0000 Subject: [PATCH 10/12] add stream and graph dumb run --- vllm/compilation/mirage_backend.py | 13 +++---------- vllm/v1/attention/backends/mirage.py | 4 ++-- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/vllm/compilation/mirage_backend.py b/vllm/compilation/mirage_backend.py index 505340c3d249..1bf2ec9b33dd 100644 --- a/vllm/compilation/mirage_backend.py +++ b/vllm/compilation/mirage_backend.py @@ -293,15 +293,7 @@ def __call__( def compile_or_call(*args): dumb_run_called = (get_forward_context().attn_metadata is None) if dumb_run_called: - model_config = self.vllm_config.model_config - dtype = model_config.dtype - hidden_size = model_config.get_hidden_size() - # # TODO(Jianan Ji): We'll want to run graph(*args) instead of doing nothing - output_tensor = torch.zeros(2, hidden_size, device='cuda', dtype=dtype) - # logger.info(f"[Mirage] Calling dumb_run_called, returning dummy output tensor with shape [{output_tensor.shape}]......!") - - return (output_tensor,) - # return graph(*args) + return graph(*args) if not self.compiled: # Compile only at the first call -- when we get real tensors @@ -318,7 +310,8 @@ def compile_or_call(*args): self.compiled = True - result_hidden_states = self.mpk() + default_stream = torch.cuda.current_stream() + result_hidden_states = self.mpk(default_stream = default_stream) return (result_hidden_states,) diff --git a/vllm/v1/attention/backends/mirage.py b/vllm/v1/attention/backends/mirage.py index 5c444bd946d8..afecbff9d2f4 100755 --- a/vllm/v1/attention/backends/mirage.py +++ b/vllm/v1/attention/backends/mirage.py @@ -300,8 +300,8 @@ def forward( if attn_metadata is None: # Profiling run. return output.fill_(0) - - return output + else: + raise NotImplementedError("MirageAttentionImpl is never meant to be used directly.") @triton.jit From 935674200b1d0af1c97343a9754456759def98a1 Mon Sep 17 00:00:00 2001 From: Jianan Ji Date: Sun, 30 Nov 2025 19:32:27 +0000 Subject: [PATCH 11/12] use cudagraph for mpk --- vllm/config/vllm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 18267d584536..a7b425d54a2b 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -690,10 +690,9 @@ def has_blocked_weights(): ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - # disable cudagraph when enforce eager execution or mirage backend is used + # disable cudagraph when enforce eager execution disable_cuda_graph = ( (self.model_config is not None and self.model_config.enforce_eager) - or (self.compilation_config.backend == "mirage_byname") ) if disable_cuda_graph: logger.info("Cudagraph is disabled under eager mode") From 065966bd2b1d9839833295dc8601ca96fa47b2ed Mon Sep 17 00:00:00 2001 From: Jianan Ji Date: Mon, 1 Dec 2025 04:02:30 +0000 Subject: [PATCH 12/12] fix registry name --- vllm/attention/backends/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 87212c30d793..0849b15bed7b 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -76,7 +76,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): "RocmAiterUnifiedAttentionBackend" ) CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend" - MIRAGE = "vllm.v1.attention.backends.mirage.MirageBackend" + MIRAGE = "vllm.v1.attention.backends.mirage.MirageAttentionBackend" # Placeholder for third-party/custom backends - must be registered before use CUSTOM = ""