mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Fix RRef type annotations (#104876)
Test Plan: Sandcastle Reviewed By: H-Huang Differential Revision: D47334579 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104876 Approved by: https://github.com/H-Huang
This commit is contained in:
committed by
PyTorch MergeBot
parent
c0a278d6f0
commit
15ea0a00cb
@ -7,7 +7,7 @@ import functools
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Generic, TypeVar, Set, Any
|
||||
from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.futures import Future
|
||||
@ -178,7 +178,7 @@ def _wait_all():
|
||||
|
||||
|
||||
@_require_initialized
|
||||
def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
|
||||
r"""
|
||||
This is similar to torch.distributed.all_gather(), but is using RPC. It
|
||||
picks the worker with the smallest name (alphabetic order) as the leader.
|
||||
@ -431,7 +431,7 @@ def _to_worker_info(to):
|
||||
raise ValueError("Cannot get WorkerInfo from name {}".format(to))
|
||||
|
||||
|
||||
def _rref_typeof_on_owner(rref, blocking=True):
|
||||
def _rref_typeof_on_owner(rref, blocking: bool = True):
|
||||
rref_type = type(rref.local_value())
|
||||
if blocking:
|
||||
return rref_type
|
||||
@ -444,7 +444,7 @@ def _rref_typeof_on_owner(rref, blocking=True):
|
||||
return future
|
||||
|
||||
|
||||
def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT, blocking=True):
|
||||
def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True):
|
||||
fut = rpc_async(
|
||||
rref.owner(),
|
||||
_rref_typeof_on_owner,
|
||||
@ -461,21 +461,25 @@ T = TypeVar("T")
|
||||
GenericWithOneTypeVar = Generic[T]
|
||||
|
||||
|
||||
try:
|
||||
# Combine the implementation class and the type class.
|
||||
class RRef(PyRRef, Generic[T]):
|
||||
pass
|
||||
except TypeError:
|
||||
# TypeError: metaclass conflict: the metaclass of a derived class
|
||||
# must be a (non-strict) subclass of the metaclasses of all its bases
|
||||
# Mypy doesn't understand __class__ (mypy bug #4177)
|
||||
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type]
|
||||
if TYPE_CHECKING:
|
||||
class RRef(PyRRef[T], Generic[T]):
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
# Combine the implementation class and the type class.
|
||||
class RRef(PyRRef, Generic[T]):
|
||||
pass
|
||||
except TypeError:
|
||||
# TypeError: metaclass conflict: the metaclass of a derived class
|
||||
# must be a (non-strict) subclass of the metaclasses of all its bases
|
||||
# Mypy doesn't understand __class__ (mypy bug #4177)
|
||||
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type]
|
||||
pass
|
||||
|
||||
# Combine the implementation class and the type class.
|
||||
# Types for classes expecting a certain generic parameter (mypy bug #7791)
|
||||
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type]
|
||||
pass
|
||||
# Combine the implementation class and the type class.
|
||||
# Types for classes expecting a certain generic parameter (mypy bug #7791)
|
||||
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type]
|
||||
pass
|
||||
|
||||
|
||||
# Install docstrings from `PyRRef` to `RRef`.
|
||||
@ -673,7 +677,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
return rref
|
||||
|
||||
|
||||
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT):
|
||||
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT):
|
||||
if not callable(func):
|
||||
raise TypeError("function should be callable.")
|
||||
|
||||
@ -736,7 +740,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
|
||||
|
||||
|
||||
@_require_initialized
|
||||
def rpc_sync(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT):
|
||||
r"""
|
||||
Make a blocking RPC call to run function ``func`` on worker ``to``. RPC
|
||||
messages are sent and received in parallel to execution of Python code. This
|
||||
|
Reference in New Issue
Block a user