mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 06:11:27 +08:00
Add type annotations to torch._C._distributed_rpc module. (#46624)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46624 Test Plan: Imported from OSS Reviewed By: glaringlee Differential Revision: D24761656 Pulled By: xuzhao9 fbshipit-source-id: b55aee5dd2b97f573a50e5bbfddde7d984943fec
This commit is contained in:
committed by
Facebook GitHub Bot
parent
73a3e70b24
commit
eaa993a2e0
@ -4,11 +4,11 @@ import functools
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Generic, TypeVar, Set, Any
|
||||
|
||||
import torch
|
||||
|
||||
from . import (
|
||||
from torch._C._distributed_rpc import (
|
||||
PyRRef,
|
||||
RemoteProfilerManager,
|
||||
WorkerInfo,
|
||||
@ -99,10 +99,10 @@ class AllGatherStates(object):
|
||||
|
||||
# States used by `def _all_gather()`.
|
||||
# `_ALL_WORKER_NAMES` is initialized on initiaizing RPC layer.
|
||||
_ALL_WORKER_NAMES = None
|
||||
_ALL_WORKER_NAMES: Set[Any] = set()
|
||||
_all_gather_dict_lock = threading.RLock()
|
||||
_all_gather_sequence_id = 0
|
||||
_all_gather_sequence_id_to_states = collections.defaultdict(AllGatherStates)
|
||||
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates)
|
||||
|
||||
|
||||
def _init_rpc_states(agent):
|
||||
@ -379,16 +379,18 @@ GenericWithOneTypeVar = Generic[T]
|
||||
|
||||
try:
|
||||
# Combine the implementation class and the type class.
|
||||
class RRef(PyRRef, GenericWithOneTypeVar):
|
||||
class RRef(PyRRef, Generic[T]):
|
||||
pass
|
||||
except TypeError as exc:
|
||||
# TypeError: metaclass conflict: the metaclass of a derived class
|
||||
# must be a (non-strict) subclass of the metaclasses of all its bases
|
||||
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__):
|
||||
# Mypy doesn't understand __class__ (mypy bug #4177)
|
||||
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore
|
||||
pass
|
||||
|
||||
# Combine the implementation class and the type class.
|
||||
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta):
|
||||
# Types for classes expecting a certain generic parameter (mypy bug #7791)
|
||||
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
@ -564,7 +566,8 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
dst_worker_info.name,
|
||||
)
|
||||
RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
|
||||
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)
|
||||
# Mypy doesn't support re-def of a variable not in the same block (#1174)
|
||||
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
|
||||
|
||||
with ctx_manager as rf:
|
||||
args = args if args else ()
|
||||
@ -639,7 +642,8 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
|
||||
dst_worker_info.name,
|
||||
)
|
||||
RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
|
||||
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)
|
||||
# Mypy doesn't support re-def of a variable not in the same block (#1174)
|
||||
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
|
||||
|
||||
with ctx_manager as rf:
|
||||
args = args if args else ()
|
||||
|
Reference in New Issue
Block a user