@@ -63,18 +63,22 @@ def cuda_malloc_supported():
6363 return True
6464
6565
66+ version = ""
67+
68+ try :
69+ torch_spec = importlib .util .find_spec ("torch" )
70+ for folder in torch_spec .submodule_search_locations :
71+ ver_file = os .path .join (folder , "version.py" )
72+ if os .path .isfile (ver_file ):
73+ spec = importlib .util .spec_from_file_location ("torch_version_import" , ver_file )
74+ module = importlib .util .module_from_spec (spec )
75+ spec .loader .exec_module (module )
76+ version = module .__version__
77+ except :
78+ pass
79+
6680if not args .cuda_malloc :
6781 try :
68- version = ""
69- torch_spec = importlib .util .find_spec ("torch" )
70- for folder in torch_spec .submodule_search_locations :
71- ver_file = os .path .join (folder , "version.py" )
72- if os .path .isfile (ver_file ):
73- spec = importlib .util .spec_from_file_location ("torch_version_import" , ver_file )
74- module = importlib .util .module_from_spec (spec )
75- spec .loader .exec_module (module )
76- version = module .__version__
77-
7882 if int (version [0 ]) >= 2 and "+cu" in version : # enable by default for torch version 2.0 and up only on cuda torch
7983 if PerformanceFeature .AutoTune not in args .fast : # Autotune has issues with cuda malloc
8084 args .cuda_malloc = cuda_malloc_supported ()
@@ -90,3 +94,6 @@ def cuda_malloc_supported():
9094 env_var += ",backend:cudaMallocAsync"
9195
9296 os .environ ['PYTORCH_CUDA_ALLOC_CONF' ] = env_var
97+
98+ def get_torch_version_noimport ():
99+ return str (version )
0 commit comments