Make torch importable if compiled without TensorPipe (#154382)

By delaying the import/hiding it behind `torch.distributed.rpc.is_tensorpipe_avaiable()` check
Fixes https://github.com/pytorch/pytorch/issues/154300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154382
Approved by: https://github.com/Skylion007
ghstack dependencies: #154325
This commit is contained in:
Nikita Shulga
2025-05-27 09:58:29 -07:00
committed by PyTorch MergeBot
parent f472ea63bb
commit 5075df6fee
4 changed files with 22 additions and 8 deletions

View File

@ -30,6 +30,10 @@ if is_available() and not torch._C._rpc_init():
if is_available():
_is_tensorpipe_available = hasattr(
torch._C._distributed_rpc, "_TensorPipeRpcBackendOptionsBase"
)
import numbers
import torch.distributed.autograd as dist_autograd
@ -37,7 +41,6 @@ if is_available():
from torch._C._distributed_rpc import ( # noqa: F401
_cleanup_python_rpc_handler,
_DEFAULT_INIT_METHOD,
_DEFAULT_NUM_WORKER_THREADS,
_DEFAULT_RPC_TIMEOUT_SEC,
_delete_all_user_and_unforked_owner_rrefs,
_destroy_rref_context,
@ -58,7 +61,6 @@ if is_available():
_set_and_start_rpc_agent,
_set_profiler_node_id,
_set_rpc_timeout,
_TensorPipeRpcBackendOptionsBase,
_UNSET_RPC_TIMEOUT,
enable_gil_profiling,
get_rpc_timeout,
@ -66,10 +68,16 @@ if is_available():
RemoteProfilerManager,
RpcAgent,
RpcBackendOptions,
TensorPipeAgent,
WorkerInfo,
)
if _is_tensorpipe_available:
from torch._C._distributed_rpc import ( # noqa: F401
_DEFAULT_NUM_WORKER_THREADS,
_TensorPipeRpcBackendOptionsBase,
TensorPipeAgent,
)
from . import api, backend_registry, functions
from .api import * # noqa: F401,F403
from .backend_registry import BackendType

View File

@ -3,8 +3,6 @@ import logging
from contextlib import contextmanager
from typing import cast
from . import api, TensorPipeAgent
logger = logging.getLogger(__name__)
@ -40,6 +38,8 @@ def _group_membership_management(store, name, is_join):
def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
from . import api, TensorPipeAgent
agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
ret = agent._update_group_membership(
worker_info, my_devices, reverse_device_map, is_join

View File

@ -27,7 +27,6 @@ from torch._C._distributed_rpc import (
get_rpc_timeout,
PyRRef,
RemoteProfilerManager,
TensorPipeAgent,
WorkerInfo,
)
from torch.futures import Future
@ -371,6 +370,8 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
if graceful:
try:
agent = _get_current_rpc_agent()
from torch._C._distributed_rpc import TensorPipeAgent
if not isinstance(agent, TensorPipeAgent) or agent.is_static_group:
_wait_all_workers(timeout)
_delete_all_user_and_unforked_owner_rrefs()

View File

@ -2,9 +2,8 @@
from typing import Optional, Union
import torch
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
from . import constants as rpc_contants
from . import _is_tensorpipe_available, constants as rpc_contants
DeviceType = Union[int, str, torch.device]
@ -43,6 +42,12 @@ def _to_device_list(devices: list[DeviceType]) -> list[torch.device]:
return list(map(_to_device, devices))
if _is_tensorpipe_available: # type: ignore[has-type]
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
else:
_TensorPipeRpcBackendOptionsBase = object # type: ignore[assignment, misc]
class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
r"""
The backend options for