[c10d][Partial-Graph Overlap] Support calling .wait_tensor() on output tensor of eager async_op=True collective if under allow_inflight_collective_as_graph_input_ctx() context manager (#137763)

This PR aims to support the following use case:
```python
def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y

x = torch.ones(1280, 1280, device="cuda") + self.rank
with allow_inflight_collective_as_graph_input_ctx():
    y = all_reduce_eager(x)
    z = all_reduce_wait_compiled(y)
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.

This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.

----

**Update**: Did two items to prevent regression to existing use cases:

1. Added memory-stressed test case to test_c10d_nccl.py `test_unwaited` to cover existing user's "not calling work.wait() for non-functional collective" use case
2. Gated all new `register_work()` / `unregister_work()` calls with `c10d::allow_inflight_collective_as_graph_input()` check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).

The risk of this new version of PR causing regression should be very low.

------

Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`

------

Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
This commit is contained in:
Will Feng
2024-10-28 14:52:18 -07:00
committed by PyTorch MergeBot
parent d8f99f39cb
commit 4ee514144b
15 changed files with 625 additions and 112 deletions

View File

@ -405,6 +405,22 @@ class TestWithNCCL(MultiProcessTestCase):
assert output.eq(expect).all()
assert output.completed
@skip_if_lt_x_gpu(2)
def test_wait_tensor(self) -> None:
self._init_process_group()
input = torch.full((10, 10), float(self.rank), device=self.device)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
output = torch.ops._c10d_functional.all_reduce(
input,
"avg",
"default",
)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
torch.ops._c10d_functional.wait_tensor(output)
# `wait_tensor(output)` will pop the work from the work registry immediately
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
@skip_if_lt_x_gpu(2)
def test_unwaited(self) -> None:
# Verify that the process can terminate gracefully
@ -412,11 +428,13 @@ class TestWithNCCL(MultiProcessTestCase):
self._init_process_group()
input = torch.full((10, 10), float(self.rank), device=self.device)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
output = torch.ops._c10d_functional.all_reduce(
input,
"avg",
"default",
)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
@skip_if_lt_x_gpu(2)
def test_py_work(self) -> None:

View File

@ -20,6 +20,7 @@ from unittest import mock, SkipTest
import torch
import torch.distributed as c10d
import torch.distributed._functional_collectives as _functional_collectives
if not c10d.is_available() or not c10d.is_nccl_available():
@ -3218,6 +3219,86 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
with self.assertRaisesRegex(TypeError, "Invalid function argument"):
c10d.barrier(device_ids=self.rank)
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_unwaited(self) -> None:
# Verify that the process can terminate gracefully
# even with unwaited tensors
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
)
# Case 1: Run collectives under context manager, and don't call wait on them.
with _functional_collectives.allow_inflight_collective_as_graph_input_ctx():
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
input = torch.full(
(10240, 10240), float(self.rank), device=f"cuda:{self.rank}"
)
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
# Non-functional collectives run under the context manager is registered in the work registry.
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
# Running another collective on the same tensor should still work
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)
# Case 2: Run collectives not under context manager, and don't call wait on them.
# NOTE: Here we intentionally test memory-stressed case.
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)
for _ in range(50000):
input = torch.full(
(1024, 1024), float(self.rank), device=f"cuda:{self.rank}"
)
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
# Work registry size is unchanged, since non-functional collectives not run under
# the context manager is not registered in the work registry.
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_wait_tensor(self) -> None:
# Verify that c10d_functional.wait_tensor() can be invoked on
# output tensor of non-functional collective
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
)
# Case 1: under context manager (i.e. work is registered in registry)
with _functional_collectives.allow_inflight_collective_as_graph_input_ctx():
input1 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}")
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
torch.ops.c10d_functional.wait_tensor(input1)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
input2 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}")
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
work.wait()
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
self.assertEqual(input1, input2)
# Case 2: not under context manager (i.e. work is not registered in registry)
input1 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}")
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
# this does not take effect, since the underlying wait_tensor() logic would not
# be able to find the corresponding work object (because it's not registered in registry)
torch.ops.c10d_functional.wait_tensor(input1)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
input2 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}")
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
work.wait()
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
self.assertEqual(input1, input2)
@requires_nccl()
@skip_if_lt_x_gpu(2)
@with_dist_debug_levels(levels=["DETAIL"])

