diff --git a/torch/cuda/amp/grad_scaler.py b/torch/cuda/amp/grad_scaler.py index dbb6371d633f..5331c0f8efe1 100644 --- a/torch/cuda/amp/grad_scaler.py +++ b/torch/cuda/amp/grad_scaler.py @@ -62,11 +62,11 @@ class GradScaler(object): ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params - themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. - If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by - ``growth_factor``. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these @@ -125,11 +125,11 @@ class GradScaler(object): """ Multiplies ('scales') a tensor or list of tensors by the scale factor. + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + Arguments: outputs (Tensor or iterable of Tensors): Outputs to scale. - - Returns: - Scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned unmodified. """ if not self._enabled: return outputs @@ -234,14 +234,13 @@ class GradScaler(object): ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + Returns the return value of ``optimizer.step(*args, **kwargs)``. + Arguments: optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. args: Any arguments. kwargs: Any keyword arguments. - Returns: - The return value of ``optimizer.step(*args, **kwargs)``. - .. warning:: Closure use is not currently supported. """ @@ -342,8 +341,7 @@ class GradScaler(object): def get_scale(self): """ - Returns: - A Python float containing the current scale, or 1.0 if scaling is disabled. + Returns a Python float containing the current scale, or 1.0 if scaling is disabled. .. warning:: :meth:`get_scale` incurs a CPU-GPU sync. @@ -355,8 +353,7 @@ class GradScaler(object): def get_growth_factor(self): r""" - Returns: - A Python float containing the scale growth factor. + Returns a Python float containing the scale growth factor. """ return self._growth_factor @@ -369,8 +366,7 @@ class GradScaler(object): def get_backoff_factor(self): r""" - Returns: - A Python float containing the scale backoff factor. + Returns a Python float containing the scale backoff factor. """ return self._backoff_factor @@ -383,8 +379,7 @@ class GradScaler(object): def get_growth_interval(self): r""" - Returns: - A Python int containing the growth interval. + Returns a Python int containing the growth interval. """ return self._growth_interval @@ -403,8 +398,7 @@ class GradScaler(object): def is_enabled(self): r""" - Returns: - A bool indicating whether this instance is enabled. + Returns a bool indicating whether this instance is enabled. """ return self._enabled