Make distributed modules importable even when backend not built (#159889)

This PR is greatly simplified now that it stacked on top of a PR that builds with distributed always. We only need to stub functions that may not be defined due to a backend not being enabled.

Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159889
Approved by: https://github.com/wconstab
ghstack dependencies: #160449
This commit is contained in:
Edward Z. Yang
2025-09-04 12:58:51 -04:00
committed by PyTorch MergeBot
parent 95ee0bfea9
commit ef3be6726f
21 changed files with 630 additions and 224 deletions

View File

@ -13,6 +13,8 @@ if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available(
fi fi
popd popd
python -mpip install -r requirements.txt
# enable debug asserts in serialization # enable debug asserts in serialization
export TORCH_SERIALIZATION_DEBUG=1 export TORCH_SERIALIZATION_DEBUG=1

View File

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Shard
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed.fake_pg import FakeStore
class TestFakeDTensor(TestCase):
def test_fake_dtensor_operations(self):
# Use FakeTensorMode to handle CUDA tensors without actual CUDA
fake_mode = FakeTensorMode()
world_size = 4
fake_store = FakeStore()
torch.distributed.init_process_group(
"fake", store=fake_store, rank=0, world_size=world_size
)
device_mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda",
(2, world_size // 2),
)
# Create fake CUDA tensor using FakeTensorMode
with fake_mode:
x = torch.randn(1, 1, device="cuda")
x = DTensor.from_local(x, device_mesh, [Shard(0), Shard(1)])
# Test basic DTensor operations
self.assertIsInstance(x, DTensor)
# Test sum operation
r = x.sum(1)
self.assertIsInstance(r, DTensor)
if __name__ == "__main__":
run_tests()

View File

@ -7,7 +7,7 @@ import sys
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing.context import SpawnProcess from multiprocessing.context import SpawnProcess
from typing import Any, Optional from typing import Any, Optional
from unittest import skipUnless from unittest import skipIf, skipUnless
from unittest.mock import mock_open, patch from unittest.mock import mock_open, patch
import torch import torch
@ -22,7 +22,7 @@ from torch.numa.binding import (
AffinityMode, AffinityMode,
NumaOptions, NumaOptions,
) )
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import IS_MACOS, run_tests, TestCase
@dataclass(frozen=True) @dataclass(frozen=True)
@ -680,6 +680,7 @@ class NumaBindingTest(TestCase):
set(range(0, 2)), set(range(0, 2)),
) )
@skipIf(IS_MACOS, "sched_getaffinity doesn't exist")
def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None: def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None:
self._add_mock_hardware( self._add_mock_hardware(
num_sockets=1, num_sockets=1,

View File

@ -851,3 +851,12 @@ class ProcessGroupXCCL(Backend):
def _set_process_group(pg: ProcessGroup) -> None: ... def _set_process_group(pg: ProcessGroup) -> None: ...
def _current_process_group() -> ProcessGroup: ... def _current_process_group() -> ProcessGroup: ...
def _dump_nccl_trace_json(
includeCollectives: Optional[bool] = ...,
onlyActive: Optional[bool] = ...,
) -> bytes: ...
def _dump_nccl_trace(
includeCollectives: Optional[bool] = ...,
includeStackTraces: Optional[bool] = ...,
onlyActive: Optional[bool] = ...,
) -> bytes: ...

View File

@ -0,0 +1,150 @@
# mypy: allow-untyped-defs
"""
Python stubs for backend-specific distributed components.
Since _C._distributed_c10d always exists now, this module only provides
stubs for backend-specific functionality that may not be available in all builds
(e.g., NCCL, UCC, MPI, Gloo, etc.).
"""
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
from torch._C._distributed_c10d import Store
if TYPE_CHECKING:
from datetime import timedelta
import torch
# Store classes
class HashStore(Store):
"""Stub HashStore for builds without this functionality."""
def __init__(self, *args, **kwargs):
self._data = {}
def set(self, key: str, value: str):
self._data[key] = value
def get(self, key: str) -> bytes:
return self._data.get(key, "").encode()
# Backend-specific process group stubs
class ProcessGroupMPI:
"""Stub ProcessGroupMPI for non-MPI builds."""
def __init__(self, *args, **kwargs):
pass
class ProcessGroupNCCL:
"""Stub ProcessGroupNCCL for non-NCCL builds."""
def __init__(self, *args, **kwargs):
pass
class ProcessGroupGloo:
"""Stub ProcessGroupGloo for non-Gloo builds."""
def __init__(self, *args, **kwargs):
pass
class ProcessGroupUCC:
"""Stub ProcessGroupUCC for non-UCC builds."""
def __init__(self, *args, **kwargs):
pass
class ProcessGroupXCCL:
"""Stub ProcessGroupXCCL for non-XCCL builds."""
def __init__(self, *args, **kwargs):
pass
class _ProcessGroupWrapper:
"""Stub _ProcessGroupWrapper for non-Gloo builds."""
def __init__(self, process_group, *args, **kwargs):
self._process_group = process_group
def __getattr__(self, name):
return getattr(self._process_group, name)
# NCCL-specific function stubs
_DEFAULT_PG_NCCL_TIMEOUT: Optional[timedelta] = None
def _hash_tensors(tensors):
"""Stub function to hash tensors - returns dummy hash."""
return 0
def _dump_nccl_trace_json(
includeCollectives: Optional[bool] = None, onlyActive: Optional[bool] = None
) -> bytes:
"""Stub function that returns empty JSON trace."""
return b"{}"
def _dump_nccl_trace(
includeCollectives: Optional[bool] = None,
includeStackTraces: Optional[bool] = None,
onlyActive: Optional[bool] = None,
) -> bytes:
"""Stub function that returns empty pickle trace."""
return b""
# NVSHMEM/SymmetricMemory stubs
def _is_nvshmem_available() -> bool:
"""Stub function that returns False indicating NVSHMEM is not available."""
return False
def _nvshmemx_cumodule_init(module: int) -> None:
"""Stub function for NVSHMEM CU module initialization."""
class _SymmetricMemory:
"""Stub _SymmetricMemory class for builds without this functionality."""
def __init__(self, *args, **kwargs):
pass
@classmethod
def empty_strided_p2p(cls, size, stride, dtype, device, group_name=None):
"""Stub that returns a regular tensor."""
return torch.empty(size, dtype=dtype, device=device)
@classmethod
def rendezvous(cls, tensor, group_name=None):
"""Stub that returns None."""
return None
@classmethod
def set_group_info(cls, *args, **kwargs):
"""Stub that does nothing."""
@classmethod
def set_backend(cls, name):
"""Stub that does nothing."""
@classmethod
def get_backend(cls, device):
"""Stub that returns None."""
return None
@classmethod
def has_multicast_support(cls, device_type, device_index):
"""Stub that returns False."""
return False

View File

@ -30,132 +30,124 @@ DistNetworkError = torch._C._DistNetworkError
DistStoreError = torch._C._DistStoreError DistStoreError = torch._C._DistStoreError
QueueEmptyError = torch._C._DistQueueEmptyError QueueEmptyError = torch._C._DistQueueEmptyError
if is_available(): from torch.distributed._distributed_c10d import (
from torch._C._distributed_c10d import ( _broadcast_coalesced,
_broadcast_coalesced, _compute_bucket_assignment_by_size,
_compute_bucket_assignment_by_size, _ControlCollectives,
_ControlCollectives, _DEFAULT_FIRST_BUCKET_BYTES,
_DEFAULT_FIRST_BUCKET_BYTES, _make_nccl_premul_sum,
_make_nccl_premul_sum, _register_builtin_comm_hook,
_register_builtin_comm_hook, _register_comm_hook,
_register_comm_hook, _StoreCollectives,
_StoreCollectives, _test_python_store,
_test_python_store, _verify_params_across_processes,
_verify_params_across_processes, Backend as _Backend,
Backend as _Backend, BuiltinCommHookType,
BuiltinCommHookType, DebugLevel,
DebugLevel, FileStore,
FileStore, get_debug_level,
get_debug_level, GradBucket,
GradBucket, Logger,
Logger, PrefixStore,
PrefixStore, ProcessGroup as ProcessGroup,
ProcessGroup as ProcessGroup, Reducer,
Reducer, set_debug_level,
set_debug_level, set_debug_level_from_env,
set_debug_level_from_env, Store,
Store, TCPStore,
TCPStore, Work as _Work,
Work as _Work, )
)
class _DistributedPdb(pdb.Pdb):
"""
Supports using PDB from inside a multiprocessing child process.
Usage: class _DistributedPdb(pdb.Pdb):
_DistributedPdb().set_trace() """
""" Supports using PDB from inside a multiprocessing child process.
def interaction(self, *args, **kwargs): Usage:
_stdin = sys.stdin _DistributedPdb().set_trace()
try: """
sys.stdin = open("/dev/stdin")
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
_breakpoint_cache: dict[int, typing.Any] = {} def interaction(self, *args, **kwargs):
_stdin = sys.stdin
def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600):
"""
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be
done with the breakpoint before continuing.
Args:
rank (int): Which rank to break on. Default: ``0``
skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``.
"""
if skip > 0:
key = hash(str(traceback.format_exc()))
counter = _breakpoint_cache.get(key, 0) + 1
_breakpoint_cache[key] = counter
if counter <= skip:
log.warning("Skip the breakpoint, counter=%d", counter)
return
# avoid having the default timeout (if short) interrupt your debug session
if timeout_s is not None:
for group in torch.distributed.distributed_c10d._pg_map:
torch.distributed.distributed_c10d._set_pg_timeout(
timedelta(seconds=timeout_s), group
)
if get_rank() == rank:
pdb = _DistributedPdb()
pdb.message(
"\n!!! ATTENTION !!!\n\n"
f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
)
pdb.set_trace()
# If Meta/Python keys are in the TLS, we want to make sure that we ignore them
# and hit the (default) CPU/CUDA implementation of barrier.
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
torch._C._set_meta_in_tls_dispatch_include(False)
try: try:
barrier() sys.stdin = open("/dev/stdin")
pdb.Pdb.interaction(self, *args, **kwargs)
finally: finally:
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) sys.stdin = _stdin
del guard
if sys.platform != "win32":
from torch._C._distributed_c10d import HashStore
from .device_mesh import DeviceMesh, init_device_mesh _breakpoint_cache: dict[int, typing.Any] = {}
# Variables prefixed with underscore are not auto imported
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
# this.
from .distributed_c10d import * # noqa: F403
from .distributed_c10d import (
_all_gather_base,
_coalescing_manager,
_CoalescingManager,
_create_process_group_wrapper,
_get_process_group_name,
_rank_not_in_group,
_reduce_scatter_base,
_time_estimator,
get_node_local_rank,
)
from .remote_device import _remote_device
from .rendezvous import (
_create_store_from_options,
register_rendezvous_handler,
rendezvous,
)
set_debug_level_from_env() def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600):
"""
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be
done with the breakpoint before continuing.
else: Args:
# This stub is sufficient to get rank (int): Which rank to break on. Default: ``0``
# python test/test_public_bindings.py -k test_correct_module_names skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``.
# working even when USE_DISTRIBUTED=0. Feel free to add more """
# stubs as necessary. if skip > 0:
# We cannot define stubs directly because they confuse pyre key = hash(str(traceback.format_exc()))
counter = _breakpoint_cache.get(key, 0) + 1
_breakpoint_cache[key] = counter
if counter <= skip:
log.warning("Skip the breakpoint, counter=%d", counter)
return
class _ProcessGroupStub: # avoid having the default timeout (if short) interrupt your debug session
pass if timeout_s is not None:
for group in torch.distributed.distributed_c10d._pg_map:
torch.distributed.distributed_c10d._set_pg_timeout(
timedelta(seconds=timeout_s), group
)
sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined] if get_rank() == rank:
pdb = _DistributedPdb()
pdb.message(
"\n!!! ATTENTION !!!\n\n"
f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
)
pdb.set_trace()
# If Meta/Python keys are in the TLS, we want to make sure that we ignore them
# and hit the (default) CPU/CUDA implementation of barrier.
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
torch._C._set_meta_in_tls_dispatch_include(False)
try:
barrier()
finally:
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
del guard
if sys.platform != "win32":
from torch.distributed._distributed_c10d import HashStore
from .device_mesh import DeviceMesh, init_device_mesh
# Variables prefixed with underscore are not auto imported
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
# this.
from .distributed_c10d import * # noqa: F403
from .distributed_c10d import (
_all_gather_base,
_coalescing_manager,
_CoalescingManager,
_create_process_group_wrapper,
_get_process_group_name,
_rank_not_in_group,
_reduce_scatter_base,
_time_estimator,
get_node_local_rank,
)
from .remote_device import _remote_device
from .rendezvous import (
_create_store_from_options,
register_rendezvous_handler,
rendezvous,
)
set_debug_level_from_env()

