Skip to content

Commit 2f5b658

Browse files
committed
Revert return_first_gpu_only and address comments
Signed-off-by: Ilya Sherstyuk <[email protected]>
1 parent c9ab33b commit 2f5b658

File tree

6 files changed

+76
-240
lines changed

6 files changed

+76
-240
lines changed

collector/collect.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,6 @@ def collect_ops(
335335
"traceback": traceback.format_exc(),
336336
}
337337
)
338-
return all_errors
339338

340339
return all_errors
341340

collector/vllm/collect_moe.py

Lines changed: 46 additions & 231 deletions
Original file line numberDiff line numberDiff line change
@@ -1,251 +1,60 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import math
54
import os
65

76
import torch
87
import torch.nn.functional as F
8+
from common_test_cases import get_common_moe_test_cases
99
from vllm.model_executor.layers.fused_moe import fused_experts
1010
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
11+
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
1112
from vllm.version import __version__ as vllm_version
1213

13-
from helper import get_sm_version, log_perf
14-
15-
# def get_sm_version():
16-
# return 86
17-
18-
# def log_perf(*args, **kwargs):
19-
# pass
14+
from helper import balanced_logits, get_sm_version, log_perf, power_law_logits_v3
2015

2116
aic_debug = int(os.getenv("aic_moe_debug", "0")) # noqa: SIM112
2217

2318

