Revert "Make distributed modules importable even when backend not built (#159889)"

This reverts commit 4ae57d448c0a7d37e4cfd5c27d977fad2cef4051.

Reverted https://github.com/pytorch/pytorch/pull/159889 on behalf of https://github.com/jeanschmidt due to Failing internal tests, probably typechecks. See D81588399 ([comment](https://github.com/pytorch/pytorch/pull/159889#issuecomment-3253651785))
This commit is contained in:
PyTorch MergeBot
2025-09-04 13:13:52 +00:00
parent 040d00af04
commit 34aa78274d
21 changed files with 221 additions and 616 deletions

View File

@ -19,21 +19,13 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import deprecated
import torch
import torch.distributed._distributed_c10d as _c10d
from torch._C import _DistStoreError as DistStoreError
from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
from torch.distributed._distributed_c10d import ( # Process group implementations; Availability flags
from torch._C._distributed_c10d import (
_DistributedBackendOptions,
_GLOO_AVAILABLE,
_MPI_AVAILABLE,
_NCCL_AVAILABLE,
_ProcessGroupWrapper,
_register_process_group,
_resolve_process_group,
_UCC_AVAILABLE,
_unregister_all_process_groups,
_unregister_process_group,
_XCCL_AVAILABLE,
AllgatherOptions,
AllreduceCoalescedOptions,
AllreduceOptions,
@ -45,11 +37,6 @@ from torch.distributed._distributed_c10d import ( # Process group implementatio
get_debug_level,
PrefixStore,
ProcessGroup,
ProcessGroupGloo,
ProcessGroupMPI,
ProcessGroupNCCL,
ProcessGroupUCC,
ProcessGroupXCCL,
ReduceOp,
ReduceOptions,
ReduceScatterOptions,
@ -57,6 +44,7 @@ from torch.distributed._distributed_c10d import ( # Process group implementatio
Store,
Work,
)
from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
from torch.monitor import _WaitCounter
from torch.overrides import handle_torch_function, has_torch_function
from torch.utils._typing_utils import not_none
@ -143,11 +131,17 @@ __all__ = [
"split_group",
]
_MPI_AVAILABLE = True
_NCCL_AVAILABLE = True
_GLOO_AVAILABLE = True
_UCC_AVAILABLE = True
_XCCL_AVAILABLE = True
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
# Change __module__ of all imported types from the distributed wrapper that are public
# Change __module__ of all imported types from torch._C._distributed_c10d that are public
def _export_c_types() -> None:
_public_types_to_change_module = [
AllreduceCoalescedOptions,
@ -173,26 +167,45 @@ def _export_c_types() -> None:
_export_c_types()
# Add process groups to __all__ and set their module based on availability
if _MPI_AVAILABLE:
try:
from torch._C._distributed_c10d import ProcessGroupMPI
ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupMPI"]
except ImportError:
_MPI_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupNCCL
if _NCCL_AVAILABLE:
ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupNCCL"]
except ImportError:
_NCCL_AVAILABLE = False
try:
from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo
if _GLOO_AVAILABLE:
ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupGloo"]
except ImportError:
_GLOO_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupUCC
if _UCC_AVAILABLE:
ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupUCC"]
except ImportError:
_UCC_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupXCCL
if _XCCL_AVAILABLE:
ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupXCCL"]
except ImportError:
_XCCL_AVAILABLE = False
logger = logging.getLogger(__name__)
@ -1314,8 +1327,7 @@ def _get_default_store() -> Store:
def _update_default_pg(pg) -> None:
_world.default_pg = pg
rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1
_c10d._set_global_rank(rank)
torch._C._distributed_c10d._set_global_rank(rank)
def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
@ -1947,7 +1959,7 @@ def _new_process_group_helper(
if device_id:
pg.bound_device_id = device_id
backend_class: _c10d.Backend
backend_class: torch._C._distributed_c10d.Backend
for device, backend_str in backend_config.get_device_backend_map().items():
# Use the group name as prefix in the default store, such that
# a single store can be reused by multiple groups.
@ -3062,9 +3074,7 @@ def _object_to_tensor(obj, device, group):
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
backend = get_backend(group)
if backend == Backend.NCCL:
from torch.distributed._distributed_c10d import _hash_tensors
hash = _hash_tensors([byte_tensor])
hash = torch._C._distributed_c10d._hash_tensors([byte_tensor])
logger.warning(
"_object_to_tensor size: %s hash value: %s",
byte_tensor.numel(),
@ -3079,9 +3089,7 @@ def _tensor_to_object(tensor, tensor_size, group):
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
backend = get_backend(group)
if backend == Backend.NCCL:
from torch.distributed._distributed_c10d import _hash_tensors
hash = _hash_tensors([tensor])
hash = torch._C._distributed_c10d._hash_tensors([tensor])
logger.warning(
"_tensor_to_object size: %s hash value: %s", tensor.numel(), hash
)
@ -4958,7 +4966,7 @@ def monitored_barrier(
def _create_process_group_wrapper(
wrapped_pg: _c10d.Backend,
wrapped_pg: torch._C._distributed_c10d.Backend,
store_prefix: str,
store: Store,
rank: int,