View File

@ -10,7 +10,7 @@ from datetime import timedelta
from typing import Protocol, Union from typing import Protocol, Union
import torch import torch
from torch._C._distributed_c10d import ( from torch.distributed._distributed_c10d import (
_current_process_group, _current_process_group,
_set_process_group, _set_process_group,
ProcessGroup, ProcessGroup,

View File

@ -0,0 +1,238 @@
# mypy: disable-error-code="assignment"
# noqa: F401
"""
Centralized module for importing and re-exporting torch._C._distributed_c10d components.
IMPORTANT PATTERN:
Never access torch._C._distributed_c10d directly in code. Always import from and use
torch.distributed._distributed_c10d which is guaranteed to have all functions available.
Example:
# WRONG: torch._C._distributed_c10d._set_global_rank(rank)
# RIGHT:
from torch.distributed._distributed_c10d import _set_global_rank
_set_global_rank(rank)
"""
from typing import TYPE_CHECKING
# Import all core distributed components from the C extension
# NB: This list has to be spelled out because the _C module doesn't have __all__
from torch._C._distributed_c10d import (
_allow_inflight_collective_as_graph_input,
_broadcast_coalesced,
_compute_bucket_assignment_by_size,
_ControlCollectives,
_current_process_group,
_DEFAULT_FIRST_BUCKET_BYTES,
_DEFAULT_PG_TIMEOUT,
_DistributedBackendOptions,
_make_nccl_premul_sum,
_register_builtin_comm_hook,
_register_comm_hook,
_register_process_group,
_register_work,
_resolve_process_group,
_set_allow_inflight_collective_as_graph_input,
_set_global_rank,
_set_process_group,
_StoreCollectives,
_test_python_store,
_unregister_all_process_groups,
_unregister_process_group,
_verify_params_across_processes,
_WorkerServer,
AllgatherOptions,
AllreduceCoalescedOptions,
AllreduceOptions,
AllToAllOptions,
Backend,
BarrierOptions,
BroadcastOptions,
BuiltinCommHookType,
DebugLevel,
FakeProcessGroup,
FakeWork,
FileStore,
GatherOptions,
get_debug_level,
GradBucket,
Logger,
PrefixStore,
ProcessGroup,
ReduceOp,
ReduceOptions,
Reducer,
ReduceScatterOptions,
ScatterOptions,
set_debug_level,
set_debug_level_from_env,
Store,
TCPStore,
Work,
)
# Backend-specific components that may not be available
_MPI_AVAILABLE = False
_NCCL_AVAILABLE = False
_GLOO_AVAILABLE = False
_UCC_AVAILABLE = False
_XCCL_AVAILABLE = False
# HashStore
try:
from torch._C._distributed_c10d import HashStore
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import HashStore
# NVSHMEM/SymmetricMemory components
try:
from torch._C._distributed_c10d import (
_is_nvshmem_available,
_nvshmemx_cumodule_init,
_SymmetricMemory,
)
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import (
_is_nvshmem_available,
_nvshmemx_cumodule_init,
_SymmetricMemory,
)
# MPI backend
try:
from torch._C._distributed_c10d import ProcessGroupMPI
_MPI_AVAILABLE = True
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import ProcessGroupMPI
# NCCL backend
try:
from torch._C._distributed_c10d import (
_DEFAULT_PG_NCCL_TIMEOUT,
_dump_nccl_trace,
_dump_nccl_trace_json,
_hash_tensors,
ProcessGroupNCCL,
)
_NCCL_AVAILABLE = True
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import (
_DEFAULT_PG_NCCL_TIMEOUT,
_dump_nccl_trace,
_dump_nccl_trace_json,
_hash_tensors,
ProcessGroupNCCL,
)
# Gloo backend
try:
from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo
_GLOO_AVAILABLE = True
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import _ProcessGroupWrapper, ProcessGroupGloo
# UCC backend
try:
from torch._C._distributed_c10d import ProcessGroupUCC
_UCC_AVAILABLE = True
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import ProcessGroupUCC
# XCCL backend
try:
from torch._C._distributed_c10d import ProcessGroupXCCL
_XCCL_AVAILABLE = True
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import ProcessGroupXCCL
# Provide backwards compatibility by making all symbols available at module level
__all__ = [
# Basic components
"_broadcast_coalesced",
"_compute_bucket_assignment_by_size",
"_ControlCollectives",
"_DEFAULT_FIRST_BUCKET_BYTES",
"_DEFAULT_PG_TIMEOUT",
"_DEFAULT_PG_NCCL_TIMEOUT",
"_make_nccl_premul_sum",
"_register_builtin_comm_hook",
"_register_comm_hook",
"_StoreCollectives",
"_test_python_store",
"_verify_params_across_processes",
"_allow_inflight_collective_as_graph_input",
"_register_work",
"_set_allow_inflight_collective_as_graph_input",
"_is_nvshmem_available",
"_nvshmemx_cumodule_init",
"_SymmetricMemory",
"_hash_tensors",
"_set_global_rank",
"_dump_nccl_trace",
"_dump_nccl_trace_json",
"Backend",
"BuiltinCommHookType",
"DebugLevel",
"FakeProcessGroup",
"FileStore",
"get_debug_level",
"GradBucket",
"HashStore",
"Logger",
"PrefixStore",
"ProcessGroup",
"Reducer",
"ReduceOp",
"set_debug_level",
"set_debug_level_from_env",
"Store",
"TCPStore",
"Work",
"FakeWork",
# Additional distributed_c10d components
"_DistributedBackendOptions",
"_register_process_group",
"_resolve_process_group",
"_unregister_all_process_groups",
"_unregister_process_group",
"_current_process_group",
"_set_process_group",
"_WorkerServer",
"AllgatherOptions",
"AllreduceCoalescedOptions",
"AllreduceOptions",
"AllToAllOptions",
"BarrierOptions",
"BroadcastOptions",
"GatherOptions",
"ReduceOptions",
"ReduceScatterOptions",
"ScatterOptions",
# Process group implementations
"ProcessGroupMPI",
"ProcessGroupNCCL",
"ProcessGroupGloo",
"ProcessGroupUCC",
"ProcessGroupXCCL",
"_ProcessGroupWrapper",
# Availability flags
"_MPI_AVAILABLE",
"_NCCL_AVAILABLE",
"_GLOO_AVAILABLE",
"_UCC_AVAILABLE",
"_XCCL_AVAILABLE",
]

View File

@ -7,6 +7,10 @@ from typing import Any, cast, Optional, TYPE_CHECKING, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed.distributed_c10d as c10d import torch.distributed.distributed_c10d as c10d
from torch.distributed._distributed_c10d import (
_allow_inflight_collective_as_graph_input,
_set_allow_inflight_collective_as_graph_input,
)
from torch.distributed.device_mesh import DeviceMesh from torch.distributed.device_mesh import DeviceMesh
from torch.fx.experimental.proxy_tensor import get_proxy_mode from torch.fx.experimental.proxy_tensor import get_proxy_mode
@ -853,15 +857,13 @@ def allow_inflight_collective_as_graph_input_ctx(value: bool = True):
will be registered in the work registry, and the wait_tensor() in compiled region called on will be registered in the work registry, and the wait_tensor() in compiled region called on
the output tensor of the collective will wait on the correct work object. the output tensor of the collective will wait on the correct work object.
""" """
previous = torch._C._distributed_c10d._allow_inflight_collective_as_graph_input() previous = _allow_inflight_collective_as_graph_input()
try: try:
torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value) _set_allow_inflight_collective_as_graph_input(value)
yield yield
finally: finally:
torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input( _set_allow_inflight_collective_as_graph_input(previous)
previous
)
def _make_all_gather_out_tensor(input, group_size): def _make_all_gather_out_tensor(input, group_size):

View File

@ -4,7 +4,7 @@ import copy
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed._shard.sharding_spec as shard_spec import torch.distributed._shard.sharding_spec as shard_spec
from torch._C._distributed_c10d import ProcessGroup from torch.distributed._distributed_c10d import ProcessGroup
from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharding_spec._internals import ( from torch.distributed._shard.sharding_spec._internals import (
get_chunked_dim_size, get_chunked_dim_size,

View File

@ -4,7 +4,7 @@ from typing import cast
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch._C._distributed_c10d import ReduceOp from torch.distributed._distributed_c10d import ReduceOp
from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._shard.sharding_spec import ChunkShardingSpec from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op

View File

@ -15,7 +15,12 @@ import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d import torch.distributed.distributed_c10d as c10d
from torch._C._autograd import DeviceType from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work from torch.distributed._distributed_c10d import (
_register_work,
_SymmetricMemory,
ProcessGroup,
Work as _Work,
)
_group_name_to_store: dict[str, c10d.Store] = {} _group_name_to_store: dict[str, c10d.Store] = {}
@ -1488,7 +1493,7 @@ def _low_contention_all_gather(
src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype) src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype)
chunks[remote_rank].copy_(src_buf) chunks[remote_rank].copy_(src_buf)
symm_mem.barrier() symm_mem.barrier()
torch._C._distributed_c10d._register_work(output, Work()) _register_work(output, Work())
return output return output
@ -1536,7 +1541,7 @@ def _low_contention_reduce_scatter_with_symm_mem_input(
ret = ret.mean(dim=0) ret = ret.mean(dim=0)
else: else:
raise ValueError(f"reduce_op ({reduce_op}) is not supported") raise ValueError(f"reduce_op ({reduce_op}) is not supported")
torch._C._distributed_c10d._register_work(ret, Work()) _register_work(ret, Work())
return ret return ret
@ -1571,7 +1576,7 @@ def _low_contention_reduce_scatter_with_workspace(
ret = ret.mean(dim=0) ret = ret.mean(dim=0)
else: else:
raise ValueError(f"reduce_op ({reduce_op}) is not supported") raise ValueError(f"reduce_op ({reduce_op}) is not supported")
torch._C._distributed_c10d._register_work(ret, Work()) _register_work(ret, Work())
return ret return ret
@ -1649,7 +1654,6 @@ from typing import overload, TYPE_CHECKING, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from torch._C._distributed_c10d import ProcessGroup
from torch.types import _device, _dtype, _int from torch.types import _device, _dtype, _int
@ -1727,8 +1731,6 @@ def rendezvous(
group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the
participating processes. This can be either a group name or a process group object. participating processes. This can be either a group name or a process group object.
""" """
from torch._C._distributed_c10d import ProcessGroup
if isinstance(group, str): if isinstance(group, str):
group_name = group group_name = group
elif isinstance(group, ProcessGroup): elif isinstance(group, ProcessGroup):
@ -1746,11 +1748,7 @@ def is_nvshmem_available() -> bool:
Check if NVSHMEM is available in current build and on current system. Check if NVSHMEM is available in current build and on current system.
""" """
try: from torch.distributed._distributed_c10d import _is_nvshmem_available
from torch._C._distributed_c10d import _is_nvshmem_available
except ImportError:
# Not all builds have NVSHMEM support.
return False
# Check if NVSHMEM is available on current system. # Check if NVSHMEM is available on current system.
return _is_nvshmem_available() return _is_nvshmem_available()

View File

@ -75,7 +75,7 @@ def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]:
""" """
import triton import triton
from torch._C._distributed_c10d import _nvshmemx_cumodule_init from torch.distributed._distributed_c10d import _nvshmemx_cumodule_init
if lib_dir is not None: if lib_dir is not None:
lib_path = os.path.join(lib_dir, "libnvshmem_device.bc") lib_path = os.path.join(lib_dir, "libnvshmem_device.bc")

View File

@ -2,7 +2,9 @@ import random
from typing import Any from typing import Any
import torch import torch
from torch._C._distributed_c10d import (
# Import centralized distributed components
from torch.distributed._distributed_c10d import (
_resolve_process_group, _resolve_process_group,
FakeWork, FakeWork,
ProcessGroup, ProcessGroup,

View File

@ -1,7 +1,11 @@
from datetime import timedelta from datetime import timedelta
from typing import Optional from typing import Optional
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT # Import from centralized fallback module - no ImportError handling needed
from torch.distributed._distributed_c10d import (
_DEFAULT_PG_NCCL_TIMEOUT,
_DEFAULT_PG_TIMEOUT,
)
__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"] __all__ = ["default_pg_timeout", "default_pg_nccl_timeout"]
@ -16,11 +20,4 @@ default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT
# Later, we could consider merging them back together at the c++ layer if we can align on a same value. # Later, we could consider merging them back together at the c++ layer if we can align on a same value.
# (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1). # (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1).
try: default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT
from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT
default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT
except ImportError:
# if C++ NCCL support is not compiled, we don't have access to the default nccl value.
# if anyone is actually trying to use nccl in this state, it should error.
default_pg_nccl_timeout = None

View File

@ -11,35 +11,14 @@ from itertools import chain, zip_longest
from typing import Optional, TYPE_CHECKING, Union from typing import Optional, TYPE_CHECKING, Union
import torch import torch
from torch.distributed import is_available
from torch.utils._typing_utils import not_none from torch.utils._typing_utils import not_none
__all__ = ["init_device_mesh", "DeviceMesh"] __all__ = ["init_device_mesh", "DeviceMesh"]
if not is_available(): if True: # just to temporarily avoid reindentation
import sys from torch.distributed._distributed_c10d import Backend as C10dBackend
# We need to create the stubs when distributed is not available.
# Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```),
# since it would try to import ``torch.distributed.device_mesh`` or
# ``torch.distributed.init_device_mesh`` but cannot find them.
class _DeviceMeshStub:
pass
def _init_device_mesh_stub():
pass
sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]
sys.modules[
"torch.distributed.device_mesh"
].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined]
else:
from torch._C._distributed_c10d import Backend as C10dBackend
from torch.distributed.distributed_c10d import ( from torch.distributed.distributed_c10d import (
_get_default_group, _get_default_group,
_resolve_process_group, _resolve_process_group,
@ -526,15 +505,16 @@ else:
# heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host # heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
# NOTE: This device selection would only work for homogeneous hardware. # NOTE: This device selection would only work for homogeneous hardware.
num_devices_per_host = device_handle.device_count() num_devices_per_host = device_handle.device_count()
if ( if num_devices_per_host:
world_size > num_devices_per_host if (
and world_size % num_devices_per_host != 0 world_size > num_devices_per_host
): and world_size % num_devices_per_host != 0
raise RuntimeError( ):
f"DeviceMesh only support homogeneous hardware, but found " raise RuntimeError(
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" f"DeviceMesh only support homogeneous hardware, but found "
) f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
device_handle.set_device(get_rank() % num_devices_per_host) )
device_handle.set_device(get_rank() % num_devices_per_host)
return _get_default_group() return _get_default_group()

View File

@ -19,13 +19,21 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import deprecated from typing_extensions import deprecated
import torch import torch
import torch.distributed._distributed_c10d as _c10d
from torch._C import _DistStoreError as DistStoreError from torch._C import _DistStoreError as DistStoreError
from torch._C._distributed_c10d import ( from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
from torch.distributed._distributed_c10d import ( # Process group implementations; Availability flags
_DistributedBackendOptions, _DistributedBackendOptions,
_GLOO_AVAILABLE,
_MPI_AVAILABLE,
_NCCL_AVAILABLE,
_ProcessGroupWrapper,
_register_process_group, _register_process_group,
_resolve_process_group, _resolve_process_group,
_UCC_AVAILABLE,
_unregister_all_process_groups, _unregister_all_process_groups,
_unregister_process_group, _unregister_process_group,
_XCCL_AVAILABLE,
AllgatherOptions, AllgatherOptions,
AllreduceCoalescedOptions, AllreduceCoalescedOptions,
AllreduceOptions, AllreduceOptions,
@ -37,6 +45,11 @@ from torch._C._distributed_c10d import (
get_debug_level, get_debug_level,
PrefixStore, PrefixStore,
ProcessGroup, ProcessGroup,
ProcessGroupGloo,
ProcessGroupMPI,
ProcessGroupNCCL,
ProcessGroupUCC,
ProcessGroupXCCL,
ReduceOp, ReduceOp,
ReduceOptions, ReduceOptions,
ReduceScatterOptions, ReduceScatterOptions,
@ -44,7 +57,6 @@ from torch._C._distributed_c10d import (
Store, Store,
Work, Work,
) )
from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
from torch.monitor import _WaitCounter from torch.monitor import _WaitCounter
from torch.overrides import handle_torch_function, has_torch_function from torch.overrides import handle_torch_function, has_torch_function
from torch.utils._typing_utils import not_none from torch.utils._typing_utils import not_none
@ -131,17 +143,11 @@ __all__ = [
"split_group", "split_group",
] ]
_MPI_AVAILABLE = True
_NCCL_AVAILABLE = True
_GLOO_AVAILABLE = True
_UCC_AVAILABLE = True
_XCCL_AVAILABLE = True
_pickler = pickle.Pickler _pickler = pickle.Pickler
_unpickler = pickle.Unpickler _unpickler = pickle.Unpickler
# Change __module__ of all imported types from torch._C._distributed_c10d that are public # Change __module__ of all imported types from the distributed wrapper that are public
def _export_c_types() -> None: def _export_c_types() -> None:
_public_types_to_change_module = [ _public_types_to_change_module = [
AllreduceCoalescedOptions, AllreduceCoalescedOptions,
@ -167,45 +173,26 @@ def _export_c_types() -> None:
_export_c_types() _export_c_types()
try: # Add process groups to __all__ and set their module based on availability
from torch._C._distributed_c10d import ProcessGroupMPI if _MPI_AVAILABLE:
ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d" ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupMPI"] __all__ += ["ProcessGroupMPI"]
except ImportError:
_MPI_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupNCCL
if _NCCL_AVAILABLE:
ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d" ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupNCCL"] __all__ += ["ProcessGroupNCCL"]
except ImportError:
_NCCL_AVAILABLE = False
try:
from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo
if _GLOO_AVAILABLE:
ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d" ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupGloo"] __all__ += ["ProcessGroupGloo"]
except ImportError:
_GLOO_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupUCC
if _UCC_AVAILABLE:
ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d" ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupUCC"] __all__ += ["ProcessGroupUCC"]
except ImportError:
_UCC_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupXCCL
if _XCCL_AVAILABLE:
ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d" ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupXCCL"] __all__ += ["ProcessGroupXCCL"]
except ImportError:
_XCCL_AVAILABLE = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -1327,7 +1314,8 @@ def _get_default_store() -> Store:
def _update_default_pg(pg) -> None: def _update_default_pg(pg) -> None:
_world.default_pg = pg _world.default_pg = pg
rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1 rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1
torch._C._distributed_c10d._set_global_rank(rank)
_c10d._set_global_rank(rank)
def get_backend_config(group: Optional[ProcessGroup] = None) -> str: def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
@ -1964,7 +1952,7 @@ def _new_process_group_helper(
if device_id: if device_id:
pg.bound_device_id = device_id pg.bound_device_id = device_id
backend_class: torch._C._distributed_c10d.Backend backend_class: _c10d.Backend
for device, backend_str in backend_config.get_device_backend_map().items(): for device, backend_str in backend_config.get_device_backend_map().items():
# Use the group name as prefix in the default store, such that # Use the group name as prefix in the default store, such that
# a single store can be reused by multiple groups. # a single store can be reused by multiple groups.
@ -3079,7 +3067,9 @@ def _object_to_tensor(obj, device, group):
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
backend = get_backend(group) backend = get_backend(group)
if backend == Backend.NCCL: if backend == Backend.NCCL:
hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) from torch.distributed._distributed_c10d import _hash_tensors
hash = _hash_tensors([byte_tensor])
logger.warning( logger.warning(
"_object_to_tensor size: %s hash value: %s", "_object_to_tensor size: %s hash value: %s",
byte_tensor.numel(), byte_tensor.numel(),
@ -3094,7 +3084,9 @@ def _tensor_to_object(tensor, tensor_size, group):
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
backend = get_backend(group) backend = get_backend(group)
if backend == Backend.NCCL: if backend == Backend.NCCL:
hash = torch._C._distributed_c10d._hash_tensors([tensor]) from torch.distributed._distributed_c10d import _hash_tensors
hash = _hash_tensors([tensor])
logger.warning( logger.warning(
"_tensor_to_object size: %s hash value: %s", tensor.numel(), hash "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash
) )
@ -4971,7 +4963,7 @@ def monitored_barrier(
def _create_process_group_wrapper( def _create_process_group_wrapper(
wrapped_pg: torch._C._distributed_c10d.Backend, wrapped_pg: _c10d.Backend,
store_prefix: str, store_prefix: str,
store: Store, store: Store,
rank: int, rank: int,

View File

@ -14,7 +14,7 @@ TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
@contextmanager @contextmanager
def _worker_server(socket_path: str) -> Generator[None, None, None]: def _worker_server(socket_path: str) -> Generator[None, None, None]:
from torch._C._distributed_c10d import _WorkerServer from torch.distributed._distributed_c10d import _WorkerServer
server = _WorkerServer(socket_path) server = _WorkerServer(socket_path)
try: try:

View File

@ -37,7 +37,6 @@ if is_available():
import numbers import numbers
import torch.distributed.autograd as dist_autograd import torch.distributed.autograd as dist_autograd
from torch._C._distributed_c10d import Store
from torch._C._distributed_rpc import ( # noqa: F401 from torch._C._distributed_rpc import ( # noqa: F401
_cleanup_python_rpc_handler, _cleanup_python_rpc_handler,
_DEFAULT_INIT_METHOD, _DEFAULT_INIT_METHOD,
@ -70,6 +69,7 @@ if is_available():
RpcBackendOptions, RpcBackendOptions,
WorkerInfo, WorkerInfo,
) )
from torch.distributed._distributed_c10d import Store
if _is_tensorpipe_available: if _is_tensorpipe_available:
from torch._C._distributed_rpc import ( # noqa: F401 from torch._C._distributed_rpc import ( # noqa: F401

View File

@ -8,8 +8,10 @@ from typing import Optional
import torch import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._dtensor_spec as dtensor_spec import torch.distributed.tensor._dtensor_spec as dtensor_spec
from torch._C._distributed_c10d import _resolve_process_group
from torch._logging import warning_once from torch._logging import warning_once
# Import from centralized fallback module - no conditional imports needed
from torch.distributed._distributed_c10d import _resolve_process_group
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.distributed_c10d import ( from torch.distributed.distributed_c10d import (
_get_group_size_by_name, _get_group_size_by_name,

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import torch.distributed as dist import torch.distributed as dist
from torch._C._distributed_c10d import FakeProcessGroup from torch.distributed._distributed_c10d import FakeProcessGroup
class FakeStore(dist.Store): class FakeStore(dist.Store):