From 4bea700d9ee1da4ac23eb0722f33f0377bc11081 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Wed, 19 Nov 2025 05:26:41 -0500 Subject: [PATCH 01/13] separate attention kernel launches for prefill and decode Signed-off-by: Jan van Lunteren --- .../test_triton_unified_attention.py | 10 + .../attention/ops/triton_unified_attention.py | 227 ++++++++++++------ vllm/v1/attention/backends/triton_attn.py | 69 +++++- 3 files changed, 234 insertions(+), 72 deletions(-) diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index bf4d2179af5f..a4cca93f445d 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -22,6 +22,10 @@ # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] +# 0: use 2D kernel for decode +# 8: use 3D kernel for decode +SEQ_THRESHOLD_3D_VALUES = [0, 8] + def ref_paged_attn( query: torch.Tensor, @@ -92,6 +96,7 @@ def ref_paged_attn( @pytest.mark.parametrize("soft_cap", [None, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("q_dtype", QDTYPES) +@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES) @torch.inference_mode() def test_triton_unified_attn( seq_lens: list[tuple[int, int]], @@ -103,6 +108,7 @@ def test_triton_unified_attn( soft_cap: float | None, num_blocks: int, q_dtype: torch.dtype | None, + seq_threshold_3D: int, ) -> None: torch.set_default_device("cuda") @@ -152,6 +158,8 @@ def test_triton_unified_attn( k_descale = torch.rand(scale_shape, dtype=torch.float32) v_descale = torch.rand(scale_shape, dtype=torch.float32) + num_decodes = num_seqs if max_query_len == 1 else query_lens.count(1) + unified_attention( q=maybe_quantized_query, k=maybe_quantized_key_cache, @@ -161,6 +169,7 @@ def test_triton_unified_attn( seqused_k=kv_lens, max_seqlen_q=max_query_len, max_seqlen_k=max_kv_len, + num_decodes=num_decodes, softmax_scale=scale, causal=True, window_size=window_size, @@ -169,6 +178,7 @@ def test_triton_unified_attn( q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + seq_threshold_3D=seq_threshold_3D, ) ref_output = ref_paged_attn( diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 565be1c39bec..d2c5be91ceb4 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -98,6 +98,8 @@ def kernel_unified_attention_2d( BLOCK_Q: tl.constexpr, # int num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int + seq_idx_offset, # int + is_prefill: tl.constexpr, USE_FP8: tl.constexpr, # bool FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, @@ -105,11 +107,15 @@ def kernel_unified_attention_2d( q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) - seq_idx = find_seq_idx( - query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True - ) - - q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + if is_prefill: + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + else: + seq_idx = q_block_global_idx + q_block_start_idx = seq_idx + seq_idx = seq_idx + seq_idx_offset q_block_local_idx = q_block_global_idx - q_block_start_idx @@ -398,16 +404,22 @@ def kernel_unified_attention_3d( num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + seq_idx_offset, # int + is_prefill: tl.constexpr, ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) - seq_idx = find_seq_idx( - query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True - ) - - q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + if is_prefill: + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + else: + seq_idx = q_block_global_idx + q_block_start_idx = seq_idx + seq_idx = seq_idx + seq_idx_offset q_block_local_idx = q_block_global_idx - q_block_start_idx @@ -662,6 +674,8 @@ def reduce_segments( query_start_len_ptr, # [num_seqs+1] BLOCK_Q: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + seq_idx_offset, # int + is_prefill: tl.constexpr, USE_FP8: tl.constexpr, # bool FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, @@ -669,9 +683,13 @@ def reduce_segments( query_token_idx = tl.program_id(0) query_head_idx = tl.program_id(1) - seq_idx = find_seq_idx( - query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False - ) + if is_prefill: + seq_idx = find_seq_idx( + query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False + ) + else: + seq_idx = query_token_idx + seq_idx = seq_idx + seq_idx_offset # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -741,6 +759,7 @@ def unified_attention( max_seqlen_q, seqused_k, max_seqlen_k, + num_decodes, softmax_scale, causal, window_size, @@ -749,6 +768,7 @@ def unified_attention( q_descale, k_descale, v_descale, + seq_threshold_3D, alibi_slopes=None, output_scale=None, qq_bias=None, @@ -790,11 +810,12 @@ def unified_attention( # Assigning default tile sizes for prefill and decode. # Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1) # and at least 16 for all other data types. - TILE_SIZE_PREFILL = 32 - TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 + TILE_SIZE_2D_PREFILL = 32 + TILE_SIZE_2D_DECODE = 32 + TILE_SIZE_3D_DECODE = 16 if q.element_size() >= 2 else 32 - # if batch contains a prefill - if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: + # prefills + if num_seqs > num_decodes: # or total_num_q_blocks * num_kv_heads > 128: kernel_unified_attention_2d[ ( total_num_q_blocks, @@ -824,7 +845,7 @@ def unified_attention( output_stride_1=out.stride(1), qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, - TILE_SIZE=TILE_SIZE_PREFILL, + TILE_SIZE=TILE_SIZE_2D_PREFILL, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -844,40 +865,20 @@ def unified_attention( BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, BLOCK_M=BLOCK_M, + seq_idx_offset=num_decodes - num_decodes, + is_prefill=True, USE_FP8=output_scale is not None, ) - else: - # for initial version, NUM_SEGMENTS = 16 is chosen as a default - # value that showed good performance in tests - NUM_SEGMENTS = 16 - - segm_output = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - triton.next_power_of_2(head_size), - dtype=torch.float32, - device=q.device, - ) - segm_max = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - dtype=torch.float32, - device=q.device, - ) - segm_expsum = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - dtype=torch.float32, - device=q.device, - ) - kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, + # decodes + if num_decodes > seq_threshold_3D: + kernel_unified_attention_2d[ + ( + num_decodes, + num_kv_heads, + ) + ]( + output_ptr=out, query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, @@ -889,15 +890,18 @@ def unified_attention( scale=softmax_scale, k_scale=k_descale, v_scale=v_descale, + out_scale=1 / output_scale if output_scale is not None else 1.0, softcap=softcap, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, block_table_stride=block_table.stride(0), query_stride_0=q.stride(0), query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, - TILE_SIZE=TILE_SIZE_DECODE, + TILE_SIZE=TILE_SIZE_2D_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -915,27 +919,108 @@ def unified_attention( stride_v_cache_3=v.stride(3), query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, - num_seqs=num_seqs, + num_seqs=num_decodes, BLOCK_M=BLOCK_M, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, - ) - reduce_segments[(q.shape[0], num_query_heads)]( - output_ptr=out, - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, - seq_lens_ptr=seqused_k, - num_seqs=num_seqs, - num_query_heads=num_query_heads, - out_scale_inv=1 / output_scale if output_scale is not None else 1.0, - output_stride_0=out.stride(0), - output_stride_1=out.stride(1), - block_table_stride=block_table.stride(0), - TILE_SIZE=TILE_SIZE_DECODE, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + seq_idx_offset=0, + is_prefill=False, USE_FP8=output_scale is not None, ) + else: + if num_decodes > 0: + # for initial version, NUM_SEGMENTS = 16 is chosen as a default + # value that showed good performance in tests + NUM_SEGMENTS = 16 + + segm_output = torch.empty( + num_decodes, + num_query_heads, + NUM_SEGMENTS, + triton.next_power_of_2(head_size), + dtype=torch.float32, + device=q.device, + ) + segm_max = torch.empty( + num_decodes, + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + segm_expsum = torch.empty( + num_decodes, + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + + kernel_unified_attention_3d[(num_decodes, num_kv_heads, NUM_SEGMENTS)]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_3D_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_decodes, + BLOCK_M=BLOCK_M, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + seq_idx_offset=0, + is_prefill=False, + ) + + reduce_segments[(num_decodes, num_query_heads)]( + output_ptr=out, + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + seq_lens_ptr=seqused_k, + num_seqs=num_decodes, + num_query_heads=num_query_heads, + out_scale_inv=1 / output_scale if output_scale is not None else 1.0, + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + block_table_stride=block_table.stride(0), + TILE_SIZE=TILE_SIZE_3D_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + seq_idx_offset=0, + is_prefill=False, + USE_FP8=output_scale is not None, + ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 889c79db18ef..f1fd58a48fda 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -17,7 +17,7 @@ triton_reshape_and_cache_flash, ) from vllm.attention.ops.triton_unified_attention import unified_attention -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -30,12 +30,17 @@ AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) +# constants +MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel + + @dataclass class TritonAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. @@ -51,9 +56,12 @@ class TritonAttentionMetadata: query_start_loc: torch.Tensor max_seq_len: int seq_lens: torch.Tensor + num_decodes: int block_table: torch.Tensor slot_mapping: torch.Tensor + seq_threshold_3D: int + # For cascade attention. use_cascade: bool common_prefix_len: int @@ -68,6 +76,7 @@ class TritonAttentionMetadata: class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS + reorder_batch_threshold: int = 1 def __init__( self, @@ -87,6 +96,49 @@ def __init__( self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() + # Check if CUDA Graphs are enabled for decode + self.decode_cudagraph_enabled = ( + self.vllm_config.compilation_config.cudagraph_mode + in ( + CUDAGraphMode.FULL_AND_PIECEWISE, + CUDAGraphMode.FULL_DECODE_ONLY, + CUDAGraphMode.FULL, + ) + ) + + # The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv). + # A lower bound for num_q_blocks is the number of sequences. + # To ensure the minimum launch grid size is achieved, the number of sequences + # must be at least equal to the threshold below. + # If this threshold is not reached (i.e., the batch size is not large enough), + # the 3D kernel will be selected instead. + self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv + + # Modify the threshold if needed. + if self.decode_cudagraph_enabled: + capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + if not capture_sizes: + # If no CUDA Graph capture sizes are specified, the threshold + # is reset to zero, forcing the 2D kernel to be used. + self.seq_threshold_3D = 0 + else: + # Select the CUDA Graph capture size closest to self.seq_threshold_3D + # as threshold. This ensures that each captured graph covers the + # correct execution path. + upd_seq_threshold_3D = min( + capture_sizes, + key=lambda x: abs(x - self.seq_threshold_3D), + ) + + # If the updated threshold becomes significantly larger than the + # initial value, it is reset to zero. This enforces the use of the + # 2D kernel only and ensures that the size of the allocated + # intermediate structures remains bounded. + if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D: + self.seq_threshold_3D = upd_seq_threshold_3D + else: + self.seq_threshold_3D = 0 + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TritonAttentionMetadata: @@ -112,6 +164,14 @@ def build( block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=True, + ) + ) + use_cascade = common_prefix_len > 0 if use_cascade: @@ -135,6 +195,7 @@ def build( query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, + num_decodes=num_decodes, block_table=block_table_tensor, slot_mapping=slot_mapping, use_cascade=use_cascade, @@ -143,6 +204,7 @@ def build( prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, + seq_threshold_3D=self.seq_threshold_3D, ) return attn_metadata @@ -341,11 +403,14 @@ def forward( ) cu_seqlens_q = attn_metadata.query_start_loc + num_decodes = attn_metadata.num_decodes seqused_k = attn_metadata.seq_lens max_seqlen_q = attn_metadata.max_query_len max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table + seq_threshold_3D = attn_metadata.seq_threshold_3D + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) unified_attention( @@ -355,6 +420,7 @@ def forward( out=output[:num_actual_tokens], cu_seqlens_q=cu_seqlens_q, max_seqlen_q=max_seqlen_q, + num_decodes=num_decodes, seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, @@ -366,6 +432,7 @@ def forward( q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + seq_threshold_3D=seq_threshold_3D, sinks=self.sinks, output_scale=output_scale, ) From 1d6bb225cc123ade9c233f335feaf42ac48aa8b0 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Wed, 19 Nov 2025 06:12:28 -0500 Subject: [PATCH 02/13] remove comment Signed-off-by: Jan van Lunteren --- vllm/attention/ops/triton_unified_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index d2c5be91ceb4..16f5037b9788 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -815,7 +815,7 @@ def unified_attention( TILE_SIZE_3D_DECODE = 16 if q.element_size() >= 2 else 32 # prefills - if num_seqs > num_decodes: # or total_num_q_blocks * num_kv_heads > 128: + if num_seqs > num_decodes: kernel_unified_attention_2d[ ( total_num_q_blocks, From 1386b3749b86629b110012083303b4ad9ced99dd Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Wed, 19 Nov 2025 11:48:38 -0500 Subject: [PATCH 03/13] various modifications Signed-off-by: Jan van Lunteren --- .../test_triton_unified_attention.py | 12 +- .../attention/ops/triton_unified_attention.py | 223 +++++++++++------- vllm/v1/attention/backends/triton_attn.py | 11 + 3 files changed, 165 insertions(+), 81 deletions(-) diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index a4cca93f445d..933045ab3d19 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -26,6 +26,8 @@ # 8: use 3D kernel for decode SEQ_THRESHOLD_3D_VALUES = [0, 8] +SPLIT_LAUNCH_VALUES = [False, True] + def ref_paged_attn( query: torch.Tensor, @@ -86,7 +88,12 @@ def ref_paged_attn( @pytest.mark.parametrize( - "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] + "seq_lens", + [ + [(1, 1328), (5, 18), (129, 463)], # mixed batch + [(1, 523), (1, 37), (1, 2011)], # decode-only batch + [(5, 18), (129, 463)], # prefill-only batch + ], ) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -97,6 +104,7 @@ def ref_paged_attn( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("q_dtype", QDTYPES) @pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES) +@pytest.mark.parametrize("split_launch", SPLIT_LAUNCH_VALUES) @torch.inference_mode() def test_triton_unified_attn( seq_lens: list[tuple[int, int]], @@ -109,6 +117,7 @@ def test_triton_unified_attn( num_blocks: int, q_dtype: torch.dtype | None, seq_threshold_3D: int, + split_launch: bool, ) -> None: torch.set_default_device("cuda") @@ -179,6 +188,7 @@ def test_triton_unified_attn( k_descale=k_descale, v_descale=v_descale, seq_threshold_3D=seq_threshold_3D, + split_launch=split_launch, ) ref_output = ref_paged_attn( diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 16f5037b9788..c185009f3157 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -98,16 +98,16 @@ def kernel_unified_attention_2d( BLOCK_Q: tl.constexpr, # int num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int - seq_idx_offset, # int - is_prefill: tl.constexpr, + q_block_offset, # int + decode_only: tl.constexpr, USE_FP8: tl.constexpr, # bool FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, ): - q_block_global_idx = tl.program_id(0) + q_block_global_idx = tl.program_id(0) + q_block_offset kv_head_idx = tl.program_id(1) - if is_prefill: + if not decode_only: seq_idx = find_seq_idx( query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True ) @@ -115,7 +115,6 @@ def kernel_unified_attention_2d( else: seq_idx = q_block_global_idx q_block_start_idx = seq_idx - seq_idx = seq_idx + seq_idx_offset q_block_local_idx = q_block_global_idx - q_block_start_idx @@ -405,13 +404,13 @@ def kernel_unified_attention_3d( BLOCK_M: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int seq_idx_offset, # int - is_prefill: tl.constexpr, + decode_only: tl.constexpr, ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) - if is_prefill: + if not decode_only: seq_idx = find_seq_idx( query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True ) @@ -675,7 +674,7 @@ def reduce_segments( BLOCK_Q: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int seq_idx_offset, # int - is_prefill: tl.constexpr, + decode_only: tl.constexpr, USE_FP8: tl.constexpr, # bool FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, @@ -683,7 +682,7 @@ def reduce_segments( query_token_idx = tl.program_id(0) query_head_idx = tl.program_id(1) - if is_prefill: + if not decode_only: seq_idx = find_seq_idx( query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False ) @@ -769,6 +768,7 @@ def unified_attention( k_descale, v_descale, seq_threshold_3D, + split_launch, alibi_slopes=None, output_scale=None, qq_bias=None, @@ -796,17 +796,6 @@ def unified_attention( ) BLOCK_Q = BLOCK_M // num_queries_per_kv - # Ideally we would launch with kernel with: - # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. - # However, it is slow to realize the query_lens on cpu. - # Instead we use upper-bound: - # \sum_i[ceil(query_len[i] / BLOCK_Q)] - # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] - # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs - # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs - # = floor(q.shape[0] / BLOCK_Q) + num_seqs - total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs - # Assigning default tile sizes for prefill and decode. # Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1) # and at least 16 for all other data types. @@ -814,8 +803,24 @@ def unified_attention( TILE_SIZE_2D_DECODE = 32 TILE_SIZE_3D_DECODE = 16 if q.element_size() >= 2 else 32 - # prefills if num_seqs > num_decodes: + # batch contains prefills + + # Ideally we would launch with kernel with: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # However, it is slow to realize the query_lens on cpu. + # Instead we use upper-bound: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # = floor(q.shape[0] / BLOCK_Q) + num_seqs + total_num_q_blocks = ( + (q.shape[0] - num_decodes) // BLOCK_Q + num_seqs + if split_launch + else q.shape[0] // BLOCK_Q + num_seqs + ) + kernel_unified_attention_2d[ ( total_num_q_blocks, @@ -865,68 +870,126 @@ def unified_attention( BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, BLOCK_M=BLOCK_M, - seq_idx_offset=num_decodes - num_decodes, - is_prefill=True, + q_block_offset=num_decodes if split_launch else 0, + decode_only=False, USE_FP8=output_scale is not None, ) - # decodes - if num_decodes > seq_threshold_3D: - kernel_unified_attention_2d[ - ( - num_decodes, - num_kv_heads, + if num_decodes > 0 and split_launch: + # batch contains decodes that are not processed in unified fashion + kernel_unified_attention_2d[ + ( + num_decodes, + num_kv_heads, + ) + ]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + out_scale=1 / output_scale if output_scale is not None else 1.0, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_2D_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_decodes, + BLOCK_M=BLOCK_M, + q_block_offset=0, + decode_only=True, + USE_FP8=output_scale is not None, ) - ]( - output_ptr=out, - query_ptr=q, - key_cache_ptr=k, - value_cache_ptr=v, - sink_ptr=sinks, - block_tables_ptr=block_table, - seq_lens_ptr=seqused_k, - alibi_slopes_ptr=alibi_slopes, - qq_bias_ptr=qq_bias, - scale=softmax_scale, - k_scale=k_descale, - v_scale=v_descale, - out_scale=1 / output_scale if output_scale is not None else 1.0, - softcap=softcap, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - block_table_stride=block_table.stride(0), - query_stride_0=q.stride(0), - query_stride_1=q.stride(1), - output_stride_0=out.stride(0), - output_stride_1=out.stride(1), - qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, - BLOCK_SIZE=block_size, - TILE_SIZE=TILE_SIZE_2D_DECODE, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - USE_ALIBI_SLOPES=use_alibi_slopes, - USE_QQ_BIAS=use_qq_bias, - USE_SOFTCAP=(softcap > 0), - USE_SINKS=(sinks is not None), - SLIDING_WINDOW=(1 + window_size[0]), - stride_k_cache_0=k.stride(0), - stride_k_cache_1=k.stride(1), - stride_k_cache_2=k.stride(2), - stride_k_cache_3=k.stride(3), - stride_v_cache_0=v.stride(0), - stride_v_cache_1=v.stride(1), - stride_v_cache_2=v.stride(2), - stride_v_cache_3=v.stride(3), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - num_seqs=num_decodes, - BLOCK_M=BLOCK_M, - seq_idx_offset=0, - is_prefill=False, - USE_FP8=output_scale is not None, - ) else: - if num_decodes > 0: + # decode-only batch + + if num_decodes > seq_threshold_3D: + # use 2D kernel for decode-only batch + kernel_unified_attention_2d[ + ( + num_decodes, + num_kv_heads, + ) + ]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + out_scale=1 / output_scale if output_scale is not None else 1.0, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_2D_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_decodes, + BLOCK_M=BLOCK_M, + q_block_offset=0, + decode_only=True, + USE_FP8=output_scale is not None, + ) + else: + # use 3D kernel for decode-only batch # for initial version, NUM_SEGMENTS = 16 is chosen as a default # value that showed good performance in tests NUM_SEGMENTS = 16 @@ -999,7 +1062,7 @@ def unified_attention( BLOCK_M=BLOCK_M, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, seq_idx_offset=0, - is_prefill=False, + decode_only=True, ) reduce_segments[(num_decodes, num_query_heads)]( @@ -1021,6 +1084,6 @@ def unified_attention( BLOCK_Q=BLOCK_Q, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, seq_idx_offset=0, - is_prefill=False, + decode_only=True, USE_FP8=output_scale is not None, ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index f1fd58a48fda..825df7390d4a 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -61,6 +61,7 @@ class TritonAttentionMetadata: slot_mapping: torch.Tensor seq_threshold_3D: int + split_launch: bool # For cascade attention. use_cascade: bool @@ -106,6 +107,13 @@ def __init__( ) ) + # Check if CUDA Graphs are enabled for prefill + self.prefill_cudagraph_enabled = ( + self.vllm_config.compilation_config.cudagraph_mode in (CUDAGraphMode.FULL,) + ) + + self.split_launch = self.prefill_cudagraph_enabled + # The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv). # A lower bound for num_q_blocks is the number of sequences. # To ensure the minimum launch grid size is achieved, the number of sequences @@ -205,6 +213,7 @@ def build( suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, seq_threshold_3D=self.seq_threshold_3D, + split_launch=self.split_launch, ) return attn_metadata @@ -410,6 +419,7 @@ def forward( block_table = attn_metadata.block_table seq_threshold_3D = attn_metadata.seq_threshold_3D + split_launch = attn_metadata.split_launch descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) @@ -433,6 +443,7 @@ def forward( k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), seq_threshold_3D=seq_threshold_3D, + split_launch=split_launch, sinks=self.sinks, output_scale=output_scale, ) From 00e011a8740ddc14414efe4921ed266824defe2d Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Wed, 19 Nov 2025 12:48:33 -0500 Subject: [PATCH 04/13] small modifications Signed-off-by: Jan van Lunteren --- vllm/attention/ops/triton_unified_attention.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index c185009f3157..759f143af4f1 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -403,10 +403,10 @@ def kernel_unified_attention_3d( num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int - seq_idx_offset, # int + q_block_offset, # int decode_only: tl.constexpr, ): - q_block_global_idx = tl.program_id(0) + q_block_global_idx = tl.program_id(0) + q_block_offset kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) @@ -418,7 +418,6 @@ def kernel_unified_attention_3d( else: seq_idx = q_block_global_idx q_block_start_idx = seq_idx - seq_idx = seq_idx + seq_idx_offset q_block_local_idx = q_block_global_idx - q_block_start_idx @@ -673,13 +672,13 @@ def reduce_segments( query_start_len_ptr, # [num_seqs+1] BLOCK_Q: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int - seq_idx_offset, # int + query_token_idx_offset, # int decode_only: tl.constexpr, USE_FP8: tl.constexpr, # bool FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, ): - query_token_idx = tl.program_id(0) + query_token_idx = tl.program_id(0) + query_token_idx_offset query_head_idx = tl.program_id(1) if not decode_only: @@ -688,7 +687,6 @@ def reduce_segments( ) else: seq_idx = query_token_idx - seq_idx = seq_idx + seq_idx_offset # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -1061,7 +1059,7 @@ def unified_attention( num_seqs=num_decodes, BLOCK_M=BLOCK_M, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, - seq_idx_offset=0, + q_block_offset=0, decode_only=True, ) @@ -1083,7 +1081,7 @@ def unified_attention( query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, - seq_idx_offset=0, + query_token_idx_offset=0, decode_only=True, USE_FP8=output_scale is not None, ) From ab68bf1a3b9e082d8ded8c4062b32bcffbf9971e Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Thu, 20 Nov 2025 02:32:04 -0500 Subject: [PATCH 05/13] address gemini-code-assist feedback Signed-off-by: Jan van Lunteren --- .../attention/ops/triton_unified_attention.py | 80 ++++--------------- 1 file changed, 17 insertions(+), 63 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 759f143af4f1..603664a260c8 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -109,7 +109,11 @@ def kernel_unified_attention_2d( if not decode_only: seq_idx = find_seq_idx( - query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + query_start_len_ptr, + q_block_global_idx, + num_seqs + q_block_offset, + BLOCK_Q, + True, ) q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx else: @@ -412,7 +416,11 @@ def kernel_unified_attention_3d( if not decode_only: seq_idx = find_seq_idx( - query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + query_start_len_ptr, + q_block_global_idx, + num_seqs + q_block_offset, + BLOCK_Q, + True, ) q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx else: @@ -814,7 +822,7 @@ def unified_attention( # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs # = floor(q.shape[0] / BLOCK_Q) + num_seqs total_num_q_blocks = ( - (q.shape[0] - num_decodes) // BLOCK_Q + num_seqs + (q.shape[0] - num_decodes) // BLOCK_Q + num_seqs - num_decodes if split_launch else q.shape[0] // BLOCK_Q + num_seqs ) @@ -866,73 +874,18 @@ def unified_attention( stride_v_cache_3=v.stride(3), query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, - num_seqs=num_seqs, + num_seqs=num_seqs - num_decodes if split_launch else num_seqs, BLOCK_M=BLOCK_M, q_block_offset=num_decodes if split_launch else 0, decode_only=False, USE_FP8=output_scale is not None, ) - if num_decodes > 0 and split_launch: - # batch contains decodes that are not processed in unified fashion - kernel_unified_attention_2d[ - ( - num_decodes, - num_kv_heads, - ) - ]( - output_ptr=out, - query_ptr=q, - key_cache_ptr=k, - value_cache_ptr=v, - sink_ptr=sinks, - block_tables_ptr=block_table, - seq_lens_ptr=seqused_k, - alibi_slopes_ptr=alibi_slopes, - qq_bias_ptr=qq_bias, - scale=softmax_scale, - k_scale=k_descale, - v_scale=v_descale, - out_scale=1 / output_scale if output_scale is not None else 1.0, - softcap=softcap, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - block_table_stride=block_table.stride(0), - query_stride_0=q.stride(0), - query_stride_1=q.stride(1), - output_stride_0=out.stride(0), - output_stride_1=out.stride(1), - qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, - BLOCK_SIZE=block_size, - TILE_SIZE=TILE_SIZE_2D_DECODE, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - USE_ALIBI_SLOPES=use_alibi_slopes, - USE_QQ_BIAS=use_qq_bias, - USE_SOFTCAP=(softcap > 0), - USE_SINKS=(sinks is not None), - SLIDING_WINDOW=(1 + window_size[0]), - stride_k_cache_0=k.stride(0), - stride_k_cache_1=k.stride(1), - stride_k_cache_2=k.stride(2), - stride_k_cache_3=k.stride(3), - stride_v_cache_0=v.stride(0), - stride_v_cache_1=v.stride(1), - stride_v_cache_2=v.stride(2), - stride_v_cache_3=v.stride(3), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - num_seqs=num_decodes, - BLOCK_M=BLOCK_M, - q_block_offset=0, - decode_only=True, - USE_FP8=output_scale is not None, - ) - else: - # decode-only batch + if num_decodes > 0 or (num_seqs > num_decodes and split_launch): + # batch contains decodes that are not processed in unified fashion if num_decodes > seq_threshold_3D: - # use 2D kernel for decode-only batch + # use 2D kernel kernel_unified_attention_2d[ ( num_decodes, @@ -987,7 +940,8 @@ def unified_attention( USE_FP8=output_scale is not None, ) else: - # use 3D kernel for decode-only batch + # use 3D kernel + # for initial version, NUM_SEGMENTS = 16 is chosen as a default # value that showed good performance in tests NUM_SEGMENTS = 16 From 41ab5213c00829330eab029fda6a136cf9d354d7 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Sat, 29 Nov 2025 09:45:08 -0500 Subject: [PATCH 06/13] partial code reorganisation Signed-off-by: Jan van Lunteren --- .../attention/ops/triton_unified_attention.py | 325 +++++++++++++++--- vllm/v1/attention/backends/triton_attn.py | 40 +-- 2 files changed, 300 insertions(+), 65 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 603664a260c8..8844fae51d50 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -53,8 +53,14 @@ def find_seq_idx( return left - 1 + @triton.jit -def kernel_unified_attention_2d( +def unified_attention_2d( + kv_head_idx, # int + seq_idx, # int + q_block_local_idx, # int + cur_batch_in_all_start_index, # int + cur_batch_query_len, # int output_ptr, # [num_tokens, num_query_heads, head_size] query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] @@ -99,37 +105,10 @@ def kernel_unified_attention_2d( num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int q_block_offset, # int - decode_only: tl.constexpr, USE_FP8: tl.constexpr, # bool FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, ): - q_block_global_idx = tl.program_id(0) + q_block_offset - kv_head_idx = tl.program_id(1) - - if not decode_only: - seq_idx = find_seq_idx( - query_start_len_ptr, - q_block_global_idx, - num_seqs + q_block_offset, - BLOCK_Q, - True, - ) - q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx - else: - seq_idx = q_block_global_idx - q_block_start_idx = seq_idx - - q_block_local_idx = q_block_global_idx - q_block_start_idx - - cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) - cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - - cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index - - if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: - return - offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) offs_t = tl.arange(0, TILE_SIZE) @@ -361,6 +340,254 @@ def kernel_unified_attention_2d( ) +@triton.jit +def kernel_mixed_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + out_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + q_block_offset, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, +): + q_block_global_idx = tl.program_id(0) + q_block_offset + kv_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx( + query_start_len_ptr, + q_block_global_idx, + num_seqs + q_block_offset, + BLOCK_Q, + True, + ) + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + unified_attention_2d( + kv_head_idx=kv_head_idx, + seq_idx=seq_idx, + q_block_local_idx=q_block_local_idx, + cur_batch_in_all_start_index=cur_batch_in_all_start_index, + cur_batch_query_len=cur_batch_query_len, + output_ptr=output_ptr, + query_ptr=query_ptr, + key_cache_ptr=key_cache_ptr, + value_cache_ptr=value_cache_ptr, + sink_ptr=sink_ptr, + block_tables_ptr=block_tables_ptr, + seq_lens_ptr=seq_lens_ptr, + alibi_slopes_ptr=alibi_slopes_ptr, + qq_bias_ptr=qq_bias_ptr, + scale=scale, + k_scale=k_scale, + v_scale=v_scale, + out_scale=out_scale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table_stride, + query_stride_0=query_stride_0, + query_stride_1=query_stride_1, + output_stride_0=output_stride_0, + output_stride_1=output_stride_1, + qq_bias_stride_0=qq_bias_stride_0, + BLOCK_SIZE=BLOCK_SIZE, + TILE_SIZE=TILE_SIZE, + HEAD_SIZE=HEAD_SIZE, + HEAD_SIZE_PADDED=HEAD_SIZE_PADDED, + USE_ALIBI_SLOPES=USE_ALIBI_SLOPES, + USE_QQ_BIAS=USE_QQ_BIAS, + USE_SOFTCAP=USE_SOFTCAP, + USE_SINKS=USE_SINKS, + SLIDING_WINDOW=SLIDING_WINDOW, + stride_k_cache_0=stride_k_cache_0, + stride_k_cache_1=stride_k_cache_1, + stride_k_cache_2=stride_k_cache_2, + stride_k_cache_3=stride_k_cache_3, + stride_v_cache_0=stride_v_cache_0, + stride_v_cache_1=stride_v_cache_1, + stride_v_cache_2=stride_v_cache_2, + stride_v_cache_3=stride_v_cache_3, + query_start_len_ptr=query_start_len_ptr, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + q_block_offset=q_block_offset, + USE_FP8=USE_FP8, + FP8_MIN=FP8_MIN, + FP8_MAX=FP8_MAX, + ) + + +@triton.jit +def kernel_decode_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + out_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + q_block_offset, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, +): + q_block_global_idx = tl.program_id(0) + q_block_offset + kv_head_idx = tl.program_id(1) + + seq_idx = q_block_global_idx + q_block_start_idx = seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + unified_attention_2d( + kv_head_idx=kv_head_idx, + seq_idx=seq_idx, + q_block_local_idx=q_block_local_idx, + cur_batch_in_all_start_index=cur_batch_in_all_start_index, + cur_batch_query_len=cur_batch_query_len, + output_ptr=output_ptr, + query_ptr=query_ptr, + key_cache_ptr=key_cache_ptr, + value_cache_ptr=value_cache_ptr, + sink_ptr=sink_ptr, + block_tables_ptr=block_tables_ptr, + seq_lens_ptr=seq_lens_ptr, + alibi_slopes_ptr=alibi_slopes_ptr, + qq_bias_ptr=qq_bias_ptr, + scale=scale, + k_scale=k_scale, + v_scale=v_scale, + out_scale=out_scale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table_stride, + query_stride_0=query_stride_0, + query_stride_1=query_stride_1, + output_stride_0=output_stride_0, + output_stride_1=output_stride_1, + qq_bias_stride_0=qq_bias_stride_0, + BLOCK_SIZE=BLOCK_SIZE, + TILE_SIZE=TILE_SIZE, + HEAD_SIZE=HEAD_SIZE, + HEAD_SIZE_PADDED=HEAD_SIZE_PADDED, + USE_ALIBI_SLOPES=USE_ALIBI_SLOPES, + USE_QQ_BIAS=USE_QQ_BIAS, + USE_SOFTCAP=USE_SOFTCAP, + USE_SINKS=USE_SINKS, + SLIDING_WINDOW=SLIDING_WINDOW, + stride_k_cache_0=stride_k_cache_0, + stride_k_cache_1=stride_k_cache_1, + stride_k_cache_2=stride_k_cache_2, + stride_k_cache_3=stride_k_cache_3, + stride_v_cache_0=stride_v_cache_0, + stride_v_cache_1=stride_v_cache_1, + stride_v_cache_2=stride_v_cache_2, + stride_v_cache_3=stride_v_cache_3, + query_start_len_ptr=query_start_len_ptr, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + q_block_offset=q_block_offset, + USE_FP8=USE_FP8, + FP8_MIN=FP8_MIN, + FP8_MAX=FP8_MAX, + ) + @triton.jit def kernel_unified_attention_3d( segm_output_ptr, @@ -764,7 +991,6 @@ def unified_attention( max_seqlen_q, seqused_k, max_seqlen_k, - num_decodes, softmax_scale, causal, window_size, @@ -773,8 +999,9 @@ def unified_attention( q_descale, k_descale, v_descale, - seq_threshold_3D, - split_launch, + num_decodes=None, + seq_threshold_3D=None, + split_launch=None, alibi_slopes=None, output_scale=None, qq_bias=None, @@ -797,10 +1024,20 @@ def unified_attention( num_queries_per_kv = num_query_heads // num_kv_heads head_size = q.shape[2] - BLOCK_M = ( - 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) - ) - BLOCK_Q = BLOCK_M // num_queries_per_kv + # Assign the following variables if they are not assigned in the attention metadata. + # This ensures backward compatibility with callers using an earlier version of this + # function. However, it is recommended to include these assignments in the + # attention metadata itself, as performing them here may negatively impact + # performance. + if ( + seq_threshold_3D is None + or split_launch is None + or num_decodes is None + ): + seq_threshold_3D = 128 // num_kv_heads + split_launch = False + seq_lens = torch.diff(cu_seqlens_q) + num_decodes = (seq_lens == 1).sum().item() # Assigning default tile sizes for prefill and decode. # Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1) @@ -812,6 +1049,11 @@ def unified_attention( if num_seqs > num_decodes: # batch contains prefills + BLOCK_M = ( + 64 if num_queries_per_kv <= 64 else triton.next_power_of_2(num_queries_per_kv) + ) + BLOCK_Q = BLOCK_M // num_queries_per_kv + # Ideally we would launch with kernel with: # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. # However, it is slow to realize the query_lens on cpu. @@ -827,7 +1069,7 @@ def unified_attention( else q.shape[0] // BLOCK_Q + num_seqs ) - kernel_unified_attention_2d[ + kernel_mixed_attention_2d[ ( total_num_q_blocks, num_kv_heads, @@ -877,16 +1119,21 @@ def unified_attention( num_seqs=num_seqs - num_decodes if split_launch else num_seqs, BLOCK_M=BLOCK_M, q_block_offset=num_decodes if split_launch else 0, - decode_only=False, + #decode_only=False, USE_FP8=output_scale is not None, ) if num_decodes > 0 or (num_seqs > num_decodes and split_launch): # batch contains decodes that are not processed in unified fashion + BLOCK_M = ( + 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + ) + BLOCK_Q = BLOCK_M // num_queries_per_kv + if num_decodes > seq_threshold_3D: # use 2D kernel - kernel_unified_attention_2d[ + kernel_decode_attention_2d[ ( num_decodes, num_kv_heads, @@ -936,7 +1183,7 @@ def unified_attention( num_seqs=num_decodes, BLOCK_M=BLOCK_M, q_block_offset=0, - decode_only=True, + #decode_only=True, USE_FP8=output_scale is not None, ) else: diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 6345b4f7c85d..be9c6a341a5b 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -56,10 +56,10 @@ class TritonAttentionMetadata: query_start_loc: torch.Tensor max_seq_len: int seq_lens: torch.Tensor - num_decodes: int block_table: torch.Tensor slot_mapping: torch.Tensor + num_decodes: int seq_threshold_3D: int split_launch: bool @@ -107,7 +107,7 @@ def __init__( ) ) - # Check if CUDA Graphs are enabled for prefill + # Check if CUDA Graphs are enabled for prefill. self.prefill_cudagraph_enabled = ( self.vllm_config.compilation_config.cudagraph_mode in (CUDAGraphMode.FULL,) ) @@ -125,27 +125,15 @@ def __init__( # Modify the threshold if needed. if self.decode_cudagraph_enabled: capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes - if not capture_sizes: - # If no CUDA Graph capture sizes are specified, the threshold - # is reset to zero, forcing the 2D kernel to be used. - self.seq_threshold_3D = 0 - else: - # Select the CUDA Graph capture size closest to self.seq_threshold_3D - # as threshold. This ensures that each captured graph covers the - # correct execution path. - upd_seq_threshold_3D = min( - capture_sizes, - key=lambda x: abs(x - self.seq_threshold_3D), - ) - - # If the updated threshold becomes significantly larger than the - # initial value, it is reset to zero. This enforces the use of the - # 2D kernel only and ensures that the size of the allocated - # intermediate structures remains bounded. - if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D: - self.seq_threshold_3D = upd_seq_threshold_3D - else: - self.seq_threshold_3D = 0 + assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified." + + # Select the CUDA Graph capture size closest to self.seq_threshold_3D + # as threshold. This ensures that each captured graph covers the + # correct execution path. + self.seq_threshold_3D = min( + capture_sizes, + key=lambda x: abs(x - self.seq_threshold_3D), + ) def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata @@ -203,7 +191,6 @@ def build( query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, - num_decodes=num_decodes, block_table=block_table_tensor, slot_mapping=slot_mapping, use_cascade=use_cascade, @@ -212,6 +199,7 @@ def build( prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, + num_decodes=num_decodes, seq_threshold_3D=self.seq_threshold_3D, split_launch=self.split_launch, ) @@ -416,12 +404,12 @@ def forward( ) cu_seqlens_q = attn_metadata.query_start_loc - num_decodes = attn_metadata.num_decodes seqused_k = attn_metadata.seq_lens max_seqlen_q = attn_metadata.max_query_len max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table + num_decodes = attn_metadata.num_decodes seq_threshold_3D = attn_metadata.seq_threshold_3D split_launch = attn_metadata.split_launch @@ -434,7 +422,6 @@ def forward( out=output[:num_actual_tokens], cu_seqlens_q=cu_seqlens_q, max_seqlen_q=max_seqlen_q, - num_decodes=num_decodes, seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, @@ -446,6 +433,7 @@ def forward( q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + num_decodes=num_decodes, seq_threshold_3D=seq_threshold_3D, split_launch=split_launch, sinks=self.sinks, From a158ca9d0cd1df6c6286ec8e1846be971b9203a1 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Mon, 1 Dec 2025 01:31:48 -0500 Subject: [PATCH 07/13] formatting Signed-off-by: Jan van Lunteren --- vllm/attention/ops/triton_unified_attention.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 8844fae51d50..49261fccdcc2 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -53,7 +53,6 @@ def find_seq_idx( return left - 1 - @triton.jit def unified_attention_2d( kv_head_idx, # int @@ -588,6 +587,7 @@ def kernel_decode_attention_2d( FP8_MAX=FP8_MAX, ) + @triton.jit def kernel_unified_attention_3d( segm_output_ptr, @@ -1029,11 +1029,7 @@ def unified_attention( # function. However, it is recommended to include these assignments in the # attention metadata itself, as performing them here may negatively impact # performance. - if ( - seq_threshold_3D is None - or split_launch is None - or num_decodes is None - ): + if seq_threshold_3D is None or split_launch is None or num_decodes is None: seq_threshold_3D = 128 // num_kv_heads split_launch = False seq_lens = torch.diff(cu_seqlens_q) @@ -1050,7 +1046,9 @@ def unified_attention( # batch contains prefills BLOCK_M = ( - 64 if num_queries_per_kv <= 64 else triton.next_power_of_2(num_queries_per_kv) + 64 + if num_queries_per_kv <= 64 + else triton.next_power_of_2(num_queries_per_kv) ) BLOCK_Q = BLOCK_M // num_queries_per_kv @@ -1119,7 +1117,6 @@ def unified_attention( num_seqs=num_seqs - num_decodes if split_launch else num_seqs, BLOCK_M=BLOCK_M, q_block_offset=num_decodes if split_launch else 0, - #decode_only=False, USE_FP8=output_scale is not None, ) @@ -1127,7 +1124,9 @@ def unified_attention( # batch contains decodes that are not processed in unified fashion BLOCK_M = ( - 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + 16 + if num_queries_per_kv <= 16 + else triton.next_power_of_2(num_queries_per_kv) ) BLOCK_Q = BLOCK_M // num_queries_per_kv @@ -1183,7 +1182,6 @@ def unified_attention( num_seqs=num_decodes, BLOCK_M=BLOCK_M, q_block_offset=0, - #decode_only=True, USE_FP8=output_scale is not None, ) else: From 3c57c70d5fe2a5730e7e86f3c8395c95b127c054 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Mon, 1 Dec 2025 06:19:20 -0500 Subject: [PATCH 08/13] partial code reorganisation Signed-off-by: Jan van Lunteren --- .../test_triton_unified_attention.py | 4 +- .../attention/ops/triton_unified_attention.py | 348 ++++-------------- vllm/v1/attention/backends/triton_attn.py | 4 + 3 files changed, 72 insertions(+), 284 deletions(-) diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index 933045ab3d19..ed8048bc3b82 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -168,6 +168,7 @@ def test_triton_unified_attn( v_descale = torch.rand(scale_shape, dtype=torch.float32) num_decodes = num_seqs if max_query_len == 1 else query_lens.count(1) + num_prefills = num_seqs - num_decodes unified_attention( q=maybe_quantized_query, @@ -178,7 +179,6 @@ def test_triton_unified_attn( seqused_k=kv_lens, max_seqlen_q=max_query_len, max_seqlen_k=max_kv_len, - num_decodes=num_decodes, softmax_scale=scale, causal=True, window_size=window_size, @@ -187,6 +187,8 @@ def test_triton_unified_attn( q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + num_prefills=num_prefills, + num_decodes=num_decodes, seq_threshold_3D=seq_threshold_3D, split_launch=split_launch, ) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 49261fccdcc2..c76f3d7c5d0d 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -54,12 +54,7 @@ def find_seq_idx( @triton.jit -def unified_attention_2d( - kv_head_idx, # int - seq_idx, # int - q_block_local_idx, # int - cur_batch_in_all_start_index, # int - cur_batch_query_len, # int +def kernel_unified_attention_2d( output_ptr, # [num_tokens, num_query_heads, head_size] query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] @@ -104,10 +99,37 @@ def unified_attention_2d( num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int q_block_offset, # int + decode_only: tl.constexpr, USE_FP8: tl.constexpr, # bool FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, ): + q_block_global_idx = tl.program_id(0) + q_block_offset + kv_head_idx = tl.program_id(1) + + if not decode_only: + seq_idx = find_seq_idx( + query_start_len_ptr, + q_block_global_idx, + num_seqs + q_block_offset, + BLOCK_Q, + True, + ) + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + else: + seq_idx = q_block_global_idx + q_block_start_idx = seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) offs_t = tl.arange(0, TILE_SIZE) @@ -339,255 +361,6 @@ def unified_attention_2d( ) -@triton.jit -def kernel_mixed_attention_2d( - output_ptr, # [num_tokens, num_query_heads, head_size] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - sink_ptr, # [num_query_heads] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - qq_bias_ptr, # [num_query_tokens, num_query_tokens] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - out_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - qq_bias_stride_0: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - TILE_SIZE: tl.constexpr, # int must be power of 2 - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_QQ_BIAS: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - USE_SINKS: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.constexpr, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, - BLOCK_M: tl.constexpr, # int - q_block_offset, # int - USE_FP8: tl.constexpr, # bool - FP8_MIN: tl.constexpr = float8_info.min, - FP8_MAX: tl.constexpr = float8_info.max, -): - q_block_global_idx = tl.program_id(0) + q_block_offset - kv_head_idx = tl.program_id(1) - - seq_idx = find_seq_idx( - query_start_len_ptr, - q_block_global_idx, - num_seqs + q_block_offset, - BLOCK_Q, - True, - ) - q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx - q_block_local_idx = q_block_global_idx - q_block_start_idx - - cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) - cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - - cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index - - if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: - return - - unified_attention_2d( - kv_head_idx=kv_head_idx, - seq_idx=seq_idx, - q_block_local_idx=q_block_local_idx, - cur_batch_in_all_start_index=cur_batch_in_all_start_index, - cur_batch_query_len=cur_batch_query_len, - output_ptr=output_ptr, - query_ptr=query_ptr, - key_cache_ptr=key_cache_ptr, - value_cache_ptr=value_cache_ptr, - sink_ptr=sink_ptr, - block_tables_ptr=block_tables_ptr, - seq_lens_ptr=seq_lens_ptr, - alibi_slopes_ptr=alibi_slopes_ptr, - qq_bias_ptr=qq_bias_ptr, - scale=scale, - k_scale=k_scale, - v_scale=v_scale, - out_scale=out_scale, - softcap=softcap, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - block_table_stride=block_table_stride, - query_stride_0=query_stride_0, - query_stride_1=query_stride_1, - output_stride_0=output_stride_0, - output_stride_1=output_stride_1, - qq_bias_stride_0=qq_bias_stride_0, - BLOCK_SIZE=BLOCK_SIZE, - TILE_SIZE=TILE_SIZE, - HEAD_SIZE=HEAD_SIZE, - HEAD_SIZE_PADDED=HEAD_SIZE_PADDED, - USE_ALIBI_SLOPES=USE_ALIBI_SLOPES, - USE_QQ_BIAS=USE_QQ_BIAS, - USE_SOFTCAP=USE_SOFTCAP, - USE_SINKS=USE_SINKS, - SLIDING_WINDOW=SLIDING_WINDOW, - stride_k_cache_0=stride_k_cache_0, - stride_k_cache_1=stride_k_cache_1, - stride_k_cache_2=stride_k_cache_2, - stride_k_cache_3=stride_k_cache_3, - stride_v_cache_0=stride_v_cache_0, - stride_v_cache_1=stride_v_cache_1, - stride_v_cache_2=stride_v_cache_2, - stride_v_cache_3=stride_v_cache_3, - query_start_len_ptr=query_start_len_ptr, - BLOCK_Q=BLOCK_Q, - num_seqs=num_seqs, - BLOCK_M=BLOCK_M, - q_block_offset=q_block_offset, - USE_FP8=USE_FP8, - FP8_MIN=FP8_MIN, - FP8_MAX=FP8_MAX, - ) - - -@triton.jit -def kernel_decode_attention_2d( - output_ptr, # [num_tokens, num_query_heads, head_size] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - sink_ptr, # [num_query_heads] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - qq_bias_ptr, # [num_query_tokens, num_query_tokens] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - out_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - qq_bias_stride_0: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - TILE_SIZE: tl.constexpr, # int must be power of 2 - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_QQ_BIAS: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - USE_SINKS: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.constexpr, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, - BLOCK_M: tl.constexpr, # int - q_block_offset, # int - USE_FP8: tl.constexpr, # bool - FP8_MIN: tl.constexpr = float8_info.min, - FP8_MAX: tl.constexpr = float8_info.max, -): - q_block_global_idx = tl.program_id(0) + q_block_offset - kv_head_idx = tl.program_id(1) - - seq_idx = q_block_global_idx - q_block_start_idx = seq_idx - - q_block_local_idx = q_block_global_idx - q_block_start_idx - - cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) - cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - - cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index - - if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: - return - - unified_attention_2d( - kv_head_idx=kv_head_idx, - seq_idx=seq_idx, - q_block_local_idx=q_block_local_idx, - cur_batch_in_all_start_index=cur_batch_in_all_start_index, - cur_batch_query_len=cur_batch_query_len, - output_ptr=output_ptr, - query_ptr=query_ptr, - key_cache_ptr=key_cache_ptr, - value_cache_ptr=value_cache_ptr, - sink_ptr=sink_ptr, - block_tables_ptr=block_tables_ptr, - seq_lens_ptr=seq_lens_ptr, - alibi_slopes_ptr=alibi_slopes_ptr, - qq_bias_ptr=qq_bias_ptr, - scale=scale, - k_scale=k_scale, - v_scale=v_scale, - out_scale=out_scale, - softcap=softcap, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - block_table_stride=block_table_stride, - query_stride_0=query_stride_0, - query_stride_1=query_stride_1, - output_stride_0=output_stride_0, - output_stride_1=output_stride_1, - qq_bias_stride_0=qq_bias_stride_0, - BLOCK_SIZE=BLOCK_SIZE, - TILE_SIZE=TILE_SIZE, - HEAD_SIZE=HEAD_SIZE, - HEAD_SIZE_PADDED=HEAD_SIZE_PADDED, - USE_ALIBI_SLOPES=USE_ALIBI_SLOPES, - USE_QQ_BIAS=USE_QQ_BIAS, - USE_SOFTCAP=USE_SOFTCAP, - USE_SINKS=USE_SINKS, - SLIDING_WINDOW=SLIDING_WINDOW, - stride_k_cache_0=stride_k_cache_0, - stride_k_cache_1=stride_k_cache_1, - stride_k_cache_2=stride_k_cache_2, - stride_k_cache_3=stride_k_cache_3, - stride_v_cache_0=stride_v_cache_0, - stride_v_cache_1=stride_v_cache_1, - stride_v_cache_2=stride_v_cache_2, - stride_v_cache_3=stride_v_cache_3, - query_start_len_ptr=query_start_len_ptr, - BLOCK_Q=BLOCK_Q, - num_seqs=num_seqs, - BLOCK_M=BLOCK_M, - q_block_offset=q_block_offset, - USE_FP8=USE_FP8, - FP8_MIN=FP8_MIN, - FP8_MAX=FP8_MAX, - ) - - @triton.jit def kernel_unified_attention_3d( segm_output_ptr, @@ -999,6 +772,7 @@ def unified_attention( q_descale, k_descale, v_descale, + num_prefills=None, num_decodes=None, seq_threshold_3D=None, split_launch=None, @@ -1029,12 +803,32 @@ def unified_attention( # function. However, it is recommended to include these assignments in the # attention metadata itself, as performing them here may negatively impact # performance. - if seq_threshold_3D is None or split_launch is None or num_decodes is None: + if ( + seq_threshold_3D is None + or split_launch is None + or num_prefills is None + or num_decodes is None + ): seq_threshold_3D = 128 // num_kv_heads split_launch = False seq_lens = torch.diff(cu_seqlens_q) + num_prefills = (seq_lens > 1).sum().item() num_decodes = (seq_lens == 1).sum().item() + # Assigning Q Block dimensions for prefill and decode. + BLOCK_M_2D_PREFILL = ( + 64 if num_queries_per_kv <= 64 else triton.next_power_of_2(num_queries_per_kv) + ) + BLOCK_M_2D_DECODE = ( + 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + ) + BLOCK_M_3D_DECODE = ( + 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + ) + BLOCK_Q_2D_PREFILL = BLOCK_M_2D_PREFILL // num_queries_per_kv + BLOCK_Q_2D_DECODE = BLOCK_M_2D_DECODE // num_queries_per_kv + BLOCK_Q_3D_DECODE = BLOCK_M_3D_DECODE // num_queries_per_kv + # Assigning default tile sizes for prefill and decode. # Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1) # and at least 16 for all other data types. @@ -1042,16 +836,9 @@ def unified_attention( TILE_SIZE_2D_DECODE = 32 TILE_SIZE_3D_DECODE = 16 if q.element_size() >= 2 else 32 - if num_seqs > num_decodes: + if num_prefills > 0: # batch contains prefills - BLOCK_M = ( - 64 - if num_queries_per_kv <= 64 - else triton.next_power_of_2(num_queries_per_kv) - ) - BLOCK_Q = BLOCK_M // num_queries_per_kv - # Ideally we would launch with kernel with: # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. # However, it is slow to realize the query_lens on cpu. @@ -1062,12 +849,12 @@ def unified_attention( # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs # = floor(q.shape[0] / BLOCK_Q) + num_seqs total_num_q_blocks = ( - (q.shape[0] - num_decodes) // BLOCK_Q + num_seqs - num_decodes + (q.shape[0] - num_decodes) // BLOCK_Q_2D_PREFILL + num_seqs - num_decodes if split_launch - else q.shape[0] // BLOCK_Q + num_seqs + else q.shape[0] // BLOCK_Q_2D_PREFILL + num_seqs ) - kernel_mixed_attention_2d[ + kernel_unified_attention_2d[ ( total_num_q_blocks, num_kv_heads, @@ -1113,26 +900,20 @@ def unified_attention( stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, + BLOCK_Q=BLOCK_Q_2D_PREFILL, num_seqs=num_seqs - num_decodes if split_launch else num_seqs, - BLOCK_M=BLOCK_M, + BLOCK_M=BLOCK_M_2D_PREFILL, q_block_offset=num_decodes if split_launch else 0, + decode_only=False, USE_FP8=output_scale is not None, ) - if num_decodes > 0 or (num_seqs > num_decodes and split_launch): + if (num_decodes > 0) and ((num_prefills == 0) or split_launch): # batch contains decodes that are not processed in unified fashion - BLOCK_M = ( - 16 - if num_queries_per_kv <= 16 - else triton.next_power_of_2(num_queries_per_kv) - ) - BLOCK_Q = BLOCK_M // num_queries_per_kv - if num_decodes > seq_threshold_3D: # use 2D kernel - kernel_decode_attention_2d[ + kernel_unified_attention_2d[ ( num_decodes, num_kv_heads, @@ -1178,10 +959,11 @@ def unified_attention( stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, + BLOCK_Q=BLOCK_Q_2D_DECODE, num_seqs=num_decodes, - BLOCK_M=BLOCK_M, + BLOCK_M=BLOCK_M_2D_DECODE, q_block_offset=0, + decode_only=True, USE_FP8=output_scale is not None, ) else: @@ -1254,9 +1036,9 @@ def unified_attention( stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, + BLOCK_Q=BLOCK_Q_3D_DECODE, num_seqs=num_decodes, - BLOCK_M=BLOCK_M, + BLOCK_M=BLOCK_M_3D_DECODE, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, q_block_offset=0, decode_only=True, @@ -1278,7 +1060,7 @@ def unified_attention( HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, + BLOCK_Q=BLOCK_Q_3D_DECODE, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, query_token_idx_offset=0, decode_only=True, diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index be9c6a341a5b..1dfaa4597e32 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -59,6 +59,7 @@ class TritonAttentionMetadata: block_table: torch.Tensor slot_mapping: torch.Tensor + num_prefills: int num_decodes: int seq_threshold_3D: int split_launch: bool @@ -199,6 +200,7 @@ def build( prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, + num_prefills=num_prefills, num_decodes=num_decodes, seq_threshold_3D=self.seq_threshold_3D, split_launch=self.split_launch, @@ -409,6 +411,7 @@ def forward( max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table + num_prefills = attn_metadata.num_prefills num_decodes = attn_metadata.num_decodes seq_threshold_3D = attn_metadata.seq_threshold_3D split_launch = attn_metadata.split_launch @@ -433,6 +436,7 @@ def forward( q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + num_prefills=num_prefills, num_decodes=num_decodes, seq_threshold_3D=seq_threshold_3D, split_launch=split_launch, From 90bea7b753929281a6a913c3e30e9e4152a6efe5 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Tue, 2 Dec 2025 04:44:47 -0500 Subject: [PATCH 09/13] modified _cudagraph_support Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 1dfaa4597e32..f5ef894eb9eb 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -77,7 +77,7 @@ class TritonAttentionMetadata: class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): - _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH reorder_batch_threshold: int = 1 def __init__( From ff46b621e1879aca83369761c1b88a127c42878b Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Tue, 2 Dec 2025 05:59:40 -0500 Subject: [PATCH 10/13] replace _cudagraph_support modification by assert statement Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index f5ef894eb9eb..97b0e74fb82d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -77,7 +77,7 @@ class TritonAttentionMetadata: class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): - _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS reorder_batch_threshold: int = 1 def __init__( @@ -112,6 +112,16 @@ def __init__( self.prefill_cudagraph_enabled = ( self.vllm_config.compilation_config.cudagraph_mode in (CUDAGraphMode.FULL,) ) + speculative_config = vllm_config.speculative_config + self.num_spec_tokens = ( + speculative_config.num_speculative_tokens + if speculative_config is not None + else 0 + ) + assert not (self.prefill_cudagraph_enabled and (self.num_spec_tokens > 0)), ( + "Triton Attention Backend does currently not support FULL CUDA Graph mode " + "when combined with speculative decoding." + ) self.split_launch = self.prefill_cudagraph_enabled From 91cb2bb578b7d23543dc890172758f4a13ec7d8b Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Thu, 4 Dec 2025 03:54:40 -0500 Subject: [PATCH 11/13] override get_cudagraph_support() Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 35 +++++++++++++++-------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 97b0e74fb82d..debeb36cbce3 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -77,7 +77,6 @@ class TritonAttentionMetadata: class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): - _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS reorder_batch_threshold: int = 1 def __init__( @@ -112,17 +111,6 @@ def __init__( self.prefill_cudagraph_enabled = ( self.vllm_config.compilation_config.cudagraph_mode in (CUDAGraphMode.FULL,) ) - speculative_config = vllm_config.speculative_config - self.num_spec_tokens = ( - speculative_config.num_speculative_tokens - if speculative_config is not None - else 0 - ) - assert not (self.prefill_cudagraph_enabled and (self.num_spec_tokens > 0)), ( - "Triton Attention Backend does currently not support FULL CUDA Graph mode " - "when combined with speculative decoding." - ) - self.split_launch = self.prefill_cudagraph_enabled # The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv). @@ -146,6 +134,29 @@ def __init__( key=lambda x: abs(x - self.seq_threshold_3D), ) + @classmethod + def get_cudagraph_support( + cls: type["TritonAttentionMetadataBuilder"], + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + + # Check if CUDA Graphs are enabled for prefill. + prefill_cudagraph_enabled = ( + vllm_config.compilation_config.cudagraph_mode in (CUDAGraphMode.FULL,) + ) + speculative_config = vllm_config.speculative_config + num_spec_tokens = ( + speculative_config.num_speculative_tokens + if speculative_config is not None + else 0 + ) + + if prefill_cudagraph_enabled and (num_spec_tokens > 0): + return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + else: + return AttentionCGSupport.ALWAYS + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TritonAttentionMetadata: From 3aad07802daea0ac9a4f2fdba75581e4e1074437 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Thu, 4 Dec 2025 04:08:25 -0500 Subject: [PATCH 12/13] Add comment Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index debeb36cbce3..2fd2a2a68c56 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -145,6 +145,8 @@ def get_cudagraph_support( prefill_cudagraph_enabled = ( vllm_config.compilation_config.cudagraph_mode in (CUDAGraphMode.FULL,) ) + + # Determine number of speculative tokens. speculative_config = vllm_config.speculative_config num_spec_tokens = ( speculative_config.num_speculative_tokens @@ -152,6 +154,7 @@ def get_cudagraph_support( else 0 ) + # Select the appropriate CUDA graph support mode. if prefill_cudagraph_enabled and (num_spec_tokens > 0): return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE else: From 64f5d578655a274024afd7dae2b9663cb5ab3000 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Thu, 4 Dec 2025 04:43:41 -0500 Subject: [PATCH 13/13] formatting Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 2fd2a2a68c56..204b92fa3c18 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -140,10 +140,9 @@ def get_cudagraph_support( vllm_config: VllmConfig, kv_cache_spec: AttentionSpec, ) -> AttentionCGSupport: - # Check if CUDA Graphs are enabled for prefill. - prefill_cudagraph_enabled = ( - vllm_config.compilation_config.cudagraph_mode in (CUDAGraphMode.FULL,) + prefill_cudagraph_enabled = vllm_config.compilation_config.cudagraph_mode in ( + CUDAGraphMode.FULL, ) # Determine number of speculative tokens.