View File

@ -1,4 +1,5 @@
# Owner(s): ["module: dynamo"]
import datetime
import functools
import unittest
from unittest.mock import patch
@ -28,6 +29,7 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
requires_cuda,
skipIfRocm,
)
from torch.testing._internal.inductor_utils import HAS_GPU
@ -245,6 +247,90 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@skipIfRocm
def test_eager_async_allreduce_inductor_wait(self):
import torch.distributed as dist
from torch._inductor.utils import run_and_get_code
def all_reduce_non_functional_eager(x):
y = x * x
work = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
assert isinstance(work, torch.distributed.Work)
return work, y
def all_reduce_wait(work, y): # potentially compiled
if torch.compiler.is_dynamo_compiling():
torch.ops.c10d_functional.wait_tensor(y)
else:
work.wait(datetime.timedelta(seconds=10))
# Under compile, if `wait_tensor(y)` above is correctly executed,
# `y`'s data is in its final form and the output of this function will match eager;
# otherwise, `y * y` will run in parallel with `all_reduce(y)` and the output of this function
# will not match eager.
return y * y
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
x = torch.ones(12800, 12800, device="cuda") + self.rank
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
# NOTE: We run for 10 iterations each, to ensure that the GPU execution is way behind CPU
# and that `y * y` on CPU side will be issued before `all_reduce(y)` on GPU side is done,
# thus guaranteeing that in the bad case `y * y` on GPU side will run in parallel with `all_reduce(y)`
# thus will produce the wrong result that fails the unit test.
def _run_loop_collective_wait(x, wait_fn, expected_registry_size):
for _ in range(10):
self.assertEqual(
torch._C._distributed_c10d._get_work_registry_size(), 0
)
work, y = all_reduce_non_functional_eager(x)
self.assertEqual(
torch._C._distributed_c10d._get_work_registry_size(),
expected_registry_size,
)
out = wait_fn(work, y)
self.assertEqual(
torch._C._distributed_c10d._get_work_registry_size(), 0
)
return work, y, out
# Test: Pure-eager
all_reduce_wait_eager = all_reduce_wait
work, y, out_ref = _run_loop_collective_wait(
x,
wait_fn=all_reduce_wait_eager,
expected_registry_size=0,
)
all_reduce_wait_compiled = torch.compile(
all_reduce_wait,
backend="inductor",
fullgraph=True,
)
# Test: Issue comm in eager -> wait for comm in compile. Use the context manager.
with _functional_collectives.allow_inflight_collective_as_graph_input_ctx():
work, y, out_compiled = _run_loop_collective_wait(
x, wait_fn=all_reduce_wait_compiled, expected_registry_size=1
)
self.assertEqual(out_ref, out_compiled)
# Check that `wait_tensor()` is in the Inductor generated code
_, triton_codes = run_and_get_code(all_reduce_wait_compiled, work, y)
FileCheck().check("torch.ops._c10d_functional.wait_tensor.default(").run(
triton_codes[0]
)
# Failure Case: Issue comm in eager -> wait for comm in compile. Doesn't use the context manager.
_, _, out_compiled = _run_loop_collective_wait(
x, wait_fn=all_reduce_wait_compiled, expected_registry_size=0
)
# In this case `.wait_tensor(y)` in compiled region will not be able to find the corresponding work object
# to invoke the wait, thus the result will not match eager.
self.assertNotEqual(out_ref, out_compiled)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)

View File

@ -628,6 +628,11 @@ def _register_process_group(
) -> None: ...
def _resolve_process_group(group_name: str) -> ProcessGroup: ...
def _register_work(tensor: torch.Tensor, work: Work) -> ProcessGroup: ...
def _get_work_registry_size() -> int: ...
def _set_allow_inflight_collective_as_graph_input(
value: bool,
) -> None: ...
def _allow_inflight_collective_as_graph_input() -> bool: ...
def _unregister_all_process_groups() -> None: ...
def _unregister_process_group(group_name: str) -> None: ...

View File

