mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547 Approved by: https://github.com/kwen2501
107 lines
3.7 KiB
Python
107 lines
3.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
from functools import partial
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
|
|
|
|
|
|
__all__ = [
|
|
"input_reshard",
|
|
]
|
|
|
|
|
|
def input_reshard(
|
|
module: torch.nn.Module,
|
|
tp_device_mesh: DeviceMesh,
|
|
input_reshard_dim: Optional[int] = None,
|
|
) -> torch.nn.Module:
|
|
"""
|
|
Register hooks to an nn.Module for input resharding, enabling sharding and restoration during backward computation.
|
|
|
|
Register hooks to an nn.Module with input resharding so that we can shard
|
|
per the given `tp_device_mesh` and `input_reshard_dim` and restore the
|
|
input back when recomputing the activations in the backward. The reason
|
|
why we can do this is that for Tensor Parallel(TP), the input are same
|
|
across all TP ranks.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be registered with input resharding.
|
|
tp_device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology
|
|
of devices for Tensor Parallel.
|
|
input_reshard_dim (Optional[int]):
|
|
The dimension of where we perform the sharding
|
|
of input. If set None, there is no sharding of input.
|
|
Default: None
|
|
|
|
Return:
|
|
A :class:`nn.Module` object registered with TP input resharding.
|
|
"""
|
|
if input_reshard_dim is None:
|
|
return module
|
|
|
|
cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None
|
|
|
|
def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: tuple[Any, ...]) -> None:
|
|
saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks(
|
|
partial(_pack_hook_tp, tp_device_mesh, input_reshard_dim),
|
|
partial(_unpack_hook_tp, tp_device_mesh, input_reshard_dim),
|
|
)
|
|
saved_tensor_hooks.__enter__()
|
|
nonlocal cx
|
|
cx = saved_tensor_hooks # type: ignore[name-defined]
|
|
|
|
def input_reshard_backward_hook(
|
|
_: torch.nn.Module, _i: tuple[Any, ...], _o: Any
|
|
) -> Any:
|
|
nonlocal cx
|
|
cx.__exit__() # type: ignore[name-defined, union-attr]
|
|
|
|
module.register_forward_pre_hook(input_reshard_forward_pre_hook)
|
|
module.register_forward_hook(input_reshard_backward_hook)
|
|
return module
|
|
|
|
|
|
def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: # noqa: D401
|
|
"""Hook function called after FWD to shard input."""
|
|
if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements):
|
|
return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
|
|
elif (
|
|
not isinstance(x, DTensor)
|
|
and isinstance(x, torch.Tensor)
|
|
and x.numel() >= mesh.size()
|
|
):
|
|
return (
|
|
DTensor.from_local(x, device_mesh=mesh)
|
|
.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
|
|
.to_local()
|
|
)
|
|
else:
|
|
return x
|
|
|
|
|
|
def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: # noqa: D401
|
|
"""Hook function called before activation recomputing in BWD to restore input."""
|
|
if (
|
|
isinstance(x, DTensor)
|
|
and len(x._spec.placements) == 1
|
|
and x._spec.placements[0].is_shard()
|
|
):
|
|
return x.redistribute(device_mesh=mesh, placements=[Replicate()])
|
|
elif (
|
|
not isinstance(x, DTensor)
|
|
and isinstance(x, torch.Tensor)
|
|
and x.numel() >= mesh.size()
|
|
):
|
|
return (
|
|
DTensor.from_local(
|
|
x, device_mesh=mesh, placements=[Shard(input_reshard_dim)]
|
|
)
|
|
.redistribute(device_mesh=mesh, placements=[Replicate()])
|
|
.to_local()
|
|
)
|
|
else:
|
|
return x
|