Skip to content

Commit 76d472d

Browse files
committed
[Fix]: Tweak a few little things in triton kernel
- delete unnecessary splited kernel - add skipif in unittest - ACK reviews Signed-off-by: ijpq <[email protected]>
1 parent 66e6711 commit 76d472d

File tree

4 files changed

+69
-167
lines changed

4 files changed

+69
-167
lines changed

tests/kernels/moe/test_gpt_oss_fused_router.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

tests/kernels/moe/test_gpt_oss_routing_consistency.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,22 @@
77
gpt_oss_custom_routing_function,
88
)
99
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
10+
from vllm.platforms import current_platform
1011

1112

1213
@pytest.mark.parametrize("num_tokens", [10, 128, 1024])
1314
@pytest.mark.parametrize("num_experts", [32, 65, 128])
1415
@pytest.mark.parametrize("topk", [2, 4])
1516
@pytest.mark.parametrize("renorm", [True, False])
17+
@pytest.mark.skipif(not current_platform.is_cuda(), reason="only available on CUDA")
1618
def test_routing_consistency(num_tokens, num_experts, topk, renorm):
1719
torch.manual_seed(42)
1820
device = torch.device("cuda")
19-
2021
hidden_states = torch.randn(num_tokens, 4096, device=device, dtype=torch.float16)
2122
router_logits = torch.randn(
2223
num_tokens, num_experts, device=device, dtype=torch.float32
2324
)
2425

25-
def native_impl(logits, topk, renorm):
26-
if renorm:
27-
ref_vals, ref_indices = torch.topk(logits, topk, dim=1)
28-
ref_vals = torch.softmax(ref_vals, dim=1)
29-
else:
30-
ref_vals = torch.softmax(logits, dim=1)
31-
ref_vals, ref_indices = torch.topk(ref_vals, topk, dim=1)
32-
33-
return ref_vals, ref_indices
34-
35-
native_weights, native_ids = native_impl(router_logits, topk, renorm)
36-
3726
ref_weights, ref_ids, _ = FusedMoE.select_experts(
3827
hidden_states=hidden_states,
3928
router_logits=router_logits,
@@ -52,40 +41,53 @@ def native_impl(logits, topk, renorm):
5241
custom_routing_function=gpt_oss_custom_routing_function,
5342
)
5443

55-
print(f"\nTesting M={num_tokens}, E={num_experts}, K={topk}")
56-
57-
# compare triton with torch
44+
# compare triton with origin
5845
torch.testing.assert_close(
59-
triton_ids.to(native_ids.dtype),
60-
native_ids,
61-
msg="Expert indices mismatch between native and triton implementation",
46+
triton_ids,
47+
ref_ids,
48+
msg="Expert indices mismatch between origin and triton implementation",
6249
)
63-
6450
torch.testing.assert_close(
6551
triton_weights,
66-
native_weights,
52+
ref_weights,
6753
atol=1e-3,
6854
rtol=1e-3,
69-
msg="Expert weights mismatch between native and triton implementation",
55+
msg="Expert weights mismatch between origin and triton implementation",
7056
)
57+
expected_indices_dtype = ref_ids.dtype
58+
expecteed_weight_dtype = ref_weights.dtype
7159

72-
# compare triton with origin
60+
def native_impl(logits, topk, renorm):
61+
if renorm:
62+
ref_vals, ref_indices = torch.topk(logits, topk, dim=1)
63+
ref_vals = torch.softmax(ref_vals, dim=1)
64+
else:
65+
ref_vals = torch.softmax(logits, dim=1)
66+
ref_vals, ref_indices = torch.topk(ref_vals, topk, dim=1)
67+
return ref_vals.to(expecteed_weight_dtype), ref_indices.to(
68+
expected_indices_dtype
69+
)
70+
71+
native_weights, native_ids = native_impl(router_logits, topk, renorm)
72+
73+
# compare triton with torch
7374
torch.testing.assert_close(
7475
triton_ids,
75-
ref_ids,
76-
msg="Expert indices mismatch between origin and triton implementation",
76+
native_ids,
77+
msg="Expert indices mismatch between native and triton implementation",
7778
)
7879
torch.testing.assert_close(
7980
triton_weights,
80-
ref_weights,
81+
native_weights,
8182
atol=1e-3,
8283
rtol=1e-3,
83-
msg="Expert weights mismatch between origin and triton implementation",
84+
msg="Expert weights mismatch between native and triton implementation",
8485
)
86+
8587
# compare origin with torch
8688
torch.testing.assert_close(
8789
native_ids,
88-
ref_ids.to(native_ids.dtype),
90+
ref_ids,
8991
msg="Expert indices mismatch between origin and native implementation",
9092
)
9193
torch.testing.assert_close(
@@ -95,3 +97,5 @@ def native_impl(logits, topk, renorm):
9597
rtol=1e-3,
9698
msg="Expert weights mismatch between origin and native implementation",
9799
)
100+
101+
print(f"\nTesting TOKENS={num_tokens}, EXPERTS={num_experts}, TOPK={topk}")

