Skip to content

Conversation

@jvlunteren
Copy link
Contributor

@jvlunteren jvlunteren commented Nov 19, 2025

Purpose

This pull request introduces separate Triton attention kernel launches for prefill and decode operations within mixed prefill/decode batches when the CUDA Graph mode is set to FULL. Splitting the launches enables tuning the parameters for prefill and decode independently, improving flexibility and potential performance in CUDA graph mode.

When the CUDA Graph mode is not FULL, the existing approach is preserved: a single attention kernel launch for the entire batch. This avoids additional kernel launch overhead, which would not be (partially) compensated by CUDA Graphs in these modes.

The changes in this PR are not intended to deliver significant performance gains immediately, but rather to lay the groundwork for future optimizations in upcoming PRs.

Performance

The following results were obtained for meta-llama/Llama-3.1-8B-Instruct on an NVIDIA H100 GPU, by running

$ VLLM_ATTENTION_BACKEND=TRITON_ATTN vllm bench latency \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --input-len <input-length> --output-len 4 \
    --batch-size 1

for

  1. the current upstream version of the Triton attention kernel, and
  2. the updated Triton attention kernel (this PR)

both for the default CUDA Graph mode and for FULL CUDA Graph mode.

Results are shown in the following graph. The input (prompt) length (in tokens) was varied in these experiments across the following values: 500, 1000, 1500, 2000, 4000, 8000, and 16000. The number of warmup iterations and measurement iterations were left at the default values of 10 and 30 respectively.

results

As illustrated in the graph above, this PR improves the performance of the Triton Unified Attention Kernel by approximately 2 times for a batch size of 1 and an input length of 16000 tokens.

Further performance improvements are discussed below: #29020 (comment)

Test Plan

The unit test ./tests/kernels/attention/test_triton_unified_attention.py has been updated to reflect the new behavior. It now covers both scenarios for mixed batches:

  • Single kernel launch (current approach)
  • Separate kernel launches for prefill and decode (new approach for FULL CUDA Graph mode)

Other test suites, such as lm_eval, remain fully compatible and require no modifications.

Test Result

unit test results for updated Triton unified attention kernel (this PR):

python3 -m pytest tests/kernels/attention/test_triton_unified_attention.py

================================================ 1536 passed in 229.01s (0:03:49) ================================================


lm_eval results for updated Triton unified attention kernel (this PR) for single kernel launch for mixes batches:

VLLM_ATTENTION_BACKEND=TRITON_ATTN vllm serve meta-llama/Llama-3.1-8B-Instruct --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}'
lm_eval --model local-completions --model_args base_url=http://localhost:8000/v1/completions,tokenizer=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.798|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.786|±  |0.0184|

lm_eval results for updated Triton unified attention kernel (this PR) for separate kernel launches for mixes batches:

VLLM_ATTENTION_BACKEND=TRITON_ATTN vllm serve meta-llama/Llama-3.1-8B-Instruct --compilation-config '{"cudagraph_mode": "FULL"}'
lm_eval --model local-completions --model_args base_url=http://localhost:8000/v1/completions,tokenizer=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.794|±  |0.0181|
|     |       |strict-match    |     5|exact_match|↑  |0.780|±  |0.0185|

yields similar lm_eval results as FlashAttention:

lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.794|±  |0.0181|
|     |       |strict-match    |     5|exact_match|↑  |0.772|±  |0.0188|

@tdoublep @bringlein

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively separates Triton attention kernel launches for prefill and decode operations in FULL CUDA Graph mode, laying a good foundation for future performance tuning. The changes are generally well-structured. However, I've identified a critical bug in the grid size calculation for the prefill kernel, which could lead to performance issues or errors. Additionally, there's significant code duplication in the kernel launch logic that should be refactored to improve maintainability. My detailed comments are below.

