mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
**Summary** As much of ReplicateState functionality is copied from FSDPState, I fixed any remaining comments that incorrectly used FSDP instead of Replicate. In addition, instead of labeling modules FSDPModule or FSDPLinear, I have changed it so that is now uses Replicate____. Finally, I have removed some leftover code from the DDP implementation. I have included test cases to verify correctness. **Test Case** 1. pytest test/distributed/_composable/test_replicate_with_fsdp.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/160133 Approved by: https://github.com/mori360 ghstack dependencies: #160128
380 lines
13 KiB
Python
380 lines
13 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Callable, Optional, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.distributed._composable_state import _get_module_state, _insert_module_state
|
|
from torch.distributed.device_mesh import _get_device_handle
|
|
from torch.distributed.fsdp._fully_shard._fsdp_api import (
|
|
MixedPrecisionPolicy,
|
|
OffloadPolicy,
|
|
)
|
|
from torch.distributed.fsdp._fully_shard._fsdp_common import (
|
|
detect_compiled_autograd,
|
|
HSDPMeshInfo,
|
|
)
|
|
from torch.distributed.fsdp._fully_shard._fsdp_init import (
|
|
_get_device_from_mesh,
|
|
_get_managed_states,
|
|
_get_post_forward_mesh_info,
|
|
_init_default_fully_shard_mesh,
|
|
_move_states_to_device,
|
|
)
|
|
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
|
|
from torch.distributed.fsdp._fully_shard._fsdp_state import (
|
|
_register_group_forward_hooks,
|
|
FSDPState,
|
|
)
|
|
from torch.distributed.fsdp._fully_shard._fully_shard import (
|
|
_unimplemented_deepcopy,
|
|
FSDPModule,
|
|
)
|
|
from torch.distributed.tensor import DeviceMesh, init_device_mesh
|
|
from torch.distributed.utils import _get_root_modules
|
|
|
|
from .contract import _get_registry, contract
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch.distributed.tensor import Shard
|
|
|
|
|
|
cls_to_replicate_cls: dict[type, type] = {}
|
|
|
|
_ROOT_MODULE_PREFIX = ""
|
|
|
|
logger = logging.getLogger("torch.distributed._composable.replicate_with_fsdp")
|
|
|
|
|
|
class _ReplicateStateContext:
|
|
"""This has state shared across Replicate states."""
|
|
|
|
def __init__(self) -> None:
|
|
# All Replicate states in the root state's module tree
|
|
self.all_states: list[_ReplicateState] = []
|
|
# Iteration's forward root runs the once-per-forward logic; this root
|
|
# may not be the overall root set by lazy initialization in cases where
|
|
# only a submodule runs forward (e.g. encoder-only for eval)
|
|
self.iter_forward_root: Optional[_ReplicateState] = None
|
|
# Final callback should only be queued once per backward
|
|
self.post_backward_final_callback_queued: bool = False
|
|
# Whether to finalize backward in this backward's final callback
|
|
self.is_last_backward: bool = True
|
|
# Optional user-provided event recorded after optimizer for the
|
|
# all-gather streams to wait on in the root pre-forward
|
|
self.post_optim_event: Optional[torch.Event] = None
|
|
|
|
|
|
def _get_module_replicate_state(module: nn.Module) -> Optional[_ReplicateState]:
|
|
"""Checks if module state is ReplicateState"""
|
|
state = _get_module_state(module)
|
|
if isinstance(state, _ReplicateState):
|
|
return state
|
|
return None
|
|
|
|
|
|
class _ReplicateState(FSDPState):
|
|
"""
|
|
Replicate state functionality is adapted from FSDP state.
|
|
In the future, could experiment with inheriting from it instead.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._state_ctx = _ReplicateStateContext() # type: ignore[assignment]
|
|
|
|
# Define a separate init since `__init__` is called in the contract
|
|
def init(
|
|
self,
|
|
modules: tuple[nn.Module, ...],
|
|
device: torch.device,
|
|
mp_policy: MixedPrecisionPolicy,
|
|
auto_reshard_after_forward: bool,
|
|
) -> None:
|
|
for module in modules:
|
|
_insert_module_state(module, self)
|
|
self._modules = modules
|
|
self._device = device
|
|
self._device_handle = _get_device_handle(device.type)
|
|
self._mp_policy = mp_policy
|
|
self._auto_reshard_after_forward = auto_reshard_after_forward
|
|
if len(modules) == 1:
|
|
self._pre_forward_hook_handle = modules[0].register_forward_pre_hook(
|
|
self._pre_forward, prepend=True, with_kwargs=True
|
|
)
|
|
self._post_forward_hook_handle = modules[0].register_forward_hook(
|
|
self._post_forward, prepend=False
|
|
)
|
|
else:
|
|
hook_handle = _register_group_forward_hooks(
|
|
modules,
|
|
self._pre_forward,
|
|
self._post_forward,
|
|
self._modules_to_run_forward,
|
|
)
|
|
self._pre_forward_hook_handle = hook_handle
|
|
self._post_forward_hook_handle = hook_handle
|
|
|
|
def _lazy_init(self) -> None:
|
|
"""
|
|
Lazy initialization represents when all modules' parallelisms have
|
|
finalized (e.g. Replicate has been applied to all desired modules). This
|
|
means that we can determine which state is the root, and we do so by
|
|
the 1st state to run forward.
|
|
"""
|
|
if self._is_root is not None:
|
|
return # no-op: already initialized
|
|
self._is_root = True
|
|
if len(self._modules) > 1:
|
|
raise RuntimeError(
|
|
f"Replicate requires a single root module but got {self._modules}"
|
|
)
|
|
detect_compiled_autograd()
|
|
root_module = self._modules[0]
|
|
visited_states: set[_ReplicateState] = set()
|
|
for module_name, module in root_module.named_modules():
|
|
if (state := _get_module_replicate_state(module)) is None:
|
|
continue
|
|
if module is not root_module:
|
|
if state not in visited_states and state._is_root is not None:
|
|
raise RuntimeError(
|
|
"Replicate state has already been lazily initialized for "
|
|
f"{module_name}\nReplicate requires running forward through "
|
|
"the root module first"
|
|
)
|
|
state._is_root = False
|
|
self._state_ctx.all_states.append(state)
|
|
visited_states.add(state)
|
|
if self._fsdp_param_group and self._auto_reshard_after_forward:
|
|
# For the root, do not reshard after forward since for training,
|
|
# the parameters would be freed and all-gathered immediately
|
|
self._fsdp_param_group.post_forward_mesh_info = None
|
|
self._init_fqns()
|
|
self._init_shared_state()
|
|
# Run parameter group lazy inits after initializing FQNs for improved
|
|
# error messages
|
|
for state in self._state_ctx.all_states: # type: ignore[assignment]
|
|
if state._fsdp_param_group: # type: ignore[union-attr]
|
|
state._fsdp_param_group.lazy_init() # type: ignore[union-attr]
|
|
|
|
|
|
def replicate_impl(
|
|
module,
|
|
mesh: DeviceMesh,
|
|
*,
|
|
device_id: Optional[Union[int, torch.device]] = None,
|
|
reshard_after_forward: Optional[Union[bool, int]] = None,
|
|
shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None,
|
|
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
|
|
offload_policy: OffloadPolicy = OffloadPolicy(),
|
|
ignored_params: Optional[set[nn.Parameter]] = None,
|
|
):
|
|
torch._C._log_api_usage_once("torch.distributed._composable.replicate_with_fsdp")
|
|
if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
|
|
raise ValueError(
|
|
f"replicate does not support containers that do not implement forward: {module}"
|
|
)
|
|
|
|
mesh = mesh or _init_default_fully_shard_mesh()
|
|
if mesh.ndim != 2:
|
|
raise ValueError(f"replicate expects a 2D DeviceMesh but got {mesh}")
|
|
|
|
else:
|
|
if mesh.mesh_dim_names is None:
|
|
raise AssertionError(
|
|
"Please init the 2D mesh for HSDP with mesh_dim_names specified"
|
|
)
|
|
mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
|
|
device = _get_device_from_mesh(mesh)
|
|
auto_reshard_after_forward = reshard_after_forward is None
|
|
# If the user does not provide ``reshard_after_forward``, we set it to True.
|
|
# During lazy_init, we identify which module is the root and override its value to False
|
|
post_forward_mesh_info = _get_post_forward_mesh_info(
|
|
reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type]
|
|
mesh_info,
|
|
)
|
|
|
|
arg_module = module
|
|
modules = (
|
|
(module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module))
|
|
)
|
|
state = replicate.state(modules[0]) # type: ignore[attr-defined] # see [1]
|
|
state.init(modules, device, mp_policy, auto_reshard_after_forward)
|
|
|
|
managed_modules = _get_managed_modules(modules, ignored_params)
|
|
params, buffers = _get_managed_states(managed_modules, ignored_params)
|
|
|
|
_move_states_to_device(params, buffers, device)
|
|
if params:
|
|
state._fsdp_param_group = FSDPParamGroup(
|
|
params,
|
|
modules,
|
|
mesh_info,
|
|
post_forward_mesh_info,
|
|
device,
|
|
shard_placement_fn,
|
|
mp_policy,
|
|
offload_policy,
|
|
)
|
|
|
|
# Place Replicate leftmost for highest priority in the method resolution order
|
|
for module in modules:
|
|
cls = module.__class__
|
|
new_cls = cls_to_replicate_cls.get(cls, None)
|
|
if not new_cls:
|
|
dct = {"__deepcopy__": _unimplemented_deepcopy}
|
|
new_cls = type(f"Replicate{cls.__name__}", (FSDPModule, cls), dct)
|
|
cls_to_replicate_cls[cls] = new_cls
|
|
module.__class__ = new_cls
|
|
return arg_module
|
|
|
|
|
|
@contract(state_cls=_ReplicateState)
|
|
def replicate(
|
|
module: nn.Module,
|
|
**kwargs,
|
|
) -> nn.Module:
|
|
r"""Replicates a module
|
|
|
|
Args:
|
|
module (torch.nn.Module): module to replicate
|
|
|
|
Example::
|
|
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
|
|
>>> module = nn.Linear(3, 3)
|
|
>>> replicate(module)
|
|
"""
|
|
|
|
if "device_id" in kwargs:
|
|
if not isinstance(kwargs["device_id"], (int, torch.device)):
|
|
raise RuntimeError(
|
|
"Expected device_id to be int or torch.device, "
|
|
f"but got {type(kwargs['device_id'])}"
|
|
)
|
|
|
|
if not is_composable_with_replicate(module):
|
|
raise RuntimeError(
|
|
"Cannot apply `replicate()` on a Module already managed by `fully_shard`"
|
|
)
|
|
|
|
device_mesh = kwargs.pop("device_mesh", None)
|
|
if device_mesh is None:
|
|
device_mesh = replicate_mesh()
|
|
|
|
module = replicate_impl(module, mesh=device_mesh, **kwargs)
|
|
return module
|
|
|
|
|
|
def _get_managed_modules(
|
|
root_modules: tuple[nn.Module, ...],
|
|
ignored_params: Optional[set[nn.Parameter]] = None,
|
|
) -> list[nn.Module]:
|
|
modules: list[nn.Module] = []
|
|
root_modules_set = set(root_modules)
|
|
# Track visisted modules to avoid visiting shared modules multiple times
|
|
visited_modules: set[nn.Module] = set()
|
|
|
|
def dfs(module: nn.Module) -> None:
|
|
"""
|
|
Runs a DFS to collect managed modules, not recursing into modules with
|
|
a non-composable API or ``replicate`` already applied.
|
|
"""
|
|
if not is_composable_with_replicate(module):
|
|
return
|
|
elif (
|
|
module not in root_modules_set
|
|
and _get_module_replicate_state(module) is not None
|
|
):
|
|
return # nested `fully_shard` module
|
|
visited_modules.add(module)
|
|
for submodule in module.children():
|
|
if submodule not in visited_modules:
|
|
dfs(submodule)
|
|
modules.append(module)
|
|
|
|
for root_module in root_modules:
|
|
dfs(root_module)
|
|
|
|
if ignored_params is None:
|
|
return modules
|
|
|
|
adjusted_modules = _adjust_managed_modules(modules, ignored_params)
|
|
return adjusted_modules
|
|
|
|
|
|
def is_composable_with_replicate(module: nn.Module) -> bool:
|
|
"""Checks if replicate can be applied with module"""
|
|
registry = _get_registry(module)
|
|
if registry is None:
|
|
return True
|
|
# Registry keys by function name
|
|
return "fully_shard" not in registry
|
|
|
|
|
|
def replicate_mesh():
|
|
"""Creates a device mesh for replicate if the user doesn't provide one"""
|
|
if not dist.distributed_c10d.is_initialized():
|
|
dist.distributed_c10d.init_process_group()
|
|
default_pg = dist.distributed_c10d._get_default_group()
|
|
device = torch._C._get_accelerator()
|
|
mesh = init_device_mesh(
|
|
device.type,
|
|
mesh_shape=(default_pg.size(), 1),
|
|
mesh_dim_names=("replicate", "shard"),
|
|
)
|
|
return mesh
|
|
|
|
|
|
def _adjust_managed_modules(
|
|
modules: list[nn.Module], ignored_params: set[nn.Parameter]
|
|
) -> list[nn.Module]:
|
|
"""
|
|
Adjust the given list of managed modules by removing those with all parameters ignored.
|
|
"""
|
|
ignore_decision: dict[nn.Module, bool] = {}
|
|
new_modules = []
|
|
for module in modules:
|
|
ignored = _ignore_module(module, ignored_params, ignore_decision)
|
|
if not ignored:
|
|
new_modules.append(module)
|
|
return new_modules
|
|
|
|
|
|
def _ignore_module(
|
|
module: nn.Module,
|
|
ignored_params: set[nn.Parameter],
|
|
ignore_decision: dict[nn.Module, bool],
|
|
) -> bool:
|
|
"""
|
|
Decide if it is safe to ignore a module for applying replicate.
|
|
"""
|
|
if module in ignore_decision:
|
|
return ignore_decision[module]
|
|
|
|
if len(list(module.buffers(recurse=False))) > 0:
|
|
# Cannot ignore a module with any buffer
|
|
ignore_decision[module] = False
|
|
return False
|
|
|
|
for _, param in module.named_parameters(recurse=False):
|
|
if param not in ignored_params:
|
|
# at least one param is not ignored. So this module shouldn't be.
|
|
ignore_decision[module] = False
|
|
return False
|
|
|
|
# Need to consider descendants of module
|
|
for child in list(module.children()):
|
|
ignore_child = _ignore_module(child, ignored_params, ignore_decision)
|
|
if not ignore_child:
|
|
# Cannot ignore module if one of its children is not ignored
|
|
ignore_decision[module] = False
|
|
return False
|
|
|
|
# Safe to ignore module
|
|
ignore_decision[module] = True
|
|
return True
|