Protect import for device_mesh (#3742)

This commit is contained in:
Marc Sun
2025-08-22 15:44:56 +02:00
committed by GitHub
parent 5fe4460ccd
commit 5dd3d0b690
2 changed files with 7 additions and 3 deletions

View File

@ -17,9 +17,8 @@ import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
from torch.distributed.device_mesh import init_device_mesh
from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig
from accelerate.utils.versions import is_torch_version
if TYPE_CHECKING:
@ -191,6 +190,11 @@ class ParallelismConfig:
Args:
device_type (`str`): The type of device for which to build the mesh, e
"""
if is_torch_version(">=", "2.2.0"):
from torch.distributed.device_mesh import init_device_mesh
else:
raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0")
mesh = self._get_mesh()
if len(mesh) == 0:
return None

View File

@ -76,7 +76,7 @@ class TestParallelismConfig:
return mesh
with patch("accelerate.parallelism_config.init_device_mesh", side_effect=mock_init_mesh):
with patch("torch.distributed.device_mesh.init_device_mesh", side_effect=mock_init_mesh):
yield mock_init_mesh
@pytest.mark.parametrize(