mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Make torch importable if compiled without TensorPipe (#154382)
By delaying the import/hiding it behind `torch.distributed.rpc.is_tensorpipe_avaiable()` check Fixes https://github.com/pytorch/pytorch/issues/154300 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154382 Approved by: https://github.com/Skylion007 ghstack dependencies: #154325
This commit is contained in:
committed by
PyTorch MergeBot
parent
f472ea63bb
commit
5075df6fee
@ -30,6 +30,10 @@ if is_available() and not torch._C._rpc_init():
|
||||
|
||||
|
||||
if is_available():
|
||||
_is_tensorpipe_available = hasattr(
|
||||
torch._C._distributed_rpc, "_TensorPipeRpcBackendOptionsBase"
|
||||
)
|
||||
|
||||
import numbers
|
||||
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
@ -37,7 +41,6 @@ if is_available():
|
||||
from torch._C._distributed_rpc import ( # noqa: F401
|
||||
_cleanup_python_rpc_handler,
|
||||
_DEFAULT_INIT_METHOD,
|
||||
_DEFAULT_NUM_WORKER_THREADS,
|
||||
_DEFAULT_RPC_TIMEOUT_SEC,
|
||||
_delete_all_user_and_unforked_owner_rrefs,
|
||||
_destroy_rref_context,
|
||||
@ -58,7 +61,6 @@ if is_available():
|
||||
_set_and_start_rpc_agent,
|
||||
_set_profiler_node_id,
|
||||
_set_rpc_timeout,
|
||||
_TensorPipeRpcBackendOptionsBase,
|
||||
_UNSET_RPC_TIMEOUT,
|
||||
enable_gil_profiling,
|
||||
get_rpc_timeout,
|
||||
@ -66,10 +68,16 @@ if is_available():
|
||||
RemoteProfilerManager,
|
||||
RpcAgent,
|
||||
RpcBackendOptions,
|
||||
TensorPipeAgent,
|
||||
WorkerInfo,
|
||||
)
|
||||
|
||||
if _is_tensorpipe_available:
|
||||
from torch._C._distributed_rpc import ( # noqa: F401
|
||||
_DEFAULT_NUM_WORKER_THREADS,
|
||||
_TensorPipeRpcBackendOptionsBase,
|
||||
TensorPipeAgent,
|
||||
)
|
||||
|
||||
from . import api, backend_registry, functions
|
||||
from .api import * # noqa: F401,F403
|
||||
from .backend_registry import BackendType
|
||||
|
@ -3,8 +3,6 @@ import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import cast
|
||||
|
||||
from . import api, TensorPipeAgent
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -40,6 +38,8 @@ def _group_membership_management(store, name, is_join):
|
||||
|
||||
|
||||
def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
|
||||
from . import api, TensorPipeAgent
|
||||
|
||||
agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
|
||||
ret = agent._update_group_membership(
|
||||
worker_info, my_devices, reverse_device_map, is_join
|
||||
|
@ -27,7 +27,6 @@ from torch._C._distributed_rpc import (
|
||||
get_rpc_timeout,
|
||||
PyRRef,
|
||||
RemoteProfilerManager,
|
||||
TensorPipeAgent,
|
||||
WorkerInfo,
|
||||
)
|
||||
from torch.futures import Future
|
||||
@ -371,6 +370,8 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
|
||||
if graceful:
|
||||
try:
|
||||
agent = _get_current_rpc_agent()
|
||||
from torch._C._distributed_rpc import TensorPipeAgent
|
||||
|
||||
if not isinstance(agent, TensorPipeAgent) or agent.is_static_group:
|
||||
_wait_all_workers(timeout)
|
||||
_delete_all_user_and_unforked_owner_rrefs()
|
||||
|
@ -2,9 +2,8 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
|
||||
|
||||
from . import constants as rpc_contants
|
||||
from . import _is_tensorpipe_available, constants as rpc_contants
|
||||
|
||||
|
||||
DeviceType = Union[int, str, torch.device]
|
||||
@ -43,6 +42,12 @@ def _to_device_list(devices: list[DeviceType]) -> list[torch.device]:
|
||||
return list(map(_to_device, devices))
|
||||
|
||||
|
||||
if _is_tensorpipe_available: # type: ignore[has-type]
|
||||
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
|
||||
else:
|
||||
_TensorPipeRpcBackendOptionsBase = object # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
|
||||
r"""
|
||||
The backend options for
|
||||
|
Reference in New Issue
Block a user