mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy] enable UFMT for torch/distributed/{fsdp,optim,rpc}/
(#128869)
Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128869 Approved by: https://github.com/fegin ghstack dependencies: #128868
This commit is contained in:
committed by
PyTorch MergeBot
parent
cec31050b4
commit
3b798df853
@ -1,6 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
__all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync",
|
||||
"rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"]
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
@ -8,17 +6,10 @@ import functools
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING
|
||||
from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
from torch.futures import Future
|
||||
|
||||
from torch._C._distributed_rpc import (
|
||||
PyRRef,
|
||||
RemoteProfilerManager,
|
||||
WorkerInfo,
|
||||
TensorPipeAgent,
|
||||
get_rpc_timeout,
|
||||
_cleanup_python_rpc_handler,
|
||||
_delete_all_user_and_unforked_owner_rrefs,
|
||||
_destroy_rref_context,
|
||||
@ -32,18 +23,36 @@ from torch._C._distributed_rpc import (
|
||||
_is_current_rpc_agent_set,
|
||||
_reset_current_rpc_agent,
|
||||
_set_and_start_rpc_agent,
|
||||
get_rpc_timeout,
|
||||
PyRRef,
|
||||
RemoteProfilerManager,
|
||||
TensorPipeAgent,
|
||||
WorkerInfo,
|
||||
)
|
||||
|
||||
from .internal import (
|
||||
PythonUDF,
|
||||
RPCExecMode,
|
||||
_internal_rpc_pickler,
|
||||
_build_rpc_profiling_key,
|
||||
)
|
||||
|
||||
from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
|
||||
from torch.futures import Future
|
||||
|
||||
from ._utils import _group_membership_management, _update_group_membership
|
||||
from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
|
||||
from .internal import (
|
||||
_build_rpc_profiling_key,
|
||||
_internal_rpc_pickler,
|
||||
PythonUDF,
|
||||
RPCExecMode,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"shutdown",
|
||||
"get_worker_info",
|
||||
"remote",
|
||||
"rpc_sync",
|
||||
"rpc_async",
|
||||
"RRef",
|
||||
"AllGatherStates",
|
||||
"method_factory",
|
||||
"new_method",
|
||||
]
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -59,6 +68,7 @@ logger = logging.getLogger(__name__)
|
||||
_ignore_rref_leak = True
|
||||
_default_pickler = _internal_rpc_pickler
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _use_rpc_pickler(rpc_pickler):
|
||||
r"""
|
||||
@ -107,7 +117,9 @@ class AllGatherStates:
|
||||
_ALL_WORKER_NAMES: Set[Any] = set()
|
||||
_all_gather_dict_lock = threading.RLock()
|
||||
_all_gather_sequence_id: Dict[str, int] = {}
|
||||
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates)
|
||||
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(
|
||||
AllGatherStates
|
||||
)
|
||||
|
||||
|
||||
def _init_rpc_states(agent):
|
||||
@ -146,6 +158,7 @@ def _broadcast_to_followers(sequence_id, objects_map):
|
||||
states.gathered_objects = objects_map
|
||||
states.proceed_signal.set()
|
||||
|
||||
|
||||
_thread_local_var = threading.local()
|
||||
|
||||
|
||||
@ -245,7 +258,7 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
|
||||
follower_name,
|
||||
_broadcast_to_followers,
|
||||
args=(sequence_id, states.gathered_objects),
|
||||
timeout=rpc_timeout
|
||||
timeout=rpc_timeout,
|
||||
)
|
||||
worker_name_to_response_future_dict[follower_name] = fut
|
||||
|
||||
@ -283,9 +296,7 @@ def _barrier(worker_names):
|
||||
try:
|
||||
_all_gather(None, set(worker_names))
|
||||
except RuntimeError as ex:
|
||||
logger.error(
|
||||
"Failed to complete barrier, got error %s", ex
|
||||
)
|
||||
logger.error("Failed to complete barrier, got error %s", ex)
|
||||
|
||||
|
||||
@_require_initialized
|
||||
@ -371,7 +382,11 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
|
||||
all_worker_infos = agent.get_worker_infos()
|
||||
for worker in all_worker_infos:
|
||||
if worker.name != my_name:
|
||||
rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False))
|
||||
rpc_sync(
|
||||
worker.name,
|
||||
_update_group_membership,
|
||||
args=(my_worker_info, [], {}, False),
|
||||
)
|
||||
agent.join(shutdown=True, timeout=timeout)
|
||||
finally:
|
||||
# In case of errors, continue to complete the local shutdown.
|
||||
@ -445,13 +460,10 @@ def _rref_typeof_on_owner(rref, blocking: bool = True):
|
||||
return future
|
||||
|
||||
|
||||
def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True):
|
||||
fut = rpc_async(
|
||||
rref.owner(),
|
||||
_rref_typeof_on_owner,
|
||||
args=(rref,),
|
||||
timeout=timeout
|
||||
)
|
||||
def _rref_typeof_on_user(
|
||||
rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True
|
||||
):
|
||||
fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout)
|
||||
if blocking:
|
||||
return fut.wait()
|
||||
else:
|
||||
@ -463,13 +475,16 @@ GenericWithOneTypeVar = Generic[T]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class RRef(PyRRef[T], Generic[T]):
|
||||
pass
|
||||
|
||||
else:
|
||||
try:
|
||||
# Combine the implementation class and the type class.
|
||||
class RRef(PyRRef, Generic[T]):
|
||||
pass
|
||||
|
||||
except TypeError:
|
||||
# TypeError: metaclass conflict: the metaclass of a derived class
|
||||
# must be a (non-strict) subclass of the metaclasses of all its bases
|
||||
@ -517,7 +532,9 @@ for method_name, method in inspect.getmembers(PyRRef):
|
||||
assert docstring is not None, "RRef user-facing methods should all have docstrings."
|
||||
|
||||
# Do surgery on pybind11 generated docstrings.
|
||||
docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef")
|
||||
docstring = docstring.replace(
|
||||
"torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef"
|
||||
)
|
||||
|
||||
# Attach user-facing RRef method with modified docstring.
|
||||
new_method = method_factory(method_name, docstring)
|
||||
@ -633,7 +650,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
dst_worker_info = _to_worker_info(to)
|
||||
should_profile = _get_should_profile()
|
||||
|
||||
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info)
|
||||
ctx_manager = _enable_rpc_profiler(
|
||||
should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info
|
||||
)
|
||||
|
||||
with ctx_manager as rf:
|
||||
args = args if args else ()
|
||||
@ -647,7 +666,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
func = wrapped
|
||||
|
||||
if qualified_name is not None:
|
||||
rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs)
|
||||
rref = _invoke_remote_builtin(
|
||||
dst_worker_info, qualified_name, timeout, *args, **kwargs
|
||||
)
|
||||
elif isinstance(func, torch.jit.ScriptFunction):
|
||||
rref = _invoke_remote_torchscript(
|
||||
dst_worker_info.name,
|
||||
@ -662,11 +683,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
PythonUDF(func, args, kwargs)
|
||||
)
|
||||
rref = _invoke_remote_python_udf(
|
||||
dst_worker_info,
|
||||
pickled_python_udf,
|
||||
tensors,
|
||||
timeout,
|
||||
is_async_exec
|
||||
dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec
|
||||
)
|
||||
# attach profiling information
|
||||
if should_profile:
|
||||
@ -678,7 +695,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
return rref
|
||||
|
||||
|
||||
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT):
|
||||
def _invoke_rpc(
|
||||
to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT
|
||||
):
|
||||
if not callable(func):
|
||||
raise TypeError("function should be callable.")
|
||||
|
||||
@ -687,7 +706,9 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float =
|
||||
|
||||
should_profile = _get_should_profile()
|
||||
|
||||
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info)
|
||||
ctx_manager = _enable_rpc_profiler(
|
||||
should_profile, qualified_name, func, rpc_type, dst_worker_info
|
||||
)
|
||||
|
||||
with ctx_manager as rf:
|
||||
args = args if args else ()
|
||||
@ -702,11 +723,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float =
|
||||
|
||||
if qualified_name is not None:
|
||||
fut = _invoke_rpc_builtin(
|
||||
dst_worker_info,
|
||||
qualified_name,
|
||||
rpc_timeout,
|
||||
*args,
|
||||
**kwargs
|
||||
dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs
|
||||
)
|
||||
elif isinstance(func, torch.jit.ScriptFunction):
|
||||
fut = _invoke_rpc_torchscript(
|
||||
@ -715,18 +732,14 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float =
|
||||
args,
|
||||
kwargs,
|
||||
rpc_timeout,
|
||||
is_async_exec
|
||||
is_async_exec,
|
||||
)
|
||||
else:
|
||||
(pickled_python_udf, tensors) = _default_pickler.serialize(
|
||||
PythonUDF(func, args, kwargs)
|
||||
)
|
||||
fut = _invoke_rpc_python_udf(
|
||||
dst_worker_info,
|
||||
pickled_python_udf,
|
||||
tensors,
|
||||
rpc_timeout,
|
||||
is_async_exec
|
||||
dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec
|
||||
)
|
||||
if should_profile:
|
||||
assert torch.autograd._profiler_enabled()
|
||||
@ -915,12 +928,15 @@ def _get_should_profile():
|
||||
# Kineto profiler.
|
||||
ActiveProfilerType = torch._C._profiler.ActiveProfilerType
|
||||
return (
|
||||
torch.autograd._profiler_enabled() and
|
||||
torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
|
||||
torch.autograd._profiler_enabled()
|
||||
and torch._C._autograd._profiler_type()
|
||||
== ActiveProfilerType.LEGACY # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
|
||||
def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info):
|
||||
def _enable_rpc_profiler(
|
||||
should_profile, qualified_name, func, rpc_type, dst_worker_info
|
||||
):
|
||||
ctx_manager = contextlib.nullcontext()
|
||||
|
||||
if should_profile:
|
||||
|
Reference in New Issue
Block a user