More doctest refinements. (#83317)

Follow up to #82797

Now that the doctests themselves are in a better state, we should be able to enable xdoctest on the CI so they stay that way.

@ezyang @vadimkantorov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83317
Approved by: https://github.com/ezyang
This commit is contained in:
joncrall
2022-08-22 20:07:23 +00:00
committed by PyTorch MergeBot
parent 9c9f424817
commit b136f3f310
41 changed files with 310 additions and 162 deletions

View File

@ -148,6 +148,7 @@ def _broadcast_to_followers(sequence_id, objects_map):
_thread_local_var = threading.local()
@contextlib.contextmanager
def _wait_all():
r"""
@ -157,10 +158,10 @@ def _wait_all():
Example::
>>> # xdoctest: +SKIP("distributed")
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> # xdoctest: +SKIP
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> with rpc._wait_all():
>>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
@ -176,6 +177,7 @@ def _wait_all():
finally:
del _thread_local_var.future_list
@_require_initialized
def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
r"""
@ -285,6 +287,7 @@ def _barrier(worker_names):
f"Failed to complete barrier, got error {ex}"
)
@_require_initialized
def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT):
r"""
@ -376,6 +379,7 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
else:
_finalize_shutdown()
def _finalize_shutdown():
try:
# This raises a `TORCH_CHECK()` exception on RRef leak detected.
@ -396,6 +400,7 @@ def _finalize_shutdown():
_cleanup_python_rpc_handler()
_reset_current_rpc_agent()
@_require_initialized
def get_worker_info(worker_name=None):
r"""
@ -453,7 +458,6 @@ def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT, blocking=True):
return fut
T = TypeVar("T")
GenericWithOneTypeVar = Generic[T]
@ -669,6 +673,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
return rref
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT):
if not callable(func):
raise TypeError("function should be callable.")
@ -900,6 +905,7 @@ def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
_thread_local_var.future_list.append(fut)
return fut
def _get_should_profile():
# Legacy profiler should be enabled. RPC profiling is not supported with
# Kineto profiler.
@ -909,6 +915,7 @@ def _get_should_profile():
torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
)
def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info):
ctx_manager = contextlib.suppress()