mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843 Approved by: https://github.com/oulgen ghstack dependencies: #127842
395 lines
17 KiB
Python
395 lines
17 KiB
Python
# mypy: allow-untyped-defs
|
|
import logging
|
|
from collections import abc, defaultdict
|
|
from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
|
|
from torch.distributed.distributed_c10d import ProcessGroup
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _refresh_per_optimizer_state() -> Dict[str, Any]:
|
|
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
|
|
|
|
|
def _is_supported_device(tensor: torch.Tensor) -> bool:
|
|
return tensor.is_cuda or tensor.device.type in (
|
|
"xla",
|
|
"cpu",
|
|
"hpu",
|
|
torch._C._get_privateuse1_backend_name(),
|
|
)
|
|
|
|
|
|
class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator):
|
|
"""
|
|
Lazily serves tensor to request device. This class extends
|
|
_MultiDeviceReplicator to allow support for "cpu" as a device.
|
|
"""
|
|
|
|
def __init__(self, master_tensor: torch.Tensor) -> None:
|
|
assert _is_supported_device(master_tensor)
|
|
self.master = master_tensor
|
|
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
|
|
|
|
|
class ShardedGradScaler(GradScaler):
|
|
"""
|
|
ShardedGradScaler helps perform gradient scaling in a shard aware manner. It extends
|
|
functionality from GradScaler:
|
|
* Supports Pytorch DDP and FSDP implementations
|
|
* Support CPU offloaded tensors (as used in fully sharded data parallel[FSDP])
|
|
* Supports the custom Mixed Precision loss dtype (fp16, bf16) that FSDP returns
|
|
* Sync inf/nan for scaled gradient tensors on any torch.device (where tensors are placed) across
|
|
nodes
|
|
|
|
Example::
|
|
|
|
# Creates a ShardedGradScaler once at the beginning of training.
|
|
scaler = ShardedGradScaler()
|
|
|
|
for epoch in epochs:
|
|
for input, target in data:
|
|
optimizer.zero_grad()
|
|
output = model(input)
|
|
loss = loss_fn(output, target)
|
|
|
|
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
|
|
scaler.scale(loss).backward()
|
|
|
|
# scaler.step() first unscales gradients of the optimizer's params.
|
|
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
|
|
# otherwise, optimizer.step() is skipped.
|
|
scaler.step(optimizer)
|
|
|
|
# Updates the scale for next iteration.
|
|
scaler.update()
|
|
|
|
See :class:`GradScaler` for explanation of scaling/unscaling and more use cases.
|
|
|
|
Args:
|
|
init_scale (float, optional, default=2.**16): Initial scale factor.
|
|
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
|
|
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
|
|
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
|
|
:meth:`update` if inf/NaN gradients occur in an iteration.
|
|
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
|
|
that must occur for the scale to be multiplied by ``growth_factor``.
|
|
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
|
|
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
|
Default: ``True``
|
|
process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD):
|
|
process group for sharding
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
device: str = "cuda",
|
|
init_scale: float = 2.0**16,
|
|
backoff_factor: float = 0.5,
|
|
growth_factor: float = 2.0,
|
|
growth_interval: int = 2000,
|
|
enabled: bool = True,
|
|
process_group: Optional[ProcessGroup] = dist.group.WORLD,
|
|
) -> None:
|
|
super().__init__(
|
|
device,
|
|
init_scale=init_scale,
|
|
backoff_factor=backoff_factor,
|
|
growth_factor=growth_factor,
|
|
growth_interval=growth_interval,
|
|
enabled=enabled,
|
|
)
|
|
if self._enabled:
|
|
self.process_group = process_group
|
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
|
|
|
@overload
|
|
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
|
|
...
|
|
|
|
@overload
|
|
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
...
|
|
|
|
@overload
|
|
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
|
|
...
|
|
|
|
@overload
|
|
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
|
|
...
|
|
|
|
def scale(
|
|
self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
|
|
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
|
|
if not self._enabled:
|
|
return outputs
|
|
|
|
if isinstance(outputs, torch.Tensor):
|
|
assert _is_supported_device(outputs)
|
|
if self._scale is None:
|
|
self._lazy_init_scale_growth_tracker(outputs.device)
|
|
assert self._scale is not None
|
|
scaled_output = outputs * self._scale.to(
|
|
device=outputs.device, non_blocking=True
|
|
)
|
|
# Here we ensure the return dtype is the same as the outputs dtype.
|
|
# For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
|
|
# format (fp16, bf16) and so the scaled loss should be of the same dtype.
|
|
return scaled_output.type(outputs.dtype)
|
|
|
|
stash: List[_GeneralMultiDeviceReplicator] = []
|
|
|
|
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
|
|
if isinstance(val, torch.Tensor):
|
|
assert _is_supported_device(val)
|
|
if len(stash) == 0:
|
|
if self._scale is None:
|
|
self._lazy_init_scale_growth_tracker(val.device)
|
|
assert self._scale is not None
|
|
stash.append(_GeneralMultiDeviceReplicator(self._scale))
|
|
scaled_val = val * stash[0].get(val.device)
|
|
# Here we ensure the return dtype is the same as the outputs dtype.
|
|
# For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
|
|
# format (fp16, bf16) and so the scaled loss should be of the same dtype.
|
|
return scaled_val.type(val.dtype)
|
|
if isinstance(val, abc.Iterable):
|
|
iterator = map(apply_scale, val)
|
|
if isinstance(val, (list, tuple)):
|
|
return type(val)(iterator)
|
|
return iterator
|
|
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
|
|
|
return apply_scale(outputs)
|
|
|
|
def _foreach_non_finite_check_and_unscale_cpu_(
|
|
self,
|
|
grads: Sequence[torch.Tensor],
|
|
found_inf: torch.Tensor,
|
|
inv_scale: torch.Tensor,
|
|
) -> None:
|
|
if len(grads) == 0:
|
|
return
|
|
assert inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."
|
|
assert found_inf.numel() == 1, "found_inf must be a 1-element tensor."
|
|
|
|
for grad in grads:
|
|
if grad.device.type != "cpu":
|
|
logger.error(
|
|
"tensor device is %s but was expected to be ``cpu``",
|
|
grad.device,
|
|
)
|
|
raise ValueError(
|
|
"Gradients were found on a non-CPU device when"
|
|
" expected to be on CPU."
|
|
)
|
|
if (
|
|
torch.isinf(grad).any().item() is True
|
|
or torch.isnan(grad).any().item() is True
|
|
):
|
|
found_inf.data = torch.tensor([1.0])
|
|
break
|
|
else:
|
|
grad.data *= inv_scale.item()
|
|
|
|
def _unscale_grads_(
|
|
self,
|
|
optimizer: torch.optim.Optimizer,
|
|
inv_scale: torch.Tensor,
|
|
found_inf: torch.Tensor,
|
|
allow_fp16: bool = True,
|
|
) -> Dict[torch.device, torch.Tensor]:
|
|
per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale)
|
|
per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf)
|
|
|
|
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
|
# There could be thousands of grads, so we'd like to iterate through them just once.
|
|
# However, we don't know their devices or dtypes in advance.
|
|
|
|
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
|
# Google says mypy struggles with defaultdicts type annotations.
|
|
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
|
with torch.no_grad():
|
|
for group in optimizer.param_groups:
|
|
for param in group["params"]:
|
|
if param.grad is None:
|
|
continue
|
|
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
|
raise ValueError("Attempting to unscale FP16 gradients.")
|
|
if param.grad.is_sparse:
|
|
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
|
# coalesce() deduplicates indices and adds all values that have the same index.
|
|
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
|
# so we should check the coalesced _values().
|
|
if param.grad.dtype is torch.float16:
|
|
# coalesce is not supported in torch.float16
|
|
param_grad_fp32 = param.grad.type(torch.float32).coalesce()
|
|
param.grad = param_grad_fp32.type(torch.float16)
|
|
to_unscale = param.grad._values()
|
|
else:
|
|
to_unscale = param.grad
|
|
|
|
per_device_and_dtype_grads[to_unscale.device][
|
|
to_unscale.dtype
|
|
].append(to_unscale)
|
|
|
|
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
|
for grads in per_dtype_grads.values():
|
|
if grads[0].device.type == "cpu":
|
|
self._foreach_non_finite_check_and_unscale_cpu_(
|
|
grads,
|
|
per_device_found_inf.get(device),
|
|
per_device_inv_scale.get(device),
|
|
)
|
|
else:
|
|
torch._amp_foreach_non_finite_check_and_unscale_(
|
|
grads,
|
|
per_device_found_inf.get(device),
|
|
per_device_inv_scale.get(device),
|
|
)
|
|
# There exist contexts (e.g. w/ `use_orig_params=True`) wherein some
|
|
# ranks may have no (non-zero sized) parameter shards, necessitating the
|
|
# initialization of `per_device_found_inf._per_device_tensors` here
|
|
if not per_device_found_inf._per_device_tensors:
|
|
assert self._scale is not None
|
|
per_device_found_inf.get(self._scale.device)
|
|
return per_device_found_inf._per_device_tensors
|
|
|
|
def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
|
|
if not self._enabled:
|
|
return
|
|
|
|
self._check_scale_growth_tracker("unscale_")
|
|
|
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
|
|
|
if optimizer_state["stage"] is OptState.UNSCALED:
|
|
raise RuntimeError(
|
|
"unscale_() has already been called on this optimizer since the last update()."
|
|
)
|
|
elif optimizer_state["stage"] is OptState.STEPPED:
|
|
raise RuntimeError("unscale_() is being called after step().")
|
|
|
|
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
|
assert self._scale is not None
|
|
inv_scale = self._scale.double().reciprocal().float()
|
|
found_inf = torch.full(
|
|
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
|
)
|
|
|
|
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
|
optimizer, inv_scale, found_inf, True
|
|
)
|
|
optimizer_state["stage"] = OptState.UNSCALED
|
|
|
|
# Synchronize the detected inf across the ranks
|
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
|
works = []
|
|
found_inf_on_cpus = []
|
|
found_inf_on_devices = []
|
|
|
|
for found_inf in optimizer_state["found_inf_per_device"].values():
|
|
if self._device != "cpu" and found_inf.device.type == "cpu":
|
|
found_inf_on_cpus.append(found_inf)
|
|
found_inf_on_device = found_inf.to(self._device)
|
|
found_inf_on_devices.append(found_inf_on_device)
|
|
works.append(
|
|
dist.all_reduce(
|
|
found_inf_on_device, async_op=True, group=self.process_group
|
|
)
|
|
)
|
|
else:
|
|
works.append(
|
|
dist.all_reduce(found_inf, async_op=True, group=self.process_group)
|
|
)
|
|
for work in works:
|
|
work.wait()
|
|
if found_inf_on_cpus:
|
|
torch._foreach_copy_(found_inf_on_cpus, found_inf_on_devices)
|
|
|
|
def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None:
|
|
"""
|
|
If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
|
|
Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
|
|
"""
|
|
assert self._scale is not None and self._growth_tracker is not None
|
|
|
|
if found_inf.item() >= 1.0:
|
|
self._scale *= self._backoff_factor
|
|
self._growth_tracker.fill_(0)
|
|
else:
|
|
successful = self._growth_tracker + 1
|
|
if successful == self._growth_interval:
|
|
self._scale *= self._growth_factor
|
|
self._growth_tracker.fill_(0)
|
|
else:
|
|
self._growth_tracker = successful
|
|
|
|
def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
|
|
"""
|
|
Updates the scale factor.
|
|
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
|
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
|
the scale is multiplied by ``growth_factor`` to increase it.
|
|
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
|
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
|
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
|
affect the scale GradScaler uses internally.)
|
|
Args:
|
|
new_scale (float or :class:`torch.Tensor`, optional, default=None): New scale factor.
|
|
.. warning::
|
|
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
|
been invoked for all optimizers used this iteration.
|
|
"""
|
|
|
|
if not self._enabled:
|
|
return
|
|
|
|
_scale, _growth_tracker = self._check_scale_growth_tracker("update") # type: ignore[var-annotated]
|
|
|
|
if new_scale is not None:
|
|
# Accept a new user-defined scale.
|
|
if isinstance(new_scale, float):
|
|
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
|
else:
|
|
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
|
|
torch.FloatTensor with requires_grad=False."
|
|
assert new_scale.device.type == self._device, reason
|
|
assert new_scale.numel() == 1, reason
|
|
assert new_scale.requires_grad is False, reason
|
|
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
|
else:
|
|
# Consume shared inf/nan data collected from optimizers to update the scale.
|
|
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
|
found_infs = [
|
|
found_inf.to(device=_scale.device, non_blocking=True)
|
|
for state in self._per_optimizer_states.values()
|
|
for found_inf in state["found_inf_per_device"].values()
|
|
]
|
|
|
|
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
|
|
|
found_inf_combined = found_infs[0]
|
|
if len(found_infs) > 1:
|
|
for i in range(1, len(found_infs)):
|
|
found_inf_combined += found_infs[i]
|
|
|
|
if _scale.device.type == "cpu":
|
|
self._amp_update_scale_cpu_(found_inf_combined)
|
|
else:
|
|
torch._amp_update_scale_(
|
|
self._scale, # type: ignore[arg-type]
|
|
self._growth_tracker, # type: ignore[arg-type]
|
|
found_inf_combined,
|
|
self._growth_factor, # type: ignore[arg-type]
|
|
self._backoff_factor, # type: ignore[arg-type]
|
|
self._growth_interval, # type: ignore[arg-type]
|
|
)
|
|
|
|
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|