Skip to content

Commit b149d9c

Browse files
committed
add fusion of shared expert and fused_moe_gate
Signed-off-by: Barbara Suslova <[email protected]>
1 parent d381eb9 commit b149d9c

File tree

22 files changed

+950
-47
lines changed

22 files changed

+950
-47
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,8 @@ set(VLLM_MOE_EXT_SRC
884884
"csrc/moe/torch_bindings.cpp"
885885
"csrc/moe/moe_align_sum_kernels.cu"
886886
"csrc/moe/moe_lora_align_sum_kernels.cu"
887-
"csrc/moe/topk_softmax_kernels.cu")
887+
"csrc/moe/topk_softmax_kernels.cu"
888+
"csrc/moe/moe_fused_gate.cu")
888889

889890
if(VLLM_GPU_LANG STREQUAL "CUDA")
890891
list(APPEND VLLM_MOE_EXT_SRC
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
5+
from vllm._custom_ops import moe_fused_gate
6+
from vllm.model_executor.layers.fused_moe.fused_moe import (
7+
grouped_topk as vllm_compiled_grouped_topk,
8+
)
9+
from vllm.triton_utils import triton
10+
11+
12+
def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
13+
return vllm_compiled_grouped_topk(
14+
hidden_states=scores,
15+
gating_output=scores,
16+
topk=topk,
17+
renormalize=True,
18+
num_expert_group=num_expert_group,
19+
topk_group=topk_group,
20+
scoring_func="sigmoid",
21+
e_score_correction_bias=bias,
22+
)
23+
24+
25+
def biased_grouped_topk_org_kernel(scores, bias, num_expert_group, topk_group, topk):
26+
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)
27+
28+
29+
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
30+
configs = [(sq,) for sq in seq_length_range]
31+
32+
33+
@triton.testing.perf_report(
34+
triton.testing.Benchmark(
35+
x_names=["seq_length"],
36+
x_vals=[list(_) for _ in configs],
37+
line_arg="provider",
38+
line_vals=["original", "kernel"],
39+
line_names=["Original", "SGL Kernel"],
40+
styles=[("blue", "-"), ("red", "-")],
41+
ylabel="us",
42+
plot_name="moe-fused-gate-performance",
43+
args={},
44+
)
45+
)
46+
def benchmark(seq_length, provider):
47+
dtype = torch.bfloat16
48+
device = torch.device("cuda")
49+
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
50+
51+
scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype)
52+
bias = torch.rand(num_experts, device=device, dtype=dtype)
53+
54+
quantiles = [0.5, 0.2, 0.8]
55+
56+
if provider == "original":
57+
ms, min_ms, max_ms = triton.testing.do_bench(
58+
lambda: biased_grouped_topk_org(
59+
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
60+
),
61+
quantiles=quantiles,
62+
)
63+
elif provider == "kernel":
64+
ms, min_ms, max_ms = triton.testing.do_bench(
65+
lambda: biased_grouped_topk_org_kernel(
66+
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
67+
),
68+
quantiles=quantiles,
69+
)
70+
71+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
72+
73+
74+
if __name__ == "__main__":
75+
benchmark.run(print_data=True)

0 commit comments

Comments
 (0)