mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	------ - [Generic TypeAlias (PEP 585)](https://peps.python.org/pep-0585): e.g. `typing.List[T] -> list[T]`, `typing.Dict[KT, VT] -> dict[KT, VT]`, `typing.Type[T] -> type[T]`. - [Union Type (PEP 604)](https://peps.python.org/pep-0604): e.g. `Union[X, Y] -> X | Y`, `Optional[X] -> X | None`, `Optional[Union[X, Y]] -> X | Y | None`. Note that in `.pyi` stub files, we do not need `from __future__ import annotations`. So this PR does not violate issue #117449: - #117449 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129419 Approved by: https://github.com/ezyang ghstack dependencies: #129375, #129376
		
			
				
	
	
		
			189 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			189 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# mypy: allow-untyped-defs
 | 
						|
# mypy: disable-error-code="type-arg"
 | 
						|
from datetime import timedelta
 | 
						|
from typing import Any, Generic, overload, TypeVar
 | 
						|
 | 
						|
import torch
 | 
						|
from torch._C import Future
 | 
						|
from torch._C._autograd import ProfilerEvent
 | 
						|
from torch._C._distributed_c10d import Store
 | 
						|
from torch._C._profiler import ProfilerConfig
 | 
						|
 | 
						|
# This module is defined in torch/csrc/distributed/rpc/init.cpp
 | 
						|
 | 
						|
_DEFAULT_INIT_METHOD: str
 | 
						|
_DEFAULT_NUM_WORKER_THREADS: int
 | 
						|
_UNSET_RPC_TIMEOUT: float
 | 
						|
_DEFAULT_RPC_TIMEOUT_SEC: float
 | 
						|
 | 
						|
_T = TypeVar("_T")
 | 
						|
 | 
						|
class RpcBackendOptions:
 | 
						|
    rpc_timeout: float
 | 
						|
    init_method: str
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        rpc_timeout: float = ...,
 | 
						|
        init_method: str = ...,
 | 
						|
    ) -> None: ...
 | 
						|
 | 
						|
class WorkerInfo:
 | 
						|
    def __init__(self, name: str, worker_id: int) -> None: ...
 | 
						|
    @property
 | 
						|
    def name(self) -> str: ...
 | 
						|
    @property
 | 
						|
    def id(self) -> int: ...
 | 
						|
    def __eq__(self, other: object) -> bool: ...
 | 
						|
 | 
						|
class RpcAgent:
 | 
						|
    def join(self, shutdown: bool = False, timeout: float = 0): ...
 | 
						|
    def sync(self): ...
 | 
						|
    def shutdown(self): ...
 | 
						|
    @overload
 | 
						|
    def get_worker_info(self) -> WorkerInfo: ...
 | 
						|
    @overload
 | 
						|
    def get_worker_info(self, workerName: str) -> WorkerInfo: ...
 | 
						|
    def get_worker_infos(self) -> list[WorkerInfo]: ...
 | 
						|
    def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
 | 
						|
    def get_debug_info(self) -> dict[str, str]: ...
 | 
						|
    def get_metrics(self) -> dict[str, str]: ...
 | 
						|
 | 
						|
class PyRRef(Generic[_T]):
 | 
						|
    def __init__(self, value: _T, type_hint: Any = None) -> None: ...
 | 
						|
    def is_owner(self) -> bool: ...
 | 
						|
    def confirmed_by_owner(self) -> bool: ...
 | 
						|
    def owner(self) -> WorkerInfo: ...
 | 
						|
    def owner_name(self) -> str: ...
 | 
						|
    def to_here(self, timeout: float = ...) -> _T: ...
 | 
						|
    def local_value(self) -> Any: ...
 | 
						|
    def rpc_sync(self, timeout: float = ...) -> Any: ...
 | 
						|
    def rpc_async(self, timeout: float = ...) -> Any: ...
 | 
						|
    def remote(self, timeout: float = ...) -> Any: ...
 | 
						|
    def _serialize(self) -> tuple: ...
 | 
						|
    @staticmethod
 | 
						|
    def _deserialize(tp: tuple) -> PyRRef: ...
 | 
						|
    def _get_type(self) -> type[_T]: ...
 | 
						|
    def _get_future(self) -> Future[_T]: ...
 | 
						|
    def _get_profiling_future(self) -> Future[_T]: ...
 | 
						|
    def _set_profiling_future(self, profilingFuture: Future[_T]): ...
 | 
						|
 | 
						|
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
 | 
						|
    num_worker_threads: int
 | 
						|
    device_maps: dict[str, dict[torch.device, torch.device]]
 | 
						|
    devices: list[torch.device]
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        num_worker_threads: int,
 | 
						|
        _transports: list | None,
 | 
						|
        _channels: list | None,
 | 
						|
        rpc_timeout: float = ...,
 | 
						|
        init_method: str = ...,
 | 
						|
        device_maps: dict[str, dict[torch.device, torch.device]] = {},  # noqa: B006
 | 
						|
        devices: list[torch.device] = [],  # noqa: B006
 | 
						|
    ) -> None: ...
 | 
						|
    def _set_device_map(
 | 
						|
        self,
 | 
						|
        to: str,
 | 
						|
        device_map: dict[torch.device, torch.device],
 | 
						|
    ): ...
 | 
						|
 | 
						|
