mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: hard to get right locally...I can build the docs but never quite match what it looks like live. the bullet point indentation was just an oversight. Removing `Returns:` formatting tabs because they take up a lot of space when rendered and add no clarity. Some functions in Pytorch [do use them](https://pytorch.org/docs/master/torch.html#torch.eye), but [many don't bother](https://pytorch.org/docs/master/torch.html#torch.is_tensor), so apparently some people shared my feelings (Not using them is in line with existing practice). Pull Request resolved: https://github.com/pytorch/pytorch/pull/33832 Differential Revision: D20135581 Pulled By: ngimel fbshipit-source-id: bc788a7e57b142f95c4fa5baf3fe01f94c45abd8
464 lines
20 KiB
Python
464 lines
20 KiB
Python
import torch
|
|
from collections import defaultdict
|
|
from torch._six import container_abcs
|
|
|
|
|
|
class _MultiDeviceReplicator(object):
|
|
"""
|
|
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
|
|
"""
|
|
def __init__(self, master_tensor):
|
|
assert master_tensor.is_cuda
|
|
self.master = master_tensor
|
|
self._per_device_tensors = {}
|
|
|
|
def get(self, device):
|
|
retval = self._per_device_tensors.get(device, None)
|
|
if retval is None:
|
|
retval = self.master.to(device=device, non_blocking=True, copy=True)
|
|
self._per_device_tensors[device] = retval
|
|
return retval
|
|
|
|
|
|
class GradScaler(object):
|
|
"""
|
|
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
|
|
conveniently.
|
|
|
|
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
|
|
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
|
|
* ``scaler.update()`` updates ``scaler``'s scale factor.
|
|
|
|
Typical use::
|
|
|
|
# Creates a GradScaler once at the beginning of training.
|
|
scaler = GradScaler()
|
|
|
|
for epoch in epochs:
|
|
for input, target in data:
|
|
optimizer.zero_grad()
|
|
output = model(input)
|
|
loss = loss_fn(output, target)
|
|
|
|
# Scales the loss, and calls backward() on the scaled loss to create scaled gradients.
|
|
scaler.scale(loss).backward()
|
|
|
|
# scaler.step() first unscales the gradients of the optimizer's assigned params.
|
|
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
|
|
# otherwise, optimizer.step() is skipped.
|
|
scaler.step(optimizer)
|
|
|
|
# Updates the scale for next iteration.
|
|
scaler.update()
|
|
|
|
See the :ref:`Gradient Scaling Examples<gradient-scaling-examples>` for usage in more complex cases like
|
|
gradient clipping, gradient penalty, and multiple losses/optimizers.
|
|
|
|
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
|
|
a large scale factor should be used. However, ``torch.float16`` values can "overflow" (become inf or NaN) if
|
|
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
|
|
without incurring inf or NaN gradient values.
|
|
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
|
|
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
|
|
|
|
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
|
|
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
|
|
|
|
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
|
|
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
|
|
``growth_factor``.
|
|
|
|
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
|
|
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
|
|
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
|
|
|
|
Arguments:
|
|
init_scale (float, optional, default=2.**16): Initial scale factor.
|
|
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
|
|
:meth:`update` if no inf/NaN gradients occur for ``growth_factor`` consecutive iterations.
|
|
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
|
|
:meth:`update` if inf/NaN gradients occur in an iteration.
|
|
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
|
|
that must occur for the scale to be multiplied by ``growth_factor``.
|
|
enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply
|
|
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
|
"""
|
|
# Python 2 doesn't support enums.
|
|
READY = 0
|
|
UNSCALED = 1
|
|
STEPPED = 2
|
|
|
|
def __init__(self,
|
|
init_scale=2.**16,
|
|
growth_factor=2.0,
|
|
backoff_factor=0.5,
|
|
growth_interval=2000,
|
|
enabled=True):
|
|
self._enabled = enabled
|
|
if enabled:
|
|
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
|
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
|
|
|
self._init_scale = init_scale
|
|
# self._scale will be lazily initialized during the first call to scale()
|
|
self._scale = None
|
|
self._growth_factor = growth_factor
|
|
self._backoff_factor = backoff_factor
|
|
self._growth_interval = growth_interval
|
|
self._init_growth_tracker = 0
|
|
# self._growth_tracker will be lazily initialized during the first call to scale()
|
|
self._growth_tracker = None
|
|
READY = self.READY
|
|
self._per_optimizer_states = defaultdict(lambda: {"stage": READY, "found_inf_per_device": {}})
|
|
|
|
def _check_scale_growth_tracker(self, funcname):
|
|
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
|
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
|
|
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
|
|
|
|
def _lazy_init_scale_growth_tracker(self, dev):
|
|
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
|
self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=dev)
|
|
self._growth_tracker = torch.full((1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
|
|
|
|
def scale(self, outputs):
|
|
"""
|
|
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
|
|
|
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
|
unmodified.
|
|
|
|
Arguments:
|
|
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
|
"""
|
|
if not self._enabled:
|
|
return outputs
|
|
|
|
# Short-circuit for the common case.
|
|
if isinstance(outputs, torch.Tensor):
|
|
assert outputs.is_cuda
|
|
if self._scale is None:
|
|
self._lazy_init_scale_growth_tracker(outputs.device)
|
|
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
|
|
|
|
# Invoke the more complex machinery only if we're treating multiple outputs.
|
|
stash = [None] # trick to hold a reference that can be overwritten at any level of the recursion below.
|
|
|
|
def apply_scale(val):
|
|
if isinstance(val, torch.Tensor):
|
|
assert val.is_cuda
|
|
if self._scale is None:
|
|
self._lazy_init_scale_growth_tracker(val.device)
|
|
if stash[0] is None:
|
|
stash[0] = _MultiDeviceReplicator(self._scale)
|
|
return val * stash[0].get(val.device)
|
|
elif isinstance(val, container_abcs.Iterable):
|
|
return type(val)(apply_scale(v) for v in val)
|
|
else:
|
|
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
|
|
|
return apply_scale(outputs)
|
|
|
|
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
|
|
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
|
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
|
|
|
for group in optimizer.param_groups:
|
|
for param in group["params"]:
|
|
if param.grad is not None:
|
|
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
|
raise ValueError("Attempting to unscale FP16 gradients.")
|
|
else:
|
|
torch._amp_non_finite_check_and_unscale_(param.grad,
|
|
per_device_found_inf.get(param.grad.device),
|
|
per_device_inv_scale.get(param.grad.device))
|
|
|
|
return per_device_found_inf._per_device_tensors
|
|
|
|
def unscale_(self, optimizer):
|
|
"""
|
|
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
|
|
|
:meth:`unscale_` is optional, serving cases where you need to
|
|
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
|
between the backward pass(es) and :meth:`step`.
|
|
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
|
|
|
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
|
|
|
...
|
|
scaler.scale(loss).backward()
|
|
scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
Arguments:
|
|
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
|
|
|
.. note::
|
|
:meth:`unscale_` does not incur a CPU-GPU sync.
|
|
|
|
.. warning::
|
|
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
|
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
|
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
|
"""
|
|
if not self._enabled:
|
|
return
|
|
|
|
self._check_scale_growth_tracker("unscale_")
|
|
|
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
|
|
|
if optimizer_state["stage"] == self.UNSCALED:
|
|
raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
|
|
elif optimizer_state["stage"] == self.STEPPED:
|
|
raise RuntimeError("unscale_() is being called after step().")
|
|
|
|
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
|
inv_scale = self._scale.double().reciprocal().float()
|
|
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
|
|
|
|
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
|
|
optimizer_state["stage"] = self.UNSCALED
|
|
|
|
def step(self, optimizer, *args, **kwargs):
|
|
"""
|
|
:meth:`step` carries out the following two operations:
|
|
|
|
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
|
|
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
|
|
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
|
|
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
|
|
|
|
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
|
|
|
|
Returns the return value of ``optimizer.step(*args, **kwargs)``.
|
|
|
|
Arguments:
|
|
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
|
|
args: Any arguments.
|
|
kwargs: Any keyword arguments.
|
|
|
|
.. warning::
|
|
Closure use is not currently supported.
|
|
"""
|
|
if (not self._enabled):
|
|
return optimizer.step(*args, **kwargs)
|
|
|
|
if "closure" in kwargs:
|
|
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
|
|
|
|
self._check_scale_growth_tracker("step")
|
|
|
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
|
|
|
if optimizer_state["stage"] == self.STEPPED:
|
|
raise RuntimeError("step() has already been called since the last update().")
|
|
|
|
retval = None
|
|
|
|
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
|
|
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
|
# The contract with custom optimizers is that their step() should accept an additional,
|
|
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
|
# it can query its own state, invoke unscale_ on itself, etc
|
|
retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self))
|
|
optimizer_state["stage"] == self.STEPPED
|
|
return retval
|
|
|
|
if optimizer_state["stage"] == self.READY:
|
|
self.unscale_(optimizer)
|
|
|
|
assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
|
|
|
|
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
|
|
retval = optimizer.step(*args, **kwargs)
|
|
|
|
optimizer_state["stage"] == self.STEPPED
|
|
|
|
return retval
|
|
|
|
def update(self, new_scale=None):
|
|
"""
|
|
Updates the scale factor.
|
|
|
|
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
|
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
|
the scale is multiplied by ``growth_factor`` to increase it.
|
|
|
|
Passing ``new_scale`` sets the scale directly.
|
|
|
|
Arguments:
|
|
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
|
|
|
|
.. warning::
|
|
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
|
been invoked for all optimizers used this iteration.
|
|
"""
|
|
if not self._enabled:
|
|
return
|
|
|
|
self._check_scale_growth_tracker("update")
|
|
|
|
if new_scale is not None:
|
|
# Accept a new user-defined scale.
|
|
if isinstance(new_scale, float):
|
|
self._scale = torch.full((1,), new_scale, dtype=torch.float32, device=self._scale.device)
|
|
else:
|
|
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
|
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
|
|
assert new_scale.numel() == 1, reason
|
|
assert new_scale.requires_grad is False, reason
|
|
self._scale = new_scale
|
|
else:
|
|
# Consume shared inf/nan data collected from optimizers to update the scale.
|
|
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
|
found_infs = [found_inf.to(device=self._scale.device, non_blocking=True)
|
|
for state in self._per_optimizer_states.values()
|
|
for found_inf in state["found_inf_per_device"].values()]
|
|
|
|
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
|
|
|
found_inf_combined = found_infs[0]
|
|
if len(found_infs) > 1:
|
|
for i in range(1, len(found_infs)):
|
|
found_inf_combined += found_infs[i]
|
|
|
|
self._scale = torch._amp_update_scale(self._growth_tracker,
|
|
self._scale,
|
|
found_inf_combined,
|
|
self._growth_factor,
|
|
self._backoff_factor,
|
|
self._growth_interval)
|
|
|
|
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
|
self._per_optimizer_states = defaultdict(lambda: {"stage": self.READY, "found_inf_per_device": {}})
|
|
|
|
def _get_scale_async(self):
|
|
return self._scale
|
|
|
|
def get_scale(self):
|
|
"""
|
|
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
|
|
|
|
.. warning::
|
|
:meth:`get_scale` incurs a CPU-GPU sync.
|
|
"""
|
|
if self._enabled:
|
|
return self._init_scale if self._scale is None else self._get_scale_async().item()
|
|
else:
|
|
return 1.0
|
|
|
|
def get_growth_factor(self):
|
|
r"""
|
|
Returns a Python float containing the scale growth factor.
|
|
"""
|
|
return self._growth_factor
|
|
|
|
def set_growth_factor(self, new_factor):
|
|
r"""
|
|
Arguments:
|
|
new_scale (float): Value to use as the new scale growth factor.
|
|
"""
|
|
self._growth_factor = new_factor
|
|
|
|
def get_backoff_factor(self):
|
|
r"""
|
|
Returns a Python float containing the scale backoff factor.
|
|
"""
|
|
return self._backoff_factor
|
|
|
|
def set_backoff_factor(self, new_factor):
|
|
r"""
|
|
Arguments:
|
|
new_scale (float): Value to use as the new scale backoff factor.
|
|
"""
|
|
self._backoff_factor = new_factor
|
|
|
|
def get_growth_interval(self):
|
|
r"""
|
|
Returns a Python int containing the growth interval.
|
|
"""
|
|
return self._growth_interval
|
|
|
|
def set_growth_interval(self, new_interval):
|
|
r"""
|
|
Arguments:
|
|
new_interval (int): Value to use as the new growth interval.
|
|
"""
|
|
self._growth_interval = new_interval
|
|
|
|
def _get_growth_tracker(self):
|
|
if self._enabled:
|
|
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
|
|
else:
|
|
return 0
|
|
|
|
def is_enabled(self):
|
|
r"""
|
|
Returns a bool indicating whether this instance is enabled.
|
|
"""
|
|
return self._enabled
|
|
|
|
def state_dict(self):
|
|
r"""
|
|
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
|
|
|
* ``"scale"`` - a Python float containing the current scale
|
|
* ``"growth_factor"`` - a Python float containing the current growth factor
|
|
* ``"backoff_factor"`` - a Python float containing the current backoff factor
|
|
* ``"growth_interval"`` - a Python int containing the current growth interval
|
|
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
|
|
|
|
If this instance is not enabled, returns an empty dict.
|
|
|
|
.. note::
|
|
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
|
should be called after :meth:`update`.
|
|
"""
|
|
return {"scale": self.get_scale(),
|
|
"growth_factor": self._growth_factor,
|
|
"backoff_factor": self._backoff_factor,
|
|
"growth_interval": self._growth_interval,
|
|
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
r"""
|
|
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
|
|
|
Arguments:
|
|
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
|
|
"""
|
|
if not self._enabled:
|
|
return
|
|
|
|
if len(state_dict) == 0:
|
|
raise RuntimeError("The source state dict is empty, possibly because it was saved "
|
|
"from a disabled instance of GradScaler.")
|
|
|
|
self._init_scale = state_dict["scale"]
|
|
if self._scale is not None:
|
|
self._scale.fill_(state_dict["scale"])
|
|
self._growth_factor = state_dict["growth_factor"]
|
|
self._backoff_factor = state_dict["backoff_factor"]
|
|
self._growth_interval = state_dict["growth_interval"]
|
|
self._init_growth_tracker = state_dict["_growth_tracker"]
|
|
if self._growth_tracker is not None:
|
|
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
|
|
|
def _check_inf_per_device(self, optimizer):
|
|
self._check_scale_growth_tracker("_check_inf_per_device")
|
|
|
|
dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=self._scale.device)
|
|
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
|
|
|
|
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
|
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
|
|
|
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
|
|
|
def _found_inf_per_device(self, optimizer):
|
|
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|