Skip to content

Commit dc3f820

Browse files
committed
[Optimization] Add Fused Triton Kernel for topk+softmax
Signed-off-by: ijpq <[email protected]>
1 parent d64429b commit dc3f820

File tree

4 files changed

+207
-0
lines changed

4 files changed

+207
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
6+
from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import fused_topk_softmax
7+
8+
9+
@pytest.mark.parametrize("M", [1, 32, 128, 2048])
10+
@pytest.mark.parametrize("N", [32, 128])
11+
@pytest.mark.parametrize("K", [1, 2])
12+
def test_fused_router(M, N, K):
13+
device = "cuda"
14+
torch.manual_seed(0)
15+
16+
logits = torch.randn((M, N), device=device, dtype=torch.float32)
17+
18+
ref_vals, ref_indices = torch.topk(logits, K, dim=-1)
19+
ref_probs = torch.softmax(ref_vals, dim=-1)
20+
21+
tri_probs, tri_indices = fused_topk_softmax(logits, K, renormalize=True)
22+
23+
torch.testing.assert_close(tri_indices.long(), ref_indices)
24+
torch.testing.assert_close(tri_probs, ref_probs, atol=1e-4, rtol=1e-4)
25+
26+
27+
if __name__ == "__main__":
28+
test_fused_router(128, 32, 2)
29+
print("Test Passed!")
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
6+
from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import (
7+
gpt_oss_custom_routing_function,
8+
)
9+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
10+
11+
12+
@pytest.mark.parametrize("num_tokens", [10, 128, 1024])
13+
@pytest.mark.parametrize("num_experts", [32, 128])
14+
@pytest.mark.parametrize("top_k", [1, 2])
15+
def test_routing_consistency(num_tokens, num_experts, top_k):
16+
torch.manual_seed(42)
17+
device = torch.device("cuda")
18+
19+
hidden_states = torch.randn(num_tokens, 4096, device=device, dtype=torch.float16)
20+
router_logits = torch.randn(
21+
num_tokens, num_experts, device=device, dtype=torch.float32
22+
)
23+
24+
ref_weights, ref_ids, _ = FusedMoE.select_experts(
25+
hidden_states=hidden_states,
26+
router_logits=router_logits,
27+
top_k=top_k,
28+
use_grouped_topk=False,
29+
renormalize=True,
30+
custom_routing_function=None,
31+
)
32+
33+
triton_weights, triton_ids, _ = FusedMoE.select_experts(
34+
hidden_states=hidden_states,
35+
router_logits=router_logits,
36+
top_k=top_k,
37+
use_grouped_topk=False,
38+
renormalize=True,
39+
custom_routing_function=gpt_oss_custom_routing_function,
40+
)
41+
42+
print(f"\nTesting M={num_tokens}, E={num_experts}, K={top_k}")
43+
44+
torch.testing.assert_close(
45+
triton_ids,
46+
ref_ids,
47+
msg="Expert indices mismatch between Native and Triton implementation",
48+
)
49+
50+
torch.testing.assert_close(
51+
triton_weights,
52+
ref_weights,
53+
atol=1e-3,
54+
rtol=1e-3,
55+
msg="Expert weights mismatch between Native and Triton implementation",
56+
)
57+
58+
59+
if __name__ == "__main__":
60+
test_routing_consistency(128, 32, 2)
61+
print("Consistency Test Passed!")
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
5+
from vllm.triton_utils import tl, triton
6+
7+
8+
@triton.jit
9+
def _topk_softmax_kernel(
10+
logits_ptr,
11+
weights_ptr,
12+
indices_ptr,
13+
M,
14+
N,
15+
K: tl.constexpr,
16+
stride_lm,
17+
stride_ln,
18+
stride_wm,
19+
stride_wk,
20+
stride_im,
21+
stride_ik,
22+
BLOCK_N: tl.constexpr,
23+
RENORM: tl.constexpr,
24+
):
25+
pid = tl.program_id(0)
26+
27+
offs = tl.arange(0, BLOCK_N)
28+
logits_offs = logits_ptr + pid * stride_lm + offs * stride_ln
29+
mask = offs < N
30+
logits = tl.load(logits_offs, mask=mask, other=float("-inf"))
31+
32+
if K == 1:
33+
max_val = tl.max(logits, axis=0)
34+
max_idx = tl.argmax(logits, axis=0)
35+
36+
weight = 1.0 if RENORM else max_val
37+
38+
tl.store(weights_ptr + pid * stride_wm, weight)
39+
tl.store(indices_ptr + pid * stride_im, max_idx)
40+
41+
elif K == 2:
42+
# first max
43+
v1 = tl.max(logits, axis=0)
44+
i1 = tl.argmax(logits, axis=0)
45+
46+
# second max
47+
masked = tl.where(offs != i1, logits, float("-inf"))
48+
v2 = tl.max(masked, axis=0)
49+
i2 = tl.argmax(masked, axis=0)
50+
51+
if RENORM:
52+
vmax = tl.maximum(v1, v2)
53+
e1 = tl.exp(v1 - vmax)
54+
e2 = tl.exp(v2 - vmax)
55+
s = e1 + e2
56+
w1, w2 = e1 / s, e2 / s
57+
else:
58+
w1, w2 = v1, v2
59+
60+
tl.store(weights_ptr + pid * stride_wm, w1)
61+
tl.store(weights_ptr + pid * stride_wm + stride_wk, w2)
62+
tl.store(indices_ptr + pid * stride_im, i1)
63+
tl.store(indices_ptr + pid * stride_im + stride_ik, i2)
64+
65+
66+
def fused_topk_softmax(
67+
router_logits: torch.Tensor,
68+
top_k: int,
69+
renormalize: bool = True,
70+
) -> tuple[torch.Tensor, torch.Tensor]:
71+
if top_k not in [1, 2]:
72+
raise NotImplementedError(f"Only K=1,2 supported, got {top_k}")
73+
74+
M, N = router_logits.shape
75+
76+
weights = torch.empty((M, top_k), device=router_logits.device, dtype=torch.float32)
77+
indices = torch.empty((M, top_k), device=router_logits.device, dtype=torch.int32)
78+
79+
BLOCK_N = triton.next_power_of_2(N)
80+
grid = (M,)
81+
82+
_topk_softmax_kernel[grid](
83+
router_logits,
84+
weights,
85+
indices,
86+
M,
87+
N,
88+
K=top_k,
89+
stride_lm=router_logits.stride(0),
90+
stride_ln=router_logits.stride(1),
91+
stride_wm=weights.stride(0),
92+
stride_wk=weights.stride(1),
93+
stride_im=indices.stride(0),
94+
stride_ik=indices.stride(1),
95+
BLOCK_N=BLOCK_N,
96+
RENORM=renormalize,
97+
)
98+
99+
return weights, indices
100+
101+
102+
def gpt_oss_custom_routing_function(
103+
hidden_states: torch.Tensor,
104+
gating_output: torch.Tensor,
105+
topk: int,
106+
renormalize: bool,
107+
) -> tuple[torch.Tensor, torch.Tensor]:
108+
# only use gating_output to avoid padding issues
109+
return fused_topk_softmax(gating_output, topk, renormalize)

vllm/model_executor/models/gpt_oss.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
)
2121
from vllm.model_executor.layers.fused_moe import FusedMoE
2222
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
23+
from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import (
24+
gpt_oss_custom_routing_function,
25+
)
2326
from vllm.model_executor.layers.layernorm import RMSNorm
2427
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
2528
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -173,6 +176,11 @@ def __init__(
173176
has_bias=True,
174177
activation="swigluoai",
175178
is_sequence_parallel=self.is_sequence_parallel,
179+
custom_routing_function=(
180+
gpt_oss_custom_routing_function
181+
if not current_platform.is_rocm()
182+
else None
183+
),
176184
)
177185

178186
def forward(self, x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)