Fix RRef type annotations (#104876)

Test Plan: Sandcastle

Reviewed By: H-Huang

Differential Revision: D47334579

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104876
Approved by: https://github.com/H-Huang
This commit is contained in:
Richard Barnes
2023-07-14 17:31:51 +00:00
committed by PyTorch MergeBot
parent c0a278d6f0
commit 15ea0a00cb
3 changed files with 31 additions and 25 deletions

View File

@ -7,7 +7,7 @@ import functools
import inspect
import logging
import threading
from typing import Dict, Generic, TypeVar, Set, Any
from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING
import torch
from torch.futures import Future
@ -178,7 +178,7 @@ def _wait_all():
@_require_initialized
def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
r"""
This is similar to torch.distributed.all_gather(), but is using RPC. It
picks the worker with the smallest name (alphabetic order) as the leader.
@ -431,7 +431,7 @@ def _to_worker_info(to):
raise ValueError("Cannot get WorkerInfo from name {}".format(to))
def _rref_typeof_on_owner(rref, blocking=True):
def _rref_typeof_on_owner(rref, blocking: bool = True):
rref_type = type(rref.local_value())
if blocking:
return rref_type
@ -444,7 +444,7 @@ def _rref_typeof_on_owner(rref, blocking=True):
return future
def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT, blocking=True):
def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True):
fut = rpc_async(
rref.owner(),
_rref_typeof_on_owner,
@ -461,21 +461,25 @@ T = TypeVar("T")
GenericWithOneTypeVar = Generic[T]
try:
# Combine the implementation class and the type class.
class RRef(PyRRef, Generic[T]):
pass
except TypeError:
# TypeError: metaclass conflict: the metaclass of a derived class
# must be a (non-strict) subclass of the metaclasses of all its bases
# Mypy doesn't understand __class__ (mypy bug #4177)
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type]
if TYPE_CHECKING:
class RRef(PyRRef[T], Generic[T]):
pass
else:
try:
# Combine the implementation class and the type class.
class RRef(PyRRef, Generic[T]):
pass
except TypeError:
# TypeError: metaclass conflict: the metaclass of a derived class
# must be a (non-strict) subclass of the metaclasses of all its bases
# Mypy doesn't understand __class__ (mypy bug #4177)
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type]
pass
# Combine the implementation class and the type class.
# Types for classes expecting a certain generic parameter (mypy bug #7791)
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type]
pass
# Combine the implementation class and the type class.
# Types for classes expecting a certain generic parameter (mypy bug #7791)
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type]
pass
# Install docstrings from `PyRRef` to `RRef`.
@ -673,7 +677,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
return rref
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT):
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT):
if not callable(func):
raise TypeError("function should be callable.")
@ -736,7 +740,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
@_require_initialized
def rpc_sync(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT):
r"""
Make a blocking RPC call to run function ``func`` on worker ``to``. RPC
messages are sent and received in parallel to execution of Python code. This