diff --git a/tests/kernels/moe/benchmark_gpt_oss.py b/tests/kernels/moe/benchmark_gpt_oss.py new file mode 100644 index 000000000000..9793de4a976e --- /dev/null +++ b/tests/kernels/moe/benchmark_gpt_oss.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +import torch.cuda.profiler as profiler + +from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import ( + gpt_oss_custom_routing_function, +) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + +def profile_run(): + torch.manual_seed(0) + device = "cuda" + + test_cases = [ + { + "name": "GPTOSS20B", + "desc": "gpt oss 20b prefill", + "M": 4096, + "N": 32, + "topk": 4, + }, + ] + + def run_origin(hidden_states, router_logits, topk): + _ = FusedMoE.select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=topk, + use_grouped_topk=False, + renormalize=True, + custom_routing_function=None, + ) + + def run_triton(hidden_states, router_logits, topk): + _ = FusedMoE.select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=topk, + use_grouped_topk=False, + renormalize=True, + custom_routing_function=gpt_oss_custom_routing_function, + ) + + for case in test_cases: + M, N, topk = case["M"], case["N"], case["topk"] + hidden_states = torch.randn(M, 4096, device=device, dtype=torch.float16) + router_logits = torch.randn(M, N, device=device, dtype=torch.float16) + + for i in range(20): + print(f"Starting Global Warmups, Iter {i}") + run_origin(hidden_states, router_logits, topk) + run_triton(hidden_states, router_logits, topk) + + torch.cuda.synchronize() + print("Warmup Completed. All kernels are compiled.") + + profiler.start() + + for case in test_cases: + M, N, topk = case["M"], case["N"], case["topk"] + hidden_states = torch.randn(M, 4096, device=device, dtype=torch.float16) + router_logits = torch.randn(M, N, device=device, dtype=torch.float16) + run_origin(hidden_states, router_logits, topk) + run_triton(hidden_states, router_logits, topk) + torch.cuda.synchronize() + + profiler.stop() + print("Benchmark finished.") + + +if __name__ == "__main__": + profile_run() diff --git a/tests/kernels/moe/test_bitonic.py b/tests/kernels/moe/test_bitonic.py new file mode 100644 index 000000000000..6ac5e2fdb9ef --- /dev/null +++ b/tests/kernels/moe/test_bitonic.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.model_executor.layers.fused_moe.triton_bitonic_sort import ( + bitonic_ce_descending_wrapper, + bitonic_sort32_descending, + bitonic_sort32_descending_wrapper, +) +from vllm.triton_utils import tl, triton + + +def test_bitonic_descending(): + val = torch.arange(32, dtype=torch.float32, device="cuda") + seq = torch.arange(32, dtype=torch.int32, device="cuda") + new_val = torch.zeros(32, dtype=torch.float32, device="cuda") + new_seq = torch.zeros(32, dtype=torch.int32, device="cuda") + ref_1_seq = torch.tensor( + [ + 1, + 0, + 2, + 3, + 5, + 4, + 6, + 7, + 9, + 8, + 10, + 11, + 13, + 12, + 14, + 15, + 17, + 16, + 18, + 19, + 21, + 20, + 22, + 23, + 25, + 24, + 26, + 27, + 29, + 28, + 30, + 31, + ], + dtype=torch.int32, + device="cuda", + ) + + # assert stride 1 is correct when constructing bitonic + bitonic_ce_descending_wrapper[(1,)](val, seq, new_val, new_seq, 1) + torch.testing.assert_close(new_seq, ref_1_seq) + + # assert final sort result + bitonic_sort32_descending_wrapper[(1,)](val, seq, new_val, new_seq) + seq = seq.flip(0) + torch.testing.assert_close(new_seq, seq) + + +@triton.jit +def test_bitonic_2d_kernel( + in_ptr, + out_val_ptr, + out_idx_ptr, + ROWS: tl.constexpr, +): + offs_row = tl.arange(0, ROWS) + offs_col = tl.arange(0, 32) + + vals = tl.load(in_ptr + offs_row[:, None] * 32 + offs_col[None, :]) # [ROWS, 32] + + idxs = tl.broadcast_to(offs_col[None, :], (ROWS, 32)).to(tl.int32) # [ROWS, 32] + + sorted_vals, sorted_idxs = bitonic_sort32_descending(vals, idxs) + + tl.store(out_val_ptr + offs_row[:, None] * 32 + offs_col[None, :], sorted_vals) + tl.store(out_idx_ptr + offs_row[:, None] * 32 + offs_col[None, :], sorted_idxs) + + +def test_bitonic_multirow(): + for ROWS in [1, 2, 4, 8]: + torch.manual_seed(42) + x = torch.randn(ROWS, 32, device="cuda", dtype=torch.float32) + out_vals = torch.empty_like(x) + out_idxs = torch.empty(ROWS, 32, device="cuda", dtype=torch.int32) + + # assumingly, num_warps >= ROWS + test_bitonic_2d_kernel[(1,)]( + x, + out_vals, + out_idxs, + ROWS=ROWS, + num_warps=max(ROWS, 4), + ) + + expected_vals, expected_idxs = x.sort(dim=1, descending=True) + + vals_match = torch.allclose(out_vals, expected_vals) + idxs_match = torch.equal(out_idxs, expected_idxs.to(torch.int32)) + + print(f"values match: {vals_match}") + print(f"indices match: {idxs_match}") + + if not vals_match or not idxs_match: + print("input:") + print(x) + print("result vals:") + print(out_vals) + print("expected vals:") + print(expected_vals) + print("result idxs:") + print(out_idxs) + print("expected idxs:") + print(expected_idxs) + + +if __name__ == "__main__": + test_bitonic_multirow() diff --git a/tests/kernels/moe/test_gpt_oss_routing_consistency.py b/tests/kernels/moe/test_gpt_oss_routing_consistency.py new file mode 100644 index 000000000000..a781c481205b --- /dev/null +++ b/tests/kernels/moe/test_gpt_oss_routing_consistency.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import ( + gpt_oss_custom_routing_function, +) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.platforms import current_platform + + +@pytest.mark.parametrize("num_tokens", [10, 128, 1024]) +@pytest.mark.parametrize("num_experts", [32, 65, 128]) +@pytest.mark.parametrize("topk", [1, 2, 3, 4]) +@pytest.mark.parametrize("renorm", [True, False]) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="only available on CUDA") +def test_routing_consistency(num_tokens, num_experts, topk, renorm): + torch.manual_seed(42) + device = torch.device("cuda") + hidden_states = torch.randn(num_tokens, 4096, device=device, dtype=torch.float16) + router_logits = torch.randn( + num_tokens, num_experts, device=device, dtype=torch.float32 + ) + + ref_weights, ref_ids, _ = FusedMoE.select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=topk, + use_grouped_topk=False, + renormalize=renorm, + custom_routing_function=None, + ) + + triton_weights, triton_ids, _ = FusedMoE.select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=topk, + use_grouped_topk=False, + renormalize=renorm, + custom_routing_function=gpt_oss_custom_routing_function, + ) + + # compare triton with origin + torch.testing.assert_close( + triton_ids, + ref_ids, + msg="Expert indices mismatch between origin and triton implementation", + ) + torch.testing.assert_close( + triton_weights, + ref_weights, + atol=1e-3, + rtol=1e-3, + msg="Expert weights mismatch between origin and triton implementation", + ) + expected_indices_dtype = ref_ids.dtype + expecteed_weight_dtype = ref_weights.dtype + + def native_impl(logits, topk, renorm): + if renorm: + ref_vals, ref_indices = torch.topk(logits, topk, dim=1) + ref_vals = torch.softmax(ref_vals, dim=1) + else: + ref_vals = torch.softmax(logits, dim=1) + ref_vals, ref_indices = torch.topk(ref_vals, topk, dim=1) + return ref_vals.to(expecteed_weight_dtype), ref_indices.to( + expected_indices_dtype + ) + + native_weights, native_ids = native_impl(router_logits, topk, renorm) + + # compare triton with torch + torch.testing.assert_close( + triton_ids, + native_ids, + msg="Expert indices mismatch between native and triton implementation", + ) + torch.testing.assert_close( + triton_weights, + native_weights, + atol=1e-3, + rtol=1e-3, + msg="Expert weights mismatch between native and triton implementation", + ) + + # compare origin with torch + torch.testing.assert_close( + native_ids, + ref_ids, + msg="Expert indices mismatch between origin and native implementation", + ) + torch.testing.assert_close( + native_weights, + ref_weights, + atol=1e-3, + rtol=1e-3, + msg="Expert weights mismatch between origin and native implementation", + ) + + print(f"\nTesting TOKENS={num_tokens}, EXPERTS={num_experts}, TOPK={topk}") diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_fused_router.py b/vllm/model_executor/layers/fused_moe/gpt_oss_fused_router.py new file mode 100644 index 000000000000..3a64e648ab88 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_fused_router.py @@ -0,0 +1,339 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.model_executor.layers.fused_moe.triton_bitonic_sort import ( + bitonic_sort32_descending, +) +from vllm.triton_utils import tl, triton + + +@triton.autotune( + configs=[ + triton.Config({"ROWS_PER_PID": r}, num_warps=num_warps, num_stages=num_stages) + for r in [1, 2, 4, 8, 16, 32] + for num_warps in [1, 2, 4, 8, 16] + for num_stages in [1, 2, 3] + ], + key=["N", "topk"], + cache_results=True, +) +@triton.jit +def _topk_softmax_kernel( + logits_ptr, + weights_ptr, + indices_ptr, + M: tl.constexpr, + N: tl.constexpr, + topk: tl.constexpr, + stride_lm, + stride_ln, + stride_wm, + stride_wk, + stride_im, + stride_ik, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + topk_padded: tl.constexpr, + RENORM: tl.constexpr, + ROWS_PER_PID: tl.constexpr, + num_stages: tl.constexpr, + USE_BITONIC: tl.constexpr, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, topk_padded) + mask_n = offs_n < N + store_mask = offs_k < topk + warp_size: tl.constexpr = 32 + + # impl topk<=2 and RENORM specialization by tl.constexpr, + # same as constexpr if in C++17 + if topk == 1: + for row_idx in tl.range(pid, M, num_programs, num_stages, warp_specialize=True): + if BLOCK_N != N: + logits = tl.load( + logits_ptr + row_idx * stride_lm + offs_n * stride_ln, + mask=mask_n, + other=float("-inf"), + ) + else: + logits = tl.load( + logits_ptr + row_idx * stride_lm + offs_n * stride_ln, + ) + + if not RENORM: + row_sub_max = logits - tl.max(logits, axis=0) + numerator = tl.exp(row_sub_max) + denominator = tl.sum(numerator, axis=0) + logits = numerator / denominator + + cur_max = 1 if RENORM else tl.max(logits, axis=0) + cur_idx = tl.argmax(logits, axis=0) + + tl.store(weights_ptr + row_idx * stride_wm + 0 * stride_wk, cur_max) + tl.store(indices_ptr + row_idx * stride_im + 0 * stride_ik, cur_idx) + + elif topk == 2: + for row_idx in tl.range(pid, M, num_programs, num_stages, warp_specialize=True): + if BLOCK_N != N: + logits = tl.load( + logits_ptr + row_idx * stride_lm + offs_n * stride_ln, + mask=mask_n, + other=float("-inf"), + ) + else: + logits = tl.load( + logits_ptr + row_idx * stride_lm + offs_n * stride_ln, + ) + + if not RENORM: + row_sub_max = logits - tl.max(logits, axis=0) + numerator = tl.exp(row_sub_max) + denominator = tl.sum(numerator, axis=0) + logits = numerator / denominator + + val0 = tl.max(logits, axis=0) + idx0 = tl.argmax(logits, axis=0) + logits = tl.where(offs_n == idx0, float("-inf"), logits) + val1 = tl.max(logits, axis=0) + idx1 = tl.argmax(logits, axis=0) + + if RENORM: + max_val = tl.maximum(val0, val1) + exp0 = tl.exp(val0 - max_val) + exp1 = tl.exp(val1 - max_val) + val0 = exp0 / (exp0 + exp1) + val1 = exp1 / (exp0 + exp1) + + tl.store(weights_ptr + row_idx * stride_wm, val0) + tl.store(indices_ptr + row_idx * stride_im, idx0) + tl.store(weights_ptr + row_idx * stride_wm + 1 * stride_wk, val1) + tl.store(indices_ptr + row_idx * stride_im + 1 * stride_ik, idx1) + + else: + rows = tl.arange(0, ROWS_PER_PID) + for row_idx in tl.range( + pid * ROWS_PER_PID, + M, + num_programs * ROWS_PER_PID, + num_stages, + warp_specialize=True, + ): + topk_vals = tl.full( + [ROWS_PER_PID, topk_padded], float("-inf"), dtype=tl.float32 + ) + topk_idxs = tl.zeros([ROWS_PER_PID, topk_padded], dtype=tl.int32) + row_indices = row_idx + rows # [ROWS_PER_POD,] + row_mask = row_indices < M + + # broadcast to [ROWS_PER_PID, BLOCKN] + ptr_off = ( + logits_ptr + + row_indices[:, None] * stride_lm + + offs_n[None, :] * stride_ln + ) + if BLOCK_N == N and BLOCK_M == M: + logits = tl.load(ptr_off) + elif BLOCK_N != N and BLOCK_M != M: + logits = tl.load( + ptr_off, + mask=row_mask[:, None] & mask_n[None, :], + other=float("-inf"), + ) + elif BLOCK_N != N: + logits = tl.load(ptr_off, mask=mask_n[None, :], other=float("-inf")) + elif BLOCK_M != M: + logits = tl.load(ptr_off, mask=row_mask[:, None], other=float("-inf")) + + if not RENORM: + row_sub_max = logits - tl.max( + logits, axis=1, keep_dims=True + ) # [ROWS_PER_PID, BLOCK_N] - [ROWS_PER_PID,1] + numerator = tl.exp(row_sub_max) + denominator = tl.sum( + numerator, axis=1, keep_dims=True + ) # [ROWS_PER_PID, BLOCKN] + logits = numerator / denominator + + if warp_size == N: + idx = tl.arange(0, warp_size)[None, :] + idxs = tl.broadcast_to(idx, (ROWS_PER_PID, warp_size)) + sorted_val, sorted_idx = bitonic_sort32_descending( + val=logits, idx=idxs + ) # [ROWS_PER_PID, 32] + tl.static_assert(sorted_val.shape == (ROWS_PER_PID, warp_size)) + # USE_BITONIC: tl.constexpr = True + else: + for k in tl.static_range(topk): + cur_max = tl.max( + logits, axis=1, keep_dims=True + ) # [ROWS_PER_PID, 1] + cur_idx = tl.argmax(logits, axis=1, keep_dims=True) + + k_mask = offs_k == k + topk_vals = tl.where( + k_mask, cur_max, topk_vals + ) # [ROWS_PER PID, 1], [ROWS_PER PID, topkpadded] + topk_idxs = tl.where(k_mask, cur_idx, topk_idxs) + + mask_selected = ( + cur_idx == offs_n[None, :] + ) # [ROWSPERPID,1] [1,BLOCKN] + logits = tl.where(mask_selected, float("-inf"), logits) + # USE_BITONIC: tl.constexpr = False + + if RENORM: + if USE_BITONIC: + topk_col_mask = ( + tl.arange(0, warp_size)[None, :] < topk + ) # [1, warp_size] + masked_val = tl.where(topk_col_mask, sorted_val, float("-inf")) + masked_val = masked_val - tl.max(masked_val, axis=1, keep_dims=True) + numerator = tl.exp(masked_val) + numerator = tl.where(topk_col_mask, numerator, 0.0) + denominator = tl.sum(numerator, axis=1, keep_dims=True) + sorted_val = tl.where( + topk_col_mask, numerator / denominator, sorted_val + ) + else: + topk_vals = topk_vals - tl.max( + topk_vals, axis=1, keep_dims=True + ) # [ROWSPERPID, topkpadded] - [ROWSPERPID,1] + numerator = tl.exp(topk_vals) + denominator = tl.sum( + numerator, axis=1, keep_dims=True + ) # [ROWSPERPID,1] + topk_vals = numerator / denominator # [ROWSPERPID,topkpadded] + + if USE_BITONIC: + offs_warp_size = tl.arange(0, warp_size) + store_col_mask = offs_warp_size < topk + tl.store( + weights_ptr + + row_indices[:, None] * stride_wm + + offs_warp_size[None, :] * stride_wk, + sorted_val, + mask=row_mask[:, None] & store_col_mask[None, :], + ) + tl.store( + indices_ptr + + row_indices[:, None] * stride_im + + offs_warp_size[None, :] * stride_ik, + sorted_idx, + mask=row_mask[:, None] & store_col_mask[None, :], + ) + else: + if topk == topk_padded and BLOCK_M == M: + tl.store( + weights_ptr + + row_indices[:, None] * stride_wm # [ROWSPERPID,1] + + offs_k[None, :] * stride_wk, # [1, topkpadded] + topk_vals, + ) + tl.store( + indices_ptr + + row_indices[:, None] * stride_im + + offs_k[None, :] * stride_ik, + topk_idxs, + ) + elif topk != topk_padded and BLOCK_M != M: + tl.store( + weights_ptr + + row_indices[:, None] * stride_wm # [ROWSPERPID,1] + + offs_k[None, :] * stride_wk, # [1, topkpadded] + topk_vals, + mask=row_mask[:, None] & store_mask[None, :], # [1, topkpadded] + ) + tl.store( + indices_ptr + + row_indices[:, None] * stride_im + + offs_k[None, :] * stride_ik, + topk_idxs, + mask=row_mask[:, None] & store_mask[None, :], + ) + elif topk != topk_padded: + tl.store( + weights_ptr + + row_indices[:, None] * stride_wm # [ROWSPERPID,1] + + offs_k[None, :] * stride_wk, # [1, topkpadded] + topk_vals, + mask=store_mask[None, :], # [1, topkpadded] + ) + tl.store( + indices_ptr + + row_indices[:, None] * stride_im + + offs_k[None, :] * stride_ik, + topk_idxs, + mask=store_mask[None, :], + ) + elif BLOCK_M != M: + tl.store( + weights_ptr + + row_indices[:, None] * stride_wm # [ROWSPERPID,1] + + offs_k[None, :] * stride_wk, # [1, topkpadded] + topk_vals, + mask=row_mask[:, None], + ) + tl.store( + indices_ptr + + row_indices[:, None] * stride_im + + offs_k[None, :] * stride_ik, + topk_idxs, + mask=row_mask[:, None], + ) + + +def fused_topk_softmax( + router_logits: torch.Tensor, + topk: int, + renormalize: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = router_logits.shape # num_tokens, num_experts + weights = torch.empty((M, topk), device=router_logits.device, dtype=torch.float32) + indices = torch.empty((M, topk), device=router_logits.device, dtype=torch.int32) + + BLOCK_N = triton.next_power_of_2(N) # num_padded_experts + topk_padded = triton.next_power_of_2(topk) + BLOCK_M = triton.next_power_of_2(M) + warp_size = 32 + + # enable autotune to find correct num threadblock, + # refer to https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + grid = lambda META: (triton.cdiv(M, META["ROWS_PER_PID"]),) + + _topk_softmax_kernel[grid]( + logits_ptr=router_logits, + weights_ptr=weights, + indices_ptr=indices, + M=M, + N=N, + topk=topk, + stride_lm=router_logits.stride(0), + stride_ln=router_logits.stride(1), + stride_wm=weights.stride(0), + stride_wk=weights.stride(1), + stride_im=indices.stride(0), + stride_ik=indices.stride(1), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + topk_padded=topk_padded, + RENORM=renormalize, + USE_BITONIC=topk > 2 and warp_size == N, + ) + + return weights, indices + + +def gpt_oss_custom_routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + # only use gating_output to avoid padding issues + assert gating_output.is_contiguous() + return fused_topk_softmax(gating_output, topk, renormalize) diff --git a/vllm/model_executor/layers/fused_moe/triton_bitonic_sort.py b/vllm/model_executor/layers/fused_moe/triton_bitonic_sort.py new file mode 100644 index 000000000000..e90d02ef6cdc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/triton_bitonic_sort.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.triton_utils import tl, triton + +""" +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=shfl#restricted-use-of-sub-word-sizes +8-bit or 16-bit values may be held directly in 32-bit or 64-bit registers +when being loaded, stored, or converted to other types and sizes. +""" + + +@triton.jit +def bitonic_ce_descending( + val, idx, stride: tl.constexpr, log2_length_pair: tl.constexpr +): + new_val, new_idx = tl.inline_asm_elementwise( + asm=""" + { + .reg .f32 %partner_val; + .reg .s32 %partner_idx; + .reg .u32 %lane_id; + .reg .u32 %is_left; + .reg .pred %p_left, %p_swap; + .reg .u32 %group_id; + .reg .pred %group_id_mask; + + // save input args to partner regs + shfl.sync.bfly.b32 %partner_val, $2, $4, 0x1f, 0xffffffff; + shfl.sync.bfly.b32 %partner_idx, $3, $4, 0x1f, 0xffffffff; + + mov.u32 %lane_id, %laneid; + + shr.u32 %group_id, %lane_id, $5; + and.b32 %group_id, %group_id, 1; + setp.eq.u32 %group_id_mask, %group_id, 1; + + and.b32 %is_left, %lane_id, $4; + setp.eq.u32 %p_left, %is_left, 0; + + // compare partner_val > val? if so, swap. + setp.gt.f32 %p_swap, %partner_val, $2; + + // TODO(ijpq): + // this logic might be redundant. + // require simplify further. + xor.pred %p_swap, %p_swap, %p_left; + not.pred %p_swap, %p_swap; + xor.pred %p_swap, %p_swap, %group_id_mask; + + selp.f32 $0, %partner_val, $2, %p_swap; + selp.b32 $1, %partner_idx, $3, %p_swap; + } + """, + constraints="=f,=r,f,r,n,n", + args=[val, idx, stride, log2_length_pair], + dtype=(tl.float32, tl.int32), + is_pure=True, + pack=1, + ) + return new_val, new_idx + + +@triton.jit +def bitonic_compare_across_part_descending(val, idx, stride: tl.constexpr): + new_val, new_idx = tl.inline_asm_elementwise( + asm="""{ + .reg .f32 %partner_val; + .reg .s32 %partner_idx; + .reg .u32 %is_left; + .reg .pred %p_left; + .reg .pred %p_swap; + .reg .u32 %lane_id; + // $2 val, $3 idx, $4 stride; + + shfl.sync.bfly.b32 %partner_val, $2, $4, 0x1f, 0xffffffff; + shfl.sync.bfly.b32 %partner_idx, $3, $4, 0x1f, 0xffffffff; + + mov.u32 %lane_id, %laneid; + and.b32 %is_left, %lane_id, $4; + setp.eq.u32 %p_left, %is_left, 0; + setp.gt.f32 %p_swap, %partner_val, $2; + xor.pred %p_swap, %p_swap, %p_left; + not.pred %p_swap, %p_swap; + + selp.f32 $0, %partner_val, $2, %p_swap; + selp.b32 $1, %partner_idx, $3, %p_swap; + }""", + constraints="=f,=r,f,r,n", + args=[val, idx, stride], + dtype=(tl.float32, tl.int32), + is_pure=True, + pack=1, + ) + return new_val, new_idx + + +@triton.jit +def bitonic_sort32_descending(val, idx): + # length_pair = 2, log2(2) = 1 + val, idx = bitonic_ce_descending(val, idx, 1, 1) + + # length_pair = 4, log2(4) = 2 + val, idx = bitonic_ce_descending(val, idx, 2, 2) + val, idx = bitonic_ce_descending(val, idx, 1, 2) + + # length_pair = 8, log2(8) = 3 + val, idx = bitonic_ce_descending(val, idx, 4, 3) + val, idx = bitonic_ce_descending(val, idx, 2, 3) + val, idx = bitonic_ce_descending(val, idx, 1, 3) + + # length_pair = 16, log2(16) = 4 + val, idx = bitonic_ce_descending(val, idx, 8, 4) + val, idx = bitonic_ce_descending(val, idx, 4, 4) + val, idx = bitonic_ce_descending(val, idx, 2, 4) + val, idx = bitonic_ce_descending(val, idx, 1, 4) + + # length_pair = 32, log2(32) = 5 + val, idx = bitonic_compare_across_part_descending(val, idx, 16) + val, idx = bitonic_compare_across_part_descending(val, idx, 8) + val, idx = bitonic_compare_across_part_descending(val, idx, 4) + val, idx = bitonic_compare_across_part_descending(val, idx, 2) + val, idx = bitonic_compare_across_part_descending(val, idx, 1) + + return val, idx + + +@triton.jit +def bitonic_ce_descending_wrapper( + val_ptr, idx_ptr, new_val_ptr, new_idx_ptr, stride: tl.constexpr +): + offs = tl.arange(0, 32) + val = tl.load(val_ptr + offs) + idx = tl.load(idx_ptr + offs) + new_val, new_idx = bitonic_ce_descending(val, idx, stride, 1) + tl.store(new_val_ptr + offs, new_val) + tl.store(new_idx_ptr + offs, new_idx) + + +@triton.jit +def bitonic_sort32_descending_wrapper(val_ptr, idx_ptr, new_val_ptr, new_idx_ptr): + offs = tl.arange(0, 32) + val = tl.load(val_ptr + offs) + idx = tl.load(idx_ptr + offs) + new_val, new_idx = bitonic_sort32_descending(val, idx) + tl.store(new_val_ptr + offs, new_val) + tl.store(new_idx_ptr + offs, new_idx) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 7df3b087ccb8..7aaa63bf76e7 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -20,6 +20,9 @@ ) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.gpt_oss_fused_router import ( + gpt_oss_custom_routing_function, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -173,6 +176,9 @@ def __init__( has_bias=True, activation="swigluoai", is_sequence_parallel=self.is_sequence_parallel, + custom_routing_function=( + gpt_oss_custom_routing_function if current_platform.is_cuda() else None + ), ) def forward(self, x: torch.Tensor) -> torch.Tensor: