Skip to content

Commit 4ef122b

Browse files
committed
[Optimization] Add Fused Triton Kernel for GPT-OSS Router
Signed-off-by: ijpq <[email protected]>
1 parent d64429b commit 4ef122b

File tree

2 files changed

+218
-2
lines changed

2 files changed

+218
-2
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Fused Router Kernel for GPT-OSS MoE.
5+
Fuses the router linear layer (GEMM) and Top-K selection + Softmax.
6+
"""
7+
8+
import torch
9+
10+
from vllm.triton_utils import tl, triton
11+
12+
13+
@triton.jit
14+
def fused_moe_router_kernel(
15+
# Pointers
16+
x_ptr, # Input [M, K]
17+
w_ptr, # Weight [N, K]
18+
out_w_ptr, # Output Weights [M, TopK]
19+
out_i_ptr, # Output Indices [M, TopK]
20+
# Dimensions
21+
M,
22+
K,
23+
N,
24+
TopK: tl.constexpr,
25+
# Strides
26+
stride_xm,
27+
stride_xk,
28+
stride_wn,
29+
stride_wk,
30+
stride_wm,
31+
stride_wk_out, # output weights stride
32+
stride_im,
33+
stride_ik_out, # output indices stride
34+
# Meta-parameters
35+
BLOCK_M: tl.constexpr,
36+
BLOCK_K: tl.constexpr,
37+
BLOCK_N: tl.constexpr, # Must be >= N (number of experts)
38+
):
39+
# 1. Program ID
40+
pid = tl.program_id(axis=0)
41+
42+
# 2. Create offsets
43+
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
44+
offs_n = tl.arange(0, BLOCK_N)
45+
46+
# 3. Initialize accumulator for GEMM (Logits)
47+
# Accumulator shape: [BLOCK_M, BLOCK_N]
48+
# We perform computation in float32 for numerical stability
49+
acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
50+
51+
# 4. GEMM Loop over K dimension
52+
for k in range(0, K, BLOCK_K):
53+
# Load Input X [BLOCK_M, BLOCK_K]
54+
offs_k = k + tl.arange(0, BLOCK_K)
55+
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
56+
# We use a mask for M dimension boundary, and K dimension
57+
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
58+
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
59+
60+
# Load Weight W [BLOCK_N, BLOCK_K]
61+
# Assuming W is stored as [N, K] row-major usually, but here accessed as [N, K]
62+
# NOTE: PyTorch Linear weights are [Out_Features, In_Features] -> [N, K]
63+
w_ptrs = w_ptr + (offs_n[None, :] * stride_wn + offs_k[:, None] * stride_wk)
64+
w_mask = (offs_n[None, :] < N) & (offs_k[:, None] < K)
65+
w = tl.load(w_ptrs, mask=w_mask, other=0.0)
66+
67+
# Matrix Multiply: [BLOCK_M, BLOCK_K] @ [BLOCK_K, BLOCK_N] -> [BLOCK_M, BLOCK_N]
68+
# Note: w is loaded as [BLOCK_N, BLOCK_K] so we essentially transpose
69+
# it by how we use dot
70+
# In triton tl.dot(a, b), if a is [M, K] and b is [K, N], it works.
71+
# Here we loaded w as [1, N] broadcasted vs [K, 1].
72+
# Let's correct the load for W to match dot product expectation.
73+
# We want W^T in the math X @ W^T.
74+
# X is [M, K]. W is [N, K].
75+
# We need to load W blocks such that we can do dot(X, W.T).
76+
# Efficient way: Load WT tile [BLOCK_K, BLOCK_N].
77+
# W_ptrs above defined: offs_n (col) * stride_wn + offs_k (row) * stride_wk
78+
# This treats W as [K, N] effectively if stride_wn is correct.
79+
# But PyTorch Linear.weight is [N, K]. So stride_wn is stride for N (rows),
80+
# stride_wk for K (cols).
81+
# To get a [BLOCK_K, BLOCK_N] tile from [N, K] matrix transposed:
82+
# The tile needs elements where K varies along rows 0..BK and N varies along
83+
# cols 0..BN.
84+
# ptr = base + (offs_n[None, :] * stride_n) + (offs_k[:, None] * stride_k)
85+
# This loads a [BLOCK_K, BLOCK_N] block from the weight matrix. Correct.
86+
87+
acc += tl.dot(x, w)
88+
89+
# 5. Top-K Selection in SRAM
90+
# acc now contains the logits [BLOCK_M, BLOCK_N]
91+
# We only care about valid experts (column index < N)
92+
# Mask out invalid experts with -inf
93+
logits = acc
94+
# Usually BLOCK_N is power of 2 (e.g. 32, 128), so if N < BLOCK_N we need masking
95+
if BLOCK_N > N:
96+
logits = tl.where(tl.arange(0, BLOCK_N)[None, :] < N, logits, float("-inf"))
97+
98+
# Storage for TopK results
99+
# Since TopK is small (usually 1 or 2), we can iteratively find max.
100+
# We can't dynamic loop in Triton easily with python range, but TopK is constexpr.
101+
102+
# We will store topk values temporarily to compute softmax
103+
topk_val_storage = tl.zeros([BLOCK_M, TopK], dtype=tl.float32)
104+
topk_idx_storage = tl.zeros([BLOCK_M, TopK], dtype=tl.int32)
105+
106+
for i in range(TopK):
107+
# Find max along the expert dimension
108+
val_max, idx_max = tl.max(logits, axis=1, return_indices=True)
109+
110+
# Store current max
111+
topk_val_storage[:, i] = val_max
112+
topk_idx_storage[:, i] = idx_max
113+
114+
# Mask out the selected expert to find the next max in next iteration
115+
# Construct a mask: broadcast indices to [BLOCK_M, BLOCK_N] and compare
116+
mask = tl.arange(0, BLOCK_N)[None, :] == idx_max[:, None]
117+
logits = tl.where(mask, float("-inf"), logits)
118+
119+
# 6. Softmax Renormalization
120+
# Now we have the TopK logits in topk_val_storage [BLOCK_M, TopK]
121+
# We perform softmax on these TopK values.
122+
123+
# Subtract max for numerical stability
124+
val_max_for_softmax = tl.max(topk_val_storage, axis=1)
125+
numerator = tl.exp(topk_val_storage - val_max_for_softmax[:, None])
126+
denominator = tl.sum(numerator, axis=1)
127+
softmax_res = numerator / denominator[:, None]
128+
129+
# 7. Write Output
130+
# We only write valid rows (M boundary)
131+
output_mask = offs_m[:, None] < M
132+
133+
# Pointers for output
134+
# out_w_ptr shape: [M, TopK]
135+
# out_i_ptr shape: [M, TopK]
136+
137+
offs_topk = tl.arange(0, TopK)
138+
out_w_ptrs = out_w_ptr + (
139+
offs_m[:, None] * stride_wm + offs_topk[None, :] * stride_wk_out
140+
)
141+
out_i_ptrs = out_i_ptr + (
142+
offs_m[:, None] * stride_im + offs_topk[None, :] * stride_ik_out
143+
)
144+
145+
tl.store(out_w_ptrs, softmax_res, mask=output_mask)
146+
tl.store(out_i_ptrs, topk_idx_storage, mask=output_mask)
147+
148+
149+
def fused_router(
150+
hidden_states: torch.Tensor,
151+
router_weights: torch.Tensor,
152+
top_k: int,
153+
) -> tuple[torch.Tensor, torch.Tensor]:
154+
"""
155+
Args:
156+
hidden_states: [num_tokens, hidden_size]
157+
router_weights: [num_experts, hidden_size]
158+
top_k: int
159+
Returns:
160+
topk_weights: [num_tokens, top_k] (after softmax)
161+
topk_indices: [num_tokens, top_k]
162+
"""
163+
assert hidden_states.ndim == 2
164+
assert router_weights.ndim == 2
165+
166+
M, K = hidden_states.shape
167+
N, _ = router_weights.shape
168+
169+
# Outputs
170+
topk_weights = torch.empty(
171+
(M, top_k), device=hidden_states.device, dtype=torch.float32
172+
)
173+
topk_indices = torch.empty(
174+
(M, top_k), device=hidden_states.device, dtype=torch.int32
175+
)
176+
177+
# Heuristics for Block Size
178+
BLOCK_M = 32
179+
BLOCK_K = 128
180+
# BLOCK_N must be power of 2 and >= N
181+
BLOCK_N = triton.next_power_of_2(N)
182+
183+
grid = (triton.cdiv(M, BLOCK_M), 1, 1)
184+
185+
fused_moe_router_kernel[grid](
186+
hidden_states,
187+
router_weights,
188+
topk_weights,
189+
topk_indices,
190+
M,
191+
K,
192+
N,
193+
TopK=top_k,
194+
stride_xm=hidden_states.stride(0),
195+
stride_xk=hidden_states.stride(1),
196+
stride_wn=router_weights.stride(0),
197+
stride_wk=router_weights.stride(1),
198+
stride_wm=topk_weights.stride(0),
199+
stride_wk_out=topk_weights.stride(1),
200+
stride_im=topk_indices.stride(0),
201+
stride_ik_out=topk_indices.stride(1),
202+
BLOCK_M=BLOCK_M,
203+
BLOCK_K=BLOCK_K,
204+
BLOCK_N=BLOCK_N,
205+
)
206+
207+
return topk_weights, topk_indices

vllm/model_executor/models/gpt_oss.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from vllm.model_executor.layers.fused_moe import FusedMoE
2222
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
23+
from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import fused_router
2324
from vllm.model_executor.layers.layernorm import RMSNorm
2425
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
2526
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -184,9 +185,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
184185
g = rocm_unquantized_gemm(
185186
self, x[:, : self.hidden_size], self.router.weight, self.router.bias
186187
)
188+
x = self.experts(hidden_states=x, router_logits=g)
187189
else:
188-
g = self.router(x)
189-
x = self.experts(hidden_states=x, router_logits=g)
190+
topk_weights, topk_indices = fused_router(
191+
hidden_states=x,
192+
router_weights=self.router.weight,
193+
top_k=self.experts_per_token,
194+
)
195+
196+
x = self.experts(
197+
hidden_states=x, topk_weights=topk_weights, topk_ids=topk_indices
198+
)
190199

191200
if self.is_sequence_parallel:
192201
x = tensor_model_parallel_all_gather(x.contiguous(), 0)

0 commit comments

Comments
 (0)