Skip to content

Conversation

@gnovack
Copy link
Contributor

@gnovack gnovack commented Nov 23, 2025

Purpose

This PR includes the following changes to the MoE LoRA Align kernel:

  • Adds a global memory variant of moe_lora_align_sum_kernel (following the moe_align_block_size_kernel implementation). Previously only the shared memory variant was implemented, which led to worse performance for models with larger num_experts.
  • Uses two cuda streams to execute moe_lora_align_sum_kernel and moe_align_block_size_kernel in parallel
  • Refactors moe_lora_align_sum_kernel and moe_align_block_size_kernel to reduce duplicate logic between LoRA and non-LoRA cases (e.g. moe_align_block_size_kernel and moe_lora_align_block_size_kernel now call a common _moe_align_block_size function which supports both LoRA and non-LoRA cases).

FIX #30026

Test Plan

  • Ran existing LoRA and moe_align_sum test cases

Test Result

Benchmark results w/ gpt-oss-120b before vs. after this change

Before:

============ Serving Benchmark Result ============
Successful requests:                     400       
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  500.09    
Total input tokens:                      641602    
Total generated tokens:                  240183    
Request throughput (req/s):              0.80      
Output token throughput (tok/s):         480.28    
Peak output token throughput (tok/s):    576.00    
Peak concurrent requests:                13.00     
Total Token throughput (tok/s):          1763.27   
---------------Time to First Token----------------
Mean TTFT (ms):                          155.95    
Median TTFT (ms):                        142.50    
P99 TTFT (ms):                           582.98    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          16.38     
Median TPOT (ms):                        16.42     
P99 TPOT (ms):                           17.02     
---------------Inter-token Latency----------------
Mean ITL (ms):                           16.38     
Median ITL (ms):                         15.15     
P99 ITL (ms):                            121.57    
==================================================

After:

============ Serving Benchmark Result ============
Successful requests:                     400       
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  480.35    
Total input tokens:                      641602    
Total generated tokens:                  240183    
Request throughput (req/s):              0.83      
Output token throughput (tok/s):         500.01    
Peak output token throughput (tok/s):    648.00    
Peak concurrent requests:                13.00     
Total Token throughput (tok/s):          1835.70   
---------------Time to First Token----------------
Mean TTFT (ms):                          166.84    
Median TTFT (ms):                        149.03    
P99 TTFT (ms):                           631.58    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          15.71     
Median TPOT (ms):                        15.73     
P99 TPOT (ms):                           16.71     
---------------Inter-token Latency----------------
Mean ITL (ms):                           15.71     
Median ITL (ms):                         14.38     
P99 ITL (ms):                            129.20    
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify
Copy link

mergify bot commented Nov 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gnovack.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 23, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several significant improvements to the MoE LoRA alignment kernels. The refactoring to unify LoRA and non-LoRA logic by using common __device__ functions is a great step towards reducing code duplication and improving maintainability. The introduction of a global memory variant for moe_lora_align_sum_kernel and the use of a separate CUDA stream for parallel execution are solid performance optimizations. However, I've found a critical issue in csrc/moe/moe_align_sum_kernels.cu where a data type mismatch for token_lora_mapping could lead to incorrect memory access and undefined behavior. This needs to be addressed.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@ApostaC
Copy link
Collaborator

ApostaC commented Nov 24, 2025

cc @tylertitsworth

@jeejeelee
Copy link
Collaborator

Most tests in test_moe_lora_align_sum.py are now failing.

@gnovack gnovack force-pushed the lora-align-refactor branch from e0cce68 to ddc47bf Compare December 3, 2025 18:38
@mergify mergify bot removed the needs-rebase label Dec 3, 2025
@gnovack
Copy link
Contributor Author

gnovack commented Dec 3, 2025

Most tests in test_moe_lora_align_sum.py are now failing.

These should be fixed now

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE,
TORCH_CHECK(num_thread <= 1024,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check should verify that the num_experts` is less than or equal to 1024.
We also need to add more number in https://github.com/vllm-project/vllm/blob/main/tests/lora/test_moe_lora_align_sum.py#L35

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. I just updated this check and added test cases for larger num_experts

@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 4, 2025
@jeejeelee
Copy link
Collaborator

LGTM, @yewentao256 could you please take a look?

@jeejeelee
Copy link
Collaborator

@gnovack All LoRA tests are failing

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! Please fix the unit tests, all related I think

@gnovack
Copy link
Contributor Author

gnovack commented Dec 5, 2025

Thanks for the work! Please fix the unit tests, all related I think

No problem! I think I've found the issue which is causing these test failures, so should have a fix out in the next few hours

@mergify
Copy link

mergify bot commented Dec 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gnovack.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added new-model Requests to new models performance Performance-related issues qwen Related to Qwen models gpt-oss Related to GPT-OSS models nvidia labels Dec 8, 2025
@mergify mergify bot added the rocm Related to AMD ROCm label Dec 8, 2025
@mergify mergify bot added the kv-connector label Dec 8, 2025
@gnovack gnovack force-pushed the lora-align-refactor branch from 0b5e995 to 8d847ab Compare December 8, 2025 23:03
@mergify mergify bot removed tpu Related to Google TPUs needs-rebase labels Dec 8, 2025
Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for contribution adn paitence

@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Dec 9, 2025
@github-project-automation github-project-automation bot moved this to In review in NVIDIA Dec 9, 2025
@jeejeelee jeejeelee merged commit ea657f2 into vllm-project:main Dec 9, 2025
95 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Dec 9, 2025
mayoohee pushed a commit to mayoohee/vllm that referenced this pull request Dec 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models nvidia performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: Done
Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: Failed to deploy Qwen3-Next-80B with LoRA Adpater on H100

4 participants