Skip to content

Commit fca484b

Browse files
committed
[Optimization] Add Fused Triton Kernel for topk+softmax
Signed-off-by: ijpq <[email protected]>
1 parent d64429b commit fca484b

File tree

4 files changed

+210
-0
lines changed

4 files changed

+210
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
6+
from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import fused_topk_softmax
7+
8+
9+
@pytest.mark.parametrize("M", [1, 32, 128, 2048])
10+
@pytest.mark.parametrize("N", [32, 65, 128])
11+
@pytest.mark.parametrize("topk", [1, 2, 3, 4, 5])
12+
def test_fused_router(M, N, topk):
13+
device = "cuda"
14+
torch.manual_seed(0)
15+
16+
logits = torch.randn((M, N), device=device, dtype=torch.float32)
17+
18+
ref_vals, ref_indices = torch.topk(logits, topk, dim=-1)
19+
ref_probs = torch.softmax(ref_vals, dim=-1)
20+
21+
tri_probs, tri_indices = fused_topk_softmax(logits, topk, renormalize=True)
22+
23+
torch.testing.assert_close(tri_indices.long(), ref_indices)
24+
torch.testing.assert_close(tri_probs, ref_probs, atol=1e-4, rtol=1e-4)
25+
26+
27+
if __name__ == "__main__":
28+
test_fused_router(128, 32, 2)
29+
print("Test Passed!")
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
6+
from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import (
7+
gpt_oss_custom_routing_function,
8+
)
9+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
10+
11+
12+
@pytest.mark.parametrize("num_tokens", [10, 128, 1024])
13+
@pytest.mark.parametrize("num_experts", [32, 65, 128])
14+
@pytest.mark.parametrize("topk", [1, 2, 3, 4, 5])
15+
def test_routing_consistency(num_tokens, num_experts, topk):
16+
torch.manual_seed(42)
17+
device = torch.device("cuda")
18+
19+
hidden_states = torch.randn(num_tokens, 4096, device=device, dtype=torch.float16)
20+
router_logits = torch.randn(
21+
num_tokens, num_experts, device=device, dtype=torch.float32
22+
)
23+
24+
ref_weights, ref_ids, _ = FusedMoE.select_experts(
25+
hidden_states=hidden_states,
26+
router_logits=router_logits,
27+
top_k=topk,
28+
use_grouped_topk=False,
29+
renormalize=True,
30+
custom_routing_function=None,
31+
)
32+
33+
triton_weights, triton_ids, _ = FusedMoE.select_experts(
34+
hidden_states=hidden_states,
35+
router_logits=router_logits,
36+
top_k=topk,
37+
use_grouped_topk=False,
38+
renormalize=True,
39+
custom_routing_function=gpt_oss_custom_routing_function,
40+
)
41+
42+
print(f"\nTesting M={num_tokens}, E={num_experts}, K={topk}")
43+
44+
torch.testing.assert_close(
45+
triton_ids,
46+
ref_ids,
47+
msg="Expert indices mismatch between Native and Triton implementation",
48+
)
49+
50+
torch.testing.assert_close(
51+
triton_weights,
52+
ref_weights,
53+
atol=1e-3,
54+
rtol=1e-3,
55+
msg="Expert weights mismatch between Native and Triton implementation",
56+
)
57+
58+
59+
if __name__ == "__main__":
60+
test_routing_consistency(128, 32, 2)
61+
print("Consistency Test Passed!")
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""GPT-OSS MoE router with Triton topk kernel."""
4+
5+
import torch
6+
7+
from vllm.triton_utils import tl, triton
8+
9+
10+
@triton.jit
11+
def _topk_softmax_kernel(
12+
logits_ptr,
13+
weights_ptr,
14+
indices_ptr,
15+
M,
16+
N,
17+
topk: tl.constexpr,
18+
topk_padded: tl.constexpr,
19+
stride_lm,
20+
stride_ln,
21+
stride_wm,
22+
stride_wk,
23+
stride_im,
24+
stride_ik,
25+
BLOCK_N: tl.constexpr,
26+
RENORM: tl.constexpr,
27+
):
28+
token_idx = tl.program_id(0)
29+
30+
offs = tl.arange(0, BLOCK_N)
31+
mask = offs < N
32+
logit_offs = logits_ptr + token_idx * stride_lm + offs * stride_ln
33+
logits = tl.load(logit_offs, mask=mask, other=float("-inf"))
34+
35+
topk_vals = tl.zeros([topk_padded], dtype=tl.float32) + float("-inf")
36+
topk_idxs = tl.zeros([topk_padded], dtype=tl.int32)
37+
38+
working_logits = logits
39+
40+
for k in range(topk):
41+
cur_max = tl.max(working_logits, axis=0)
42+
cur_idx = tl.argmax(working_logits, axis=0)
43+
44+
k_mask = tl.arange(0, topk_padded) == k
45+
topk_vals = tl.where(k_mask, cur_max, topk_vals)
46+
topk_idxs = tl.where(k_mask, cur_idx, topk_idxs)
47+
48+
mask_selected = offs == cur_idx
49+
working_logits = tl.where(mask_selected, float("-inf"), working_logits)
50+
51+
if RENORM:
52+
max_val = tl.max(topk_vals, axis=0)
53+
exp_vals = tl.exp(topk_vals - max_val)
54+
sum_exp = tl.sum(exp_vals, axis=0)
55+
topk_vals = exp_vals / sum_exp
56+
57+
offs_k = tl.arange(0, topk_padded)
58+
59+
store_mask = offs_k < topk
60+
61+
weight_ptrs = weights_ptr + token_idx * stride_wm + offs_k * stride_wk
62+
tl.store(weight_ptrs, topk_vals, mask=store_mask)
63+
64+
index_ptrs = indices_ptr + token_idx * stride_im + offs_k * stride_ik
65+
tl.store(index_ptrs, topk_idxs, mask=store_mask)
66+
67+
68+
def fused_topk_softmax(
69+
router_logits: torch.Tensor,
70+
topk: int,
71+
renormalize: bool = True,
72+
) -> tuple[torch.Tensor, torch.Tensor]:
73+
M, N = router_logits.shape
74+
75+
weights = torch.empty((M, topk), device=router_logits.device, dtype=torch.float32)
76+
indices = torch.empty((M, topk), device=router_logits.device, dtype=torch.int32)
77+
78+
BLOCK_N = triton.next_power_of_2(N)
79+
80+
topk_padded = triton.next_power_of_2(topk)
81+
82+
grid = (M,)
83+
84+
_topk_softmax_kernel[grid](
85+
logits_ptr=router_logits,
86+
weights_ptr=weights,
87+
indices_ptr=indices,
88+
M=M,
89+
N=N,
90+
topk=topk,
91+
topk_padded=topk_padded,
92+
stride_lm=router_logits.stride(0),
93+
stride_ln=router_logits.stride(1),
94+
stride_wm=weights.stride(0),
95+
stride_wk=weights.stride(1),
96+
stride_im=indices.stride(0),
97+
stride_ik=indices.stride(1),
98+
BLOCK_N=BLOCK_N,
99+
RENORM=renormalize,
100+
)
101+
102+
return weights, indices
103+
104+
105+
def gpt_oss_custom_routing_function(
106+
hidden_states: torch.Tensor,
107+
gating_output: torch.Tensor,
108+
topk: int,
109+
renormalize: bool,
110+
) -> tuple[torch.Tensor, torch.Tensor]:
111+
# only use gating_output to avoid padding issues
112+
return fused_topk_softmax(gating_output, topk, renormalize)

vllm/model_executor/models/gpt_oss.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
)
2121
from vllm.model_executor.layers.fused_moe import FusedMoE
2222
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
23+
from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import (
24+
gpt_oss_custom_routing_function,
25+
)
2326
from vllm.model_executor.layers.layernorm import RMSNorm
2427
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
2528
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -173,6 +176,11 @@ def __init__(
173176
has_bias=True,
174177
activation="swigluoai",
175178
is_sequence_parallel=self.is_sequence_parallel,
179+
custom_routing_function=(
180+
gpt_oss_custom_routing_function
181+
if not current_platform.is_rocm()
182+
else None
183+
),
176184
)
177185

178186
def forward(self, x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)