Skip to content

Commit 5921c78

Browse files
committed
Add SDK support for vllm moe
1 parent ca37da4 commit 5921c78

File tree

7 files changed

+68
-50
lines changed

7 files changed

+68
-50
lines changed

collector/vllm/collect_moe.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, return_first_g
106106
)
107107
num_tokens_per_expert = num_tokens_per_expert_reshaped.view(-1)
108108

109-
110109
revised_num_tokens = num_tokens
111110
revised_topk = topk
112111
if return_first_gpu_only:
@@ -144,7 +143,6 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, return_first_g
144143
expert_assignments = torch.tensor(expert_assignments, dtype=torch.long)
145144
h_selected_experts = expert_assignments.reshape(revised_topk, revised_num_tokens).T
146145

147-
148146
expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
149147
router_logits = F.softmax(expert_map.half(), dim=1)
150148
return router_logits
@@ -188,7 +186,7 @@ def get_moe_test_cases():
188186
ep_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
189187
num_gpu_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
190188
alpha_list = [1.01, 1.2]
191-
189+
192190
# Model configurations: [hidden_size, inter_size, topk, num_experts, model_name]
193191
model_config_list = [
194192
[4096, 14336, 2, 8, "MOE_Mixtral8x7B"], # mixtral_8x7b
@@ -204,7 +202,7 @@ def get_moe_test_cases():
204202
moe_list = ["float16"]
205203

206204
if get_sm_version() > 86:
207-
moe_list += ["fp8",]
205+
moe_list += ["fp8"]
208206

209207
test_cases = []
210208

@@ -251,6 +249,7 @@ def get_moe_test_cases():
251249

252250
return test_cases
253251

252+
254253
def run_moe_torch(
255254
moe_type,
256255
num_tokens_lists,
@@ -301,18 +300,18 @@ def run_moe_torch(
301300
# w1: gate + up projection weights [num_experts, 2 * inter_size, hidden_size]
302301
# w2: down projection weights [num_experts, hidden_size, inter_size]
303302
w1 = torch.randn(
304-
local_num_experts,
305-
2 * local_inter_size,
306-
hidden_size,
307-
dtype=torch.float16,
308-
device=device
303+
local_num_experts,
304+
2 * local_inter_size,
305+
hidden_size,
306+
dtype=torch.float16,
307+
device=device,
309308
)
310309
w2 = torch.randn(
311-
local_num_experts,
312-
hidden_size,
313-
local_inter_size,
314-
dtype=torch.float16,
315-
device=device
310+
local_num_experts,
311+
hidden_size,
312+
local_inter_size,
313+
dtype=torch.float16,
314+
device=device,
316315
)
317316

318317
if dtype == torch.float8_e4m3fn:
@@ -332,14 +331,13 @@ def run_moe_torch(
332331
topk_ids_list = []
333332

334333
for _ in range(num_iter):
335-
logits = power_law_logits_v3(
336-
num_tokens,
337-
num_experts,
338-
topk,
339-
moe_ep_size,
340-
power_law_alpha,
341-
return_first_gpu_only=True
342-
).half().to(device)
334+
logits = (
335+
power_law_logits_v3(
336+
num_tokens, num_experts, topk, moe_ep_size, power_law_alpha, return_first_gpu_only=True
337+
)
338+
.half()
339+
.to(device)
340+
)
343341
weights, ids = torch.topk(logits, local_topk, dim=-1)
344342
topk_weights_list.append(F.softmax(weights, dim=-1))
345343
topk_ids_list.append(ids)
@@ -356,7 +354,6 @@ def run_moe_torch(
356354
else:
357355
raise ValueError(f"Unsupported distributed mode: {distributed}")
358356

359-
360357
num_warmups = 3
361358
num_runs = 6
362359
if distributed == "power_law":
@@ -418,11 +415,11 @@ def run_iterations(use_cuda_graph=False):
418415

419416
try:
420417
latency = run_iterations(use_cuda_graph=False)
421-
except torch.OutOfMemoryError as e:
418+
except torch.OutOfMemoryError:
422419
# If OOM, check if we had at least one successful run.
423420
if num_tokens_idx > 0:
424421
break
425-
raise e
422+
raise
426423

427424
print(f"moe latency: {latency}")
428425

@@ -454,24 +451,13 @@ def run_iterations(use_cuda_graph=False):
454451

455452
if __name__ == "__main__":
456453
test_cases = get_moe_test_cases()
457-
# test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 4096, 14336, 2, 8, 1, 2, 'MOE_Mixtral8x7B', 'moe_perf.txt', 'power_law', 1.2]]
458-
# test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 4096, 14336, 2, 8, 1, 2, 'MOE_Mixtral8x7B', 'moe_perf.txt', 'balanced', 1.01]]
459-
# test_cases = [['float16', [128, 256, 320], 4096, 14336, 2, 8, 1, 2, 'MOE_Mixtral8x7B', 'moe_perf.txt', 'power_law', 1.01]]
460-
test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 4096, 14336, 2, 8, 1, 2, 'MOE_Mixtral8x7B', 'moe_perf.txt', 'power_law', 1.01]]
461-
# test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 6144, 2560, 8, 160, 1, 1, 'QWEN3_480B', 'moe_perf.txt', 'power_law', 1.2]]
462-
# test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 7168, 2048, 8, 384, 1, 1, 'KIMI_K2', 'moe_perf.txt', 'power_law', 1.01]]
463-
# test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 7168, 2048, 8, 384, 1, 4, 'KIMI_K2', 'moe_perf.txt', 'power_law', 1.01]]
464-
# test_cases = [['float16', [1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 20480, 32768, 65536], 2048, 768, 8, 128, 2, 32, 'QWEN3_30B_A3B', 'moe_perf.txt', 'power_law', 1.01]]
465-
466-
test_cases = [['float16',[65536],4096,14336,2,8,1,1,'MOE_Mixtral8x7B', 'moe_perf.txt', 'power_law', 1.01]]
467-
468454
print(f"Total test cases: {len(test_cases)}")
469-
470-
for test_case in test_cases[:40]:
455+
456+
for test_case in test_cases:
471457
print(f"Running test case: {test_case}")
472458
try:
473459
run_moe_torch(*test_case)
474460
except Exception as e:
475461
print(f"Test case failed: {test_case}")
476462
print(f"Error: {e}")
477-
continue
463+
continue

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ ignore = [
146146
"RUF059", # unpacked variable is never used
147147
"UP007", # require using X | Y for type annotations
148148
"UP045", # require using X | None for type annotations
149+
"B023", # Function definition does not bind loop variable
149150
]
150151

151152
[tool.ruff.lint.isort]

src/aiconfigurator/sdk/operations.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,22 @@ def query(self, database: PerfDatabase, **kwargs):
348348
else:
349349
comm_latency = 0
350350
elif database.backend == common.BackendName.vllm.value:
351-
raise NotImplementedError("Need to implement MoE dispatch for vllm")
352-
else: # sglang
351+
assert self._moe_tp_size == 1, "vllm does not support moe_tp_size > 1"
352+
353+
comm_latency = 0
354+
355+
# Add allreduce latency when TP > 1
356+
if self._attention_tp_size > 1:
357+
comm_latency += database.query_allreduce(common.CommQuantMode.half, self.num_gpus, volume)
358+
359+
if self._attention_dp_size > 1:
360+
comm_latency += database.query_nccl(
361+
common.CommQuantMode.half,
362+
self.num_gpus,
363+
"all_gather" if self._pre_dispatch else "reduce_scatter",
364+
volume * self._attention_dp_size,
365+
)
366+
elif database.backend == common.BackendName.sglang.value:
353367
if self._moe_backend == "deepep_moe":
354368
if self._is_context:
355369
comm_latency = database.query_deepep_normal(
@@ -370,6 +384,8 @@ def query(self, database: PerfDatabase, **kwargs):
370384
)
371385
else:
372386
raise NotImplementedError(f"MoE backend {self._moe_backend} not implemented")
387+
else:
388+
raise NotImplementedError(f"Backend {database.backend} not implemented")
373389
return comm_latency * self._scale_factor
374390

375391
def get_weights(self, **kwargs):

src/aiconfigurator/sdk/perf_database.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,8 @@ def __init__(self, system: str, backend: str, version: str, systems_dir: str = "
11211121
self._custom_allreduce_data = load_custom_allreduce_data(
11221122
os.path.join(data_dir, common.PerfDataFilename.custom_allreduce.value)
11231123
)
1124+
self._moe_data, _ = load_moe_data(os.path.join(data_dir, common.PerfDataFilename.moe.value))
1125+
self._nccl_data = load_nccl_data(nccl_data_dir)
11241126
else: # TRTLLM
11251127
self._gemm_data = load_gemm_data(os.path.join(data_dir, common.PerfDataFilename.gemm.value))
11261128
self._context_attention_data = load_context_attention_data(
@@ -2590,6 +2592,13 @@ def get_sol(
25902592
num_left, num_right = self._nearest_1d_point_helper(num_tokens, list(moe_dict.keys()), inner_only=False)
25912593
lat = self._interp_1d([num_left, num_right], [moe_dict[num_left], moe_dict[num_right]], num_tokens)
25922594
return lat
2595+
elif self.backend == common.BackendName.vllm.value:
2596+
moe_dict = self._moe_data[quant_mode][workload_distribution][topk][num_experts][hidden_size][
2597+
inter_size
2598+
][moe_tp_size][moe_ep_size]
2599+
num_left, num_right = self._nearest_1d_point_helper(num_tokens, list(moe_dict.keys()), inner_only=False)
2600+
latency = self._interp_1d([num_left, num_right], [moe_dict[num_left], moe_dict[num_right]], num_tokens)
2601+
return latency
25932602
else:
25942603
raise NotImplementedError(f"backend {self.backend} not supported for moe")
25952604

src/aiconfigurator/sdk/task.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,6 @@ def _get_quant_mode(
436436
use_specific_quant_mode: str | None = None,
437437
) -> tuple[str, str, str, str, str]:
438438
gemm_quant_mode = "fp8_block"
439-
moe_quant_mode = "fp8_block"
440439
kvcache_quant_mode = "fp8"
441440
fmha_quant_mode = "float16" if model_name in ["DEEPSEEK_V3", "KIMI_K2"] else "fp8"
442441
comm_quant_mode = "half"
@@ -469,13 +468,14 @@ def _get_quant_mode(
469468

470469
if model_name in ["DEEPSEEK_V3", "KIMI_K2"]:
471470
fmha_quant_mode = "float16"
471+
472472
if (
473473
any(keyword in model_name for keyword in ["MOE_Mixtral", "QWEN2", "LLAMA"])
474474
and sm_version < 100
475475
and sm_version >= 89
476476
):
477-
gemm_quant_mode = "fp8"
478-
moe_quant_mode = "fp8"
477+
gemm_quant_mode = fp8_gemm_quant
478+
moe_quant_mode = fp8_gemm_quant
479479

480480
if use_specific_quant_mode is not None:
481481
if use_specific_quant_mode != "w4afp8":
@@ -730,8 +730,8 @@ def validate(self):
730730
"""
731731
Check that the task can be run by AIC.
732732
"""
733-
if check_is_moe(self.model_name) and self.backend_name == "vllm":
734-
raise NotImplementedError("AIConfigurator does not yet support MOE models for VLLM backend.")
733+
if self.backend_name == "vllm" and get_model_family(self.model_name) == "DEEPSEEK":
734+
raise NotImplementedError("AIConfigurator does not yet support DEEPSEEK models for VLLM backend.")
735735

736736
def pretty(self) -> str:
737737
def _convert(obj: Any) -> Any:
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:adb518cbf8558ed461469cde875d81a9e36702945d0046f0fd45aae3cc98b2d1
3+
size 1048377

tools/sanity_check/validate_database.ipynb

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
"from aiconfigurator.sdk import common\n",
1515
"from aiconfigurator.sdk.perf_database import get_database\n",
1616
"\n",
17-
"system = \"gb200_sxm\"\n",
18-
"database = get_database(system=system, backend=\"trtllm\", version=\"1.0.0rc6\")"
17+
"system = \"h100_sxm\"\n",
18+
"# database = get_database(system=system, backend=\"trtllm\", version=\"1.0.0rc3\")\n",
19+
"database = get_database(system=system, backend=\"vllm\", version=\"0.11.0\")"
1920
]
2021
},
2122
{
@@ -563,6 +564,8 @@
563564
" tp_ep_list = []\n",
564565
" for tp in tp_list:\n",
565566
" for ep in ep_list:\n",
567+
" if database.backend == \"vllm\" and tp > 1 and ep > 1:\n",
568+
" continue\n",
566569
" if tp * ep >= 4 and tp * ep <= 16:\n",
567570
" tp_ep_list.append([tp, ep])\n",
568571
" m_list = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 4]\n",
@@ -775,7 +778,7 @@
775778
],
776779
"metadata": {
777780
"kernelspec": {
778-
"display_name": "Python 3 (ipykernel)",
781+
"display_name": "myenv",
779782
"language": "python",
780783
"name": "python3"
781784
},
@@ -789,7 +792,7 @@
789792
"name": "python",
790793
"nbconvert_exporter": "python",
791794
"pygments_lexer": "ipython3",
792-
"version": "3.12.3"
795+
"version": "3.10.17"
793796
}
794797
},
795798
"nbformat": 4,

0 commit comments

Comments
 (0)