Commit Graph

38 Commits

Author SHA1 Message Date
a8c528c105 [1/N] Apply UP035 rule in tests (#163947)
Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947
Approved by: https://github.com/ezyang
2025-09-29 01:42:01 +00:00
c6392fcc06 [2/N] Port 3 fsdp distributed test cases to Intel GPU (#160940)
For https://github.com/pytorch/pytorch/issues/114850, we will port distributed tests to Intel GPU. This is the second PR for fsdp distributed test cases, the first is https://github.com/pytorch/pytorch/pull/160158.
We could enable Intel GPU with following methods and try the best to keep the original code styles:
- Use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- Enabled XPU for some test path

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160940
Approved by: https://github.com/guangyey, https://github.com/d4l3k
2025-09-17 10:45:28 +00:00
12112fd198 Fix bug in FSDP wrapped module with zero argument (#147771)
Fixes https://github.com/pytorch/pytorch/issues/147531

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147771
Approved by: https://github.com/awgu
2025-02-26 01:40:53 +00:00
22a4129a76 Generalization of FSDP common for non-cuda execution (#133209)
## Motivation
The FSDP common code for FSDP UT execution is mostly written with cuda device in mind. However other devices such the intel Gaudi supports most of the functionality. We are generalizing the base content so that the UT content can be used for non-cuda device execution.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133209
Approved by: https://github.com/kwen2501
2024-09-27 00:38:10 +00:00
920f0426ae Add None return type to init -- tests rest (#132376)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132376
Approved by: https://github.com/jamesjwu
ghstack dependencies: #132335, #132351, #132352
2024-08-01 15:44:51 +00:00
29cc293725 [BE]: FURB142 - Remove set mutations. Use set update (#124551)
Uses set mutation methods instead of manually reimplementing (update, set_difference etc).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124551
Approved by: https://github.com/ezyang
2024-04-21 14:12:33 +00:00
8810fdd21e fsdp: Unit test for ModuleWrapPolicy as a Callable (#117395)
We use `_or_policy` as a `Callable` to wrap a `ModuleWrapPolicy` instance as a `Callable`.

Fixes https://github.com/pytorch/pytorch/issues/109266

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117395
Approved by: https://github.com/wconstab
2024-01-25 17:40:06 +00:00
2a87ab4508 Refactor some tests by using TEST_CUDA & TEST_MULTIGPU instead (#116083)
as https://github.com/pytorch/pytorch/pull/116014#discussion_r1430510759 stated, refactor some tests related.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116083
Approved by: https://github.com/fduwjj
2024-01-03 08:53:59 +00:00
57fba6fd86 [FSDP][9/N] Introduce CustomPolicy (#104986)
This PR adds a new `CustomPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = CustomPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root

---

After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104986
Approved by: https://github.com/ezyang
ghstack dependencies: #104427, #104967, #104999, #104969
2023-08-03 12:46:36 +00:00
15953fdf35 [FSDP][8/N] Replace _FSDPPolicy.policy with _Policy._run_policy (#104969)
This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.

This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999
2023-08-03 12:42:14 +00:00
a8c52863dd [FSDP][6/N] Check valid param freezing for ModuleWrapPolicy (#104427)
This PR adds improved error/warning messaging when auto wrapping with `ModuleWrapPolicy` in the presence of frozen parameters.
- For `use_orig_params=False`, FSDP requires uniform `requires_grad` for each FSDP instance. This PR adds a `ValueError` at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names.
- For `use_orig_params=True`, FSDP allows non-uniform `requires_grad` for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a `UserWarning` at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names.
    - There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain.

<details>
<summary> Why DFS via named_children() vs. Using named_modules()</summary>

```
LoraModel(
  (embed_tokens): Embedding(100, 32)
  (layers): ModuleList(
    (0-3): 4 x LoraDecoder(
      (attn): LoraAttention(
        (q_proj): Linear(in_features=32, out_features=32, bias=False)
        (lora_A): Linear(in_features=32, out_features=8, bias=False)
        (lora_B): Linear(in_features=8, out_features=32, bias=False)
        (k_proj): Linear(in_features=32, out_features=32, bias=False)
        (v_proj): Linear(in_features=32, out_features=32, bias=False)
        (o_proj): Linear(in_features=32, out_features=32, bias=False)
      )
      (mlp): LoraMLP(
        (proj1): Linear(in_features=32, out_features=128, bias=False)
        (proj2): Linear(in_features=128, out_features=32, bias=False)
      )
      (inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    )
  )
  (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)
```
Reverse topological order with stack-based DFS via `named_children()`:
```
[
  'embed_tokens',
  'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0',
  'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1',
  'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2',
  'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3',
  'layers', 'norm', ''
]
```
Reverse topological order with `named_modules()`:
```
[
  'norm',
  'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3',
  'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2',
  'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0',
  'layers', 'embed_tokens', ''
]
```
With the stack-based DFS via `named_children()`, reversing the topological order gives us each level in the module tree in the registered order, wheres with `named_modules()`, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the _first-registered_ module that violates the frozen/non-frozen condition.

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104427
Approved by: https://github.com/ezyang
2023-08-02 21:44:44 +00:00
6d71b4f9f1 [FSDP][2/N][Easy] Prepare _auto_wrap for fully_shard (#104407)
This mainly just changes the `_auto_wrap()` function signature and generalizes the `_check_nested_wrapping()` to both wrapper and composable paths (though the composable path will not hit in this PR).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104407
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-07-08 12:40:09 +00:00
f3e42f15e9 [FSDP] Start to generalize modules to ignore for mixed precision (#102010)
The main use case here is that folks would like to ignore layer norm for mixed precision. This can now be enabled with:

```
mp_config = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
            _mixed_precision_module_classes_to_ignore=[_BatchNorm, nn.LayerNorm],
        )
```

This is done by classes of types in `_mixed_precision_module_classes_to_ignore` being wrapped in their own FSDP unit with mixed preicsion disabled. This is only enabled for auto wrapping.

We also add module pre and post hooks to cast / downcast inputs to the appropriate full precision.

Differential Revision: [D46079957](https://our.internmc.facebook.com/intern/diff/D46079957/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102010
Approved by: https://github.com/awgu
2023-05-25 00:45:54 +00:00
d01bf1d1f1 [FSDP] Introduce ModuleWrapPolicy for simplicity (#88450)
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`).

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88450
Approved by: https://github.com/zhaojuanmao
2022-11-12 04:14:32 +00:00
107f92a683 [FSDP] ufmt FSDP test (#87812)
This applies `ufmt` to all of the FSDP test files in the `test/distributed/fsdp/` directory.

**Test Plan**
CI

**Notes**
For VSCode users,
- Install `ufmt`: https://pypi.org/project/ufmt/
- Install VSCode `ufmt` extension: https://marketplace.visualstudio.com/items?itemName=omnilib.ufmt
- Include in `settings.json`:
```
{
    "[python]": {
        "editor.defaultFormatter": "omnilib.ufmt",
        "editor.formatOnSave": true,
    },
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87812
Approved by: https://github.com/rohan-varma
2022-10-27 04:25:55 +00:00
125e9256f4 [FSDP] Add back forward_prefetch (#85177)
- This implements explicit forward prefetching following the static 1st iteration's pre-forward order when `forward_prefetch=True` in the FSDP constructor.
- This has the same unit test coverage as the original `forward_prefetch`.
- I checked via print statements that the prefetches are happening, but since I cannot get a good CPU bound workload, it is hard to tell via traces that the prefetch is working.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85177
Approved by: https://github.com/zhaojuanmao
2022-09-21 14:40:37 +00:00
afcc7c7f5c [FSDP] Generalize prefetching; lower unshard/reshard to handle (#83665)
### Additional Constructor Changes
- `self.sharding_strategy`
    - If the world size is 1, I clamp the sharding strategy to `NO_SHARD`, regardless of the passed-in sharding strategy, since the behavior is fully equivalent. This absolves the need for `p._is_sharded or self.world_size == 1` checks in the core code. Once we fully shift the paradigm to using handles, this should result in a clear net positive. However, for now, we still have some places where we interface directly with the `FlatParameter`, in which case we have some temporary hacky code.
- `HandleConfig`
    - As a part of the new design abstraction, much logic is lowered to the `FlatParamHandle`. This requires the handle be aware of mixed precision, CPU offloading, sharding strategy, and the process group (for world size > 1). To be less error-prone, I re-defined the `dataclass`s and `enum`s for the handle. These can be removed and coalesced with the existing ones.
    - The drawback is that the `FlattenParamsWrapper` constructor now takes in the `HandleConfig` to forward it to the `FlatParamHandle` constructor. I tolerate this since we plan to retire the FPW. For now, the handle's process group attributes are set later when we call `handle.shard()`.
    - We will dive into this logic lowering later. For now, the idea is we need to pass some extra info to the handle, which must go through the FPW.
- `FullyShardedDataParallel._shard_parameters()` -> `FlatParamHandle.shard()`
- [Important] Generalizing attributes to remove the 1 `FullyShardedDataParallel` : 1 `FlatParameter` assumption
    - **Before:** `_fsdp_graph_order`, `_pre_backward_hook_full_params_prefetched`, `_forward_full_params_prefetched`, `reshard_after_forward` are with respect to 1 `FullyShardedDataParallel`
    - **After:** (1) We use `FlatParamHandle` in place of `FullyShardedDataParallel`. (2) The atomic unit for forward and pre-backward is a _group_ of handles involved in the same module's forward/pre-backward. This is represented as `Tuple[FlatParamHandle, ...]`. For now, this is **always a singleton tuple**, but this shift enables a module having multiple FSDP parameters (which we have use cases for).
- `_reset_lazy_init()` attributes
    - The prefetched flags are merged into `self._handles_prefetched`, which is directly defined in the constructor. `reshard_after_forward` is retired since it can be fully determined by other attributes (`_is_root` and `sharding_strategy`).

## FSDP Runtime: Unshard

The first step is to read the existing `_rebuild_full_params()`. A few notable observations:
- It returns `Tuple[Tensor, bool]`. The first element is the _padded unsharded flattened parameter_, and the second element is whether we can free it upon exiting `summon_full_params()`. This return value is **only used in `summon_full_params()`**.
- If parameter mixed precision is enabled and the `FlatParameter` is already unsharded, then the low precision shard (`_mp_shard`) is still re-allocated on GPU. (It is freed at the end of the method.)
- If CPU offloading is enabled and the `FlatParameter` is already unsharded, then there is a no-op `p.data = p.data.to(self.compute_device, non_blocking=True)`.
- Inside `summon_full_params()`, `mixed_precision_cast_ran` is always `False`. Therefore, the return value for the `not p._is_sharded and mixed_precision_cast_ran` branch is unused.
-`summon_full_params()` can only be called (before forward or after backward) or (between forward and backward). Given this, I cannot think of a case where we call `summon_full_params()`, the `FlatParameter` is already unsharded, but `reshard_after_forward` is `True`. The `FlatParameter` should be sharded (before forward or after backward), and the `FlatParameter` may only be unsharded (between forward and backward) if `reshard_after_forward` is `False`.
- If parameter mixed precision is enabled and the sharding strategy is a sharded one, then inside `summon_full_params()`, the `FlatParameter` is unsharded in full precision. This involves allocating a new padded unsharded flattened parameter on GPU in full precision since `_full_param_padded` is in the low precision.

Some comments:
- Ideally, we reduce the complexity of the core code path: i.e. unshard for forward and pre-backward. If the return value is only used for `summon_full_params()`, we should consider if we can compartmentalize that logic.
- The branching is complex, and some return values are never used, where this fact is not immediately obvious. We should see if we can reduce the branch complexity.

Disclaimer: The difference in attribute semantics between `NO_SHARD` and the sharded strategies makes it challenging to unify the cases. This PR does not attempt to address that since it requires more design thought. However, it does attempt to reduce the complexity for the sharded strategies.

### Unshard: Core Code Path
Let us trace through the new logical unshard.
1. `FullyShardedDataParallel._unshard(self, handles: List[FlatParamHandle], prepare_gradient: bool)`
    - This iterates over the handles and calls `handle.pre_unshard()`, `handle.unshard()`, and `handle.post_unshard(prepare_gradient)` in the all-gather stream.
2. `FlatParamHandle.needs_unshard(self)`
    - We take an aside to look at this key subroutine.
    - For `NO_SHARD`, this returns `False`.
    - For sharded strategies, this checks if the padded unsharded flattened parameter is allocated. The padded unsharded flattened parameter is the base tensor for the unpadded unsharded flattened parameter, which is a view into the padded one. Thus, the padded one's allocation fully determines if the `FlatParameter` is unsharded.
    - For sharded strategies, to accommodate the parameter mixed precision + `summon_full_params()` case, we introduce `_full_prec_full_param_padded`, which is the padded unsharded flattened parameter in full precision. The helper `_get_padded_unsharded_flat_param()` takes care of this casing and returns the padded unsharded flattened parameter. Instead of allocating a new tensor each time, we manually manage `_full_prec_full_param_padded`'s storage just like for `_full_param_padded`.
3. `FlatParamHandle.pre_unshard(self)`
    - For sharded strategies, the postcondition is that the handle's `FlatParameter` points to the tensor to all-gather. This should be on the communication device and in the desired precision. The allocation and usage of the low precision shard for parameter mixed precision and the CPU -> GPU copy for CPU offloading both classify naturally in the pre-unshard.
    - For sharded strategies, if the `FlatParameter` does not need to be unsharded, `pre_unshard()` is a no-op. This avoids unnecessarily allocating and freeing the low precision shard.
    - For `NO_SHARD`, we simply preserve the existing semantics.
4. `FlatParamHandle.unshard(self)`
    - If the handle was resharded without freeing the padded unsharded flattened parameter (e.g. `summon_full_params()` between forward and backward when `reshard_after_forward=False`), then the `FlatParameter` points to the sharded flattened parameter. We need to switch to using the unsharded parameter. This is a design choice. Alternatively, we may not switch to using the sharded flattened parameter in `reshard()` if we do not free the padded unsharded flattened parameter. However, the postcondition that the `FlatParameter` points to the sharded flattened parameter after `reshard()` is helpful logically, so I prefer this approach.
    - Otherwise, this allocates the padded unsharded flattened parameter, all-gathers, and switches to using the unpadded unsharded flattened parameter.
    - In the future, we may add an option to `unshard()` that additionally all-gathers the gradient.
5. `FlatParamHandle.post_unshard(self, prepare_gradient: bool)`
    - For sharded strategies, if using parameter mixed precision, this frees the low precision shard. More generally, this should free any sharded allocations made in `pre_unshard()` since the all-gather has been launched. If using CPU offloading, the GPU copy of the local shard goes out of scope after `unshard()` and is able to be garbage collected. **We should understand if there is any performance difference between manually freeing versus deferring to garbage collection since our usage is inconsistent.** For now, I preserve the existing semantics here.
    - `prepare_gradient` is meant to be set to `True` for the pre-backward unshard and `False` for the forward unshard. This runs the equivalent logic of `_prep_grads_for_backward()`.
    - This post-unshard logic (notably the gradient preparation) now runs in the all-gather stream, which is fine because we always have the current stream wait for the all-gather stream immediately after `FullyShardedDataParallel._unshard()`. IIUC, we do not need to call `_mp_shard.record_stream(current_stream)` (where `current_stream` is the default stream) because `_mp_shard` is allocated and freed in the same (all-gather) stream.
    - A postcondition is that the `FlatParameter` is on the compute device. It should also have the unpadded unsharded size (though I do not have a check for this at the moment).

### Unshard: `summon_full_params()`
Now that we see how the logical unshard has been reorganized for the core code path, let us dive into `summon_full_params()`.

The two constraints are:
1. If using parameter mixed precision, we should unshard in full precision.
2. We must determine if we should free the padded unsharded flattened parameter upon exiting.

The first constraint is addressed as described before in the core unshard code path, so it remains to explore the second constraint.

I propose a simple rule: **We free iff we actually unshard the `FlatParameter` in `summon_full_params()`** (i.e. it was not already unsharded). We perform a case analysis:

**Parameter mixed precision enabled:**
* `NO_SHARD`: `flat_param.data` points to `flat_param._local_shard`, which is the full precision unsharded flattened parameter. This is **not safe to free**.
* `FULL_SHARD` / `SHARD_GRAD_OP`: We force full precision and all-gather to `_full_prec_full_param_padded`. We do not support `nested summon_full_params()`, so `_full_prec_full_param_padded` must be unallocated. We unshard, and it is **safe to free**.

**Parameter mixed precision disabled:**
* `NO_SHARD`: This is the same as with mixed precision enabled. This is **not safe to free**.
* `FULL_SHARD` / `SHARD_GRAD_OP`: We all-gather to `_full_param_padded`. It may already be unsharded.
    * Already unsharded: The unshard is a no-op. This is **not safe to free**.
        * For `FULL_SHARD`, this can happen for the root FSDP instance after `forward()` but before backward.
        * For `SHARD_GRAD_OP`, this can happen for all FSDP instances after `forward()` but before backward.
    * Needs unshard: We unshard. This is **safe to free**.

Therefore, we see that it is not safe to free when using `NO_SHARD` and when using a sharded strategy but the `FlatParameter` is already unsharded. This is precisely the proposed rule.

There were two notable edge cases that the existing code did not address.
1. The existing code tests if the `FlatParameter` is already unsharded by checking the allocation status of `_full_param_padded`. When using parameter mixed precision, this is the incorrect tensor to check. If `_full_param_padded` is allocated (e.g. when `reshard_after_forward=False` and calling `summon_full_params()` between forward and backward), the already-unsharded check is a false positive, and `summon_full_params()` does not correctly force full precision. https://github.com/pytorch/pytorch/issues/83068
    - This PR's `needs_unshard()` check correctly routes to the appropriate padded unsharded flattened parameter depending on the calling context (i.e. if it needs to force full precision or not).
2. The existing code does not free the GPU copy of the padded unsharded flattened parameter when calling `summon_full_params(offload_to_cpu=True)`. It unshards the `FlatParameter`, moves the padded unsharded flattened parameter to CPU, and sets the `FlatParameter` data to be the appropriate unpadded view into the padded unsharded flattened parameter on CPU. However, `_full_param_padded` still points to the all-gathered padded unsharded flattened parameter on GPU, which is kept in memory. https://github.com/pytorch/pytorch/issues/83076
    - This PR frees the GPU copy and reallocates it upon exiting `summon_full_params()`. This is essential for avoiding peak GPU memory usage from increasing as we recurse through the module tree. There may be some cases where we can avoid reallocation altogether, but that can be addressed in a follow-up PR.
    - This PR offloads the *unpadded* unsharded flattened parameter to CPU directly instead of the *padded* one. As far as I can tell, there is no need to include the padding since unflattening the original parameters does not require the padding.
    - The relevant code is in the context manager `FlatParamHandle.to_cpu()`.

### Unshard: Mixed-Precision Stream

This PR removes the mixed precision stream usage. As is, I do not think there is any extra overlap being achieved by the stream usage.

The low precision shard is allocated and copied to in the mixed precision stream ([code](1f99bdfcc4/torch/distributed/fsdp/fully_sharded_data_parallel.py (L1401-L1412))), and the current stream (in this case the all-gather stream) waits for the mixed precision stream ([code](1f99bdfcc4/torch/distributed/fsdp/fully_sharded_data_parallel.py (L1414))). However, we immediately schedule an all-gather that communicates that exact low precision shard ([code](1f99bdfcc4/torch/distributed/fsdp/fully_sharded_data_parallel.py (L3338))) with no other meaningful computation between. If we remove the mixed precision stream, the low precision shard is allocated and copied to in the all-gather stream (including the non-blocking CPU -> GPU copy if using CPU offloading).

Under this PR's design, we may consider a "pre-unshard" stream for all logical pre-unshard data transfers if we want to overlap in the future. IIUC, the overlap opportunity exists if there are multiple `FlatParameter`s per module, and we only have the all-gather stream wait for the data transfer corresponding to the local shard it communicates, not the others.

If we agree on removing the mixed-precision stream for now, I will remember to delete it from `_init_streams()`.

## FSDP Runtime: Reshard

Like with unshard, the first step is the look at the existing `_free_full_params()` and `_use_param_local_shard()`. A few notable observations:
- For only `NO_SHARD`, `_free_full_params()` includes a call to `_free_mp_shard()`.
- For `summon_full_params()`, there is a separate `_free_full_params_and_use_local_shard()` that duplicates the main logic of `_free_full_params()` and calls `_use_param_local_shard()`.
- In `forward()`, if `reshard_after_forward=True`, we call `_free_full_params()` and then `_free_mp_shard()`. Hence, for `NO_SHARD`, the `_free_mp_shard()` is a no-op.
- In the post-backward hook, we typically call `_free_full_params()` and `_free_mp_shard()`. The `_free_mp_shard()` is a no-op for `NO_SHARD` and if `reshard_after_forward=True`.

Some comments:
- The code certainly works, but some of the no-ops are subtle. When possible, we should make it clear when calls are no-ops or not. It is good that the existing code documents that `_free_mp_shard()` is a no-op in the post-backward hook when `reshard_after_forward=True`. However, there are still some non-obvious no-ops (around `NO_SHARD`).
- We should see if we can avoid the duplicate `_free_full_params_and_use_local_shard()`.

Let us trace through the logical reshard:
1. `FullyShardedDataParallel._reshard(self, handles: List[FlatParamHandle], free_unsharded_flat_params: List[bool])`
    - The two args should have the same length since they are to be zipped.
    - The goal of having `free_unsharded_flat_params` is that the caller should be explicit about whether the (padded) unsharded flattened parameter should be freed. The low precision shard is always meant to be freed (as early as possible), so there is no corresponding `List[bool]`.
2. `FlatParamHandle.reshard(self, free_unsharded_flat_param: bool)`
    - This frees the (padded) unsharded flattened parameter if `free_unsharded_flat_param` and switches to using the sharded flattened parameter.
    - Echoing back to forcing full precision in `summon_full_params()`, `_free_unsharded_flat_param()` frees the correct tensor by using `_get_padded_unsharded_flat_parameter()`.
3. `FlatParamHandle.post_reshard(self)`
    - I am not fully content with the existence of this method, but this seems to be an unavoidable consequence of `NO_SHARD`. Perhaps, this may be useful in the future for other reasons though.
    - Right now, this method is only meaningful for `NO_SHARD` + parameter mixed precision + outside `summon_full_params()`. `_mp_shard` is not freed in the post-unshard since it is also the low precision _unsharded_ flattened parameter, so we must delay the free until the the post-reshard.

Below the `FlatParamHandle.reshard()` and `post_reshard()` layer, there should not be any no-ops.

One final comment I will mention is that I like the `pre_unshard()`, `unshard()`, `post_unshard()`, and `reshard()`, `post_reshard()` organization because it makes it clear what the boundaries are and their temporal relationship. Through that, we can set pre- and post-conditions. Furthermore, we can eventually convert logic to hooks that may be registered on the `FlatParamHandle` (for `pre_unshard()`, `post_unshard()`, and `post_reshard()`). This may improve the customizability of FSDP.

 ## FSDP Runtime: `forward()`

- This PR reorganizes `forward()` in preparation for non-recursive wrapping, which uses pre-forward and post-forward hooks that expect the signature `hook(module, input)`. For FSDP, the `module` and `input` arguments are not used.
- This PR creates a new method `_fsdp_root_pre_forward()` to handle the logic only the root FSDP should run.

## FSDP Prefetching

Finally, we dive into the prefetching changes. Some highlights:
1. This PR unifies the execution order validation and prefetching implementations.
    - Both involve the execution order and can be unified to share some boilerplate.
2. Execution order validation only runs when the distributed debug level is `INFO`.
    - We have yet to have one success case where we actually catch an unintended source of dynamism. The warning is also too verbose. Hence, we are gating it by the `INFO` level.
3. This PR moves prefetching to be with respect to groups of handles (as mentioned in the constructor comment).
    - This is essential for supporting prefetching with non-recursive wrapping.
4. This PR does not include "bubbles", i.e. modules with no handles, in the recorded execution order(s). This deviates from the existing implementation.
    - This makes prefetching possibly more aggressive (when there are such bubbles), but it should not have significant performance implications either way.
5. This PR changes backward prefetching to reset the post-forward order each iteration (as intended).
6. This PR changes forward prefetching to use the first iteration's pre-forward order instead of the first iteration's post-forward order. (We can discuss whether we want this in this PR or not. Otherwise, I can keep it as using the post-forward order to preserve the existing semantics.) This PR also removes the `all_gather_stream.wait_stream(current_stream)` before forward prefetching because it does not help with high GPU reserved memory. We can add that back if desired.

### Appendix
#### Reverse Post-Forward Order Is Not Always the Pre-Backward Order
The existing PT-D FSDP pre-backward prefetching uses the reverse post-forward order.
<details>
  <summary>Model Code</summary>

  ```
  class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 4, kernel_size=3),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True),
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(4, 4, kernel_size=3),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=False),
        )
        self.block3 = nn.Linear(12, 8)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
            nn.Linear(4, 10),
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return self.head(x)

  model = Model().cuda()
  fsdp_kwargs = {}
  model.block1[1] = FSDP(model.block1[1], **fsdp_kwargs)  # BN2d
  model.block2[1] = FSDP(model.block2[1], **fsdp_kwargs)  # BN2d
  model.block1 = FSDP(model.block1, **fsdp_kwargs)
  model.block2 = FSDP(model.block2, **fsdp_kwargs)
  model.block3 = FSDP(model.block3, **fsdp_kwargs)
  model = FSDP(model, **fsdp_kwargs)
  ```
</details>

<details>
  <summary>Execution Orders </summary>

  ```
  Pre-backward hook for ('head.2.weight', 'head.2.bias') 140339520587136 (model)
  Pre-backward hook for ('weight', 'bias') 140339461194656 (block3)
  Pre-backward hook for ('0.weight', '0.bias') 140339520589776 (block2)
  Pre-backward hook for ('weight', 'bias') 140339520587664 (block2 BN)
  Pre-backward hook for ('weight', 'bias') 140339520586656 (block1 BN)
  Pre-backward hook for ('0.weight', '0.bias') 140339520588768 (block1)

  Pre-forward order:
  ('head.2.weight', 'head.2.bias') 140339520587136 (model)
  ('0.weight', '0.bias') 140339520588768 (block1)
  ('weight', 'bias') 140339520586656 (block1 BN)
  ('0.weight', '0.bias') 140339520589776 (block2)
  ('weight', 'bias') 140339520587664 (block2 BN)
  ('weight', 'bias') 140339461194656 (block3)

  Reverse post-forward order:
  ('head.2.weight', 'head.2.bias') 140339520587136 (model)
  ('weight', 'bias') 140339461194656 (block3)
  ('0.weight', '0.bias') 140339520589776 (block2)
  ('weight', 'bias') 140339520587664 (block2 BN)
  ('0.weight', '0.bias') 140339520588768 (block1)
  ('weight', 'bias') 140339520586656 (block1 BN)
  ```
</details>

Differential Revision: [D39293429](https://our.internmc.facebook.com/intern/diff/D39293429)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83665
Approved by: https://github.com/zhaojuanmao
2022-09-13 17:05:10 +00:00
9d5b3e4da8 [FSDP] Remove forward_prefetch (#84600)
We are removing the `forward_prefetch` option. By the nature of async GPU kernel execution, launching the CPU kernel for the next layer's all-gather early does not actually improve performance. Moreover, the existing `forward_prefetch` uses the post-forward order instead of the pre-forward order, which leads to mis-targeted prefetched all-gathers.

Differential Revision: [D39454217](https://our.internmc.facebook.com/intern/diff/D39454217)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84600
Approved by: https://github.com/zhaojuanmao
2022-09-13 02:45:07 +00:00
2cb75e1579 [BE][FSDP] Introduce FSDPTestModel interface (#80873)
**Overview**
Please refer to https://github.com/pytorch/pytorch/issues/80867 first.

This addresses:
> Goal 3: Refactor model construction to enable simpler testing for the non-recursive wrapping path.

The idea is that we have an abstract class `FSDPTestModel` that defines the interface expected from the parity check and training boilerplate. This PR refactors the models in `common_fsdp.py` used in `test_fsdp_core.py` to implement this interface. Further unification under this interface is coming in follow-up PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80873
Approved by: https://github.com/rohan-varma
2022-07-13 18:46:28 +00:00
70446c25d7 [FSDP] Add forward prefetching option in FSDP API (#78841)
Fixes #78608

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78841
Approved by: https://github.com/zhaojuanmao
2022-06-15 20:59:08 +00:00
c29df68f95 [FSDP] Return original module when fsdp wrapped model call .module (#78671)
Fixes #78607

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78671
Approved by: https://github.com/awgu, https://github.com/rohan-varma
2022-06-03 04:38:19 +00:00
aaf5c32992 device_id
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77321

Approved by: https://github.com/awgu
2022-05-13 22:14:49 +00:00
9493900876 [Reland] Mixed precision batchnorm fix (#77234)
Reland of https://github.com/pytorch/pytorch/pull/77089, which was reverted due to land race.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77234
Approved by: https://github.com/zhaojuanmao
2022-05-11 15:03:34 +00:00
091f8915ae Revert "Mixed Precision batchnorm fix (#77089)"
This reverts commit bf61b795031b4f30b2cf1267c5625bfe36cd5f3c.

Reverted https://github.com/pytorch/pytorch/pull/77089 on behalf of https://github.com/suo
2022-05-11 03:00:33 +00:00
bf61b79503 Mixed Precision batchnorm fix (#77089)
Rehash of https://github.com/pytorch/pytorch/pull/76642 which could not be updated due to GHF out of sync issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77089
Approved by: https://github.com/awgu
2022-05-11 02:22:01 +00:00
3621462ebb [FSDP] Change default auto wrap policy name (#76858)
current default_auto_wrap_policy is not really a recommended default auto wrap policy, change the name to avoid confusion, updated the doc as well

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76858
Approved by: https://github.com/rohan-varma
2022-05-10 13:20:26 +00:00
786903ea29 Provide an auto wrap policy for common transformer models
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76455

Provide an auto wrap policy for common transformer models

Differential Revision: [D35972488](https://our.internmc.facebook.com/intern/diff/D35972488/)

Approved by: https://github.com/rohan-varma
2022-05-01 05:57:42 +00:00
648823b087 [FSDP] Add ignored_modules ctor arg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75431

Approved by: https://github.com/rohan-varma
2022-04-12 19:46:00 +00:00
247e4cca1e [BC-Breaking] Remove redundant fsdp prefix (#73791)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73791

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D34645228

Pulled By: mrshenli

fbshipit-source-id: 269a288050969f720c7995aa1992bd2964a860e7
(cherry picked from commit 5981488be59766aaf6791f33086b0808332da771)
2022-03-12 06:24:58 +00:00
270c27efeb [FSDP] Add always_wrap policy (#73687)
Summary:
Add a smaller helper policy that always returns True to automatically always wrap all FSDP submodules. This is the first and simplest step of providing a set of policies that allow users to seamlessly experiment with different FSDP config.

More Context: https://github.com/pytorch/pytorch/issues/68789

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73687

Reviewed By: jbschlosser, zhaojuanmao

Differential Revision: D34625801

Pulled By: rohan-varma

fbshipit-source-id: f20c951f8d62ea29b504543c93acd546247d8206
(cherry picked from commit 3b0bf02bc8bb236ee09e2fa986d52bbf5231efc3)
2022-03-08 05:37:55 +00:00
2336571cb7 make fsdp folder to be public (#72084)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084

make fsdp folder to be public
ghstack-source-id: 148173447

Test Plan: unit tests

Reviewed By: mrshenli

Differential Revision: D33903417

fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe2f391941bb0afa611dd39994585513)
2022-02-02 15:50:14 +00:00
d0ff1f0013 [FSDP] Backward prefetch in recursive call (#71804)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71804

Add backward prefetch arg when using auto_wrap_policy. Unittests are
updated appropriately.
ghstack-source-id: 147753214

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D33782346

fbshipit-source-id: c0176b48db29c3756a8873e809610ed53480102b
(cherry picked from commit 764acb3f1c8fb9879b6c92a934df1a7d2c9e3f3d)
2022-01-28 00:34:08 +00:00
a30b0cf52a [FSDP] Add/refactor unit test for wrap (#71803)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71803

1. Extra check for wrapping with override args,

2. Enhance UT to make sure
`wrap` doesn't wrap outside of ctx.
ghstack-source-id: 147753225

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D33774512

fbshipit-source-id: 1f8d60bdf9b3ba257fee465064a0e25235b3622b
(cherry picked from commit 9ab775b29eddcd193c11398184bee8beffed0327)
2022-01-28 00:34:08 +00:00
29a7cb41d8 [BE] Fix FSDP flaky test (#71525)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71525

Closes https://github.com/pytorch/pytorch/issues/71496. Use file init
for test as opposed to TCP init which runs into some port racing conditions as
seen in the failures for that issue.
ghstack-source-id: 147300691

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D33676165

fbshipit-source-id: fcf83f7c7541d3521d3e38481195b0c7cb081691
(cherry picked from commit ea091c4af7d864e4d2ebcda6f72d04e17ae7bd82)
2022-01-21 21:00:13 +00:00
c95277e92a [FSDP] Remove auto_wrap() (#69356)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69356

Per title
ghstack-source-id: 144807949

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D32816150

fbshipit-source-id: 6b4eacc63edd267bc1eb8a1c1d6c753bc581d63a
2021-12-06 12:11:14 -08:00
7fad758e02 [FSDP] AutoWrap Main API (#68155)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68155

Per title
ghstack-source-id: 144398229

Test Plan: CI

Reviewed By: pbelevich, mrshenli

Differential Revision: D32327954

fbshipit-source-id: 36bdf06c1c50932a93acbfa97017c549fa490a6c
2021-12-01 00:16:38 -08:00
36d9a74bc6 Enforce that test cases extend from correct TestCase (#67819)
Summary:
Addresses https://github.com/pytorch/pytorch/issues/66903

Main code is in  torch/testing/_internal/common_utils.py and everything else is fixing the lint

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67819

Reviewed By: H-Huang

Differential Revision: D32259978

Pulled By: janeyx99

fbshipit-source-id: 39c5ffbaa510e1e533d6bdacf5c6158a3dd9885d
2021-11-08 18:28:36 -08:00
5ad169b7cc Adding in Wrap functions for FSDP from Fairscale (#67292)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67292

as title

Test Plan: buck test mode/dev-nosan //caffe2/test/distributed/fsdp:wrap --keep-going

Reviewed By: rohan-varma

Differential Revision: D31936404

fbshipit-source-id: b7ebead9a649766aec83e5630c2ce1386ad33e11
2021-11-02 13:30:41 -07:00