Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tests/kernels/moe/test_gpt_oss_fused_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import fused_topk_softmax


@pytest.mark.parametrize("M", [1, 32, 128, 2048])
@pytest.mark.parametrize("N", [32, 65, 128])
@pytest.mark.parametrize("topk", [1, 2, 3, 4, 5])
def test_fused_router(M, N, topk):
device = "cuda"
torch.manual_seed(0)

logits = torch.randn((M, N), device=device, dtype=torch.float32)

ref_vals, ref_indices = torch.topk(logits, topk, dim=-1)
ref_probs = torch.softmax(ref_vals, dim=-1)

tri_probs, tri_indices = fused_topk_softmax(logits, topk, renormalize=True)

torch.testing.assert_close(tri_indices.long(), ref_indices)
torch.testing.assert_close(tri_probs, ref_probs, atol=1e-4, rtol=1e-4)


if __name__ == "__main__":
test_fused_router(128, 32, 2)
print("Test Passed!")
61 changes: 61 additions & 0 deletions tests/kernels/moe/test_gpt_oss_routing_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import (
gpt_oss_custom_routing_function,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE


@pytest.mark.parametrize("num_tokens", [10, 128, 1024])
@pytest.mark.parametrize("num_experts", [32, 65, 128])
@pytest.mark.parametrize("topk", [1, 2, 3, 4, 5])
def test_routing_consistency(num_tokens, num_experts, topk):
torch.manual_seed(42)
device = torch.device("cuda")

hidden_states = torch.randn(num_tokens, 4096, device=device, dtype=torch.float16)
router_logits = torch.randn(
num_tokens, num_experts, device=device, dtype=torch.float32
)

ref_weights, ref_ids, _ = FusedMoE.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
use_grouped_topk=False,
renormalize=True,
custom_routing_function=None,
)

triton_weights, triton_ids, _ = FusedMoE.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
use_grouped_topk=False,
renormalize=True,
custom_routing_function=gpt_oss_custom_routing_function,
)

print(f"\nTesting M={num_tokens}, E={num_experts}, K={topk}")

torch.testing.assert_close(
triton_ids,
ref_ids,
msg="Expert indices mismatch between Native and Triton implementation",
)

torch.testing.assert_close(
triton_weights,
ref_weights,
atol=1e-3,
rtol=1e-3,
msg="Expert weights mismatch between Native and Triton implementation",
)


if __name__ == "__main__":
test_routing_consistency(128, 32, 2)
print("Consistency Test Passed!")
112 changes: 112 additions & 0 deletions vllm/model_executor/layers/fused_moe/gpt_oss_fused_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""GPT-OSS MoE router with Triton topk kernel."""

import torch

from vllm.triton_utils import tl, triton


@triton.jit
def _topk_softmax_kernel(
logits_ptr,
weights_ptr,
indices_ptr,
M,
N,
topk: tl.constexpr,
topk_padded: tl.constexpr,
stride_lm,
stride_ln,
stride_wm,
stride_wk,
stride_im,
stride_ik,
BLOCK_N: tl.constexpr,
RENORM: tl.constexpr,
):
token_idx = tl.program_id(0)

offs = tl.arange(0, BLOCK_N)
mask = offs < N
logit_offs = logits_ptr + token_idx * stride_lm + offs * stride_ln
logits = tl.load(logit_offs, mask=mask, other=float("-inf"))

topk_vals = tl.zeros([topk_padded], dtype=tl.float32) + float("-inf")
topk_idxs = tl.zeros([topk_padded], dtype=tl.int32)

working_logits = logits

for k in range(topk):
cur_max = tl.max(working_logits, axis=0)
cur_idx = tl.argmax(working_logits, axis=0)

k_mask = tl.arange(0, topk_padded) == k
topk_vals = tl.where(k_mask, cur_max, topk_vals)
topk_idxs = tl.where(k_mask, cur_idx, topk_idxs)

mask_selected = offs == cur_idx
working_logits = tl.where(mask_selected, float("-inf"), working_logits)

if RENORM:
max_val = tl.max(topk_vals, axis=0)
exp_vals = tl.exp(topk_vals - max_val)
sum_exp = tl.sum(exp_vals, axis=0)
topk_vals = exp_vals / sum_exp

offs_k = tl.arange(0, topk_padded)

store_mask = offs_k < topk

weight_ptrs = weights_ptr + token_idx * stride_wm + offs_k * stride_wk
tl.store(weight_ptrs, topk_vals, mask=store_mask)

index_ptrs = indices_ptr + token_idx * stride_im + offs_k * stride_ik
tl.store(index_ptrs, topk_idxs, mask=store_mask)


def fused_topk_softmax(
router_logits: torch.Tensor,
topk: int,
renormalize: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = router_logits.shape

weights = torch.empty((M, topk), device=router_logits.device, dtype=torch.float32)
indices = torch.empty((M, topk), device=router_logits.device, dtype=torch.int32)

BLOCK_N = triton.next_power_of_2(N)

topk_padded = triton.next_power_of_2(topk)

grid = (M,)

_topk_softmax_kernel[grid](
logits_ptr=router_logits,
weights_ptr=weights,
indices_ptr=indices,
M=M,
N=N,
topk=topk,
topk_padded=topk_padded,
stride_lm=router_logits.stride(0),
stride_ln=router_logits.stride(1),
stride_wm=weights.stride(0),
stride_wk=weights.stride(1),
stride_im=indices.stride(0),
stride_ik=indices.stride(1),
BLOCK_N=BLOCK_N,
RENORM=renormalize,
)

return weights, indices


def gpt_oss_custom_routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
# only use gating_output to avoid padding issues
return fused_topk_softmax(gating_output, topk, renormalize)
8 changes: 8 additions & 0 deletions vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import (
gpt_oss_custom_routing_function,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -173,6 +176,11 @@ def __init__(
has_bias=True,
activation="swigluoai",
is_sequence_parallel=self.is_sequence_parallel,
custom_routing_function=(
gpt_oss_custom_routing_function
if not current_platform.is_rocm()
else None
),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down