mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843 Approved by: https://github.com/oulgen ghstack dependencies: #127842
98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Any, List, Tuple
|
|
|
|
import torch.nn as nn
|
|
from torch.distributed.tensor.parallel._data_parallel_utils import (
|
|
_flatten_tensor,
|
|
_unflatten_tensor,
|
|
)
|
|
|
|
__all__ = [] # type: ignore[var-annotated]
|
|
|
|
|
|
def _get_submodule_n_params(module: nn.Module, path: str):
|
|
"""
|
|
Get submodule and the direct path of parameter from the module
|
|
"""
|
|
if "." in path:
|
|
path_list = path.split(".")
|
|
parent_module_path = ".".join(path_list[:-1])
|
|
module = module.get_submodule(parent_module_path)
|
|
path = path_list[-1]
|
|
return module, path
|
|
|
|
|
|
def _update_module_param(param_list: List[Tuple[nn.Module, str, nn.Parameter]]):
|
|
"""
|
|
Update parameters within the module
|
|
"""
|
|
for item in param_list:
|
|
parent_module, module_path, t = item
|
|
assert hasattr(parent_module, module_path)
|
|
delattr(parent_module, module_path)
|
|
setattr(parent_module, module_path, t)
|
|
|
|
|
|
def _reconstruct_dtensor(module: nn.Module, _input: Any):
|
|
"""
|
|
Recontruct DTensor parameters from local tensors
|
|
"""
|
|
param_list = []
|
|
# TODO: To add perf optimizations to this iterations
|
|
for name, t in module.named_parameters():
|
|
if hasattr(t, "_st_info"):
|
|
dtensor = _unflatten_tensor(t, t._st_info)
|
|
param_list.append((*_get_submodule_n_params(module, name), dtensor))
|
|
_update_module_param(param_list) # type: ignore[arg-type]
|
|
|
|
|
|
def _localize_dtensor(module: nn.Module, *_: Any):
|
|
"""
|
|
Convert DTensor parameters to local tensors
|
|
"""
|
|
param_list = []
|
|
for name, param in module.named_parameters():
|
|
t, sharding_info = _flatten_tensor(param)
|
|
if sharding_info is not None:
|
|
t = nn.Parameter(t)
|
|
t._st_info = sharding_info # type: ignore[attr-defined]
|
|
param_list.append((*_get_submodule_n_params(module, name), t))
|
|
_update_module_param(param_list) # type: ignore[arg-type]
|
|
|
|
|
|
def _pre_dp_module_transform(module: nn.Module):
|
|
"""
|
|
Enable the composability between Tensor Parallelism (TP) and Data
|
|
Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which
|
|
are DTensors to local tensors before wrapping with data parallelism API.
|
|
We then register two hooks, one for converting local tensors back to DTensor
|
|
preforward and one to convert DTensors back to tensors after Forward. By
|
|
integrating this way, we avoid any special handling of DTensor parameters by DDP
|
|
and get DTensor's gradients propagated back to DP, e.g. gradient buckets of DDP.
|
|
|
|
For now, this API only works with ``DistributedDataParallel``. It will later support
|
|
other DP methods such as FSDP.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module which has been applied TP on.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP("distributed")
|
|
>>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
|
|
>>> from torch.nn.parallel import DistributedDataParallel as DDP
|
|
>>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform
|
|
>>>
|
|
>>> # Define the module.
|
|
>>> m = module(...)
|
|
>>> parallelize_module(m, PairwiseParallel())
|
|
>>> m = pre_dp_module_transform(m)
|
|
>>> m = DDP(m)
|
|
>>>
|
|
"""
|
|
|
|
_localize_dtensor(module, None, None)
|
|
# TODO: To add test cases and ensure that it works for nested modules
|
|
module.register_forward_pre_hook(_reconstruct_dtensor)
|
|
module.register_forward_hook(_localize_dtensor)
|