88import torch .nn .functional as F
99from vllm .model_executor .layers .fused_moe import fused_experts
1010from vllm .model_executor .layers .fused_moe .config import fp8_w8a8_moe_quant_config
11+ from vllm .model_executor .layers .fused_moe .layer import determine_expert_map
1112from vllm .version import __version__ as vllm_version
1213
1314from helper import get_sm_version , log_perf
@@ -97,8 +98,10 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, return_first_g
9798 res = conv1d (num_tokens_per_expert .unsqueeze (0 ).unsqueeze (0 ).float ())
9899 max_ep_idx = torch .argmax (res ).item ()
99100
101+ # Number of experts per GPU
102+ ep_group_size = num_experts // ep
103+
100104 if max_ep_idx != 0 :
101- ep_group_size = num_experts // ep
102105 num_tokens_per_expert_reshaped = num_tokens_per_expert .view (ep , ep_group_size )
103106 num_tokens_per_expert_reshaped [0 ], num_tokens_per_expert_reshaped [max_ep_idx ] = (
104107 num_tokens_per_expert_reshaped [max_ep_idx ].clone (),
@@ -109,9 +112,6 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, return_first_g
109112 revised_num_tokens = num_tokens
110113 revised_topk = topk
111114 if return_first_gpu_only :
112- # Number of experts per GPU
113- ep_group_size = num_experts // ep
114-
115115 # How many experts will be run on the first GPU.
116116 # Can't exceed the number of experts per GPU.
117117 revised_topk = min (topk , ep_group_size )
@@ -294,7 +294,8 @@ def run_moe_torch(
294294 local_inter_size = inter_size // moe_tp_size
295295
296296 # How many experts will be run on this GPU
297- local_topk = min (topk , local_num_experts )
297+ # local_topk = min(topk, local_num_experts)
298+ local_topk = topk
298299
299300 # Create weight tensors
300301 # w1: gate + up projection weights [num_experts, 2 * inter_size, hidden_size]
@@ -314,6 +315,9 @@ def run_moe_torch(
314315 device = device ,
315316 )
316317
318+ # Maps global expert index to local expert index.
319+ _ , expert_map = determine_expert_map (moe_ep_size , 0 , num_experts )
320+
317321 if dtype == torch .float8_e4m3fn :
318322 w1 = w1 .to (dtype )
319323 w2 = w2 .to (dtype )
@@ -333,23 +337,29 @@ def run_moe_torch(
333337 for _ in range (num_iter ):
334338 logits = (
335339 power_law_logits_v3 (
336- num_tokens , num_experts , topk , moe_ep_size , power_law_alpha , return_first_gpu_only = True
340+ # num_tokens, num_experts, topk, moe_ep_size, power_law_alpha, return_first_gpu_only=True
341+ num_tokens ,
342+ num_experts ,
343+ topk ,
344+ moe_ep_size ,
345+ power_law_alpha ,
346+ return_first_gpu_only = False ,
337347 )
338348 .half ()
339349 .to (device )
340350 )
341- weights , ids = torch .topk (logits , local_topk , dim = - 1 )
351+ weights , ids = torch .topk (logits , topk , dim = - 1 )
352+ # weights, ids = torch.topk(logits, local_topk, dim=-1)
342353 topk_weights_list .append (F .softmax (weights , dim = - 1 ))
343354 topk_ids_list .append (ids )
344355
345356 print ("actual num_tokens: " , [topk_ids .shape [0 ] for topk_ids in topk_ids_list ])
346357
347358 elif distributed == "balanced" :
348- local_num_tokens = math . ceil (num_tokens / moe_ep_size )
349- actual_logits = balanced_logits (local_num_tokens , local_num_experts , local_topk ).half ().to (device )
359+ # actual_logits = balanced_logits (num_tokens, local_num_experts, local_topk).half().to(device )
360+ actual_logits = balanced_logits (num_tokens , num_experts , topk ).half ().to (device )
350361 topk_weights , topk_ids = torch .topk (actual_logits , local_topk , dim = - 1 )
351362 topk_weights = F .softmax (topk_weights , dim = - 1 )
352- print ("actual num_tokens: " , actual_logits .shape [0 ])
353363
354364 else :
355365 raise ValueError (f"Unsupported distributed mode: { distributed } " )
@@ -372,6 +382,8 @@ def run_single_iteration():
372382 ti ,
373383 inplace = True ,
374384 quant_config = quant_config ,
385+ global_num_experts = num_experts ,
386+ expert_map = expert_map ,
375387 )
376388 else :
377389 _ = fused_experts (
@@ -382,6 +394,8 @@ def run_single_iteration():
382394 topk_ids ,
383395 inplace = True ,
384396 quant_config = quant_config ,
397+ global_num_experts = num_experts ,
398+ expert_map = expert_map ,
385399 )
386400
387401 def run_iterations (use_cuda_graph = False ):
@@ -453,6 +467,8 @@ def run_iterations(use_cuda_graph=False):
453467 test_cases = get_moe_test_cases ()
454468 print (f"Total test cases: { len (test_cases )} " )
455469
470+ # test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 4096, 14336, 2, 8, 1, 2, 'MOE_Mixtral8x7B', 'moe_perf.txt', 'power_law', 1.01]]
471+
456472 for test_case in test_cases :
457473 print (f"Running test case: { test_case } " )
458474 try :
0 commit comments