@ -6,80 +6,10 @@
#include <torch/csrc/distributed/c10d/Functional.hpp>
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
#include <utility>
namespace {
class WorkRegistry {
public:
void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work) {
auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto [it, inserted] = registry_.try_emplace(std::move(storage), work);
TORCH_CHECK(
inserted || it->second != work,
"The tensor storage is already associated with another work.");
}
c10::intrusive_ptr<c10d::Work> pop_work(const at::Tensor& tensor) {
const auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto it = registry_.find(storage);
if (it == registry_.end()) {
return nullptr;
}
auto work = it->second;
registry_.erase(it);
return work;
}
~WorkRegistry() {
// If there are still unwaited work objects, their corresponding process
// groups should have already been destroyed at this stage. Any attempts to
// wait for these work objects or to destroy them will only result in
// confusing errors. Therefore, we simply issue a warning and intentionally
// allow the unwaited work objects to leak.
if (!registry_.empty()) {
TORCH_WARN(
"At the time of process termination, there are still ",
registry_.size(),
" unwaited c10d_functional collective calls. "
"Please review your program to ensure c10d_functional.wait_tensor() "
"is invoked on all tensors returned from c10d_functional collective "
"ops before they are used.");
}
for (auto& it : registry_) {
it.second.release();
}
}
private:
std::unordered_map<
c10::weak_intrusive_ptr<c10::StorageImpl>,
c10::intrusive_ptr<c10d::Work>>
registry_;
std::mutex lock_;
};
static WorkRegistry process_registry;
} // namespace
namespace c10d {
void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work) {
RankLocal<WorkRegistry>::get().register_work(tensor, work);
}
} // namespace c10d
namespace {
const std::unordered_map<std::string, c10d::ReduceOp> str_to_reduce_op = {
{"sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::SUM)},
{"avg", c10d::ReduceOp(c10d::ReduceOp::RedOpType::AVG)},
@ -296,14 +226,6 @@ at::Tensor broadcast(
return broadcast_(output, src, std::move(group_name));
}
at::Tensor wait_tensor(const at::Tensor& tensor) {
auto work = c10d::RankLocal<WorkRegistry>::get().pop_work(tensor);
if (work != nullptr) {
work->wait();
}
return tensor;
}
} // namespace
TORCH_LIBRARY(_c10d_functional, m) {
@ -389,7 +311,7 @@ TORCH_LIBRARY(_c10d_functional, m) {
m.def(
"wait_tensor(Tensor tensor) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::wait_tensor),
c10::DispatchKey::CompositeExplicitAutograd, c10d::wait_tensor),
{at::Tag::pt2_compliant_tag});
}
@ -438,7 +360,7 @@ class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
// TODO: track active cuda stream in wait
out = c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
.typed<decltype(wait_tensor)>()
.typed<decltype(c10d::wait_tensor)>()
.call(out);
return {out, at::Tensor(), at::Tensor(), at::Tensor()};
@ -493,7 +415,7 @@ class ReduceScatterTensor
// TODO: track active cuda stream in wait
out = c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
.typed<decltype(wait_tensor)>()
.typed<decltype(c10d::wait_tensor)>()
.call(out);
return {
@ -549,7 +471,7 @@ class AllGatherIntoTensor
// TODO: track active cuda stream in wait
out = c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
.typed<decltype(wait_tensor)>()
.typed<decltype(c10d::wait_tensor)>()
.call(out);
return {

View File

@ -1,11 +1,3 @@
#pragma once
#include <torch/csrc/distributed/c10d/Work.hpp>
namespace c10d {
C10_EXPORT void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work);
} // namespace c10d
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>

View File

@ -1,5 +1,6 @@
#include <ATen/ThreadLocalState.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
#include <c10/util/Logging.h>
#include <fmt/format.h>
@ -11,7 +12,6 @@
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
#include <utility>
namespace c10d {
@ -159,3 +159,172 @@ void ProcessGroup::release_resources() {
}
} // namespace c10d
namespace {
class WorkRegistry {
public:
void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work) {
if (!tensor.has_storage()) {
TORCH_WARN_ONCE(
"Registering collective work for tensor without storage is not supported. "
"Calling c10d_functional.wait_tensor() on this tensor will not wait for the collective to complete. "
"Unsupported tensor type: " +
tensor.toString());
return;
}
auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto it = registry_.find(storage);
if (it == registry_.end()) {
registry_.emplace(
std::move(storage),
std::vector<c10::intrusive_ptr<c10d::Work>>{work});
} else {
// There is no guarantee that the previous work object for this
// tensor storage is completed before the new work object is registered.
// Therefore we need to maintain a list of work objects for each tensor
// storage.
// Check if work is already in the list
bool work_exists = false;
for (const auto& existing_work : it->second) {
if (existing_work == work) {
work_exists = true;
break;
}
}
// Only append if work is not already in the list
if (!work_exists) {
it->second.push_back(work);
}
}
}
std::vector<c10::intrusive_ptr<c10d::Work>> pop_works(
const at::Tensor& tensor) {
const auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto it = registry_.find(storage);
if (it == registry_.end()) {
return {};
}
auto works = it->second;
registry_.erase(it);
return works;
}
void unregister_work(const c10::intrusive_ptr<c10d::Work>& work) {
std::unique_lock lock(lock_);
for (auto it = registry_.begin(); it != registry_.end();) {
std::vector<c10::intrusive_ptr<c10d::Work>> nonmatching_works;
for (const auto& _work : it->second) {
if (_work != work) {
nonmatching_works.push_back(_work);
}
}
if (nonmatching_works.empty()) {
it = registry_.erase(it);
} else {
it->second = std::move(nonmatching_works);
++it;
}
}
}
size_t get_work_registry_size() {
std::unique_lock lock(lock_);
size_t total_size = 0;
for (const auto& [storage, works] : registry_) {
total_size += works.size();
}
return total_size;
}
void set_allow_inflight_collective_as_graph_input(bool value) {
std::unique_lock lock(lock_);
allow_inflight_collective_as_graph_input_ = value;
}
bool allow_inflight_collective_as_graph_input() {
std::unique_lock lock(lock_);
return allow_inflight_collective_as_graph_input_;
}
~WorkRegistry() {
// If there are still unwaited work objects, their corresponding process
// groups should have already been destroyed at this stage. Any attempts to
// wait for these work objects or to destroy them will only result in
// confusing errors. Therefore, we simply issue a warning and intentionally
// allow the unwaited work objects to leak.
size_t registry_size = get_work_registry_size();
if (registry_size > 0) {
TORCH_WARN(
"At the time of process termination, there are still ",
registry_size,
" unwaited collective calls. "
"Please review your program to ensure that:\n"
"1. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,\n"
"2. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective "
"called under `with allow_inflight_collective_as_graph_input_ctx():`,\n"
"before the output tensors of the collective are used.");
}
for (auto& it : registry_) {
for (auto& work : it.second) {
work.release();
}
}
}
private:
std::unordered_map<
c10::weak_intrusive_ptr<c10::StorageImpl>,
std::vector<c10::intrusive_ptr<c10d::Work>>>
registry_;
bool allow_inflight_collective_as_graph_input_ = false;
std::mutex lock_;
};
static WorkRegistry process_registry;
} // namespace
namespace c10d {
void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work) {
RankLocal<WorkRegistry>::get().register_work(tensor, work);
}
at::Tensor wait_tensor(const at::Tensor& tensor) {
auto works = RankLocal<WorkRegistry>::get().pop_works(tensor);
for (const auto& work : works) {
work->wait();
}
return tensor;
}
void unregister_work(const c10::intrusive_ptr<c10d::Work>& work) {
RankLocal<WorkRegistry>::get().unregister_work(work);
}
size_t get_work_registry_size() {
return RankLocal<WorkRegistry>::get().get_work_registry_size();
}
void set_allow_inflight_collective_as_graph_input(bool value) {
return RankLocal<WorkRegistry>::get()
.set_allow_inflight_collective_as_graph_input(value);
}
bool allow_inflight_collective_as_graph_input() {
return RankLocal<WorkRegistry>::get()
.allow_inflight_collective_as_graph_input();
}
} // namespace c10d

View File

@ -1,6 +1,7 @@
#pragma once
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <memory>
#include <unordered_map>
#include <utility>
@ -23,6 +24,31 @@ constexpr auto kProcessGroupDefaultTimeout =
namespace c10d {
// We only call `register_work()` in two cases:
// 1. If the work object is created from a functional collective call.
// 2. If the work object is created from a non-functional collective call within
// the `with allow_inflight_collective_as_graph_input_ctx()` context manager.
C10_EXPORT void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work);
C10_EXPORT at::Tensor wait_tensor(const at::Tensor& tensor);
// We only call `unregister_work()` in one case:
// 1. If the work object is created from a non-functional collective call within
// the `with allow_inflight_collective_as_graph_input_ctx()` context manager.
//
// Q: What about the functional collective case?
// A: The unregistration of work object for functional collective is done in
// the required user-side explicit call to `wait_tensor()`.
C10_EXPORT void unregister_work(const c10::intrusive_ptr<c10d::Work>& work);
C10_EXPORT size_t get_work_registry_size();
C10_EXPORT void set_allow_inflight_collective_as_graph_input(bool value);
C10_EXPORT bool allow_inflight_collective_as_graph_input();
// ProcessGroup is a base class that captures collective and point to
// point communication in a fixed set of processes.
//
@ -158,13 +184,20 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// It's awakward to unbox the opts here and box them again in the custom C++
// op. But it's also complicated to make opts as a CustomClassHolder. Leave
// it as it is now.
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.rootTensor,
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> allreduce(
@ -181,12 +214,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::optional<at::Tensor>& sparse_indices,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.sparseIndices,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> allreduce_coalesced(
@ -200,11 +240,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
return op.call(
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> reduce(
@ -219,13 +266,20 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
int64_t,
int64_t,
int64_t)>();
return op.call(
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.rootRank,
opts.rootTensor,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> allgather(
@ -242,11 +296,20 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor_list : outputTensors) {
for (const auto& tensor : tensor_list) {
c10d::register_work(tensor, work);
}
}
}
return work;
}
// Gathers a single tensor inputBuffer into a single buffer outputBuffer that
@ -267,12 +330,17 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
bool,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(outputBuffer, work);
}
return work;
}
// This function is deprecated and will be moved out of ProcessGroup to comms:
@ -291,10 +359,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
return op.call(
auto work = op.call(
outputTensorLists,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor_list : outputTensorLists) {
for (const auto& tensor : tensor_list) {
c10d::register_work(tensor, work);
}
}
}
return work;
}
// This function is a coalesced version of `allgather_into_tensor` (currently
@ -312,10 +389,17 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
return op.call(
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> gather(
@ -330,12 +414,21 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return op.call(
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor_list : outputTensors) {
for (const auto& tensor : tensor_list) {
c10d::register_work(tensor, work);
}
}
}
return work;
}
virtual c10::intrusive_ptr<Work> scatter(
@ -353,13 +446,20 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
int64_t,
bool,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> reduce_scatter(
@ -376,12 +476,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> _reduce_scatter_base(
@ -398,13 +505,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(outputBuffer, work);
}
return work;
}
// This function is a coalesced version of `reduce_scatter_tensor` (currently
@ -424,12 +536,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
return op.call(
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> alltoall_base(
@ -447,13 +566,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::vector<int64_t>,
std::vector<int64_t>,
int64_t)>();
return op.call(
auto work = op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
outputSplitSizes,
inputSplitSizes,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(outputBuffer, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> alltoall(
@ -469,11 +593,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
return std::get<1>(op.call(
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual void monitoredBarrier(
@ -549,11 +680,17 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return op.call(
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
dstRank,
tag);
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> recv(
@ -567,11 +704,17 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return op.call(
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
srcRank,
tag);
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> recvAnysource(
@ -583,10 +726,16 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
return op.call(
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
tag);
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> barrier(
@ -618,11 +767,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<int64_t>&,
int64_t)>();
return op.call(
auto work = op.call(
tensor,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.device_ids,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(tensor, work);
}
return work;
}
bool hasBackends() {

View File

@ -5,6 +5,7 @@
#include <torch/csrc/distributed/c10d/GlooDeviceFactory.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <chrono>
#include <exception>
@ -574,6 +575,11 @@ bool ProcessGroupGloo::SendWork::wait(std::chrono::milliseconds timeout) {
// Completes the Work object and throws the exception.
finishAndThrow(exception);
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupGloo::SendWork>::unsafe_reclaim_from_nonowning(this));
}
return sendCompleted;
}
@ -621,6 +627,11 @@ bool ProcessGroupGloo::RecvWork::wait(std::chrono::milliseconds timeout) {
// Completes the Work object and throws the exception.
finishAndThrow(exception);
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupGloo::RecvWork>::unsafe_reclaim_from_nonowning(this));
}
return recvCompleted;
}

View File

@ -7,6 +7,7 @@
#include <c10/core/DeviceGuard.h>
#include <c10/util/irange.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#if defined(OPEN_MPI) && OPEN_MPI
#include <mpi-ext.h> // Needed for CUDA-aware check
@ -198,6 +199,11 @@ bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) {
populateException();
std::rethrow_exception(exception_);
}
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupMPI::AsyncWork>::unsafe_reclaim_from_nonowning(this));
}
// Always return true, because abort API is not implemented.
return true;
}

View File

@ -725,6 +725,11 @@ void ProcessGroupNCCL::WorkNCCL::handleException(
void ProcessGroupNCCL::WorkNCCL::synchronize() {
synchronizeStream();
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupNCCL::WorkNCCL>::unsafe_reclaim_from_nonowning(this));
}
}
void ProcessGroupNCCL::WorkNCCL::synchronizeStream() {

View File

@ -2,6 +2,7 @@
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <c10/util/env.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
#include <torch/csrc/distributed/c10d/UCCTracing.hpp>
#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
@ -273,6 +274,11 @@ bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) {
Work::recordFunctionEndCallback_();
Work::recordFunctionEndCallback_ = nullptr;
}
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupUCC::WorkUCC>::unsafe_reclaim_from_nonowning(this));
}
return true;
}

