mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547 Approved by: https://github.com/kwen2501
338 lines
11 KiB
Python
338 lines
11 KiB
Python
# mypy: allow-untyped-defs
|
|
import contextlib
|
|
import warnings
|
|
from collections.abc import Generator
|
|
from typing import cast
|
|
|
|
import torch
|
|
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
|
import torch.nn as nn
|
|
from torch.distributed.fsdp._common_utils import (
|
|
_FSDPState,
|
|
_get_module_fsdp_state,
|
|
_has_fsdp_params,
|
|
_module_handle,
|
|
HandleTrainingState,
|
|
TrainingState,
|
|
)
|
|
from torch.distributed.fsdp._runtime_utils import (
|
|
_lazy_init,
|
|
_reset_flat_param_grad_info_if_needed,
|
|
_reshard,
|
|
_reshard_grads,
|
|
_unshard,
|
|
_unshard_grads,
|
|
)
|
|
from torch.distributed.utils import _p_assert
|
|
|
|
from ._flat_param import FlatParamHandle
|
|
|
|
|
|
FLAT_PARAM = "_flat_param"
|
|
|
|
|
|
@torch.no_grad()
|
|
def _writeback_to_local_shard(
|
|
handle: FlatParamHandle,
|
|
writeback_grad: bool,
|
|
):
|
|
"""
|
|
For the handle, writes back the this rank's shard of the unsharded
|
|
flattened parameter to the sharded flattened parameter. If
|
|
``writeback_grad=True``, then writes back to the sharded gradient as
|
|
well.
|
|
|
|
Precondition: The handle's ``FlatParameter`` 's data points to the
|
|
padded unsharded flattened parameter.
|
|
"""
|
|
|
|
def _get_shard(flat_param_or_grad: torch.Tensor) -> torch.Tensor:
|
|
if handle.uses_sharded_strategy:
|
|
# For sharded strategies, get the *unpadded* shard instead of
|
|
# the *padded* shard to persist user changes to the padding
|
|
# (though FSDP does not explicitly support this)
|
|
shard, _ = FlatParamHandle._get_unpadded_shard(
|
|
flat_param_or_grad,
|
|
handle.rank,
|
|
handle.world_size,
|
|
)
|
|
return shard
|
|
# For `NO_SHARD`, the `flat_param` or its gradient may be modified,
|
|
# so we write it back directly
|
|
return flat_param_or_grad
|
|
|
|
param_shard = _get_shard(handle.flat_param)
|
|
handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) # type: ignore[attr-defined]
|
|
if writeback_grad:
|
|
existing_grad = handle.sharded_grad
|
|
if existing_grad is not None:
|
|
assert handle.flat_param.grad is not None
|
|
grad_shard = _get_shard(handle.flat_param.grad)
|
|
existing_grad[: grad_shard.numel()].copy_(grad_shard)
|
|
|
|
|
|
def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None:
|
|
"""
|
|
De-registers the flattened parameter from the wrapped module, hiding it
|
|
from ``nn.Module`` methods.
|
|
|
|
We do not use ``del`` because we want ``FLAT_PARAM`` to always be an
|
|
attribute but dynamically change whether it is visible to ``nn.Module``
|
|
methods.
|
|
"""
|
|
if _has_fsdp_params(state, module):
|
|
# TODO: figure out the case for the composable APIs.
|
|
cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None)
|
|
|
|
|
|
def _register_flat_param(state: _FSDPState, module: nn.Module) -> None:
|
|
"""
|
|
Registers the flattened parameter to the wrapped module, making it
|
|
visible to ``nn.Module`` methods.
|
|
|
|
We do not use :meth:`nn.Module.register_parameter` because we want
|
|
``FLAT_PARAM`` to always be an attribute but dynamically change whether
|
|
it is visible to ``nn.Module`` methods.
|
|
"""
|
|
handle = _module_handle(state, module)
|
|
if _has_fsdp_params(state, module):
|
|
# TODO: figure out the case for the composable APIs.
|
|
cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handle.flat_param
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator:
|
|
"""
|
|
Assumes that the flattened parameter is unsharded. When in the context,
|
|
de-registers the flattened parameter and unflattens the original
|
|
parameters as ``nn.Parameter`` views into the flattened parameter.
|
|
After the context, re-registers the flattened parameter and restores
|
|
the original parameters as ``Tensor`` views into the flattened
|
|
parameter.
|
|
"""
|
|
handle = _module_handle(state, module)
|
|
if not handle:
|
|
yield
|
|
else:
|
|
_deregister_flat_param(state, module)
|
|
try:
|
|
with handle.unflatten_as_params():
|
|
yield
|
|
finally:
|
|
if not handle._use_orig_params:
|
|
_register_flat_param(state, module)
|
|
|
|
|
|
def _validate_unshard_params_args(
|
|
state: _FSDPState,
|
|
writeback: bool,
|
|
rank0_only: bool,
|
|
offload_to_cpu: bool,
|
|
with_grads: bool,
|
|
) -> None:
|
|
if with_grads and (offload_to_cpu or not state._use_orig_params):
|
|
raise NotImplementedError(
|
|
f"with_grads={with_grads}, "
|
|
f"use_orig_params={state._use_orig_params}, "
|
|
f"offload_to_cpu={offload_to_cpu} "
|
|
f"is not supported yet"
|
|
)
|
|
if offload_to_cpu and state._handle and (not state._handle.uses_sharded_strategy):
|
|
raise NotImplementedError(
|
|
"offload_to_cpu=True and NO_SHARD is not supported yet"
|
|
)
|
|
if writeback and rank0_only:
|
|
# TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to
|
|
# persist the changes.
|
|
raise NotImplementedError(
|
|
"writeback=True and rank0_only=True is not supported yet"
|
|
)
|
|
if offload_to_cpu and not rank0_only:
|
|
warnings.warn(
|
|
"offload_to_cpu=True and rank0_only=False may result in the"
|
|
"unsharded parameters being redundantly copied to CPU memory for "
|
|
"GPUs sharing the same CPU memory, which risks CPU OOM. We "
|
|
"recommend using offload_to_cpu=True with rank0_only=True."
|
|
)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _unshard_fsdp_state_params(
|
|
module: nn.Module,
|
|
state: _FSDPState,
|
|
writeback: bool,
|
|
rank0_only: bool,
|
|
offload_to_cpu: bool,
|
|
with_grads: bool,
|
|
):
|
|
"""
|
|
This unshards the parameters for a single FSDP state ``state`` that
|
|
corresponds to ``module``.
|
|
"""
|
|
_validate_unshard_params_args(
|
|
state, writeback, rank0_only, offload_to_cpu, with_grads
|
|
)
|
|
state._device_handle.synchronize()
|
|
# If handles are shared by other module(s), the handle may be already unsharded.
|
|
maybe_handle = _module_handle(state, module)
|
|
handle = None
|
|
if (
|
|
maybe_handle
|
|
and maybe_handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS
|
|
):
|
|
handle = maybe_handle
|
|
if not handle:
|
|
yield
|
|
return
|
|
|
|
assert handle._training_state == HandleTrainingState.IDLE, (
|
|
f"Expects the handle training to be IDLE but got {handle._training_state}"
|
|
)
|
|
|
|
handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
|
|
|
|
_reset_flat_param_grad_info_if_needed(handle)
|
|
free_unsharded_flat_param = handle.needs_unshard()
|
|
# No need to call `wait_stream()` since we unshard in the computation
|
|
# stream directly
|
|
computation_stream = state._device_handle.current_stream()
|
|
_unshard(state, handle, computation_stream, computation_stream)
|
|
if with_grads:
|
|
_unshard_grads(handle)
|
|
|
|
if rank0_only and state.rank != 0:
|
|
# Free the unsharded flattened parameter early
|
|
_reshard(state, handle, free_unsharded_flat_param)
|
|
if with_grads:
|
|
_reshard_grads(handle)
|
|
try:
|
|
yield
|
|
finally:
|
|
handle._training_state = HandleTrainingState.IDLE
|
|
else:
|
|
# Unflatten the unsharded flattened parameters
|
|
with contextlib.ExitStack() as stack:
|
|
# Invariant: rank == 0 or !rank0_only
|
|
if offload_to_cpu and handle.uses_sharded_strategy:
|
|
stack.enter_context(handle.to_cpu())
|
|
# NOTE: Since PyTorch enforces that a parameter and its
|
|
# gradients need to match metadata (e.g. device), we must
|
|
# move gradients to CPU *after* we move parameters.
|
|
# NOTE: This assumes 1 `FlatParameter`
|
|
if not state._use_orig_params:
|
|
stack.enter_context(_unflatten_as_params(state, module))
|
|
try:
|
|
yield
|
|
finally:
|
|
stack.close()
|
|
if writeback:
|
|
_writeback_to_local_shard(handle, with_grads)
|
|
_reshard(state, handle, free_unsharded_flat_param)
|
|
if with_grads:
|
|
_reshard_grads(handle)
|
|
handle._training_state = HandleTrainingState.IDLE
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _unshard_params_for_summon(
|
|
module: nn.Module,
|
|
state: _FSDPState,
|
|
writeback: bool,
|
|
rank0_only: bool,
|
|
offload_to_cpu: bool,
|
|
with_grads: bool,
|
|
):
|
|
_validate_unshard_params_args(
|
|
state, writeback, rank0_only, offload_to_cpu, with_grads
|
|
)
|
|
_lazy_init(state, module)
|
|
if state.training_state == TrainingState.FORWARD_BACKWARD:
|
|
raise AssertionError(
|
|
"Cannot manually unshard parameters during forward/backward"
|
|
)
|
|
elif state.training_state == TrainingState.SUMMON_FULL_PARAMS:
|
|
raise AssertionError(
|
|
"Cannot manually unshard parameters when already unsharding parameters"
|
|
)
|
|
with _unshard_fsdp_state_params(
|
|
module=module,
|
|
state=state,
|
|
writeback=writeback,
|
|
rank0_only=rank0_only,
|
|
offload_to_cpu=offload_to_cpu,
|
|
with_grads=with_grads,
|
|
):
|
|
try:
|
|
state.training_state = TrainingState.SUMMON_FULL_PARAMS
|
|
yield
|
|
finally:
|
|
state.training_state = TrainingState.IDLE
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _unshard_params(
|
|
module: nn.Module,
|
|
recurse: bool,
|
|
writeback: bool,
|
|
rank0_only: bool,
|
|
offload_to_cpu: bool,
|
|
with_grads: bool,
|
|
):
|
|
"""
|
|
This unshards FSDP-managed parameters for all modules with FSDP applied in
|
|
the module tree rooted at ``module``.
|
|
"""
|
|
if not recurse:
|
|
optional_state = _get_module_fsdp_state(module)
|
|
if optional_state is None:
|
|
with contextlib.nullcontext():
|
|
yield
|
|
return
|
|
states_and_modules = ([optional_state], [module])
|
|
else:
|
|
states_and_modules = traversal_utils._get_fsdp_states_with_modules(module)
|
|
with contextlib.ExitStack() as stack:
|
|
for state, module in zip(*states_and_modules):
|
|
stack.enter_context(
|
|
_unshard_params_for_summon(
|
|
module=module,
|
|
state=state,
|
|
writeback=writeback,
|
|
rank0_only=rank0_only,
|
|
offload_to_cpu=offload_to_cpu,
|
|
with_grads=with_grads,
|
|
)
|
|
)
|
|
yield
|
|
|
|
|
|
def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
|
|
"""
|
|
Deregisters the original parameters; registers the ``FlatParameter``.
|
|
"""
|
|
handle = _module_handle(state, module)
|
|
if not handle:
|
|
return
|
|
_p_assert(
|
|
handle._use_orig_params,
|
|
f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
|
|
f"handle: {handle._use_orig_params}",
|
|
)
|
|
handle._deregister_orig_params()
|
|
_register_flat_param(state, module)
|
|
|
|
|
|
def _register_orig_params(state: _FSDPState, module: nn.Module) -> None:
|
|
"""
|
|
Deregisters the ``FlatParameter``; registers the original parameters.
|
|
"""
|
|
handle = _module_handle(state, module)
|
|
if not handle:
|
|
return
|
|
_deregister_flat_param(state, module)
|
|
if handle.is_sharded(handle.flat_param):
|
|
handle._use_sharded_views()
|
|
handle._use_sharded_grad_views()
|
|
else:
|
|
handle._use_unsharded_views(as_params=True)
|