|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +""" |
| 4 | +Fused Router Kernel for GPT-OSS MoE. |
| 5 | +Fuses the router linear layer (GEMM) and Top-K selection + Softmax. |
| 6 | +""" |
| 7 | + |
| 8 | +import torch |
| 9 | + |
| 10 | +from vllm.triton_utils import tl, triton |
| 11 | + |
| 12 | + |
| 13 | +@triton.jit |
| 14 | +def fused_moe_router_kernel( |
| 15 | + # Pointers |
| 16 | + x_ptr, # Input [M, K] |
| 17 | + w_ptr, # Weight [N, K] |
| 18 | + out_w_ptr, # Output Weights [M, TopK] |
| 19 | + out_i_ptr, # Output Indices [M, TopK] |
| 20 | + # Dimensions |
| 21 | + M, |
| 22 | + K, |
| 23 | + N, |
| 24 | + TopK: tl.constexpr, |
| 25 | + # Strides |
| 26 | + stride_xm, |
| 27 | + stride_xk, |
| 28 | + stride_wn, |
| 29 | + stride_wk, |
| 30 | + stride_wm, |
| 31 | + stride_wk_out, # output weights stride |
| 32 | + stride_im, |
| 33 | + stride_ik_out, # output indices stride |
| 34 | + # Meta-parameters |
| 35 | + BLOCK_M: tl.constexpr, |
| 36 | + BLOCK_K: tl.constexpr, |
| 37 | + BLOCK_N: tl.constexpr, # Must be >= N (number of experts) |
| 38 | +): |
| 39 | + # 1. Program ID |
| 40 | + pid = tl.program_id(axis=0) |
| 41 | + |
| 42 | + # 2. Create offsets |
| 43 | + offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M) |
| 44 | + offs_n = tl.arange(0, BLOCK_N) |
| 45 | + |
| 46 | + # 3. Initialize accumulator for GEMM (Logits) |
| 47 | + # Accumulator shape: [BLOCK_M, BLOCK_N] |
| 48 | + # We perform computation in float32 for numerical stability |
| 49 | + acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
| 50 | + |
| 51 | + # 4. GEMM Loop over K dimension |
| 52 | + for k in range(0, K, BLOCK_K): |
| 53 | + # Load Input X [BLOCK_M, BLOCK_K] |
| 54 | + offs_k = k + tl.arange(0, BLOCK_K) |
| 55 | + x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk) |
| 56 | + # We use a mask for M dimension boundary, and K dimension |
| 57 | + x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K) |
| 58 | + x = tl.load(x_ptrs, mask=x_mask, other=0.0) |
| 59 | + |
| 60 | + # Load Weight W [BLOCK_N, BLOCK_K] |
| 61 | + # Assuming W is stored as [N, K] row-major usually, but here accessed as [N, K] |
| 62 | + # NOTE: PyTorch Linear weights are [Out_Features, In_Features] -> [N, K] |
| 63 | + w_ptrs = w_ptr + (offs_n[None, :] * stride_wn + offs_k[:, None] * stride_wk) |
| 64 | + w_mask = (offs_n[None, :] < N) & (offs_k[:, None] < K) |
| 65 | + w = tl.load(w_ptrs, mask=w_mask, other=0.0) |
| 66 | + |
| 67 | + # Matrix Multiply: [BLOCK_M, BLOCK_K] @ [BLOCK_K, BLOCK_N] -> [BLOCK_M, BLOCK_N] |
| 68 | + # Note: w is loaded as [BLOCK_N, BLOCK_K] so we essentially transpose |
| 69 | + # it by how we use dot |
| 70 | + # In triton tl.dot(a, b), if a is [M, K] and b is [K, N], it works. |
| 71 | + # Here we loaded w as [1, N] broadcasted vs [K, 1]. |
| 72 | + # Let's correct the load for W to match dot product expectation. |
| 73 | + # We want W^T in the math X @ W^T. |
| 74 | + # X is [M, K]. W is [N, K]. |
| 75 | + # We need to load W blocks such that we can do dot(X, W.T). |
| 76 | + # Efficient way: Load WT tile [BLOCK_K, BLOCK_N]. |
| 77 | + # W_ptrs above defined: offs_n (col) * stride_wn + offs_k (row) * stride_wk |
| 78 | + # This treats W as [K, N] effectively if stride_wn is correct. |
| 79 | + # But PyTorch Linear.weight is [N, K]. So stride_wn is stride for N (rows), |
| 80 | + # stride_wk for K (cols). |
| 81 | + # To get a [BLOCK_K, BLOCK_N] tile from [N, K] matrix transposed: |
| 82 | + # The tile needs elements where K varies along rows 0..BK and N varies along |
| 83 | + # cols 0..BN. |
| 84 | + # ptr = base + (offs_n[None, :] * stride_n) + (offs_k[:, None] * stride_k) |
| 85 | + # This loads a [BLOCK_K, BLOCK_N] block from the weight matrix. Correct. |
| 86 | + |
| 87 | + acc += tl.dot(x, w) |
| 88 | + |
| 89 | + # 5. Top-K Selection in SRAM |
| 90 | + # acc now contains the logits [BLOCK_M, BLOCK_N] |
| 91 | + # We only care about valid experts (column index < N) |
| 92 | + # Mask out invalid experts with -inf |
| 93 | + logits = acc |
| 94 | + # Usually BLOCK_N is power of 2 (e.g. 32, 128), so if N < BLOCK_N we need masking |
| 95 | + if BLOCK_N > N: |
| 96 | + logits = tl.where(tl.arange(0, BLOCK_N)[None, :] < N, logits, float("-inf")) |
| 97 | + |
| 98 | + # Storage for TopK results |
| 99 | + # Since TopK is small (usually 1 or 2), we can iteratively find max. |
| 100 | + # We can't dynamic loop in Triton easily with python range, but TopK is constexpr. |
| 101 | + |
| 102 | + # We will store topk values temporarily to compute softmax |
| 103 | + topk_val_storage = tl.zeros([BLOCK_M, TopK], dtype=tl.float32) |
| 104 | + topk_idx_storage = tl.zeros([BLOCK_M, TopK], dtype=tl.int32) |
| 105 | + |
| 106 | + for i in range(TopK): |
| 107 | + # Find max along the expert dimension |
| 108 | + val_max, idx_max = tl.max(logits, axis=1, return_indices=True) |
| 109 | + |
| 110 | + # Store current max |
| 111 | + topk_val_storage[:, i] = val_max |
| 112 | + topk_idx_storage[:, i] = idx_max |
| 113 | + |
| 114 | + # Mask out the selected expert to find the next max in next iteration |
| 115 | + # Construct a mask: broadcast indices to [BLOCK_M, BLOCK_N] and compare |
| 116 | + mask = tl.arange(0, BLOCK_N)[None, :] == idx_max[:, None] |
| 117 | + logits = tl.where(mask, float("-inf"), logits) |
| 118 | + |
| 119 | + # 6. Softmax Renormalization |
| 120 | + # Now we have the TopK logits in topk_val_storage [BLOCK_M, TopK] |
| 121 | + # We perform softmax on these TopK values. |
| 122 | + |
| 123 | + # Subtract max for numerical stability |
| 124 | + val_max_for_softmax = tl.max(topk_val_storage, axis=1) |
| 125 | + numerator = tl.exp(topk_val_storage - val_max_for_softmax[:, None]) |
| 126 | + denominator = tl.sum(numerator, axis=1) |
| 127 | + softmax_res = numerator / denominator[:, None] |
| 128 | + |
| 129 | + # 7. Write Output |
| 130 | + # We only write valid rows (M boundary) |
| 131 | + output_mask = offs_m[:, None] < M |
| 132 | + |
| 133 | + # Pointers for output |
| 134 | + # out_w_ptr shape: [M, TopK] |
| 135 | + # out_i_ptr shape: [M, TopK] |
| 136 | + |
| 137 | + offs_topk = tl.arange(0, TopK) |
| 138 | + out_w_ptrs = out_w_ptr + ( |
| 139 | + offs_m[:, None] * stride_wm + offs_topk[None, :] * stride_wk_out |
| 140 | + ) |
| 141 | + out_i_ptrs = out_i_ptr + ( |
| 142 | + offs_m[:, None] * stride_im + offs_topk[None, :] * stride_ik_out |
| 143 | + ) |
| 144 | + |
| 145 | + tl.store(out_w_ptrs, softmax_res, mask=output_mask) |
| 146 | + tl.store(out_i_ptrs, topk_idx_storage, mask=output_mask) |
| 147 | + |
| 148 | + |
| 149 | +def fused_router( |
| 150 | + hidden_states: torch.Tensor, |
| 151 | + router_weights: torch.Tensor, |
| 152 | + top_k: int, |
| 153 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 154 | + """ |
| 155 | + Args: |
| 156 | + hidden_states: [num_tokens, hidden_size] |
| 157 | + router_weights: [num_experts, hidden_size] |
| 158 | + top_k: int |
| 159 | + Returns: |
| 160 | + topk_weights: [num_tokens, top_k] (after softmax) |
| 161 | + topk_indices: [num_tokens, top_k] |
| 162 | + """ |
| 163 | + assert hidden_states.ndim == 2 |
| 164 | + assert router_weights.ndim == 2 |
| 165 | + |
| 166 | + M, K = hidden_states.shape |
| 167 | + N, _ = router_weights.shape |
| 168 | + |
| 169 | + # Outputs |
| 170 | + topk_weights = torch.empty( |
| 171 | + (M, top_k), device=hidden_states.device, dtype=torch.float32 |
| 172 | + ) |
| 173 | + topk_indices = torch.empty( |
| 174 | + (M, top_k), device=hidden_states.device, dtype=torch.int32 |
| 175 | + ) |
| 176 | + |
| 177 | + # Heuristics for Block Size |
| 178 | + BLOCK_M = 32 |
| 179 | + BLOCK_K = 128 |
| 180 | + # BLOCK_N must be power of 2 and >= N |
| 181 | + BLOCK_N = triton.next_power_of_2(N) |
| 182 | + |
| 183 | + grid = (triton.cdiv(M, BLOCK_M), 1, 1) |
| 184 | + |
| 185 | + fused_moe_router_kernel[grid]( |
| 186 | + hidden_states, |
| 187 | + router_weights, |
| 188 | + topk_weights, |
| 189 | + topk_indices, |
| 190 | + M, |
| 191 | + K, |
| 192 | + N, |
| 193 | + TopK=top_k, |
| 194 | + stride_xm=hidden_states.stride(0), |
| 195 | + stride_xk=hidden_states.stride(1), |
| 196 | + stride_wn=router_weights.stride(0), |
| 197 | + stride_wk=router_weights.stride(1), |
| 198 | + stride_wm=topk_weights.stride(0), |
| 199 | + stride_wk_out=topk_weights.stride(1), |
| 200 | + stride_im=topk_indices.stride(0), |
| 201 | + stride_ik_out=topk_indices.stride(1), |
| 202 | + BLOCK_M=BLOCK_M, |
| 203 | + BLOCK_K=BLOCK_K, |
| 204 | + BLOCK_N=BLOCK_N, |
| 205 | + ) |
| 206 | + |
| 207 | + return topk_weights, topk_indices |
0 commit comments