24-
def balanced_logits(num_tokens, num_experts, topk):
25-
"""Generate balanced distribution router logits"""
26-
h_selected_experts = -torch.ones([num_tokens, topk])
27-
stride = math.ceil(num_experts / topk)
28-
29-
for token_i in range(num_tokens):
30-
for i in range(topk):
31-
if num_tokens >= stride:
32-
h_selected_experts[token_i][i] = (token_i + i * stride) % num_experts
33-
else:
34-
h_selected_experts[token_i][i] = (token_i * stride / num_tokens + i * stride) % num_experts
35-
36-
expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
37-
router_logits = F.softmax(expert_map.half(), dim=1)
38-
return router_logits
39-
40-
41-
def sample_power_law(size, alpha, xmin, xmax):
42-
"""Sample from power law distribution"""
43-
u = torch.rand(size)
44-
inv_cdf = ((xmax ** (1 - alpha) - xmin ** (1 - alpha)) * u + xmin ** (1 - alpha)) ** (1 / (1 - alpha))
45-
return inv_cdf
46-
47-
48-
def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, return_first_gpu_only=False):
49-
"""Generate power law distributed router logits (simulating real-world load imbalance scenarios)"""
50-
if num_tokens * topk > num_experts:
51-
num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8)
52-
else:
53-
num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2)
54-
55-
target_sum = num_tokens * topk
56-
57-
original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum()
58-
59-
target_distribution = original_distribution * target_sum
60-
61-
num_tokens_per_expert = torch.round(target_distribution).to(torch.int64)
62-
63-
current_sum = num_tokens_per_expert.sum().item()
64-
delta = target_sum - current_sum
65-
if delta != 0:
66-
sorted_indices = torch.argsort(num_tokens_per_expert, descending=True)
67-
68-
if delta > 0:
69-
for i in range(delta):
70-
expert_idx = sorted_indices[i % len(sorted_indices)]
71-
num_tokens_per_expert[expert_idx] += 1
72-
else:
73-
for i in range(-delta):
74-
expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1]
75-
if num_tokens_per_expert[expert_idx] > 0:
76-
num_tokens_per_expert[expert_idx] -= 1
77-
else:
78-
num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1
79-
80-
if len(num_tokens_per_expert) > 1:
81-
sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0]
82-
assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted"
83-
84-
# Ensure the busiest expert group in EP dimension is placed on the first rank
85-
with torch.no_grad():
86-
conv1d = torch.nn.Conv1d(
87-
in_channels=1,
88-
out_channels=1,
89-
kernel_size=num_experts // ep,
90-
stride=num_experts // ep,
91-
padding=0,
92-
bias=False,
93-
)
94-
conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)])
95-
conv1d.weight.copy_(conv1d_weights)
96-
97-
res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float())
98-
max_ep_idx = torch.argmax(res).item()
99-
100-
if max_ep_idx != 0:
101-
ep_group_size = num_experts // ep
102-
num_tokens_per_expert_reshaped = num_tokens_per_expert.view(ep, ep_group_size)
103-
num_tokens_per_expert_reshaped[0], num_tokens_per_expert_reshaped[max_ep_idx] = (
104-
num_tokens_per_expert_reshaped[max_ep_idx].clone(),
105-
num_tokens_per_expert_reshaped[0].clone(),
106-
)
107-
num_tokens_per_expert = num_tokens_per_expert_reshaped.view(-1)
108-
109-
revised_num_tokens = num_tokens
110-
revised_topk = topk
111-
if return_first_gpu_only:
112-
# Number of experts per GPU
113-
ep_group_size = num_experts // ep
114-
115-
# How many experts will be run on the first GPU.
116-
# Can't exceed the number of experts per GPU.
117-
revised_topk = min(topk, ep_group_size)
118-
119-
# Only generate token -> expert assignments for the first GPU.
120-
num_tokens_per_expert = num_tokens_per_expert[:ep_group_size]
121-
122-
# Bump up the total number of tokens on the first GPU
123-
# to be a multiple of revised_topk.
124-
tokens_on_first_gpu = torch.sum(num_tokens_per_expert).item()
125-
num_extra_tokens = (revised_topk - (tokens_on_first_gpu % revised_topk)) % revised_topk
126-
for i in range(num_extra_tokens):
127-
num_tokens_per_expert[i % len(num_tokens_per_expert)] += 1
128-
tokens_on_first_gpu = torch.sum(num_tokens_per_expert).item()
129-
assert tokens_on_first_gpu % revised_topk == 0
130-
131-
# Now revised_num_tokens represents only the tokens on the first GPU.
132-
revised_num_tokens = tokens_on_first_gpu // revised_topk
133-
134-
if aic_debug == 2:
135-
print("num_tokens_per_expert", num_tokens_per_expert, num_tokens_per_expert.sum().item())
136-
137-
_, num_tokens_per_expert_sorted_index = torch.sort(num_tokens_per_expert, descending=True)
138-
expert_assignments = []
139-
num_tokens_per_expert_sorted_index_lists = num_tokens_per_expert_sorted_index.tolist()
140-
for expert_id in num_tokens_per_expert_sorted_index_lists:
141-
expert_assignments.extend([expert_id] * num_tokens_per_expert[expert_id])
142-
143-
expert_assignments = torch.tensor(expert_assignments, dtype=torch.long)
144-
h_selected_experts = expert_assignments.reshape(revised_topk, revised_num_tokens).T
145-
146-
expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
147-
router_logits = F.softmax(expert_map.half(), dim=1)
148-
return router_logits
149-
150-
15119
def get_moe_test_cases():
15220
"""Generate MoE test cases"""
153-
num_tokens = [
154-
1,
155-
2,
156-
4,
157-
8,
158-
16,
159-
32,
160-
48,
161-
64,
162-
80,
163-
96,
164-
128,
165-
160,
166-
192,
167-
256,
168-
320,
169-
384,
170-
512,
171-
768,
172-
1024,
173-
1536,
174-
2048,
175-
3072,
176-
4096,
177-
6144,
178-
8192,
179-
12288,
180-
16384,
181-
20480,
182-
32768,
183-
65536,
184-
]
185-
tp_list = [1, 2, 4, 8, 16, 32]
186-
ep_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
187-
num_gpu_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
188-
alpha_list = [1.01, 1.2]
189-
190-
# Model configurations: [hidden_size, inter_size, topk, num_experts, model_name]
191-
model_config_list = [
192-
[4096, 14336, 2, 8, "MOE_Mixtral8x7B"], # mixtral_8x7b
193-
[6144, 16384, 2, 8, "MOE_Mixtral8x22B"], # mixtral_8x22b
194-
[7168, 2048, 8, 256, "DEEPSEEK_V3"], # deepseekv3
195-
[2048, 768, 8, 128, "QWEN3_30B_A3B"], # qwen3-moe, 30b-a3b
196-
[4096, 1536, 8, 128, "QWEN3_235B"], # qwen3-moe, 235b-a22b
197-
[6144, 2560, 8, 160, "QWEN3_480B"], # qwen3-moe, 480b-a35b
198-
[7168, 2048, 8, 384, "KIMI_K2"], # kimi k2
199-
]
20021

