diff --git a/CMakeLists.txt b/CMakeLists.txt index e09972fe7199..df18d05e8ce1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -945,7 +945,8 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/moe_lora_align_sum_kernels.cu" - "csrc/moe/topk_softmax_kernels.cu") + "csrc/moe/topk_softmax_kernels.cu" + "csrc/moe/moe_fused_gate.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC diff --git a/benchmarks/kernels/benchmark_moe_fused_gate.py b/benchmarks/kernels/benchmark_moe_fused_gate.py new file mode 100644 index 000000000000..0e90d73551f9 --- /dev/null +++ b/benchmarks/kernels/benchmark_moe_fused_gate.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm._custom_ops import moe_fused_gate +from vllm.model_executor.layers.fused_moe.fused_moe import ( + grouped_topk as vllm_compiled_grouped_topk, +) +from vllm.triton_utils import triton + + +def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk): + return vllm_compiled_grouped_topk( + hidden_states=scores, + gating_output=scores, + topk=topk, + renormalize=True, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func="sigmoid", + e_score_correction_bias=bias, + ) + + +def biased_grouped_topk_org_kernel(scores, bias, num_expert_group, topk_group, topk): + return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk) + + +seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000] +configs = [(sq,) for sq in seq_length_range] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["seq_length"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["original", "kernel"], + line_names=["Original", "SGL Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name="moe-fused-gate-performance", + args={}, + ) +) +def benchmark(seq_length, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8 + + scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype) + bias = torch.rand(num_experts, device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "original": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: biased_grouped_topk_org( + scores.clone(), bias.clone(), num_expert_group, topk_group, topk + ), + quantiles=quantiles, + ) + elif provider == "kernel": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: biased_grouped_topk_org_kernel( + scores.clone(), bias.clone(), num_expert_group, topk_group, topk + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/csrc/moe/moe_fused_gate.cu b/csrc/moe/moe_fused_gate.cu new file mode 100644 index 000000000000..bbc9ad927b3b --- /dev/null +++ b/csrc/moe/moe_fused_gate.cu @@ -0,0 +1,486 @@ +// copied from +// https://github.com/sgl-project/sglang/blob/v0.5.5/sgl-kernel/csrc/moe/moe_fused_gate.cu +#include +#include +#include +#include +#include +#include +#include + +#include +#include +template +using AlignedArray = cutlass::AlignedArray; +using bfloat16_t = cutlass::bfloat16_t; +using float16_t = cutlass::half_t; +using float32_t = float; + +// QQ NOTE: to handle the case for at::Half, error: more than one operator ">" +// matches these operands: built-in operator "arithmetic > arithmetic" function +// "operator>(const __half &, const __half &)" +template +__device__ inline bool cmp_gt(const T& a, const T& b) { + if constexpr (std::is_same::value) { + // at::Half (or float16_t in our native case) causes ambiguity, so we cast + // to float. + return static_cast(a) > static_cast(b); + } else { + // For types like float, at::BFloat16, or cutlass::half_t / + // cutlass::bfloat16_t, assume operator> works as expected. + return a > b; + } +} + +template +__device__ inline bool cmp_eq(const T& a, const T& b) { + if constexpr (std::is_same::value) { + return static_cast(a) == static_cast(b); + } else { + return a == b; + } +} + +// Fixed constants common to both dynamic and static template versions: +static constexpr int WARP_SIZE = 32; +static constexpr int WARPS_PER_CTA = 6; +static constexpr int MAX_VPT = + 32; // maximum VPT we support, > params.VPT = num_expert / num_expert_group + +// Create an alias for Array using AlignedArray +template +using Array = AlignedArray; +// QQ: NOTE expression must have a constant value, this has to be > params.VPT +template +using AccessType = AlignedArray; + +template +__device__ void moe_fused_gate_impl(void* input, void* bias, float* output_ptr, + int32_t* indices_ptr, int64_t num_rows, + int64_t topk_group, int64_t topk, + int64_t num_fused_shared_experts, + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output, + Params params) { + int tidx = threadIdx.x; + int64_t thread_row = blockIdx.x * params.ROWS_PER_CTA + + threadIdx.y * params.ROWS_PER_WARP + + tidx / params.THREADS_PER_ROW; + if (thread_row >= num_rows) { + return; + } + + // Calculate topk_excluding_share_expert_fusion from topk + int64_t topk_excluding_share_expert_fusion = topk - num_fused_shared_experts; + + // Cast pointers to type T: + auto* input_ptr = reinterpret_cast(input); + auto* bias_ptr = reinterpret_cast(bias); + auto* thread_row_ptr = input_ptr + thread_row * params.NUM_EXPERTS; + + int thread_group_idx = tidx % params.THREADS_PER_ROW; + int first_elt_read_by_thread = thread_group_idx * params.VPT; + + // Create local arrays for the row chunk and bias chunk and then reinterpret + // the address of row_chunk as a pointer to AccessType. + T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + Array row_chunk; + AccessType const* vec_thread_read_ptr = + reinterpret_cast const*>(thread_read_ptr); + + T* bias_thread_read_ptr = bias_ptr + first_elt_read_by_thread; + Array bias_chunk; + AccessType const* vec_bias_thread_read_ptr = + reinterpret_cast const*>(bias_thread_read_ptr); + +// QQ NOTE: doing the follow will be slower than loop assign and more +// importantly have misaligned address issue when params.VPT < 8 and mismatch +// with MAX_VPT AccessType* row_chunk_vec_ptr = +// reinterpret_cast*>(&row_chunk); row_chunk_vec_ptr[0] = +// vec_thread_read_ptr[0]; +#pragma unroll + for (int ii = 0; ii < params.VPT; ++ii) { + row_chunk[ii] = vec_thread_read_ptr[0][ii]; + bias_chunk[ii] = vec_bias_thread_read_ptr[0][ii]; + } + + __syncthreads(); + +////////////////////// Sigmoid ////////////////////// +#pragma unroll + for (int ii = 0; ii < params.VPT; ++ii) { + row_chunk[ii] = static_cast(1.0f / (1.0f + expf(-float(row_chunk[ii])))); + } + __syncthreads(); + +////////////////////// Add Bias ////////////////////// +#pragma unroll + for (int ii = 0; ii < params.VPT; ++ii) { + bias_chunk[ii] = row_chunk[ii] + bias_chunk[ii]; + } + +////////////////////// Exclude Groups ////////////////////// +#pragma unroll + for (int k_idx = 0; k_idx < params.THREADS_PER_ROW - topk_group; + ++k_idx) { // QQ NOTE Here params.THREADS_PER_ROW = num_expert_group + int expert = first_elt_read_by_thread; + // local argmax + T max_val = static_cast(-FLT_MAX); + T max_val_second = static_cast(-FLT_MAX); +#pragma unroll + for (int ii = 0; ii < params.VPT; ++ii) { + T val = bias_chunk[ii]; + + if (cmp_gt(val, max_val)) { + max_val_second = max_val; + max_val = val; + } else if (cmp_gt(val, max_val_second)) { + max_val_second = val; + } + } + + // QQ NOTE: currently fixed to pick top2 sigmoid weight value in each expert + // group and sum them as the group weight to select expert groups + T max_sum = max_val + max_val_second; + +// argmin reduce +#pragma unroll + for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + T other_max_sum = static_cast( + __shfl_xor_sync(0xFFFFFFFF, static_cast(max_sum), mask, + params.THREADS_PER_ROW)); + int other_expert = + __shfl_xor_sync(0xFFFFFFFF, expert, mask, params.THREADS_PER_ROW); + + // higher indices win + if (cmp_gt(max_sum, other_max_sum) || + (cmp_eq(other_max_sum, max_sum) && other_expert > expert)) { + max_sum = other_max_sum; + expert = other_expert; + } + } + + // clear the max value in the thread + if (k_idx < params.THREADS_PER_ROW - topk_group) { + int const thread_to_clear_in_group = expert / params.VPT; + + if (thread_group_idx == thread_to_clear_in_group) { +#pragma unroll + for (int ii = 0; ii < params.VPT; ++ii) { + bias_chunk[ii] = static_cast(FLT_MAX); + } + } + } + } + + __syncthreads(); + + ////////////////////// Topk ////////////////////// + float output_sum = 0.0f; + for (int k_idx = 0; k_idx < topk_excluding_share_expert_fusion; ++k_idx) { + // local argmax + T max_val = bias_chunk[0]; + int expert = first_elt_read_by_thread; + + if (!cmp_eq(max_val, static_cast(FLT_MAX))) { +#pragma unroll + for (int ii = 1; ii < params.VPT; ++ii) { + T val = bias_chunk[ii]; + if (cmp_gt(val, max_val)) { + max_val = val; + expert = first_elt_read_by_thread + ii; + } + } + } else { + max_val = static_cast(-FLT_MAX); + } + + // argmax reduce +#pragma unroll + for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + T other_max = static_cast( + __shfl_xor_sync(0xFFFFFFFF, static_cast(max_val), mask, + params.THREADS_PER_ROW)); + int other_expert = + __shfl_xor_sync(0xFFFFFFFF, expert, mask, params.THREADS_PER_ROW); + + // lower indices to win + if (cmp_gt(other_max, max_val) || + (cmp_eq(other_max, max_val) && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } + + int thread_to_clear_in_group = expert / params.VPT; + int64_t idx = topk * thread_row + k_idx; + + if (thread_group_idx == thread_to_clear_in_group) { + int expert_to_clear_in_thread = expert % params.VPT; + + // clear the max value in the thread + bias_chunk[expert_to_clear_in_thread] = static_cast(-FLT_MAX); + + // store output + output_ptr[idx] = + static_cast(row_chunk[expert_to_clear_in_thread]); + indices_ptr[idx] = static_cast(expert); + } + + // accumulate sum for all elements + if (thread_group_idx == 0) { + output_sum += output_ptr[idx]; + } + + __syncthreads(); + } + + if (thread_group_idx == 0 && num_fused_shared_experts > 0) { + int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion; + int64_t expert_offset = 0; + + indices_ptr[last_idx] = + static_cast(params.NUM_EXPERTS + expert_offset); + + // Set the weight to the sum of all weights divided by routed_scaling_factor + output_ptr[last_idx] = output_sum / routed_scaling_factor; + + if (num_fused_shared_experts > 1) { + for (int i = 1; i < num_fused_shared_experts; ++i) { + ++last_idx; + ++expert_offset; + indices_ptr[last_idx] = + static_cast(params.NUM_EXPERTS + expert_offset); + // Set the weight to the sum of all weights divided by + // routed_scaling_factor + output_ptr[last_idx] = output_sum / routed_scaling_factor; + } + } + } + __syncthreads(); + + ////////////////////// Rescale Output ////////////////////// + if (thread_group_idx == 0) { +#pragma unroll + for (int ii = 0; ii < topk; ++ii) { + int64_t const idx = topk * thread_row + ii; + output_ptr[idx] = output_ptr[idx] / output_sum; + if (apply_routed_scaling_factor_on_output) { + output_ptr[idx] *= routed_scaling_factor; + } + } + } +} + +//------------------------------------------------------------------------------ +// Templated Kernel Version (using compile-time constants) +//------------------------------------------------------------------------------ +template +struct KernelParams { + static constexpr int VPT = VPT_; + static constexpr int NUM_EXPERTS = NUM_EXPERTS_; + static constexpr int THREADS_PER_ROW = THREADS_PER_ROW_; + static constexpr int ROWS_PER_WARP = ROWS_PER_WARP_; + static constexpr int ROWS_PER_CTA = ROWS_PER_CTA_; + static constexpr int WARPS_PER_CTA = WARPS_PER_CTA_; +}; + +template +__global__ void moe_fused_gate_kernel( + void* input, void* bias, float* output_ptr, int32_t* indices_ptr, + int64_t num_rows, int64_t topk_group, int64_t topk, + int64_t num_fused_shared_experts, double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { + KernelParams + params; + moe_fused_gate_impl(input, bias, output_ptr, indices_ptr, num_rows, + topk_group, topk, num_fused_shared_experts, + routed_scaling_factor, + apply_routed_scaling_factor_on_output, params); +} + +// Macro to compute compile-time constants and launch the kernel. +#define LAUNCH_MOE_GATE_CONFIG(T, EXPERTS, EXPERT_GROUP) \ + do { \ + constexpr int VPT = (EXPERTS) / (EXPERT_GROUP); \ + /* If EXPERT_GROUP > WARP_SIZE, fall back to 1 row per warp */ \ + constexpr int ROWS_PER_WARP = \ + ((EXPERT_GROUP) <= WARP_SIZE) ? (WARP_SIZE / (EXPERT_GROUP)) : 1; \ + constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; \ + moe_fused_gate_kernel \ + <<>>( \ + input.data_ptr(), bias.data_ptr(), output.data_ptr(), \ + indices.data_ptr(), num_rows, topk_group, topk, \ + num_fused_shared_experts, routed_scaling_factor, \ + apply_routed_scaling_factor_on_output); \ + dispatched = true; \ + } while (0) + +//------------------------------------------------------------------------------ +// Dynamic Kernel Version (parameters computed at runtime) +//------------------------------------------------------------------------------ +struct KernelParamsDynamic { + int VPT; + int NUM_EXPERTS; + int THREADS_PER_ROW; + int ROWS_PER_WARP; + int ROWS_PER_CTA; + int WARPS_PER_CTA; +}; + +template +__global__ void moe_fused_gate_kernel_dynamic( + void* input, void* bias, float* output_ptr, int32_t* indices_ptr, + int64_t num_rows, int64_t num_experts, int64_t num_expert_group, + int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, + double routed_scaling_factor, bool apply_routed_scaling_factor_on_output) { + KernelParamsDynamic params; + params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256 + params.VPT = num_experts / + num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32 + params.THREADS_PER_ROW = + num_expert_group; // fixed as num_expert_group, e.g., for deepseek v3, + // this is 8 + params.WARPS_PER_CTA = WARPS_PER_CTA; // fixed as 6 + params.ROWS_PER_WARP = std::max( + 1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32 + params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP; + + moe_fused_gate_impl(input, bias, output_ptr, indices_ptr, num_rows, + topk_group, topk, num_fused_shared_experts, + routed_scaling_factor, + apply_routed_scaling_factor_on_output, params); +} + +//------------------------------------------------------------------------------ +// Host Launcher Function +//------------------------------------------------------------------------------ +std::vector moe_fused_gate( + at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, + int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, + double routed_scaling_factor, bool apply_routed_scaling_factor_on_output) { + TORCH_CHECK(input.dtype() == bias.dtype(), + "input and bias should have the same dtype"); + + int64_t num_rows = input.size(0); + int32_t num_experts = input.size(1); + auto options = + torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto output = torch::empty({num_rows, topk}, options); + auto indices = torch::empty({num_rows, topk}, options.dtype(torch::kInt32)); + + // Compute grid dimensions based on runtime value for num_expert_group. + int64_t rows_per_warp = std::max(1, WARP_SIZE / num_expert_group); + int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp; + int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + dim3 block_dim(WARP_SIZE, WARPS_PER_CTA); + + // Check 1: Ensure that num_experts is a power of 2. + TORCH_CHECK((num_experts & (num_experts - 1)) == 0, + "num_experts must be a power of 2, but got ", num_experts); + + // Check 2: Ensure that num_experts is divisible by num_expert_group. (this + // also means num_expert_group is power of 2) + TORCH_CHECK(num_experts % num_expert_group == 0, + "num_experts must be divisible by num_expert_group, but got ", + num_experts, " / ", num_expert_group); + + int computed_vpt = num_experts / num_expert_group; + // Check 3: Ensure that num_experts/num_expert_group does not exceed + // MAX_VPT=32. Maximum VPT indicate max value per threads we can process. + TORCH_CHECK(computed_vpt <= MAX_VPT, + "Per group experts: num_experts / num_expert_group = (", + computed_vpt, ") exceeds the maximum supported (", MAX_VPT, ")"); + + // Dispatch to templated kernel for known compile-time configurations. + // We currently only support for: + // Case 1: 256 experts, with 8 or 16 groups. + // Case 2: 128 experts, with 4 or 8 groups. + // Case 3: other cases, require 8 <= num_experts / num_expert_group <= 32 + bool dispatched = false; + switch (num_experts) { + case 256: + if (num_expert_group == 8) { + // This is deepseek v3 case. Here VPT = 256/8 = 32, ROWS_PER_WARP = 32/8 + // = 4, ROWS_PER_CTA = 6 * 4 = 24. + if (input.scalar_type() == at::kBFloat16) { + LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 8); + } else if (input.scalar_type() == at::kHalf) { + LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 8); + } else if (input.scalar_type() == at::kFloat) { + LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 8); + } + } else if (num_expert_group == 16) { + // Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 + // * 2 = 12. + if (input.scalar_type() == at::kBFloat16) { + LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 16); + } else if (input.scalar_type() == at::kHalf) { + LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 16); + } else if (input.scalar_type() == at::kFloat) { + LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 16); + } + } + break; + case 128: + if (num_expert_group == 4) { + // VPT = 128/4 = 32, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 + // = 12. + if (input.scalar_type() == at::kBFloat16) { + LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 4); + } else if (input.scalar_type() == at::kHalf) { + LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 4); + } else if (input.scalar_type() == at::kFloat) { + LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 4); + } + } else if (num_expert_group == 8) { + // VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 + // = 24. + if (input.scalar_type() == at::kBFloat16) { + LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 8); + } else if (input.scalar_type() == at::kHalf) { + LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 8); + } else if (input.scalar_type() == at::kFloat) { + LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 8); + } + } + break; + default: + break; + } + if (!dispatched) { + // Fallback to the dynamic kernel if none of the supported combinations + // match. currently only support num_experts / num_expert_group <= 32 for + // dynamic kernels + if (input.scalar_type() == at::kBFloat16) { + moe_fused_gate_kernel_dynamic + <<>>( + input.data_ptr(), bias.data_ptr(), output.data_ptr(), + indices.data_ptr(), num_rows, num_experts, + num_expert_group, topk_group, topk, num_fused_shared_experts, + routed_scaling_factor, apply_routed_scaling_factor_on_output); + } else if (input.scalar_type() == at::kHalf) { + moe_fused_gate_kernel_dynamic + <<>>( + input.data_ptr(), bias.data_ptr(), output.data_ptr(), + indices.data_ptr(), num_rows, num_experts, + num_expert_group, topk_group, topk, num_fused_shared_experts, + routed_scaling_factor, apply_routed_scaling_factor_on_output); + } else if (input.scalar_type() == at::kFloat) { + moe_fused_gate_kernel_dynamic + <<>>( + input.data_ptr(), bias.data_ptr(), output.data_ptr(), + indices.data_ptr(), num_rows, num_experts, + num_expert_group, topk_group, topk, num_fused_shared_experts, + routed_scaling_factor, apply_routed_scaling_factor_on_output); + } else { + TORCH_CHECK(false, "Unsupported data type for moe_fused_gate"); + } + } + return {output, indices}; +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 11c6875f7f1d..0238b73f85d9 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -27,6 +27,12 @@ void moe_lora_align_block_size( torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, torch::Tensor lora_ids); + +std::vector moe_fused_gate( + torch::Tensor& input, torch::Tensor& bias, int64_t num_expert_group, + int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts, + double routed_scaling_factor, bool apply_routed_scaling_factor_on_output); + #ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index e0a8280722f3..570388e25e60 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -48,6 +48,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor !adapter_enabled," " Tensor !lora_ids) -> () "); m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); + m.def( + "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int " + "topk_group, int topk, int " + "num_fused_shared_experts, float routed_scaling_factor, " + "bool apply_routed_scaling_factor_on_output) -> " + "(Tensor[])"); + m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); #ifndef USE_ROCM m.def( diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index b0ff1e64e321..b04b48ab0295 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -86,7 +86,7 @@ ] BLOCK_SIZE = [[128, 128]] -E = [2, 8, 16] # [128, 256] +E = [2, 8, 16, 258] # [128, 256] TOP_KS = [1, 2, 6] SEEDS = [0] @@ -149,6 +149,9 @@ def test_w8a8_block_fp8_fused_moe( a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) + if E == 258: + monkeypatch.setenv("VLLM_USE_CUDA_FUSION_SHARED_EXPERTS", "1") + w1, w2, quant_config = make_test_quant_config( E, N, diff --git a/tests/kernels/moe/test_moe_fused_gate.py b/tests/kernels/moe/test_moe_fused_gate.py new file mode 100644 index 000000000000..6c2699d50300 --- /dev/null +++ b/tests/kernels/moe/test_moe_fused_gate.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk + + +@pytest.mark.parametrize( + "seq_length", + list(range(1, 10)) + + [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536], +) +@pytest.mark.parametrize( + "dtype", + [torch.float32], # torch.float16, torch.bfloat16 - aren't working correctly yet +) +@pytest.mark.parametrize( + "params", + [ + # (128, 4, 2, 4), + (256, 8, 4, 8), # deepseek v3 + # (512, 16, 8, 16), + ], +) +@pytest.mark.parametrize( + "num_fused_shared_experts", + [ + 0, + 1, + ], +) +def test_moe_fused_gate_combined( + seq_length, dtype, params, num_fused_shared_experts, monkeypatch +): + num_experts, num_expert_group, topk_group, topk = params + topk += 1 if num_fused_shared_experts > 0 else 0 + + torch.manual_seed(seq_length) + tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda() + scores = tensor.clone() + bias = torch.rand(num_experts).to(dtype).cuda() + routed_scaling_factor = 2.5 + + output, indices = ops.moe_fused_gate( + tensor, + bias, + num_expert_group=num_expert_group, + topk_group=topk_group, + topk=topk, + num_fused_shared_experts=num_fused_shared_experts, + routed_scaling_factor=routed_scaling_factor, + apply_routed_scaling_factor_on_output=True, + ) + + monkeypatch.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") + ref_vllm_output, ref_vllm_indices = grouped_topk( + hidden_states=scores, + gating_output=scores, + topk=topk, + renormalize=True, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func="sigmoid", + e_score_correction_bias=bias, + num_fused_shared_experts=num_fused_shared_experts, + routed_scaling_factor=routed_scaling_factor, + ) + + if num_fused_shared_experts > 0: + original_indices = indices.clone() + original_ref_indices = ref_vllm_indices.clone() + indices = indices[:, :-1] + ref_vllm_indices = ref_vllm_indices[:, :-1] + + valid_min = num_experts + valid_max = num_experts + num_fused_shared_experts + shared_indices = original_indices[:, -1] + shared_ref_indices = original_ref_indices[:, -1] + if shared_indices is not None: + assert torch.all( + (shared_indices >= valid_min) & (shared_indices < valid_max) + ), ( + "Shared expert indices out of range: ", + f"found values outside [{valid_min}, {valid_max})", + ) + if shared_ref_indices is not None: + assert torch.all( + (shared_ref_indices >= valid_min) & (shared_ref_indices < valid_max) + ), ( + "Shared expert reference indices out of range: ", + f"found values outside [{valid_min}, {valid_max})", + ) + + vllm_idx_check = torch.allclose( + ref_vllm_indices.sort()[0].to(torch.int32), + indices.sort()[0].to(torch.int32), + rtol=1e-04, + atol=1e-05, + ) + vllm_output_check = torch.allclose( + ref_vllm_output.sort()[0].to(torch.float32), + output.sort()[0].to(torch.float32), + rtol=1e-04, + atol=1e-03, + ) + + assert vllm_idx_check, ( + f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, " + f"params {params}, num_fused_shared_experts {num_fused_shared_experts}" + ) + assert vllm_output_check, ( + f"Output mismatch at seq_length {seq_length}, dtype {dtype}, " + f"params {params}, num_fused_shared_experts {num_fused_shared_experts}" + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e60158898685..750bc1b931b1 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1936,6 +1936,52 @@ def moe_lora_align_block_size( ) +def moe_fused_gate( + input_tensor: torch.Tensor, + bias: torch.Tensor, + num_expert_group: int, + topk_group: int, + topk: int, + num_fused_shared_experts: int = 0, + routed_scaling_factor: float = 0.0, + apply_routed_scaling_factor_on_output: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops._moe_C.moe_fused_gate( + input_tensor, + bias, + num_expert_group, + topk_group, + topk, + num_fused_shared_experts, + routed_scaling_factor, + apply_routed_scaling_factor_on_output, + ) + + +if hasattr(torch.ops._moe_C, "moe_fused_gate"): + + @register_fake("_moe_C::moe_fused_gate") + def _moe_fused_gate_fake( + input_tensor: torch.Tensor, + bias: torch.Tensor, + num_expert_group: int, + topk_group: int, + topk: int, + num_fused_shared_experts: int = 0, + routed_scaling_factor: float = 1.0, + apply_routed_scaling_factor_on_output: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty( + (input_tensor.size(0), topk), + dtype=torch.float32, + device=input_tensor.device, + ), torch.empty( + (input_tensor.size(0), topk), + dtype=torch.int32, + device=input_tensor.device, + ) + + def moe_wna16_gemm( input: torch.Tensor, output: torch.Tensor, diff --git a/vllm/envs.py b/vllm/envs.py index d0912863e644..338ce89febc8 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -236,6 +236,8 @@ VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False + VLLM_USE_CUDA_FUSION_SHARED_EXPERTS: bool = False + VLLM_USE_FUSED_MOE_ROUTER: bool = False def get_default_cache_root(): @@ -1550,6 +1552,14 @@ def get_vllm_port() -> int | None: "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) ), + # Enable the fusion of the shared experts of the model with other experts. + "VLLM_USE_CUDA_FUSION_SHARED_EXPERTS": lambda: bool( + int(os.getenv("VLLM_USE_CUDA_FUSION_SHARED_EXPERTS", "0")) + ), + # Use the fused grouped top-k MoE expert selection router + "VLLM_USE_FUSED_MOE_ROUTER": lambda: bool( + int(os.getenv("VLLM_USE_FUSED_MOE_ROUTER", "0")) + ), } # --8<-- [end:env-vars-definition] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index df208eae2e71..8da6a67b5561 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1168,9 +1168,20 @@ def grouped_topk( scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, + num_fused_shared_experts: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: + use_fused_moe_grouped_topk = envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK + enable_fused_shared_experts = num_fused_shared_experts > 0 + if enable_fused_shared_experts and use_fused_moe_grouped_topk: + logger.info( + "Fused MoE grouped topk is enabled with fused shared experts.", + "Only one of these options can be used at a time", + "Fused MoE grouped topk is disabled.", + ) + use_fused_moe_grouped_topk = False + if ( - envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK + use_fused_moe_grouped_topk and current_platform.is_cuda() and num_expert_group <= 32 and topk <= 32 @@ -1198,6 +1209,7 @@ def grouped_topk( raise ValueError(f"Unsupported scoring function: {scoring_func}") num_token = scores.size(0) + num_experts = scores.size(1) if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights @@ -1226,18 +1238,45 @@ def grouped_topk( tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] + topk_ids = torch.topk( + tmp_scores, + k=topk, + dim=-1, + sorted=(use_sorted or enable_fused_shared_experts), + )[1] # Use original unbiased scores for the routing weights topk_weights = original_scores.gather(1, topk_ids) else: topk_weights, topk_ids = torch.topk( - tmp_scores, k=topk, dim=-1, sorted=use_sorted + tmp_scores, + k=topk, + dim=-1, + sorted=(use_sorted or enable_fused_shared_experts), ) + if enable_fused_shared_experts: + assert routed_scaling_factor is not None, "With num_fused_shared_experts>0" + ", routed_scaling_factor need to be provided" + topk_ids[:, -1] = torch.randint( + low=num_experts, + high=num_experts + num_fused_shared_experts, + size=(topk_ids.size(0),), + dtype=topk_ids.dtype, + device=topk_ids.device, + ) + if routed_scaling_factor != 1.0: + topk_weights[:, -1] = ( + topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor + ) + if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + if not enable_fused_shared_experts: + topk_weights_sum = topk_weights.sum(dim=-1, keepdim=True) + else: + topk_weights_sum = topk_weights[:, :-1].sum(dim=-1, keepdim=True) + topk_weights = topk_weights / topk_weights_sum - if routed_scaling_factor != 1.0: + if not enable_fused_shared_experts and routed_scaling_factor != 1.0: topk_weights = topk_weights * routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to(torch.int32) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 902a77987d61..c426714a6850 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from collections.abc import Callable, Iterable from contextlib import nullcontext from enum import Enum -from functools import partial from typing import Literal, cast, get_args, overload import torch @@ -57,6 +57,8 @@ from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): + from vllm._custom_ops import moe_fused_gate + from .fused_moe import eplb_map_to_physical_and_record, fused_experts else: fused_experts = None # type: ignore @@ -96,6 +98,10 @@ def _eplb_map_to_physical_and_record( logger = init_logger(__name__) +def is_power_of_two(n): + return n > 0 and math.log2(n).is_integer() + + class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" @@ -410,6 +416,11 @@ def __init__( dp_size_=dp_size_, vllm_parallel_config=vllm_config.parallel_config, ) + enable_fused_shared_experts = envs.VLLM_USE_CUDA_FUSION_SHARED_EXPERTS + if enable_fused_shared_experts: + assert n_shared_experts is not None + num_experts += n_shared_experts + top_k += n_shared_experts self.global_num_experts = num_experts + num_redundant_experts self.logical_num_experts = num_experts @@ -443,7 +454,7 @@ def __init__( vllm_config.parallel_config.expert_placement_strategy ) - # ROCm aiter shared experts fusion + # ROCm aiter and CUDA shared experts fusion self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled() self.aiter_fmoe_shared_expert_enabled = ( rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() @@ -451,16 +462,19 @@ def __init__( self.num_fused_shared_experts = ( n_shared_experts - if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled + if n_shared_experts is not None + and (self.aiter_fmoe_shared_expert_enabled or enable_fused_shared_experts) else 0 ) if ( not self.aiter_fmoe_shared_expert_enabled + and not enable_fused_shared_experts and self.num_fused_shared_experts != 0 ): raise ValueError( "n_shared_experts is only supported on ROCm aiter when " "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled" + "and on CUDA when VLLM_USE_CUDA_FUSION_SHARED_EXPERTS is enabled" ) # Determine expert maps @@ -508,6 +522,15 @@ def __init__( self.global_num_experts, get_compressed_expert_map(self.expert_map), ) + if self.num_fused_shared_experts > 0: + logger.warning( + "With EP enabled and share expert fusion enabled" + ", share expert replica should be same as ep_size" + "got share expert replica = %d" + "and ep_size = %d", + self.num_fused_shared_experts, + self.ep_size, + ) else: self.local_num_experts, self.expert_map, self.expert_mask = ( self.global_num_experts, @@ -516,10 +539,10 @@ def __init__( ) self.top_k = top_k - - self._init_aiter_shared_experts_topK_buffer( - vllm_config=vllm_config, dp_size=dp_size_ - ) + if self.aiter_fmoe_shared_expert_enabled: + self._init_aiter_shared_experts_topK_buffer( + vllm_config=vllm_config, dp_size=dp_size_ + ) assert intermediate_size % self.tp_size == 0 self.hidden_size = hidden_size @@ -1578,24 +1601,40 @@ def select_experts( if rocm_aiter_ops.is_fused_moe_enabled(): if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): assert self.num_fused_shared_experts == 0 - grouped_topk_impl = partial( - rocm_aiter_grouped_topk, - num_fused_shared_experts=self.num_fused_shared_experts, - ) + grouped_topk_impl = rocm_aiter_grouped_topk else: grouped_topk_impl = grouped_topk - topk_weights, topk_ids = grouped_topk_impl( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.top_k, - renormalize=self.renormalize, - num_expert_group=self.num_expert_group, - topk_group=self.topk_group, - scoring_func=self.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, - e_score_correction_bias=self.e_score_correction_bias, - ) + if ( + envs.VLLM_USE_FUSED_MOE_ROUTER + and self.e_score_correction_bias is not None + and is_power_of_two(self.e_score_correction_bias.shape[0]) + ): + topk_weights, topk_ids = moe_fused_gate( + input_tensor=router_logits.to(dtype=torch.float32), + bias=self.e_score_correction_bias.data.to(dtype=torch.float32), + num_expert_group=self.num_expert_group, + topk_group=self.topk_group, + topk=self.top_k, + num_fused_shared_experts=self.num_fused_shared_experts, + routed_scaling_factor=self.routed_scaling_factor + if self.routed_scaling_factor is not None + else 1.0, + apply_routed_scaling_factor_on_output=False, + ) + else: + topk_weights, topk_ids = grouped_topk_impl( + hidden_states=hidden_states, + gating_output=router_logits, + topk=self.top_k, + renormalize=self.renormalize, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.e_score_correction_bias, + num_fused_shared_experts=self.num_fused_shared_experts, + ) elif self.e_score_correction_bias is not None: topk_weights, topk_ids = fused_topk_bias( hidden_states=hidden_states, diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 48e5a8907f92..987af5e06e73 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -316,6 +316,7 @@ def apply( expert_load_view=expert_load_view, logical_to_physical_map=logical_to_physical_map, logical_replica_count=logical_replica_count, + num_fused_shared_experts=layer.num_fused_shared_experts, ) def get_fused_moe_quant_config( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 73cac2556c55..9140732cce01 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -98,6 +98,8 @@ elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops as ops +import vllm.envs as envs + logger = init_logger(__name__) @@ -277,6 +279,7 @@ def __init__( self.enable_eplb = parallel_config.enable_eplb self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size @@ -289,6 +292,7 @@ def __init__( self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() self.is_fusion_moe_shared_experts_enabled = ( rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + or envs.VLLM_USE_CUDA_FUSION_SHARED_EXPERTS ) if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled: self.shared_experts = None @@ -304,7 +308,9 @@ def __init__( reduce_results=False, prefix=f"{prefix}.shared_experts", ) - + used_inside_scaling = ( + self.is_rocm_aiter_moe_enabled or self.is_fusion_moe_shared_experts_enabled + ) self.experts = SharedFusedMoE( shared_experts=self.shared_experts, gate=self.gate, @@ -323,7 +329,7 @@ def __init__( # we do scaling outside, set factor to 1.0 to avoid double mul # aiter applies routed_scaling_factor internally routed_scaling_factor=1.0 - if not self.is_rocm_aiter_moe_enabled + if not used_inside_scaling else self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, @@ -1432,8 +1438,9 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - rocm_aiter_moe_shared_expert_enabled = ( + is_fusion_moe_shared_experts_enabled = ( rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + or envs.VLLM_USE_CUDA_FUSION_SHARED_EXPERTS ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -1454,8 +1461,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: else: stacked_params_mapping.extend(mla_params_mapping) - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) + if is_fusion_moe_shared_experts_enabled: + logger.info( + "Cloning %s replicas of the shared expert into MoE", + self.num_shared_experts, + ) + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -1463,7 +1474,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: num_experts=self.config.n_routed_experts + ( self.config.n_shared_experts - if rocm_aiter_moe_shared_expert_enabled + if is_fusion_moe_shared_experts_enabled else 0 ), num_redundant_experts=self.num_redundant_experts, @@ -1480,7 +1491,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue # skip spec decode layers for main model is_fusion_moe_shared_experts_layer = ( - rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name) + is_fusion_moe_shared_experts_enabled and ("mlp.shared_experts" in name) ) for param_name, weight_name, shard_id in stacked_params_mapping: