Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions tests/kernels/moe/benchmark_gpt_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.cuda.profiler as profiler

from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import (
gpt_oss_custom_routing_function,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE


def profile_run():
torch.manual_seed(0)
device = "cuda"

test_cases = [
{
"name": "GPTOSS20B",
"desc": "gpt oss 20b prefill",
"M": 4096,
"N": 32,
"topk": 4,
},
]

def run_origin(hidden_states, router_logits, topk):
_ = FusedMoE.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
use_grouped_topk=False,
renormalize=True,
custom_routing_function=None,
)

def run_triton(hidden_states, router_logits, topk):
_ = FusedMoE.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
use_grouped_topk=False,
renormalize=True,
custom_routing_function=gpt_oss_custom_routing_function,
)

for case in test_cases:
M, N, topk = case["M"], case["N"], case["topk"]
hidden_states = torch.randn(M, 4096, device=device, dtype=torch.float16)
router_logits = torch.randn(M, N, device=device, dtype=torch.float16)

for i in range(20):
print(f"Starting Global Warmups, Iter {i}")
run_origin(hidden_states, router_logits, topk)
run_triton(hidden_states, router_logits, topk)

torch.cuda.synchronize()
print("Warmup Completed. All kernels are compiled.")

profiler.start()

for case in test_cases:
M, N, topk = case["M"], case["N"], case["topk"]
hidden_states = torch.randn(M, 4096, device=device, dtype=torch.float16)
router_logits = torch.randn(M, N, device=device, dtype=torch.float16)
run_origin(hidden_states, router_logits, topk)
run_triton(hidden_states, router_logits, topk)
torch.cuda.synchronize()

profiler.stop()
print("Benchmark finished.")


if __name__ == "__main__":
profile_run()
125 changes: 125 additions & 0 deletions tests/kernels/moe/test_bitonic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch

from vllm.model_executor.layers.fused_moe.triton_bitonic_sort import (
bitonic_ce_descending_wrapper,
bitonic_sort32_descending,
bitonic_sort32_descending_wrapper,
)
from vllm.triton_utils import tl, triton


def test_bitonic_descending():
val = torch.arange(32, dtype=torch.float32, device="cuda")
seq = torch.arange(32, dtype=torch.int32, device="cuda")
new_val = torch.zeros(32, dtype=torch.float32, device="cuda")
new_seq = torch.zeros(32, dtype=torch.int32, device="cuda")
ref_1_seq = torch.tensor(
[
1,
0,
2,
3,
5,
4,
6,
7,
9,
8,
10,
11,
13,
12,
14,
15,
17,
16,
18,
19,
21,
20,
22,
23,
25,
24,
26,
27,
29,
28,
30,
31,
],
dtype=torch.int32,
device="cuda",
)

# assert stride 1 is correct when constructing bitonic
bitonic_ce_descending_wrapper[(1,)](val, seq, new_val, new_seq, 1)
torch.testing.assert_close(new_seq, ref_1_seq)

# assert final sort result
bitonic_sort32_descending_wrapper[(1,)](val, seq, new_val, new_seq)
seq = seq.flip(0)
torch.testing.assert_close(new_seq, seq)


@triton.jit
def test_bitonic_2d_kernel(
in_ptr,
out_val_ptr,
out_idx_ptr,
ROWS: tl.constexpr,
):
offs_row = tl.arange(0, ROWS)
offs_col = tl.arange(0, 32)

vals = tl.load(in_ptr + offs_row[:, None] * 32 + offs_col[None, :]) # [ROWS, 32]

idxs = tl.broadcast_to(offs_col[None, :], (ROWS, 32)).to(tl.int32) # [ROWS, 32]

sorted_vals, sorted_idxs = bitonic_sort32_descending(vals, idxs)

tl.store(out_val_ptr + offs_row[:, None] * 32 + offs_col[None, :], sorted_vals)
tl.store(out_idx_ptr + offs_row[:, None] * 32 + offs_col[None, :], sorted_idxs)


def test_bitonic_multirow():
for ROWS in [1, 2, 4, 8]:
torch.manual_seed(42)
x = torch.randn(ROWS, 32, device="cuda", dtype=torch.float32)
out_vals = torch.empty_like(x)
out_idxs = torch.empty(ROWS, 32, device="cuda", dtype=torch.int32)

# assumingly, num_warps >= ROWS
test_bitonic_2d_kernel[(1,)](
x,
out_vals,
out_idxs,
ROWS=ROWS,
num_warps=max(ROWS, 4),
)

expected_vals, expected_idxs = x.sort(dim=1, descending=True)

vals_match = torch.allclose(out_vals, expected_vals)
idxs_match = torch.equal(out_idxs, expected_idxs.to(torch.int32))

print(f"values match: {vals_match}")
print(f"indices match: {idxs_match}")

if not vals_match or not idxs_match:
print("input:")
print(x)
print("result vals:")
print(out_vals)
print("expected vals:")
print(expected_vals)
print("result idxs:")
print(out_idxs)
print("expected idxs:")
print(expected_idxs)


if __name__ == "__main__":
test_bitonic_multirow()
101 changes: 101 additions & 0 deletions tests/kernels/moe/test_gpt_oss_routing_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import (
gpt_oss_custom_routing_function,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.platforms import current_platform


@pytest.mark.parametrize("num_tokens", [10, 128, 1024])
@pytest.mark.parametrize("num_experts", [32, 65, 128])
@pytest.mark.parametrize("topk", [1, 2, 3, 4])
@pytest.mark.parametrize("renorm", [True, False])
@pytest.mark.skipif(not current_platform.is_cuda(), reason="only available on CUDA")
def test_routing_consistency(num_tokens, num_experts, topk, renorm):
torch.manual_seed(42)
device = torch.device("cuda")
hidden_states = torch.randn(num_tokens, 4096, device=device, dtype=torch.float16)
router_logits = torch.randn(
num_tokens, num_experts, device=device, dtype=torch.float32
)

ref_weights, ref_ids, _ = FusedMoE.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
use_grouped_topk=False,
renormalize=renorm,
custom_routing_function=None,
)

triton_weights, triton_ids, _ = FusedMoE.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
use_grouped_topk=False,
renormalize=renorm,
custom_routing_function=gpt_oss_custom_routing_function,
)

# compare triton with origin
torch.testing.assert_close(
triton_ids,
ref_ids,
msg="Expert indices mismatch between origin and triton implementation",
)
torch.testing.assert_close(
triton_weights,
ref_weights,
atol=1e-3,
rtol=1e-3,
msg="Expert weights mismatch between origin and triton implementation",
)
expected_indices_dtype = ref_ids.dtype
expecteed_weight_dtype = ref_weights.dtype

def native_impl(logits, topk, renorm):
if renorm:
ref_vals, ref_indices = torch.topk(logits, topk, dim=1)
ref_vals = torch.softmax(ref_vals, dim=1)
else:
ref_vals = torch.softmax(logits, dim=1)
ref_vals, ref_indices = torch.topk(ref_vals, topk, dim=1)
return ref_vals.to(expecteed_weight_dtype), ref_indices.to(
expected_indices_dtype
)

native_weights, native_ids = native_impl(router_logits, topk, renorm)

# compare triton with torch
torch.testing.assert_close(
triton_ids,
native_ids,
msg="Expert indices mismatch between native and triton implementation",
)
torch.testing.assert_close(
triton_weights,
native_weights,
atol=1e-3,
rtol=1e-3,
msg="Expert weights mismatch between native and triton implementation",
)

# compare origin with torch
torch.testing.assert_close(
native_ids,
ref_ids,
msg="Expert indices mismatch between origin and native implementation",
)
torch.testing.assert_close(
native_weights,
ref_weights,
atol=1e-3,
rtol=1e-3,
msg="Expert weights mismatch between origin and native implementation",
)

print(f"\nTesting TOKENS={num_tokens}, EXPERTS={num_experts}, TOPK={topk}")
Loading