Skip to content

Commit a404e55

Browse files
committed
[Optimization]: add autotune and fix little things
Signed-off-by: ijpq <[email protected]>
1 parent f0301c9 commit a404e55

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

vllm/model_executor/layers/fused_moe/gpt_oss_fused_router.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,21 @@ def torch_dtype_to_tl(dtype: torch.dtype):
2020
raise ValueError(f"Unsupported dtype: {dtype}")
2121

2222

23+
@triton.autotune(
24+
configs=[
25+
triton.Config({"ROWS_PER_PID": 1, "num_stages": 2, "num_warps": 1}),
26+
triton.Config({"ROWS_PER_PID": 1, "num_stages": 2, "num_warps": 2}),
27+
triton.Config({"ROWS_PER_PID": 2, "num_stages": 2, "num_warps": 2}),
28+
triton.Config({"ROWS_PER_PID": 4, "num_stages": 2, "num_warps": 2}),
29+
triton.Config({"ROWS_PER_PID": 16, "num_stages": 2, "num_warps": 4}),
30+
triton.Config({"ROWS_PER_PID": 16, "num_stages": 3, "num_warps": 8}),
31+
triton.Config({"ROWS_PER_PID": 32, "num_stages": 2, "num_warps": 4}),
32+
triton.Config({"ROWS_PER_PID": 32, "num_stages": 3, "num_warps": 8}),
33+
triton.Config({"ROWS_PER_PID": 64, "num_stages": 3, "num_warps": 8}),
34+
triton.Config({"ROWS_PER_PID": 128, "num_stages": 3, "num_warps": 8}),
35+
],
36+
key=["N", "topk", "RENORM"],
37+
)
2338
@triton.jit
2439
def _topk_softmax_kernel(
2540
logits_ptr,
@@ -37,8 +52,8 @@ def _topk_softmax_kernel(
3752
stride_ik,
3853
BLOCK_N: tl.constexpr,
3954
RENORM: tl.constexpr,
40-
num_stages: tl.constexpr,
4155
ROWS_PER_PID: tl.constexpr,
56+
num_stages: tl.constexpr,
4257
):
4358
pid = tl.program_id(0)
4459
num_programs = tl.num_programs(0)
@@ -48,10 +63,10 @@ def _topk_softmax_kernel(
4863
mask_n = offs_n < N
4964
store_mask = offs_k < topk
5065

51-
# specify topk<=2 and RENORM specialization by tl.constexpr,
52-
# similar as `constexpr if` in C++17
66+
# impl topk<=2 and RENORM specialization by tl.constexpr,
67+
# same as constexpr if in C++17
5368
if topk == 1:
54-
for row_idx in tl.range(pid, M, num_programs, num_stages):
69+
for row_idx in tl.range(pid, M, num_programs, num_stages, warp_specialize=True):
5570
logits = tl.load(
5671
logits_ptr + row_idx * stride_lm + offs_n * stride_ln,
5772
mask=mask_n,
@@ -71,10 +86,10 @@ def _topk_softmax_kernel(
7186
cur_max = 1
7287

7388
tl.store(weights_ptr + row_idx * stride_wm + 0 * stride_wk, cur_max)
74-
tl.store(indices_ptr + row_idx * stride_im + 0 * stride_wk, cur_idx)
89+
tl.store(indices_ptr + row_idx * stride_im + 0 * stride_ik, cur_idx)
7590

7691
elif topk == 2:
77-
for row_idx in tl.range(pid, M, num_programs, num_stages):
92+
for row_idx in tl.range(pid, M, num_programs, num_stages, warp_specialize=True):
7893
logits = tl.load(
7994
logits_ptr + row_idx * stride_lm + offs_n * stride_ln,
8095
mask=mask_n,
@@ -103,7 +118,7 @@ def _topk_softmax_kernel(
103118
tl.store(weights_ptr + row_idx * stride_wm, val0)
104119
tl.store(indices_ptr + row_idx * stride_im, idx0)
105120
tl.store(weights_ptr + row_idx * stride_wm + 1 * stride_wk, val1)
106-
tl.store(indices_ptr + row_idx * stride_im + 1 * stride_wk, idx1)
121+
tl.store(indices_ptr + row_idx * stride_im + 1 * stride_ik, idx1)
107122

108123
else:
109124
topk_vals = tl.zeros([ROWS_PER_PID, topk_padded], dtype=tl.float32) + float(
@@ -113,7 +128,11 @@ def _topk_softmax_kernel(
113128

114129
rows = tl.arange(0, ROWS_PER_PID)
115130
for row_idx in tl.range(
116-
pid * ROWS_PER_PID, M, num_programs * ROWS_PER_PID, num_stages
131+
pid * ROWS_PER_PID,
132+
M,
133+
num_programs * ROWS_PER_PID,
134+
num_stages,
135+
warp_specialize=True,
117136
):
118137
row_indices = row_idx + rows # [ROWS_PER_POD,]
119138
row_mask = row_indices < M
@@ -183,16 +202,15 @@ def fused_topk_softmax(
183202
) -> tuple[torch.Tensor, torch.Tensor]:
184203
M, N = router_logits.shape # num_tokens, num_experts
185204

186-
weights = torch.empty(
187-
(M, topk), device=router_logits.device, dtype=router_logits.dtype
188-
)
205+
weights = torch.empty((M, topk), device=router_logits.device, dtype=torch.float32)
189206
indices = torch.empty((M, topk), device=router_logits.device, dtype=torch.int32)
190207

191208
BLOCK_N = triton.next_power_of_2(N) # num_padded_experts
192209
topk_padded = triton.next_power_of_2(topk)
193-
grid = (M,)
194-
num_stages = 2
195-
ROWS_PER_PID = 4
210+
211+
# enable autotune to find correct num threadblock,
212+
# refer to https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
213+
grid = lambda META: (triton.cdiv(M, META["ROWS_PER_PID"]),)
196214

197215
_topk_softmax_kernel[grid](
198216
logits_ptr=router_logits,
@@ -210,8 +228,6 @@ def fused_topk_softmax(
210228
stride_ik=indices.stride(1),
211229
BLOCK_N=BLOCK_N,
212230
RENORM=renormalize,
213-
num_stages=num_stages,
214-
ROWS_PER_PID=ROWS_PER_PID,
215231
)
216232

217233
return weights, indices

0 commit comments

Comments
 (0)