[RELAND] Always build USE_DISTRIBUTED (#160449) and Make distributed modules importable even when backend not built (#159889) (#162594)

Summary:
Original: D81957844 and D81957923

Also, https://github.com/pytorch/pytorch/pull/162142 is patched in as well

#buildall

Test Plan:
sandcastle and oss ci

Rollback Plan:

Reviewed By: H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162594
Approved by: https://github.com/H-Huang, https://github.com/dcci
This commit is contained in:
Edward Yang
2025-09-22 21:12:14 +00:00
committed by PyTorch MergeBot
parent 4027e97791
commit 09cb34c1dc
52 changed files with 766 additions and 446 deletions

View File

@ -19,13 +19,21 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import deprecated
import torch
import torch.distributed._distributed_c10d as _c10d
from torch._C import _DistStoreError as DistStoreError
from torch._C._distributed_c10d import (
from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
from torch.distributed._distributed_c10d import ( # Process group implementations; Availability flags
_DistributedBackendOptions,
_GLOO_AVAILABLE,
_MPI_AVAILABLE,
_NCCL_AVAILABLE,
_ProcessGroupWrapper,
_register_process_group,
_resolve_process_group,
_UCC_AVAILABLE,
_unregister_all_process_groups,
_unregister_process_group,
_XCCL_AVAILABLE,
AllgatherOptions,
AllreduceCoalescedOptions,
AllreduceOptions,
@ -37,6 +45,11 @@ from torch._C._distributed_c10d import (
get_debug_level,
PrefixStore,
ProcessGroup,
ProcessGroupGloo,
ProcessGroupMPI,
ProcessGroupNCCL,
ProcessGroupUCC,
ProcessGroupXCCL,
ReduceOp,
ReduceOptions,
ReduceScatterOptions,
@ -44,7 +57,6 @@ from torch._C._distributed_c10d import (
Store,
Work,
)
from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
from torch.monitor import _WaitCounter
from torch.overrides import handle_torch_function, has_torch_function
from torch.utils._typing_utils import not_none
@ -131,17 +143,11 @@ __all__ = [
"split_group",
]
_MPI_AVAILABLE = True
_NCCL_AVAILABLE = True
_GLOO_AVAILABLE = True
_UCC_AVAILABLE = True
_XCCL_AVAILABLE = True
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
# Change __module__ of all imported types from torch._C._distributed_c10d that are public
# Change __module__ of all imported types from the distributed wrapper that are public
def _export_c_types() -> None:
_public_types_to_change_module = [
AllreduceCoalescedOptions,
@ -167,45 +173,26 @@ def _export_c_types() -> None:
_export_c_types()
try:
from torch._C._distributed_c10d import ProcessGroupMPI
# Add process groups to __all__ and set their module based on availability
if _MPI_AVAILABLE:
ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupMPI"]
except ImportError:
_MPI_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupNCCL
if _NCCL_AVAILABLE:
ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupNCCL"]
except ImportError:
_NCCL_AVAILABLE = False
try:
from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo
if _GLOO_AVAILABLE:
ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupGloo"]
except ImportError:
_GLOO_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupUCC
if _UCC_AVAILABLE:
ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupUCC"]
except ImportError:
_UCC_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupXCCL
if _XCCL_AVAILABLE:
ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupXCCL"]
except ImportError:
_XCCL_AVAILABLE = False
logger = logging.getLogger(__name__)
@ -1327,7 +1314,8 @@ def _get_default_store() -> Store:
def _update_default_pg(pg) -> None:
_world.default_pg = pg
rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1
torch._C._distributed_c10d._set_global_rank(rank)
_c10d._set_global_rank(rank)
def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
@ -1964,7 +1952,7 @@ def _new_process_group_helper(
if device_id:
pg.bound_device_id = device_id
backend_class: torch._C._distributed_c10d.Backend
backend_class: _c10d.Backend
for device, backend_str in backend_config.get_device_backend_map().items():
# Use the group name as prefix in the default store, such that
# a single store can be reused by multiple groups.
@ -3079,7 +3067,9 @@ def _object_to_tensor(obj, device, group):
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
backend = get_backend(group)
if backend == Backend.NCCL:
hash = torch._C._distributed_c10d._hash_tensors([byte_tensor])
from torch.distributed._distributed_c10d import _hash_tensors
hash = _hash_tensors([byte_tensor])
logger.warning(
"_object_to_tensor size: %s hash value: %s",
byte_tensor.numel(),
@ -3094,7 +3084,9 @@ def _tensor_to_object(tensor, tensor_size, group):
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
backend = get_backend(group)
if backend == Backend.NCCL:
hash = torch._C._distributed_c10d._hash_tensors([tensor])
from torch.distributed._distributed_c10d import _hash_tensors
hash = _hash_tensors([tensor])
logger.warning(
"_tensor_to_object size: %s hash value: %s", tensor.numel(), hash
)
@ -4971,7 +4963,7 @@ def monitored_barrier(
def _create_process_group_wrapper(
wrapped_pg: torch._C._distributed_c10d.Backend,
wrapped_pg: _c10d.Backend,
store_prefix: str,
store: Store,
rank: int,