1010from vllm .model_executor .layers .fused_moe .config import fp8_w8a8_moe_quant_config
1111from vllm .version import __version__ as vllm_version
1212
13- # from helper import get_sm_version, log_perf
13+ from helper import get_sm_version , log_perf
1414
15- def get_sm_version ():
16- return 86
15+ # def get_sm_version():
16+ # return 86
1717
18- def log_perf (* args , ** kwargs ):
19- pass
18+ # def log_perf(*args, **kwargs):
19+ # pass
2020
2121aic_debug = int (os .getenv ("aic_moe_debug" , "0" )) # noqa: SIM112
2222
@@ -34,7 +34,7 @@ def balanced_logits(num_tokens, num_experts, topk):
3434 h_selected_experts [token_i ][i ] = (token_i * stride / num_tokens + i * stride ) % num_experts
3535
3636 expert_map = F .one_hot (h_selected_experts .long (), num_classes = num_experts ).sum (1 )
37- router_logits = F .softmax (expert_map .bfloat16 (), dim = 1 )
37+ router_logits = F .softmax (expert_map .half (), dim = 1 )
3838 return router_logits
3939
4040
@@ -146,7 +146,7 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, return_first_g
146146
147147
148148 expert_map = F .one_hot (h_selected_experts .long (), num_classes = num_experts ).sum (1 )
149- router_logits = F .softmax (expert_map .bfloat16 (), dim = 1 )
149+ router_logits = F .softmax (expert_map .half (), dim = 1 )
150150 return router_logits
151151
152152
@@ -227,6 +227,10 @@ def get_moe_test_cases():
227227 if inter_s % tp != 0 :
228228 continue
229229
230+ # vllm does not support TP when EP is enabled.
231+ if tp > 1 and ep > 1 :
232+ continue
233+
230234 for power_law_alpha in alpha_list :
231235 test_cases .append (
232236 [
@@ -245,7 +249,7 @@ def get_moe_test_cases():
245249 ]
246250 )
247251
248- return test_cases [: 20 ]
252+ return test_cases
249253
250254def run_moe_torch (
251255 moe_type ,
@@ -267,7 +271,7 @@ def run_moe_torch(
267271 torch .set_default_device (device )
268272
269273 # Configure quantization parameters
270- dtype = torch .bfloat16
274+ dtype = torch .float16
271275 quant_config = None
272276
273277 if moe_type == "fp8" :
@@ -300,14 +304,14 @@ def run_moe_torch(
300304 local_num_experts ,
301305 2 * local_inter_size ,
302306 hidden_size ,
303- dtype = torch .bfloat16 ,
307+ dtype = torch .float16 ,
304308 device = device
305309 )
306310 w2 = torch .randn (
307311 local_num_experts ,
308312 hidden_size ,
309313 local_inter_size ,
310- dtype = torch .bfloat16 ,
314+ dtype = torch .float16 ,
311315 device = device
312316 )
313317
@@ -319,7 +323,7 @@ def run_moe_torch(
319323 for num_tokens_idx , num_tokens in enumerate (num_tokens_lists ):
320324 print ("num_tokens" , num_tokens )
321325 print ("topk" , topk )
322- hidden_states = torch .randn ([num_tokens , hidden_size ]).bfloat16 ().to (device )
326+ hidden_states = torch .randn ([num_tokens , hidden_size ]).half ().to (device )
323327
324328 # Generate topk_weights and topk_ids
325329 num_iter = 10 if distributed == "power_law" else 1
@@ -335,7 +339,7 @@ def run_moe_torch(
335339 moe_ep_size ,
336340 power_law_alpha ,
337341 return_first_gpu_only = True
338- ).bfloat16 ().to (device )
342+ ).half ().to (device )
339343 weights , ids = torch .topk (logits , local_topk , dim = - 1 )
340344 topk_weights_list .append (F .softmax (weights , dim = - 1 ))
341345 topk_ids_list .append (ids )
@@ -344,7 +348,7 @@ def run_moe_torch(
344348
345349 elif distributed == "balanced" :
346350 local_num_tokens = math .ceil (num_tokens / moe_ep_size )
347- actual_logits = balanced_logits (local_num_tokens , local_num_experts , local_topk ).bfloat16 ().to (device )
351+ actual_logits = balanced_logits (local_num_tokens , local_num_experts , local_topk ).half ().to (device )
348352 topk_weights , topk_ids = torch .topk (actual_logits , local_topk , dim = - 1 )
349353 topk_weights = F .softmax (topk_weights , dim = - 1 )
350354 print ("actual num_tokens: " , actual_logits .shape [0 ])
@@ -459,6 +463,8 @@ def run_iterations(use_cuda_graph=False):
459463 # test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 7168, 2048, 8, 384, 1, 4, 'KIMI_K2', 'moe_perf.txt', 'power_law', 1.01]]
460464 # test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 2048, 768, 8, 128, 2, 32, 'QWEN3_30B_A3B', 'moe_perf.txt', 'power_law', 1.01]]
461465
466+ test_cases = [['float16' ,[65536 ],4096 ,14336 ,2 ,8 ,1 ,1 ,'MOE_Mixtral8x7B' , 'moe_perf.txt' , 'power_law' , 1.01 ]]
467+
462468 print (f"Total test cases: { len (test_cases )} " )
463469
464470 for test_case in test_cases [:40 ]:
0 commit comments