Skip to content

Commit ca37da4

Browse files
committed
fix
1 parent de7773f commit ca37da4

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

collector/vllm/collect_moe.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
1111
from vllm.version import __version__ as vllm_version
1212

13-
# from helper import get_sm_version, log_perf
13+
from helper import get_sm_version, log_perf
1414

15-
def get_sm_version():
16-
return 86
15+
# def get_sm_version():
16+
# return 86
1717

18-
def log_perf(*args, **kwargs):
19-
pass
18+
# def log_perf(*args, **kwargs):
19+
# pass
2020

2121
aic_debug = int(os.getenv("aic_moe_debug", "0")) # noqa: SIM112
2222

@@ -34,7 +34,7 @@ def balanced_logits(num_tokens, num_experts, topk):
3434
h_selected_experts[token_i][i] = (token_i * stride / num_tokens + i * stride) % num_experts
3535

3636
expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
37-
router_logits = F.softmax(expert_map.bfloat16(), dim=1)
37+
router_logits = F.softmax(expert_map.half(), dim=1)
3838
return router_logits
3939

4040

@@ -146,7 +146,7 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, return_first_g
146146

147147

148148
expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
149-
router_logits = F.softmax(expert_map.bfloat16(), dim=1)
149+
router_logits = F.softmax(expert_map.half(), dim=1)
150150
return router_logits
151151

152152

@@ -227,6 +227,10 @@ def get_moe_test_cases():
227227
if inter_s % tp != 0:
228228
continue
229229

230+
# vllm does not support TP when EP is enabled.
231+
if tp > 1 and ep > 1:
232+
continue
233+
230234
for power_law_alpha in alpha_list:
231235
test_cases.append(
232236
[
@@ -245,7 +249,7 @@ def get_moe_test_cases():
245249
]
246250
)
247251

248-
return test_cases[:20]
252+
return test_cases
249253

250254
def run_moe_torch(
251255
moe_type,
@@ -267,7 +271,7 @@ def run_moe_torch(
267271
torch.set_default_device(device)
268272

269273
# Configure quantization parameters
270-
dtype = torch.bfloat16
274+
dtype = torch.float16
271275
quant_config = None
272276

273277
if moe_type == "fp8":
@@ -300,14 +304,14 @@ def run_moe_torch(
300304
local_num_experts,
301305
2 * local_inter_size,
302306
hidden_size,
303-
dtype=torch.bfloat16,
307+
dtype=torch.float16,
304308
device=device
305309
)
306310
w2 = torch.randn(
307311
local_num_experts,
308312
hidden_size,
309313
local_inter_size,
310-
dtype=torch.bfloat16,
314+
dtype=torch.float16,
311315
device=device
312316
)
313317

@@ -319,7 +323,7 @@ def run_moe_torch(
319323
for num_tokens_idx, num_tokens in enumerate(num_tokens_lists):
320324
print("num_tokens", num_tokens)
321325
print("topk", topk)
322-
hidden_states = torch.randn([num_tokens, hidden_size]).bfloat16().to(device)
326+
hidden_states = torch.randn([num_tokens, hidden_size]).half().to(device)
323327

324328
# Generate topk_weights and topk_ids
325329
num_iter = 10 if distributed == "power_law" else 1
@@ -335,7 +339,7 @@ def run_moe_torch(
335339
moe_ep_size,
336340
power_law_alpha,
337341
return_first_gpu_only=True
338-
).bfloat16().to(device)
342+
).half().to(device)
339343
weights, ids = torch.topk(logits, local_topk, dim=-1)
340344
topk_weights_list.append(F.softmax(weights, dim=-1))
341345
topk_ids_list.append(ids)
@@ -344,7 +348,7 @@ def run_moe_torch(
344348

345349
elif distributed == "balanced":
346350
local_num_tokens = math.ceil(num_tokens / moe_ep_size)
347-
actual_logits = balanced_logits(local_num_tokens, local_num_experts, local_topk).bfloat16().to(device)
351+
actual_logits = balanced_logits(local_num_tokens, local_num_experts, local_topk).half().to(device)
348352
topk_weights, topk_ids = torch.topk(actual_logits, local_topk, dim=-1)
349353
topk_weights = F.softmax(topk_weights, dim=-1)
350354
print("actual num_tokens: ", actual_logits.shape[0])
@@ -459,6 +463,8 @@ def run_iterations(use_cuda_graph=False):
459463
# test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 7168, 2048, 8, 384, 1, 4, 'KIMI_K2', 'moe_perf.txt', 'power_law', 1.01]]
460464
# test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 2048, 768, 8, 128, 2, 32, 'QWEN3_30B_A3B', 'moe_perf.txt', 'power_law', 1.01]]
461465

466+
test_cases = [['float16',[65536],4096,14336,2,8,1,1,'MOE_Mixtral8x7B', 'moe_perf.txt', 'power_law', 1.01]]
467+
462468
print(f"Total test cases: {len(test_cases)}")
463469

464470
for test_case in test_cases[:40]:

0 commit comments

Comments
 (0)