Skip to content

Commit d7a0aef

Browse files
Set OCL_SET_SVM_SIZE on AMD. (#11139)
1 parent 913f86b commit d7a0aef

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

cuda_malloc.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,22 @@ def cuda_malloc_supported():
6363
return True
6464

6565

66+
version = ""
67+
68+
try:
69+
torch_spec = importlib.util.find_spec("torch")
70+
for folder in torch_spec.submodule_search_locations:
71+
ver_file = os.path.join(folder, "version.py")
72+
if os.path.isfile(ver_file):
73+
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
74+
module = importlib.util.module_from_spec(spec)
75+
spec.loader.exec_module(module)
76+
version = module.__version__
77+
except:
78+
pass
79+
6680
if not args.cuda_malloc:
6781
try:
68-
version = ""
69-
torch_spec = importlib.util.find_spec("torch")
70-
for folder in torch_spec.submodule_search_locations:
71-
ver_file = os.path.join(folder, "version.py")
72-
if os.path.isfile(ver_file):
73-
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
74-
module = importlib.util.module_from_spec(spec)
75-
spec.loader.exec_module(module)
76-
version = module.__version__
77-
7882
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
7983
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
8084
args.cuda_malloc = cuda_malloc_supported()
@@ -90,3 +94,6 @@ def cuda_malloc_supported():
9094
env_var += ",backend:cudaMallocAsync"
9195

9296
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
97+
98+
def get_torch_version_noimport():
99+
return str(version)

main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ def execute_script(script_path):
167167
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
168168

169169
import cuda_malloc
170+
if "rocm" in cuda_malloc.get_torch_version_noimport():
171+
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
172+
170173

171174
if 'torch' in sys.modules:
172175
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")

0 commit comments

Comments
 (0)