Before the fix, the unit test will fail at forward Dynamo tracing:
```
File "/data/users/willfeng/pytorch/test/distributed/_composable/test_replicate_with_compiler.py", line 415, in test_ddp_tp
loss = compiled_replicate_model(data).sum()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...
torch._dynamo.exc.InternalTorchDynamoError: SymNodeVariable() is not a constant
from user code:
File "/data/users/willfeng/pytorch/torch/distributed/tensor/parallel/_data_parallel_utils.py", line 34, in _unflatten_tensor
result = DTensor.from_local(
```
After the fix, the compilation fails at a later step (Compiled Autograd tracing), due to needing "pre-dispatch tracing of backward graph" feature (see details at https://github.com/pytorch/pytorch/issues/127797#issuecomment-2291695474).
I believe this PR is a net improvement, because it should also fix the 1D Traceable FSDP2 failure case on internal models (https://github.com/pytorch/pytorch/issues/130978#issuecomment-2319476690), which is much harder to build a minimal unit test for.
Fixes https://github.com/pytorch/pytorch/issues/130978.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135315
Approved by: https://github.com/bdhirsh
Summary: `test/distributed/_composable/test_replicate_with_compiler.py` torch.compiles. This change introduces a version of MultiProcessTestCase that derives from the inductor TestCase class to make sure we always get a clean cache dir.
Test Plan: `python test/distributed/_composable/test_replicate_with_compiler.py`
Differential Revision: [D59925519](https://our.internmc.facebook.com/intern/diff/D59925519)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131053
Approved by: https://github.com/eellison
Summary: `test/distributed/_composable/test_replicate_with_compiler.py` exercises inductor. This change introduces a version of MultiProcessTestCase that derives from the inductor TestCase class to make sure we always get a clean cache dir.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129494
Approved by: https://github.com/eellison
## Summary
After this PR, the functional collective Python APIs will stop honoring `TORCH_DISABLE_NATIVE_FUNCOL` and only use native funcol ops. Specifically, this PR:
- Removed `use_native_funcol()`.
- Removed the code path in the Python APIs when `use_native_funcol()` is `False`.
- Changed the CI tests that runs on both native funcol and legacy funcol through the Python API to only run with native funcol.
## Test Changes
`test_functional_api.py`
- Removed the tests where only one of output_split_sizes or input_split_sizes is specified. This behavior is unreliable has been removed from the native funcol.
- Removed `TestWaitiness` which tests an implementation detail of the legacy funcol. We have equivalent tests for native funcol in `test/distributed/test_c10d_functional_native.py` b7fac76fc2/test/distributed/test_c10d_functional_native.py (L114-L116)
`test/distributed/_tensor/test_dtensor.py`
`test/distributed/_tensor/test_dtensor_compile.py`
`test/distributed/test_device_mesh.py`
`test/distributed/_tensor/experimental/test_tp_transform.py`
`test/distributed/_tensor/test_matrix_ops.py`
`test/distributed/test_inductor_collectives.py`
- All these tests were double running with both native funcol and legacy funcol. Changed to only run with native funcol.
`test/distributed/test_c10d_functional_native.py`
- Removed the `run_with_native_funcol` decorators.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123777
Approved by: https://github.com/wanchaol
ghstack dependencies: #123776
This PR enables DDP + TP using a TP internal API. This should not be the final implementation. A more sound implementation is to inline the TP internal API in DDP. In other words, DDP needs to be aware of DTensor so that we can support 2D state_dict.
This PR adds a compiled DDP + TP test to ensure the new compiled DDP fusion doesn't break TP all_reduce.
**TODOs**
- [x] Implement DDP allreduce fusion algorithm for Inductor post_grad pass.
- [x] Add unit tests to ensure the fusion doesn't DDP + TP.
- [ ] Group different PG and data type of all_reduces.
- [ ] Mixed precision supports and tests
- [ ] Implement the fusions with Inductor IR.
- [ ] Add auto bucketing based on Inductor profiling.
Differential Revision: [D54105050](https://our.internmc.facebook.com/intern/diff/D54105050/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120479
Approved by: https://github.com/wz337
ghstack dependencies: #113209
Differential Revision: [D49858057](https://our.internmc.facebook.com/intern/diff/D49858057/)
**TL;DR**
This PR implements 2 different DDP all_reduce fusions in Inductor post_grad fx passes. The two fusions are 1) fusion with concat op and 2) fusion with all_reduce_coalesced. When DDP detects that Python reducer is being used, DDP will automatically turn on the fusion.
This PR does not invent any algorithm and simply reflects the bucket size users set to DDP.
**Implementation Details**
*Fusion with concat op*
The idea of this fusion is to use a concat op to concatenate all the gradients into one tensor and perform one `all_reduce`. After the `wait` op of the `all_reduce`, splitting and reshaping will also be perform to get the individual gradient.
Because DDP needs to perform gradient scaling, the benefit of using this fusion is that we could perform the gradient scaling over the the concatenated buffer.
*Fusion with `all_reduce_coalesced`*
The idea of this fusion is to use `all_reduce_coalesced` op to directly perform the `all_reduce` over multiple buffers. This avoid the copy overhead but may not achieve the best NCCL performance. In addition, because there are multiple buffers, we could not do one simple gradient scaling but have to rely on `foreach_div` to help the gradient scaling.
**Limitations**
Current fusions do not distinguish `all_reduce` generated by different DDP modules. This is okay if all DDP instances use the same PG and data type. The support of multiple DDP instances with different PG and data type will come in the later PRs.
**TODOs**
- [x] Implement DDP allreduce fusion algorithm for Inductor post_grad pass.
- [ ] Add unit tests to ensure the fusion doesn't DDP + TP.
- [ ] Group different PG and data type of `all_reduce`s.
- [ ] Mixed precision supports and tests
- [ ] Implement the fusions with Inductor IR.
- [ ] Add auto bucketing based on Inductor profiling.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113209
Approved by: https://github.com/yf225
**Summary**
The reducer of `DistributedDataParallel` is implemented with C++ and it is not easy to trace the allreduce launched in the reducer. This PR modifies `DistributedDataParallel` to launch one allreduce per gradient when `compiled_autograd` is enabled. The changes allow us to use `compiled_autograd` to trace the allreduce and later be optimized (fused) in the Inductor.
**Key Logic**
1. If `ddp_python_hook` is True, we assume `compiled_autograd` is used. `DistributedDataParallel` registers `compiled_accum_grad_hook` for all parameters.
2. In the first forward() call, if `DistributedDataParallel` is not compiled, all `compiled_accum_grad_hook` are deregistered. If `DistributedDataParallel` is compiled, all `compiled_accum_grad_hook` will be compiled by `compiled_autograd`.
3. `compiled_accum_grad_hook` launches an allreduce to reduce the gradient of the parameter.
**Bucketing**
The compiled backward is slow because there is no bucketing for the allreduces. We rely on Inductor to bucket the allreduces.
The bucketing is done in a separate PR.
Differential Revision: [D49428482](https://our.internmc.facebook.com/intern/diff/D49428482/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110662
Approved by: https://github.com/wconstab