20122
# Quantization types supported by vLLM
20223
moe_list = ["float16"]
203-
20424
if get_sm_version() > 86:
20525
moe_list += ["fp8"]
20626

20727
test_cases = []
20828

209-
for num_gpu in num_gpu_list:
29+
for common_moe_testcase in get_common_moe_test_cases():
30+
if common_moe_testcase.token_expert_distribution != "power_law":
31+
continue
32+
33+
model_name = common_moe_testcase.model_name
34+
if model_name in ["GPT_OSS_20B", "GPT_OSS_120B"]:
35+
continue
36+
37+
# vllm does not support TP when EP is enabled.
38+
if common_moe_testcase.tp > 1 and common_moe_testcase.ep > 1:
39+
continue
40+
21041
for moe_type in moe_list:
211-
for model_config in model_config_list:
212-
hs, inter_s, topk, num_experts, model_name = model_config
213-
for tp in tp_list:
214-
# QWEN3_30B_A3B: exclude tp >= 8 as they are not used in actual deployments
215-
if model_name == "QWEN3_30B_A3B" and tp >= 8:
216-
continue
217-
for ep in ep_list:
218-
if tp * ep != num_gpu:
219-
continue
220-
if ep > num_experts:
221-
continue
222-
if num_experts % ep != 0:
223-
continue
224-
# Ensure inter_s can be divided by tp
225-
if inter_s % tp != 0:
226-
continue
227-
228-
# vllm does not support TP when EP is enabled.
229-
if tp > 1 and ep > 1:
230-
continue
231-
232-
for power_law_alpha in alpha_list:
233-
test_cases.append(
234-
[
235-
moe_type,
236-
num_tokens,
237-
hs,
238-
inter_s,
239-
topk,
240-
num_experts,
241-
tp,
242-
ep,
243-
model_name,
244-
"moe_perf.txt",
245-
"power_law",
246-
power_law_alpha,
247-
]
248-
)
42+
test_cases.append(
43+
[
44+
moe_type,
45+
common_moe_testcase.num_tokens_list,
46+
common_moe_testcase.hidden_size,
47+
common_moe_testcase.inter_size,
48+
common_moe_testcase.topk,
49+
common_moe_testcase.num_experts,
50+
common_moe_testcase.tp,
51+
common_moe_testcase.ep,
52+
common_moe_testcase.model_name,
53+
"moe_perf.txt",
54+
common_moe_testcase.token_expert_distribution,
55+
common_moe_testcase.power_law_alpha,
56+
]
57+
)
24958

25059
return test_cases
25160

