mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Protect import for device_mesh (#3742)
This commit is contained in:
@ -17,9 +17,8 @@ import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -191,6 +190,11 @@ class ParallelismConfig:
|
||||
Args:
|
||||
device_type (`str`): The type of device for which to build the mesh, e
|
||||
"""
|
||||
if is_torch_version(">=", "2.2.0"):
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
else:
|
||||
raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0")
|
||||
|
||||
mesh = self._get_mesh()
|
||||
if len(mesh) == 0:
|
||||
return None
|
||||
|
@ -76,7 +76,7 @@ class TestParallelismConfig:
|
||||
|
||||
return mesh
|
||||
|
||||
with patch("accelerate.parallelism_config.init_device_mesh", side_effect=mock_init_mesh):
|
||||
with patch("torch.distributed.device_mesh.init_device_mesh", side_effect=mock_init_mesh):
|
||||
yield mock_init_mesh
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
Reference in New Issue
Block a user