class TensorPipeAgent(RpcAgent):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        store: Store,
 | 
						|
        name: str,
 | 
						|
        worker_id: int,
 | 
						|
        world_size: int | None,
 | 
						|
        opts: _TensorPipeRpcBackendOptionsBase,
 | 
						|
        reverse_device_maps: dict[str, dict[torch.device, torch.device]],
 | 
						|
        devices: list[torch.device],
 | 
						|
    ) -> None: ...
 | 
						|
    def join(self, shutdown: bool = False, timeout: float = 0): ...
 | 
						|
    def shutdown(self): ...
 | 
						|
    @overload
 | 
						|
    def get_worker_info(self) -> WorkerInfo: ...
 | 
						|
    @overload
 | 
						|
    def get_worker_info(self, workerName: str) -> WorkerInfo: ...
 | 
						|
    @overload
 | 
						|
    def get_worker_info(self, id: int) -> WorkerInfo: ...
 | 
						|
    def get_worker_infos(self) -> list[WorkerInfo]: ...
 | 
						|
    def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
 | 
						|
    def _update_group_membership(
 | 
						|
        self,
 | 
						|
        worker_info: WorkerInfo,
 | 
						|
        my_devices: list[torch.device],
 | 
						|
        reverse_device_map: dict[str, dict[torch.device, torch.device]],
 | 
						|
        is_join: bool,
 | 
						|
    ): ...
 | 
						|
    def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
 | 
						|
    @property
 | 
						|
    def is_static_group(self) -> bool: ...
 | 
						|
    @property
 | 
						|
    def store(self) -> Store: ...
 | 
						|
 | 
						|
def _is_current_rpc_agent_set() -> bool: ...
 | 
						|
def _get_current_rpc_agent() -> RpcAgent: ...
 | 
						|
def _set_and_start_rpc_agent(agent: RpcAgent): ...
 | 
						|
def _reset_current_rpc_agent(): ...
 | 
						|
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
 | 
						|
def _destroy_rref_context(ignoreRRefLeak: bool): ...
 | 
						|
def _rref_context_get_debug_info() -> dict[str, str]: ...
 | 
						|
def _cleanup_python_rpc_handler(): ...
 | 
						|
def _invoke_rpc_builtin(
 | 
						|
    dst: WorkerInfo,
 | 
						|
    opName: str,
 | 
						|
    rpcTimeoutSeconds: float,
 | 
						|
    *args: Any,
 | 
						|
    **kwargs: Any,
 | 
						|
): ...
 | 
						|
def _invoke_rpc_python_udf(
 | 
						|
    dst: WorkerInfo,
 | 
						|
    pickledPythonUDF: str,
 | 
						|
    tensors: list[torch.Tensor],
 | 
						|
    rpcTimeoutSeconds: float,
 | 
						|
    isAsyncExecution: bool,
 | 
						|
): ...
 | 
						|
def _invoke_rpc_torchscript(
 | 
						|
    dstWorkerName: str,
 | 
						|
    qualifiedNameStr: str,
 | 
						|
    argsTuple: tuple,
 | 
						|
    kwargsDict: dict,
 | 
						|
    rpcTimeoutSeconds: float,
 | 
						|
    isAsyncExecution: bool,
 | 
						|
): ...
 | 
						|
def _invoke_remote_builtin(
 | 
						|
    dst: WorkerInfo,
 | 
						|
    opName: str,
 | 
						|
    rpcTimeoutSeconds: float,
 | 
						|
    *args: Any,
 | 
						|
    **kwargs: Any,
 | 
						|
): ...
 | 
						|
def _invoke_remote_python_udf(
 | 
						|
    dst: WorkerInfo,
 | 
						|
    pickledPythonUDF: str,
 | 
						|
    tensors: list[torch.Tensor],
 | 
						|
    rpcTimeoutSeconds: float,
 | 
						|
    isAsyncExecution: bool,
 | 
						|
): ...
 | 
						|
def _invoke_remote_torchscript(
 | 
						|
    dstWorkerName: WorkerInfo,
 | 
						|
    qualifiedNameStr: str,
 | 
						|
    rpcTimeoutSeconds: float,
 | 
						|
    isAsyncExecution: bool,
 | 
						|
    *args: Any,
 | 
						|
    **kwargs: Any,
 | 
						|
): ...
 | 
						|
def get_rpc_timeout() -> float: ...
 | 
						|
def enable_gil_profiling(flag: bool): ...
 | 
						|
def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
 | 
						|
 | 
						|
class RemoteProfilerManager:
 | 
						|
    @staticmethod
 | 
						|
    def set_current_profiling_key(key: str): ...
 | 
						|
 | 
						|
def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
 | 
						|
def _disable_server_process_global_profiler() -> list[list[list[ProfilerEvent]]]: ...
 | 
						|
def _set_profiler_node_id(default_node_id: int): ...
 | 
						|
def _enable_jit_rref_pickle(): ...
 | 
						|
def _disable_jit_rref_pickle(): ...
 |