Commit Graph

22 Commits

Author SHA1 Message Date
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
57fba6fd86 [FSDP][9/N] Introduce CustomPolicy (#104986)
This PR adds a new `CustomPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = CustomPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root

---

After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104986
Approved by: https://github.com/ezyang
ghstack dependencies: #104427, #104967, #104999, #104969
2023-08-03 12:46:36 +00:00
15953fdf35 [FSDP][8/N] Replace _FSDPPolicy.policy with _Policy._run_policy (#104969)
This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.

This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999
2023-08-03 12:42:14 +00:00
a8c52863dd [FSDP][6/N] Check valid param freezing for ModuleWrapPolicy (#104427)
This PR adds improved error/warning messaging when auto wrapping with `ModuleWrapPolicy` in the presence of frozen parameters.
- For `use_orig_params=False`, FSDP requires uniform `requires_grad` for each FSDP instance. This PR adds a `ValueError` at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names.
- For `use_orig_params=True`, FSDP allows non-uniform `requires_grad` for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a `UserWarning` at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names.
    - There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain.

<details>
<summary> Why DFS via named_children() vs. Using named_modules()</summary>

```
LoraModel(
  (embed_tokens): Embedding(100, 32)
  (layers): ModuleList(
    (0-3): 4 x LoraDecoder(
      (attn): LoraAttention(
        (q_proj): Linear(in_features=32, out_features=32, bias=False)
        (lora_A): Linear(in_features=32, out_features=8, bias=False)
        (lora_B): Linear(in_features=8, out_features=32, bias=False)
        (k_proj): Linear(in_features=32, out_features=32, bias=False)
        (v_proj): Linear(in_features=32, out_features=32, bias=False)
        (o_proj): Linear(in_features=32, out_features=32, bias=False)
      )
      (mlp): LoraMLP(
        (proj1): Linear(in_features=32, out_features=128, bias=False)
        (proj2): Linear(in_features=128, out_features=32, bias=False)
      )
      (inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    )
  )
  (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)
```
Reverse topological order with stack-based DFS via `named_children()`:
```
[
  'embed_tokens',
  'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0',
  'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1',
  'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2',
  'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3',
  'layers', 'norm', ''
]
```
Reverse topological order with `named_modules()`:
```
[
  'norm',
  'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3',
  'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2',
  'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0',
  'layers', 'embed_tokens', ''
]
```
With the stack-based DFS via `named_children()`, reversing the topological order gives us each level in the module tree in the registered order, wheres with `named_modules()`, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the _first-registered_ module that violates the frozen/non-frozen condition.

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104427
Approved by: https://github.com/ezyang
2023-08-02 21:44:44 +00:00
b65b9e6ff4 [PT][FSDP] Combine _utils.py into _common_utils.py [1/3] (#105857)
Summary:
https://github.com/pytorch/pytorch/issues/97813

This diffs moves `_override_module_mixed_precision`

Test Plan: CI

Differential Revision: D47706059

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105857
Approved by: https://github.com/awgu
2023-07-25 17:37:08 +00:00
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3dd4e66059522bf5f5c1ba0431e2069.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262f82bbc76aa776119c9fea079fbffe3.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
610f74627e [FSDP][4/N] Remove _get_fully_sharded_module_to_states (#104409)
`_get_fully_sharded_module_to_states()` was used to emulate auto wrapping without actually calling `fully_shard`. Since we committed to unifying (see previous PR), we can remove this function and its helpers/tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104409
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-07-08 12:40:14 +00:00
6d71b4f9f1 [FSDP][2/N][Easy] Prepare _auto_wrap for fully_shard (#104407)
This mainly just changes the `_auto_wrap()` function signature and generalizes the `_check_nested_wrapping()` to both wrapper and composable paths (though the composable path will not hit in this PR).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104407
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-07-08 12:40:09 +00:00
d58f75be8b [FSDP][1/N] Move wrapper ModuleWrapPolicy to new path (#104346)
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support:
- Checking valid frozen parameters (uniform frozenness per FSDP)
- Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher)
- Writing auto wrapping policies that may take multiple passes over the module tree
- Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy)

The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping.

I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple.

The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)

To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104346
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-07-08 12:40:07 +00:00
f3e42f15e9 [FSDP] Start to generalize modules to ignore for mixed precision (#102010)
The main use case here is that folks would like to ignore layer norm for mixed precision. This can now be enabled with:

```
mp_config = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
            _mixed_precision_module_classes_to_ignore=[_BatchNorm, nn.LayerNorm],
        )
```

This is done by classes of types in `_mixed_precision_module_classes_to_ignore` being wrapped in their own FSDP unit with mixed preicsion disabled. This is only enabled for auto wrapping.

We also add module pre and post hooks to cast / downcast inputs to the appropriate full precision.

Differential Revision: [D46079957](https://our.internmc.facebook.com/intern/diff/D46079957/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102010
Approved by: https://github.com/awgu
2023-05-25 00:45:54 +00:00
5ee230face [FSDP][1/N] Refactor module materialization (#94196)
**Overview**
This refactors module materialization (i.e. meta device or `torchdistX` deferred initialization) to compute the parameter and buffer names as needed instead of pre-computing them. These are needed to reacquire references to the states (e.g. `module.get_parameter(param_name)`) after materialization since the materialization may create new variables.

This refactor simplifies `_get_fully_sharded_module_to_states()` (the core function for "pseudo auto wrapping") to better enable lowest common ancestor (LCA) module computation for shared parameters, for which tracking parameter and buffer names may complicate the already non-obvious implementation.

**Discussion**
The tradeoff is a worst case quadratic traversal over modules if materializing all of them. However, since (1) the number of modules is relatively small, (2) the computation per module in the quadratic traversal is negligible, (3) this runs only once per training session, and (4) module materialization targets truly large models, I think this tradeoff is tolerable.

**For Reviewers**
- `_init_param_handle_from_module()` initializes _one_ `FlatParamHandle` from a fully sharded module and represents the module wrapper code path. For this code path, there is no need to reacquire references to the parameters/buffers for now since the managed parameters are only computed after materialization. This works because the managed parameters have a simple definition: any parameter in the local root module's tree excluding those already marked as flattened by FSDP. Similarly, FSDP marks buffers to indicate that they have already been processed (synced if `sync_module_states`).
- `_init_param_handles_from_module()` initializes _all_ `FlatParamHandle`s from a fully sharded module and represents the composable code path. For this code path, we must reacquire references to parameters/buffers because each logical wrapping is specified as a list of parameters/buffers to group together by those variables and because materialization may create new variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94196
Approved by: https://github.com/rohan-varma
2023-02-13 21:43:00 +00:00
aec09eeb3a [FSDP][7/N] Support replicate in fully_shard (#91044)
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00
32fde53713 [FSDP][5/N] Add manual "wrapping" support for fully_shard (#90874)
This PR adds manual "wrapping" support for `fully_shard`. For example, for
```
fully_shard(mod.sub)
fully_shard(mod)
```
`mod.sub` and `mod` will share the same FSDP data structures.

To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90874
Approved by: https://github.com/mrshenli
2022-12-20 16:49:15 +00:00
8cd1808dbf [FSDP] Introduce "fully sharded module"; remove comm. module (#90933)
This PR removes the "communication module" (comm. module / `comm_module`) concept from the FSDP code base since it causes disproportionate confusion compared to its benefit for now.

Instead, we introduce the term "fully sharded module" as the single concept to unify the wrapper and non-wrapper code paths. The definition is presented in a note at the top of `flat_param.py`. I reproduce it here:

---
We define the **"fully sharded module"** to be the original `nn.Module` that owns a `FlatParamHandle`. It is the *single* module logically responsible for the *single* unshard/reshard pair for the handle's `FlatParameter` for a given forward or backward pass. The fully sharded module should be passed to the `FlatParamHandle` constructor.

For the wrapper code path:
- The `FullyShardedDataParallel` module wrapping the fully sharded module runs the unshard/reshard on behalf of the fully sharded module by overriding `nn.Module.forward`.
- The fully sharded module is exactly the module passed to the `FullyShardedDataParallel` constructor's `module` argument and is saved in `_fsdp_wrapped_module`.

For the non-wrapper code path:
- Hooks registered on the fully sharded module run the unshard/reshard.
- The fully sharded module may either be the direct argument to `fully_shard` or a submodule chosen by the provided wrapping policy.
---

After this PR, `handle.flat_param._fqns`, `_param_infos`, and `_shared_param_infos` all prefix names from the same module, namely the fully sharded module. This should make state dict less confusing.

---
As an example, consider:
```
mod: Module(
  sub1: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
  sub2: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
)
```
For wrapper FSDP manual wrap:
```
mod.sub1 = FSDP(mod.sub1)
mod.sub2 = FSDP(mod.sub2)
mod = FSDP(mod)
```
For wrapper FSDP auto wrap:
```
mod = FSDP(mod, auto_wrap_policy=ModuleWrapPolicy({Submodule}))
```
(WIP) For non-wrapper FSDP manual wrap:
```
fully_shard(mod.sub1)
fully_shard(mod.sub2)
fully_shard(mod)
```
For non-wrapper FSDP auto wrap:
```
fully_shard(mod, policy=ModuleWrapPolicy({Submodule}))
```
The fully sharded module **in all cases** are `mod`, `mod.sub1`, `mod.sub2`, and notably, `subsub1` and `subsub2`s are not fully sharded modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90933
Approved by: https://github.com/rohan-varma
2022-12-16 18:45:52 +00:00
7a08261a9c Fix fully_shard error when policy is not provided (#90151)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90151
Approved by: https://github.com/awgu
2022-12-05 15:21:47 +00:00
d01bf1d1f1 [FSDP] Introduce ModuleWrapPolicy for simplicity (#88450)
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`).

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88450
Approved by: https://github.com/zhaojuanmao
2022-11-12 04:14:32 +00:00
c1e28731b3 [FSDP()][10/N][11/N] Introduce composable (ctor only) (#87924)
This PR introduces the composable FSDP API (with constructor semantics only) along with some further constructor refactoring. A notable contribution here is `_get_submodule_to_states()`, which performs auto wrapping without actually wrapping.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87924
Approved by: https://github.com/mrshenli
2022-11-01 12:39:24 +00:00
78170701a3 [FSDP()][9/N] Refactor ctor (continued) (#87923)
This PR makes a second pass over the constructor. The logic has been grouped into `_init_<...>` functions based on intent (e.g. `_init_prefetching_state()` or `_init_runtime_state()`). This makes the initialization code for composable FSDP much cleaner than having to re-write the same sequences of lower-level helper calls.

This PR also moves `_ExecOrderData` into its own file `_exec_order_utils.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87923
Approved by: https://github.com/mrshenli
2022-11-01 12:39:21 +00:00