mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-31 04:04:57 +08:00
fix docstring issues in torch.distributed (#113337)
Fixes #112643 Fixes all the issues listed ### Error Count |File | Count Before | Count now| |---- | ---- | ---- | |`torch/distributed/optim/named_optimizer.py` | 13 | 1| |`torch/distributed/nn/functional.py` | 7 | 1| |`torch/distributed/nn/api/remote_module.py` | 25 | 3| |`torch/distributed/algorithms/join.py` | 43 | 4| Pull Request resolved: https://github.com/pytorch/pytorch/pull/113337 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
5e10dd2c78
commit
c0b57d4e3b
@ -10,28 +10,28 @@ __all__ = ['JoinHook', 'Joinable', 'Join']
|
||||
|
||||
class JoinHook:
|
||||
r"""
|
||||
This defines a join hook, which provides two entry points in the join
|
||||
context manager: a main hook, which is called repeatedly while there exists
|
||||
a non-joined process, and a post-hook, which is called once all processes
|
||||
have joined.
|
||||
This defines a join hook, which provides two entry points in the join context manager.
|
||||
|
||||
Entry points : a main hook, which is called repeatedly while there exists a non-joined
|
||||
process, and a post-hook, which is called once all processes have joined.
|
||||
|
||||
To implement a join hook for the generic join context manager, define a
|
||||
class that inherits from :class:`JoinHook` and override ``main_hook()`` and
|
||||
``post_hook()`` as appropriate.
|
||||
"""
|
||||
|
||||
def main_hook(self) -> None:
|
||||
r"""
|
||||
This hook is called repeatedly while there exists a non-joined process
|
||||
to shadow collective communications in one training iteration (i.e. in
|
||||
one forward pass, backward pass, and optimizer step).
|
||||
r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration.
|
||||
|
||||
Training iteration i.e., in one forward pass, backward pass, and optimizer step.
|
||||
"""
|
||||
...
|
||||
|
||||
def post_hook(self, is_last_joiner: bool) -> None:
|
||||
r"""
|
||||
This hook is called after all processes have joined. It is passed an
|
||||
additional ``bool`` argument ``is_last_joiner``, which indicates if the
|
||||
rank is one of the last to join.
|
||||
Call hook after all processes have joined.
|
||||
|
||||
It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join.
|
||||
|
||||
Arguments:
|
||||
is_last_joiner (bool): ``True`` if the rank is one of the last to
|
||||
@ -42,12 +42,15 @@ class JoinHook:
|
||||
|
||||
class Joinable(ABC):
|
||||
r"""
|
||||
This defines an abstract base class for joinable classes. A joinable class
|
||||
This defines an abstract base class for joinable classes.
|
||||
|
||||
A joinable class
|
||||
(inheriting from :class:`Joinable`) should implement :meth:`join_hook`,
|
||||
which returns a :class:`JoinHook` instance, in addition to
|
||||
:meth:`join_device` and :meth:`join_process_group` that return device and
|
||||
process group information, respectively.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -56,7 +59,7 @@ class Joinable(ABC):
|
||||
@abstractmethod
|
||||
def join_hook(self, **kwargs) -> JoinHook:
|
||||
r"""
|
||||
Returns a :class:`JoinHook` instance for the given :class:`Joinable`.
|
||||
Return a :class:`JoinHook` instance for the given :class:`Joinable`.
|
||||
|
||||
Arguments:
|
||||
kwargs (dict): a :class:`dict` containing any keyword arguments
|
||||
@ -69,37 +72,28 @@ class Joinable(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def join_device(self) -> torch.device:
|
||||
r"""
|
||||
Returns the device from which to perform collective communications
|
||||
needed by the join context manager implementation itself.
|
||||
"""
|
||||
r"""Return the device from which to perform collective communications needed by the join context manager."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def join_process_group(self) -> Any:
|
||||
r"""
|
||||
Returns the process group for the collective communications needed by
|
||||
the join context manager itself.
|
||||
"""
|
||||
r"""Returns the process group for the collective communications needed by the join context manager itself."""
|
||||
...
|
||||
|
||||
|
||||
class _JoinConfig(NamedTuple):
|
||||
r"""
|
||||
This includes all fields needed from a :class:`Joinable` instance for the
|
||||
join context manager side.
|
||||
"""
|
||||
r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side."""
|
||||
|
||||
enable: bool
|
||||
throw_on_early_termination: bool
|
||||
is_first_joinable: bool
|
||||
|
||||
@staticmethod
|
||||
def construct_disabled_join_config():
|
||||
r"""
|
||||
Returns a :class:`_JoinConfig` instance indicating that join-related
|
||||
logic should be disabled, e.g. if the caller is not in a join context
|
||||
manager.
|
||||
r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled.
|
||||
|
||||
e.g. if the caller is not in a join context manager.
|
||||
"""
|
||||
return _JoinConfig(
|
||||
enable=False,
|
||||
@ -111,8 +105,9 @@ class _JoinConfig(NamedTuple):
|
||||
|
||||
class Join:
|
||||
r"""
|
||||
This class defines the generic join context manager, which allows custom
|
||||
hooks to be called after a process joins. These hooks should shadow the
|
||||
This class defines the generic join context manager, which allows custom hooks to be called after a process joins.
|
||||
|
||||
These hooks should shadow the
|
||||
collective communications of non-joined processes to prevent hanging and
|
||||
erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook`
|
||||
for details about the hook definition.
|
||||
@ -169,6 +164,7 @@ class Join:
|
||||
>>> optim.step()
|
||||
>>> # All ranks reach here without hanging/erroring
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
joinables: List[Joinable],
|
||||
@ -186,9 +182,7 @@ class Join:
|
||||
self._extract_dist_info()
|
||||
|
||||
def _set_joinable_configs(self) -> None:
|
||||
r"""
|
||||
Sets the :class:`_JoinConfig` of each participating :class:`Joinable`.
|
||||
"""
|
||||
r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`."""
|
||||
assert len(self._joinables) > 0
|
||||
is_first_joinable = True
|
||||
for joinable in self._joinables:
|
||||
@ -201,7 +195,8 @@ class Join:
|
||||
|
||||
def _extract_dist_info(self) -> None:
|
||||
r"""
|
||||
Extracts the process group and device information from the joinables.
|
||||
Extract the process group and device information from the joinables.
|
||||
|
||||
If there are multiple joinables, then the context manager uses the
|
||||
first specified device.
|
||||
|
||||
@ -236,8 +231,7 @@ class Join:
|
||||
traceback: Optional[TracebackType]
|
||||
):
|
||||
r"""
|
||||
Repeatedly runs the main hooks until all processes join; then, runs
|
||||
the post-hooks.
|
||||
Repeatedly runs the main hooks until all processes join; then, runs the post-hooks.
|
||||
|
||||
Raises:
|
||||
RuntimeError
|
||||
@ -283,19 +277,15 @@ class Join:
|
||||
join_hook.post_hook(is_last_joiner)
|
||||
|
||||
def _get_num_nonjoined_procs(self):
|
||||
r"""
|
||||
Returns the number of non-joined processes by shadowing an all-reduce
|
||||
in the non-joined processes.
|
||||
"""
|
||||
r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes."""
|
||||
num_nonjoined_procs = torch.zeros(1, device=self._device)
|
||||
dist.all_reduce(num_nonjoined_procs, group=self._process_group)
|
||||
return num_nonjoined_procs.item()
|
||||
|
||||
def _notify_procs_to_terminate(self):
|
||||
r"""
|
||||
Schedules an all-reduce to notify non-joined processes to terminate
|
||||
and raises a ``RuntimeError`` indicating that the current process has
|
||||
exhausted its inputs.
|
||||
r"""Schedule an all-reduce to notify non-joined processes to terminate.
|
||||
|
||||
Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs.
|
||||
"""
|
||||
ones = torch.ones(1, device=self._device)
|
||||
dist.all_reduce(ones, group=self._process_group)
|
||||
@ -304,10 +294,10 @@ class Join:
|
||||
@staticmethod
|
||||
def notify_join_context(joinable: Joinable):
|
||||
r"""
|
||||
Notifies the join context manager that the calling process has not yet
|
||||
joined; then, if ``throw_on_early_termination=True``, checks if uneven
|
||||
inputs have been detected (i.e. if one process has already joined) and
|
||||
throws an exception if so.
|
||||
Notifies the join context manager that the calling process has not yet joined.
|
||||
|
||||
Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected
|
||||
(i.e. if one process has already joined) and throws an exception if so.
|
||||
|
||||
This method should be called from a :class:`Joinable` object before
|
||||
its per-iteration collective communications. For example, this should
|
||||
|
||||
Reference in New Issue
Block a user