fix docstring issues in torch.distributed (#113337)

Fixes #112643

Fixes all the issues listed

### Error Count

|File | Count Before | Count now|
|---- | ---- | ---- |
|`torch/distributed/optim/named_optimizer.py` | 13 | 1|
|`torch/distributed/nn/functional.py` | 7 | 1|
|`torch/distributed/nn/api/remote_module.py` | 25 | 3|
|`torch/distributed/algorithms/join.py` | 43 | 4|

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113337
Approved by: https://github.com/ezyang
This commit is contained in:
ChanBong
2023-11-13 19:37:21 +00:00
committed by PyTorch MergeBot
parent 5e10dd2c78
commit c0b57d4e3b
4 changed files with 69 additions and 96 deletions

View File

@ -10,28 +10,28 @@ __all__ = ['JoinHook', 'Joinable', 'Join']
class JoinHook:
r"""
This defines a join hook, which provides two entry points in the join
context manager: a main hook, which is called repeatedly while there exists
a non-joined process, and a post-hook, which is called once all processes
have joined.
This defines a join hook, which provides two entry points in the join context manager.
Entry points : a main hook, which is called repeatedly while there exists a non-joined
process, and a post-hook, which is called once all processes have joined.
To implement a join hook for the generic join context manager, define a
class that inherits from :class:`JoinHook` and override ``main_hook()`` and
``post_hook()`` as appropriate.
"""
def main_hook(self) -> None:
r"""
This hook is called repeatedly while there exists a non-joined process
to shadow collective communications in one training iteration (i.e. in
one forward pass, backward pass, and optimizer step).
r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration.
Training iteration i.e., in one forward pass, backward pass, and optimizer step.
"""
...
def post_hook(self, is_last_joiner: bool) -> None:
r"""
This hook is called after all processes have joined. It is passed an
additional ``bool`` argument ``is_last_joiner``, which indicates if the
rank is one of the last to join.
Call hook after all processes have joined.
It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join.
Arguments:
is_last_joiner (bool): ``True`` if the rank is one of the last to
@ -42,12 +42,15 @@ class JoinHook:
class Joinable(ABC):
r"""
This defines an abstract base class for joinable classes. A joinable class
This defines an abstract base class for joinable classes.
A joinable class
(inheriting from :class:`Joinable`) should implement :meth:`join_hook`,
which returns a :class:`JoinHook` instance, in addition to
:meth:`join_device` and :meth:`join_process_group` that return device and
process group information, respectively.
"""
@abstractmethod
def __init__(self):
super().__init__()
@ -56,7 +59,7 @@ class Joinable(ABC):
@abstractmethod
def join_hook(self, **kwargs) -> JoinHook:
r"""
Returns a :class:`JoinHook` instance for the given :class:`Joinable`.
Return a :class:`JoinHook` instance for the given :class:`Joinable`.
Arguments:
kwargs (dict): a :class:`dict` containing any keyword arguments
@ -69,37 +72,28 @@ class Joinable(ABC):
@property
@abstractmethod
def join_device(self) -> torch.device:
r"""
Returns the device from which to perform collective communications
needed by the join context manager implementation itself.
"""
r"""Return the device from which to perform collective communications needed by the join context manager."""
...
@property
@abstractmethod
def join_process_group(self) -> Any:
r"""
Returns the process group for the collective communications needed by
the join context manager itself.
"""
r"""Returns the process group for the collective communications needed by the join context manager itself."""
...
class _JoinConfig(NamedTuple):
r"""
This includes all fields needed from a :class:`Joinable` instance for the
join context manager side.
"""
r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side."""
enable: bool
throw_on_early_termination: bool
is_first_joinable: bool
@staticmethod
def construct_disabled_join_config():
r"""
Returns a :class:`_JoinConfig` instance indicating that join-related
logic should be disabled, e.g. if the caller is not in a join context
manager.
r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled.
e.g. if the caller is not in a join context manager.
"""
return _JoinConfig(
enable=False,
@ -111,8 +105,9 @@ class _JoinConfig(NamedTuple):
class Join:
r"""
This class defines the generic join context manager, which allows custom
hooks to be called after a process joins. These hooks should shadow the
This class defines the generic join context manager, which allows custom hooks to be called after a process joins.
These hooks should shadow the
collective communications of non-joined processes to prevent hanging and
erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook`
for details about the hook definition.
@ -169,6 +164,7 @@ class Join:
>>> optim.step()
>>> # All ranks reach here without hanging/erroring
"""
def __init__(
self,
joinables: List[Joinable],
@ -186,9 +182,7 @@ class Join:
self._extract_dist_info()
def _set_joinable_configs(self) -> None:
r"""
Sets the :class:`_JoinConfig` of each participating :class:`Joinable`.
"""
r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`."""
assert len(self._joinables) > 0
is_first_joinable = True
for joinable in self._joinables:
@ -201,7 +195,8 @@ class Join:
def _extract_dist_info(self) -> None:
r"""
Extracts the process group and device information from the joinables.
Extract the process group and device information from the joinables.
If there are multiple joinables, then the context manager uses the
first specified device.
@ -236,8 +231,7 @@ class Join:
traceback: Optional[TracebackType]
):
r"""
Repeatedly runs the main hooks until all processes join; then, runs
the post-hooks.
Repeatedly runs the main hooks until all processes join; then, runs the post-hooks.
Raises:
RuntimeError
@ -283,19 +277,15 @@ class Join:
join_hook.post_hook(is_last_joiner)
def _get_num_nonjoined_procs(self):
r"""
Returns the number of non-joined processes by shadowing an all-reduce
in the non-joined processes.
"""
r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes."""
num_nonjoined_procs = torch.zeros(1, device=self._device)
dist.all_reduce(num_nonjoined_procs, group=self._process_group)
return num_nonjoined_procs.item()
def _notify_procs_to_terminate(self):
r"""
Schedules an all-reduce to notify non-joined processes to terminate
and raises a ``RuntimeError`` indicating that the current process has
exhausted its inputs.
r"""Schedule an all-reduce to notify non-joined processes to terminate.
Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs.
"""
ones = torch.ones(1, device=self._device)
dist.all_reduce(ones, group=self._process_group)
@ -304,10 +294,10 @@ class Join:
@staticmethod
def notify_join_context(joinable: Joinable):
r"""
Notifies the join context manager that the calling process has not yet
joined; then, if ``throw_on_early_termination=True``, checks if uneven
inputs have been detected (i.e. if one process has already joined) and
throws an exception if so.
Notifies the join context manager that the calling process has not yet joined.
Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected
(i.e. if one process has already joined) and throws an exception if so.
This method should be called from a :class:`Joinable` object before
its per-iteration collective communications. For example, this should