[c10d][Partial-Graph Overlap] Support calling .wait_tensor() on output tensor of eager async_op=True collective if under allow_inflight_collective_as_graph_input_ctx() context manager (#137763)

This PR aims to support the following use case:
```python
def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y

x = torch.ones(1280, 1280, device="cuda") + self.rank
with allow_inflight_collective_as_graph_input_ctx():
    y = all_reduce_eager(x)
    z = all_reduce_wait_compiled(y)
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.

This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.

----

**Update**: Did two items to prevent regression to existing use cases:

1. Added memory-stressed test case to test_c10d_nccl.py `test_unwaited` to cover existing user's "not calling work.wait() for non-functional collective" use case
2. Gated all new `register_work()` / `unregister_work()` calls with `c10d::allow_inflight_collective_as_graph_input()` check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).

The risk of this new version of PR causing regression should be very low.

------

Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`

------

Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
This commit is contained in:
Will Feng
2024-10-28 14:52:18 -07:00
committed by PyTorch MergeBot
parent d8f99f39cb
commit 4ee514144b
15 changed files with 625 additions and 112 deletions

View File

@ -1,6 +1,7 @@
#pragma once
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <memory>
#include <unordered_map>
#include <utility>
@ -23,6 +24,31 @@ constexpr auto kProcessGroupDefaultTimeout =
namespace c10d {
// We only call `register_work()` in two cases:
// 1. If the work object is created from a functional collective call.
// 2. If the work object is created from a non-functional collective call within
// the `with allow_inflight_collective_as_graph_input_ctx()` context manager.
C10_EXPORT void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work);
C10_EXPORT at::Tensor wait_tensor(const at::Tensor& tensor);
// We only call `unregister_work()` in one case:
// 1. If the work object is created from a non-functional collective call within
// the `with allow_inflight_collective_as_graph_input_ctx()` context manager.
//
// Q: What about the functional collective case?
// A: The unregistration of work object for functional collective is done in
// the required user-side explicit call to `wait_tensor()`.
C10_EXPORT void unregister_work(const c10::intrusive_ptr<c10d::Work>& work);
C10_EXPORT size_t get_work_registry_size();
C10_EXPORT void set_allow_inflight_collective_as_graph_input(bool value);
C10_EXPORT bool allow_inflight_collective_as_graph_input();
// ProcessGroup is a base class that captures collective and point to
// point communication in a fixed set of processes.
//
@ -158,13 +184,20 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// It's awakward to unbox the opts here and box them again in the custom C++
// op. But it's also complicated to make opts as a CustomClassHolder. Leave
// it as it is now.
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.rootTensor,
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> allreduce(
@ -181,12 +214,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::optional<at::Tensor>& sparse_indices,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.sparseIndices,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> allreduce_coalesced(
@ -200,11 +240,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
return op.call(
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> reduce(
@ -219,13 +266,20 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
int64_t,
int64_t,
int64_t)>();
return op.call(
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.rootRank,
opts.rootTensor,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> allgather(
@ -242,11 +296,20 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor_list : outputTensors) {
for (const auto& tensor : tensor_list) {
c10d::register_work(tensor, work);
}
}
}
return work;
}
// Gathers a single tensor inputBuffer into a single buffer outputBuffer that
@ -267,12 +330,17 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
bool,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(outputBuffer, work);
}
return work;
}
// This function is deprecated and will be moved out of ProcessGroup to comms:
@ -291,10 +359,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
return op.call(
auto work = op.call(
outputTensorLists,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor_list : outputTensorLists) {
for (const auto& tensor : tensor_list) {
c10d::register_work(tensor, work);
}
}
}
return work;
}
// This function is a coalesced version of `allgather_into_tensor` (currently
@ -312,10 +389,17 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
return op.call(
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> gather(
@ -330,12 +414,21 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return op.call(
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor_list : outputTensors) {
for (const auto& tensor : tensor_list) {
c10d::register_work(tensor, work);
}
}
}
return work;
}
virtual c10::intrusive_ptr<Work> scatter(
@ -353,13 +446,20 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
int64_t,
bool,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> reduce_scatter(
@ -376,12 +476,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> _reduce_scatter_base(
@ -398,13 +505,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(outputBuffer, work);
}
return work;
}
// This function is a coalesced version of `reduce_scatter_tensor` (currently
@ -424,12 +536,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
return op.call(
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> alltoall_base(
@ -447,13 +566,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::vector<int64_t>,
std::vector<int64_t>,
int64_t)>();
return op.call(
auto work = op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
outputSplitSizes,
inputSplitSizes,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(outputBuffer, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> alltoall(
@ -469,11 +593,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual void monitoredBarrier(
@ -549,11 +680,17 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return op.call(
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
dstRank,
tag);
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> recv(
@ -567,11 +704,17 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return op.call(
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
srcRank,
tag);
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> recvAnysource(
@ -583,10 +726,16 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
return op.call(
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
tag);
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> barrier(
@ -618,11 +767,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<int64_t>&,
int64_t)>();
return op.call(
auto work = op.call(
tensor,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.device_ids,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(tensor, work);
}
return work;
}
bool hasBackends() {