mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is the result of applying the ruff `UP035` check. `Callable` is imported from `collections.abc` instead of `typing`. This PR is the follow-up of #164054. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164104 Approved by: https://github.com/Skylion007
1187 lines
45 KiB
Python
1187 lines
45 KiB
Python
# mypy: allow-untyped-defs
|
|
import collections
|
|
import itertools
|
|
import os
|
|
import warnings
|
|
from collections.abc import Callable, Generator, Iterable, Iterator
|
|
from typing import Any, no_type_check, Optional, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed.fsdp._exec_order_utils as exec_order_utils
|
|
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
|
import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
|
|
import torch.nn as nn
|
|
from torch.distributed.algorithms._comm_hooks import default_hooks
|
|
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
|
from torch.distributed.distributed_c10d import _get_default_group
|
|
from torch.distributed.fsdp._common_utils import (
|
|
_FSDPDeviceHandle,
|
|
_FSDPState,
|
|
_get_module_fsdp_state,
|
|
_is_fsdp_flattened,
|
|
_named_parameters_with_duplicates,
|
|
clean_tensor_name,
|
|
TrainingState,
|
|
)
|
|
from torch.distributed.fsdp._flat_param import (
|
|
_FSDP_USE_FULL_PREC_IN_EVAL,
|
|
FlatParameter,
|
|
FlatParamHandle,
|
|
HandleShardingStrategy,
|
|
)
|
|
from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
|
|
from torch.distributed.fsdp.api import (
|
|
BackwardPrefetch,
|
|
CPUOffload,
|
|
FullOptimStateDictConfig,
|
|
FullStateDictConfig,
|
|
MixedPrecision,
|
|
ShardingStrategy,
|
|
StateDictConfig,
|
|
StateDictType,
|
|
)
|
|
from torch.distributed.fsdp.wrap import _Policy
|
|
from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
|
|
from torch.distributed.utils import _sync_params_and_buffers
|
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch.utils.hooks import RemovableHandle
|
|
|
|
_TORCHDISTX_AVAIL = True
|
|
try:
|
|
from torchdistx import deferred_init, fake # type: ignore[import]
|
|
except ImportError:
|
|
_TORCHDISTX_AVAIL = False
|
|
|
|
PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024)
|
|
FSDP_SYNCED = "_fsdp_synced"
|
|
# Specification of process groups for hybrid sharding strategies.
|
|
HybridShardProcessGroupType = tuple[dist.ProcessGroup, dist.ProcessGroup]
|
|
# Overall specification of process group.
|
|
ProcessGroupType = Optional[Union[dist.ProcessGroup, HybridShardProcessGroupType]]
|
|
|
|
|
|
# TODO (awgu): Refactor this later
|
|
SHARDING_STRATEGY_MAP = {
|
|
ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD,
|
|
ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD,
|
|
ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP,
|
|
ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
|
|
ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
|
|
}
|
|
HYBRID_SHARDING_STRATEGIES = [
|
|
ShardingStrategy.HYBRID_SHARD,
|
|
ShardingStrategy._HYBRID_SHARD_ZERO2,
|
|
]
|
|
NO_RESHARD_AFTER_FORWARD_STRATEGIES = (
|
|
ShardingStrategy.SHARD_GRAD_OP,
|
|
ShardingStrategy._HYBRID_SHARD_ZERO2,
|
|
)
|
|
|
|
|
|
# NOTE: Since non-self attributes cannot be type annotated, several attributes
|
|
# on `state` are defined first as local variables before being assigned.
|
|
|
|
|
|
@no_type_check
|
|
def _init_process_group_state(
|
|
state: _FSDPState,
|
|
process_group: ProcessGroupType,
|
|
sharding_strategy: ShardingStrategy,
|
|
policy: Optional[_Policy],
|
|
device_mesh: Optional[DeviceMesh] = None,
|
|
) -> _FSDPState:
|
|
if process_group is not None and device_mesh is not None:
|
|
raise ValueError(
|
|
"Cannot pass both process_group and device_mesh at the "
|
|
"same time. Please just pass only one of them."
|
|
)
|
|
is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES
|
|
if is_hybrid_strategy:
|
|
if process_group is None and policy is None and device_mesh is None:
|
|
# Raise an error here, since this is manual wrapping with no process group
|
|
# passed in, there is no way to ensure all wrapped FSDP instances use the same
|
|
# process groups.
|
|
raise ValueError(
|
|
f"Manual wrapping with {sharding_strategy} "
|
|
"requires explicit specification of process group or device_mesh."
|
|
)
|
|
else:
|
|
state = _init_process_group_state_for_hybrid_shard(
|
|
state, process_group, device_mesh
|
|
)
|
|
else:
|
|
if device_mesh:
|
|
state._device_mesh = device_mesh
|
|
state.process_group = device_mesh.get_group(mesh_dim=0)
|
|
else:
|
|
state.process_group = (
|
|
process_group if process_group is not None else _get_default_group()
|
|
)
|
|
|
|
state.rank = state.process_group.rank()
|
|
state.world_size = state.process_group.size()
|
|
data_parallel_world_size = state.world_size
|
|
if is_hybrid_strategy:
|
|
data_parallel_world_size *= state._inter_node_pg.size()
|
|
state._gradient_predivide_factor = (
|
|
default_hooks.DefaultState._get_gradient_predivide_factor(
|
|
data_parallel_world_size
|
|
)
|
|
)
|
|
state._gradient_postdivide_factor = (
|
|
data_parallel_world_size / state._gradient_predivide_factor
|
|
)
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_process_group_state_for_hybrid_shard(
|
|
state: _FSDPState,
|
|
process_group: ProcessGroupType,
|
|
device_mesh: DeviceMesh,
|
|
) -> _FSDPState:
|
|
if device_mesh:
|
|
if _is_valid_hybrid_shard_device_mesh(device_mesh):
|
|
state._device_mesh = device_mesh
|
|
# We currently only allow _inter_node_pg to be the outermost dimension, and the
|
|
# process_group(intra_node) to be the innermost dimension.
|
|
state._inter_node_pg = device_mesh.get_group(mesh_dim=0)
|
|
state.process_group = device_mesh.get_group(mesh_dim=1)
|
|
else:
|
|
raise ValueError(
|
|
f"Expected device_mesh to have ndim=2 but got {device_mesh.ndim}"
|
|
)
|
|
elif process_group is None:
|
|
default_group = _get_default_group()
|
|
intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
|
|
default_group, state._device_handle.device_count()
|
|
)
|
|
# we shard across intra-node
|
|
state.process_group = intra_node_group
|
|
# save _inter_node_pg to allreduce across.
|
|
state._inter_node_pg = inter_node_group
|
|
else:
|
|
# Check type and assign state.process_group and state._inter_node_pg.
|
|
if _is_valid_hybrid_shard_pg_type(process_group):
|
|
# Assuming that user passed in as intra node group and inter node group
|
|
# as documented.
|
|
state.process_group, state._inter_node_pg = process_group
|
|
else:
|
|
raise ValueError(
|
|
"Expected process_group to be passed in as either None or "
|
|
f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}"
|
|
)
|
|
# Create state for allreduce
|
|
state._inter_node_state = _get_default_comm_hook_state(
|
|
process_group=state._inter_node_pg,
|
|
)
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool:
|
|
return (
|
|
isinstance(process_group, tuple)
|
|
and len(process_group) == 2
|
|
and all(isinstance(pg, dist.ProcessGroup) for pg in process_group)
|
|
)
|
|
|
|
|
|
@no_type_check
|
|
def _is_valid_hybrid_shard_device_mesh(device_mesh: DeviceMesh) -> bool:
|
|
return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2
|
|
|
|
|
|
@no_type_check
|
|
def _init_intra_node_process_group(num_devices_per_node: int) -> dist.ProcessGroup:
|
|
"""
|
|
Return a process group across the current node.
|
|
|
|
For example, given each row is a distinct node:
|
|
0 1 2 3 4 5 6 7
|
|
8 9 10 11 12 13 14 15
|
|
This API would return an intra-node subgroup across
|
|
[0, 1, ..., 7] or [8, 9, ..., 15] depending on the process's rank.
|
|
For example, rank 3 would get [0, 1, ..., 7].
|
|
"""
|
|
intra_node_subgroup, _ = dist.new_subgroups(num_devices_per_node)
|
|
return intra_node_subgroup
|
|
|
|
|
|
@no_type_check
|
|
def _init_inter_node_process_group(
|
|
global_process_group: dist.ProcessGroup,
|
|
num_devices_per_node: int,
|
|
) -> dist.ProcessGroup:
|
|
"""
|
|
Return an inter-node process group where each contained rank has the same local rank.
|
|
|
|
For example, given each row is a distinct node:
|
|
0 1 2 3 4 5 6 7
|
|
8 9 10 11 12 13 14 15
|
|
This API would return inter-node process group [0, 8], [1, 9], [2, 10], and so forth
|
|
depending on the process's rank. For example, rank 1 would get [1, 9], rank 5
|
|
would get [5, 13].
|
|
"""
|
|
# the inter-node pg that is returned
|
|
inter_node_pg = None
|
|
sharding_backend = dist.get_backend(global_process_group)
|
|
world_size = dist.get_world_size(global_process_group)
|
|
# Assuming fully homogeneous setup
|
|
num_nodes = world_size // num_devices_per_node
|
|
my_local_rank = dist.get_rank(global_process_group) % num_devices_per_node
|
|
for local_rank in range(num_devices_per_node):
|
|
ranks_for_inter_group = [
|
|
local_rank + (i * num_devices_per_node) for i in range(num_nodes)
|
|
]
|
|
# every rank always needs to call dist.new_group
|
|
grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend)
|
|
if local_rank == my_local_rank:
|
|
inter_node_pg = grp
|
|
|
|
assert inter_node_pg is not None, (
|
|
f"{my_local_rank} expected to assign inter-node pg, but did not"
|
|
)
|
|
return inter_node_pg
|
|
|
|
|
|
def _init_intra_and_inter_node_groups(
|
|
global_process_group: dist.ProcessGroup,
|
|
num_devices_per_node: int,
|
|
) -> tuple[dist.ProcessGroup, dist.ProcessGroup]:
|
|
"""
|
|
Initialize intra and inter-node process groups and return the ones corresponding to this process's rank.
|
|
|
|
This function can be used to initialize process groups for ``HYBRID_SHARD`` or
|
|
``_HYBRID_SHARD_ZERO2`` in FSDP.
|
|
This function assumes each node has an equal number of CUDA-enabled devices.
|
|
Returns:
|
|
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group.
|
|
"""
|
|
return (
|
|
_init_intra_node_process_group(num_devices_per_node),
|
|
_init_inter_node_process_group(global_process_group, num_devices_per_node),
|
|
)
|
|
|
|
|
|
@no_type_check
|
|
def _init_ignored_module_states(
|
|
state: _FSDPState,
|
|
module: nn.Module,
|
|
ignored_modules: Optional[Iterable[torch.nn.Module]],
|
|
ignored_states: Union[
|
|
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
|
|
] = None,
|
|
) -> _FSDPState:
|
|
if ignored_modules is not None and ignored_states is not None:
|
|
raise ValueError(
|
|
"Cannot pass both ignored_modules and ignored_states at the "
|
|
"same time. Please just pass ignored_states."
|
|
)
|
|
ignored_parameters = None
|
|
passed_as_ignored_states = ignored_states is not None
|
|
if passed_as_ignored_states:
|
|
ignored_states_list = list(ignored_states)
|
|
_check_ignored_states(ignored_states_list, True)
|
|
else:
|
|
ignored_states_list = []
|
|
_check_ignored_states(
|
|
list(ignored_modules) if ignored_modules is not None else [], False
|
|
)
|
|
if len(ignored_states_list) > 0:
|
|
if isinstance(ignored_states_list[0], nn.Parameter):
|
|
ignored_parameters = ignored_states_list
|
|
else:
|
|
ignored_modules = ignored_states_list
|
|
state._ignored_modules = _get_ignored_modules(module, ignored_modules)
|
|
state._ignored_params = _get_ignored_params(
|
|
module,
|
|
state._ignored_modules,
|
|
ignored_parameters,
|
|
)
|
|
state._ignored_buffer_names = _get_ignored_buffer_names(
|
|
module,
|
|
state._ignored_modules,
|
|
)
|
|
# TODO: FSDP's contract for buffers is not well-defined. They are
|
|
# implicitly ignored for most functionality since they are not sharded;
|
|
# however, FSDP still imposes some semantics on buffers (e.g. buffer mixed
|
|
# precision). We should formalize this contract and decide if we need to
|
|
# compute and store `_ignored_buffers`.
|
|
return state
|
|
|
|
|
|
def _check_ignored_states(
|
|
ignored_states: list[Any], passed_as_ignored_states: bool
|
|
) -> None:
|
|
"""
|
|
Check that the ignored states are uniformly parameters or uniformly modules.
|
|
|
|
We may remove this check in the future if we permit mixing.
|
|
"""
|
|
if len(ignored_states) == 0:
|
|
return
|
|
if passed_as_ignored_states:
|
|
all_params = all(isinstance(state, nn.Parameter) for state in ignored_states)
|
|
all_modules = all(isinstance(state, nn.Module) for state in ignored_states)
|
|
if not all_params and not all_modules:
|
|
# Sort for consistent ordering for unit test regex matching
|
|
sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
|
|
raise ValueError(
|
|
"ignored_states expects all nn.Parameter or all nn.Module list "
|
|
f"elements but got types {sorted_types}"
|
|
)
|
|
else:
|
|
if not all(isinstance(state, nn.Module) for state in ignored_states):
|
|
sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
|
|
raise ValueError(
|
|
"ignored_modules expects nn.Module list elements but got "
|
|
f"types {sorted_types}"
|
|
)
|
|
|
|
|
|
@no_type_check
|
|
def _init_device_handle(
|
|
state: _FSDPState,
|
|
module: nn.Module,
|
|
ignored_params: set[nn.Parameter],
|
|
device_id: Optional[Union[int, torch.device]],
|
|
) -> _FSDPState:
|
|
"""
|
|
Determine device handle used for initializing FSDP.
|
|
|
|
If a device is specified by ``device_id``,
|
|
then returns device handle corresponds to that device type. Otherwise, If the
|
|
module is already on a non-CPU device, then the device type is that non-CPU device type.
|
|
If the module is on CPU or meta, then the device type is the current accelerator device.
|
|
See the :ref:`Accelerators<accelerators>` for details.
|
|
|
|
|
|
This method will be called once ignored parameters was determined, as the device handle maybe needed
|
|
for other initialization.
|
|
"""
|
|
determined_device = None
|
|
if device_id is not None:
|
|
determined_device = (
|
|
device_id
|
|
if isinstance(device_id, torch.device)
|
|
else torch.device(device_id)
|
|
)
|
|
if determined_device is None:
|
|
for param in _get_orig_params(module, ignored_params):
|
|
if param.device.type in {"cpu", "meta"}:
|
|
continue
|
|
if determined_device is None:
|
|
determined_device = param.device
|
|
else:
|
|
if param.device.type != determined_device.type:
|
|
raise RuntimeError(
|
|
f"FSDP does not support modules with different device types "
|
|
f"but got params on {determined_device.type} and {param.device.type}"
|
|
)
|
|
determined_device = determined_device or torch._C._get_accelerator()
|
|
if determined_device.type == "cpu":
|
|
raise RuntimeError(
|
|
"FSDP needs a non-CPU accelerator device, but no accelerator device is detected."
|
|
)
|
|
|
|
state._device_handle = _FSDPDeviceHandle.from_device(determined_device)
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_buffer_state(
|
|
state: _FSDPState,
|
|
module: nn.Module,
|
|
) -> _FSDPState:
|
|
state._buffer_names = _get_buffer_names(module)
|
|
# Save a mapping from clean fully-qualified buffer name (starting from
|
|
# `module`) to its original dtype for restoring that dtype during model
|
|
# checkpointing when buffer mixed precision is enabled. The names should
|
|
# be clean since the casting happens in a `summon_full_params()` context.
|
|
_buffer_name_to_orig_dtype: dict[str, torch.dtype] = {}
|
|
for buffer_name, buffer in module.named_buffers():
|
|
buffer_name = clean_tensor_name(buffer_name)
|
|
_buffer_name_to_orig_dtype[buffer_name] = buffer.dtype
|
|
state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_core_state(
|
|
state: _FSDPState,
|
|
sharding_strategy: Optional[ShardingStrategy],
|
|
mixed_precision: Optional[MixedPrecision],
|
|
cpu_offload: Optional[CPUOffload],
|
|
limit_all_gathers: bool,
|
|
use_orig_params: bool,
|
|
backward_prefetch_limit: int,
|
|
forward_prefetch_limit: int,
|
|
) -> _FSDPState:
|
|
# We clamp the strategy to `NO_SHARD` for world size of 1 since they are
|
|
# currently functionally equivalent. This may change if/when we integrate
|
|
# FSDP with MoE.
|
|
if state.world_size == 1:
|
|
if sharding_strategy != ShardingStrategy.NO_SHARD:
|
|
warnings.warn(
|
|
"FSDP is switching to use `NO_SHARD` instead of "
|
|
f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since "
|
|
"the world size is 1."
|
|
)
|
|
sharding_strategy = ShardingStrategy.NO_SHARD
|
|
elif sharding_strategy == ShardingStrategy.NO_SHARD:
|
|
warnings.warn(
|
|
"The `NO_SHARD` sharding strategy is deprecated. If having issues, "
|
|
"please use `DistributedDataParallel` instead.",
|
|
FutureWarning,
|
|
# Level 1 is here, level 2 is from `FullyShardedDataParallel`, and
|
|
# level 3 is from the true caller
|
|
stacklevel=3,
|
|
)
|
|
state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD
|
|
state.mixed_precision = mixed_precision or MixedPrecision()
|
|
if mixed_precision is not None:
|
|
torch._C._log_api_usage_once(
|
|
f"torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}"
|
|
)
|
|
state._use_full_prec_in_eval = (
|
|
os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
|
|
)
|
|
state.cpu_offload = cpu_offload or CPUOffload()
|
|
state.limit_all_gathers = limit_all_gathers
|
|
state._use_orig_params = use_orig_params
|
|
state.training_state = TrainingState.IDLE
|
|
state._is_root = None
|
|
state._free_event_queue = _FreeEventQueue()
|
|
state._debug_level = dist.get_debug_level()
|
|
state._exec_order_data = exec_order_utils._ExecOrderData(
|
|
state._debug_level,
|
|
backward_prefetch_limit,
|
|
forward_prefetch_limit,
|
|
)
|
|
state._unshard_event = None
|
|
# Mapping from fully sharded module to the handles it is responsible to
|
|
# unshard and reshard (see [Note: Fully Sharded Module])
|
|
_fully_sharded_module_to_handle: dict[nn.Module, FlatParamHandle] = {}
|
|
state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle
|
|
# Invariant: `state.params` contains exactly the `FlatParameter`s of the
|
|
# handles in `state._handle`
|
|
_handle: Optional[FlatParamHandle] = None
|
|
state._handle = _handle
|
|
params: list[FlatParameter] = []
|
|
state.params = params
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_runtime_state(
|
|
state: _FSDPState,
|
|
) -> _FSDPState:
|
|
_root_pre_forward_handles: list[RemovableHandle] = []
|
|
state._root_pre_forward_handles = _root_pre_forward_handles
|
|
_pre_forward_handles: list[RemovableHandle] = []
|
|
state._pre_forward_handles = _pre_forward_handles
|
|
_post_forward_handles: list[RemovableHandle] = []
|
|
state._post_forward_handles = _post_forward_handles
|
|
state._sync_gradients = True
|
|
state._comm_hook = None
|
|
state._comm_hook_state = None
|
|
# Used to prevent running the pre-backward hook multiple times
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_prefetching_state(
|
|
state: _FSDPState,
|
|
backward_prefetch: BackwardPrefetch,
|
|
forward_prefetch: bool,
|
|
) -> _FSDPState:
|
|
state.backward_prefetch = backward_prefetch
|
|
state.forward_prefetch = forward_prefetch
|
|
# The data structures use tuples of handles to generalize over the case
|
|
# where a module's forward involves multiple handles.
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
|
|
# TODO: we need to add additional check once we support FSDP + PiPPy.
|
|
# This check is currently sufficient, since we only support FSDP + TP.
|
|
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
|
|
# if a root mesh is not the same as device_mesh,
|
|
# meaning the device_mesh is sliced out from the root mesh.
|
|
if device_mesh and root_mesh != state._device_mesh:
|
|
state._fsdp_extension = DTensorExtensions(state._device_handle)
|
|
else:
|
|
# We need to explicitly set _fsdp_extension to None.
|
|
# Otherwise, we will run into an infinite recursion when getting the attribute.
|
|
state._fsdp_extension = None
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_state_dict_state(state: _FSDPState) -> _FSDPState:
|
|
state._state_dict_type = StateDictType.FULL_STATE_DICT
|
|
state_dict_config: StateDictConfig = FullStateDictConfig()
|
|
state._optim_state_dict_config = FullOptimStateDictConfig()
|
|
state._state_dict_config = state_dict_config
|
|
unshard_params_ctx: dict[nn.Module, Generator] = {}
|
|
state._unshard_params_ctx = unshard_params_ctx
|
|
|
|
return state
|
|
|
|
|
|
def _verify_managed_params(module: nn.Module, params: list[nn.Parameter]) -> None:
|
|
"""
|
|
Verify if the parameters are accepted by FSDP. The only restriction now
|
|
is that the parameter cannot be a scalar tensor (param.shape == []).
|
|
"""
|
|
for param in params:
|
|
if len(param.shape) == 0:
|
|
param_name = ""
|
|
for name, param_ in module.named_parameters():
|
|
if param is param_:
|
|
param_name = name
|
|
break
|
|
assert param_name
|
|
raise ValueError(
|
|
"FSDP doesn't support scalar parameters. "
|
|
f"Change {param_name} to a 1D tensor with numel equal to 1."
|
|
)
|
|
|
|
|
|
@no_type_check
|
|
def _init_param_handle_from_module(
|
|
state: _FSDPState,
|
|
fully_sharded_module: nn.Module,
|
|
device_id: Optional[Union[int, torch.device]],
|
|
param_init_fn: Optional[Callable[[nn.Module], None]],
|
|
sync_module_states: bool,
|
|
) -> _FSDPState:
|
|
"""Initialize a ``FlatParamHandle`` from a module ``fully_sharded_module``."""
|
|
_check_single_device_module(fully_sharded_module, state._ignored_params, device_id)
|
|
device_from_device_id = _get_device_from_device_id(
|
|
device_id, state.rank, state._device_handle
|
|
)
|
|
is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
|
|
fully_sharded_module, state._ignored_params, state._ignored_modules
|
|
)
|
|
# Materialize the module if needed
|
|
if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None:
|
|
_materialize_with_param_init_fn(
|
|
fully_sharded_module, param_init_fn, state._ignored_modules
|
|
)
|
|
elif is_meta_module:
|
|
_materialize_meta_module(
|
|
fully_sharded_module,
|
|
device_id,
|
|
state._ignored_modules,
|
|
state._device_handle,
|
|
)
|
|
elif is_torchdistX_deferred_init:
|
|
deferred_init.materialize_module(
|
|
fully_sharded_module,
|
|
check_fn=lambda submodule: _get_module_fsdp_state(submodule) is None
|
|
and submodule not in state._ignored_modules,
|
|
)
|
|
|
|
ignored_buffers = {
|
|
buffer
|
|
for ignored_module in state._ignored_modules
|
|
for buffer in ignored_module.buffers()
|
|
}
|
|
|
|
_move_module_to_device(
|
|
fully_sharded_module,
|
|
state._ignored_params,
|
|
ignored_buffers,
|
|
device_from_device_id,
|
|
)
|
|
state.compute_device = _get_compute_device(
|
|
fully_sharded_module,
|
|
state._ignored_params,
|
|
device_from_device_id,
|
|
state.rank,
|
|
state._device_handle,
|
|
)
|
|
|
|
managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params))
|
|
_verify_managed_params(fully_sharded_module, managed_params)
|
|
if sync_module_states:
|
|
_sync_module_params_and_buffers(
|
|
fully_sharded_module, managed_params, state.process_group
|
|
)
|
|
if state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
|
|
_sync_module_params_and_buffers(
|
|
fully_sharded_module, managed_params, state._inter_node_pg
|
|
)
|
|
_init_param_handle_from_params(state, managed_params, fully_sharded_module)
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_param_handle_from_params(
|
|
state: _FSDPState,
|
|
params: list[nn.Parameter],
|
|
fully_sharded_module: nn.Module,
|
|
):
|
|
if len(params) == 0:
|
|
return
|
|
handle = FlatParamHandle(
|
|
params,
|
|
fully_sharded_module,
|
|
state.compute_device,
|
|
SHARDING_STRATEGY_MAP[state.sharding_strategy],
|
|
state.cpu_offload.offload_params,
|
|
state.mixed_precision.param_dtype,
|
|
state.mixed_precision.reduce_dtype,
|
|
state.mixed_precision.keep_low_precision_grads,
|
|
state.process_group,
|
|
state._use_orig_params,
|
|
fsdp_extension=state._fsdp_extension,
|
|
)
|
|
handle.shard()
|
|
assert not state._handle
|
|
state.params.append(handle.flat_param)
|
|
state._handle = handle
|
|
state._fully_sharded_module_to_handle[handle._fully_sharded_module] = handle
|
|
cpu_device = torch.device("cpu")
|
|
if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device:
|
|
handle.flat_param_to(cpu_device)
|
|
|
|
|
|
def _get_ignored_modules(
|
|
root_module: nn.Module,
|
|
_ignored_modules: Optional[Iterable[torch.nn.Module]],
|
|
) -> set[nn.Module]:
|
|
"""
|
|
Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances.
|
|
|
|
Return the modules contained in their module
|
|
subtrees as a :class:`set`. Nested FSDP instances are excluded, but their
|
|
already-computed ignored modules are included.
|
|
|
|
``_ignored_modules`` represents the argument passed by the user to FSDP.
|
|
"""
|
|
msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s "
|
|
try:
|
|
ignored_root_modules = (
|
|
set(_ignored_modules) if _ignored_modules is not None else set()
|
|
)
|
|
except TypeError as e:
|
|
raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e
|
|
for module in ignored_root_modules:
|
|
if not isinstance(module, torch.nn.Module):
|
|
raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
|
|
if _get_module_fsdp_state(module):
|
|
# TODO: We may relax this by taking the FSDP instance's wrapped
|
|
# module to provide more flexibility to the user.
|
|
raise ValueError("`ignored_modules` should not include FSDP modules")
|
|
# Treat modules that cannot compose with `fully_shard` as ignored modules,
|
|
# meaning that their subtrees are ignored
|
|
for module in root_module.modules():
|
|
if not traversal_utils._composable(module):
|
|
ignored_root_modules.add(module)
|
|
# NOTE: Even if `ignored_root_modules` is empty, do not return early so
|
|
# that this FSDP instance can get any ignored modules from its children.
|
|
|
|
# Include child modules and exclude nested FSDP modules themselves
|
|
ignored_modules = {
|
|
child
|
|
for module in ignored_root_modules
|
|
for child in module.modules()
|
|
if not isinstance(child, fsdp_file.FullyShardedDataParallel)
|
|
}
|
|
if root_module in ignored_modules:
|
|
warnings.warn(
|
|
"Trying to ignore the top-level module passed into the FSDP "
|
|
"constructor itself will result in all parameters being "
|
|
f"ignored and is not well-supported: {module}"
|
|
)
|
|
# Include nested FSDP modules' ignored modules
|
|
for submodule in root_module.modules():
|
|
optional_fsdp_state = _get_module_fsdp_state(submodule)
|
|
if optional_fsdp_state is not None:
|
|
assert hasattr(optional_fsdp_state, "_ignored_modules")
|
|
ignored_modules.update(optional_fsdp_state._ignored_modules)
|
|
return ignored_modules
|
|
|
|
|
|
def _get_ignored_params(
|
|
root_module: torch.nn.Module,
|
|
ignored_modules: set[torch.nn.Module],
|
|
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
|
|
) -> set[torch.nn.Parameter]:
|
|
"""
|
|
Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``.
|
|
|
|
:class:`FlatParameter` s are excluded from the result.
|
|
"""
|
|
all_ignored_params: set[torch.nn.Parameter] = set()
|
|
|
|
params_in_ignored_modules = {
|
|
p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
|
|
}
|
|
|
|
all_ignored_params.update(params_in_ignored_modules)
|
|
|
|
if ignored_parameters is not None:
|
|
params_in_ignored_parameters = {
|
|
p for p in ignored_parameters if not _is_fsdp_flattened(p)
|
|
}
|
|
all_ignored_params.update(params_in_ignored_parameters)
|
|
|
|
# Always include nested FSDP modules' ignored parameters
|
|
for submodule in root_module.modules():
|
|
optional_fsdp_state = _get_module_fsdp_state(submodule)
|
|
if optional_fsdp_state is not None:
|
|
assert hasattr(optional_fsdp_state, "_ignored_params")
|
|
all_ignored_params.update(optional_fsdp_state._ignored_params)
|
|
|
|
return all_ignored_params
|
|
|
|
|
|
def _get_ignored_buffer_names(
|
|
root_module: torch.nn.Module,
|
|
ignored_modules: set[torch.nn.Module],
|
|
) -> set[str]:
|
|
"""Return the cleaned buffer FQNs in ``ignored_modules``."""
|
|
all_ignored_buffer_names: set[str] = set()
|
|
|
|
buffers_in_ignored_modules = {
|
|
buffer for m in ignored_modules for buffer in m.buffers()
|
|
}
|
|
|
|
all_ignored_buffer_names.update(
|
|
{
|
|
clean_tensor_name(buffer_name)
|
|
for buffer_name, buffer in root_module.named_buffers()
|
|
if buffer in buffers_in_ignored_modules
|
|
}
|
|
)
|
|
|
|
# Always include nested FSDP modules' ignored buffer names
|
|
for submodule in root_module.modules():
|
|
optional_fsdp_state = _get_module_fsdp_state(submodule)
|
|
if optional_fsdp_state is not None:
|
|
assert hasattr(optional_fsdp_state, "_ignored_buffer_names")
|
|
all_ignored_buffer_names.update(optional_fsdp_state._ignored_buffer_names)
|
|
|
|
return all_ignored_buffer_names
|
|
|
|
|
|
def _get_buffer_names(root_module: nn.Module) -> set[str]:
|
|
"""Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`."""
|
|
return {
|
|
clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers()
|
|
}
|
|
|
|
|
|
def _check_single_device_module(
|
|
module: nn.Module,
|
|
ignored_params: set[nn.Parameter],
|
|
device_id: Optional[Union[int, torch.device]],
|
|
) -> None:
|
|
"""
|
|
Raise an error if ``module`` has original parameters on multiple devices, ignoring the parameters in ``ignored_params``.
|
|
|
|
Thus, after this method, the
|
|
module must be either fully on the CPU or fully on a non-CPU device.
|
|
"""
|
|
devices = {param.device for param in _get_orig_params(module, ignored_params)}
|
|
# We allow module to be partially on CPU and partially on GPU if device_id is not
|
|
# None, since the device_id arg will result in the CPU portion being moved to
|
|
# GPU. This is useful in cases where part of the module may be parallelized
|
|
# by another algorithm and may already be on GPU. We'd like to enforce device_id
|
|
# to not be None, otherwise we'd flatten parameters in a mixed module which is
|
|
# not supported.
|
|
if len(devices) == 2 and torch.device("cpu") in devices:
|
|
if device_id is None:
|
|
raise RuntimeError(
|
|
"To support a module with both CPU and GPU params, "
|
|
"please pass in device_id argument."
|
|
)
|
|
elif len(devices) > 1:
|
|
raise RuntimeError(
|
|
f"FSDP only supports single device modules but got params on {devices}"
|
|
)
|
|
|
|
|
|
def _get_device_from_device_id(
|
|
device_id: Optional[Union[int, torch.device]],
|
|
rank: int,
|
|
device_handle: _FSDPDeviceHandle,
|
|
) -> Optional[torch.device]:
|
|
"""
|
|
Return a ``torch.device`` for the specified ``device_id``.
|
|
|
|
Processes ``device_id`` and returns either the corresponding device or
|
|
``None`` if ``device_id`` is ``None``.
|
|
"""
|
|
if device_id is None:
|
|
return None
|
|
device = (
|
|
device_id if isinstance(device_id, torch.device) else torch.device(device_id)
|
|
)
|
|
if device.type != "cpu" and device.index is None:
|
|
warnings.warn(
|
|
f"FSDP got the argument `device_id` {device_id} on rank "
|
|
f"{rank}, which does not have an explicit index. "
|
|
f"FSDP will use the current device {device_handle.current_device()}. "
|
|
f"If this is incorrect, please explicitly call `torch.{device.type}.set_device()` "
|
|
"before FSDP initialization or pass in the explicit device "
|
|
"index as the `device_id` argument."
|
|
)
|
|
device = torch.device(device_handle.current_device())
|
|
return device
|
|
|
|
|
|
def _need_to_materialize_module(
|
|
module: nn.Module,
|
|
ignored_params: set[nn.Parameter],
|
|
ignored_modules: set[nn.Module],
|
|
) -> tuple[bool, bool]:
|
|
"""
|
|
Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization.
|
|
|
|
At most of the returned bools can
|
|
be ``True``. If either is ``True``, then ``module`` needs to be
|
|
materialized.
|
|
"""
|
|
managed_params = list(_get_orig_params(module, ignored_params))
|
|
is_meta_module = any(param.is_meta for param in managed_params)
|
|
# TODO: We need to establish a contract for FSDP and buffers. For now, we
|
|
# skip checking for meta buffers from ignored modules. We should consider
|
|
# refactoring the initialization holistically to avoid so many traversals.
|
|
for submodule in module.modules():
|
|
if submodule in ignored_modules:
|
|
continue
|
|
for buf in submodule.buffers(recurse=False):
|
|
is_meta_module |= buf.is_meta
|
|
is_torchdistX_deferred_init = (
|
|
not is_meta_module
|
|
and _TORCHDISTX_AVAIL
|
|
and any(fake.is_fake(param) for param in managed_params)
|
|
)
|
|
return is_meta_module, is_torchdistX_deferred_init
|
|
|
|
|
|
def _materialize_with_param_init_fn(
|
|
root_module: nn.Module,
|
|
param_init_fn: Callable[[nn.Module], None],
|
|
ignored_modules: set[nn.Module],
|
|
) -> None:
|
|
if not callable(param_init_fn):
|
|
raise ValueError(
|
|
f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}"
|
|
)
|
|
modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
|
|
for module in modules_to_materialize:
|
|
param_init_fn(module)
|
|
|
|
|
|
def _materialize_meta_module(
|
|
root_module: nn.Module,
|
|
device_from_device_id: Optional[torch.device],
|
|
ignored_modules: set[nn.Module],
|
|
device_handle: _FSDPDeviceHandle,
|
|
):
|
|
# Run default meta device initialization
|
|
materialization_device = device_from_device_id or torch.device(
|
|
device_handle.current_device()
|
|
)
|
|
modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
|
|
module = None
|
|
try:
|
|
# Assume that each module's `reset_parameters()` only initializes its
|
|
# own parameters and not those of its children
|
|
with torch.no_grad():
|
|
for module in modules_to_materialize:
|
|
# As a contract to the user, only call `reset_parameters()` if
|
|
# the module has directly managed parameters/buffers
|
|
module_state_iter = itertools.chain(
|
|
module.parameters(recurse=False), module.buffers(recurse=False)
|
|
)
|
|
has_module_states = len(list(module_state_iter)) > 0
|
|
if has_module_states:
|
|
module.to_empty(device=materialization_device, recurse=False)
|
|
module.reset_parameters() # type: ignore[operator]
|
|
except BaseException as e:
|
|
warnings.warn(
|
|
"Unable to call `reset_parameters()` for module on meta "
|
|
f"device with error {str(e)}. Please ensure that your module of"
|
|
f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined]
|
|
)
|
|
raise e
|
|
|
|
|
|
def _get_modules_to_materialize(
|
|
root_module: nn.Module, ignored_modules: set[nn.Module]
|
|
) -> list[nn.Module]:
|
|
# Run BFS to collect the modules to materialize via `reset_parameters()`,
|
|
# stopping at any module with FSDP already applied or at ignored modules.
|
|
modules_to_materialize: list[nn.Module] = []
|
|
queue = collections.deque([root_module])
|
|
visited_modules: set[nn.Module] = {root_module}
|
|
while queue:
|
|
module = queue.popleft()
|
|
modules_to_materialize.append(module)
|
|
for child_module in module.children():
|
|
if (
|
|
child_module not in visited_modules
|
|
and _get_module_fsdp_state(child_module) is None
|
|
and child_module not in ignored_modules
|
|
):
|
|
visited_modules.add(child_module)
|
|
queue.append(child_module)
|
|
return modules_to_materialize
|
|
|
|
|
|
def _move_module_to_device(
|
|
module: nn.Module,
|
|
ignored_params: set[nn.Parameter],
|
|
ignored_buffers: set[torch.Tensor],
|
|
device_from_device_id: Optional[torch.device],
|
|
) -> None:
|
|
"""
|
|
Move ``module`` depending on ``device_from_device_id`` and its current device.
|
|
|
|
This includes moving ignored modules' parameters.
|
|
|
|
- If ``device_from_device_id`` is not ``None``, then this moves
|
|
``module`` to the device.
|
|
- If ``device_from_device_id`` is ``None``, then this does not move
|
|
``module`` but warns the user if it is on CPU.
|
|
|
|
Precondition: ``_check_single_device_module()``.
|
|
"""
|
|
cpu_device = torch.device("cpu")
|
|
if device_from_device_id is not None:
|
|
# BFS from `module` without traversing any nested FSDP instances to
|
|
# collect the parameters/buffers that have not yet been managed
|
|
queue: collections.deque[nn.Module] = collections.deque()
|
|
queue.append(module)
|
|
params: list[nn.Parameter] = []
|
|
buffers: list[torch.Tensor] = []
|
|
while queue:
|
|
curr_module = queue.popleft()
|
|
# NOTE: We include a check to only move parameters/buffers that are
|
|
# on CPU device. If they are on a CUDA device different from the
|
|
# one specified by `device_id`, then this does NOT move them. This
|
|
# is so that we can raise an error in `_get_compute_device()`.
|
|
params.extend(
|
|
param
|
|
for param in curr_module.parameters(recurse=False)
|
|
if param.device == cpu_device
|
|
)
|
|
buffers.extend(
|
|
buffer
|
|
for buffer in curr_module.buffers(recurse=False)
|
|
if buffer.device == cpu_device
|
|
)
|
|
for submodule in curr_module.children():
|
|
if not isinstance(submodule, fsdp_file.FullyShardedDataParallel):
|
|
queue.append(submodule)
|
|
params_to_move = [p for p in params if p not in ignored_params]
|
|
bufs_to_move = [p for p in buffers if p not in ignored_buffers]
|
|
_move_states_to_device(params_to_move, bufs_to_move, device_from_device_id)
|
|
return
|
|
param = next(_get_orig_params(module, ignored_params), None)
|
|
if param is not None and param.device == cpu_device:
|
|
_warn_cpu_init()
|
|
|
|
|
|
def _move_states_to_device(
|
|
params: list[nn.Parameter],
|
|
buffers: list[torch.Tensor],
|
|
device_from_device_id: Optional[torch.device],
|
|
) -> None:
|
|
"""
|
|
Move states to the specified device.
|
|
|
|
Precondition: ``_check_single_device_module()`` and module's parameters and
|
|
buffers have been materialized if needed.
|
|
"""
|
|
if len(params) == 0 and len(buffers) == 0:
|
|
return
|
|
if len(params) > 0:
|
|
current_device = params[0].device
|
|
elif len(buffers) > 0:
|
|
current_device = buffers[0].device
|
|
cpu_device = torch.device("cpu")
|
|
if device_from_device_id is not None:
|
|
# Move the parameters and buffers like the `.data` code path in
|
|
# `nn.Module._apply()`, which underlies `nn.Module.to()`
|
|
for param in params:
|
|
with torch.no_grad():
|
|
param.data = param.to(device_from_device_id)
|
|
if param.grad is not None:
|
|
param.grad.data = param.grad.to(device_from_device_id)
|
|
for buffer in buffers:
|
|
buffer.data = buffer.to(device_from_device_id)
|
|
elif current_device == cpu_device: # type: ignore[possibly-undefined]
|
|
_warn_cpu_init()
|
|
|
|
|
|
def _warn_cpu_init():
|
|
warnings.warn(
|
|
"The passed-in `module` is on CPU and will thus have FSDP's sharding "
|
|
"initialization run on CPU, which may be slower than on GPU. We "
|
|
"recommend passing in the `device_id` argument for FSDP to move "
|
|
"`module` to GPU for the sharding initialization. `module` must also "
|
|
"be on GPU device to work with the `sync_module_states=True` flag "
|
|
"since that requires GPU communication."
|
|
)
|
|
|
|
|
|
def _get_compute_device(
|
|
module: nn.Module,
|
|
ignored_params: set[nn.Parameter],
|
|
device_from_device_id: Optional[torch.device],
|
|
rank: int,
|
|
device_handle: _FSDPDeviceHandle,
|
|
) -> torch.device:
|
|
"""
|
|
Determine and return this FSDP instance's compute device.
|
|
|
|
If the module is already on a non-CPU device, then the compute device is that non-CPU
|
|
device. If the module is on CPU, then the compute device is the current
|
|
device.
|
|
|
|
Since this method should be called after materializing the module, any
|
|
non-CPU device should not be meta device. For now, the compute device is
|
|
always a CUDA or CUDA-like device with its explicit index.
|
|
|
|
Precondition: ``_check_single_device_module()`` and
|
|
``_move_module_to_device()``.
|
|
"""
|
|
param = next(_get_orig_params(module, ignored_params), None)
|
|
if param is not None and param.device.type != "cpu":
|
|
compute_device = param.device # Determined by model param placement
|
|
else:
|
|
compute_device = torch.device(device_handle.current_device())
|
|
if device_from_device_id is not None and compute_device != device_from_device_id:
|
|
raise ValueError(
|
|
f"Inconsistent compute device and `device_id` on rank {rank}: "
|
|
f"{compute_device} vs {device_from_device_id}"
|
|
)
|
|
return compute_device
|
|
|
|
|
|
# TODO: See how to deprecate!
|
|
def _sync_module_params_and_buffers(
|
|
module: nn.Module,
|
|
params: list[nn.Parameter],
|
|
process_group: dist.ProcessGroup,
|
|
) -> None:
|
|
"""
|
|
Synchronize module states (i.e. parameters ``params`` and all not-yet-synced buffers) by broadcasting from rank 0 to all ranks.
|
|
|
|
Precondition: ``sync_module_states == True`` and ``self.process_group`` has
|
|
been set.
|
|
"""
|
|
module_states: list[torch.Tensor] = []
|
|
for buffer in module.buffers():
|
|
# Avoid re-synchronizing buffers in case of nested wrapping
|
|
if not getattr(buffer, FSDP_SYNCED, False):
|
|
setattr(buffer, FSDP_SYNCED, True)
|
|
detached_buffer = buffer.detach()
|
|
if is_traceable_wrapper_subclass(detached_buffer):
|
|
# NOTE: Here we assume no nested subclasses, at most one level of subclass
|
|
# in both model's buffers and params
|
|
attrs, _ = detached_buffer.__tensor_flatten__() # type: ignore[attr-defined]
|
|
inner_buffers = [getattr(detached_buffer, attr) for attr in attrs]
|
|
module_states.extend(inner_buffers)
|
|
else:
|
|
module_states.append(detached_buffer)
|
|
|
|
for param in params:
|
|
detached_param = param.detach()
|
|
if is_traceable_wrapper_subclass(detached_param):
|
|
attrs, _ = detached_param.__tensor_flatten__() # type: ignore[attr-defined]
|
|
inner_params = [getattr(detached_param, attr) for attr in attrs]
|
|
module_states.extend(inner_params)
|
|
else:
|
|
module_states.append(detached_param)
|
|
|
|
_check_module_states_for_sync_module_states(module_states)
|
|
_sync_params_and_buffers(
|
|
process_group,
|
|
module_states,
|
|
PARAM_BROADCAST_BUCKET_SIZE,
|
|
src=0,
|
|
)
|
|
|
|
|
|
def _check_module_states_for_sync_module_states(
|
|
module_states: list[torch.Tensor],
|
|
) -> None:
|
|
if module_states and any(
|
|
tensor.device == torch.device("cpu") for tensor in module_states
|
|
):
|
|
raise ValueError(
|
|
"The module has CPU parameters or buffers when `sync_module_states=True`, "
|
|
"which requires them to be on GPU. Please specify the `device_id` argument "
|
|
"or move the module to GPU before passing it to FSDP."
|
|
)
|
|
|
|
|
|
def _get_orig_params(
|
|
module: nn.Module,
|
|
ignored_params: set[nn.Parameter],
|
|
) -> Iterator[nn.Parameter]:
|
|
"""
|
|
Return an iterator over the original parameters in ``module``.
|
|
|
|
The iterator does not return
|
|
the parameters in ``ignored_params``, any ``FlatParameter`` s (which may be
|
|
present due to nested FSDP wrapping), or any original parameters already
|
|
flattened (only relevant when ``use_orig_params=True``).
|
|
"""
|
|
param_gen = module.parameters()
|
|
try:
|
|
while True:
|
|
param = next(param_gen)
|
|
if param not in ignored_params and not _is_fsdp_flattened(param):
|
|
yield param
|
|
except StopIteration:
|
|
pass
|
|
|
|
|
|
def _check_orig_params_flattened(
|
|
fsdp_module,
|
|
ignored_params: set[nn.Parameter],
|
|
) -> None:
|
|
"""
|
|
Check that original parameters in ``fsdp_module`` have been flattened.
|
|
|
|
The flattened parameters are made
|
|
invisible to ``named_parameters()`` for the module hierarchy rooted at
|
|
``fsdp_module``. This should be called as a sanity check after flattening
|
|
the wrapped module's parameters.
|
|
"""
|
|
for param_name, param in _named_parameters_with_duplicates(fsdp_module):
|
|
if param not in ignored_params and not _is_fsdp_flattened(param):
|
|
raise RuntimeError(
|
|
f"Found an unflattened parameter: {param_name}; "
|
|
f"{param.size()} {param.__class__}"
|
|
)
|
|
|
|
|
|
def _get_default_comm_hook(sharding_strategy: ShardingStrategy):
|
|
return (
|
|
default_hooks.allreduce_hook
|
|
if sharding_strategy == ShardingStrategy.NO_SHARD
|
|
else default_hooks.reduce_scatter_hook
|
|
)
|
|
|
|
|
|
def _get_default_comm_hook_state(
|
|
process_group: dist.ProcessGroup,
|
|
) -> default_hooks.DefaultState:
|
|
return default_hooks.DefaultState(process_group=process_group)
|