mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR re-lands - [Typing] Fix PEP 484 Violation (#105022) - Update mypy to 1.4.1 (#91983) That were reverted due to the conflict with internal source repo. Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional) Plus few real fixes: - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi` - Add missing return statement to `torch._export. deserialize_graph` - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights` - Add assert it `torch/optim/optimizer.py` that Optional list is not None TODO (in followup PR): - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py` Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04: - Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh` - Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done) Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227 Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
107 lines
3.5 KiB
Python
107 lines
3.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
from functools import partial
|
|
from typing import Any, Optional, Tuple
|
|
|
|
import torch
|
|
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
|
|
|
__all__ = [
|
|
"input_reshard",
|
|
]
|
|
|
|
|
|
def input_reshard(
|
|
module: torch.nn.Module,
|
|
tp_device_mesh: DeviceMesh,
|
|
input_reshard_dim: Optional[int] = None,
|
|
) -> torch.nn.Module:
|
|
"""
|
|
Register hooks to an nn.Module with input resharding so that we can shard
|
|
per the given `tp_device_mesh` and `input_reshard_dim` and restore the
|
|
input back when recomputing the activations in the backward. The reason
|
|
why we can do this is that for Tensor Parallel(TP), the input are same
|
|
across all TP ranks.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`):
|
|
Module to be registered with input resharding.
|
|
tp_device_mesh (:class:`DeviceMesh`):
|
|
Object which describes the mesh topology
|
|
of devices for Tensor Parallel.
|
|
input_reshard_dim (Optional[int]):
|
|
The dimension of where we perform the sharding
|
|
of input. If set None, there is no sharding of input.
|
|
Default: None
|
|
|
|
Return:
|
|
A :class:`nn.Module` object registered with TP input resharding.
|
|
"""
|
|
cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None
|
|
|
|
def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> None:
|
|
saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks(
|
|
partial(_pack_hook_tp, tp_device_mesh, input_reshard_dim),
|
|
partial(_unpack_hook_tp, tp_device_mesh, input_reshard_dim),
|
|
)
|
|
saved_tensor_hooks.__enter__()
|
|
nonlocal cx
|
|
cx = saved_tensor_hooks # type: ignore[name-defined]
|
|
|
|
def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any) -> Any:
|
|
nonlocal cx
|
|
cx.__exit__() # type: ignore[name-defined, union-attr]
|
|
|
|
if input_reshard_dim is None:
|
|
return module
|
|
module.register_forward_pre_hook(input_reshard_forward_pre_hook)
|
|
module.register_forward_hook(input_reshard_backward_hook)
|
|
return module
|
|
|
|
|
|
def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any:
|
|
"""
|
|
Hook functions called after FWD to shard input.
|
|
"""
|
|
|
|
if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements):
|
|
return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
|
|
elif (
|
|
not isinstance(x, DTensor)
|
|
and isinstance(x, torch.Tensor)
|
|
and x.numel() >= mesh.size()
|
|
):
|
|
return (
|
|
DTensor.from_local(x, device_mesh=mesh)
|
|
.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
|
|
.to_local()
|
|
)
|
|
else:
|
|
return x
|
|
|
|
|
|
def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor:
|
|
"""
|
|
Hook functions called before activation recomputing in BWD to restore input.
|
|
"""
|
|
|
|
if (
|
|
isinstance(x, DTensor)
|
|
and len(x._spec.placements) == 1
|
|
and x._spec.placements[0].is_shard()
|
|
):
|
|
return x.redistribute(device_mesh=mesh, placements=[Replicate()])
|
|
elif (
|
|
not isinstance(x, DTensor)
|
|
and isinstance(x, torch.Tensor)
|
|
and x.numel() >= mesh.size()
|
|
):
|
|
return (
|
|
DTensor.from_local(
|
|
x, device_mesh=mesh, placements=[Shard(input_reshard_dim)]
|
|
)
|
|
.redistribute(device_mesh=mesh, placements=[Replicate()])
|
|
.to_local()
|
|
)
|
|
else:
|
|
return x
|