Skip to content

Commit 66e6711

Browse files
committed
[Optimization]: Optimize Fused Triton Kernel for topk+softmax
- split two kernels, in case renorm or not - add online softmax - unroll along M Signed-off-by: ijpq <[email protected]>
1 parent fca484b commit 66e6711

File tree

2 files changed

+197
-62
lines changed

2 files changed

+197
-62
lines changed

tests/kernels/moe/test_gpt_oss_routing_consistency.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
@pytest.mark.parametrize("num_tokens", [10, 128, 1024])
1313
@pytest.mark.parametrize("num_experts", [32, 65, 128])
14-
@pytest.mark.parametrize("topk", [1, 2, 3, 4, 5])
15-
def test_routing_consistency(num_tokens, num_experts, topk):
14+
@pytest.mark.parametrize("topk", [2, 4])
15+
@pytest.mark.parametrize("renorm", [True, False])
16+
def test_routing_consistency(num_tokens, num_experts, topk, renorm):
1617
torch.manual_seed(42)
1718
device = torch.device("cuda")
1819

@@ -21,12 +22,24 @@ def test_routing_consistency(num_tokens, num_experts, topk):
2122
num_tokens, num_experts, device=device, dtype=torch.float32
2223
)
2324

25+
def native_impl(logits, topk, renorm):
26+
if renorm:
27+
ref_vals, ref_indices = torch.topk(logits, topk, dim=1)
28+
ref_vals = torch.softmax(ref_vals, dim=1)
29+
else:
30+
ref_vals = torch.softmax(logits, dim=1)
31+
ref_vals, ref_indices = torch.topk(ref_vals, topk, dim=1)
32+
33+
return ref_vals, ref_indices
34+
35+
native_weights, native_ids = native_impl(router_logits, topk, renorm)
36+
2437
ref_weights, ref_ids, _ = FusedMoE.select_experts(
2538
hidden_states=hidden_states,
2639
router_logits=router_logits,
2740
top_k=topk,
2841
use_grouped_topk=False,
29-
renormalize=True,
42+
renormalize=renorm,
3043
custom_routing_function=None,
3144
)
3245

@@ -35,27 +48,50 @@ def test_routing_consistency(num_tokens, num_experts, topk):
3548
router_logits=router_logits,
3649
top_k=topk,
3750
use_grouped_topk=False,
38-
renormalize=True,
51+
renormalize=renorm,
3952
custom_routing_function=gpt_oss_custom_routing_function,
4053
)
4154

4255
print(f"\nTesting M={num_tokens}, E={num_experts}, K={topk}")
4356

57+
# compare triton with torch
58+
torch.testing.assert_close(
59+
triton_ids.to(native_ids.dtype),
60+
native_ids,
61+
msg="Expert indices mismatch between native and triton implementation",
62+
)
63+
64+
torch.testing.assert_close(
65+
triton_weights,
66+
native_weights,
67+
atol=1e-3,
68+
rtol=1e-3,
69+
msg="Expert weights mismatch between native and triton implementation",
70+
)
71+
72+
# compare triton with origin
4473
torch.testing.assert_close(
4574
triton_ids,
4675
ref_ids,
47-
msg="Expert indices mismatch between Native and Triton implementation",
76+
msg="Expert indices mismatch between origin and triton implementation",
4877
)
49-
5078
torch.testing.assert_close(
5179
triton_weights,
5280
ref_weights,
5381
atol=1e-3,
5482
rtol=1e-3,
55-
msg="Expert weights mismatch between Native and Triton implementation",
83+
msg="Expert weights mismatch between origin and triton implementation",
84+
)
85+
# compare origin with torch
86+
torch.testing.assert_close(
87+
native_ids,
88+
ref_ids.to(native_ids.dtype),
89+
msg="Expert indices mismatch between origin and native implementation",
90+
)
91+
torch.testing.assert_close(
92+
native_weights,
93+
ref_weights,
94+
atol=1e-3,
95+
rtol=1e-3,
96+
msg="Expert weights mismatch between origin and native implementation",
5697
)
57-
58-
59-
if __name__ == "__main__":
60-
test_routing_consistency(128, 32, 2)
61-
print("Consistency Test Passed!")

vllm/model_executor/layers/fused_moe/gpt_oss_fused_router.py

Lines changed: 149 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,24 @@
77
from vllm.triton_utils import tl, triton
88

99

