[BE][Easy] enable UFMT for torch/distributed/{fsdp,optim,rpc}/ (#128869)

Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128869
Approved by: https://github.com/fegin
ghstack dependencies: #128868
This commit is contained in:
Xuehai Pan
2024-06-18 23:21:44 +08:00
committed by PyTorch MergeBot
parent cec31050b4
commit 3b798df853
41 changed files with 316 additions and 229 deletions

View File

@ -1,6 +1,4 @@
# mypy: allow-untyped-defs
__all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync",
"rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"]
import collections
import contextlib
@ -8,17 +6,10 @@ import functools
import inspect
import logging
import threading
from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING
from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar
import torch
from torch.futures import Future
from torch._C._distributed_rpc import (
PyRRef,
RemoteProfilerManager,
WorkerInfo,
TensorPipeAgent,
get_rpc_timeout,
_cleanup_python_rpc_handler,
_delete_all_user_and_unforked_owner_rrefs,
_destroy_rref_context,
@ -32,18 +23,36 @@ from torch._C._distributed_rpc import (
_is_current_rpc_agent_set,
_reset_current_rpc_agent,
_set_and_start_rpc_agent,
get_rpc_timeout,
PyRRef,
RemoteProfilerManager,
TensorPipeAgent,
WorkerInfo,
)
from .internal import (
PythonUDF,
RPCExecMode,
_internal_rpc_pickler,
_build_rpc_profiling_key,
)
from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
from torch.futures import Future
from ._utils import _group_membership_management, _update_group_membership
from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
from .internal import (
_build_rpc_profiling_key,
_internal_rpc_pickler,
PythonUDF,
RPCExecMode,
)
__all__ = [
"shutdown",
"get_worker_info",
"remote",
"rpc_sync",
"rpc_async",
"RRef",
"AllGatherStates",
"method_factory",
"new_method",
]
logger = logging.getLogger(__name__)
@ -59,6 +68,7 @@ logger = logging.getLogger(__name__)
_ignore_rref_leak = True
_default_pickler = _internal_rpc_pickler
@contextlib.contextmanager
def _use_rpc_pickler(rpc_pickler):
r"""
@ -107,7 +117,9 @@ class AllGatherStates:
_ALL_WORKER_NAMES: Set[Any] = set()
_all_gather_dict_lock = threading.RLock()
_all_gather_sequence_id: Dict[str, int] = {}
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates)
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(
AllGatherStates
)
def _init_rpc_states(agent):
@ -146,6 +158,7 @@ def _broadcast_to_followers(sequence_id, objects_map):
states.gathered_objects = objects_map
states.proceed_signal.set()
_thread_local_var = threading.local()
@ -245,7 +258,7 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
follower_name,
_broadcast_to_followers,
args=(sequence_id, states.gathered_objects),
timeout=rpc_timeout
timeout=rpc_timeout,
)
worker_name_to_response_future_dict[follower_name] = fut
@ -283,9 +296,7 @@ def _barrier(worker_names):
try:
_all_gather(None, set(worker_names))
except RuntimeError as ex:
logger.error(
"Failed to complete barrier, got error %s", ex
)
logger.error("Failed to complete barrier, got error %s", ex)
@_require_initialized
@ -371,7 +382,11 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
all_worker_infos = agent.get_worker_infos()
for worker in all_worker_infos:
if worker.name != my_name:
rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False))
rpc_sync(
worker.name,
_update_group_membership,
args=(my_worker_info, [], {}, False),
)
agent.join(shutdown=True, timeout=timeout)
finally:
# In case of errors, continue to complete the local shutdown.
@ -445,13 +460,10 @@ def _rref_typeof_on_owner(rref, blocking: bool = True):
return future
def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True):
fut = rpc_async(
rref.owner(),
_rref_typeof_on_owner,
args=(rref,),
timeout=timeout
)
def _rref_typeof_on_user(
rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True
):
fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout)
if blocking:
return fut.wait()
else:
@ -463,13 +475,16 @@ GenericWithOneTypeVar = Generic[T]
if TYPE_CHECKING:
class RRef(PyRRef[T], Generic[T]):
pass
else:
try:
# Combine the implementation class and the type class.
class RRef(PyRRef, Generic[T]):
pass
except TypeError:
# TypeError: metaclass conflict: the metaclass of a derived class
# must be a (non-strict) subclass of the metaclasses of all its bases
@ -517,7 +532,9 @@ for method_name, method in inspect.getmembers(PyRRef):
assert docstring is not None, "RRef user-facing methods should all have docstrings."
# Do surgery on pybind11 generated docstrings.
docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef")
docstring = docstring.replace(
"torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef"
)
# Attach user-facing RRef method with modified docstring.
new_method = method_factory(method_name, docstring)
@ -633,7 +650,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
dst_worker_info = _to_worker_info(to)
should_profile = _get_should_profile()
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info)
ctx_manager = _enable_rpc_profiler(
should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info
)
with ctx_manager as rf:
args = args if args else ()
@ -647,7 +666,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
func = wrapped
if qualified_name is not None:
rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs)
rref = _invoke_remote_builtin(
dst_worker_info, qualified_name, timeout, *args, **kwargs
)
elif isinstance(func, torch.jit.ScriptFunction):
rref = _invoke_remote_torchscript(
dst_worker_info.name,
@ -662,11 +683,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
PythonUDF(func, args, kwargs)
)
rref = _invoke_remote_python_udf(
dst_worker_info,
pickled_python_udf,
tensors,
timeout,
is_async_exec
dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec
)
# attach profiling information
if should_profile:
@ -678,7 +695,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
return rref
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT):
def _invoke_rpc(
to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT
):
if not callable(func):
raise TypeError("function should be callable.")
@ -687,7 +706,9 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float =
should_profile = _get_should_profile()
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info)
ctx_manager = _enable_rpc_profiler(
should_profile, qualified_name, func, rpc_type, dst_worker_info
)
with ctx_manager as rf:
args = args if args else ()
@ -702,11 +723,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float =
if qualified_name is not None:
fut = _invoke_rpc_builtin(
dst_worker_info,
qualified_name,
rpc_timeout,
*args,
**kwargs
dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs
)
elif isinstance(func, torch.jit.ScriptFunction):
fut = _invoke_rpc_torchscript(
@ -715,18 +732,14 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float =
args,
kwargs,
rpc_timeout,
is_async_exec
is_async_exec,
)
else:
(pickled_python_udf, tensors) = _default_pickler.serialize(
PythonUDF(func, args, kwargs)
)
fut = _invoke_rpc_python_udf(
dst_worker_info,
pickled_python_udf,
tensors,
rpc_timeout,
is_async_exec
dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec
)
if should_profile:
assert torch.autograd._profiler_enabled()
@ -915,12 +928,15 @@ def _get_should_profile():
# Kineto profiler.
ActiveProfilerType = torch._C._profiler.ActiveProfilerType
return (
torch.autograd._profiler_enabled() and
torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
torch.autograd._profiler_enabled()
and torch._C._autograd._profiler_type()
== ActiveProfilerType.LEGACY # type: ignore[attr-defined]
)
def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info):
def _enable_rpc_profiler(
should_profile, qualified_name, func, rpc_type, dst_worker_info
):
ctx_manager = contextlib.nullcontext()
if should_profile: