diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 125e4e382774..0849b15bed7b 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -76,6 +76,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): "RocmAiterUnifiedAttentionBackend" ) CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend" + MIRAGE = "vllm.v1.attention.backends.mirage.MirageAttentionBackend" # Placeholder for third-party/custom backends - must be registered before use CUSTOM = "" diff --git a/vllm/compilation/mirage_backend.py b/vllm/compilation/mirage_backend.py new file mode 100644 index 000000000000..1bf2ec9b33dd --- /dev/null +++ b/vllm/compilation/mirage_backend.py @@ -0,0 +1,318 @@ +import os +from collections import defaultdict +import time +from mirage import MPK, MPKMetadata, MirageModelConfig +import re +from typing import Any + +import torch +import torch.fx as fx + +from vllm.config import CompilationConfig, ModelConfig, VllmConfig, get_current_vllm_config +from vllm.config.parallel import ParallelConfig +from vllm.forward_context import get_forward_context +from vllm.model_executor.models.utils import extract_layer_index +from vllm.logger import init_logger + +from .counter import compilation_counter + +logger = init_logger(__name__) + +# TODO(Jianan Ji): Is this name mapping common for all models? +def transfer_tensor_names(placeholders: list[torch.fx.node.Node]) -> list[str]: + """Transfer FX placeholder debug names to model-like dotted names. Return a list of transferred names and input id. + + Example: + l_self_modules_layers_modules_17_modules_mlp_\ + modules_gate_up_proj_parameters_weight_ + -> model.layers.17.mlp.gate_up_proj.weight + + Notes: + - Tailored for Qwen3-style module names seen in exported FX graphs. + - We do NOT rename the FX node identifiers (dots are not valid in FX names). + Instead, we annotate via node.meta['logical_name'] and return the list. + """ + converted_names = [] + s_pattern = re.compile(r"^s\d+$") # s72 / s80 + + for node in placeholders: + name = node.name + if name == 'l_input_ids_': + final_name = 'input_ids' + converted_names.append(final_name) + elif name == 'l_positions_': + final_name = 'positions' + converted_names.append(final_name) + elif s_pattern.match(name): # s72 / s80 + converted_names.append(name) + else: + if name.startswith('l_self_modules_'): + name = name.replace('l_self_modules_', '', 1) + if name.endswith('_'): + name = name[:-1] + + name = name.replace('_modules_', '.') + name = name.replace('_parameters_', '.') + + final_name = 'model.' + name + + converted_names.append(final_name) + + return converted_names + +def build_model_config( + model_config: ModelConfig, + state_dict: dict[str, torch.Tensor], + k_cache_tensors: list[torch.Tensor], + v_cache_tensors: list[torch.Tensor], + position_embeddings_: torch.Tensor, + parallel_config: ParallelConfig, +) -> MirageModelConfig: + whole_dim = position_embeddings_.shape[-1] + cos_tensor_ = position_embeddings_[:, 0:whole_dim//2].unsqueeze(0) + sin_tensor_ = position_embeddings_[:, whole_dim//2:].unsqueeze(0) + + cos_tensor = torch.cat([cos_tensor_, cos_tensor_], dim=-1) + sin_tensor = torch.cat([sin_tensor_, sin_tensor_], dim=-1) + + position_embeddings = (cos_tensor, sin_tensor) + mirage_model_config = MirageModelConfig( + # model architecture + hidden_size=model_config.get_hidden_size(), + intermediate_size=getattr(model_config.hf_text_config, "intermediate_size", 0), + vocab_size=model_config.get_vocab_size(), + local_num_q_heads=model_config.get_num_attention_heads(parallel_config), + local_num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_dim=model_config.get_head_size(), + num_layers=getattr(model_config.hf_text_config, "num_hidden_layers", 0), + # kv cache + k_cache=k_cache_tensors, + v_cache=v_cache_tensors, + # position embeddings + position_embeddings=position_embeddings, + # model weights + state_dict=state_dict, + with_lm_head=False, + ) + return mirage_model_config + +def build_mpk_metadata( + vllm_config: VllmConfig, + args: list[Any], + transfered_tensor_names: list[str], + ) -> MPKMetadata: + forward_context = get_forward_context() + model_config = vllm_config.model_config + scheduler_config = vllm_config.scheduler_config + cache_config = vllm_config.cache_config + parallel_config = vllm_config.parallel_config + # For now we assume only one attention group + attn_metadata = list(forward_context.attn_metadata.values())[0] + + static_forward_context = forward_context.no_compile_layers # layer names to layers + k_cache_tensors = [] + v_cache_tensors = [] + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in static_forward_context.keys(): + index2name[extract_layer_index(layer_name, 1)].append(layer_name) + + for layer_index in sorted(index2name): + layer_names = index2name[layer_index] + assert len(layer_names) == 1, "Multiple layers with the same layer index are not supported" + layer_name = layer_names[0] + k_cache_tensors.append(static_forward_context[layer_name].kv_cache[0][0]) + v_cache_tensors.append(static_forward_context[layer_name].kv_cache[0][1]) + + state_dict = {} + input_token_ids = None + positions_tensor = None + position_embeddings = None + for arg, name in zip(args, transfered_tensor_names): + if name == 'input_ids': + input_token_ids = arg + elif name == 'positions': + positions_tensor = arg + elif "cos_sin_cache" in name: + position_embeddings = arg + elif "qkv" in name: + # Split qkv since we need to shuffle them on mirage side later + # (6144, 4096) -> (4096, 4096), (1024, 4096), (1024, 4096) + qkv_tensor = arg + + total_dim = qkv_tensor.shape[0] + n_q_heads = model_config.get_num_attention_heads(parallel_config) # 32 + n_kv_heads = model_config.get_num_kv_heads(parallel_config) # 8 + n_heads = n_q_heads + n_kv_heads * 2 + + q_range = (total_dim * n_q_heads) // n_heads # 6144 * 32 / 48 = 4096 + k_range = (total_dim * (n_q_heads + n_kv_heads)) // n_heads # 6144 * 40 / 48 = 5120 + + q_tensor = qkv_tensor[:q_range, :] + k_tensor = qkv_tensor[q_range:k_range, :] + v_tensor = qkv_tensor[k_range:, :] + + # substitute qkv to q/k/v views + state_dict[name.replace("qkv", "q")] = q_tensor + state_dict[name.replace("qkv", "k")] = k_tensor + state_dict[name.replace("qkv", "v")] = v_tensor + + state_dict[name] = qkv_tensor + elif "gate_up" in name: + # Split gate_up to gate and up + gate_up_tensor = arg + total_dim = gate_up_tensor.shape[0] + single_dim = total_dim // 2 + + gate_tensor = gate_up_tensor[:single_dim, :] + up_tensor = gate_up_tensor[single_dim:, :] + + # substitude gate_up to gate and up + state_dict[name.replace("gate_up", "gate")] = gate_tensor + state_dict[name.replace("gate_up", "up")] = up_tensor + + state_dict[name] = gate_up_tensor + else: + state_dict[name] = arg + + mirage_model_config = build_model_config( + model_config, + state_dict, + k_cache_tensors, + v_cache_tensors, + position_embeddings, + parallel_config, + ) + mpk_metadata = MPKMetadata( + mode = "online_notoken", + # total_num_requests + # num_remote_schedulers: int = 0 + max_seq_length = model_config.max_model_len, + max_num_batched_requests = scheduler_config.max_num_seqs, + max_num_batched_tokens = scheduler_config.max_num_batched_tokens, + max_num_pages = cache_config.num_gpu_blocks, + page_size = cache_config.block_size, + # max_sm_num: int = 108 + device = "cuda", + # # model + weight_from_model = False, + model_name = model_config.model, + # model_path: Optional[str] = None + # multi device support + world_size = parallel_config.world_size, + rank = parallel_config.rank, + # # Meta tensors + step = positions_tensor, + # tokens: Optional[torch.Tensor] = None + input_tokens = input_token_ids, + # output_tokens: Optional[torch.Tensor] = None + # num_new_tokens: Optional[torch.Tensor] = None + # prompt_lengths: Optional[torch.Tensor] = None + qo_indptr_buffer = attn_metadata.qo_indptr_gpu, + paged_kv_indptr_buffer = attn_metadata.paged_kv_indptr_gpu, + paged_kv_indices_buffer = attn_metadata.paged_kv_indices_gpu, + paged_kv_last_page_len_buffer = attn_metadata.paged_kv_last_page_len_gpu, + # kv cache tensors, weights and model config + model_config=mirage_model_config, + # # profiling + # profiler_tensor: Optional[torch.Tensor] = None + # trace_name: Optional[str] = None + # # spec decode config + # spec_decode_config: Optional[object] = None + ) + return mpk_metadata + +class MirageBackend: + """The compilation backend for Mirage Persistent Kernel.""" + + vllm_config: VllmConfig + compilation_config: CompilationConfig + _called: bool = False + # the graph we compiled + graph: fx.GraphModule + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + logger.debug("[Mirage] Calling MirageBackend init!") + + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.model_config = vllm_config.model_config + self.model_name = vllm_config.model_config.model + + def __call__( + self, graph: fx.GraphModule, example_inputs + ) -> Any: + + # when dynamo calls the backend, it means the bytecode + # transform and analysis are done + compilation_counter.num_graphs_seen += 1 + from .monitor import torch_compile_start_time + + # TODO: remove this after debugging + # try: + # src = graph.print_readable(print_output=False) + # except Exception: + # src = str(graph) + # try: + # with open('mirage_backends_graph.txt', 'w') as f: + # logger.info('Writing readable FX graph to mirage_backends_graph.txt') + # f.write(src) + # logger.info('Readable FX graph written to mirage_backends_graph.txt') + # except Exception: + # logger.exception('Failed to write mirage_backends_graph.txt') + + dynamo_time = time.time() - torch_compile_start_time + logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) + self.compilation_config.compilation_time += dynamo_time + + # we control the compilation process, each instance can only be + # called once + assert not self._called, "MirageBackend can only be called once" + + placeholders = [node for node in graph.graph.nodes if node.op == 'placeholder'] + assert len(placeholders) == len(example_inputs) + + transfered_tensor_names = transfer_tensor_names(placeholders) + + max_input_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens + + # TODO(Jianan Ji): remove this after debugging + # with open('mirage_backends_graph.txt', 'w') as f: + # f.write(graph.print_readable(print_output=False)) + # with open("graph_structure.txt", "w", encoding="utf-8") as f: + # f.write(str(graph.graph)) + + + self._called = True + self.compiled = False + + def compile_or_call(*args): + dumb_run_called = (get_forward_context().attn_metadata is None) + if dumb_run_called: + return graph(*args) + + if not self.compiled: + # Compile only at the first call -- when we get real tensors + logger.info("[Mirage] Calling compile_or_call for the first time, compiling......!") + mpk_metadata = build_mpk_metadata( + self.vllm_config, + args, + transfered_tensor_names, + ) + logger.info(f"[Mirage] MPK metadata: {mpk_metadata.info_as_string()}") + self.mpk = MPK(mpk_metadata) + self.mpk.build() + self.mpk.compile(output_dir=os.path.join(os.path.dirname(__file__), "mirage_backend_output")) + + self.compiled = True + + default_stream = torch.cuda.current_stream() + result_hidden_states = self.mpk(default_stream = default_stream) + + return (result_hidden_states,) + + return compile_or_call \ No newline at end of file diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index da2c100dae3d..1e6e97c15cb9 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -811,6 +811,9 @@ def init_backend(self, vllm_config: "VllmConfig") -> str | Callable: ]: if self.backend in torch_backends: return self.backend + if self.backend == "mirage_byname": + from vllm.compilation.mirage_backend import MirageBackend + return MirageBackend(vllm_config) return resolve_obj_by_qualname(self.backend) assert self.mode == CompilationMode.VLLM_COMPILE diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 34e70e3e134b..a7b425d54a2b 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -691,7 +691,10 @@ def has_blocked_weights(): self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE # disable cudagraph when enforce eager execution - if self.model_config is not None and self.model_config.enforce_eager: + disable_cuda_graph = ( + (self.model_config is not None and self.model_config.enforce_eager) + ) + if disable_cuda_graph: logger.info("Cudagraph is disabled under eager mode") self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE # override related settings when enforce eager @@ -703,6 +706,13 @@ def has_blocked_weights(): self._set_cudagraph_sizes() else: self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + if self.compilation_config.backend == "mirage_byname": + if envs.VLLM_ATTENTION_BACKEND is None: + envs.VLLM_ATTENTION_BACKEND = "MIRAGE" + elif envs.VLLM_ATTENTION_BACKEND != "MIRAGE": + raise ValueError(f"Have to use MIRAGE attention backend when using mirage backend. Now it is {envs.VLLM_ATTENTION_BACKEND}") + assert self.cache_config.block_size % 64 == 0, "Block size must be a multiple of 64 for mirage backend." if self.cache_config.kv_sharing_fast_prefill: if ( diff --git a/vllm/v1/attention/backends/mirage.py b/vllm/v1/attention/backends/mirage.py new file mode 100755 index 000000000000..afecbff9d2f4 --- /dev/null +++ b/vllm/v1/attention/backends/mirage.py @@ -0,0 +1,328 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlashInfer.""" + +from dataclasses import dataclass +from typing import ClassVar + +import numpy as np +import torch + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, + MultipleOf, +) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) + +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils.math_utils import cdiv +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_kv_cache_layout, +) +from vllm.v1.kv_cache_interface import AttentionSpec + +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + +logger = init_logger(__name__) + +class MirageAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + # TODO: (Jianan) Make sure these are correct. + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [128] + + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [16, 32, 64, 4096] + + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes." + ) + + @staticmethod + def get_name() -> str: + return "MIRAGE" + + @staticmethod + def get_impl_cls() -> type["MirageAttentionImpl"]: + return MirageAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["MirageAttentionMetadata"]: + return MirageAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["MirageAttentionMetadataBuilder"]: + return MirageAttentionMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order() -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets us from + # `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_cache_layout() + if cache_layout == "NHD": + stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND": + stride_order = (0, 1, 3, 2, 4) + else: + raise ValueError(f"Unknown cache layout format {cache_layout}.") + return stride_order + +@dataclass +class MirageAttentionMetadata: + # Meta tensors + qo_indptr_gpu: torch.Tensor | None = None + paged_kv_indptr_gpu: torch.Tensor | None = None + paged_kv_indices_gpu: torch.Tensor | None = None + paged_kv_last_page_len_gpu: torch.Tensor | None = None + + +class MirageAttentionMetadataBuilder(AttentionMetadataBuilder[MirageAttentionMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) + + reorder_batch_threshold: int = 1 + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.cache_config = vllm_config.cache_config + self.model_config = vllm_config.model_config + + if vllm_is_batch_invariant(): + self.decode_fixed_split_size = 2048 + self.prefill_fixed_split_size = 4096 + self.disable_split_kv = True + else: + self.decode_fixed_split_size = -1 + self.prefill_fixed_split_size = -1 + self.disable_split_kv = False + + self.compilation_config = vllm_config.compilation_config + max_num_pages_per_req = cdiv( + self.model_config.max_model_len, self.kv_cache_spec.block_size + ) + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + max_num_pages = max_num_reqs * max_num_pages_per_req + + self.num_qo_heads = self.model_config.get_num_attention_heads( + self.vllm_config.parallel_config + ) + self.num_kv_heads = self.kv_cache_spec.num_kv_heads + self.head_dim = self.kv_cache_spec.head_size + self.page_size = self.kv_cache_spec.block_size + + # Preparing persistent buffers (device-side) + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=self.device + ) + self.paged_kv_indices = torch.zeros( + max_num_pages, # max num pages possible + dtype=torch.int32, + device=self.device, + ) + self.paged_kv_last_page_len = torch.zeros( + max_num_reqs, dtype=torch.int32, device=self.device + ) + # host-side buffer + pin_memory = is_pin_memory_available() + self.paged_kv_indptr_cpu = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() + self.paged_kv_indices_cpu = torch.zeros( + max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_cpu = torch.zeros( + max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy() + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> MirageAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + + page_size = self.page_size + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens_np = seq_lens_cpu.numpy() + block_table_tensor = common_attn_metadata.block_table_tensor + + num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size + + # write self.paged_kv_indptr_cpu inplace (0-index is always 0) + np.cumsum( + num_blocks_np, + dtype=np.int32, + out=self.paged_kv_indptr_np[1 : num_reqs + 1], + ) + + paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1] + paged_kv_indptr.copy_( + self.paged_kv_indptr_cpu[: num_reqs + 1], non_blocking=True + ) + + # write self.paged_kv_indices inplace + num_actual_pages = self.paged_kv_indptr_np[num_reqs] + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + _copy_page_indices_kernel[(num_reqs,)]( + paged_kv_indices, + block_table_tensor, + block_table_tensor.stride(0), + paged_kv_indptr, + BLOCK_SIZE=1024, + ) + + # write self.paged_kv_last_page_len_cpu inplace + paged_kv_last_page_len_np = seq_lens_np % page_size + self.paged_kv_last_page_len_np[:num_reqs] = np.where( + paged_kv_last_page_len_np == 0, + page_size, + paged_kv_last_page_len_np, + ) + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] + paged_kv_last_page_len.copy_( + self.paged_kv_last_page_len_cpu[:num_reqs], non_blocking=True + ) + + # uses_spec_reorder = self.reorder_batch_threshold > 1 + + attn_metadata = MirageAttentionMetadata( + qo_indptr_gpu=common_attn_metadata.query_start_loc, + paged_kv_indptr_gpu=paged_kv_indptr, + paged_kv_indices_gpu=paged_kv_indices, + paged_kv_last_page_len_gpu=paged_kv_last_page_len, + ) + + return attn_metadata + + +class MirageAttentionImpl(AttentionImpl): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: int | None = None, + sinks: torch.Tensor | None = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = scale + self.num_kv_heads = num_kv_heads + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap + self.attn_type = attn_type + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + self.sinks = sinks + pass + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: MirageAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass that do nothing. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: KV cache tensor with different possible shapes: + - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + else: + raise NotImplementedError("MirageAttentionImpl is never meant to be used directly.") + + +@triton.jit +def _copy_page_indices_kernel( + page_indices, + block_table, + block_table_stride, + cu_num_blocks, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = block_table + req_idx * block_table_stride + start_idx = tl.load(cu_num_blocks + req_idx) + end_idx = tl.load(cu_num_blocks + req_idx + 1) + num_blocks = end_idx - start_idx + + offset = tl.arange(0, BLOCK_SIZE) + for i in tl.range(0, num_blocks, BLOCK_SIZE): + block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) + tl.store( + page_indices + start_idx + i + offset, + block_ids, + mask=i + offset < num_blocks, + )