Add type annotations to torch._C._distributed_rpc module. (#46624)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46624

Test Plan: Imported from OSS

Reviewed By: glaringlee

Differential Revision: D24761656

Pulled By: xuzhao9

fbshipit-source-id: b55aee5dd2b97f573a50e5bbfddde7d984943fec
This commit is contained in:
Xu Zhao
2020-11-06 00:47:23 -08:00
committed by Facebook GitHub Bot
parent 73a3e70b24
commit eaa993a2e0
13 changed files with 321 additions and 51 deletions

View File

@ -4,11 +4,11 @@ import functools
import inspect
import logging
import threading
from typing import Generic, TypeVar
from typing import Generic, TypeVar, Set, Any
import torch
from . import (
from torch._C._distributed_rpc import (
PyRRef,
RemoteProfilerManager,
WorkerInfo,
@ -99,10 +99,10 @@ class AllGatherStates(object):
# States used by `def _all_gather()`.
# `_ALL_WORKER_NAMES` is initialized on initiaizing RPC layer.
_ALL_WORKER_NAMES = None
_ALL_WORKER_NAMES: Set[Any] = set()
_all_gather_dict_lock = threading.RLock()
_all_gather_sequence_id = 0
_all_gather_sequence_id_to_states = collections.defaultdict(AllGatherStates)
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates)
def _init_rpc_states(agent):
@ -379,16 +379,18 @@ GenericWithOneTypeVar = Generic[T]
try:
# Combine the implementation class and the type class.
class RRef(PyRRef, GenericWithOneTypeVar):
class RRef(PyRRef, Generic[T]):
pass
except TypeError as exc:
# TypeError: metaclass conflict: the metaclass of a derived class
# must be a (non-strict) subclass of the metaclasses of all its bases
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__):
# Mypy doesn't understand __class__ (mypy bug #4177)
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore
pass
# Combine the implementation class and the type class.
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta):
# Types for classes expecting a certain generic parameter (mypy bug #7791)
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore
pass
@ -564,7 +566,8 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
dst_worker_info.name,
)
RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)
# Mypy doesn't support re-def of a variable not in the same block (#1174)
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
with ctx_manager as rf:
args = args if args else ()
@ -639,7 +642,8 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
dst_worker_info.name,
)
RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)
# Mypy doesn't support re-def of a variable not in the same block (#1174)
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
with ctx_manager as rf:
args = args if args else ()