Skip to content

Commit b7db778

Browse files
committed
Add MoE support for vllm
1 parent e4de7f0 commit b7db778

File tree

9 files changed

+196
-156
lines changed

9 files changed

+196
-156
lines changed

collector/collect.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ def worker(queue, device_id: int, func, progress_value, lock, error_queue=None,
128128
for handler in worker_logger.handlers:
129129
handler.flush()
130130

131+
# This error is could be fatal and require a process restart.
132+
if isinstance(e, torch.AcceleratorError):
133+
# Exiting with non-zero code will add an additional error to the summary,
134+
# which we don't want.
135+
exit(0)
136+
131137

132138
def parallel_run(tasks, func, num_processes, module_name="unknown"):
133139
"""parallel runner with error collection"""
@@ -329,6 +335,7 @@ def collect_ops(
329335
"traceback": traceback.format_exc(),
330336
}
331337
)
338+
return all_errors
332339

333340
return all_errors
334341

@@ -427,7 +434,7 @@ def collect_sglang(num_processes: int, ops: list[str] | None = None):
427434

428435
def collect_vllm(num_processes: int, ops: list[str] | None = None):
429436
"""
430-
Collect performance data for VLLM
437+
Collect performance data for VLLM
431438
"""
432439

433440
try:

collector/helper.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010
import sys
1111
import traceback
1212

13-
try:
14-
from cuda import cuda
15-
except:
16-
from cuda.bindings import driver as cuda
1713
from datetime import datetime
1814
from pathlib import Path
1915

@@ -215,63 +211,44 @@ def save_error_report(errors, filename):
215211
with open(filename, "w") as f:
216212
json.dump(errors, f, indent=2)
217213

214+
218215
def get_sm_version():
219216
"""Get CUDA compute capability (SM version)"""
220217
try:
221218
import torch
219+
222220
if torch.cuda.is_available():
223221
device = torch.cuda.current_device()
224222
capability = torch.cuda.get_device_capability(device)
225223
return capability[0] * 10 + capability[1]
226224
except Exception:
227225
pass
228-
226+
229227
# fallback to cuda-python
230228
try:
231229
from cuda import cuda
230+
232231
# Init
233232
(err,) = cuda.cuInit(0)
234233
if err != 0:
235234
raise RuntimeError(f"cuInit failed with error code: {err}")
236-
235+
237236
# Device
238237
err, cu_device = cuda.cuDeviceGet(0)
239238
if err != 0:
240239
raise RuntimeError(f"cuDeviceGet failed with error code: {err}")
241-
240+
242241
# Get target architecture
243242
err, sm_major = cuda.cuDeviceGetAttribute(
244-
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
245-
cu_device
243+
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cu_device
246244
)
247245
err, sm_minor = cuda.cuDeviceGetAttribute(
248-
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
249-
cu_device
246+
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cu_device
250247
)
251-
248+
252249
return sm_major * 10 + sm_minor
253250
except Exception as e:
254-
raise RuntimeError(
255-
f"Cannot get SM version: both PyTorch and cuda-python failed. "
256-
f"Error: {e}"
257-
) from e
258-
259-
# def get_sm_version():
260-
# # Init
261-
# (err,) = cuda.cuInit(0)
262-
263-
# # Device
264-
# err, cu_device = cuda.cuDeviceGet(0)
265-
266-
# # Get target architecture
267-
# err, sm_major = cuda.cuDeviceGetAttribute(
268-
# cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cu_device
269-
# )
270-
# err, sm_minor = cuda.cuDeviceGetAttribute(
271-
# cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cu_device
272-
# )
273-
274-
# return sm_major * 10 + sm_minor
251+
raise RuntimeError(f"Cannot get SM version: both PyTorch and cuda-python failed. Error: {e}") from e
275252

276253

277254
def create_test_case_id(test_case, test_type, module_name):

0 commit comments

Comments
 (0)