Make distributed modules importable even when backend not built (#159889)

This PR is greatly simplified now that it stacked on top of a PR that builds with distributed always. We only need to stub functions that may not be defined due to a backend not being enabled.

Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159889
Approved by: https://github.com/wconstab
ghstack dependencies: #160449
This commit is contained in:
Edward Z. Yang
2025-09-02 23:34:49 -04:00
committed by PyTorch MergeBot
parent 90b08643c3
commit 4ae57d448c
21 changed files with 619 additions and 224 deletions

View File

@ -19,13 +19,21 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import deprecated
import torch
import torch.distributed._distributed_c10d as _c10d
from torch._C import _DistStoreError as DistStoreError
from torch._C._distributed_c10d import (
from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
from torch.distributed._distributed_c10d import ( # Process group implementations; Availability flags
_DistributedBackendOptions,
_GLOO_AVAILABLE,
_MPI_AVAILABLE,
_NCCL_AVAILABLE,
_ProcessGroupWrapper,
_register_process_group,
_resolve_process_group,
_UCC_AVAILABLE,
_unregister_all_process_groups,
_unregister_process_group,
_XCCL_AVAILABLE,
AllgatherOptions,
AllreduceCoalescedOptions,
AllreduceOptions,
@ -37,6 +45,11 @@ from torch._C._distributed_c10d import (
get_debug_level,
PrefixStore,
ProcessGroup,
ProcessGroupGloo,
ProcessGroupMPI,
ProcessGroupNCCL,
ProcessGroupUCC,
ProcessGroupXCCL,
ReduceOp,
ReduceOptions,
ReduceScatterOptions,
@ -44,7 +57,6 @@ from torch._C._distributed_c10d import (
Store,
Work,
)
from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
from torch.monitor import _WaitCounter
from torch.overrides import handle_torch_function, has_torch_function
from torch.utils._typing_utils import not_none
@ -131,17 +143,11 @@ __all__ = [
"split_group",
]
_MPI_AVAILABLE = True
_NCCL_AVAILABLE = True
_GLOO_AVAILABLE = True
_UCC_AVAILABLE = True
_XCCL_AVAILABLE = True
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
# Change __module__ of all imported types from torch._C._distributed_c10d that are public
# Change __module__ of all imported types from the distributed wrapper that are public
def _export_c_types() -> None:
_public_types_to_change_module = [
AllreduceCoalescedOptions,
@ -167,45 +173,26 @@ def _export_c_types() -> None:
_export_c_types()
try:
from torch._C._distributed_c10d import ProcessGroupMPI
# Add process groups to __all__ and set their module based on availability
if _MPI_AVAILABLE:
ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupMPI"]
except ImportError:
_MPI_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupNCCL
if _NCCL_AVAILABLE:
ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupNCCL"]
except ImportError:
_NCCL_AVAILABLE = False
try:
from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo
if _GLOO_AVAILABLE:
ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupGloo"]
except ImportError:
_GLOO_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupUCC
if _UCC_AVAILABLE:
ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupUCC"]
except ImportError:
_UCC_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupXCCL
if _XCCL_AVAILABLE:
ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupXCCL"]
except ImportError:
_XCCL_AVAILABLE = False
logger = logging.getLogger(__name__)
@ -1325,7 +1312,8 @@ def _get_default_store() -> Store:
def _update_default_pg(pg) -> None:
_world.default_pg = pg
rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1
torch._C._distributed_c10d._set_global_rank(rank)
_c10d._set_global_rank(rank)
def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
@ -1957,7 +1945,7 @@ def _new_process_group_helper(
if device_id:
pg.bound_device_id = device_id
backend_class: torch._C._distributed_c10d.Backend
backend_class: _c10d.Backend
for device, backend_str in backend_config.get_device_backend_map().items():
# Use the group name as prefix in the default store, such that
# a single store can be reused by multiple groups.
@ -3072,7 +3060,9 @@ def _object_to_tensor(obj, device, group):
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
backend = get_backend(group)
if backend == Backend.NCCL:
hash = torch._C._distributed_c10d._hash_tensors([byte_tensor])
from torch.distributed._distributed_c10d import _hash_tensors
hash = _hash_tensors([byte_tensor])
logger.warning(
"_object_to_tensor size: %s hash value: %s",
byte_tensor.numel(),
@ -3087,7 +3077,9 @@ def _tensor_to_object(tensor, tensor_size, group):
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
backend = get_backend(group)
if backend == Backend.NCCL:
hash = torch._C._distributed_c10d._hash_tensors([tensor])
from torch.distributed._distributed_c10d import _hash_tensors
hash = _hash_tensors([tensor])
logger.warning(
"_tensor_to_object size: %s hash value: %s", tensor.numel(), hash
)
@ -4964,7 +4956,7 @@ def monitored_barrier(
def _create_process_group_wrapper(
wrapped_pg: torch._C._distributed_c10d.Backend,
wrapped_pg: _c10d.Backend,
store_prefix: str,
store: Store,
rank: int,