mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.
---
This PR reworks some tree traversal.
One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
submod1: Submodule()
submod2: Submodule(
subsubmod: Subsubmodule(),
),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.
Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.
The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.
- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.
---
Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.
The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class UnitModule(nn.Module):
|
|
def __init__(self, device: torch.device):
|
|
super().__init__()
|
|
self.l1 = nn.Linear(100, 100, device=device)
|
|
self.seq = nn.Sequential(
|
|
nn.ReLU(),
|
|
nn.Linear(100, 100, device=device),
|
|
nn.ReLU(),
|
|
)
|
|
self.l2 = nn.Linear(100, 100, device=device)
|
|
|
|
def forward(self, x):
|
|
return self.l2(self.seq(self.l1(x)))
|
|
|
|
|
|
class CompositeModel(nn.Module):
|
|
def __init__(self, device: torch.device):
|
|
super().__init__()
|
|
self.l1 = nn.Linear(100, 100, device=device)
|
|
self.u1 = UnitModule(device)
|
|
self.u2 = UnitModule(device)
|
|
self.l2 = nn.Linear(100, 100, device=device)
|
|
|
|
def forward(self, x):
|
|
return self.l2(self.u2(self.u1(self.l1(x))))
|
|
|
|
|
|
class UnitParamModule(nn.Module):
|
|
def __init__(self, device: torch.device):
|
|
super().__init__()
|
|
self.l = nn.Linear(100, 100, device=device)
|
|
self.seq = nn.Sequential(
|
|
nn.ReLU(),
|
|
nn.Linear(100, 100, device=device),
|
|
nn.ReLU(),
|
|
)
|
|
self.p = nn.Parameter(torch.randn((100, 100), device=device))
|
|
|
|
def forward(self, x):
|
|
return torch.mm(self.seq(self.l(x)), self.p)
|
|
|
|
|
|
class CompositeParamModel(nn.Module):
|
|
def __init__(self, device: torch.device):
|
|
super().__init__()
|
|
self.l = nn.Linear(100, 100, device=device)
|
|
self.u1 = UnitModule(device)
|
|
self.u2 = UnitModule(device)
|
|
self.p = nn.Parameter(torch.randn((100, 100), device=device))
|
|
|
|
def forward(self, x):
|
|
a = self.u2(self.u1(self.l(x)))
|
|
b = self.p
|
|
return torch.mm(a, b)
|
|
|
|
|
|
class FakeSequential(nn.Module):
|
|
# Define this class to achieve a desired nested wrapping using the module
|
|
# wrap policy with `nn.Sequential`
|
|
def __init__(self, *modules: Tuple[nn.Module, ...]) -> None:
|
|
super().__init__()
|
|
self._module_sequence = list(modules)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
for module in self._module_sequence:
|
|
x = module(x)
|
|
return x
|
|
|
|
|
|
class NestedSequentialModel(nn.Module):
|
|
def __init__(self, device: torch.device) -> None:
|
|
super().__init__()
|
|
# This nested structure exercises traversal order to catch differences
|
|
# between valid traversals (e.g. BFS and DFS variations).
|
|
self.seq1 = nn.Sequential(
|
|
nn.Linear(1, 1, device=device),
|
|
FakeSequential(
|
|
nn.Linear(1, 1, device=device),
|
|
nn.ReLU(),
|
|
FakeSequential(
|
|
nn.Linear(1, 1, device=device),
|
|
),
|
|
nn.ReLU(),
|
|
),
|
|
nn.Linear(1, 2, device=device),
|
|
)
|
|
self.lin = nn.Linear(2, 2, device=device)
|
|
self.seq2 = nn.Sequential(
|
|
nn.ReLU(),
|
|
nn.Linear(2, 3, device=device),
|
|
FakeSequential(
|
|
nn.Linear(3, 2, bias=False, device=device),
|
|
nn.Linear(2, 4, bias=False, device=device),
|
|
),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.seq2(self.lin(self.seq1(x)))
|