Revert "Make distributed modules importable even when backend not built (#159889)"

This reverts commit 626cb7df8161dd4ecb4fe43b60f37ce9076f56b1.

Reverted https://github.com/pytorch/pytorch/pull/159889 on behalf of https://github.com/jeanschmidt due to Breaking internal builds, can't be landed with forward fix due to internal tooling problems ([comment](https://github.com/pytorch/pytorch/pull/159889#issuecomment-3246677982))
This commit is contained in:
PyTorch MergeBot
2025-09-02 20:24:01 +00:00
parent 82f63c8f6d
commit 420c52ecf3
22 changed files with 221 additions and 639 deletions

View File

@ -11,14 +11,35 @@ from itertools import chain, zip_longest
from typing import Optional, TYPE_CHECKING, Union
import torch
from torch.distributed import is_available
from torch.utils._typing_utils import not_none
__all__ = ["init_device_mesh", "DeviceMesh"]
if True: # just to temporarily avoid reindentation
from torch.distributed._distributed_c10d import Backend as C10dBackend
if not is_available():
import sys
# We need to create the stubs when distributed is not available.
# Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```),
# since it would try to import ``torch.distributed.device_mesh`` or
# ``torch.distributed.init_device_mesh`` but cannot find them.
class _DeviceMeshStub:
pass
def _init_device_mesh_stub():
pass
sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]
sys.modules[
"torch.distributed.device_mesh"
].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined]
else:
from torch._C._distributed_c10d import Backend as C10dBackend
from torch.distributed.distributed_c10d import (
_get_default_group,
_resolve_process_group,
@ -505,16 +526,15 @@ if True: # just to temporarily avoid reindentation
# heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
# NOTE: This device selection would only work for homogeneous hardware.
num_devices_per_host = device_handle.device_count()
if num_devices_per_host:
if (
world_size > num_devices_per_host
and world_size % num_devices_per_host != 0
):
raise RuntimeError(
f"DeviceMesh only support homogeneous hardware, but found "
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
)
device_handle.set_device(get_rank() % num_devices_per_host)
if (
world_size > num_devices_per_host
and world_size % num_devices_per_host != 0
):
raise RuntimeError(
f"DeviceMesh only support homogeneous hardware, but found "
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
)
device_handle.set_device(get_rank() % num_devices_per_host)
return _get_default_group()