Skip to content

Commit 0c7698e

Browse files
committed
[Optimization]: add specialization for small topk
Signed-off-by: ijpq <[email protected]>
1 parent 76d472d commit 0c7698e

File tree

2 files changed

+98
-40
lines changed

2 files changed

+98
-40
lines changed

tests/kernels/moe/test_gpt_oss_routing_consistency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
@pytest.mark.parametrize("num_tokens", [10, 128, 1024])
1414
@pytest.mark.parametrize("num_experts", [32, 65, 128])
15-
@pytest.mark.parametrize("topk", [2, 4])
15+
@pytest.mark.parametrize("topk", [1, 2, 3, 4])
1616
@pytest.mark.parametrize("renorm", [True, False])
1717
@pytest.mark.skipif(not current_platform.is_cuda(), reason="only available on CUDA")
1818
def test_routing_consistency(num_tokens, num_experts, topk, renorm):

vllm/model_executor/layers/fused_moe/gpt_oss_fused_router.py

Lines changed: 97 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -47,48 +47,106 @@ def _topk_softmax_kernel(
4747
mask_n = offs_n < N
4848
store_mask = offs_k < topk
4949

50-
topk_vals = tl.zeros([topk_padded], dtype=tl.float32) + float("-inf")
51-
topk_idxs = tl.zeros([topk_padded], dtype=tl.int32)
52-
53-
for row_idx in tl.range(pid, M, num_programs, num_stages):
54-
logits = tl.load(
55-
logits_ptr + row_idx * stride_lm + offs_n * stride_ln,
56-
mask=mask_n,
57-
other=float("-inf"),
58-
)
59-
60-
if not RENORM:
61-
row_sub_max = logits - tl.max(logits, axis=0)
62-
numerator = tl.exp(row_sub_max)
63-
denominator = tl.sum(numerator, axis=0)
64-
logits = numerator / denominator
65-
66-
for k in tl.static_range(topk):
50+
# specify topk<=2 and RENORM specialization by tl.constexpr,
51+
# similar as `constexpr if` in C++17
52+
if topk == 1:
53+
for row_idx in tl.range(pid, M, num_programs, num_stages):
54+
logits = tl.load(
55+
logits_ptr + row_idx * stride_lm + offs_n * stride_ln,
56+
mask=mask_n,
57+
other=float("-inf"),
58+
)
59+
60+
if not RENORM:
61+
row_sub_max = logits - tl.max(logits, axis=0)
62+
numerator = tl.exp(row_sub_max)
63+
denominator = tl.sum(numerator, axis=0)
64+
logits = numerator / denominator
65+
6766
cur_max = tl.max(logits, axis=0)
6867
cur_idx = tl.argmax(logits, axis=0)
6968

70-
k_mask = offs_k == k
71-
topk_vals = tl.where(k_mask, cur_max, topk_vals)
72-
topk_idxs = tl.where(k_mask, cur_idx, topk_idxs)
73-
74-
logits = tl.where(offs_n == cur_idx, float("-inf"), logits)
75-
76-
if RENORM:
77-
topk_vals = topk_vals - tl.max(topk_vals, axis=0)
78-
numerator = tl.exp(topk_vals)
79-
denominator = tl.sum(numerator, axis=0)
80-
topk_vals = numerator / denominator
81-
82-
tl.store(
83-
weights_ptr + row_idx * stride_wm + offs_k * stride_wk,
84-
topk_vals,
85-
mask=store_mask,
86-
)
87-
tl.store(
88-
indices_ptr + row_idx * stride_im + offs_k * stride_ik,
89-
topk_idxs,
90-
mask=store_mask,
91-
)
69+
if RENORM:
70+
cur_max = 1
71+
72+
tl.store(weights_ptr + row_idx * stride_wm + 0 * stride_wk, cur_max)
73+
tl.store(indices_ptr + row_idx * stride_im + 0 * stride_wk, cur_idx)
74+
75+
elif topk == 2:
76+
for row_idx in tl.range(pid, M, num_programs, num_stages):
77+
logits = tl.load(
78+
logits_ptr + row_idx * stride_lm + offs_n * stride_ln,
79+
mask=mask_n,
80+
other=float("-inf"),
81+
)
82+
83+
if not RENORM:
84+
row_sub_max = logits - tl.max(logits, axis=0)
85+
numerator = tl.exp(row_sub_max)
86+
denominator = tl.sum(numerator, axis=0)
87+
logits = numerator / denominator
88+
89+
val0 = tl.max(logits, axis=0)
90+
idx0 = tl.argmax(logits, axis=0)
91+
logits = tl.where(offs_n == idx0, float("-inf"), logits)
92+
val1 = tl.max(logits, axis=0)
93+
idx1 = tl.argmax(logits, axis=0)
94+
95+
if RENORM:
96+
max_val = tl.maximum(val0, val1)
97+
exp0 = tl.exp(val0 - max_val)
98+
exp1 = tl.exp(val1 - max_val)
99+
val0 = exp0 / (exp0 + exp1)
100+
val1 = exp1 / (exp0 + exp1)
101+
102+
tl.store(weights_ptr + row_idx * stride_wm, val0)
103+
tl.store(indices_ptr + row_idx * stride_im, idx0)
104+
tl.store(weights_ptr + row_idx * stride_wm + 1 * stride_wk, val1)
105+
tl.store(indices_ptr + row_idx * stride_im + 1 * stride_wk, idx1)
106+
107+
else:
108+
topk_vals = tl.zeros([topk_padded], dtype=tl.float32) + float("-inf")
109+
topk_idxs = tl.zeros([topk_padded], dtype=tl.int32)
110+
111+
for row_idx in tl.range(pid, M, num_programs, num_stages):
112+
logits = tl.load(
113+
logits_ptr + row_idx * stride_lm + offs_n * stride_ln,
114+
mask=mask_n,
115+
other=float("-inf"),
116+
)
117+
118+
if not RENORM:
119+
row_sub_max = logits - tl.max(logits, axis=0)
120+
numerator = tl.exp(row_sub_max)
121+
denominator = tl.sum(numerator, axis=0)
122+
logits = numerator / denominator
123+
124+
for k in tl.static_range(topk):
125+
cur_max = tl.max(logits, axis=0)
126+
cur_idx = tl.argmax(logits, axis=0)
127+
128+
k_mask = offs_k == k
129+
topk_vals = tl.where(k_mask, cur_max, topk_vals)
130+
topk_idxs = tl.where(k_mask, cur_idx, topk_idxs)
131+
132+
logits = tl.where(offs_n == cur_idx, float("-inf"), logits)
133+
134+
if RENORM:
135+
topk_vals = topk_vals - tl.max(topk_vals, axis=0)
136+
numerator = tl.exp(topk_vals)
137+
denominator = tl.sum(numerator, axis=0)
138+
topk_vals = numerator / denominator
139+
140+
tl.store(
141+
weights_ptr + row_idx * stride_wm + offs_k * stride_wk,
142+
topk_vals,
143+
mask=store_mask,
144+
)
145+
tl.store(
146+
indices_ptr + row_idx * stride_im + offs_k * stride_ik,
147+
topk_idxs,
148+
mask=store_mask,
149+
)
92150

93151

94152
def fused_topk_softmax(

0 commit comments

Comments
 (0)