Skip to content

Commit f0301c9

Browse files
committed
[Optimization]: unroll for each program along M
Signed-off-by: ijpq <[email protected]>
1 parent 0c7698e commit f0301c9

File tree

1 file changed

+48
-19
lines changed

1 file changed

+48
-19
lines changed

vllm/model_executor/layers/fused_moe/gpt_oss_fused_router.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def _topk_softmax_kernel(
3838
BLOCK_N: tl.constexpr,
3939
RENORM: tl.constexpr,
4040
num_stages: tl.constexpr,
41+
ROWS_PER_PID: tl.constexpr,
4142
):
4243
pid = tl.program_id(0)
4344
num_programs = tl.num_programs(0)
@@ -105,47 +106,73 @@ def _topk_softmax_kernel(
105106
tl.store(indices_ptr + row_idx * stride_im + 1 * stride_wk, idx1)
106107

107108
else:
108-
topk_vals = tl.zeros([topk_padded], dtype=tl.float32) + float("-inf")
109-
topk_idxs = tl.zeros([topk_padded], dtype=tl.int32)
110-
111-
for row_idx in tl.range(pid, M, num_programs, num_stages):
109+
topk_vals = tl.zeros([ROWS_PER_PID, topk_padded], dtype=tl.float32) + float(
110+
"-inf"
111+
)
112+
topk_idxs = tl.zeros([ROWS_PER_PID, topk_padded], dtype=tl.int32)
113+
114+
rows = tl.arange(0, ROWS_PER_PID)
115+
for row_idx in tl.range(
116+
pid * ROWS_PER_PID, M, num_programs * ROWS_PER_PID, num_stages
117+
):
118+
row_indices = row_idx + rows # [ROWS_PER_POD,]
119+
row_mask = row_indices < M
120+
# broadcast to [ROWS_PER_PID, BLOCKN]
112121
logits = tl.load(
113-
logits_ptr + row_idx * stride_lm + offs_n * stride_ln,
114-
mask=mask_n,
122+
logits_ptr
123+
+ row_indices[:, None] * stride_lm # [ROWS_PER_PID, 1]
124+
+ offs_n[None, :] * stride_ln, # [1, BLOCK_N]
125+
mask=row_mask[:, None] # [ROWS_PER_PID,1]
126+
& mask_n[None, :], # [1, BLOCKN]
115127
other=float("-inf"),
116128
)
117129

118130
if not RENORM:
119-
row_sub_max = logits - tl.max(logits, axis=0)
131+
row_sub_max = logits - tl.max(
132+
logits, axis=1, keep_dims=True
133+
) # [ROWS_PER_PID, BLOCK_N] - [ROWS_PER_PID,1]
120134
numerator = tl.exp(row_sub_max)
121-
denominator = tl.sum(numerator, axis=0)
135+
denominator = tl.sum(
136+
numerator, axis=1, keep_dims=True
137+
) # [ROWS_PER_PID, BLOCKN]
122138
logits = numerator / denominator
123139

124140
for k in tl.static_range(topk):
125-
cur_max = tl.max(logits, axis=0)
126-
cur_idx = tl.argmax(logits, axis=0)
141+
cur_max = tl.max(logits, axis=1, keep_dims=True) # [ROWS_PER_PID, 1]
142+
cur_idx = tl.argmax(logits, axis=1, keep_dims=True)
127143

128144
k_mask = offs_k == k
129-
topk_vals = tl.where(k_mask, cur_max, topk_vals)
145+
topk_vals = tl.where(
146+
k_mask, cur_max, topk_vals
147+
) # [ROWS_PER PID, 1], [ROWS_PER PID, topkpadded]
130148
topk_idxs = tl.where(k_mask, cur_idx, topk_idxs)
131149

132-
logits = tl.where(offs_n == cur_idx, float("-inf"), logits)
150+
mask_selected = cur_idx == offs_n[None, :] # [ROWSPERPID,1] [1,BLOCKN]
151+
logits = tl.where(mask_selected, float("-inf"), logits)
133152

134153
if RENORM:
135-
topk_vals = topk_vals - tl.max(topk_vals, axis=0)
154+
topk_vals = topk_vals - tl.max(
155+
topk_vals, axis=1, keep_dims=True
156+
) # [ROWSPERPID, topkpadded] - [ROWSPERPID,1]
136157
numerator = tl.exp(topk_vals)
137-
denominator = tl.sum(numerator, axis=0)
138-
topk_vals = numerator / denominator
158+
denominator = tl.sum(
159+
numerator, axis=1, keep_dims=True
160+
) # [ROWSPERPID,1]
161+
topk_vals = numerator / denominator # [ROWSPERPID,topkpadded]
139162

140163
tl.store(
141-
weights_ptr + row_idx * stride_wm + offs_k * stride_wk,
164+
weights_ptr
165+
+ row_indices[:, None] * stride_wm # [ROWSPERPID,1]
166+
+ offs_k[None, :] * stride_wk, # [1, topkpadded]
142167
topk_vals,
143-
mask=store_mask,
168+
mask=row_mask[:, None] & store_mask[None, :], # [1, topkpadded]
144169
)
145170
tl.store(
146-
indices_ptr + row_idx * stride_im + offs_k * stride_ik,
171+
indices_ptr
172+
+ row_indices[:, None] * stride_im
173+
+ offs_k[None, :] * stride_ik,
147174
topk_idxs,
148-
mask=store_mask,
175+
mask=row_mask[:, None] & store_mask[None, :],
149176
)
150177

151178

@@ -165,6 +192,7 @@ def fused_topk_softmax(
165192
topk_padded = triton.next_power_of_2(topk)
166193
grid = (M,)
167194
num_stages = 2
195+
ROWS_PER_PID = 4
168196

169197
_topk_softmax_kernel[grid](
170198
logits_ptr=router_logits,
@@ -183,6 +211,7 @@ def fused_topk_softmax(
183211
BLOCK_N=BLOCK_N,
184212
RENORM=renormalize,
185213
num_stages=num_stages,
214+
ROWS_PER_PID=ROWS_PER_PID,
186215
)
187216

188217
return weights, indices

0 commit comments

Comments
 (0)