Files
pytorch/torch/testing/_internal/common_dist_composable.py
Andrew Gu aec09eeb3a [FSDP][7/N] Support replicate in fully_shard (#91044)
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00

107 lines
3.2 KiB
Python

# Owner(s): ["oncall: distributed"]
from typing import Tuple
import torch
import torch.nn as nn
class UnitModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l1 = nn.Linear(100, 100, device=device)
self.seq = nn.Sequential(
nn.ReLU(),
nn.Linear(100, 100, device=device),
nn.ReLU(),
)
self.l2 = nn.Linear(100, 100, device=device)
def forward(self, x):
return self.l2(self.seq(self.l1(x)))
class CompositeModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l1 = nn.Linear(100, 100, device=device)
self.u1 = UnitModule(device)
self.u2 = UnitModule(device)
self.l2 = nn.Linear(100, 100, device=device)
def forward(self, x):
return self.l2(self.u2(self.u1(self.l1(x))))
class UnitParamModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l = nn.Linear(100, 100, device=device)
self.seq = nn.Sequential(
nn.ReLU(),
nn.Linear(100, 100, device=device),
nn.ReLU(),
)
self.p = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, x):
return torch.mm(self.seq(self.l(x)), self.p)
class CompositeParamModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l = nn.Linear(100, 100, device=device)
self.u1 = UnitModule(device)
self.u2 = UnitModule(device)
self.p = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, x):
a = self.u2(self.u1(self.l(x)))
b = self.p
return torch.mm(a, b)
class FakeSequential(nn.Module):
# Define this class to achieve a desired nested wrapping using the module
# wrap policy with `nn.Sequential`
def __init__(self, *modules: Tuple[nn.Module, ...]) -> None:
super().__init__()
self._module_sequence = list(modules)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for module in self._module_sequence:
x = module(x)
return x
class NestedSequentialModel(nn.Module):
def __init__(self, device: torch.device) -> None:
super().__init__()
# This nested structure exercises traversal order to catch differences
# between valid traversals (e.g. BFS and DFS variations).
self.seq1 = nn.Sequential(
nn.Linear(1, 1, device=device),
FakeSequential(
nn.Linear(1, 1, device=device),
nn.ReLU(),
FakeSequential(
nn.Linear(1, 1, device=device),
),
nn.ReLU(),
),
nn.Linear(1, 2, device=device),
)
self.lin = nn.Linear(2, 2, device=device)
self.seq2 = nn.Sequential(
nn.ReLU(),
nn.Linear(2, 3, device=device),
FakeSequential(
nn.Linear(3, 2, bias=False, device=device),
nn.Linear(2, 4, bias=False, device=device),
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.seq2(self.lin(self.seq1(x)))