Files
pytorch/torch/distributed/tensor/parallel/input_reshard.py
Nikita Shulga 5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00

107 lines
3.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
from functools import partial
from typing import Any, Optional, Tuple
import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
__all__ = [
"input_reshard",
]
def input_reshard(
module: torch.nn.Module,
tp_device_mesh: DeviceMesh,
input_reshard_dim: Optional[int] = None,
) -> torch.nn.Module:
"""
Register hooks to an nn.Module with input resharding so that we can shard
per the given `tp_device_mesh` and `input_reshard_dim` and restore the
input back when recomputing the activations in the backward. The reason
why we can do this is that for Tensor Parallel(TP), the input are same
across all TP ranks.
Args:
module (:class:`nn.Module`):
Module to be registered with input resharding.
tp_device_mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for Tensor Parallel.
input_reshard_dim (Optional[int]):
The dimension of where we perform the sharding
of input. If set None, there is no sharding of input.
Default: None
Return:
A :class:`nn.Module` object registered with TP input resharding.
"""
cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None
def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> None:
saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks(
partial(_pack_hook_tp, tp_device_mesh, input_reshard_dim),
partial(_unpack_hook_tp, tp_device_mesh, input_reshard_dim),
)
saved_tensor_hooks.__enter__()
nonlocal cx
cx = saved_tensor_hooks # type: ignore[name-defined]
def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any) -> Any:
nonlocal cx
cx.__exit__() # type: ignore[name-defined, union-attr]
if input_reshard_dim is None:
return module
module.register_forward_pre_hook(input_reshard_forward_pre_hook)
module.register_forward_hook(input_reshard_backward_hook)
return module
def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any:
"""
Hook functions called after FWD to shard input.
"""
if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements):
return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
elif (
not isinstance(x, DTensor)
and isinstance(x, torch.Tensor)
and x.numel() >= mesh.size()
):
return (
DTensor.from_local(x, device_mesh=mesh)
.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
.to_local()
)
else:
return x
def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor:
"""
Hook functions called before activation recomputing in BWD to restore input.
"""
if (
isinstance(x, DTensor)
and len(x._spec.placements) == 1
and x._spec.placements[0].is_shard()
):
return x.redistribute(device_mesh=mesh, placements=[Replicate()])
elif (
not isinstance(x, DTensor)
and isinstance(x, torch.Tensor)
and x.numel() >= mesh.size()
):
return (
DTensor.from_local(
x, device_mesh=mesh, placements=[Shard(input_reshard_dim)]
)
.redistribute(device_mesh=mesh, placements=[Replicate()])
.to_local()
)
else:
return x