vllm/model_executor/layers/fused_moe/gpt_oss_fused_router.py

Lines changed: 36 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ def torch_dtype_to_tl(dtype: torch.dtype):
2222

2323
@triton.jit
2424
def _topk_softmax_kernel(
25-
logits_ptr: torch.Tensor,
26-
weights_ptr: torch.Tensor,
27-
indices_ptr: torch.Tensor,
25+
logits_ptr,
26+
weights_ptr,
27+
indices_ptr,
2828
M,
29-
N,
29+
N: tl.constexpr,
3030
topk: tl.constexpr,
3131
topk_padded: tl.constexpr,
3232
stride_lm,
@@ -36,6 +36,7 @@ def _topk_softmax_kernel(
3636
stride_im,
3737
stride_ik,
3838
BLOCK_N: tl.constexpr,
39+
RENORM: tl.constexpr,
3940
num_stages: tl.constexpr,
4041
):
4142
pid = tl.program_id(0)
@@ -44,6 +45,7 @@ def _topk_softmax_kernel(
4445
offs_n = tl.arange(0, BLOCK_N)
4546
offs_k = tl.arange(0, topk_padded)
4647
mask_n = offs_n < N
48+
store_mask = offs_k < topk
4749

4850
topk_vals = tl.zeros([topk_padded], dtype=tl.float32) + float("-inf")
4951
topk_idxs = tl.zeros([topk_padded], dtype=tl.int32)
@@ -54,10 +56,12 @@ def _topk_softmax_kernel(
5456
mask=mask_n,
5557
other=float("-inf"),
5658
)
57-
row_sub_max = logits - tl.max(logits, axis=0)
58-
numerator = tl.exp(row_sub_max)
59-
denominator = tl.sum(numerator, axis=0)
60-
logits = numerator / denominator
59+
60+
if not RENORM:
61+
row_sub_max = logits - tl.max(logits, axis=0)
62+
numerator = tl.exp(row_sub_max)
63+
denominator = tl.sum(numerator, axis=0)
64+
logits = numerator / denominator
6165

6266
for k in tl.static_range(topk):
6367
cur_max = tl.max(logits, axis=0)
@@ -69,7 +73,12 @@ def _topk_softmax_kernel(
6973

7074
logits = tl.where(offs_n == cur_idx, float("-inf"), logits)
7175

72-
store_mask = offs_k < topk
76+
if RENORM:
77+
topk_vals = topk_vals - tl.max(topk_vals, axis=0)
78+
numerator = tl.exp(topk_vals)
79+
denominator = tl.sum(numerator, axis=0)
80+
topk_vals = numerator / denominator
81+
7382
tl.store(
7483
weights_ptr + row_idx * stride_wm + offs_k * stride_wk,
7584
topk_vals,
@@ -82,66 +91,6 @@ def _topk_softmax_kernel(
8291
)
8392

8493

85-
@triton.jit
86-
def _topk_softmax_renorm_kernel(
87-
logits_ptr,
88-
weights_ptr,
89-
indices_ptr,
90-
M,
91-
N,
92-
topk: tl.constexpr,
93-
topk_padded: tl.constexpr,
94-
stride_lm,
95-
stride_ln,
96-
stride_wm,
97-
stride_wk,
98-
stride_im,
99-
stride_ik,
100-
BLOCK_N: tl.constexpr,
101-
num_stages: tl.constexpr,
102-
):
103-
pid = tl.program_id(0)
104-
num_programs = tl.num_programs(0)
105-
106-
offs_n = tl.arange(0, BLOCK_N)
107-
offs_k = tl.arange(0, topk_padded)
108-
mask_n = offs_n < N
109-
110-
for row_idx in tl.range(pid, M, num_programs, num_stages):
111-
logits = tl.load(
112-
logits_ptr + row_idx * stride_lm + offs_n * stride_ln,
113-
mask=mask_n,
114-
other=float("-inf"),
115-
)
116-
117-
topk_vals = tl.zeros([topk_padded], dtype=tl.float32) + float("-inf")
118-
topk_idxs = tl.zeros([topk_padded], dtype=tl.int32)
119-
120-
running_max = float("-inf")
121-
running_sum = 0.0
122-
123-
for k in tl.static_range(topk):
124-
cur_max = tl.max(logits, axis=0)
125-
cur_idx = tl.argmax(logits, axis=0)
126-
127-
new_max = tl.maximum(running_max, cur_max)
128-
running_sum = running_sum * tl.exp(running_max - new_max) + tl.exp(
129-
cur_max - new_max
130-
)
131-
running_max = new_max
132-
133-
k_mask = offs_k == k
134-
topk_vals = tl.where(k_mask, cur_max, topk_vals)
135-
topk_idxs = tl.where(k_mask, cur_idx, topk_idxs)
136-
137-
logits = tl.where(offs_n == cur_idx, float("-inf"), logits)
138-
139-
topk_vals = tl.exp(topk_vals - running_max) / running_sum
140-
141-
tl.store(weights_ptr + row_idx * stride_wm + offs_k * stride_wk, topk_vals)
142-
tl.store(indices_ptr + row_idx * stride_im + offs_k * stride_ik, topk_idxs)
143-
144-
14594
def fused_topk_softmax(
14695
router_logits: torch.Tensor,
14796
topk: int,
@@ -155,48 +104,28 @@ def fused_topk_softmax(
155104
indices = torch.empty((M, topk), device=router_logits.device, dtype=torch.int32)
156105

157106
BLOCK_N = triton.next_power_of_2(N) # num_padded_experts
158-
159107
topk_padded = triton.next_power_of_2(topk)
160-
161108
grid = (M,)
162109
num_stages = 2
163110

164-
if renormalize:
165-
_topk_softmax_renorm_kernel[grid](
166-
logits_ptr=router_logits,
167-
weights_ptr=weights,
168-
indices_ptr=indices,
169-
M=M,
170-
N=N,
171-
topk=topk,
172-
topk_padded=topk_padded,
173-
stride_lm=router_logits.stride(0),
174-
stride_ln=router_logits.stride(1),
175-
stride_wm=weights.stride(0),
176-
stride_wk=weights.stride(1),
177-
stride_im=indices.stride(0),
178-
stride_ik=indices.stride(1),
179-
BLOCK_N=BLOCK_N,
180-
num_stages=num_stages,
181-
)
182-
else:
183-
_topk_softmax_kernel[grid](
184-
logits_ptr=router_logits,
185-
weights_ptr=weights,
186-
indices_ptr=indices,
187-
M=M,
188-
N=N,
189-
topk=topk,
190-
topk_padded=topk_padded,
191-
stride_lm=router_logits.stride(0),
192-
stride_ln=router_logits.stride(1),
193-
stride_wm=weights.stride(0),
194-
stride_wk=weights.stride(1),
195-
stride_im=indices.stride(0),
196-
stride_ik=indices.stride(1),
197-
BLOCK_N=BLOCK_N,
198-
num_stages=num_stages,
199-
)
111+
_topk_softmax_kernel[grid](
112+
logits_ptr=router_logits,
113+
weights_ptr=weights,
114+
indices_ptr=indices,
115+
M=M,
116+
N=N,
117+
topk=topk,
118+
topk_padded=topk_padded,
119+
stride_lm=router_logits.stride(0),
120+
stride_ln=router_logits.stride(1),
121+
stride_wm=weights.stride(0),
122+
stride_wk=weights.stride(1),
123+
stride_im=indices.stride(0),
124+
stride_ik=indices.stride(1),
125+
BLOCK_N=BLOCK_N,
126+
RENORM=renormalize,
127+
num_stages=num_stages,
128+
)
200129

201130
return weights, indices
202131

vllm/model_executor/models/gpt_oss.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,7 @@ def __init__(
177177
activation="swigluoai",
178178
is_sequence_parallel=self.is_sequence_parallel,
179179
custom_routing_function=(
180-
gpt_oss_custom_routing_function
181-
if not current_platform.is_rocm()
182-
else None
180+
gpt_oss_custom_routing_function if current_platform.is_cuda() else None
183181
),
184182
)
185183

0 commit comments

Comments
 (0)