mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This PR aims to support the following use case: ```python def all_reduce_eager(x): y = x * x req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True) assert isinstance(req, torch.distributed.Work) return y @torch.compile(fullgraph=True) def all_reduce_wait_compiled(y): torch.ops.c10d_functional.wait_tensor(y) return y * y x = torch.ones(1280, 1280, device="cuda") + self.rank with allow_inflight_collective_as_graph_input_ctx(): y = all_reduce_eager(x) z = all_reduce_wait_compiled(y) ``` where the collective is issued in eager (with `async_op=True`) but waited in compiled region. This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel. ---- **Update**: Did two items to prevent regression to existing use cases: 1. Added memory-stressed test case to test_c10d_nccl.py `test_unwaited` to cover existing user's "not calling work.wait() for non-functional collective" use case 2. Gated all new `register_work()` / `unregister_work()` calls with `c10d::allow_inflight_collective_as_graph_input()` check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users). The risk of this new version of PR causing regression should be very low. ------ Test commands: - `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait` - `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives` - `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload` - `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor` - `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited` - `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor` - `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited` - `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal` - `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph` - `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing` - `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli` - `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True` - `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees` - `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco` ------ Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763 Approved by: https://github.com/yifuwang
4 lines
70 B
C++
4 lines
70 B
C++
#pragma once
|
|
|
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|