@@ -293,9 +102,6 @@ def run_moe_torch(
293102
local_num_experts = num_experts // moe_ep_size
294103
local_inter_size = inter_size // moe_tp_size
295104

296-
# How many experts will be run on this GPU
297-
local_topk = min(topk, local_num_experts)
298-
299105
# Create weight tensors
300106
# w1: gate + up projection weights [num_experts, 2 * inter_size, hidden_size]
301107
# w2: down projection weights [num_experts, hidden_size, inter_size]
@@ -314,6 +120,9 @@ def run_moe_torch(
314120
device=device,
315121
)
316122

123+
# Maps global expert index to local expert index.
124+
_, expert_map = determine_expert_map(moe_ep_size, 0, num_experts)
125+
317126
if dtype == torch.float8_e4m3fn:
318127
w1 = w1.to(dtype)
319128
w2 = w2.to(dtype)
@@ -333,23 +142,25 @@ def run_moe_torch(
333142
for _ in range(num_iter):
334143
logits = (
335144
power_law_logits_v3(
336-
num_tokens, num_experts, topk, moe_ep_size, power_law_alpha, return_first_gpu_only=True
145+
num_tokens,
146+
num_experts,
147+
topk,
148+
moe_ep_size,
149+
power_law_alpha,
337150
)
338151
.half()
339152
.to(device)
340153
)
341-
weights, ids = torch.topk(logits, local_topk, dim=-1)
154+
weights, ids = torch.topk(logits, topk, dim=-1)
342155
topk_weights_list.append(F.softmax(weights, dim=-1))
343156
topk_ids_list.append(ids)
344157

345158
print("actual num_tokens: ", [topk_ids.shape[0] for topk_ids in topk_ids_list])
346159

347160
elif distributed == "balanced":
348-
local_num_tokens = math.ceil(num_tokens / moe_ep_size)
349-
actual_logits = balanced_logits(local_num_tokens, local_num_experts, local_topk).half().to(device)
350-
topk_weights, topk_ids = torch.topk(actual_logits, local_topk, dim=-1)
161+
actual_logits = balanced_logits(num_tokens, num_experts, topk).half().to(device)
162+
topk_weights, topk_ids = torch.topk(actual_logits, topk, dim=-1)
351163
topk_weights = F.softmax(topk_weights, dim=-1)
352-
print("actual num_tokens: ", actual_logits.shape[0])
353164

354165
else:
355166
raise ValueError(f"Unsupported distributed mode: {distributed}")
@@ -372,6 +183,8 @@ def run_single_iteration():
372183
ti,
373184
inplace=True,
374185
quant_config=quant_config,
186+
global_num_experts=num_experts,
187+
expert_map=expert_map,
375188
)
376189
else:
377190
_ = fused_experts(
@@ -382,6 +195,8 @@ def run_single_iteration():
382195
topk_ids,
383196
inplace=True,
384197
quant_config=quant_config,
198+
global_num_experts=num_experts,
199+
expert_map=expert_map,
385200
)
386201

387202
def run_iterations(use_cuda_graph=False):

src/aiconfigurator/sdk/operations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,15 @@ def query(self, database: PerfDatabase, **kwargs):
358358
else:
359359
comm_latency = 0
360360
elif database.backend == common.BackendName.vllm.value:
361-
assert self._moe_tp_size == 1, "vllm does not support moe_tp_size > 1"
361+
assert self._moe_tp_size == 1 or self._moe_ep_size == 1, (
362+
"vllm does not support MoE TP and MoE EP at the same time"
363+
)
362364

363365
comm_latency = 0
364366

365367
# Add allreduce latency when TP > 1
366368
if self._attention_tp_size > 1:
367-
comm_latency += database.query_allreduce(common.CommQuantMode.half, self.num_gpus, volume)
369+
comm_latency += database.query_custom_allreduce(common.CommQuantMode.half, self.num_gpus, volume)
368370

369371
if self._attention_dp_size > 1:
370372
comm_latency += database.query_nccl(

src/aiconfigurator/sdk/pareto_analysis.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,9 @@ def enumerate_parallel_config(
6969
not enable_wideep and moe_ep > 1
7070
): # wideep only has ep
7171
continue
72-
elif backend == common.BackendName.vllm:
73-
pass # TODO
72+
elif backend == common.BackendName.vllm and moe_tp > 1 and moe_ep > 1:
73+
continue # vllm does not support moe_tp > 1 and moe_ep > 1 at the same time
74+
7475
parallel_config_list.append([tp, pp, dp, moe_tp, moe_ep])
7576
else:
7677
if tp * pp in num_gpu_list:

0 commit comments

Comments
 (0)