Formatting changes for gradient scaling (#33832)

Summary:
hard to get right locally...I can build the docs but never quite match what it looks like live.  the bullet point indentation was just an oversight.

Removing `Returns:` formatting tabs because they take up a lot of space when rendered and add no clarity.  Some functions in Pytorch [do use them](https://pytorch.org/docs/master/torch.html#torch.eye), but [many don't bother](https://pytorch.org/docs/master/torch.html#torch.is_tensor), so apparently some people shared my feelings (Not using them is in line with existing practice).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33832

Differential Revision: D20135581

Pulled By: ngimel

fbshipit-source-id: bc788a7e57b142f95c4fa5baf3fe01f94c45abd8
This commit is contained in:
Michael Carilli
2020-02-28 11:37:04 -08:00
committed by Facebook Github Bot
parent 5dde8cd483
commit a726827ec8

View File

@ -62,11 +62,11 @@ class GradScaler(object):
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
``growth_factor``. ``growth_factor``.
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
@ -125,11 +125,11 @@ class GradScaler(object):
""" """
Multiplies ('scales') a tensor or list of tensors by the scale factor. Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Arguments: Arguments:
outputs (Tensor or iterable of Tensors): Outputs to scale. outputs (Tensor or iterable of Tensors): Outputs to scale.
Returns:
Scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned unmodified.
""" """
if not self._enabled: if not self._enabled:
return outputs return outputs
@ -234,14 +234,13 @@ class GradScaler(object):
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
Returns the return value of ``optimizer.step(*args, **kwargs)``.
Arguments: Arguments:
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
args: Any arguments. args: Any arguments.
kwargs: Any keyword arguments. kwargs: Any keyword arguments.
Returns:
The return value of ``optimizer.step(*args, **kwargs)``.
.. warning:: .. warning::
Closure use is not currently supported. Closure use is not currently supported.
""" """
@ -342,8 +341,7 @@ class GradScaler(object):
def get_scale(self): def get_scale(self):
""" """
Returns: Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
A Python float containing the current scale, or 1.0 if scaling is disabled.
.. warning:: .. warning::
:meth:`get_scale` incurs a CPU-GPU sync. :meth:`get_scale` incurs a CPU-GPU sync.
@ -355,8 +353,7 @@ class GradScaler(object):
def get_growth_factor(self): def get_growth_factor(self):
r""" r"""
Returns: Returns a Python float containing the scale growth factor.
A Python float containing the scale growth factor.
""" """
return self._growth_factor return self._growth_factor
@ -369,8 +366,7 @@ class GradScaler(object):
def get_backoff_factor(self): def get_backoff_factor(self):
r""" r"""
Returns: Returns a Python float containing the scale backoff factor.
A Python float containing the scale backoff factor.
""" """
return self._backoff_factor return self._backoff_factor
@ -383,8 +379,7 @@ class GradScaler(object):
def get_growth_interval(self): def get_growth_interval(self):
r""" r"""
Returns: Returns a Python int containing the growth interval.
A Python int containing the growth interval.
""" """
return self._growth_interval return self._growth_interval
@ -403,8 +398,7 @@ class GradScaler(object):
def is_enabled(self): def is_enabled(self):
r""" r"""
Returns: Returns a bool indicating whether this instance is enabled.
A bool indicating whether this instance is enabled.
""" """
return self._enabled return self._enabled