Fixes#159590
This is similar to the reverted commit #156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue #159601. This adds little to no overhead in eager ([see past benchmarks](https://github.com/pytorch/pytorch/pull/156868#issuecomment-3047831149)).
This also handles cases such as #159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160798
Approved by: https://github.com/bdhirsh
Fixes#159601
Unfortunately #156868 introduced a couple regressions (see #159590 and #159601). This reverts the commit while I am working on a permanent fix. This means the `in_compiled_autograd_initial_trace` global flag will be removed and the `_are_we_tracing()` will instead be replaced with the symint preprocessing step during sharding prop post init.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159671
Approved by: https://github.com/xmfan
This PR is part of the work to deprecate torch::deploy in OSS. Effectively it does 3 things to get started.
1. Remove test_deploy_interaction as we no longer need to worry about this
2. Remove all torch._running_with_deploy checks and use the False path always (surfaced 1)
3. Remove `USE_DEPLOY` and switch to the default path always
Note: MyPy does fail on a bunch of things here as a bunch of older files are touched. It may be better to fix these things on a separate PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158288
Approved by: https://github.com/albanD
This PR is part of the work to deprecate torch::deploy in OSS. Effectively it does 3 things to get started.
1. Remove test_deploy_interaction as we no longer need to worry about this
2. Remove all torch._running_with_deploy checks and use the False path always (surfaced 1)
3. Remove `USE_DEPLOY` and switch to the default path always
Note: MyPy does fail on a bunch of things here as a bunch of older files are touched. It may be better to fix these things on a separate PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158288
Approved by: https://github.com/albanD
This PR is part of the work to deprecate torch::deploy in OSS. Effectively it does 3 things to get started.
1. Remove test_deploy_interaction as we no longer need to worry about this
2. Remove all torch._running_with_deploy checks and use the False path always (surfaced 1)
3. Remove `USE_DEPLOY` and switch to the default path always
Note: MyPy does fail on a bunch of things here as a bunch of older files are touched. It may be better to fix these things on a separate PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158288
Approved by: https://github.com/albanD
Part of https://github.com/pytorch/torchtitan/issues/866
## Context
- Async TP needs to support the "reshape -> scaled_mm -> reshape" pattern because scaled mm only supports 2D input tensors and 2D scales.
- (a,b,c) => (a*b,c)
- (a\*b,c) @ (c,d) = (a\*b,d)
- (a\*b,d) => (a,b,d)
- Currently the implementation does not support scaled mm with rowwise scales **for all cases** of the reshape -> scaled_mm -> reshape pattern. The minimal example of this pattern is confirmed to work via this [unit test](00a2c68f67/test/distributed/tensor/parallel/test_micro_pipeline_tp.py (L406)), but more involved e2e examples in torchtitan fail silently (more context in final bullet point).
- Previously, the "A tensor" **node** referenced in the async TP graph manipulation code is the 3D+ node before the reshape, but the "A_scale" node is the 2d node from after the reshape, so they are incompatible.
- I previously implemented a simpler solution to this problem in https://github.com/pytorch/pytorch/pull/148001, with a [unit test](https://github.com/pytorch/pytorch/pull/148001/files#diff-115f1d0852382c9b58f22640d80999d879b33618e5f6c633fc9e4d0ca9781cecR406) confirming the fused node is indeed in the graph for the minimal example of the reshape->mm->reshape pattern. I also confirmed via manual e2e testing w/ torchtitan that the crash I was fixing no longer occurred. However, it turns out due to this [bug in torchtitan](https://github.com/pytorch/torchtitan/issues/866) it was causing async TP to fail silently and fall back to vanilla TP, hiding the fact that this original solution fixed the crash but the fusion would not occur for rowwise scales. Thus, more robust solution is needed to support all cases.
## Solution TL;DR
- Use the 2D 'A' tensor and corresponding 2D scales as input to the fused_matmul_reduce_scatter implementation, instead of the 3D+ tensor/scales.
- Track the "pre mm reshape" and "post mm reshape" separately, to be referenced in the `fused_scaled_matmul_reduce_scatter` implementation, to update the scatter dim through the pre-mm reshape, and apply the post-mm reshape before applying the reduce scatter and returning the output tensor.
- Separate the `fused_matmul_reduce_scatter` and the `fused_scaled_matmul_reduce_scatter` code paths, to simplify them both.
- By fixing the bug in torchtitan (PR https://github.com/pytorch/torchtitan/pull/965) and implementing support for rowwise scales in pytorch in this PR, together these changes will solve the problem of how to support rowwise scales with all types of AC.
## Additional details for reviewers
To use the 2D A tensor while also supporting the "reshape -> mm -> reshape" pattern, the following other changes were needed:
- Track the pre-mm reshape, as it will affect the scatter dim used in the fused_matmul_reduce_scatter impementation.
- Track the post-mm reshape, as it will affect the output shape used in the fused_matmul_reduce_scatter impementation
- Based on the pre-mm reshape and the original scatter dim, calculate the new scatter dim for the 2D tensor. This is needed because during the pipelined producer mm implementation, the scatter dim is moved to dim 0 (so it can be sharded along the first dim and then get chunks to do mm ops on by indexing into the first dim), then moved back to it's original place before the reduce-scatter.
- Use the tracked post-mm reshape to reshape the stacked partial 2D outputs of the mm ops into 3D outputs needed for 1) the reduce-scatter w/ the original scatter dim, and 2) the expected output shape to prevent shape errors with subsequent ops.
## Test plan
- All existing unit tests passing.
- Expand unit tests for rowwise scales to test more scatter dims
- Added unit tests enforcing that async TP fails fast / throws an error if it fails to perform any fusions. Previously it just "failed silently" (fell back to vanilla TP without the user knowing) which has led to confusion, so this will improve the UX.
- Compared loss curves of bf16 vs float8 w/ rowwise scales to confirm integrity of numerics
- Confirmed via manual testing with torchtitan and inspecting the compile graph that the fusion is working as intended for:
- bfloat16
- float8 with tensorwise scales
- float8 with rowwise scales
## Loss curves
Loss curves are virtually identical for bf16 + vanilla TP versus float8 with rowwise scales + async TP:
<img width="1017" alt="loss_async_tp" src="https://github.com/user-attachments/assets/4995db78-7012-490f-a370-f4fecc289a22" />
## Performance
#### Per op SAC
Performance benchmarks for torchtitan Llama3 8b training runs on 4 H100s with per op SAC, using FSDP degree=2, TP degree=2:
- bf16 (vanilla TP): TPS 5161.5, peak memory 50.53 GB
- bf16 (async TP): TPS 5229.5, peak memory 50.68 GB
- float8 tensorwise (vanilla TP): TPS: 5959.5, peak memory: 50.47 GB
- float8 tensorwise (async TP): TPS 5964.5, peak memory 50.47 GB
- float8 rowwise (vanilla TP): TPS: 4962.0, peak memory: 50.55 GB
- float8 rowwise (async TP): TPS 4966.5, peak memory 50.65 GB
#### Full AC
Llama3 70b training runs on 128 H100s with full AC, using FSDP=16, TP=8
- bf16 (vanilla TP): 598 TPS, peak memory 71.51 GB
- bf16 (async TP): TPS 673, peak memory 71.08 (+12.54% TPS vs vanilla TP)
- float8 tensorwise (vanilla TP): 820 TPS, peak memory 55.26 GB
- float8 tensorwise (async TP): 950 TPS, peak memory 55.91 GB (+15.85% TPS vs vanilla TP)
- float8 rowwise (vanilla TP): TPS: 540 TPS, peak memory 71.46 GB
- float8 rowwise (async TP): 560 TPS, peak memory 70.65 GB (+3.7% TPS vs vanilla TP but still unexpectedly lower than bf16)
As you can see, float8 rowwise is working but performance needs to be improved further.
## Other changes
- Added logging so the user will know why fusion failed if it does.
- Remove logic which inserted a reshape node targeting "A scale" to get it to be in 3D like the "A tensor" since it's no longer needed.
## Long term plan
- Add a `scaled_matmul` op in pytorch, which will natively support a 3D+ "A tensor" and allow us to simplify the async TP implementation by avoiding the reshape -> scaled_mm -> reshape pattern and the special handling for it.
## Visualizing fused nodes in graphs for torchtitan training runs
Below are examples of the visualized graph generated by torch compile for torchtitan llama3 8b training runs with per op SAC. These graphs provide additional evidence (beyond the new unit tests added) that the implementation is working correctly.
### bf16
<img width="900" alt="bf16-fusion" src="https://github.com/user-attachments/assets/a3bed917-28eb-4a56-8d6e-2d2bf498385c" />
### float8 with tensorwise scales
<img width="900" alt="tensorwise-node" src="https://github.com/user-attachments/assets/b212ec4a-1899-44de-a4de-18c74e1de68a" />
### float8 with rowwise scales
<img width="900" alt="rowwise" src="https://github.com/user-attachments/assets/ed3354a3-894b-4ec9-86d0-f80364bf3d83" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149247
Approved by: https://github.com/kwen2501
Summary: Previously we added support for `all_reduce` to non strict. This PR extends this support to other non-functional collectives that are remapped in Dynamo: `all_gather`, `all_gather_into_tensor`, `all_to_all_single`, `reduce_scatter_tensor`.
Test Plan: added unit tests
Differential Revision: D69813991
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147417
Approved by: https://github.com/angelayi
Fixes https://github.com/pytorch/pytorch/issues/142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary.
This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)).
Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself.
One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode.
This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142075
Approved by: https://github.com/yifuwang, https://github.com/kwen2501
ghstack dependencies: #141725, #141728
Based on discussion here: https://github.com/pytorch/pytorch/pull/138731
Introducing ability for subclass implement type convertion to expected_type.
```
def __coerce_same_metadata_as_tangent__(
self, expected_metadata: Any, expected_type: Optional[Type] = None
):
```
Here if `expected_type=None` means `SubclassClass` is expected.
E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case
`expected_type=Tensor` will be called during runtime
Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139095
Approved by: https://github.com/bdhirsh
This PR aims to support the following use case:
```python
def all_reduce_eager(x):
y = x * x
req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
assert isinstance(req, torch.distributed.Work)
return y
@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
torch.ops.c10d_functional.wait_tensor(y)
return y * y
x = torch.ones(1280, 1280, device="cuda") + self.rank
with allow_inflight_collective_as_graph_input_ctx():
y = all_reduce_eager(x)
z = all_reduce_wait_compiled(y)
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.
This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.
----
**Update**: Did two items to prevent regression to existing use cases:
1. Added memory-stressed test case to test_c10d_nccl.py `test_unwaited` to cover existing user's "not calling work.wait() for non-functional collective" use case
2. Gated all new `register_work()` / `unregister_work()` calls with `c10d::allow_inflight_collective_as_graph_input()` check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).
The risk of this new version of PR causing regression should be very low.
------
Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`
------
Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang