Skip to content

Commit eacf1cd

Browse files
committed
[Optimization]: add autotune and fix little things
Signed-off-by: ijpq <[email protected]>
1 parent f0301c9 commit eacf1cd

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

vllm/model_executor/layers/fused_moe/gpt_oss_fused_router.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@ def torch_dtype_to_tl(dtype: torch.dtype):
2020
raise ValueError(f"Unsupported dtype: {dtype}")
2121

2222

23+
@triton.autotune(
24+
configs=[
25+
triton.Config({"ROWS_PER_PID": r}, num_warps=num_warps, num_stages=num_stages)
26+
for r in [1, 2, 4, 8, 16, 32, 64, 128]
27+
for num_warps in [1, 2, 4, 8, 16, 32]
28+
for num_stages in [1, 2, 3]
29+
],
30+
key=["N", "topk"],
31+
cache_results=True,
32+
)
2333
@triton.jit
2434
def _topk_softmax_kernel(
2535
logits_ptr,
@@ -37,8 +47,8 @@ def _topk_softmax_kernel(
3747
stride_ik,
3848
BLOCK_N: tl.constexpr,
3949
RENORM: tl.constexpr,
40-
num_stages: tl.constexpr,
4150
ROWS_PER_PID: tl.constexpr,
51+
num_stages: tl.constexpr,
4252
):
4353
pid = tl.program_id(0)
4454
num_programs = tl.num_programs(0)
@@ -48,8 +58,8 @@ def _topk_softmax_kernel(
4858
mask_n = offs_n < N
4959
store_mask = offs_k < topk
5060

51-
# specify topk<=2 and RENORM specialization by tl.constexpr,
52-
# similar as `constexpr if` in C++17
61+
# impl topk<=2 and RENORM specialization by tl.constexpr,
62+
# same as constexpr if in C++17
5363
if topk == 1:
5464
for row_idx in tl.range(pid, M, num_programs, num_stages):
5565
logits = tl.load(
@@ -64,14 +74,11 @@ def _topk_softmax_kernel(
6474
denominator = tl.sum(numerator, axis=0)
6575
logits = numerator / denominator
6676

67-
cur_max = tl.max(logits, axis=0)
77+
cur_max = 1 if RENORM else tl.max(logits, axis=0)
6878
cur_idx = tl.argmax(logits, axis=0)
6979

70-
if RENORM:
71-
cur_max = 1
72-
7380
tl.store(weights_ptr + row_idx * stride_wm + 0 * stride_wk, cur_max)
74-
tl.store(indices_ptr + row_idx * stride_im + 0 * stride_wk, cur_idx)
81+
tl.store(indices_ptr + row_idx * stride_im + 0 * stride_ik, cur_idx)
7582

7683
elif topk == 2:
7784
for row_idx in tl.range(pid, M, num_programs, num_stages):
@@ -103,7 +110,7 @@ def _topk_softmax_kernel(
103110
tl.store(weights_ptr + row_idx * stride_wm, val0)
104111
tl.store(indices_ptr + row_idx * stride_im, idx0)
105112
tl.store(weights_ptr + row_idx * stride_wm + 1 * stride_wk, val1)
106-
tl.store(indices_ptr + row_idx * stride_im + 1 * stride_wk, idx1)
113+
tl.store(indices_ptr + row_idx * stride_im + 1 * stride_ik, idx1)
107114

108115
else:
109116
topk_vals = tl.zeros([ROWS_PER_PID, topk_padded], dtype=tl.float32) + float(
@@ -113,7 +120,10 @@ def _topk_softmax_kernel(
113120

114121
rows = tl.arange(0, ROWS_PER_PID)
115122
for row_idx in tl.range(
116-
pid * ROWS_PER_PID, M, num_programs * ROWS_PER_PID, num_stages
123+
pid * ROWS_PER_PID,
124+
M,
125+
num_programs * ROWS_PER_PID,
126+
num_stages,
117127
):
118128
row_indices = row_idx + rows # [ROWS_PER_POD,]
119129
row_mask = row_indices < M
@@ -183,16 +193,15 @@ def fused_topk_softmax(
183193
) -> tuple[torch.Tensor, torch.Tensor]:
184194
M, N = router_logits.shape # num_tokens, num_experts
185195

186-
weights = torch.empty(
187-
(M, topk), device=router_logits.device, dtype=router_logits.dtype
188-
)
196+
weights = torch.empty((M, topk), device=router_logits.device, dtype=torch.float32)
189197
indices = torch.empty((M, topk), device=router_logits.device, dtype=torch.int32)
190198

191199
BLOCK_N = triton.next_power_of_2(N) # num_padded_experts
192200
topk_padded = triton.next_power_of_2(topk)
193-
grid = (M,)
194-
num_stages = 2
195-
ROWS_PER_PID = 4
201+
202+
# enable autotune to find correct num threadblock,
203+
# refer to https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
204+
grid = lambda META: (triton.cdiv(M, META["ROWS_PER_PID"]),)
196205

197206
_topk_softmax_kernel[grid](
198207
logits_ptr=router_logits,
@@ -210,8 +219,6 @@ def fused_topk_softmax(
210219
stride_ik=indices.stride(1),
211220
BLOCK_N=BLOCK_N,
212221
RENORM=renormalize,
213-
num_stages=num_stages,
214-
ROWS_PER_PID=ROWS_PER_PID,
215222
)
216223

217224
return weights, indices

0 commit comments

Comments
 (0)