Skip to content

Commit e4de7f0

Browse files
davilu-nvidiailyasher
authored andcommitted
update vllm moe collect and helper for sm ver detection
1 parent 5765ea2 commit e4de7f0

File tree

3 files changed

+68
-18
lines changed

3 files changed

+68
-18
lines changed

collector/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
moe_perf.txt
2+
*.log
3+
moe_*/

collector/collect.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def collect_sglang(num_processes: int, ops: list[str] | None = None):
427427

428428
def collect_vllm(num_processes: int, ops: list[str] | None = None):
429429
"""
430-
Collect performance data for VLLM v1.
430+
Collect performance data for VLLM
431431
"""
432432

433433
try:
@@ -441,7 +441,7 @@ def collect_vllm(num_processes: int, ops: list[str] | None = None):
441441

442442
collections = [
443443
# GEMM collections
444-
# vllm v1 GEMM collection for fp16, fp8, fp8_block, nvfp4, awq, and gptq
444+
# vllm GEMM collection for fp16, fp8, fp8_block, nvfp4, awq, and gptq
445445
{
446446
"name": "vllm",
447447
"type": "gemm",
@@ -464,6 +464,13 @@ def collect_vllm(num_processes: int, ops: list[str] | None = None):
464464
"get_func": "get_generation_attention_test_cases",
465465
"run_func": "run_attention_torch",
466466
},
467+
{
468+
"name": "vllm",
469+
"type": "moe",
470+
"module": "collector.vllm.collect_moe",
471+
"get_func": "get_moe_test_cases",
472+
"run_func": "run_moe_torch",
473+
},
467474
]
468475

469476
all_errors = collect_ops(num_processes, collections, ops, version)

collector/helper.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -215,23 +215,63 @@ def save_error_report(errors, filename):
215215
with open(filename, "w") as f:
216216
json.dump(errors, f, indent=2)
217217

218-
219218
def get_sm_version():
220-
# Init
221-
(err,) = cuda.cuInit(0)
222-
223-
# Device
224-
err, cu_device = cuda.cuDeviceGet(0)
225-
226-
# Get target architecture
227-
err, sm_major = cuda.cuDeviceGetAttribute(
228-
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cu_device
229-
)
230-
err, sm_minor = cuda.cuDeviceGetAttribute(
231-
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cu_device
232-
)
233-
234-
return sm_major * 10 + sm_minor
219+
"""Get CUDA compute capability (SM version)"""
220+
try:
221+
import torch
222+
if torch.cuda.is_available():
223+
device = torch.cuda.current_device()
224+
capability = torch.cuda.get_device_capability(device)
225+
return capability[0] * 10 + capability[1]
226+
except Exception:
227+
pass
228+
229+
# fallback to cuda-python
230+
try:
231+
from cuda import cuda
232+
# Init
233+
(err,) = cuda.cuInit(0)
234+
if err != 0:
235+
raise RuntimeError(f"cuInit failed with error code: {err}")
236+
237+
# Device
238+
err, cu_device = cuda.cuDeviceGet(0)
239+
if err != 0:
240+
raise RuntimeError(f"cuDeviceGet failed with error code: {err}")
241+
242+
# Get target architecture
243+
err, sm_major = cuda.cuDeviceGetAttribute(
244+
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
245+
cu_device
246+
)
247+
err, sm_minor = cuda.cuDeviceGetAttribute(
248+
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
249+
cu_device
250+
)
251+
252+
return sm_major * 10 + sm_minor
253+
except Exception as e:
254+
raise RuntimeError(
255+
f"Cannot get SM version: both PyTorch and cuda-python failed. "
256+
f"Error: {e}"
257+
) from e
258+
259+
# def get_sm_version():
260+
# # Init
261+
# (err,) = cuda.cuInit(0)
262+
263+
# # Device
264+
# err, cu_device = cuda.cuDeviceGet(0)
265+
266+
# # Get target architecture
267+
# err, sm_major = cuda.cuDeviceGetAttribute(
268+
# cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cu_device
269+
# )
270+
# err, sm_minor = cuda.cuDeviceGetAttribute(
271+
# cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cu_device
272+
# )
273+
274+
# return sm_major * 10 + sm_minor
235275

236276

237277
def create_test_case_id(test_case, test_type, module_name):

0 commit comments

Comments
 (0)