Skip to content

Conversation

@PatchouliTIS
Copy link

@PatchouliTIS PatchouliTIS commented Nov 21, 2025

Purpose

This PR is based on PR #24799 aiming to implement GPU version of ngram speculative decoding and make it compatible with Async Scheduler.

Test Plan

  • Async Scheduler + NGram + Qwen3-1.7B
    Test config:
# dataset is CMU-DoG, which is an input-grounded dataset.
python3.12 -u -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8000 \
--max-num-seqs 128 \
--max-model-len 2048 \
--model Qwen/Qwen3-1.7B \
--tensor-parallel-size 1 \
--trust-remote-code \
--dtype bfloat16  \
--enable-chunked-prefill \
--disable-log-requests \
--async-scheduling \
--speculative_config '{"method": "ngram_gpu", "num_speculative_tokens": 3, "prompt_lookup_max": 2,"prompt_lookup_min": 2}'

Test Device: NVIDIA H20

Test Result

Performance

num_prompts async_ngram(tps) sync_ngram(tps) speedup
2 466 357 30.5%
8 1378 988 39.4%
16 2082 1726 20.6%

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

hl475 and others added 4 commits November 24, 2025 10:58
…rs (vllm-project#29111)

Signed-off-by: Huamin Li <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: PatchouliTaisa <[email protected]>
Signed-off-by: PatchouliTaisa <[email protected]>
Signed-off-by: PatchouliTaisa <[email protected]>
Signed-off-by: PatchouliTaisa <[email protected]>
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@ZJY0516
Copy link
Contributor

ZJY0516 commented Nov 27, 2025

cc @njhill

# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache = not is_compile_cache_enabled(self.inductor_config)

# TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Can this be fixed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I enabled torch compile in the ngram gpu kernel, the computational graph corresponding to ngram operator would hit a precompiled computational graph cache in the main model, leading to mismatched computational graph results. Therefore, I directly disabled the compile cache here. I tested this locally, and disabling the cache had no impact on performance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume disabling the compile cache would lead to longer startup time? I'm not an expert here but maybe it's possible to add an identifier to the compile cache to avoid extraneous cache hits?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the startup time will increase a little. I attempted to add additional input parameters and other member variables to the nn.Modules forward method decorated with @support_torch_compile to achieve cache isolation, but none of them worked. I suspect this might be related to the internal implementation of @support_torch_compile within vLLM. However, as things stand, disabling torch compile caching only impacts the performance of the entire inference service during the initial startup phase.

PatchouliTaisa added 6 commits December 2, 2025 15:49
Signed-off-by: PatchouliTaisa <[email protected]>
Signed-off-by: PatchouliTaisa <[email protected]>
Signed-off-by: PatchouliTaisa <[email protected]>
Signed-off-by: PatchouliTaisa <[email protected]>
Signed-off-by: PatchouliTaisa <[email protected]>
Signed-off-by: PatchouliTaisa <[email protected]>
Copy link
Author

@PatchouliTIS PatchouliTIS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the codes according to comments.

for i, num_tokens in enumerate(num_accepted_tokens):
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens

def _update_ngram_gpu_tensors(self, scheduler_output: "SchedulerOutput") -> None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This separate processing for the ngram GPU input avoids direct copying each time. It performs incremental updates to the GPU buffer based on the previous prev_req_id_to_index and the current self.input_batch.req_id_to_index, thereby preventing extensive copying operations.

# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache = not is_compile_cache_enabled(self.inductor_config)

# TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the startup time will increase a little. I attempted to add additional input parameters and other member variables to the nn.Modules forward method decorated with @support_torch_compile to achieve cache isolation, but none of them worked. I suspect this might be related to the internal implementation of @support_torch_compile within vLLM. However, as things stand, disabling torch compile caching only impacts the performance of the entire inference service during the initial startup phase.

pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.token_ids_gpu_tensor = torch.zeros(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this section requires more sophisticated logic to store token IDs for each request, it would necessitate significant modifications, including changes to how ngram_gpu handles token ID input. For this PR, I propose moving this buffer allocation logic to gpu_model_runner. This means buffer allocation would only occur when both async_scheduler and ngram_gpu are enabled. Users could also manually set max_num_seqs and max_model_len to reduce the VRAM footprint of this buffer. Further optimizations could be addressed in a separate PR.

all_token_ids = prompt_token_ids + req_state.output_token_ids
num_tokens = len(all_token_ids)
# Copy to GPU tensor
self.input_batch.token_ids_gpu_tensor[idx, :num_tokens].copy_(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I'll fix that in the actual async implementation.

),
non_blocking=True,
)
self.input_batch.num_tokens_no_spec_gpu[idx] = num_tokens
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move the update of token_ids_gpu_tensor used for ngram gpu into _update_states, reuse the result of num_tokens_no_spec maintained in input_batch

@support_torch_compile(
dynamic_arg_dims={
"num_tokens_no_spec": 0,
"token_ids_gpu": [0, 1],
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the second dim flag for token_ids_gpu, now there is only batch_size dim of all inputs is marked as dynamic. The second dim is seq_len and since I use a fix size of buffer token_ids_gpu_tensor, the input always has fixed size in second dim.

@Neo9061
Copy link

Neo9061 commented Dec 2, 2025

Thanks for enabling this feature! Question: will the grain becomes different at higher batch sizes? and which benchmark datasets you used?

Maybe you mean the performance in higher batch size? for now there is performance degradation in higher batch size because the scheduler pre-allocates num_spec_tokens draft token positions for next step even if the draft token ids is invalid, so there is a huge redundent model forward computation and sampling process when in higher batch size. But this can be solved and I'm still working on it. The datasets are here: https://github.com/festvox/datasets-CMU_DoG

Yes! I meant higher batch sizes. We also observed similar. #27379 (we also used Blazedit dataset to show n-gram effectiveness)

Do you by chance have a timeline when you plan to resolve the higher-batch issue for n-gram? For context I want to use n-gram in the context of hybrid decoding (#24344) and since EAGLE now has async scheduling, your PR would be very useful to make n-gram compatible in hybrid.

@PatchouliTIS
Copy link
Author

Thanks for enabling this feature! Question: will the grain becomes different at higher batch sizes? and which benchmark datasets you used?

Maybe you mean the performance in higher batch size? for now there is performance degradation in higher batch size because the scheduler pre-allocates num_spec_tokens draft token positions for next step even if the draft token ids is invalid, so there is a huge redundent model forward computation and sampling process when in higher batch size. But this can be solved and I'm still working on it. The datasets are here: https://github.com/festvox/datasets-CMU_DoG

Yes! I meant higher batch sizes. We also observed similar. #27379 (we also used Blazedit dataset to show n-gram effectiveness)

Do you by chance have a timeline when you plan to resolve the higher-batch issue for n-gram? For context I want to use n-gram in the context of hybrid decoding (#24344) and since EAGLE now has async scheduling, your PR would be very useful to make n-gram compatible in hybrid.

I come up a new implementation and will push up this week. Based on it I run some benchmark in Blazedit datasets on Qwen3-8B and here is the result:

async + ngram gpu (bs24):
image

origin ngram cpu (bs24):
image

baseline (bs24):
image

async + ngram gpu (bs96):
image

origin ngram cpu (bs96):
image

baseline (bs96):
image

It appears that the new implementation yields some benefits at large batch sizes.

@PatchouliTIS
Copy link
Author

Thanks for enabling this feature! Question: will the grain becomes different at higher batch sizes? and which benchmark datasets you used?

Maybe you mean the performance in higher batch size? for now there is performance degradation in higher batch size because the scheduler pre-allocates num_spec_tokens draft token positions for next step even if the draft token ids is invalid, so there is a huge redundent model forward computation and sampling process when in higher batch size. But this can be solved and I'm still working on it. The datasets are here: https://github.com/festvox/datasets-CMU_DoG

Yes! I meant higher batch sizes. We also observed similar. #27379 (we also used Blazedit dataset to show n-gram effectiveness)

Do you by chance have a timeline when you plan to resolve the higher-batch issue for n-gram? For context I want to use n-gram in the context of hybrid decoding (#24344) and since EAGLE now has async scheduling, your PR would be very useful to make n-gram compatible in hybrid.

During local benchmark I also notice that the mean draft tokens acceptance rate is approxiamately 37%~41% in Blazedit datasets, which maybe not enough for an explicit performance enhancements?

@ArmageddonKnight
Copy link

@benchislett @njhill We have addressed the feedback. Could you please kindly review this PR again? Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants