@@ -215,23 +215,63 @@ def save_error_report(errors, filename):
215215 with open (filename , "w" ) as f :
216216 json .dump (errors , f , indent = 2 )
217217
218-
219218def get_sm_version ():
220- # Init
221- (err ,) = cuda .cuInit (0 )
222-
223- # Device
224- err , cu_device = cuda .cuDeviceGet (0 )
225-
226- # Get target architecture
227- err , sm_major = cuda .cuDeviceGetAttribute (
228- cuda .CUdevice_attribute .CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR , cu_device
229- )
230- err , sm_minor = cuda .cuDeviceGetAttribute (
231- cuda .CUdevice_attribute .CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR , cu_device
232- )
233-
234- return sm_major * 10 + sm_minor
219+ """Get CUDA compute capability (SM version)"""
220+ try :
221+ import torch
222+ if torch .cuda .is_available ():
223+ device = torch .cuda .current_device ()
224+ capability = torch .cuda .get_device_capability (device )
225+ return capability [0 ] * 10 + capability [1 ]
226+ except Exception :
227+ pass
228+
229+ # fallback to cuda-python
230+ try :
231+ from cuda import cuda
232+ # Init
233+ (err ,) = cuda .cuInit (0 )
234+ if err != 0 :
235+ raise RuntimeError (f"cuInit failed with error code: { err } " )
236+
237+ # Device
238+ err , cu_device = cuda .cuDeviceGet (0 )
239+ if err != 0 :
240+ raise RuntimeError (f"cuDeviceGet failed with error code: { err } " )
241+
242+ # Get target architecture
243+ err , sm_major = cuda .cuDeviceGetAttribute (
244+ cuda .CUdevice_attribute .CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR ,
245+ cu_device
246+ )
247+ err , sm_minor = cuda .cuDeviceGetAttribute (
248+ cuda .CUdevice_attribute .CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR ,
249+ cu_device
250+ )
251+
252+ return sm_major * 10 + sm_minor
253+ except Exception as e :
254+ raise RuntimeError (
255+ f"Cannot get SM version: both PyTorch and cuda-python failed. "
256+ f"Error: { e } "
257+ ) from e
258+
259+ # def get_sm_version():
260+ # # Init
261+ # (err,) = cuda.cuInit(0)
262+
263+ # # Device
264+ # err, cu_device = cuda.cuDeviceGet(0)
265+
266+ # # Get target architecture
267+ # err, sm_major = cuda.cuDeviceGetAttribute(
268+ # cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cu_device
269+ # )
270+ # err, sm_minor = cuda.cuDeviceGetAttribute(
271+ # cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cu_device
272+ # )
273+
274+ # return sm_major * 10 + sm_minor
235275
236276
237277def create_test_case_id (test_case , test_type , module_name ):
0 commit comments