Skip to content

Commit d2ea116

Browse files
authored
Revert "feat: Add MoE support for VLLM (#118)"
This reverts commit b17ba13.
1 parent b17ba13 commit d2ea116

File tree

13 files changed

+33
-524
lines changed

13 files changed

+33
-524
lines changed

collector/.gitignore

Lines changed: 0 additions & 3 deletions
This file was deleted.

collector/collect.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,6 @@ def worker(queue, device_id: int, func, progress_value, lock, error_queue=None,
128128
for handler in worker_logger.handlers:
129129
handler.flush()
130130

131-
# This error is could be fatal and require a process restart.
132-
if isinstance(e, torch.AcceleratorError):
133-
# Exiting with non-zero code will add an additional error to the summary,
134-
# which we don't want.
135-
exit(0)
136-
137131

138132
def parallel_run(tasks, func, num_processes, module_name="unknown"):
139133
"""parallel runner with error collection"""
@@ -433,7 +427,7 @@ def collect_sglang(num_processes: int, ops: list[str] | None = None):
433427

434428
def collect_vllm(num_processes: int, ops: list[str] | None = None):
435429
"""
436-
Collect performance data for VLLM
430+
Collect performance data for VLLM v1.
437431
"""
438432

439433
try:
@@ -447,7 +441,7 @@ def collect_vllm(num_processes: int, ops: list[str] | None = None):
447441

448442
collections = [
449443
# GEMM collections
450-
# vllm GEMM collection for fp16, fp8, fp8_block, nvfp4, awq, and gptq
444+
# vllm v1 GEMM collection for fp16, fp8, fp8_block, nvfp4, awq, and gptq
451445
{
452446
"name": "vllm",
453447
"type": "gemm",
@@ -470,13 +464,6 @@ def collect_vllm(num_processes: int, ops: list[str] | None = None):
470464
"get_func": "get_generation_attention_test_cases",
471465
"run_func": "run_attention_torch",
472466
},
473-
{
474-
"name": "vllm",
475-
"type": "moe",
476-
"module": "collector.vllm.collect_moe",
477-
"get_func": "get_moe_test_cases",
478-
"run_func": "run_moe_torch",
479-
},
480467
]
481468

482469
all_errors = collect_ops(num_processes, collections, ops, version)

collector/common_test_cases.py

Lines changed: 0 additions & 132 deletions
This file was deleted.

collector/helper.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
import signal
1111
import sys
1212
import traceback
13+
14+
try:
15+
from cuda import cuda
16+
except:
17+
from cuda.bindings import driver as cuda
1318
from datetime import datetime
1419
from pathlib import Path
1520

@@ -213,42 +218,21 @@ def save_error_report(errors, filename):
213218

214219

215220
def get_sm_version():
216-
"""Get CUDA compute capability (SM version)"""
217-
try:
218-
import torch
221+
# Init
222+
(err,) = cuda.cuInit(0)
219223

220-
if torch.cuda.is_available():
221-
device = torch.cuda.current_device()
222-
capability = torch.cuda.get_device_capability(device)
223-
return capability[0] * 10 + capability[1]
224-
except Exception:
225-
pass
226-
227-
# fallback to cuda-python
228-
try:
229-
from cuda import cuda
224+
# Device
225+
err, cu_device = cuda.cuDeviceGet(0)
230226

231-
# Init
232-
(err,) = cuda.cuInit(0)
233-
if err != 0:
234-
raise RuntimeError(f"cuInit failed with error code: {err}")
235-
236-
# Device
237-
err, cu_device = cuda.cuDeviceGet(0)
238-
if err != 0:
239-
raise RuntimeError(f"cuDeviceGet failed with error code: {err}")
240-
241-
# Get target architecture
242-
err, sm_major = cuda.cuDeviceGetAttribute(
243-
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cu_device
244-
)
245-
err, sm_minor = cuda.cuDeviceGetAttribute(
246-
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cu_device
247-
)
227+
# Get target architecture
228+
err, sm_major = cuda.cuDeviceGetAttribute(
229+
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cu_device
230+
)
231+
err, sm_minor = cuda.cuDeviceGetAttribute(
232+
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cu_device
233+
)
248234

249-
return sm_major * 10 + sm_minor
250-
except Exception as e:
251-
raise RuntimeError(f"Cannot get SM version: both PyTorch and cuda-python failed. Error: {e}") from e
235+
return sm_major * 10 + sm_minor
252236

253237

254238
def create_test_case_id(test_case, test_type, module_name):

0 commit comments

Comments
 (0)