mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make distributed modules importable even when backend not built (#159889)
This PR is greatly simplified now that it stacked on top of a PR that builds with distributed always. We only need to stub functions that may not be defined due to a backend not being enabled. Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/159889 Approved by: https://github.com/wconstab ghstack dependencies: #160449
This commit is contained in:
committed by
PyTorch MergeBot
parent
95ee0bfea9
commit
ef3be6726f
@ -13,6 +13,8 @@ if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available(
|
|||||||
fi
|
fi
|
||||||
popd
|
popd
|
||||||
|
|
||||||
|
python -mpip install -r requirements.txt
|
||||||
|
|
||||||
# enable debug asserts in serialization
|
# enable debug asserts in serialization
|
||||||
export TORCH_SERIALIZATION_DEBUG=1
|
export TORCH_SERIALIZATION_DEBUG=1
|
||||||
|
|
||||||
|
41
test/distributed/tensor/test_fake.py
Normal file
41
test/distributed/tensor/test_fake.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
|
# Owner(s): ["oncall: distributed"]
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
from torch.distributed.tensor.placement_types import Shard
|
||||||
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
|
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||||
|
|
||||||
|
|
||||||
|
class TestFakeDTensor(TestCase):
|
||||||
|
def test_fake_dtensor_operations(self):
|
||||||
|
# Use FakeTensorMode to handle CUDA tensors without actual CUDA
|
||||||
|
fake_mode = FakeTensorMode()
|
||||||
|
world_size = 4
|
||||||
|
|
||||||
|
fake_store = FakeStore()
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
"fake", store=fake_store, rank=0, world_size=world_size
|
||||||
|
)
|
||||||
|
device_mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||||
|
"cuda",
|
||||||
|
(2, world_size // 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fake CUDA tensor using FakeTensorMode
|
||||||
|
with fake_mode:
|
||||||
|
x = torch.randn(1, 1, device="cuda")
|
||||||
|
x = DTensor.from_local(x, device_mesh, [Shard(0), Shard(1)])
|
||||||
|
|
||||||
|
# Test basic DTensor operations
|
||||||
|
self.assertIsInstance(x, DTensor)
|
||||||
|
|
||||||
|
# Test sum operation
|
||||||
|
r = x.sum(1)
|
||||||
|
self.assertIsInstance(r, DTensor)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
@ -7,7 +7,7 @@ import sys
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing.context import SpawnProcess
|
from multiprocessing.context import SpawnProcess
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from unittest import skipUnless
|
from unittest import skipIf, skipUnless
|
||||||
from unittest.mock import mock_open, patch
|
from unittest.mock import mock_open, patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -22,7 +22,7 @@ from torch.numa.binding import (
|
|||||||
AffinityMode,
|
AffinityMode,
|
||||||
NumaOptions,
|
NumaOptions,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import IS_MACOS, run_tests, TestCase
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -680,6 +680,7 @@ class NumaBindingTest(TestCase):
|
|||||||
set(range(0, 2)),
|
set(range(0, 2)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIf(IS_MACOS, "sched_getaffinity doesn't exist")
|
||||||
def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None:
|
def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None:
|
||||||
self._add_mock_hardware(
|
self._add_mock_hardware(
|
||||||
num_sockets=1,
|
num_sockets=1,
|
||||||
|
@ -851,3 +851,12 @@ class ProcessGroupXCCL(Backend):
|
|||||||
|
|
||||||
def _set_process_group(pg: ProcessGroup) -> None: ...
|
def _set_process_group(pg: ProcessGroup) -> None: ...
|
||||||
def _current_process_group() -> ProcessGroup: ...
|
def _current_process_group() -> ProcessGroup: ...
|
||||||
|
def _dump_nccl_trace_json(
|
||||||
|
includeCollectives: Optional[bool] = ...,
|
||||||
|
onlyActive: Optional[bool] = ...,
|
||||||
|
) -> bytes: ...
|
||||||
|
def _dump_nccl_trace(
|
||||||
|
includeCollectives: Optional[bool] = ...,
|
||||||
|
includeStackTraces: Optional[bool] = ...,
|
||||||
|
onlyActive: Optional[bool] = ...,
|
||||||
|
) -> bytes: ...
|
||||||
|
150
torch/distributed/_C_stubs.py
Normal file
150
torch/distributed/_C_stubs.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# mypy: allow-untyped-defs
|
||||||
|
"""
|
||||||
|
Python stubs for backend-specific distributed components.
|
||||||
|
|
||||||
|
Since _C._distributed_c10d always exists now, this module only provides
|
||||||
|
stubs for backend-specific functionality that may not be available in all builds
|
||||||
|
(e.g., NCCL, UCC, MPI, Gloo, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from torch._C._distributed_c10d import Store
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# Store classes
|
||||||
|
class HashStore(Store):
|
||||||
|
"""Stub HashStore for builds without this functionality."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._data = {}
|
||||||
|
|
||||||
|
def set(self, key: str, value: str):
|
||||||
|
self._data[key] = value
|
||||||
|
|
||||||
|
def get(self, key: str) -> bytes:
|
||||||
|
return self._data.get(key, "").encode()
|
||||||
|
|
||||||
|
|
||||||
|
# Backend-specific process group stubs
|
||||||
|
class ProcessGroupMPI:
|
||||||
|
"""Stub ProcessGroupMPI for non-MPI builds."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessGroupNCCL:
|
||||||
|
"""Stub ProcessGroupNCCL for non-NCCL builds."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessGroupGloo:
|
||||||
|
"""Stub ProcessGroupGloo for non-Gloo builds."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessGroupUCC:
|
||||||
|
"""Stub ProcessGroupUCC for non-UCC builds."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessGroupXCCL:
|
||||||
|
"""Stub ProcessGroupXCCL for non-XCCL builds."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _ProcessGroupWrapper:
|
||||||
|
"""Stub _ProcessGroupWrapper for non-Gloo builds."""
|
||||||
|
|
||||||
|
def __init__(self, process_group, *args, **kwargs):
|
||||||
|
self._process_group = process_group
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._process_group, name)
|
||||||
|
|
||||||
|
|
||||||
|
# NCCL-specific function stubs
|
||||||
|
_DEFAULT_PG_NCCL_TIMEOUT: Optional[timedelta] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_tensors(tensors):
|
||||||
|
"""Stub function to hash tensors - returns dummy hash."""
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _dump_nccl_trace_json(
|
||||||
|
includeCollectives: Optional[bool] = None, onlyActive: Optional[bool] = None
|
||||||
|
) -> bytes:
|
||||||
|
"""Stub function that returns empty JSON trace."""
|
||||||
|
return b"{}"
|
||||||
|
|
||||||
|
|
||||||
|
def _dump_nccl_trace(
|
||||||
|
includeCollectives: Optional[bool] = None,
|
||||||
|
includeStackTraces: Optional[bool] = None,
|
||||||
|
onlyActive: Optional[bool] = None,
|
||||||
|
) -> bytes:
|
||||||
|
"""Stub function that returns empty pickle trace."""
|
||||||
|
return b""
|
||||||
|
|
||||||
|
|
||||||
|
# NVSHMEM/SymmetricMemory stubs
|
||||||
|
def _is_nvshmem_available() -> bool:
|
||||||
|
"""Stub function that returns False indicating NVSHMEM is not available."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _nvshmemx_cumodule_init(module: int) -> None:
|
||||||
|
"""Stub function for NVSHMEM CU module initialization."""
|
||||||
|
|
||||||
|
|
||||||
|
class _SymmetricMemory:
|
||||||
|
"""Stub _SymmetricMemory class for builds without this functionality."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty_strided_p2p(cls, size, stride, dtype, device, group_name=None):
|
||||||
|
"""Stub that returns a regular tensor."""
|
||||||
|
return torch.empty(size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def rendezvous(cls, tensor, group_name=None):
|
||||||
|
"""Stub that returns None."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_group_info(cls, *args, **kwargs):
|
||||||
|
"""Stub that does nothing."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_backend(cls, name):
|
||||||
|
"""Stub that does nothing."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_backend(cls, device):
|
||||||
|
"""Stub that returns None."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def has_multicast_support(cls, device_type, device_index):
|
||||||
|
"""Stub that returns False."""
|
||||||
|
return False
|
@ -30,132 +30,124 @@ DistNetworkError = torch._C._DistNetworkError
|
|||||||
DistStoreError = torch._C._DistStoreError
|
DistStoreError = torch._C._DistStoreError
|
||||||
QueueEmptyError = torch._C._DistQueueEmptyError
|
QueueEmptyError = torch._C._DistQueueEmptyError
|
||||||
|
|
||||||
if is_available():
|
from torch.distributed._distributed_c10d import (
|
||||||
from torch._C._distributed_c10d import (
|
_broadcast_coalesced,
|
||||||
_broadcast_coalesced,
|
_compute_bucket_assignment_by_size,
|
||||||
_compute_bucket_assignment_by_size,
|
_ControlCollectives,
|
||||||
_ControlCollectives,
|
_DEFAULT_FIRST_BUCKET_BYTES,
|
||||||
_DEFAULT_FIRST_BUCKET_BYTES,
|
_make_nccl_premul_sum,
|
||||||
_make_nccl_premul_sum,
|
_register_builtin_comm_hook,
|
||||||
_register_builtin_comm_hook,
|
_register_comm_hook,
|
||||||
_register_comm_hook,
|
_StoreCollectives,
|
||||||
_StoreCollectives,
|
_test_python_store,
|
||||||
_test_python_store,
|
_verify_params_across_processes,
|
||||||
_verify_params_across_processes,
|
Backend as _Backend,
|
||||||
Backend as _Backend,
|
BuiltinCommHookType,
|
||||||
BuiltinCommHookType,
|
DebugLevel,
|
||||||
DebugLevel,
|
FileStore,
|
||||||
FileStore,
|
get_debug_level,
|
||||||
get_debug_level,
|
GradBucket,
|
||||||
GradBucket,
|
Logger,
|
||||||
Logger,
|
PrefixStore,
|
||||||
PrefixStore,
|
ProcessGroup as ProcessGroup,
|
||||||
ProcessGroup as ProcessGroup,
|
Reducer,
|
||||||
Reducer,
|
set_debug_level,
|
||||||
set_debug_level,
|
set_debug_level_from_env,
|
||||||
set_debug_level_from_env,
|
Store,
|
||||||
Store,
|
TCPStore,
|
||||||
TCPStore,
|
Work as _Work,
|
||||||
Work as _Work,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
class _DistributedPdb(pdb.Pdb):
|
|
||||||
"""
|
|
||||||
Supports using PDB from inside a multiprocessing child process.
|
|
||||||
|
|
||||||
Usage:
|
class _DistributedPdb(pdb.Pdb):
|
||||||
_DistributedPdb().set_trace()
|
"""
|
||||||
"""
|
Supports using PDB from inside a multiprocessing child process.
|
||||||
|
|
||||||
def interaction(self, *args, **kwargs):
|
Usage:
|
||||||
_stdin = sys.stdin
|
_DistributedPdb().set_trace()
|
||||||
try:
|
"""
|
||||||
sys.stdin = open("/dev/stdin")
|
|
||||||
pdb.Pdb.interaction(self, *args, **kwargs)
|
|
||||||
finally:
|
|
||||||
sys.stdin = _stdin
|
|
||||||
|
|
||||||
_breakpoint_cache: dict[int, typing.Any] = {}
|
def interaction(self, *args, **kwargs):
|
||||||
|
_stdin = sys.stdin
|
||||||
def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600):
|
|
||||||
"""
|
|
||||||
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be
|
|
||||||
done with the breakpoint before continuing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rank (int): Which rank to break on. Default: ``0``
|
|
||||||
skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``.
|
|
||||||
"""
|
|
||||||
if skip > 0:
|
|
||||||
key = hash(str(traceback.format_exc()))
|
|
||||||
counter = _breakpoint_cache.get(key, 0) + 1
|
|
||||||
_breakpoint_cache[key] = counter
|
|
||||||
if counter <= skip:
|
|
||||||
log.warning("Skip the breakpoint, counter=%d", counter)
|
|
||||||
return
|
|
||||||
|
|
||||||
# avoid having the default timeout (if short) interrupt your debug session
|
|
||||||
if timeout_s is not None:
|
|
||||||
for group in torch.distributed.distributed_c10d._pg_map:
|
|
||||||
torch.distributed.distributed_c10d._set_pg_timeout(
|
|
||||||
timedelta(seconds=timeout_s), group
|
|
||||||
)
|
|
||||||
|
|
||||||
if get_rank() == rank:
|
|
||||||
pdb = _DistributedPdb()
|
|
||||||
pdb.message(
|
|
||||||
"\n!!! ATTENTION !!!\n\n"
|
|
||||||
f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
|
|
||||||
)
|
|
||||||
pdb.set_trace()
|
|
||||||
# If Meta/Python keys are in the TLS, we want to make sure that we ignore them
|
|
||||||
# and hit the (default) CPU/CUDA implementation of barrier.
|
|
||||||
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
|
|
||||||
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
|
|
||||||
torch._C._set_meta_in_tls_dispatch_include(False)
|
|
||||||
try:
|
try:
|
||||||
barrier()
|
sys.stdin = open("/dev/stdin")
|
||||||
|
pdb.Pdb.interaction(self, *args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
|
sys.stdin = _stdin
|
||||||
del guard
|
|
||||||
|
|
||||||
if sys.platform != "win32":
|
|
||||||
from torch._C._distributed_c10d import HashStore
|
|
||||||
|
|
||||||
from .device_mesh import DeviceMesh, init_device_mesh
|
_breakpoint_cache: dict[int, typing.Any] = {}
|
||||||
|
|
||||||
# Variables prefixed with underscore are not auto imported
|
|
||||||
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
|
|
||||||
# this.
|
|
||||||
from .distributed_c10d import * # noqa: F403
|
|
||||||
from .distributed_c10d import (
|
|
||||||
_all_gather_base,
|
|
||||||
_coalescing_manager,
|
|
||||||
_CoalescingManager,
|
|
||||||
_create_process_group_wrapper,
|
|
||||||
_get_process_group_name,
|
|
||||||
_rank_not_in_group,
|
|
||||||
_reduce_scatter_base,
|
|
||||||
_time_estimator,
|
|
||||||
get_node_local_rank,
|
|
||||||
)
|
|
||||||
from .remote_device import _remote_device
|
|
||||||
from .rendezvous import (
|
|
||||||
_create_store_from_options,
|
|
||||||
register_rendezvous_handler,
|
|
||||||
rendezvous,
|
|
||||||
)
|
|
||||||
|
|
||||||
set_debug_level_from_env()
|
def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600):
|
||||||
|
"""
|
||||||
|
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be
|
||||||
|
done with the breakpoint before continuing.
|
||||||
|
|
||||||
else:
|
Args:
|
||||||
# This stub is sufficient to get
|
rank (int): Which rank to break on. Default: ``0``
|
||||||
# python test/test_public_bindings.py -k test_correct_module_names
|
skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``.
|
||||||
# working even when USE_DISTRIBUTED=0. Feel free to add more
|
"""
|
||||||
# stubs as necessary.
|
if skip > 0:
|
||||||
# We cannot define stubs directly because they confuse pyre
|
key = hash(str(traceback.format_exc()))
|
||||||
|
counter = _breakpoint_cache.get(key, 0) + 1
|
||||||
|
_breakpoint_cache[key] = counter
|
||||||
|
if counter <= skip:
|
||||||
|
log.warning("Skip the breakpoint, counter=%d", counter)
|
||||||
|
return
|
||||||
|
|
||||||
class _ProcessGroupStub:
|
# avoid having the default timeout (if short) interrupt your debug session
|
||||||
pass
|
if timeout_s is not None:
|
||||||
|
for group in torch.distributed.distributed_c10d._pg_map:
|
||||||
|
torch.distributed.distributed_c10d._set_pg_timeout(
|
||||||
|
timedelta(seconds=timeout_s), group
|
||||||
|
)
|
||||||
|
|
||||||
sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]
|
if get_rank() == rank:
|
||||||
|
pdb = _DistributedPdb()
|
||||||
|
pdb.message(
|
||||||
|
"\n!!! ATTENTION !!!\n\n"
|
||||||
|
f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
|
||||||
|
)
|
||||||
|
pdb.set_trace()
|
||||||
|
# If Meta/Python keys are in the TLS, we want to make sure that we ignore them
|
||||||
|
# and hit the (default) CPU/CUDA implementation of barrier.
|
||||||
|
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
|
||||||
|
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
|
||||||
|
torch._C._set_meta_in_tls_dispatch_include(False)
|
||||||
|
try:
|
||||||
|
barrier()
|
||||||
|
finally:
|
||||||
|
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
|
||||||
|
del guard
|
||||||
|
|
||||||
|
|
||||||
|
if sys.platform != "win32":
|
||||||
|
from torch.distributed._distributed_c10d import HashStore
|
||||||
|
|
||||||
|
from .device_mesh import DeviceMesh, init_device_mesh
|
||||||
|
|
||||||
|
# Variables prefixed with underscore are not auto imported
|
||||||
|
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
|
||||||
|
# this.
|
||||||
|
from .distributed_c10d import * # noqa: F403
|
||||||
|
from .distributed_c10d import (
|
||||||
|
_all_gather_base,
|
||||||
|
_coalescing_manager,
|
||||||
|
_CoalescingManager,
|
||||||
|
_create_process_group_wrapper,
|
||||||
|
_get_process_group_name,
|
||||||
|
_rank_not_in_group,
|
||||||
|
_reduce_scatter_base,
|
||||||
|
_time_estimator,
|
||||||
|
get_node_local_rank,
|
||||||
|
)
|
||||||
|
from .remote_device import _remote_device
|
||||||
|
from .rendezvous import (
|
||||||
|
_create_store_from_options,
|
||||||
|
register_rendezvous_handler,
|
||||||
|
rendezvous,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
set_debug_level_from_env()
|
||||||
|
@ -10,7 +10,7 @@ from datetime import timedelta
|
|||||||
from typing import Protocol, Union
|
from typing import Protocol, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._C._distributed_c10d import (
|
from torch.distributed._distributed_c10d import (
|
||||||
_current_process_group,
|
_current_process_group,
|
||||||
_set_process_group,
|
_set_process_group,
|
||||||
ProcessGroup,
|
ProcessGroup,
|
||||||
|
238
torch/distributed/_distributed_c10d.py
Normal file
238
torch/distributed/_distributed_c10d.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
# mypy: disable-error-code="assignment"
|
||||||
|
# noqa: F401
|
||||||
|
"""
|
||||||
|
Centralized module for importing and re-exporting torch._C._distributed_c10d components.
|
||||||
|
|
||||||
|
IMPORTANT PATTERN:
|
||||||
|
Never access torch._C._distributed_c10d directly in code. Always import from and use
|
||||||
|
torch.distributed._distributed_c10d which is guaranteed to have all functions available.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# WRONG: torch._C._distributed_c10d._set_global_rank(rank)
|
||||||
|
# RIGHT:
|
||||||
|
from torch.distributed._distributed_c10d import _set_global_rank
|
||||||
|
_set_global_rank(rank)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
# Import all core distributed components from the C extension
|
||||||
|
# NB: This list has to be spelled out because the _C module doesn't have __all__
|
||||||
|
from torch._C._distributed_c10d import (
|
||||||
|
_allow_inflight_collective_as_graph_input,
|
||||||
|
_broadcast_coalesced,
|
||||||
|
_compute_bucket_assignment_by_size,
|
||||||
|
_ControlCollectives,
|
||||||
|
_current_process_group,
|
||||||
|
_DEFAULT_FIRST_BUCKET_BYTES,
|
||||||
|
_DEFAULT_PG_TIMEOUT,
|
||||||
|
_DistributedBackendOptions,
|
||||||
|
_make_nccl_premul_sum,
|
||||||
|
_register_builtin_comm_hook,
|
||||||
|
_register_comm_hook,
|
||||||
|
_register_process_group,
|
||||||
|
_register_work,
|
||||||
|
_resolve_process_group,
|
||||||
|
_set_allow_inflight_collective_as_graph_input,
|
||||||
|
_set_global_rank,
|
||||||
|
_set_process_group,
|
||||||
|
_StoreCollectives,
|
||||||
|
_test_python_store,
|
||||||
|
_unregister_all_process_groups,
|
||||||
|
_unregister_process_group,
|
||||||
|
_verify_params_across_processes,
|
||||||
|
_WorkerServer,
|
||||||
|
AllgatherOptions,
|
||||||
|
AllreduceCoalescedOptions,
|
||||||
|
AllreduceOptions,
|
||||||
|
AllToAllOptions,
|
||||||
|
Backend,
|
||||||
|
BarrierOptions,
|
||||||
|
BroadcastOptions,
|
||||||
|
BuiltinCommHookType,
|
||||||
|
DebugLevel,
|
||||||
|
FakeProcessGroup,
|
||||||
|
FakeWork,
|
||||||
|
FileStore,
|
||||||
|
GatherOptions,
|
||||||
|
get_debug_level,
|
||||||
|
GradBucket,
|
||||||
|
Logger,
|
||||||
|
PrefixStore,
|
||||||
|
ProcessGroup,
|
||||||
|
ReduceOp,
|
||||||
|
ReduceOptions,
|
||||||
|
Reducer,
|
||||||
|
ReduceScatterOptions,
|
||||||
|
ScatterOptions,
|
||||||
|
set_debug_level,
|
||||||
|
set_debug_level_from_env,
|
||||||
|
Store,
|
||||||
|
TCPStore,
|
||||||
|
Work,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Backend-specific components that may not be available
|
||||||
|
_MPI_AVAILABLE = False
|
||||||
|
_NCCL_AVAILABLE = False
|
||||||
|
_GLOO_AVAILABLE = False
|
||||||
|
_UCC_AVAILABLE = False
|
||||||
|
_XCCL_AVAILABLE = False
|
||||||
|
|
||||||
|
# HashStore
|
||||||
|
try:
|
||||||
|
from torch._C._distributed_c10d import HashStore
|
||||||
|
except ImportError:
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
from torch.distributed._C_stubs import HashStore
|
||||||
|
|
||||||
|
# NVSHMEM/SymmetricMemory components
|
||||||
|
try:
|
||||||
|
from torch._C._distributed_c10d import (
|
||||||
|
_is_nvshmem_available,
|
||||||
|
_nvshmemx_cumodule_init,
|
||||||
|
_SymmetricMemory,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
from torch.distributed._C_stubs import (
|
||||||
|
_is_nvshmem_available,
|
||||||
|
_nvshmemx_cumodule_init,
|
||||||
|
_SymmetricMemory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MPI backend
|
||||||
|
try:
|
||||||
|
from torch._C._distributed_c10d import ProcessGroupMPI
|
||||||
|
|
||||||
|
_MPI_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
from torch.distributed._C_stubs import ProcessGroupMPI
|
||||||
|
|
||||||
|
# NCCL backend
|
||||||
|
try:
|
||||||
|
from torch._C._distributed_c10d import (
|
||||||
|
_DEFAULT_PG_NCCL_TIMEOUT,
|
||||||
|
_dump_nccl_trace,
|
||||||
|
_dump_nccl_trace_json,
|
||||||
|
_hash_tensors,
|
||||||
|
ProcessGroupNCCL,
|
||||||
|
)
|
||||||
|
|
||||||
|
_NCCL_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
from torch.distributed._C_stubs import (
|
||||||
|
_DEFAULT_PG_NCCL_TIMEOUT,
|
||||||
|
_dump_nccl_trace,
|
||||||
|
_dump_nccl_trace_json,
|
||||||
|
_hash_tensors,
|
||||||
|
ProcessGroupNCCL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gloo backend
|
||||||
|
try:
|
||||||
|
from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo
|
||||||
|
|
||||||
|
_GLOO_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
from torch.distributed._C_stubs import _ProcessGroupWrapper, ProcessGroupGloo
|
||||||
|
|
||||||
|
# UCC backend
|
||||||
|
try:
|
||||||
|
from torch._C._distributed_c10d import ProcessGroupUCC
|
||||||
|
|
||||||
|
_UCC_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
from torch.distributed._C_stubs import ProcessGroupUCC
|
||||||
|
|
||||||
|
# XCCL backend
|
||||||
|
try:
|
||||||
|
from torch._C._distributed_c10d import ProcessGroupXCCL
|
||||||
|
|
||||||
|
_XCCL_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
from torch.distributed._C_stubs import ProcessGroupXCCL
|
||||||
|
|
||||||
|
# Provide backwards compatibility by making all symbols available at module level
|
||||||
|
__all__ = [
|
||||||
|
# Basic components
|
||||||
|
"_broadcast_coalesced",
|
||||||
|
"_compute_bucket_assignment_by_size",
|
||||||
|
"_ControlCollectives",
|
||||||
|
"_DEFAULT_FIRST_BUCKET_BYTES",
|
||||||
|
"_DEFAULT_PG_TIMEOUT",
|
||||||
|
"_DEFAULT_PG_NCCL_TIMEOUT",
|
||||||
|
"_make_nccl_premul_sum",
|
||||||
|
"_register_builtin_comm_hook",
|
||||||
|
"_register_comm_hook",
|
||||||
|
"_StoreCollectives",
|
||||||
|
"_test_python_store",
|
||||||
|
"_verify_params_across_processes",
|
||||||
|
"_allow_inflight_collective_as_graph_input",
|
||||||
|
"_register_work",
|
||||||
|
"_set_allow_inflight_collective_as_graph_input",
|
||||||
|
"_is_nvshmem_available",
|
||||||
|
"_nvshmemx_cumodule_init",
|
||||||
|
"_SymmetricMemory",
|
||||||
|
"_hash_tensors",
|
||||||
|
"_set_global_rank",
|
||||||
|
"_dump_nccl_trace",
|
||||||
|
"_dump_nccl_trace_json",
|
||||||
|
"Backend",
|
||||||
|
"BuiltinCommHookType",
|
||||||
|
"DebugLevel",
|
||||||
|
"FakeProcessGroup",
|
||||||
|
"FileStore",
|
||||||
|
"get_debug_level",
|
||||||
|
"GradBucket",
|
||||||
|
"HashStore",
|
||||||
|
"Logger",
|
||||||
|
"PrefixStore",
|
||||||
|
"ProcessGroup",
|
||||||
|
"Reducer",
|
||||||
|
"ReduceOp",
|
||||||
|
"set_debug_level",
|
||||||
|
"set_debug_level_from_env",
|
||||||
|
"Store",
|
||||||
|
"TCPStore",
|
||||||
|
"Work",
|
||||||
|
"FakeWork",
|
||||||
|
# Additional distributed_c10d components
|
||||||
|
"_DistributedBackendOptions",
|
||||||
|
"_register_process_group",
|
||||||
|
"_resolve_process_group",
|
||||||
|
"_unregister_all_process_groups",
|
||||||
|
"_unregister_process_group",
|
||||||
|
"_current_process_group",
|
||||||
|
"_set_process_group",
|
||||||
|
"_WorkerServer",
|
||||||
|
"AllgatherOptions",
|
||||||
|
"AllreduceCoalescedOptions",
|
||||||
|
"AllreduceOptions",
|
||||||
|
"AllToAllOptions",
|
||||||
|
"BarrierOptions",
|
||||||
|
"BroadcastOptions",
|
||||||
|
"GatherOptions",
|
||||||
|
"ReduceOptions",
|
||||||
|
"ReduceScatterOptions",
|
||||||
|
"ScatterOptions",
|
||||||
|
# Process group implementations
|
||||||
|
"ProcessGroupMPI",
|
||||||
|
"ProcessGroupNCCL",
|
||||||
|
"ProcessGroupGloo",
|
||||||
|
"ProcessGroupUCC",
|
||||||
|
"ProcessGroupXCCL",
|
||||||
|
"_ProcessGroupWrapper",
|
||||||
|
# Availability flags
|
||||||
|
"_MPI_AVAILABLE",
|
||||||
|
"_NCCL_AVAILABLE",
|
||||||
|
"_GLOO_AVAILABLE",
|
||||||
|
"_UCC_AVAILABLE",
|
||||||
|
"_XCCL_AVAILABLE",
|
||||||
|
]
|
@ -7,6 +7,10 @@ from typing import Any, cast, Optional, TYPE_CHECKING, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.distributed.distributed_c10d as c10d
|
import torch.distributed.distributed_c10d as c10d
|
||||||
|
from torch.distributed._distributed_c10d import (
|
||||||
|
_allow_inflight_collective_as_graph_input,
|
||||||
|
_set_allow_inflight_collective_as_graph_input,
|
||||||
|
)
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
from torch.fx.experimental.proxy_tensor import get_proxy_mode
|
from torch.fx.experimental.proxy_tensor import get_proxy_mode
|
||||||
|
|
||||||
@ -853,15 +857,13 @@ def allow_inflight_collective_as_graph_input_ctx(value: bool = True):
|
|||||||
will be registered in the work registry, and the wait_tensor() in compiled region called on
|
will be registered in the work registry, and the wait_tensor() in compiled region called on
|
||||||
the output tensor of the collective will wait on the correct work object.
|
the output tensor of the collective will wait on the correct work object.
|
||||||
"""
|
"""
|
||||||
previous = torch._C._distributed_c10d._allow_inflight_collective_as_graph_input()
|
previous = _allow_inflight_collective_as_graph_input()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value)
|
_set_allow_inflight_collective_as_graph_input(value)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(
|
_set_allow_inflight_collective_as_graph_input(previous)
|
||||||
previous
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_all_gather_out_tensor(input, group_size):
|
def _make_all_gather_out_tensor(input, group_size):
|
||||||
|
@ -4,7 +4,7 @@ import copy
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.distributed._shard.sharding_spec as shard_spec
|
import torch.distributed._shard.sharding_spec as shard_spec
|
||||||
from torch._C._distributed_c10d import ProcessGroup
|
from torch.distributed._distributed_c10d import ProcessGroup
|
||||||
from torch.distributed._shard.metadata import ShardMetadata
|
from torch.distributed._shard.metadata import ShardMetadata
|
||||||
from torch.distributed._shard.sharding_spec._internals import (
|
from torch.distributed._shard.sharding_spec._internals import (
|
||||||
get_chunked_dim_size,
|
get_chunked_dim_size,
|
||||||
|
@ -4,7 +4,7 @@ from typing import cast
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch._C._distributed_c10d import ReduceOp
|
from torch.distributed._distributed_c10d import ReduceOp
|
||||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||||
from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
|
from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
|
||||||
|
@ -15,7 +15,12 @@ import torch
|
|||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
import torch.distributed.distributed_c10d as c10d
|
import torch.distributed.distributed_c10d as c10d
|
||||||
from torch._C._autograd import DeviceType
|
from torch._C._autograd import DeviceType
|
||||||
from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work
|
from torch.distributed._distributed_c10d import (
|
||||||
|
_register_work,
|
||||||
|
_SymmetricMemory,
|
||||||
|
ProcessGroup,
|
||||||
|
Work as _Work,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_group_name_to_store: dict[str, c10d.Store] = {}
|
_group_name_to_store: dict[str, c10d.Store] = {}
|
||||||
@ -1488,7 +1493,7 @@ def _low_contention_all_gather(
|
|||||||
src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype)
|
src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype)
|
||||||
chunks[remote_rank].copy_(src_buf)
|
chunks[remote_rank].copy_(src_buf)
|
||||||
symm_mem.barrier()
|
symm_mem.barrier()
|
||||||
torch._C._distributed_c10d._register_work(output, Work())
|
_register_work(output, Work())
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -1536,7 +1541,7 @@ def _low_contention_reduce_scatter_with_symm_mem_input(
|
|||||||
ret = ret.mean(dim=0)
|
ret = ret.mean(dim=0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"reduce_op ({reduce_op}) is not supported")
|
raise ValueError(f"reduce_op ({reduce_op}) is not supported")
|
||||||
torch._C._distributed_c10d._register_work(ret, Work())
|
_register_work(ret, Work())
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@ -1571,7 +1576,7 @@ def _low_contention_reduce_scatter_with_workspace(
|
|||||||
ret = ret.mean(dim=0)
|
ret = ret.mean(dim=0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"reduce_op ({reduce_op}) is not supported")
|
raise ValueError(f"reduce_op ({reduce_op}) is not supported")
|
||||||
torch._C._distributed_c10d._register_work(ret, Work())
|
_register_work(ret, Work())
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@ -1649,7 +1654,6 @@ from typing import overload, TYPE_CHECKING, Union
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch._C._distributed_c10d import ProcessGroup
|
|
||||||
from torch.types import _device, _dtype, _int
|
from torch.types import _device, _dtype, _int
|
||||||
|
|
||||||
|
|
||||||
@ -1727,8 +1731,6 @@ def rendezvous(
|
|||||||
group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the
|
group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the
|
||||||
participating processes. This can be either a group name or a process group object.
|
participating processes. This can be either a group name or a process group object.
|
||||||
"""
|
"""
|
||||||
from torch._C._distributed_c10d import ProcessGroup
|
|
||||||
|
|
||||||
if isinstance(group, str):
|
if isinstance(group, str):
|
||||||
group_name = group
|
group_name = group
|
||||||
elif isinstance(group, ProcessGroup):
|
elif isinstance(group, ProcessGroup):
|
||||||
@ -1746,11 +1748,7 @@ def is_nvshmem_available() -> bool:
|
|||||||
|
|
||||||
Check if NVSHMEM is available in current build and on current system.
|
Check if NVSHMEM is available in current build and on current system.
|
||||||
"""
|
"""
|
||||||
try:
|
from torch.distributed._distributed_c10d import _is_nvshmem_available
|
||||||
from torch._C._distributed_c10d import _is_nvshmem_available
|
|
||||||
except ImportError:
|
|
||||||
# Not all builds have NVSHMEM support.
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if NVSHMEM is available on current system.
|
# Check if NVSHMEM is available on current system.
|
||||||
return _is_nvshmem_available()
|
return _is_nvshmem_available()
|
||||||
|
@ -75,7 +75,7 @@ def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]:
|
|||||||
"""
|
"""
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
from torch._C._distributed_c10d import _nvshmemx_cumodule_init
|
from torch.distributed._distributed_c10d import _nvshmemx_cumodule_init
|
||||||
|
|
||||||
if lib_dir is not None:
|
if lib_dir is not None:
|
||||||
lib_path = os.path.join(lib_dir, "libnvshmem_device.bc")
|
lib_path = os.path.join(lib_dir, "libnvshmem_device.bc")
|
||||||
|
@ -2,7 +2,9 @@ import random
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._C._distributed_c10d import (
|
|
||||||
|
# Import centralized distributed components
|
||||||
|
from torch.distributed._distributed_c10d import (
|
||||||
_resolve_process_group,
|
_resolve_process_group,
|
||||||
FakeWork,
|
FakeWork,
|
||||||
ProcessGroup,
|
ProcessGroup,
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
|
# Import from centralized fallback module - no ImportError handling needed
|
||||||
|
from torch.distributed._distributed_c10d import (
|
||||||
|
_DEFAULT_PG_NCCL_TIMEOUT,
|
||||||
|
_DEFAULT_PG_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"]
|
__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"]
|
||||||
@ -16,11 +20,4 @@ default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT
|
|||||||
# Later, we could consider merging them back together at the c++ layer if we can align on a same value.
|
# Later, we could consider merging them back together at the c++ layer if we can align on a same value.
|
||||||
# (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1).
|
# (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1).
|
||||||
|
|
||||||
try:
|
default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT
|
||||||
from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT
|
|
||||||
|
|
||||||
default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT
|
|
||||||
except ImportError:
|
|
||||||
# if C++ NCCL support is not compiled, we don't have access to the default nccl value.
|
|
||||||
# if anyone is actually trying to use nccl in this state, it should error.
|
|
||||||
default_pg_nccl_timeout = None
|
|
||||||
|
@ -11,35 +11,14 @@ from itertools import chain, zip_longest
|
|||||||
from typing import Optional, TYPE_CHECKING, Union
|
from typing import Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import is_available
|
|
||||||
from torch.utils._typing_utils import not_none
|
from torch.utils._typing_utils import not_none
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["init_device_mesh", "DeviceMesh"]
|
__all__ = ["init_device_mesh", "DeviceMesh"]
|
||||||
|
|
||||||
|
|
||||||
if not is_available():
|
if True: # just to temporarily avoid reindentation
|
||||||
import sys
|
from torch.distributed._distributed_c10d import Backend as C10dBackend
|
||||||
|
|
||||||
# We need to create the stubs when distributed is not available.
|
|
||||||
# Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```),
|
|
||||||
# since it would try to import ``torch.distributed.device_mesh`` or
|
|
||||||
# ``torch.distributed.init_device_mesh`` but cannot find them.
|
|
||||||
|
|
||||||
class _DeviceMeshStub:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _init_device_mesh_stub():
|
|
||||||
pass
|
|
||||||
|
|
||||||
sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]
|
|
||||||
sys.modules[
|
|
||||||
"torch.distributed.device_mesh"
|
|
||||||
].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
|
||||||
from torch._C._distributed_c10d import Backend as C10dBackend
|
|
||||||
from torch.distributed.distributed_c10d import (
|
from torch.distributed.distributed_c10d import (
|
||||||
_get_default_group,
|
_get_default_group,
|
||||||
_resolve_process_group,
|
_resolve_process_group,
|
||||||
@ -526,15 +505,16 @@ else:
|
|||||||
# heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
|
# heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
|
||||||
# NOTE: This device selection would only work for homogeneous hardware.
|
# NOTE: This device selection would only work for homogeneous hardware.
|
||||||
num_devices_per_host = device_handle.device_count()
|
num_devices_per_host = device_handle.device_count()
|
||||||
if (
|
if num_devices_per_host:
|
||||||
world_size > num_devices_per_host
|
if (
|
||||||
and world_size % num_devices_per_host != 0
|
world_size > num_devices_per_host
|
||||||
):
|
and world_size % num_devices_per_host != 0
|
||||||
raise RuntimeError(
|
):
|
||||||
f"DeviceMesh only support homogeneous hardware, but found "
|
raise RuntimeError(
|
||||||
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
|
f"DeviceMesh only support homogeneous hardware, but found "
|
||||||
)
|
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
|
||||||
device_handle.set_device(get_rank() % num_devices_per_host)
|
)
|
||||||
|
device_handle.set_device(get_rank() % num_devices_per_host)
|
||||||
|
|
||||||
return _get_default_group()
|
return _get_default_group()
|
||||||
|
|
||||||
|
@ -19,13 +19,21 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
|||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed._distributed_c10d as _c10d
|
||||||
from torch._C import _DistStoreError as DistStoreError
|
from torch._C import _DistStoreError as DistStoreError
|
||||||
from torch._C._distributed_c10d import (
|
from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
|
||||||
|
from torch.distributed._distributed_c10d import ( # Process group implementations; Availability flags
|
||||||
_DistributedBackendOptions,
|
_DistributedBackendOptions,
|
||||||
|
_GLOO_AVAILABLE,
|
||||||
|
_MPI_AVAILABLE,
|
||||||
|
_NCCL_AVAILABLE,
|
||||||
|
_ProcessGroupWrapper,
|
||||||
_register_process_group,
|
_register_process_group,
|
||||||
_resolve_process_group,
|
_resolve_process_group,
|
||||||
|
_UCC_AVAILABLE,
|
||||||
_unregister_all_process_groups,
|
_unregister_all_process_groups,
|
||||||
_unregister_process_group,
|
_unregister_process_group,
|
||||||
|
_XCCL_AVAILABLE,
|
||||||
AllgatherOptions,
|
AllgatherOptions,
|
||||||
AllreduceCoalescedOptions,
|
AllreduceCoalescedOptions,
|
||||||
AllreduceOptions,
|
AllreduceOptions,
|
||||||
@ -37,6 +45,11 @@ from torch._C._distributed_c10d import (
|
|||||||
get_debug_level,
|
get_debug_level,
|
||||||
PrefixStore,
|
PrefixStore,
|
||||||
ProcessGroup,
|
ProcessGroup,
|
||||||
|
ProcessGroupGloo,
|
||||||
|
ProcessGroupMPI,
|
||||||
|
ProcessGroupNCCL,
|
||||||
|
ProcessGroupUCC,
|
||||||
|
ProcessGroupXCCL,
|
||||||
ReduceOp,
|
ReduceOp,
|
||||||
ReduceOptions,
|
ReduceOptions,
|
||||||
ReduceScatterOptions,
|
ReduceScatterOptions,
|
||||||
@ -44,7 +57,6 @@ from torch._C._distributed_c10d import (
|
|||||||
Store,
|
Store,
|
||||||
Work,
|
Work,
|
||||||
)
|
)
|
||||||
from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
|
|
||||||
from torch.monitor import _WaitCounter
|
from torch.monitor import _WaitCounter
|
||||||
from torch.overrides import handle_torch_function, has_torch_function
|
from torch.overrides import handle_torch_function, has_torch_function
|
||||||
from torch.utils._typing_utils import not_none
|
from torch.utils._typing_utils import not_none
|
||||||
@ -131,17 +143,11 @@ __all__ = [
|
|||||||
"split_group",
|
"split_group",
|
||||||
]
|
]
|
||||||
|
|
||||||
_MPI_AVAILABLE = True
|
|
||||||
_NCCL_AVAILABLE = True
|
|
||||||
_GLOO_AVAILABLE = True
|
|
||||||
_UCC_AVAILABLE = True
|
|
||||||
_XCCL_AVAILABLE = True
|
|
||||||
|
|
||||||
_pickler = pickle.Pickler
|
_pickler = pickle.Pickler
|
||||||
_unpickler = pickle.Unpickler
|
_unpickler = pickle.Unpickler
|
||||||
|
|
||||||
|
|
||||||
# Change __module__ of all imported types from torch._C._distributed_c10d that are public
|
# Change __module__ of all imported types from the distributed wrapper that are public
|
||||||
def _export_c_types() -> None:
|
def _export_c_types() -> None:
|
||||||
_public_types_to_change_module = [
|
_public_types_to_change_module = [
|
||||||
AllreduceCoalescedOptions,
|
AllreduceCoalescedOptions,
|
||||||
@ -167,45 +173,26 @@ def _export_c_types() -> None:
|
|||||||
|
|
||||||
_export_c_types()
|
_export_c_types()
|
||||||
|
|
||||||
try:
|
# Add process groups to __all__ and set their module based on availability
|
||||||
from torch._C._distributed_c10d import ProcessGroupMPI
|
if _MPI_AVAILABLE:
|
||||||
|
|
||||||
ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d"
|
ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d"
|
||||||
__all__ += ["ProcessGroupMPI"]
|
__all__ += ["ProcessGroupMPI"]
|
||||||
except ImportError:
|
|
||||||
_MPI_AVAILABLE = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torch._C._distributed_c10d import ProcessGroupNCCL
|
|
||||||
|
|
||||||
|
if _NCCL_AVAILABLE:
|
||||||
ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
|
ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
|
||||||
__all__ += ["ProcessGroupNCCL"]
|
__all__ += ["ProcessGroupNCCL"]
|
||||||
except ImportError:
|
|
||||||
_NCCL_AVAILABLE = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo
|
|
||||||
|
|
||||||
|
if _GLOO_AVAILABLE:
|
||||||
ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d"
|
ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d"
|
||||||
__all__ += ["ProcessGroupGloo"]
|
__all__ += ["ProcessGroupGloo"]
|
||||||
except ImportError:
|
|
||||||
_GLOO_AVAILABLE = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torch._C._distributed_c10d import ProcessGroupUCC
|
|
||||||
|
|
||||||
|
if _UCC_AVAILABLE:
|
||||||
ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d"
|
ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d"
|
||||||
__all__ += ["ProcessGroupUCC"]
|
__all__ += ["ProcessGroupUCC"]
|
||||||
except ImportError:
|
|
||||||
_UCC_AVAILABLE = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torch._C._distributed_c10d import ProcessGroupXCCL
|
|
||||||
|
|
||||||
|
if _XCCL_AVAILABLE:
|
||||||
ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d"
|
ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d"
|
||||||
__all__ += ["ProcessGroupXCCL"]
|
__all__ += ["ProcessGroupXCCL"]
|
||||||
except ImportError:
|
|
||||||
_XCCL_AVAILABLE = False
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -1327,7 +1314,8 @@ def _get_default_store() -> Store:
|
|||||||
def _update_default_pg(pg) -> None:
|
def _update_default_pg(pg) -> None:
|
||||||
_world.default_pg = pg
|
_world.default_pg = pg
|
||||||
rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1
|
rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1
|
||||||
torch._C._distributed_c10d._set_global_rank(rank)
|
|
||||||
|
_c10d._set_global_rank(rank)
|
||||||
|
|
||||||
|
|
||||||
def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
|
def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
|
||||||
@ -1964,7 +1952,7 @@ def _new_process_group_helper(
|
|||||||
|
|
||||||
if device_id:
|
if device_id:
|
||||||
pg.bound_device_id = device_id
|
pg.bound_device_id = device_id
|
||||||
backend_class: torch._C._distributed_c10d.Backend
|
backend_class: _c10d.Backend
|
||||||
for device, backend_str in backend_config.get_device_backend_map().items():
|
for device, backend_str in backend_config.get_device_backend_map().items():
|
||||||
# Use the group name as prefix in the default store, such that
|
# Use the group name as prefix in the default store, such that
|
||||||
# a single store can be reused by multiple groups.
|
# a single store can be reused by multiple groups.
|
||||||
@ -3079,7 +3067,9 @@ def _object_to_tensor(obj, device, group):
|
|||||||
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
|
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
|
||||||
backend = get_backend(group)
|
backend = get_backend(group)
|
||||||
if backend == Backend.NCCL:
|
if backend == Backend.NCCL:
|
||||||
hash = torch._C._distributed_c10d._hash_tensors([byte_tensor])
|
from torch.distributed._distributed_c10d import _hash_tensors
|
||||||
|
|
||||||
|
hash = _hash_tensors([byte_tensor])
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"_object_to_tensor size: %s hash value: %s",
|
"_object_to_tensor size: %s hash value: %s",
|
||||||
byte_tensor.numel(),
|
byte_tensor.numel(),
|
||||||
@ -3094,7 +3084,9 @@ def _tensor_to_object(tensor, tensor_size, group):
|
|||||||
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
|
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
|
||||||
backend = get_backend(group)
|
backend = get_backend(group)
|
||||||
if backend == Backend.NCCL:
|
if backend == Backend.NCCL:
|
||||||
hash = torch._C._distributed_c10d._hash_tensors([tensor])
|
from torch.distributed._distributed_c10d import _hash_tensors
|
||||||
|
|
||||||
|
hash = _hash_tensors([tensor])
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"_tensor_to_object size: %s hash value: %s", tensor.numel(), hash
|
"_tensor_to_object size: %s hash value: %s", tensor.numel(), hash
|
||||||
)
|
)
|
||||||
@ -4971,7 +4963,7 @@ def monitored_barrier(
|
|||||||
|
|
||||||
|
|
||||||
def _create_process_group_wrapper(
|
def _create_process_group_wrapper(
|
||||||
wrapped_pg: torch._C._distributed_c10d.Backend,
|
wrapped_pg: _c10d.Backend,
|
||||||
store_prefix: str,
|
store_prefix: str,
|
||||||
store: Store,
|
store: Store,
|
||||||
rank: int,
|
rank: int,
|
||||||
|
@ -14,7 +14,7 @@ TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _worker_server(socket_path: str) -> Generator[None, None, None]:
|
def _worker_server(socket_path: str) -> Generator[None, None, None]:
|
||||||
from torch._C._distributed_c10d import _WorkerServer
|
from torch.distributed._distributed_c10d import _WorkerServer
|
||||||
|
|
||||||
server = _WorkerServer(socket_path)
|
server = _WorkerServer(socket_path)
|
||||||
try:
|
try:
|
||||||
|
@ -37,7 +37,6 @@ if is_available():
|
|||||||
import numbers
|
import numbers
|
||||||
|
|
||||||
import torch.distributed.autograd as dist_autograd
|
import torch.distributed.autograd as dist_autograd
|
||||||
from torch._C._distributed_c10d import Store
|
|
||||||
from torch._C._distributed_rpc import ( # noqa: F401
|
from torch._C._distributed_rpc import ( # noqa: F401
|
||||||
_cleanup_python_rpc_handler,
|
_cleanup_python_rpc_handler,
|
||||||
_DEFAULT_INIT_METHOD,
|
_DEFAULT_INIT_METHOD,
|
||||||
@ -70,6 +69,7 @@ if is_available():
|
|||||||
RpcBackendOptions,
|
RpcBackendOptions,
|
||||||
WorkerInfo,
|
WorkerInfo,
|
||||||
)
|
)
|
||||||
|
from torch.distributed._distributed_c10d import Store
|
||||||
|
|
||||||
if _is_tensorpipe_available:
|
if _is_tensorpipe_available:
|
||||||
from torch._C._distributed_rpc import ( # noqa: F401
|
from torch._C._distributed_rpc import ( # noqa: F401
|
||||||
|
@ -8,8 +8,10 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
import torch.distributed.tensor._dtensor_spec as dtensor_spec
|
import torch.distributed.tensor._dtensor_spec as dtensor_spec
|
||||||
from torch._C._distributed_c10d import _resolve_process_group
|
|
||||||
from torch._logging import warning_once
|
from torch._logging import warning_once
|
||||||
|
|
||||||
|
# Import from centralized fallback module - no conditional imports needed
|
||||||
|
from torch.distributed._distributed_c10d import _resolve_process_group
|
||||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
||||||
from torch.distributed.distributed_c10d import (
|
from torch.distributed.distributed_c10d import (
|
||||||
_get_group_size_by_name,
|
_get_group_size_by_name,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch._C._distributed_c10d import FakeProcessGroup
|
from torch.distributed._distributed_c10d import FakeProcessGroup
|
||||||
|
|
||||||
|
|
||||||
class FakeStore(dist.Store):
|
class FakeStore(dist.Store):
|
||||||
|
Reference in New Issue
Block a user