View File

@ -1,4 +1,5 @@
#include <ATen/ThreadLocalState.h>
#include <distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <utility>
@ -70,7 +71,12 @@ std::vector<at::Tensor> Work::result() {
TORCH_CHECK(false, "result() not implemented.");
}
void Work::synchronize() {}
void Work::synchronize() {
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::unregister_work(
c10::intrusive_ptr<Work>::unsafe_reclaim_from_nonowning(this));
}
}
bool Work::wait(std::chrono::milliseconds timeout) {
std::unique_lock<std::mutex> lock(mutex_);

View File

@ -937,6 +937,21 @@ This class does not support ``__members__`` property.)");
py::arg("tensor"),
py::arg("work"));
module.def("_get_work_registry_size", []() {
return ::c10d::get_work_registry_size();
});
module.def(
"_set_allow_inflight_collective_as_graph_input",
[](bool value) {
return ::c10d::set_allow_inflight_collective_as_graph_input(value);
},
py::arg("value"));
module.def("_allow_inflight_collective_as_graph_input", []() {
return ::c10d::allow_inflight_collective_as_graph_input();
});
// Remove a group from the native registry
module.def(
"_unregister_process_group",

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import contextlib
import sys
import warnings
from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union
@ -816,6 +817,43 @@ def _maybe_wrap_tensor(self) -> torch.Tensor:
return cast(torch.Tensor, res)
@contextlib.contextmanager
def allow_inflight_collective_as_graph_input_ctx(value: bool = True):
"""
Context manager to temporarily set whether inflight collectives are allowed as torch.compile graph inputs.
Common use case is when the collective is issued in eager (with `async_op=True`) but waited in compiled region:
```
def all_reduce_eager(x):
y = x * x
req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
return y
@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
torch.ops.c10d_functional.wait_tensor(y)
return y * y
x = torch.ones(1280, 1280, device="cuda") + self.rank
# the context manager ensures that `wait_tensor(y)` will wait on the correct work object
with allow_inflight_collective_as_graph_input_ctx():
y = all_reduce_eager(x)
z = all_reduce_wait_compiled(y)
```
With this context manager, when a collective is called, under the hood the work object of the collective
will be registered in the work registry, and the wait_tensor() in compiled region called on
the output tensor of the collective will wait on the correct work object.
"""
previous = torch._C._distributed_c10d._allow_inflight_collective_as_graph_input()
try:
torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value)
yield
finally:
torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(
previous
)
def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size):
def mk_out_tensor(shard):
out_size = list(shard.size())