-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Optimization] Add Fused Triton Kernel for GPT-OSS Router #29237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ijpq
wants to merge
8
commits into
vllm-project:main
Choose a base branch
from
ijpq:ijpq/fused_router_gptoss
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+791
−0
Open
Changes from 1 commit
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
fca484b
[Optimization] Add Fused Triton Kernel for topk+softmax
ijpq 66e6711
[Optimization]: Optimize Fused Triton Kernel for topk+softmax
ijpq 76d472d
[Fix]: Tweak a few little things in triton kernel
ijpq 0c7698e
[Optimization]: add specialization for small topk
ijpq f0301c9
[Optimization]: unroll for each program along M
ijpq eacf1cd
[Optimization]: add autotune and fix little things
ijpq 0f96a01
[Optimization]: further specialize for M,N,topk
ijpq d2fe146
[Optimization]: add bitonic sort within warp
ijpq File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import fused_topk_softmax | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("M", [1, 32, 128, 2048]) | ||
| @pytest.mark.parametrize("N", [32, 65, 128]) | ||
| @pytest.mark.parametrize("topk", [1, 2, 3, 4, 5]) | ||
| def test_fused_router(M, N, topk): | ||
| device = "cuda" | ||
| torch.manual_seed(0) | ||
|
|
||
| logits = torch.randn((M, N), device=device, dtype=torch.float32) | ||
|
|
||
| ref_vals, ref_indices = torch.topk(logits, topk, dim=-1) | ||
| ref_probs = torch.softmax(ref_vals, dim=-1) | ||
|
|
||
| tri_probs, tri_indices = fused_topk_softmax(logits, topk, renormalize=True) | ||
|
|
||
| torch.testing.assert_close(tri_indices.long(), ref_indices) | ||
| torch.testing.assert_close(tri_probs, ref_probs, atol=1e-4, rtol=1e-4) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_fused_router(128, 32, 2) | ||
| print("Test Passed!") | ||
ijpq marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import ( | ||
| gpt_oss_custom_routing_function, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.layer import FusedMoE | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_tokens", [10, 128, 1024]) | ||
| @pytest.mark.parametrize("num_experts", [32, 65, 128]) | ||
| @pytest.mark.parametrize("topk", [1, 2, 3, 4, 5]) | ||
| def test_routing_consistency(num_tokens, num_experts, topk): | ||
ijpq marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| torch.manual_seed(42) | ||
| device = torch.device("cuda") | ||
|
|
||
| hidden_states = torch.randn(num_tokens, 4096, device=device, dtype=torch.float16) | ||
| router_logits = torch.randn( | ||
| num_tokens, num_experts, device=device, dtype=torch.float32 | ||
| ) | ||
|
|
||
| ref_weights, ref_ids, _ = FusedMoE.select_experts( | ||
| hidden_states=hidden_states, | ||
| router_logits=router_logits, | ||
| top_k=topk, | ||
| use_grouped_topk=False, | ||
| renormalize=True, | ||
| custom_routing_function=None, | ||
| ) | ||
|
|
||
| triton_weights, triton_ids, _ = FusedMoE.select_experts( | ||
| hidden_states=hidden_states, | ||
| router_logits=router_logits, | ||
| top_k=topk, | ||
| use_grouped_topk=False, | ||
| renormalize=True, | ||
| custom_routing_function=gpt_oss_custom_routing_function, | ||
| ) | ||
|
|
||
| print(f"\nTesting M={num_tokens}, E={num_experts}, K={topk}") | ||
|
|
||
| torch.testing.assert_close( | ||
| triton_ids, | ||
| ref_ids, | ||
| msg="Expert indices mismatch between Native and Triton implementation", | ||
| ) | ||
|
|
||
| torch.testing.assert_close( | ||
| triton_weights, | ||
| ref_weights, | ||
| atol=1e-3, | ||
| rtol=1e-3, | ||
| msg="Expert weights mismatch between Native and Triton implementation", | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_routing_consistency(128, 32, 2) | ||
| print("Consistency Test Passed!") | ||
112 changes: 112 additions & 0 deletions
112
vllm/model_executor/layers/fused_moe/gpt_oss_fused_router.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """GPT-OSS MoE router with Triton topk kernel.""" | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.triton_utils import tl, triton | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _topk_softmax_kernel( | ||
| logits_ptr, | ||
| weights_ptr, | ||
| indices_ptr, | ||
| M, | ||
| N, | ||
| topk: tl.constexpr, | ||
| topk_padded: tl.constexpr, | ||
| stride_lm, | ||
| stride_ln, | ||
| stride_wm, | ||
| stride_wk, | ||
| stride_im, | ||
| stride_ik, | ||
| BLOCK_N: tl.constexpr, | ||
| RENORM: tl.constexpr, | ||
| ): | ||
| token_idx = tl.program_id(0) | ||
|
|
||
| offs = tl.arange(0, BLOCK_N) | ||
ijpq marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| mask = offs < N | ||
| logit_offs = logits_ptr + token_idx * stride_lm + offs * stride_ln | ||
| logits = tl.load(logit_offs, mask=mask, other=float("-inf")) | ||
|
|
||
| topk_vals = tl.zeros([topk_padded], dtype=tl.float32) + float("-inf") | ||
| topk_idxs = tl.zeros([topk_padded], dtype=tl.int32) | ||
|
|
||
| working_logits = logits | ||
|
|
||
| for k in range(topk): | ||
| cur_max = tl.max(working_logits, axis=0) | ||
| cur_idx = tl.argmax(working_logits, axis=0) | ||
|
|
||
| k_mask = tl.arange(0, topk_padded) == k | ||
| topk_vals = tl.where(k_mask, cur_max, topk_vals) | ||
| topk_idxs = tl.where(k_mask, cur_idx, topk_idxs) | ||
|
|
||
| mask_selected = offs == cur_idx | ||
| working_logits = tl.where(mask_selected, float("-inf"), working_logits) | ||
|
|
||
| if RENORM: | ||
| max_val = tl.max(topk_vals, axis=0) | ||
| exp_vals = tl.exp(topk_vals - max_val) | ||
| sum_exp = tl.sum(exp_vals, axis=0) | ||
| topk_vals = exp_vals / sum_exp | ||
|
|
||
| offs_k = tl.arange(0, topk_padded) | ||
|
|
||
| store_mask = offs_k < topk | ||
|
|
||
| weight_ptrs = weights_ptr + token_idx * stride_wm + offs_k * stride_wk | ||
| tl.store(weight_ptrs, topk_vals, mask=store_mask) | ||
|
|
||
| index_ptrs = indices_ptr + token_idx * stride_im + offs_k * stride_ik | ||
| tl.store(index_ptrs, topk_idxs, mask=store_mask) | ||
|
|
||
|
|
||
| def fused_topk_softmax( | ||
| router_logits: torch.Tensor, | ||
| topk: int, | ||
| renormalize: bool = True, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| M, N = router_logits.shape | ||
|
|
||
| weights = torch.empty((M, topk), device=router_logits.device, dtype=torch.float32) | ||
| indices = torch.empty((M, topk), device=router_logits.device, dtype=torch.int32) | ||
|
|
||
| BLOCK_N = triton.next_power_of_2(N) | ||
|
|
||
| topk_padded = triton.next_power_of_2(topk) | ||
|
|
||
| grid = (M,) | ||
ijpq marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| _topk_softmax_kernel[grid]( | ||
| logits_ptr=router_logits, | ||
| weights_ptr=weights, | ||
| indices_ptr=indices, | ||
| M=M, | ||
| N=N, | ||
| topk=topk, | ||
| topk_padded=topk_padded, | ||
| stride_lm=router_logits.stride(0), | ||
| stride_ln=router_logits.stride(1), | ||
| stride_wm=weights.stride(0), | ||
| stride_wk=weights.stride(1), | ||
| stride_im=indices.stride(0), | ||
| stride_ik=indices.stride(1), | ||
| BLOCK_N=BLOCK_N, | ||
| RENORM=renormalize, | ||
| ) | ||
|
|
||
| return weights, indices | ||
|
|
||
|
|
||
| def gpt_oss_custom_routing_function( | ||
| hidden_states: torch.Tensor, | ||
| gating_output: torch.Tensor, | ||
| topk: int, | ||
| renormalize: bool, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| # only use gating_output to avoid padding issues | ||
| return fused_topk_softmax(gating_output, topk, renormalize) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.