@@ -106,7 +106,6 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, return_first_g
106106 )
107107 num_tokens_per_expert = num_tokens_per_expert_reshaped .view (- 1 )
108108
109-
110109 revised_num_tokens = num_tokens
111110 revised_topk = topk
112111 if return_first_gpu_only :
@@ -144,7 +143,6 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, return_first_g
144143 expert_assignments = torch .tensor (expert_assignments , dtype = torch .long )
145144 h_selected_experts = expert_assignments .reshape (revised_topk , revised_num_tokens ).T
146145
147-
148146 expert_map = F .one_hot (h_selected_experts .long (), num_classes = num_experts ).sum (1 )
149147 router_logits = F .softmax (expert_map .half (), dim = 1 )
150148 return router_logits
@@ -188,7 +186,7 @@ def get_moe_test_cases():
188186 ep_list = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
189187 num_gpu_list = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
190188 alpha_list = [1.01 , 1.2 ]
191-
189+
192190 # Model configurations: [hidden_size, inter_size, topk, num_experts, model_name]
193191 model_config_list = [
194192 [4096 , 14336 , 2 , 8 , "MOE_Mixtral8x7B" ], # mixtral_8x7b
@@ -204,7 +202,7 @@ def get_moe_test_cases():
204202 moe_list = ["float16" ]
205203
206204 if get_sm_version () > 86 :
207- moe_list += ["fp8" , ]
205+ moe_list += ["fp8" ]
208206
209207 test_cases = []
210208
@@ -251,6 +249,7 @@ def get_moe_test_cases():
251249
252250 return test_cases
253251
252+
254253def run_moe_torch (
255254 moe_type ,
256255 num_tokens_lists ,
@@ -301,18 +300,18 @@ def run_moe_torch(
301300 # w1: gate + up projection weights [num_experts, 2 * inter_size, hidden_size]
302301 # w2: down projection weights [num_experts, hidden_size, inter_size]
303302 w1 = torch .randn (
304- local_num_experts ,
305- 2 * local_inter_size ,
306- hidden_size ,
307- dtype = torch .float16 ,
308- device = device
303+ local_num_experts ,
304+ 2 * local_inter_size ,
305+ hidden_size ,
306+ dtype = torch .float16 ,
307+ device = device ,
309308 )
310309 w2 = torch .randn (
311- local_num_experts ,
312- hidden_size ,
313- local_inter_size ,
314- dtype = torch .float16 ,
315- device = device
310+ local_num_experts ,
311+ hidden_size ,
312+ local_inter_size ,
313+ dtype = torch .float16 ,
314+ device = device ,
316315 )
317316
318317 if dtype == torch .float8_e4m3fn :
@@ -332,14 +331,13 @@ def run_moe_torch(
332331 topk_ids_list = []
333332
334333 for _ in range (num_iter ):
335- logits = power_law_logits_v3 (
336- num_tokens ,
337- num_experts ,
338- topk ,
339- moe_ep_size ,
340- power_law_alpha ,
341- return_first_gpu_only = True
342- ).half ().to (device )
334+ logits = (
335+ power_law_logits_v3 (
336+ num_tokens , num_experts , topk , moe_ep_size , power_law_alpha , return_first_gpu_only = True
337+ )
338+ .half ()
339+ .to (device )
340+ )
343341 weights , ids = torch .topk (logits , local_topk , dim = - 1 )
344342 topk_weights_list .append (F .softmax (weights , dim = - 1 ))
345343 topk_ids_list .append (ids )
@@ -356,7 +354,6 @@ def run_moe_torch(
356354 else :
357355 raise ValueError (f"Unsupported distributed mode: { distributed } " )
358356
359-
360357 num_warmups = 3
361358 num_runs = 6
362359 if distributed == "power_law" :
@@ -418,11 +415,11 @@ def run_iterations(use_cuda_graph=False):
418415
419416 try :
420417 latency = run_iterations (use_cuda_graph = False )
421- except torch .OutOfMemoryError as e :
418+ except torch .OutOfMemoryError :
422419 # If OOM, check if we had at least one successful run.
423420 if num_tokens_idx > 0 :
424421 break
425- raise e
422+ raise
426423
427424 print (f"moe latency: { latency } " )
428425
@@ -454,24 +451,13 @@ def run_iterations(use_cuda_graph=False):
454451
455452if __name__ == "__main__" :
456453 test_cases = get_moe_test_cases ()
457- # test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 4096, 14336, 2, 8, 1, 2, 'MOE_Mixtral8x7B', 'moe_perf.txt', 'power_law', 1.2]]
458- # test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 4096, 14336, 2, 8, 1, 2, 'MOE_Mixtral8x7B', 'moe_perf.txt', 'balanced', 1.01]]
459- # test_cases = [['float16', [128, 256, 320], 4096, 14336, 2, 8, 1, 2, 'MOE_Mixtral8x7B', 'moe_perf.txt', 'power_law', 1.01]]
460- test_cases = [['float16' , [1 , 2 , 4 , 8 , 16 , 32 , 48 , 64 , 80 , 96 , 128 , 160 , 192 , 256 , 320 , 384 , 512 , 768 , 1024 , 1536 , 2048 , 3072 , 4096 , 6144 , 8192 , 12288 , 16384 , 20480 , 32768 , 65536 ], 4096 , 14336 , 2 , 8 , 1 , 2 , 'MOE_Mixtral8x7B' , 'moe_perf.txt' , 'power_law' , 1.01 ]]
461- # test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 6144, 2560, 8, 160, 1, 1, 'QWEN3_480B', 'moe_perf.txt', 'power_law', 1.2]]
462- # test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 7168, 2048, 8, 384, 1, 1, 'KIMI_K2', 'moe_perf.txt', 'power_law', 1.01]]
463- # test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 7168, 2048, 8, 384, 1, 4, 'KIMI_K2', 'moe_perf.txt', 'power_law', 1.01]]
464- # test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 2048, 768, 8, 128, 2, 32, 'QWEN3_30B_A3B', 'moe_perf.txt', 'power_law', 1.01]]
465-
466- test_cases = [['float16' ,[65536 ],4096 ,14336 ,2 ,8 ,1 ,1 ,'MOE_Mixtral8x7B' , 'moe_perf.txt' , 'power_law' , 1.01 ]]
467-
468454 print (f"Total test cases: { len (test_cases )} " )
469-
470- for test_case in test_cases [: 40 ] :
455+
456+ for test_case in test_cases :
471457 print (f"Running test case: { test_case } " )
472458 try :
473459 run_moe_torch (* test_case )
474460 except Exception as e :
475461 print (f"Test case failed: { test_case } " )
476462 print (f"Error: { e } " )
477- continue
463+ continue
0 commit comments