10+
def torch_dtype_to_tl(dtype: torch.dtype):
11+
if dtype == torch.float16:
12+
return tl.float16
13+
elif dtype == torch.bfloat16:
14+
return tl.bfloat16
15+
elif dtype == torch.float32:
16+
return tl.float32
17+
elif dtype == torch.int32:
18+
return tl.int32
19+
else:
20+
raise ValueError(f"Unsupported dtype: {dtype}")
21+
22+
1023
@triton.jit
1124
def _topk_softmax_kernel(
12-
logits_ptr,
13-
weights_ptr,
14-
indices_ptr,
25+
logits_ptr: torch.Tensor,
26+
weights_ptr: torch.Tensor,
27+
indices_ptr: torch.Tensor,
1528
M,
1629
N,
1730
topk: tl.constexpr,
@@ -23,81 +36,167 @@ def _topk_softmax_kernel(
2336
stride_im,
2437
stride_ik,
2538
BLOCK_N: tl.constexpr,
26-
RENORM: tl.constexpr,
39+
num_stages: tl.constexpr,
2740
):
28-
token_idx = tl.program_id(0)
41+
pid = tl.program_id(0)
42+
num_programs = tl.num_programs(0)
2943

30-
offs = tl.arange(0, BLOCK_N)
31-
mask = offs < N
32-
logit_offs = logits_ptr + token_idx * stride_lm + offs * stride_ln
33-
logits = tl.load(logit_offs, mask=mask, other=float("-inf"))
44+
offs_n = tl.arange(0, BLOCK_N)
45+
offs_k = tl.arange(0, topk_padded)
46+
mask_n = offs_n < N
3447

3548
topk_vals = tl.zeros([topk_padded], dtype=tl.float32) + float("-inf")
3649
topk_idxs = tl.zeros([topk_padded], dtype=tl.int32)
3750

38-
working_logits = logits
51+
for row_idx in tl.range(pid, M, num_programs, num_stages):
52+
logits = tl.load(
53+
logits_ptr + row_idx * stride_lm + offs_n * stride_ln,
54+
mask=mask_n,
55+
other=float("-inf"),
56+
)
57+
row_sub_max = logits - tl.max(logits, axis=0)
58+
numerator = tl.exp(row_sub_max)
59+
denominator = tl.sum(numerator, axis=0)
60+
logits = numerator / denominator
61+
62+
for k in tl.static_range(topk):
63+
cur_max = tl.max(logits, axis=0)
64+
cur_idx = tl.argmax(logits, axis=0)
65+
66+
k_mask = offs_k == k
67+
topk_vals = tl.where(k_mask, cur_max, topk_vals)
68+
topk_idxs = tl.where(k_mask, cur_idx, topk_idxs)
69+
70+
logits = tl.where(offs_n == cur_idx, float("-inf"), logits)
71+
72+
store_mask = offs_k < topk
73+
tl.store(
74+
weights_ptr + row_idx * stride_wm + offs_k * stride_wk,
75+
topk_vals,
76+
mask=store_mask,
77+
)
78+
tl.store(
79+
indices_ptr + row_idx * stride_im + offs_k * stride_ik,
80+
topk_idxs,
81+
mask=store_mask,
82+
)
3983

40-
for k in range(topk):
41-
cur_max = tl.max(working_logits, axis=0)
42-
cur_idx = tl.argmax(working_logits, axis=0)
4384

44-
k_mask = tl.arange(0, topk_padded) == k
45-
topk_vals = tl.where(k_mask, cur_max, topk_vals)
46-
topk_idxs = tl.where(k_mask, cur_idx, topk_idxs)
85+
@triton.jit
86+
def _topk_softmax_renorm_kernel(
87+
logits_ptr,
88+
weights_ptr,
89+
indices_ptr,
90+
M,
91+
N,
92+
topk: tl.constexpr,
93+
topk_padded: tl.constexpr,
94+
stride_lm,
95+
stride_ln,
96+
stride_wm,
97+
stride_wk,
98+
stride_im,
99+
stride_ik,
100+
BLOCK_N: tl.constexpr,
101+
num_stages: tl.constexpr,
102+
):
103+
pid = tl.program_id(0)
104+
num_programs = tl.num_programs(0)
47105

48-
mask_selected = offs == cur_idx
49-
working_logits = tl.where(mask_selected, float("-inf"), working_logits)
106+
offs_n = tl.arange(0, BLOCK_N)
107+
offs_k = tl.arange(0, topk_padded)
108+
mask_n = offs_n < N
50109

51-
if RENORM:
52-
max_val = tl.max(topk_vals, axis=0)
53-
exp_vals = tl.exp(topk_vals - max_val)
54-
sum_exp = tl.sum(exp_vals, axis=0)
55-
topk_vals = exp_vals / sum_exp
110+
for row_idx in tl.range(pid, M, num_programs, num_stages):
111+
logits = tl.load(
112+
logits_ptr + row_idx * stride_lm + offs_n * stride_ln,
113+
mask=mask_n,
114+
other=float("-inf"),
115+
)
56116

57-
offs_k = tl.arange(0, topk_padded)
117+
topk_vals = tl.zeros([topk_padded], dtype=tl.float32) + float("-inf")
118+
topk_idxs = tl.zeros([topk_padded], dtype=tl.int32)
58119

59-
store_mask = offs_k < topk
120+
running_max = float("-inf")
121+
running_sum = 0.0
60122

61-
weight_ptrs = weights_ptr + token_idx * stride_wm + offs_k * stride_wk
62-
tl.store(weight_ptrs, topk_vals, mask=store_mask)
123+
for k in tl.static_range(topk):
124+
cur_max = tl.max(logits, axis=0)
125+
cur_idx = tl.argmax(logits, axis=0)
63126

64-
index_ptrs = indices_ptr + token_idx * stride_im + offs_k * stride_ik
65-
tl.store(index_ptrs, topk_idxs, mask=store_mask)
127+
new_max = tl.maximum(running_max, cur_max)
128+
running_sum = running_sum * tl.exp(running_max - new_max) + tl.exp(
129+
cur_max - new_max
130+
)
131+
running_max = new_max
132+
133+
k_mask = offs_k == k
134+
topk_vals = tl.where(k_mask, cur_max, topk_vals)
135+
topk_idxs = tl.where(k_mask, cur_idx, topk_idxs)
136+
137+
logits = tl.where(offs_n == cur_idx, float("-inf"), logits)
138+
139+
topk_vals = tl.exp(topk_vals - running_max) / running_sum
140+
141+
tl.store(weights_ptr + row_idx * stride_wm + offs_k * stride_wk, topk_vals)
142+
tl.store(indices_ptr + row_idx * stride_im + offs_k * stride_ik, topk_idxs)
66143

67144

68145
def fused_topk_softmax(
69146
router_logits: torch.Tensor,
70147
topk: int,
71148
renormalize: bool = True,
72149
) -> tuple[torch.Tensor, torch.Tensor]:
73-
M, N = router_logits.shape
150+
M, N = router_logits.shape # num_tokens, num_experts
74151

75-
weights = torch.empty((M, topk), device=router_logits.device, dtype=torch.float32)
152+
weights = torch.empty(
153+
(M, topk), device=router_logits.device, dtype=router_logits.dtype
154+
)
76155
indices = torch.empty((M, topk), device=router_logits.device, dtype=torch.int32)
77156

78-
BLOCK_N = triton.next_power_of_2(N)
157+
BLOCK_N = triton.next_power_of_2(N) # num_padded_experts
79158

80159
topk_padded = triton.next_power_of_2(topk)
81160

82161
grid = (M,)
83-
84-
_topk_softmax_kernel[grid](
85-
logits_ptr=router_logits,
86-
weights_ptr=weights,
87-
indices_ptr=indices,
88-
M=M,
89-
N=N,
90-
topk=topk,
91-
topk_padded=topk_padded,
92-
stride_lm=router_logits.stride(0),
93-
stride_ln=router_logits.stride(1),
94-
stride_wm=weights.stride(0),
95-
stride_wk=weights.stride(1),
96-
stride_im=indices.stride(0),
97-
stride_ik=indices.stride(1),
98-
BLOCK_N=BLOCK_N,
99-
RENORM=renormalize,
100-
)
162+
num_stages = 2
163+
164+
if renormalize:
165+
_topk_softmax_renorm_kernel[grid](
166+
logits_ptr=router_logits,
167+
weights_ptr=weights,
168+
indices_ptr=indices,
169+
M=M,
170+
N=N,
171+
topk=topk,
172+
topk_padded=topk_padded,
173+
stride_lm=router_logits.stride(0),
174+
stride_ln=router_logits.stride(1),
175+
stride_wm=weights.stride(0),
176+
stride_wk=weights.stride(1),
177+
stride_im=indices.stride(0),
178+
stride_ik=indices.stride(1),
179+
BLOCK_N=BLOCK_N,
180+
num_stages=num_stages,
181+
)
182+
else:
183+
_topk_softmax_kernel[grid](
184+
logits_ptr=router_logits,
185+
weights_ptr=weights,
186+
indices_ptr=indices,
187+
M=M,
188+
N=N,
189+
topk=topk,
190+
topk_padded=topk_padded,
191+
stride_lm=router_logits.stride(0),
192+
stride_ln=router_logits.stride(1),
193+
stride_wm=weights.stride(0),
194+
stride_wk=weights.stride(1),
195+
stride_im=indices.stride(0),
196+
stride_ik=indices.stride(1),
197+
BLOCK_N=BLOCK_N,
198+
num_stages=num_stages,
199+
)
101200

102201
return weights, indices
103202

0 commit comments

Comments
 (0)