Comment on lines 818 to 822
total_num_q_blocks = (
(q.shape[0] - num_decodes) // BLOCK_Q + num_seqs
if split_launch
else q.shape[0] // BLOCK_Q + num_seqs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation for total_num_q_blocks for the prefill part of a mixed batch appears to be incorrect when split_launch is true. The number of sequences added to the upper bound should be the number of prefill sequences (num_seqs - num_decodes), not the total number of sequences (num_seqs). This overestimates the number of required blocks by num_decodes, which can lead to wasted computation and potentially exceed grid size limits.

Suggested change
total_num_q_blocks = (
(q.shape[0] - num_decodes) // BLOCK_Q + num_seqs
if split_launch
else q.shape[0] // BLOCK_Q + num_seqs
)
total_num_q_blocks = (
(q.shape[0] - num_decodes) // BLOCK_Q + (num_seqs - num_decodes)
if split_launch
else q.shape[0] // BLOCK_Q + num_seqs
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is addressed by new modification.

Comment on lines 878 to 990
if num_decodes > 0 and split_launch:
# batch contains decodes that are not processed in unified fashion
kernel_unified_attention_2d[
(
num_decodes,
num_kv_heads,
)
](
output_ptr=out,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_2D_DECODE,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_decodes,
BLOCK_M=BLOCK_M,
q_block_offset=0,
decode_only=True,
USE_FP8=output_scale is not None,
)
else:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
# value that showed good performance in tests
NUM_SEGMENTS = 16

segm_output = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
triton.next_power_of_2(head_size),
dtype=torch.float32,
device=q.device,
)
segm_max = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
segm_expsum = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
# decode-only batch

if num_decodes > seq_threshold_3D:
# use 2D kernel for decode-only batch
kernel_unified_attention_2d[
(
num_decodes,
num_kv_heads,
)
](
output_ptr=out,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_2D_DECODE,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_decodes,
BLOCK_M=BLOCK_M,
q_block_offset=0,
decode_only=True,
USE_FP8=output_scale is not None,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant code duplication for launching the kernel_unified_attention_2d for decode operations. The kernel call block is repeated for the split_launch case in a mixed batch (lines 878-932) and for the 2D path in a decode-only batch (lines 936-990). This makes the code harder to read and maintain, as any changes would need to be synchronized in multiple places.

I suggest refactoring this duplicated logic into a helper function. This would improve code clarity and make future modifications easier. For example, you could create a helper function like _launch_decode_2d_kernel(...) that encapsulates the kernel launch and its parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is addressed by new modification.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 844 to 874
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
q_block_offset=num_decodes if split_launch else 0,
decode_only=False,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Offset prefills by decode count instead of decode blocks

When launching the second 2D kernel for the prefill portion, the code offsets the program id by num_decodes and subtracts the same value from the q‑block estimate. num_decodes counts requests, not q blocks. This only works if every decode request occupies exactly one BLOCK_Q block (e.g. one token). For speculative decoding or any configuration where decodes can contain multiple tokens, the first kernel will start in the middle of the decode blocks and the second kernel only launches num_decodes blocks, so some decode blocks are skipped or processed with prefill parameters. In FULL CUDA graph mode this yields incorrect outputs for multi‑token decodes. The offset should be based on the number of decode q blocks (∑⌈len_i/BLOCK_Q⌉) rather than the number of decode requests.

Useful? React with 👍 / 👎.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct. Speculative decoding is not yet included in the optimizations introduced in this pull request. To handle this properly, reorder_batch_threshold has been set to 1 token, which ensures speculative decode steps are processed by the 2D kernel in the same way as prefills. Furthermore, FULL CUDA Graph mode in combination with speculative decoding is not supported by this pull request, and will be blocked using an assert statement A follow-up pull request will extend the optimizations in this pull request to also support speculative decoding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above mentioned assert statement has been replaced by an override of the get_cudagraph_support() class method in the TritonAttentionMetadataBuilder class. This override adjusts the _cudagraph_support value from ALWAYS to UNIFORM_SINGLE_TOKEN_DECODE when the combination of FULL CUDA Graph mode and speculative decoding is detected.

Signed-off-by: Jan van Lunteren <[email protected]>
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this PR are not intended to deliver significant performance gains immediately, but rather to lay the groundwork for future optimizations in upcoming PRs.

IMO its best practice to hold off until we can see the benefits; we can land this as a chain once we know the final PR is delivering perf; this gives reviewers full visibility into the planned changes and ultimate benefit

Signed-off-by: Jan van Lunteren <[email protected]>
@jvlunteren
Copy link
Contributor Author

The changes in this PR are not intended to deliver significant performance gains immediately, but rather to lay the groundwork for future optimizations in upcoming PRs.

IMO its best practice to hold off until we can see the benefits; we can land this as a chain once we know the final PR is delivering perf; this gives reviewers full visibility into the planned changes and ultimate benefit

That is a fair point. I will see if I can include some experimental results that demonstrate performance improvements without requiring additional PRs.

@jvlunteren
Copy link
Contributor Author

The changes in this PR are not intended to deliver significant performance gains immediately, but rather to lay the groundwork for future optimizations in upcoming PRs.

IMO its best practice to hold off until we can see the benefits; we can land this as a chain once we know the final PR is delivering perf; this gives reviewers full visibility into the planned changes and ultimate benefit

@LucasWilkinson

An example where this pull request provides a clear advantage is in handling mixed prefill/decode batches that include a few sequences with long decode phases. In the current upstream implementation, the entire batch is processed using the 2D kernel. In contrast, this PR (when running in FULL CUDA Graph mode) optimizes the workflow by processing the prefill operations with the 2D kernel and the decode operations with the 3D kernel.

To illustrate this, I created the following script that generates one sequence with a long prompt along with multiple short sequences. By setting max_num_seqs to 2 and configuring the number of output tokens for the long sequence to match the number of short sequences, most batches scheduled by vLLM will include one decode operating on a large context in the KV cache combined with one prefill. The script then measures the total latency.

import asyncio
import time
from transformers import AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
from vllm.config import CompilationConfig

async def generate_request(engine, prompt, max_tokens, request_id):
    sampling_params = SamplingParams(max_tokens=max_tokens)
    async for output in engine.generate(prompt, sampling_params, request_id):
        if output.finished:
            return

async def main():
    warmup_iters = 3
    measure_iters = 5
    long_prompt_length = 65536
    num_short_seqs = 100

    engine_args = AsyncEngineArgs(
        model="meta-llama/Llama-3.1-8B-Instruct",
        max_num_seqs=2,
        compilation_config={"cudagraph_mode": "FULL"}
    )
    engine = AsyncLLMEngine.from_engine_args(engine_args)
       
    # Create long prompt comprised of valid non-special tokens from the vocabulary
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
    vocab_size = tokenizer.vocab_size
    special_ids = set(tokenizer.all_special_ids)
    valid_ids = [i for i in range(vocab_size) if i not in special_ids]
    long_prompt = tokenizer.decode(valid_ids[:long_prompt_length])

    # Warmup
    for i in range(warmup_iters):
        warmup_tasks = [generate_request(engine, long_prompt, num_short_seqs, f"warmup_long_{i}")]
        for j in range(num_short_seqs):
            warmup_tasks.append(generate_request(engine, f"Short sequence {j}", 1, f"warmup_short_{i}_{j}"))
        await asyncio.gather(*warmup_tasks)

    # Measurement
    latencies = []
    for i in range(measure_iters):
        tasks = [generate_request(engine, long_prompt, num_short_seqs, f"long_{i}")]
        for j in range(num_short_seqs):
            tasks.append(generate_request(engine, f"Short sequence {j}", 1, f"short_{i}_{j}"))
        start_time = time.perf_counter()
        results = await asyncio.gather(*tasks)
        end_time = time.perf_counter()
        latencies.append(end_time - start_time)
    print(f"latencies={latencies}")

asyncio.run(main())

The results are:

prompt size [tokens] current (FULL_AND_PIECEWISE) current (FULL) this PR (FULL)
16K 3.1 sec 3.1 sec 1.5 sec
32K 5.4 sec 5.3 sec 1.6 sec
64K 10.1 sec 10.1 sec 2.1 sec

These results demonstrate that this PR clearly outperforms the current upstream version in this scenario, achieving nearly a 5x improvement for a 64K prompt size. The performance gain in practical workloads naturally depends on how frequently such situations occur.

Signed-off-by: Jan van Lunteren <[email protected]>
Signed-off-by: Jan van Lunteren <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

2 participants