mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
More doctest refinements. (#83317)
Follow up to #82797 Now that the doctests themselves are in a better state, we should be able to enable xdoctest on the CI so they stay that way. @ezyang @vadimkantorov Pull Request resolved: https://github.com/pytorch/pytorch/pull/83317 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
9c9f424817
commit
b136f3f310
@ -148,6 +148,7 @@ def _broadcast_to_followers(sequence_id, objects_map):
|
||||
|
||||
_thread_local_var = threading.local()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _wait_all():
|
||||
r"""
|
||||
@ -157,10 +158,10 @@ def _wait_all():
|
||||
|
||||
|
||||
Example::
|
||||
>>> # xdoctest: +SKIP("distributed")
|
||||
>>> # On worker 0:
|
||||
>>> import torch
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||
>>> with rpc._wait_all():
|
||||
>>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
|
||||
@ -176,6 +177,7 @@ def _wait_all():
|
||||
finally:
|
||||
del _thread_local_var.future_list
|
||||
|
||||
|
||||
@_require_initialized
|
||||
def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
r"""
|
||||
@ -285,6 +287,7 @@ def _barrier(worker_names):
|
||||
f"Failed to complete barrier, got error {ex}"
|
||||
)
|
||||
|
||||
|
||||
@_require_initialized
|
||||
def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT):
|
||||
r"""
|
||||
@ -376,6 +379,7 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
|
||||
else:
|
||||
_finalize_shutdown()
|
||||
|
||||
|
||||
def _finalize_shutdown():
|
||||
try:
|
||||
# This raises a `TORCH_CHECK()` exception on RRef leak detected.
|
||||
@ -396,6 +400,7 @@ def _finalize_shutdown():
|
||||
_cleanup_python_rpc_handler()
|
||||
_reset_current_rpc_agent()
|
||||
|
||||
|
||||
@_require_initialized
|
||||
def get_worker_info(worker_name=None):
|
||||
r"""
|
||||
@ -453,7 +458,6 @@ def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT, blocking=True):
|
||||
return fut
|
||||
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
GenericWithOneTypeVar = Generic[T]
|
||||
|
||||
@ -669,6 +673,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
|
||||
return rref
|
||||
|
||||
|
||||
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT):
|
||||
if not callable(func):
|
||||
raise TypeError("function should be callable.")
|
||||
@ -900,6 +905,7 @@ def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
_thread_local_var.future_list.append(fut)
|
||||
return fut
|
||||
|
||||
|
||||
def _get_should_profile():
|
||||
# Legacy profiler should be enabled. RPC profiling is not supported with
|
||||
# Kineto profiler.
|
||||
@ -909,6 +915,7 @@ def _get_should_profile():
|
||||
torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
|
||||
def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info):
|
||||
ctx_manager = contextlib.suppress()
|
||||
|
||||
|
Reference in New Issue
Block a user