Skip to content

Commit d73e921

Browse files
committed
Protect import for device_mesh (#3742)
1 parent f4593e3 commit d73e921

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

src/accelerate/parallelism_config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
from dataclasses import dataclass
1818
from typing import TYPE_CHECKING, Optional, Union
1919

20-
from torch.distributed.device_mesh import init_device_mesh
21-
2220
from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig
21+
from accelerate.utils.versions import is_torch_version
2322

2423

2524
if TYPE_CHECKING:
@@ -191,6 +190,11 @@ def build_device_mesh(self, device_type: str):
191190
Args:
192191
device_type (`str`): The type of device for which to build the mesh, e
193192
"""
193+
if is_torch_version(">=", "2.2.0"):
194+
from torch.distributed.device_mesh import init_device_mesh
195+
else:
196+
raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0")
197+
194198
mesh = self._get_mesh()
195199
if len(mesh) == 0:
196200
return None

tests/test_dataclasses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def mock_flatten(name):
7676

7777
return mesh
7878

79-
with patch("accelerate.parallelism_config.init_device_mesh", side_effect=mock_init_mesh):
79+
with patch("torch.distributed.device_mesh.init_device_mesh", side_effect=mock_init_mesh):
8080
yield mock_init_mesh
8181

8282
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)