11# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4- import math
54import os
65
76import torch
87import torch .nn .functional as F
8+ from common_test_cases import get_common_moe_test_cases
99from vllm .model_executor .layers .fused_moe import fused_experts
1010from vllm .model_executor .layers .fused_moe .config import fp8_w8a8_moe_quant_config
11+ from vllm .model_executor .layers .fused_moe .layer import determine_expert_map
1112from vllm .version import __version__ as vllm_version
1213
13- from helper import get_sm_version , log_perf
14-
15- # def get_sm_version():
16- # return 86
17-
18- # def log_perf(*args, **kwargs):
19- # pass
14+ from helper import balanced_logits , get_sm_version , log_perf , power_law_logits_v3
2015
2116aic_debug = int (os .getenv ("aic_moe_debug" , "0" )) # noqa: SIM112
2217
2318
24- def balanced_logits (num_tokens , num_experts , topk ):
25- """Generate balanced distribution router logits"""
26- h_selected_experts = - torch .ones ([num_tokens , topk ])
27- stride = math .ceil (num_experts / topk )
28-
29- for token_i in range (num_tokens ):
30- for i in range (topk ):
31- if num_tokens >= stride :
32- h_selected_experts [token_i ][i ] = (token_i + i * stride ) % num_experts
33- else :
34- h_selected_experts [token_i ][i ] = (token_i * stride / num_tokens + i * stride ) % num_experts
35-
36- expert_map = F .one_hot (h_selected_experts .long (), num_classes = num_experts ).sum (1 )
37- router_logits = F .softmax (expert_map .half (), dim = 1 )
38- return router_logits
39-
40-
41- def sample_power_law (size , alpha , xmin , xmax ):
42- """Sample from power law distribution"""
43- u = torch .rand (size )
44- inv_cdf = ((xmax ** (1 - alpha ) - xmin ** (1 - alpha )) * u + xmin ** (1 - alpha )) ** (1 / (1 - alpha ))
45- return inv_cdf
46-
47-
48- def power_law_logits_v3 (num_tokens , num_experts , topk , ep , alpha , return_first_gpu_only = False ):
49- """Generate power law distributed router logits (simulating real-world load imbalance scenarios)"""
50- if num_tokens * topk > num_experts :
51- num_tokens_per_expert = sample_power_law (num_experts , alpha , 1 , num_tokens * 0.8 )
52- else :
53- num_tokens_per_expert = sample_power_law (num_experts , alpha , 0.01 , 2 )
54-
55- target_sum = num_tokens * topk
56-
57- original_distribution = num_tokens_per_expert / num_tokens_per_expert .sum ()
58-
59- target_distribution = original_distribution * target_sum
60-
61- num_tokens_per_expert = torch .round (target_distribution ).to (torch .int64 )
62-
63- current_sum = num_tokens_per_expert .sum ().item ()
64- delta = target_sum - current_sum
65- if delta != 0 :
66- sorted_indices = torch .argsort (num_tokens_per_expert , descending = True )
67-
68- if delta > 0 :
69- for i in range (delta ):
70- expert_idx = sorted_indices [i % len (sorted_indices )]
71- num_tokens_per_expert [expert_idx ] += 1
72- else :
73- for i in range (- delta ):
74- expert_idx = sorted_indices [- (i % len (sorted_indices )) - 1 ]
75- if num_tokens_per_expert [expert_idx ] > 0 :
76- num_tokens_per_expert [expert_idx ] -= 1
77- else :
78- num_tokens_per_expert [torch .argmax (num_tokens_per_expert )] -= 1
79-
80- if len (num_tokens_per_expert ) > 1 :
81- sorted_tokens = torch .sort (num_tokens_per_expert , descending = True )[0 ]
82- assert sorted_tokens [0 ] >= sorted_tokens [- 1 ], "Power law distribution pattern disrupted"
83-
84- # Ensure the busiest expert group in EP dimension is placed on the first rank
85- with torch .no_grad ():
86- conv1d = torch .nn .Conv1d (
87- in_channels = 1 ,
88- out_channels = 1 ,
89- kernel_size = num_experts // ep ,
90- stride = num_experts // ep ,
91- padding = 0 ,
92- bias = False ,
93- )
94- conv1d_weights = torch .tensor ([1 for _ in range (num_experts // ep )])
95- conv1d .weight .copy_ (conv1d_weights )
96-
97- res = conv1d (num_tokens_per_expert .unsqueeze (0 ).unsqueeze (0 ).float ())
98- max_ep_idx = torch .argmax (res ).item ()
99-
100- if max_ep_idx != 0 :
101- ep_group_size = num_experts // ep
102- num_tokens_per_expert_reshaped = num_tokens_per_expert .view (ep , ep_group_size )
103- num_tokens_per_expert_reshaped [0 ], num_tokens_per_expert_reshaped [max_ep_idx ] = (
104- num_tokens_per_expert_reshaped [max_ep_idx ].clone (),
105- num_tokens_per_expert_reshaped [0 ].clone (),
106- )
107- num_tokens_per_expert = num_tokens_per_expert_reshaped .view (- 1 )
108-
109- revised_num_tokens = num_tokens
110- revised_topk = topk
111- if return_first_gpu_only :
112- # Number of experts per GPU
113- ep_group_size = num_experts // ep
114-
115- # How many experts will be run on the first GPU.
116- # Can't exceed the number of experts per GPU.
117- revised_topk = min (topk , ep_group_size )
118-
119- # Only generate token -> expert assignments for the first GPU.
120- num_tokens_per_expert = num_tokens_per_expert [:ep_group_size ]
121-
122- # Bump up the total number of tokens on the first GPU
123- # to be a multiple of revised_topk.
124- tokens_on_first_gpu = torch .sum (num_tokens_per_expert ).item ()
125- num_extra_tokens = (revised_topk - (tokens_on_first_gpu % revised_topk )) % revised_topk
126- for i in range (num_extra_tokens ):
127- num_tokens_per_expert [i % len (num_tokens_per_expert )] += 1
128- tokens_on_first_gpu = torch .sum (num_tokens_per_expert ).item ()
129- assert tokens_on_first_gpu % revised_topk == 0
130-
131- # Now revised_num_tokens represents only the tokens on the first GPU.
132- revised_num_tokens = tokens_on_first_gpu // revised_topk
133-
134- if aic_debug == 2 :
135- print ("num_tokens_per_expert" , num_tokens_per_expert , num_tokens_per_expert .sum ().item ())
136-
137- _ , num_tokens_per_expert_sorted_index = torch .sort (num_tokens_per_expert , descending = True )
138- expert_assignments = []
139- num_tokens_per_expert_sorted_index_lists = num_tokens_per_expert_sorted_index .tolist ()
140- for expert_id in num_tokens_per_expert_sorted_index_lists :
141- expert_assignments .extend ([expert_id ] * num_tokens_per_expert [expert_id ])
142-
143- expert_assignments = torch .tensor (expert_assignments , dtype = torch .long )
144- h_selected_experts = expert_assignments .reshape (revised_topk , revised_num_tokens ).T
145-
146- expert_map = F .one_hot (h_selected_experts .long (), num_classes = num_experts ).sum (1 )
147- router_logits = F .softmax (expert_map .half (), dim = 1 )
148- return router_logits
149-
150-
15119def get_moe_test_cases ():
15220 """Generate MoE test cases"""
153- num_tokens = [
154- 1 ,
155- 2 ,
156- 4 ,
157- 8 ,
158- 16 ,
159- 32 ,
160- 48 ,
161- 64 ,
162- 80 ,
163- 96 ,
164- 128 ,
165- 160 ,
166- 192 ,
167- 256 ,
168- 320 ,
169- 384 ,
170- 512 ,
171- 768 ,
172- 1024 ,
173- 1536 ,
174- 2048 ,
175- 3072 ,
176- 4096 ,
177- 6144 ,
178- 8192 ,
179- 12288 ,
180- 16384 ,
181- 20480 ,
182- 32768 ,
183- 65536 ,
184- ]
185- tp_list = [1 , 2 , 4 , 8 , 16 , 32 ]
186- ep_list = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
187- num_gpu_list = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
188- alpha_list = [1.01 , 1.2 ]
189-
190- # Model configurations: [hidden_size, inter_size, topk, num_experts, model_name]
191- model_config_list = [
192- [4096 , 14336 , 2 , 8 , "MOE_Mixtral8x7B" ], # mixtral_8x7b
193- [6144 , 16384 , 2 , 8 , "MOE_Mixtral8x22B" ], # mixtral_8x22b
194- [7168 , 2048 , 8 , 256 , "DEEPSEEK_V3" ], # deepseekv3
195- [2048 , 768 , 8 , 128 , "QWEN3_30B_A3B" ], # qwen3-moe, 30b-a3b
196- [4096 , 1536 , 8 , 128 , "QWEN3_235B" ], # qwen3-moe, 235b-a22b
197- [6144 , 2560 , 8 , 160 , "QWEN3_480B" ], # qwen3-moe, 480b-a35b
198- [7168 , 2048 , 8 , 384 , "KIMI_K2" ], # kimi k2
199- ]
20021
20122 # Quantization types supported by vLLM
20223 moe_list = ["float16" ]
203-
20424 if get_sm_version () > 86 :
20525 moe_list += ["fp8" ]
20626
20727 test_cases = []
20828
209- for num_gpu in num_gpu_list :
29+ for common_moe_testcase in get_common_moe_test_cases ():
30+ if common_moe_testcase .token_expert_distribution != "power_law" :
31+ continue
32+
33+ model_name = common_moe_testcase .model_name
34+ if model_name in ["GPT_OSS_20B" , "GPT_OSS_120B" ]:
35+ continue
36+
37+ # vllm does not support TP when EP is enabled.
38+ if common_moe_testcase .tp > 1 and common_moe_testcase .ep > 1 :
39+ continue
40+
21041 for moe_type in moe_list :
211- for model_config in model_config_list :
212- hs , inter_s , topk , num_experts , model_name = model_config
213- for tp in tp_list :
214- # QWEN3_30B_A3B: exclude tp >= 8 as they are not used in actual deployments
215- if model_name == "QWEN3_30B_A3B" and tp >= 8 :
216- continue
217- for ep in ep_list :
218- if tp * ep != num_gpu :
219- continue
220- if ep > num_experts :
221- continue
222- if num_experts % ep != 0 :
223- continue
224- # Ensure inter_s can be divided by tp
225- if inter_s % tp != 0 :
226- continue
227-
228- # vllm does not support TP when EP is enabled.
229- if tp > 1 and ep > 1 :
230- continue
231-
232- for power_law_alpha in alpha_list :
233- test_cases .append (
234- [
235- moe_type ,
236- num_tokens ,
237- hs ,
238- inter_s ,
239- topk ,
240- num_experts ,
241- tp ,
242- ep ,
243- model_name ,
244- "moe_perf.txt" ,
245- "power_law" ,
246- power_law_alpha ,
247- ]
248- )
42+ test_cases .append (
43+ [
44+ moe_type ,
45+ common_moe_testcase .num_tokens_list ,
46+ common_moe_testcase .hidden_size ,
47+ common_moe_testcase .inter_size ,
48+ common_moe_testcase .topk ,
49+ common_moe_testcase .num_experts ,
50+ common_moe_testcase .tp ,
51+ common_moe_testcase .ep ,
52+ common_moe_testcase .model_name ,
53+ "moe_perf.txt" ,
54+ common_moe_testcase .token_expert_distribution ,
55+ common_moe_testcase .power_law_alpha ,
56+ ]
57+ )
24958
25059 return test_cases
25160
@@ -293,9 +102,6 @@ def run_moe_torch(
293102 local_num_experts = num_experts // moe_ep_size
294103 local_inter_size = inter_size // moe_tp_size
295104
296- # How many experts will be run on this GPU
297- local_topk = min (topk , local_num_experts )
298-
299105 # Create weight tensors
300106 # w1: gate + up projection weights [num_experts, 2 * inter_size, hidden_size]
301107 # w2: down projection weights [num_experts, hidden_size, inter_size]
@@ -314,6 +120,9 @@ def run_moe_torch(
314120 device = device ,
315121 )
316122
123+ # Maps global expert index to local expert index.
124+ _ , expert_map = determine_expert_map (moe_ep_size , 0 , num_experts )
125+
317126 if dtype == torch .float8_e4m3fn :
318127 w1 = w1 .to (dtype )
319128 w2 = w2 .to (dtype )
@@ -333,23 +142,25 @@ def run_moe_torch(
333142 for _ in range (num_iter ):
334143 logits = (
335144 power_law_logits_v3 (
336- num_tokens , num_experts , topk , moe_ep_size , power_law_alpha , return_first_gpu_only = True
145+ num_tokens ,
146+ num_experts ,
147+ topk ,
148+ moe_ep_size ,
149+ power_law_alpha ,
337150 )
338151 .half ()
339152 .to (device )
340153 )
341- weights , ids = torch .topk (logits , local_topk , dim = - 1 )
154+ weights , ids = torch .topk (logits , topk , dim = - 1 )
342155 topk_weights_list .append (F .softmax (weights , dim = - 1 ))
343156 topk_ids_list .append (ids )
344157
345158 print ("actual num_tokens: " , [topk_ids .shape [0 ] for topk_ids in topk_ids_list ])
346159
347160 elif distributed == "balanced" :
348- local_num_tokens = math .ceil (num_tokens / moe_ep_size )
349- actual_logits = balanced_logits (local_num_tokens , local_num_experts , local_topk ).half ().to (device )
350- topk_weights , topk_ids = torch .topk (actual_logits , local_topk , dim = - 1 )
161+ actual_logits = balanced_logits (num_tokens , num_experts , topk ).half ().to (device )
162+ topk_weights , topk_ids = torch .topk (actual_logits , topk , dim = - 1 )
351163 topk_weights = F .softmax (topk_weights , dim = - 1 )
352- print ("actual num_tokens: " , actual_logits .shape [0 ])
353164
354165 else :
355166 raise ValueError (f"Unsupported distributed mode: { distributed } " )
@@ -372,6 +183,8 @@ def run_single_iteration():
372183 ti ,
373184 inplace = True ,
374185 quant_config = quant_config ,
186+ global_num_experts = num_experts ,
187+ expert_map = expert_map ,
375188 )
376189 else :
377190 _ = fused_experts (
@@ -382,6 +195,8 @@ def run_single_iteration():
382195 topk_ids ,
383196 inplace = True ,
384197 quant_config = quant_config ,
198+ global_num_experts = num_experts ,
199+ expert_map = expert_map ,
385200 )
386201
387202 def run_iterations (use_cuda_graph = False ):
0 commit comments