Files
pytorch/torch/distributed/fsdp/_runtime_utils.py
Nikita Shulga 634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00

1563 lines
64 KiB
Python

import functools
import logging
from enum import auto, Enum
from itertools import chain
from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple
import torch
import torch.distributed as dist
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.autograd.graph import register_multi_grad_hook
from torch.distributed import get_backend, get_world_size
from torch.distributed._tensor import DeviceMesh
from torch.distributed.algorithms._comm_hooks import default_hooks, LOW_PRECISION_HOOKS
from torch.distributed.fsdp._common_utils import (
_assert_in_training_states,
_FSDPState,
_get_module_fsdp_state,
_get_sharding_strategy,
_is_composable,
TrainingState,
)
from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
from torch.distributed.fsdp._utils import _no_dispatch_record_stream
from torch.distributed.fsdp.api import BackwardPrefetch
from torch.distributed.fsdp.flat_param import (
_HandlesKey,
FlatParameter,
FlatParamHandle,
HandleShardingStrategy,
HandleTrainingState,
RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES,
)
from torch.distributed.utils import (
_apply_to_tensors,
_cast_forward_inputs,
_p_assert,
_to_kwargs,
)
from torch.utils._pytree import tree_flatten
log = logging.getLogger(__name__)
# Do not include "process_group" to enable hybrid shard and MoE cases
HOMOGENEOUS_ATTR_NAMES = (
"_use_orig_params",
"limit_all_gathers",
"_use_full_prec_in_eval",
)
class _PrefetchMode(Enum):
BACKWARD = auto()
FORWARD = auto()
def _get_fsdp_root_states_with_modules(
module: nn.Module,
) -> Tuple[List[_FSDPState], List[nn.Module]]:
"""
Returns a tuple containing:
1. A list of the root ``_FSDPState`` instances in the module tree rooted at
``module`` without any duplicates and following the ``module.modules()``
traversal order (which is assumed to be depth-first).
2. A corresponding list of the root modules owning the states in the first
list.
This is similar to :func:`_get_fsdp_states_with_modules` except that we
must call :func:`_is_fsdp_root` to force a lazy initialization to determine
the FSDP root in case lazy initialization has not yet happened.
"""
fsdp_root_states: List[_FSDPState] = []
fsdp_root_modules: List[nn.Module] = []
visited_fsdp_states: Set[_FSDPState] = set()
# NOTE: This function assumes that `module.modules()` proceeds top-down.
for submodule in module.modules():
optional_state = _get_module_fsdp_state(submodule)
if (
optional_state is not None
and optional_state not in visited_fsdp_states
and _is_fsdp_root(optional_state, submodule)
):
visited_fsdp_states.add(optional_state)
fsdp_root_states.append(optional_state)
fsdp_root_modules.append(submodule)
return fsdp_root_states, fsdp_root_modules
def _get_fsdp_root_states(module: nn.Module) -> List[_FSDPState]:
"""See :func:`_get_fsdp_root_states_with_modules`."""
fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module)
return fsdp_root_states
def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool:
"""
Returns if ``state`` corresponds to that of an FSDP root.
For the wrapper code path, ``state`` and ``module`` should be the same. For
the non-wrapper code path, ``state`` should be ``module`` 's state.
"""
# Force a lazy initialization to determine the FSDP root
_lazy_init(state, module)
assert state._is_root is not None # mypy
return state._is_root
@no_type_check
def _validate_and_get_hybrid_shard_state(
root_module: nn.Module,
) -> default_hooks.DefaultState:
"""
Precondition: ``root_module`` is a ``FullyShardedDataParallel`` instance.
This checks that all instances using a hybrid sharding strategy have the
same intra- and inter-node process groups.
Returns:
DefaultState: One of the instances' inter-node state (does not
matter which since they will share the same one).
"""
intra_node_pgs = set()
inter_node_pgs = set()
inter_node_states = set()
for fsdp_module in traversal_utils._get_fsdp_states(root_module):
# TODO: Change this to handle's sharding strategy if we deprecate
# `ShardingStrategy` internally.
# https://github.com/pytorch/pytorch/issues/90857
if fsdp_module.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
intra_node_pgs.add(fsdp_module.process_group)
inter_node_pgs.add(fsdp_module._inter_node_pg)
inter_node_states.add(fsdp_module._inter_node_state)
if len(intra_node_pgs) == 0 and len(inter_node_pgs) == 0:
# No instances use a hybrid sharding strategy
return None
error_prefix = "At least one instance uses a hybrid sharding strategy but has no "
if len(intra_node_pgs) > 0 and len(inter_node_pgs) == 0:
raise AssertionError(error_prefix + "inter-node process group set")
if len(intra_node_pgs) == 0 and len(inter_node_pgs) > 0:
raise AssertionError(error_prefix + "intra-node process group set")
error_prefix = "Some instances use a hybrid sharding strategy, but "
if len(intra_node_pgs) != 1:
raise ValueError(error_prefix + "intra-node process groups do not match")
if len(inter_node_pgs) != 1:
raise ValueError(error_prefix + "inter-node process groups do not match")
return next(iter(inter_node_states))
@no_type_check
def _lazy_init(
state: _FSDPState,
root_module: nn.Module,
) -> _FSDPState:
"""
Performs initialization lazily, typically right before the first forward
pass. The laziness is needed to ensure that the parameter device/dtype and
the FSDP hierarchy have finalized. This method's actual logic only runs on
the root FSDP instance, which performs initialization for all non-root FSDP
instances to avoid partial initialization.
For the non-composable code path, ``state`` and ``root_module`` should be
the same, namely the FSDP instance itself.
"""
if state._is_root is not None:
return # no-op: already lazily initialized
if not state._device_handle.is_available():
# Allow the FSDP constructor to run even without CUDA but check this
# once we start real execution
raise RuntimeError("FSDP does not support CPU only execution")
# The following logic is only run on the root FSDP instance since it will
# set `_is_root=False` for the non-root instances
state._is_root = True
_assert_in_training_states(state, [TrainingState.IDLE])
_check_flat_params_on_expected_device(state, root_module)
_init_streams(state)
buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module)
_cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device)
state._exec_order_data.init(state, root_module, state.process_group)
_share_state_and_init_handle_attrs(state, root_module)
return state
def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module):
"""
Checks that all ``FlatParameter``s in ``module`` 's tree managed by
``state`` are on the expected device for *lazy initialization*.
"""
cpu_device = torch.device("cpu")
for handle in traversal_utils._get_fsdp_handles(module):
if (
not handle._offload_params
and handle.flat_param.device != state.compute_device
):
raise RuntimeError(
"An FSDP-managed module unexpectedly has parameters on "
f"{handle.flat_param.device}. Make sure to move the module to "
f"{state.compute_device} before training."
)
elif handle._offload_params and handle.flat_param.device != cpu_device:
raise RuntimeError(
"An FSDP-managed module with parameter CPU offloading enabled "
f"has parameters on {handle.flat_param.device}. Make sure to "
f"not move the module from CPU when offloading parameters."
)
def _init_device_mesh(
root_state: _FSDPState,
) -> Optional[DeviceMesh]:
# We are testing 1D DeviceMesh where dist.get_world_size(pg) == dist.get_world_size() for now.
# TODO: Address cases when dist.get_world_size(pg) != dist.get_world_size(). This would capture
# what 1D DeviceMesh currently would not work for:
# 1) HSDP Hybrid Sharding, 2) 2D FSDP + TP, 3) dist.new_group() cannot be expressed in 1D DeviceMesh.
if root_state.process_group != dist.distributed_c10d._get_default_group():
return None
if get_backend() == "fake" or not root_state.compute_device:
return None
device_type = root_state.compute_device.type
mesh_tensor = torch.arange(get_world_size(root_state.process_group))
device_mesh = DeviceMesh(device_type, mesh_tensor, _validate_mesh=False)
return device_mesh
@no_type_check
def _share_state_and_init_handle_attrs(
root_state: _FSDPState,
root_module: nn.Module,
) -> None:
"""
Shares data structure state from the ``root_state`` to all FSDP states in
``root_module`` 's module tree, and initializes handle attributes. These
are done together to require a single loop over the states.
"""
for handle in root_state._handles:
handle.init_flat_param_attributes()
inter_node_state = _validate_and_get_hybrid_shard_state(root_module)
attr_name_to_values: Dict[str, Set[Any]] = {}
for attr_name in HOMOGENEOUS_ATTR_NAMES:
attr_name_to_values[attr_name] = set()
root_state._all_fsdp_states = traversal_utils._get_fsdp_states(root_module)
root_state._all_handles = root_state._exec_order_data.all_handles # share reference
root_state._device_mesh = _init_device_mesh(root_state)
# Update _has_optim_in_backward for each handle.
for handle in root_state._all_handles:
flat_param = handle.flat_param
if hasattr(flat_param, "_in_backward_optimizers"):
raise RuntimeError(
"FSDP optimizer in backward only supported with use_orig_params=True!"
)
handle._has_optim_in_backward = flat_param._params is not None and any(
hasattr(param, "_in_backward_optimizers") for param in flat_param._params
)
for fsdp_state in root_state._all_fsdp_states:
for attr_name in HOMOGENEOUS_ATTR_NAMES:
_p_assert(
hasattr(fsdp_state, attr_name),
f"FSDP state missing attribute {attr_name}",
)
attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))
if fsdp_state is root_state:
continue
handle_sharding_strategy = _get_sharding_strategy(fsdp_state._handles)
if handle_sharding_strategy in (
HandleShardingStrategy.HYBRID_SHARD,
HandleShardingStrategy._HYBRID_SHARD_ZERO2,
):
# Share the all-reduce state across FSDP units. This is not strictly necessary
# as each one already uses the same process group, but can slightly save memory
# since other FSDP units allreduce state can be garbage collected.
assert inter_node_state is not None, (
"`_validate_and_get_hybrid_shard_state()` should have returned "
"a valid inter-node state if there exists an FSDP instance "
"using a hybrid sharding strategy"
)
fsdp_state._inter_node_state = inter_node_state
# Relax the assert for non-root FSDP instances in case the nested
# initialized module is wrapped again in FSDP later (e.g. after
# training to run inference)
_p_assert(
fsdp_state._is_root is None or not fsdp_state._is_root,
"Non-root FSDP instance's `_is_root` should not have been "
"set yet or should have been set to `False`",
)
fsdp_state._is_root = False
fsdp_state._unshard_stream = root_state._unshard_stream
fsdp_state._post_backward_stream = root_state._post_backward_stream
fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream
fsdp_state._default_stream = root_state._default_stream
fsdp_state._exec_order_data = root_state._exec_order_data
fsdp_state._free_event_queue = root_state._free_event_queue
fsdp_state._handles_prefetched = root_state._handles_prefetched
fsdp_state._needs_pre_backward_unshard = root_state._needs_pre_backward_unshard
fsdp_state._device_mesh = root_state._device_mesh
for handle in fsdp_state._handles:
handle.init_flat_param_attributes()
for attr_name, attr_values in attr_name_to_values.items():
if len(attr_values) != 1:
raise ValueError(
f"Expects one homogeneous value for {attr_name} but got {attr_values}"
)
@no_type_check
def _init_streams(
state: _FSDPState,
) -> _FSDPState:
"""
Initializes CUDA streams for overlapping communication, computation, and
data transfers. The streams should be shared across FSDP instances.
"""
assert state._is_root
assert state._device_handle.is_available()
# Stream for unshard logic, including allocating the all-gather destination
# tensors and the all-gathers themselves.
state._unshard_stream = state._device_handle.Stream()
# Stream for overlapping gradient reduction with the backward pass gradient
# computation.
state._post_backward_stream = state._device_handle.Stream()
# Stream for pre-unshard logic, namely allocations and writes for CPU
# offloading (H2D copy) and mixed precision (low precision cast).
state._pre_unshard_stream = state._device_handle.Stream()
# Default stream for computation
state._default_stream = state._device_handle.current_stream()
@no_type_check
def _unshard(
state: _FSDPState,
handles: List[FlatParamHandle],
unshard_stream: torch.Stream,
pre_unshard_stream: torch.Stream,
) -> None:
"""
Unshards the handles in ``handles``. If the handles are in
:meth:`summon_full_params` and are using mixed precision, then they are
forced to full precision.
Postcondition: Each handle's ``FlatParameter`` 's data is the padded
unsharded flat parameter on the compute device.
"""
if not handles:
return
any_ran_pre_unshard = False
with state._device_handle.stream(pre_unshard_stream):
for handle in handles:
ran_pre_unshard = handle.pre_unshard()
any_ran_pre_unshard = any_ran_pre_unshard or ran_pre_unshard
if any_ran_pre_unshard:
unshard_stream.wait_stream(pre_unshard_stream)
if state.limit_all_gathers:
event = state._free_event_queue.dequeue_if_needed()
if event:
with torch.profiler.record_function(
"FullyShardedDataParallel.rate_limiter"
):
event.synchronize()
with state._device_handle.stream(unshard_stream):
for handle in handles:
handle.unshard()
handle.post_unshard()
@no_type_check
def _reshard(
state: _FSDPState,
handles: List[FlatParamHandle],
free_unsharded_flat_params: List[bool],
):
"""
Reshards the handles in ``handles``. ``free_unsharded_flat_params`` should
have the same length as ``handles``, and each element should give whether
the corresponding handle should free its padded unsharded flat parameter.
"""
if not handles:
return
_p_assert(
len(handles) == len(free_unsharded_flat_params),
"Expects both lists to have equal length but got "
f"{len(handles)} and {len(free_unsharded_flat_params)}",
)
for handle, free_unsharded_flat_param in zip(
handles,
free_unsharded_flat_params,
):
handle.reshard(free_unsharded_flat_param)
if state.limit_all_gathers and free_unsharded_flat_param:
free_event = state._device_handle.Event()
free_event.record()
state._free_event_queue.enqueue(free_event)
handle.post_reshard()
# Since we prefetch entire handles keys at a time, conservatively mark
# the entire key as no longer prefetched once we free at least one
handles_key = tuple(handles)
if any(free_unsharded_flat_params):
state._handles_prefetched.pop(handles_key, None)
def _unshard_grads(
handles: List[FlatParamHandle],
) -> None:
for handle in handles:
handle.unshard_grad()
def _reshard_grads(
handles: List[FlatParamHandle],
) -> None:
for handle in handles:
handle.reshard_grad()
@no_type_check
def _pre_forward(
state: _FSDPState,
handles: List[FlatParamHandle],
unshard_fn: Callable,
module: nn.Module,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""
Runs the pre-forward logic. This includes an opportunity to unshard
currently sharded parameters such as those for the current forward and
registering post-backward hooks for these current parameters. This function
also converts forward ``args`` and ``kwargs`` to the given precision.
Args:
handles (List[FlatParamHandle]): Handles giving the parameters used in
the current forward.
unshard_fn (Optional[Callable]): A callable to unshard any currently
sharded parameters or ``None`` to not do any unsharding.
module (nn.Module): Module whose forward this method runs right before;
expected by the hook signature.
args (Tuple[Any, ...]): Module forward ``args``.
kwargs (Dict[str, Any]): Module forward ``kwargs``.
"""
with torch.profiler.record_function("FullyShardedDataParallel._pre_forward"):
state.training_state = TrainingState.FORWARD_BACKWARD
state._exec_order_data.record_pre_forward(handles, module.training)
for handle in handles:
handle._training_state = HandleTrainingState.FORWARD
if unshard_fn is not None:
unshard_fn()
# Register post-backward hooks to reshard the parameters and reduce-scatter
# their gradients. They must be re-registered every forward pass in case
# the `grad_fn` is mutated.
_register_post_backward_hooks(state, handles)
# We have to reallocate the _cpu_grad if optimizer overlap
# set the grad to None in the backward pass.
for handle in handles:
if handle._offload_params and handle.flat_param._cpu_grad is None:
handle.flat_param._cpu_grad = torch.zeros_like(
handle.flat_param._local_shard, device=torch.device("cpu")
).pin_memory()
should_cast_forward_inputs = len(state._handles) > 0 and all(
not handle._force_full_precision for handle in state._handles
)
if should_cast_forward_inputs and state.mixed_precision.cast_forward_inputs:
# Recursively convert args and kwargs to specified precision.
input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
_register_post_backward_reshard_only_hooks(state, handles, args, kwargs)
return args, kwargs
@no_type_check
def _pre_forward_unshard(
state: _FSDPState,
handles: List[FlatParamHandle],
) -> None:
"""Unshards parameters in the pre-forward."""
if not handles:
return
handles_key = tuple(handles)
# If the handles have been prefetched, then there is no need to call
# `_unshard()` again
if not state._handles_prefetched.get(handles_key, False):
_unshard(state, handles, state._unshard_stream, state._pre_unshard_stream)
state._needs_pre_forward_unshard[handles_key] = False
state._device_handle.current_stream().wait_stream(state._unshard_stream)
_prefetch_handles(state, handles_key, _PrefetchMode.FORWARD)
@no_type_check
def _post_forward(
state: _FSDPState,
handles: List[FlatParamHandle],
reshard_fn: Callable,
module: nn.Module,
input: Any,
output: Any,
) -> Any:
"""
Runs the post-forward logic. This includes an opportunity to reshard
currently unsharded parameters such as those used in the current forward
and registering pre-backward hooks on the forward outputs.
Args:
handles (List[FlatParamHandle]): Handles giving the parameters used in
the current forward.
reshard_fn (Optional[Callable]): A callable to reshard any currently
unsharded parameters (e.g. from the current forward) or ``None`` to
not do any resharding.
module (nn.Module): Module whose forward just ran, which should be a
fully sharded module (see [Note: Fully Sharded Module]); expected
by the hook signature.
input (Any): Unused; expected by the hook signature.
output (Any): Forward pass output; pre-backward hooks are registered on
the tensors that require gradients in this output.
Postcondition: Each ``FlatParameter`` 's data points to the sharded flat
parameter.
"""
with torch.profiler.record_function("FullyShardedDataParallel._post_forward"):
state._exec_order_data.record_post_forward(handles)
if reshard_fn is not None:
reshard_fn()
# Register pre-backward hooks to unshard the flat parameters for the
# gradient computation (if needed)
output = _register_pre_backward_hooks(state, module, output, handles)
state.training_state = TrainingState.IDLE
for handle in handles:
handle._training_state = HandleTrainingState.IDLE
return output
@no_type_check
def _post_forward_reshard(
state: _FSDPState,
handles: List[FlatParamHandle],
) -> None:
"""Reshards parameters in the post-forward."""
if not handles:
return
# Do not free the root's parameters in the post-forward for `FULL_SHARD`
# with the intention that they are immediately used for backward
# computation (though this may not be true)
free_unsharded_flat_params = [
not state._is_root
and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
for handle in handles
]
_reshard(state, handles, free_unsharded_flat_params)
@no_type_check
def _root_pre_forward(
state: _FSDPState,
module: nn.Module,
args,
kwargs,
) -> None:
"""
Runs pre-forward logic specific to the root FSDP instance, which should run
before any individual module's pre-forward. This starts with an attempt at
lazy initialization (which only runs non-vacuously once). Otherwise, if
this is called on a non-root FSDP instance, then it returns directly.
Args:
module (nn.Module): Module for which this logic tries to run. It may or
may not be the root. If not, then this method does not do anything.
"""
with torch.profiler.record_function("FullyShardedDataParallel._root_pre_forward"):
_lazy_init(state, module)
_p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
if not state._is_root:
# Always cast forward inputs in the root of this local FSDP unit for mixed
# precision, as this is where mixed precision could be configed.
# This is more useful for auto wrapping that is recommended in composable path.
# For manual wrapping, cast forward inputs on each local FSDP unit root will
# increase some overhead, so not turned on for model wrapper path right now where
# manual wrapping is more broadly used.
if _is_composable(state):
return _root_cast_forward_input(state, module, args, kwargs)
return args, kwargs
# We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers
# are in full precision and if we should cast them back to lower precision, which happens when
# exiting eval() mode and full precision in eval was configured.
should_cast_buffers_to_full_prec = any(
handle._force_full_precision for handle in state._handles
)
if should_cast_buffers_to_full_prec:
_cast_buffers_to_dtype_and_device(
buffers=dict(module.named_buffers()).values(),
buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()),
device=state.compute_device,
)
# This flag is only set when we cast buffers to full precision, to avoid the
# CPU overhead that can stem from retrieving all buffers and their types in the
# following else branch.
state._needs_buffer_dtype_restore_check = True
elif getattr(state, "_needs_buffer_dtype_restore_check", False):
# Check if buffers are in full precision and we need to cast them
# back down.
(
buffers,
buffer_dtypes_for_computation,
) = _get_buffers_and_dtypes_for_computation(state, module)
if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0:
if any(
buffer.dtype != buffer_dtype_for_computation
for buffer, buffer_dtype_for_computation in zip(
buffers, buffer_dtypes_for_computation
)
):
# Assume we have to cast everything if there is one mismatch
_cast_buffers_to_dtype_and_device(
buffers, buffer_dtypes_for_computation, state.compute_device
)
# We don't have to check this again until we cast buffers to full precision again.
state._needs_buffer_dtype_restore_check = False
if state.forward_prefetch:
handles_keys = []
for fsdp_state in state._all_fsdp_states:
# TODO: Forward prefetch assumes singleton handles key. For the
# composable path, `_handles` may have more than one handle,
# whereas for the wrapper path, it has at most one handle.
handles_keys.extend((handle,) for handle in fsdp_state._handles)
for handles_key in handles_keys:
state._needs_pre_forward_unshard[handles_key] = True
_wait_for_computation_stream(
state._device_handle.current_stream(),
state._unshard_stream,
state._pre_unshard_stream,
)
_reset_flat_param_grad_info_if_needed(state._all_handles)
# Prepares the forward inputs by moving them to ``compute_device``
# TODO: Do not use the side stream for tensor copies for now; investigate
# the perf with/without it.
with torch.profiler.record_function("FullyShardedDataParallel._to_kwargs"):
args_tuple, kwargs_tuple = _to_kwargs(
args, kwargs, state.compute_device, False
)
args = args_tuple[0]
kwargs = kwargs_tuple[0]
return _root_cast_forward_input(state, module, args, kwargs)
@no_type_check
def _root_cast_forward_input(
state: _FSDPState, module: torch.nn.Module, args, kwargs
) -> Tuple[Any, Any]:
should_cast_forward_inputs = (
(module.training or not state._use_full_prec_in_eval)
and all(not handle._force_full_precision for handle in state._handles)
) and state.mixed_precision.cast_root_forward_inputs
if should_cast_forward_inputs:
input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
return args, kwargs
@no_type_check
def _pre_backward_hook(
state: _FSDPState,
module: nn.Module,
_handles: List[FlatParamHandle],
*unused: Any,
) -> Any:
"""
Prepares ``_handles`` 's ``FlatParameter`` s for gradient computation.
Args:
module (nn.Module): Fully sharded module (see [Note: Fully Sharded
Module]).
"""
_handles_key = tuple(_handles) # avoid shadowing `handles_key`
# Only run the pre-backward hook once per group of handles involved in the
# same module forward computation
if _handles_key and state._ran_pre_backward_hook.get(_handles_key, False):
return
with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"):
# Queue the post-backward callback once for the root FSDP instance to
# attach it to the outermost backward graph task so that it is called
# after all backward calls complete
if state._is_root and not state._post_backward_callback_queued:
_register_post_backward_final_callback(state, module)
_reset_flat_param_grad_info_if_needed(state._all_handles)
elif _handles_key:
allowed_states = [TrainingState.IDLE]
if _is_composable(state):
allowed_states.append(TrainingState.FORWARD_BACKWARD)
_assert_in_training_states(state, allowed_states)
state.training_state = TrainingState.FORWARD_BACKWARD
# Queueing the post-backward callback is the only logic that is not
# per-handle in the pre-backward hook, so we can return early here if
# there are no handles.
if not _handles_key:
return
for handle in _handles:
handle._training_state = HandleTrainingState.BACKWARD_PRE
if state._needs_pre_backward_unshard[_handles_key]:
# If the handles have been prefetched, then there is no need to
# call `_unshard()` again
if not state._handles_prefetched.get(_handles_key, False):
_unshard(
state,
_handles,
state._unshard_stream,
state._pre_unshard_stream,
)
state._device_handle.current_stream().wait_stream(state._unshard_stream)
# Set this to `False` to ensure that a mistargeted prefetch does not
# actually unshard these handles
state._needs_pre_backward_unshard[_handles_key] = False
_prefetch_handles(state, _handles_key, _PrefetchMode.BACKWARD)
for handle in _handles:
handle.prepare_gradient_for_backward()
state._ran_pre_backward_hook[_handles_key] = True
@no_type_check
@torch.no_grad()
def _post_backward_hook(
state: _FSDPState,
handle: FlatParamHandle,
*unused: Any,
):
"""
Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
unsharded gradient for the local batch.
Postcondition:
- If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced
unsharded gradient.
- Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
gradient (accumulating with any existing gradient).
"""
# Under TORCH_DISTRIBUTED_DEBUG=INFO, log the module names this hook fires for.
# Below logging of module names this post-bwd hook fires for can help debug certain
# cases where hooks don't fire, such as under certain activation checkpoint configs.
if state._use_orig_params and handle._debug_level == dist.DebugLevel.INFO:
param_to_fqn = state._exec_order_data.param_to_fqn
handle_params = handle.flat_param._params # only populated for use_orig_params
param_fqns = [
param
for param_list in [param_to_fqn[p] for p in handle_params]
for param in param_list
]
log.warning("FSDP firing post-backward hooks for parameters %s", param_fqns)
flat_param = handle.flat_param
flat_param._post_backward_called = True
with torch.autograd.profiler.record_function(
"FullyShardedDataParallel._post_backward_hook"
):
_assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
# For multiple applications of reentrant AC across submodules sharing
# the same `FlatParameter`, the post-backward hook may run multiple
# times in one backward, in which case we permit the state to already
# be in `BACKWARD_POST`.
_p_assert(
handle._training_state
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
)
handle._training_state = HandleTrainingState.BACKWARD_POST
if flat_param.grad is None:
return
if flat_param.grad.requires_grad:
raise RuntimeError("FSDP does not support gradients of gradients")
_post_backward_reshard(state, handle)
if not state._sync_gradients:
if handle._use_orig_params:
handle._use_unsharded_grad_views()
return
# Wait for all ops in the current stream (e.g. gradient
# computation) to finish before reduce-scattering the gradient
state._post_backward_stream.wait_stream(state._device_handle.current_stream())
with state._device_handle.stream(state._post_backward_stream):
autograd_computed_grad = flat_param.grad.data
if state._exec_order_data.is_first_iter: # only check once
_check_comm_hook(
state._communication_hook, state._communication_hook_state
)
if (
not _low_precision_hook_enabled(state)
and flat_param.grad.dtype != handle._reduce_dtype
# If we are forcing full precision but communicating grads
# (i.e. model.eval() + full precision in eval was configured), don't downcast gradient.
and not handle._force_full_precision
):
flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype)
if handle.uses_sharded_strategy:
# We clear `.grad` to permit multiple backwards. This avoids a
# race where the second backward pass computation precedes
# ahead of the first backward pass reduction, which is possible
# since the reduction is issued in a separate stream and is
# async and would result in reducing the wrong gradient.
unsharded_grad = flat_param.grad.data
flat_param.grad = None
chunks = list(unsharded_grad.chunk(state.world_size))
numel_to_pad = (
state.world_size * chunks[0].numel() - unsharded_grad.numel()
)
padded_unsharded_grad = (
F.pad(unsharded_grad, [0, numel_to_pad])
if numel_to_pad > 0
else unsharded_grad
)
new_sharded_grad = torch.empty_like(chunks[0]) # padded
state._communication_hook(
state._communication_hook_state,
padded_unsharded_grad,
new_sharded_grad,
)
if handle._sharding_strategy in (
HandleShardingStrategy.HYBRID_SHARD,
HandleShardingStrategy._HYBRID_SHARD_ZERO2,
):
default_hooks.allreduce_hook(
state=state._inter_node_state,
grad=new_sharded_grad,
)
_cast_grad_to_param_dtype(state, new_sharded_grad, flat_param)
# Save the sharded gradient in `_saved_grad_shard` to support
# gradient accumulation -- for multiple backwards, the gradient
# reductions may happen in arbitrary order
accumulate_grad = hasattr(flat_param, "_saved_grad_shard")
if accumulate_grad:
_check_grad_to_accumulate(
new_sharded_grad, flat_param._saved_grad_shard
)
flat_param._saved_grad_shard += new_sharded_grad
else:
flat_param._saved_grad_shard = new_sharded_grad
grad_to_offload = flat_param._saved_grad_shard
else:
state._communication_hook(
state._communication_hook_state, flat_param.grad
)
# For `NO_SHARD`, we can keep the low precision gradients by
# simply omitting the cast altogether
if not handle._keep_low_precision_grads:
_cast_grad_to_param_dtype(state, flat_param.grad, flat_param)
grad_to_offload = flat_param.grad.data
if handle._offload_params:
# Offload the gradient to CPU to ensure parameters and
# gradients are on the same device as required by the optimizer
# TODO: Investigate why `NO_SHARD` breaks correctness when
# using `non_blocking=True` here.
# TODO (rohan-varma): When CPU offload and optimizer overlap,
# non_blocking=True won't work since the copy may have not finished
# before the optimizer step executes on CPU. If we want to use
# non-blocking=True here, we'll have to synchronize before using
# result on CPU.
non_blocking = (
handle.uses_sharded_strategy and not handle._has_optim_in_backward
)
flat_param._cpu_grad.copy_( # type: ignore[attr-defined]
grad_to_offload.detach(), non_blocking=non_blocking
) # synchronized in the post-backward callback
# Since the gradient being offloaded may have been produced in
# the computation stream and is being consumed here in the
# post-backward stream, inform the caching allocator
_no_dispatch_record_stream(
grad_to_offload.data,
state._post_backward_stream,
)
# Since the unsharded gradient is produced in the computation
# stream and consumed in the post-backward stream, inform the
# caching allocator (before it goes out of scope)
_no_dispatch_record_stream(
autograd_computed_grad, state._post_backward_stream
)
if handle._use_orig_params:
# Since the handle's `FlatParameter` completed its gradient
# computation, we should reset the gradient noneness mask
handle._reset_is_grad_none()
# Delay using sharded gradient views until after the
# reduce-scatter instead of immediately after resharding
handle._use_sharded_grad_views()
if handle._has_optim_in_backward:
handle.prepare_gradient_for_optim()
for orig_param in handle.flat_param._params:
# checking grad for None also filters out params
# that don't belong to this rank
if orig_param.grad is not None and hasattr(
orig_param, "_in_backward_optimizers"
):
# TODO (rohan-varma): For CPU offload, this unfortunately
# operates on CPU, because the parameters and gradients
# have already been offloaded. We should run this on
# GPU after refactoring.
for optim in orig_param._in_backward_optimizers:
optim.step()
optim.zero_grad(set_to_none=True)
handle._reset_flat_param_grad_info_if_needed()
if handle._offload_params:
handle.flat_param._cpu_grad = None
def _post_backward_reshard(
state: _FSDPState,
handle: FlatParamHandle,
*unused: Any,
) -> None:
free_unsharded_flat_param = _should_free_in_backward(state, handle)
_reshard(state, [handle], [free_unsharded_flat_param])
# TODO: Post-backward prefetching does not support the multiple handles
# per module case since the post-backward hook runs per handle, not per
# group of handles.
handles_key = (handle,)
_prefetch_handles(state, handles_key, _PrefetchMode.BACKWARD)
@no_type_check
def _should_free_in_backward(
state: _FSDPState,
handle: FlatParamHandle,
) -> bool:
"""
Returns whether FSDP should free the unsharded flat parameter in the
post-backward or not.
"""
if not handle.uses_sharded_strategy:
return False
# If not syncing gradients, then we do not free for strategies that do not
# reshard after forward as a *heuristic* to tradeoff higher memory for
# higher throughput.
return (
state._sync_gradients
or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
)
@no_type_check
def _cast_grad_to_param_dtype(
state: _FSDPState,
sharded_grad: torch.Tensor,
param: FlatParameter,
):
"""
Casts ``sharded_grad`` back to the full parameter dtype so that the
optimizer step runs with that dtype. This performs an actual cast if
1. parameters were in reduced precision during the forward since then
gradients would be in that reduced precision, or
2. parameters were not in reduced precision but gradients were in
reduced precision for communication.
However, if a low precision communication hook is registered, then this
dtype cast happens in the hook instead.
"""
_assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
if not _low_precision_hook_enabled(state) and sharded_grad.dtype != param.dtype:
low_prec_grad_data = sharded_grad.data
sharded_grad.data = sharded_grad.data.to(dtype=param.dtype)
# Since for `NO_SHARD`, the gradient is produced in the computation
# stream and consumed here in the post-backward stream, inform the
# caching allocator; for the sharded strategies, the gradient is
# produced in the post-backward stream, so this `record_stream()`
# should be a no-op
_no_dispatch_record_stream(
low_prec_grad_data, state._device_handle.current_stream()
)
def _check_comm_hook(
comm_hook: Any,
comm_hook_state: Any,
) -> None:
_p_assert(comm_hook is not None, "Communication hook should not be `None`")
_p_assert(
comm_hook_state is not None, "Communication hook state should not be `None`"
)
def _check_grad_to_accumulate(
new_sharded_grad: torch.Tensor,
accumulated_grad: torch.Tensor,
) -> None:
_p_assert(
accumulated_grad.shape == new_sharded_grad.shape,
"Shape mismatch when accumulating gradients: "
f"existing gradient shape={accumulated_grad.shape} "
f"new gradient shape={new_sharded_grad.shape}",
)
_p_assert(
accumulated_grad.device == new_sharded_grad.device,
"Device mismatch when accumulating gradients: "
f"existing gradient device={accumulated_grad.device} "
f"new gradient device={new_sharded_grad.device}",
)
@no_type_check
def _low_precision_hook_enabled(state: _FSDPState) -> bool:
return state._communication_hook in LOW_PRECISION_HOOKS
@no_type_check
@torch.no_grad()
def _post_backward_final_callback(
state: _FSDPState,
module: nn.Module,
):
"""
This waits for the post-backward to finish and performs some final cleanup.
This runs at the end of the entire backward pass and should only be called
on the root FSDP instance.
"""
_p_assert(
state._is_root,
"The post-backward callback should only be called on the root FSDP instance",
)
root_state = state
if root_state._sync_gradients:
# TODO (rohan-varma): this also waits for the overlapped optimizer step to finish
# since it currently runs in the post-backward stream. That can be
# pushed to the next forward if run in a different stream
state._device_handle.current_stream().wait_stream(
root_state._post_backward_stream
)
if root_state.cpu_offload.offload_params:
# Wait for non-blocking GPU -> CPU sharded gradient copies from the
# post-backward hooks to finish explicitly since CPU gradients do
# not automatically synchronize with the GPU
state._device_handle.current_stream().synchronize()
root_state._exec_order_data.next_iter()
for fsdp_state in state._all_fsdp_states:
_catch_all_reshard(fsdp_state)
_finalize_params(fsdp_state)
fsdp_state._needs_pre_backward_unshard.clear()
fsdp_state._ran_pre_backward_hook.clear()
fsdp_state.training_state = TrainingState.IDLE
for handle in fsdp_state._handles:
handle._training_state = HandleTrainingState.IDLE
fsdp_state._handles_prefetched.clear()
# Reset for cases like one forward and multiple backwards
root_state._post_backward_callback_queued = False
@no_type_check
def _catch_all_reshard(
state: _FSDPState,
) -> None:
"""
Reshards the parameters that may not have been resharded in the
post-backward hook. This can happen when a module's output is used in the
forward pass, meaning that its pre-backward hook runs (unsharding the
parameter), but the post-backward hook does not run because the output was
not jused in the loss computation corresponding to this backward pass.
"""
# Wrap with a try-except to provide a more informative traceback if an
# error is raised
try:
free_unsharded_flat_params: List[bool] = []
handles_to_reshard: List[FlatParamHandle] = []
for handle in state._handles:
# TODO: This already-resharded check is brittle:
# https://github.com/pytorch/pytorch/issues/83956
already_resharded = (
handle.flat_param.data_ptr()
== handle.flat_param._local_shard.data_ptr()
# If FSDP skipped using sharded views, then the flat parameter
# still points to the sharded data, so we need to reshard to
# use sharded views
and not handle._skipped_use_sharded_views
)
if already_resharded:
continue
free_unsharded_flat_params.append(_should_free_in_backward(state, handle))
handles_to_reshard.append(handle)
if handles_to_reshard:
_reshard(state, handles_to_reshard, free_unsharded_flat_params)
except Exception as e:
_p_assert(
False,
f"Got exception in the catch-all reshard for {state}: {str(e)}",
raise_assertion_error=False,
)
raise e
@no_type_check
def _finalize_params(
state: _FSDPState,
) -> None:
"""Finalizes the parameters before the next iteration."""
for handle in state._handles:
flat_param = handle.flat_param
if hasattr(flat_param, "_post_backward_hook_state"):
post_backward_hook_state_len = len(flat_param._post_backward_hook_state)
expected_post_backward_hook_state_len = int(flat_param.requires_grad) + 1
_p_assert(
post_backward_hook_state_len == expected_post_backward_hook_state_len,
f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
)
flat_param._post_backward_hook_state[-1].remove()
delattr(flat_param, "_post_backward_hook_state")
if flat_param.requires_grad:
if not state._sync_gradients:
# Preserve the gradient accumulation state if not synchronizing
# gradients: `.grad` remains the unsharded gradient from prior
# `no_sync()` iterations, and `_saved_grad_shard` remains the
# sharded gradient from the last synchronized iteration
continue
# Skip call to prepare_gradient_for_optim if we've already run the
# optimizer in backward pass.
if not handle._has_optim_in_backward:
handle.prepare_gradient_for_optim()
_p_assert(
hasattr(flat_param, "_post_backward_called"),
"Expects `_post_backward_called` to be set on the `FlatParameter`",
)
flat_param._post_backward_called = False
@no_type_check
def _prefetch_handles(
state: _FSDPState,
current_handles_key: _HandlesKey,
prefetch_mode: _PrefetchMode,
) -> None:
"""
Prefetches the next handles if needed (without synchronization). An empty
handles key cannot prefetch.
"""
if not current_handles_key:
return
handles_to_prefetch = _get_handles_to_prefetch(state, current_handles_key)
for handles_key in handles_to_prefetch:
# Temporarily emulate the training state while calling `_unshard` to
# ensure the correct `as_params` for `_use_unsharded_views()`
prev_training_states: List[HandleTrainingState] = []
for handle in handles_key:
prev_training_states.append(handle._training_state)
if prefetch_mode == _PrefetchMode.BACKWARD:
handle._training_state = HandleTrainingState.BACKWARD_PRE
elif prefetch_mode == _PrefetchMode.FORWARD:
handle._training_state = HandleTrainingState.FORWARD
else:
raise ValueError(
f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}"
)
# Prefetch the next set of handles without synchronizing to allow
# the sync to happen as late as possible to maximize overlap
_unshard(state, handles_key, state._unshard_stream, state._pre_unshard_stream)
for handle, prev_training_state in zip(handles_key, prev_training_states):
handle._training_state = prev_training_state
state._handles_prefetched[handles_key] = True
@no_type_check
def _get_handles_to_prefetch(
state: _FSDPState,
current_handles_key: _HandlesKey,
) -> List[_HandlesKey]:
"""
Returns a :class:`list` of the handles keys to prefetch for the next
module(s), where ``current_handles_key`` represents the current module.
"Prefetching" refers to running the unshard logic early (without
synchronization), and the "next" modules depend on the recorded execution
order and the current training state.
"""
training_state = _get_training_state(current_handles_key)
valid_training_states = (
HandleTrainingState.BACKWARD_PRE,
HandleTrainingState.BACKWARD_POST,
HandleTrainingState.FORWARD,
)
_p_assert(
training_state in valid_training_states,
f"Prefetching is only supported in {valid_training_states} but "
f"currently in {training_state}",
)
eod = state._exec_order_data
target_handles_keys: List[_HandlesKey] = []
if (
training_state == HandleTrainingState.BACKWARD_PRE
and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
) or (
training_state == HandleTrainingState.BACKWARD_POST
and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
):
target_handles_keys = [
target_handles_key
for target_handles_key in eod.get_handles_to_backward_prefetch(
current_handles_key
)
if state._needs_pre_backward_unshard.get(target_handles_key, False)
and not state._handles_prefetched.get(target_handles_key, False)
]
elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch:
target_handles_keys = [
target_handles_key
for target_handles_key in eod.get_handles_to_forward_prefetch(
current_handles_key
)
if state._needs_pre_forward_unshard.get(target_handles_key, False)
and not state._handles_prefetched.get(target_handles_key, False)
]
return target_handles_keys
def _get_training_state(
handles_key: _HandlesKey,
) -> HandleTrainingState:
"""Returns the training state of the handles in ``handles_key``."""
_p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
training_states = {handle._training_state for handle in handles_key}
_p_assert(
len(training_states) == 1,
f"Expects uniform training state but got {training_states}",
)
return next(iter(training_states))
@no_type_check
def _register_pre_forward_hook(
state: _FSDPState,
module: nn.Module,
) -> None:
"""
Registers a pre-forward hook on ``module``.
"""
for forward_handle in state._pre_forward_handles:
forward_handle.remove()
state._pre_forward_handles.clear()
module_param_handles = state._fully_sharded_module_to_handles.get(module, [])
unshard_fn = functools.partial(
_pre_forward_unshard,
state,
module_param_handles,
)
hook = functools.partial(_pre_forward, state, module_param_handles, unshard_fn)
state._pre_forward_handles.append(
module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)
)
@no_type_check
def _register_post_forward_hook(
state: _FSDPState,
module: nn.Module,
) -> None:
"""
Registers a post-forward hook on ``module``. Even if the module has no
handles, we should register the hook since it will register the module's
pre-backward hook.
"""
for forward_handle in state._post_forward_handles:
forward_handle.remove()
state._post_forward_handles.clear()
module_param_handles = state._fully_sharded_module_to_handles.get(module, [])
reshard_fn = functools.partial(
_post_forward_reshard,
state,
module_param_handles,
)
hook = functools.partial(
_post_forward,
state,
module_param_handles,
reshard_fn,
)
state._post_forward_handles.append(module.register_forward_hook(hook))
@no_type_check
def _register_root_pre_forward_hook(
state: _FSDPState,
module: nn.Module,
):
"""
Registers root pre-forward hook on ``module``, which should be the local
FSDP root.
NOTE: For the current composable FSDP design, we have each application of
``fully_shard()`` to a module to indicate that that module is the local
FSDP root. We may remove this assumption in the future, in which case we
will need to register this root pre-forward hook on any candidate module
that may be the local FSDP root.
"""
for forward_handle in state._root_pre_forward_handles:
forward_handle.remove()
state._root_pre_forward_handles.clear()
hook = functools.partial(_root_pre_forward, state)
state._root_pre_forward_handles.append(
module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)
)
@no_type_check
def _register_pre_backward_hooks(
state: _FSDPState,
module: nn.Module,
outputs: Any,
handles: List[FlatParamHandle],
) -> None:
"""
Registers pre-backward hooks on the tensors that require gradients in the
forward pass outputs ``outputs``, which were computed using the
``FlatParameter`` s of ``handles``.
Args:
module (nn.Module): Fully sharded module (see [Note: Fully Sharded
Module]).
Returns:
Forward pass outputs with pre-backward hooks registered to tensors that
require gradients.
"""
# If there is no gradient computation, then there is no need for
# pre-backward logic
if not torch.is_grad_enabled():
return outputs
if state._is_root:
state._post_backward_callback_queued = False # only defined on the root
handles_key = tuple(handles)
if handles_key:
state._needs_pre_backward_unshard[handles_key] = False
# Since these handles' `FlatParameter`s participated in a forward, we
# conservatively assume that they will be used in the backward
state._ran_pre_backward_hook[handles_key] = False
def _register_hook(t: torch.Tensor) -> torch.Tensor:
if t.requires_grad:
t.register_hook(
functools.partial(_pre_backward_hook, state, module, handles)
)
state._needs_pre_backward_unshard[handles_key] = True
return t
return _apply_to_tensors(_register_hook, outputs)
def _register_post_backward_hooks(
state: _FSDPState,
handles: List[FlatParamHandle],
) -> None:
"""
Registers post-backward hooks on the ``FlatParameter`` s'
``AccumulateGrad`` objects to reshard and to reduce-scatter gradients.
The ``AccumulateGrad`` object represents the last function that finalizes
the ``FlatParameter`` 's gradient, so it only runs after its entire
gradient computation has finished.
We register the post-backward hook only once in the *first* forward that a
``FlatParameter`` participates in. This relies on the ``AccumulateGrad``
object being preserved through multiple forwards.
NOTE: We follow this heuristic to prefer the *first* forward to target the
parameter mixed precision case, where there are *separate*
``AccumulateGrad`` objects across the different forwards. (Without
parameter mixed precision, the ``AccumulateGrad`` objects are the same.) If
we instead prefer the *last* forward, then the hook runs early.
"""
# If there is no gradient computation, then there is no need for
# post-backward logic
if not torch.is_grad_enabled():
return
for handle in handles:
flat_param = handle.flat_param
already_registered = hasattr(flat_param, "_post_backward_hook_state")
if already_registered or not flat_param.requires_grad:
continue
# Get the `AccumulateGrad` object
temp_flat_param = flat_param.expand_as(flat_param)
_p_assert(
temp_flat_param.grad_fn is not None,
"The `grad_fn` is needed to access the `AccumulateGrad` and "
"register the post-backward hook",
)
acc_grad = temp_flat_param.grad_fn.next_functions[0][0] # type: ignore[union-attr]
assert acc_grad is not None
hook_handle = acc_grad.register_hook(
functools.partial(_post_backward_hook, state, handle)
)
flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined]
def _register_post_backward_reshard_only_hooks(
state: _FSDPState,
handles: List[FlatParamHandle],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> None:
"""
Registers post-backward hooks to reshard flat parameters that do not
require gradient. We register these using multi-post-grad hooks on the
input activations to ensure that all gradients that may depend on the
parameters have been computed before resharding.
"""
# Construct `inp_tensors` lazily to avoid CPU overhead in typical case
# where each flat parameter requires gradient
inp_tensors: Optional[List[torch.Tensor]] = None
for handle in handles:
if handle.flat_param.requires_grad:
continue
if inp_tensors is None:
args_list, _ = tree_flatten(args)
kwargs_list, _ = tree_flatten(kwargs)
inp_tensors = [
obj
for obj in chain(args_list, kwargs_list)
if torch.is_tensor(obj) and obj.requires_grad
]
assert inp_tensors is not None # mypy
hook_handle = register_multi_grad_hook(
inp_tensors, functools.partial(_post_backward_reshard, state, handle)
)
handle.flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment]
@no_type_check
def _register_post_backward_final_callback(
state: _FSDPState, module: nn.Module
) -> None:
"""
Registers the post-backward final callback that runs at the end of the
backward pass. This should be called from the root FSDP instance at the
beginning of the pre-backward.
"""
_p_assert(
state._is_root,
"Only the root FSDP instance should register the post-backward callback",
)
if state._post_backward_callback_queued:
return
_assert_in_training_states(state, [TrainingState.IDLE])
state._post_backward_callback_queued = True
Variable._execution_engine.queue_callback(
functools.partial(_post_backward_final_callback, state, module)
)
def _wait_for_computation_stream(
computation_stream: torch.Stream,
unshard_stream: torch.Stream,
pre_unshard_stream: torch.Stream,
):
"""
Has the unshard and pre-unshard streams wait for the computation stream.
For example, this should be called in the FSDP root's pre-forward to
respect optimizer step computation.
"""
unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
# Having the pre-all-gather stream wait for the current stream even if we
# do not leverage the pre-all-gather stream is tolerable since this only
# runs once per iteration
pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
def _reset_flat_param_grad_info_if_needed(
handles: List[FlatParamHandle],
):
"""
Clears the original parameters' gradients if needed. This method's CPU
overhead is minimal, so we may call it throughout FSDP methods, which serve
as callsites to free the gradient memory earlier.
"""
for handle in handles:
if handle._use_orig_params:
handle._reset_flat_param_grad_info_if_needed()
@no_type_check
def _get_buffers_and_dtypes_for_computation(
state: _FSDPState,
root_module: nn.Module,
) -> Tuple[List[torch.Tensor], List[Optional[torch.dtype]]]:
"""
Returns all buffers in the module tree rooted at ``root_module`` and a
corresponding list of the buffer dtypes for computation. Each buffer dtype
is either ``None`` if buffer mixed precision is not enabled or the buffer
low precision dtype otherwise.
"""
_p_assert(state._is_root, "Expects the root to cast buffers")
buffers: List[torch.Tensor] = []
buffer_dtypes: List[Optional[torch.dtype]] = []
visited_buffers: Set[torch.Tensor] = set()
# Traverse the FSDP states bottom-up so that we prefer the owning FSDP
# instance's mixed precision setting for each buffer
fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules(
root_module
)
for fsdp_state, fsdp_module in zip(reversed(fsdp_states), reversed(fsdp_modules)):
for buffer in fsdp_module.buffers():
if buffer in visited_buffers:
continue
visited_buffers.add(buffer)
buffers.append(buffer)
buffer_dtypes.append(fsdp_state.mixed_precision.buffer_dtype)
assert len(buffers) == len(buffer_dtypes), f"{len(buffers)} {len(buffer_dtypes)}"
return buffers, buffer_dtypes
@no_type_check
def _get_orig_buffer_dtypes(
state: _FSDPState,
buffer_names: List[str],
) -> List[torch.dtype]:
"""
Returns the original buffer types of the given buffer names.
"""
buffer_dtypes: List[torch.dtype] = []
for buffer_name in buffer_names:
_p_assert(
buffer_name in state._buffer_name_to_orig_dtype,
f"{buffer_name} is missing from pre-computed dict on rank "
f"{state.rank}, which only has keys "
f"{state._buffer_name_to_orig_dtype.keys()}",
)
buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name])
return buffer_dtypes
def _cast_buffers_to_dtype_and_device(
buffers: List[torch.Tensor],
buffer_dtypes: List[Optional[torch.dtype]],
device: torch.device,
) -> None:
"""
Casts ``buffers`` to the dtypes given by ``buffer_dtypes`` and moves them
to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
corresponding buffer is only moved to ``device``.
"""
_p_assert(
buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
f"Expects `buffers` and `buffer_dtypes` to have the same length if "
f"`buffer_dtypes` is specified but got {len(buffers)} and "
f"{len(buffer_dtypes)}",
)
for buffer, buffer_dtype in zip(buffers, buffer_dtypes):
if not torch.is_floating_point(buffer) or buffer_dtype is None:
buffer.data = buffer.to(device=device)
else:
buffer.data = buffer.to(device=device, dtype=buffer_dtype)