-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Kernel] Separate Triton Attention Kernel Launches for Prefill and Decode for FULL CUDA Graph mode #29020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Jan van Lunteren <[email protected]>
Signed-off-by: Jan van Lunteren <[email protected]>
Signed-off-by: Jan van Lunteren <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request effectively separates Triton attention kernel launches for prefill and decode operations in FULL CUDA Graph mode, laying a good foundation for future performance tuning. The changes are generally well-structured. However, I've identified a critical bug in the grid size calculation for the prefill kernel, which could lead to performance issues or errors. Additionally, there's significant code duplication in the kernel launch logic that should be refactored to improve maintainability. My detailed comments are below.
| total_num_q_blocks = ( | ||
| (q.shape[0] - num_decodes) // BLOCK_Q + num_seqs | ||
| if split_launch | ||
| else q.shape[0] // BLOCK_Q + num_seqs | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation for total_num_q_blocks for the prefill part of a mixed batch appears to be incorrect when split_launch is true. The number of sequences added to the upper bound should be the number of prefill sequences (num_seqs - num_decodes), not the total number of sequences (num_seqs). This overestimates the number of required blocks by num_decodes, which can lead to wasted computation and potentially exceed grid size limits.
| total_num_q_blocks = ( | |
| (q.shape[0] - num_decodes) // BLOCK_Q + num_seqs | |
| if split_launch | |
| else q.shape[0] // BLOCK_Q + num_seqs | |
| ) | |
| total_num_q_blocks = ( | |
| (q.shape[0] - num_decodes) // BLOCK_Q + (num_seqs - num_decodes) | |
| if split_launch | |
| else q.shape[0] // BLOCK_Q + num_seqs | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is addressed by new modification.
| if num_decodes > 0 and split_launch: | ||
| # batch contains decodes that are not processed in unified fashion | ||
| kernel_unified_attention_2d[ | ||
| ( | ||
| num_decodes, | ||
| num_kv_heads, | ||
| ) | ||
| ]( | ||
| output_ptr=out, | ||
| query_ptr=q, | ||
| key_cache_ptr=k, | ||
| value_cache_ptr=v, | ||
| sink_ptr=sinks, | ||
| block_tables_ptr=block_table, | ||
| seq_lens_ptr=seqused_k, | ||
| alibi_slopes_ptr=alibi_slopes, | ||
| qq_bias_ptr=qq_bias, | ||
| scale=softmax_scale, | ||
| k_scale=k_descale, | ||
| v_scale=v_descale, | ||
| out_scale=1 / output_scale if output_scale is not None else 1.0, | ||
| softcap=softcap, | ||
| num_query_heads=num_query_heads, | ||
| num_queries_per_kv=num_queries_per_kv, | ||
| block_table_stride=block_table.stride(0), | ||
| query_stride_0=q.stride(0), | ||
| query_stride_1=q.stride(1), | ||
| output_stride_0=out.stride(0), | ||
| output_stride_1=out.stride(1), | ||
| qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, | ||
| BLOCK_SIZE=block_size, | ||
| TILE_SIZE=TILE_SIZE_2D_DECODE, | ||
| HEAD_SIZE=head_size, | ||
| HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), | ||
| USE_ALIBI_SLOPES=use_alibi_slopes, | ||
| USE_QQ_BIAS=use_qq_bias, | ||
| USE_SOFTCAP=(softcap > 0), | ||
| USE_SINKS=(sinks is not None), | ||
| SLIDING_WINDOW=(1 + window_size[0]), | ||
| stride_k_cache_0=k.stride(0), | ||
| stride_k_cache_1=k.stride(1), | ||
| stride_k_cache_2=k.stride(2), | ||
| stride_k_cache_3=k.stride(3), | ||
| stride_v_cache_0=v.stride(0), | ||
| stride_v_cache_1=v.stride(1), | ||
| stride_v_cache_2=v.stride(2), | ||
| stride_v_cache_3=v.stride(3), | ||
| query_start_len_ptr=cu_seqlens_q, | ||
| BLOCK_Q=BLOCK_Q, | ||
| num_seqs=num_decodes, | ||
| BLOCK_M=BLOCK_M, | ||
| q_block_offset=0, | ||
| decode_only=True, | ||
| USE_FP8=output_scale is not None, | ||
| ) | ||
| else: | ||
| # for initial version, NUM_SEGMENTS = 16 is chosen as a default | ||
| # value that showed good performance in tests | ||
| NUM_SEGMENTS = 16 | ||
|
|
||
| segm_output = torch.empty( | ||
| q.shape[0], | ||
| num_query_heads, | ||
| NUM_SEGMENTS, | ||
| triton.next_power_of_2(head_size), | ||
| dtype=torch.float32, | ||
| device=q.device, | ||
| ) | ||
| segm_max = torch.empty( | ||
| q.shape[0], | ||
| num_query_heads, | ||
| NUM_SEGMENTS, | ||
| dtype=torch.float32, | ||
| device=q.device, | ||
| ) | ||
| segm_expsum = torch.empty( | ||
| q.shape[0], | ||
| num_query_heads, | ||
| NUM_SEGMENTS, | ||
| dtype=torch.float32, | ||
| device=q.device, | ||
| ) | ||
| # decode-only batch | ||
|
|
||
| if num_decodes > seq_threshold_3D: | ||
| # use 2D kernel for decode-only batch | ||
| kernel_unified_attention_2d[ | ||
| ( | ||
| num_decodes, | ||
| num_kv_heads, | ||
| ) | ||
| ]( | ||
| output_ptr=out, | ||
| query_ptr=q, | ||
| key_cache_ptr=k, | ||
| value_cache_ptr=v, | ||
| sink_ptr=sinks, | ||
| block_tables_ptr=block_table, | ||
| seq_lens_ptr=seqused_k, | ||
| alibi_slopes_ptr=alibi_slopes, | ||
| qq_bias_ptr=qq_bias, | ||
| scale=softmax_scale, | ||
| k_scale=k_descale, | ||
| v_scale=v_descale, | ||
| out_scale=1 / output_scale if output_scale is not None else 1.0, | ||
| softcap=softcap, | ||
| num_query_heads=num_query_heads, | ||
| num_queries_per_kv=num_queries_per_kv, | ||
| block_table_stride=block_table.stride(0), | ||
| query_stride_0=q.stride(0), | ||
| query_stride_1=q.stride(1), | ||
| output_stride_0=out.stride(0), | ||
| output_stride_1=out.stride(1), | ||
| qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, | ||
| BLOCK_SIZE=block_size, | ||
| TILE_SIZE=TILE_SIZE_2D_DECODE, | ||
| HEAD_SIZE=head_size, | ||
| HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), | ||
| USE_ALIBI_SLOPES=use_alibi_slopes, | ||
| USE_QQ_BIAS=use_qq_bias, | ||
| USE_SOFTCAP=(softcap > 0), | ||
| USE_SINKS=(sinks is not None), | ||
| SLIDING_WINDOW=(1 + window_size[0]), | ||
| stride_k_cache_0=k.stride(0), | ||
| stride_k_cache_1=k.stride(1), | ||
| stride_k_cache_2=k.stride(2), | ||
| stride_k_cache_3=k.stride(3), | ||
| stride_v_cache_0=v.stride(0), | ||
| stride_v_cache_1=v.stride(1), | ||
| stride_v_cache_2=v.stride(2), | ||
| stride_v_cache_3=v.stride(3), | ||
| query_start_len_ptr=cu_seqlens_q, | ||
| BLOCK_Q=BLOCK_Q, | ||
| num_seqs=num_decodes, | ||
| BLOCK_M=BLOCK_M, | ||
| q_block_offset=0, | ||
| decode_only=True, | ||
| USE_FP8=output_scale is not None, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant code duplication for launching the kernel_unified_attention_2d for decode operations. The kernel call block is repeated for the split_launch case in a mixed batch (lines 878-932) and for the 2D path in a decode-only batch (lines 936-990). This makes the code harder to read and maintain, as any changes would need to be synchronized in multiple places.
I suggest refactoring this duplicated logic into a helper function. This would improve code clarity and make future modifications easier. For example, you could create a helper function like _launch_decode_2d_kernel(...) that encapsulates the kernel launch and its parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is addressed by new modification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| BLOCK_Q=BLOCK_Q, | ||
| num_seqs=num_seqs, | ||
| BLOCK_M=BLOCK_M, | ||
| q_block_offset=num_decodes if split_launch else 0, | ||
| decode_only=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Offset prefills by decode count instead of decode blocks
When launching the second 2D kernel for the prefill portion, the code offsets the program id by num_decodes and subtracts the same value from the q‑block estimate. num_decodes counts requests, not q blocks. This only works if every decode request occupies exactly one BLOCK_Q block (e.g. one token). For speculative decoding or any configuration where decodes can contain multiple tokens, the first kernel will start in the middle of the decode blocks and the second kernel only launches num_decodes blocks, so some decode blocks are skipped or processed with prefill parameters. In FULL CUDA graph mode this yields incorrect outputs for multi‑token decodes. The offset should be based on the number of decode q blocks (∑⌈len_i/BLOCK_Q⌉) rather than the number of decode requests.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is correct. Speculative decoding is not yet included in the optimizations introduced in this pull request. To handle this properly, reorder_batch_threshold has been set to 1 token, which ensures speculative decode steps are processed by the 2D kernel in the same way as prefills. Furthermore, FULL CUDA Graph mode in combination with speculative decoding is not supported by this pull request, and will be blocked using an assert statement A follow-up pull request will extend the optimizations in this pull request to also support speculative decoding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The above mentioned assert statement has been replaced by an override of the get_cudagraph_support() class method in the TritonAttentionMetadataBuilder class. This override adjusts the _cudagraph_support value from ALWAYS to UNIFORM_SINGLE_TOKEN_DECODE when the combination of FULL CUDA Graph mode and speculative decoding is detected.
Signed-off-by: Jan van Lunteren <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes in this PR are not intended to deliver significant performance gains immediately, but rather to lay the groundwork for future optimizations in upcoming PRs.
IMO its best practice to hold off until we can see the benefits; we can land this as a chain once we know the final PR is delivering perf; this gives reviewers full visibility into the planned changes and ultimate benefit
Signed-off-by: Jan van Lunteren <[email protected]>
That is a fair point. I will see if I can include some experimental results that demonstrate performance improvements without requiring additional PRs. |
An example where this pull request provides a clear advantage is in handling mixed prefill/decode batches that include a few sequences with long decode phases. In the current upstream implementation, the entire batch is processed using the 2D kernel. In contrast, this PR (when running in FULL CUDA Graph mode) optimizes the workflow by processing the prefill operations with the 2D kernel and the decode operations with the 3D kernel. To illustrate this, I created the following script that generates one sequence with a long prompt along with multiple short sequences. By setting The results are:
These results demonstrate that this PR clearly outperforms the current upstream version in this scenario, achieving nearly a 5x improvement for a 64K prompt size. The performance gain in practical workloads naturally depends on how frequently such situations occur. |
Signed-off-by: jvlunteren <[email protected]>
Signed-off-by: Jan van Lunteren <[email protected]>
Signed-off-by: Jan van Lunteren <[email protected]>
Signed-off-by: Jan van Lunteren <[email protected]>
Signed-off-by: Jan van Lunteren <[email protected]>
Signed-off-by: Jan van Lunteren <[email protected]>
Signed-off-by: Jan van Lunteren <[email protected]>
Signed-off-by: Jan van Lunteren <[email protected]>
Signed-off-by: Jan van Lunteren <[email protected]>
Purpose
This pull request introduces separate Triton attention kernel launches for prefill and decode operations within mixed prefill/decode batches when the CUDA Graph mode is set to FULL. Splitting the launches enables tuning the parameters for prefill and decode independently, improving flexibility and potential performance in CUDA graph mode.
When the CUDA Graph mode is not FULL, the existing approach is preserved: a single attention kernel launch for the entire batch. This avoids additional kernel launch overhead, which would not be (partially) compensated by CUDA Graphs in these modes.
The changes in this PR are not intended to deliver significant performance gains immediately, but rather to lay the groundwork for future optimizations in upcoming PRs.Performance
The following results were obtained for
meta-llama/Llama-3.1-8B-Instructon an NVIDIA H100 GPU, by running$ VLLM_ATTENTION_BACKEND=TRITON_ATTN vllm bench latency \ --model meta-llama/Llama-3.1-8B-Instruct \ --input-len <input-length> --output-len 4 \ --batch-size 1for
both for the default CUDA Graph mode and for FULL CUDA Graph mode.
Results are shown in the following graph. The input (prompt) length (in tokens) was varied in these experiments across the following values: 500, 1000, 1500, 2000, 4000, 8000, and 16000. The number of warmup iterations and measurement iterations were left at the default values of 10 and 30 respectively.
As illustrated in the graph above, this PR improves the performance of the Triton Unified Attention Kernel by approximately 2 times for a batch size of 1 and an input length of 16000 tokens.
Further performance improvements are discussed below: #29020 (comment)
Test Plan
The unit test
./tests/kernels/attention/test_triton_unified_attention.pyhas been updated to reflect the new behavior. It now covers both scenarios for mixed batches:Other test suites, such as lm_eval, remain fully compatible and require no modifications.
Test Result
unit test results for updated Triton unified attention kernel (this PR):
lm_evalresults for updated Triton unified attention kernel (this PR) for single kernel launch for mixes batches:VLLM_ATTENTION_BACKEND=TRITON_ATTN vllm serve meta-llama/Llama-3.1-8B-Instruct --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' lm_eval --model local-completions --model_args base_url=http://localhost:8000/v1/completions,tokenizer=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500 |Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr| |-----|------:|----------------|-----:|-----------|---|----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.798|± |0.0180| | | |strict-match | 5|exact_match|↑ |0.786|± |0.0184|lm_evalresults for updated Triton unified attention kernel (this PR) for separate kernel launches for mixes batches:VLLM_ATTENTION_BACKEND=TRITON_ATTN vllm serve meta-llama/Llama-3.1-8B-Instruct --compilation-config '{"cudagraph_mode": "FULL"}' lm_eval --model local-completions --model_args base_url=http://localhost:8000/v1/completions,tokenizer=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500 |Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr| |-----|------:|----------------|-----:|-----------|---|----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.794|± |0.0181| | | |strict-match | 5|exact_match|↑ |0.780|± |0.0185|yields similar
lm_evalresults as FlashAttention:@tdoublep @bringlein