Skip to content

Commit ff4e819

Browse files
committed
revert
1 parent 7ee6c37 commit ff4e819

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

collector/vllm/collect_moe.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch.nn.functional as F
99
from vllm.model_executor.layers.fused_moe import fused_experts
1010
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
11+
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
1112
from vllm.version import __version__ as vllm_version
1213

1314
from helper import get_sm_version, log_perf
@@ -97,8 +98,10 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, return_first_g
9798
res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float())
9899
max_ep_idx = torch.argmax(res).item()
99100

101+
# Number of experts per GPU
102+
ep_group_size = num_experts // ep
103+
100104
if max_ep_idx != 0:
101-
ep_group_size = num_experts // ep
102105
num_tokens_per_expert_reshaped = num_tokens_per_expert.view(ep, ep_group_size)
103106
num_tokens_per_expert_reshaped[0], num_tokens_per_expert_reshaped[max_ep_idx] = (
104107
num_tokens_per_expert_reshaped[max_ep_idx].clone(),
@@ -109,9 +112,6 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, return_first_g
109112
revised_num_tokens = num_tokens
110113
revised_topk = topk
111114
if return_first_gpu_only:
112-
# Number of experts per GPU
113-
ep_group_size = num_experts // ep
114-
115115
# How many experts will be run on the first GPU.
116116
# Can't exceed the number of experts per GPU.
117117
revised_topk = min(topk, ep_group_size)
@@ -294,7 +294,8 @@ def run_moe_torch(
294294
local_inter_size = inter_size // moe_tp_size
295295

296296
# How many experts will be run on this GPU
297-
local_topk = min(topk, local_num_experts)
297+
# local_topk = min(topk, local_num_experts)
298+
local_topk = topk
298299

299300
# Create weight tensors
300301
# w1: gate + up projection weights [num_experts, 2 * inter_size, hidden_size]
@@ -314,6 +315,9 @@ def run_moe_torch(
314315
device=device,
315316
)
316317

318+
# Maps global expert index to local expert index.
319+
_, expert_map = determine_expert_map(moe_ep_size, 0, num_experts)
320+
317321
if dtype == torch.float8_e4m3fn:
318322
w1 = w1.to(dtype)
319323
w2 = w2.to(dtype)
@@ -333,23 +337,29 @@ def run_moe_torch(
333337
for _ in range(num_iter):
334338
logits = (
335339
power_law_logits_v3(
336-
num_tokens, num_experts, topk, moe_ep_size, power_law_alpha, return_first_gpu_only=True
340+
# num_tokens, num_experts, topk, moe_ep_size, power_law_alpha, return_first_gpu_only=True
341+
num_tokens,
342+
num_experts,
343+
topk,
344+
moe_ep_size,
345+
power_law_alpha,
346+
return_first_gpu_only=False,
337347
)
338348
.half()
339349
.to(device)
340350
)
341-
weights, ids = torch.topk(logits, local_topk, dim=-1)
351+
weights, ids = torch.topk(logits, topk, dim=-1)
352+
# weights, ids = torch.topk(logits, local_topk, dim=-1)
342353
topk_weights_list.append(F.softmax(weights, dim=-1))
343354
topk_ids_list.append(ids)
344355

345356
print("actual num_tokens: ", [topk_ids.shape[0] for topk_ids in topk_ids_list])
346357

347358
elif distributed == "balanced":
348-
local_num_tokens = math.ceil(num_tokens / moe_ep_size)
349-
actual_logits = balanced_logits(local_num_tokens, local_num_experts, local_topk).half().to(device)
359+
# actual_logits = balanced_logits(num_tokens, local_num_experts, local_topk).half().to(device)
360+
actual_logits = balanced_logits(num_tokens, num_experts, topk).half().to(device)
350361
topk_weights, topk_ids = torch.topk(actual_logits, local_topk, dim=-1)
351362
topk_weights = F.softmax(topk_weights, dim=-1)
352-
print("actual num_tokens: ", actual_logits.shape[0])
353363

354364
else:
355365
raise ValueError(f"Unsupported distributed mode: {distributed}")
@@ -372,6 +382,8 @@ def run_single_iteration():
372382
ti,
373383
inplace=True,
374384
quant_config=quant_config,
385+
global_num_experts=num_experts,
386+
expert_map=expert_map,
375387
)
376388
else:
377389
_ = fused_experts(
@@ -382,6 +394,8 @@ def run_single_iteration():
382394
topk_ids,
383395
inplace=True,
384396
quant_config=quant_config,
397+
global_num_experts=num_experts,
398+
expert_map=expert_map,
385399
)
386400

387401
def run_iterations(use_cuda_graph=False):
@@ -453,6 +467,8 @@ def run_iterations(use_cuda_graph=False):
453467
test_cases = get_moe_test_cases()
454468
print(f"Total test cases: {len(test_cases)}")
455469

470+
# test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 4096, 14336, 2, 8, 1, 2, 'MOE_Mixtral8x7B', 'moe_perf.txt', 'power_law', 1.01]]
471+
456472
for test_case in test_cases:
457473
print(f"Running test case: {test_case}")
458474
try:

0 commit comments

Comments
 (0)