mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[c10d][Partial-Graph Overlap] Support calling .wait_tensor() on output tensor of eager async_op=True
collective if under allow_inflight_collective_as_graph_input_ctx()
context manager (#137763)"
This reverts commit a688c57033b4536ef59356cdad241d65ca52a869. Reverted https://github.com/pytorch/pytorch/pull/137763 on behalf of https://github.com/yf225 due to Seems to have bad interaction with latest commits on trunk, reverting to be safe ([comment](https://github.com/pytorch/pytorch/pull/137763#issuecomment-2442527696))
This commit is contained in:
@ -405,22 +405,6 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
assert output.eq(expect).all()
|
||||
assert output.completed
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_wait_tensor(self) -> None:
|
||||
self._init_process_group()
|
||||
|
||||
input = torch.full((10, 10), float(self.rank), device=self.device)
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
output = torch.ops._c10d_functional.all_reduce(
|
||||
input,
|
||||
"avg",
|
||||
"default",
|
||||
)
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
|
||||
torch.ops._c10d_functional.wait_tensor(output)
|
||||
# `wait_tensor(output)` will pop the work from the work registry immediately
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_unwaited(self) -> None:
|
||||
# Verify that the process can terminate gracefully
|
||||
@ -428,13 +412,11 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
self._init_process_group()
|
||||
|
||||
input = torch.full((10, 10), float(self.rank), device=self.device)
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
output = torch.ops._c10d_functional.all_reduce(
|
||||
input,
|
||||
"avg",
|
||||
"default",
|
||||
)
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_py_work(self) -> None:
|
||||
|
@ -20,7 +20,6 @@ from unittest import mock, SkipTest
|
||||
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.distributed._functional_collectives as _functional_collectives
|
||||
|
||||
|
||||
if not c10d.is_available() or not c10d.is_nccl_available():
|
||||
@ -3241,86 +3240,6 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||
with self.assertRaisesRegex(TypeError, "Invalid function argument"):
|
||||
c10d.barrier(device_ids=self.rank)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_unwaited(self) -> None:
|
||||
# Verify that the process can terminate gracefully
|
||||
# even with unwaited tensors
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
c10d.init_process_group(
|
||||
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
|
||||
)
|
||||
|
||||
# Case 1: Run collectives under context manager, and don't call wait on them.
|
||||
with _functional_collectives.allow_inflight_collective_as_graph_input_ctx():
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
input = torch.full(
|
||||
(10240, 10240), float(self.rank), device=f"cuda:{self.rank}"
|
||||
)
|
||||
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
|
||||
# Non-functional collectives run under the context manager is registered in the work registry.
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
|
||||
# Running another collective on the same tensor should still work
|
||||
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)
|
||||
|
||||
# Case 2: Run collectives not under context manager, and don't call wait on them.
|
||||
# NOTE: Here we intentionally test memory-stressed case.
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)
|
||||
for _ in range(50000):
|
||||
input = torch.full(
|
||||
(1024, 1024), float(self.rank), device=f"cuda:{self.rank}"
|
||||
)
|
||||
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
|
||||
# Work registry size is unchanged, since non-functional collectives not run under
|
||||
# the context manager is not registered in the work registry.
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_wait_tensor(self) -> None:
|
||||
# Verify that c10d_functional.wait_tensor() can be invoked on
|
||||
# output tensor of non-functional collective
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
c10d.init_process_group(
|
||||
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
|
||||
)
|
||||
|
||||
# Case 1: under context manager (i.e. work is registered in registry)
|
||||
with _functional_collectives.allow_inflight_collective_as_graph_input_ctx():
|
||||
input1 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}")
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True)
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
|
||||
torch.ops.c10d_functional.wait_tensor(input1)
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
|
||||
input2 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}")
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True)
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
|
||||
work.wait()
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
self.assertEqual(input1, input2)
|
||||
|
||||
# Case 2: not under context manager (i.e. work is not registered in registry)
|
||||
input1 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}")
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True)
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
# this does not take effect, since the underlying wait_tensor() logic would not
|
||||
# be able to find the corresponding work object (because it's not registered in registry)
|
||||
torch.ops.c10d_functional.wait_tensor(input1)
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
|
||||
input2 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}")
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True)
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
work.wait()
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
self.assertEqual(input1, input2)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_dist_debug_levels(levels=["DETAIL"])
|
||||
|
@ -1,5 +1,4 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import datetime
|
||||
import functools
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
@ -29,7 +28,6 @@ from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
requires_cuda,
|
||||
skipIfRocm,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
@ -247,90 +245,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
)
|
||||
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skipIfRocm
|
||||
def test_eager_async_allreduce_inductor_wait(self):
|
||||
import torch.distributed as dist
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
|
||||
def all_reduce_non_functional_eager(x):
|
||||
y = x * x
|
||||
work = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
|
||||
assert isinstance(work, torch.distributed.Work)
|
||||
return work, y
|
||||
|
||||
def all_reduce_wait(work, y): # potentially compiled
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
torch.ops.c10d_functional.wait_tensor(y)
|
||||
else:
|
||||
work.wait(datetime.timedelta(seconds=10))
|
||||
# Under compile, if `wait_tensor(y)` above is correctly executed,
|
||||
# `y`'s data is in its final form and the output of this function will match eager;
|
||||
# otherwise, `y * y` will run in parallel with `all_reduce(y)` and the output of this function
|
||||
# will not match eager.
|
||||
return y * y
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
x = torch.ones(12800, 12800, device="cuda") + self.rank
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
|
||||
# NOTE: We run for 10 iterations each, to ensure that the GPU execution is way behind CPU
|
||||
# and that `y * y` on CPU side will be issued before `all_reduce(y)` on GPU side is done,
|
||||
# thus guaranteeing that in the bad case `y * y` on GPU side will run in parallel with `all_reduce(y)`
|
||||
# thus will produce the wrong result that fails the unit test.
|
||||
|
||||
def _run_loop_collective_wait(x, wait_fn, expected_registry_size):
|
||||
for _ in range(10):
|
||||
self.assertEqual(
|
||||
torch._C._distributed_c10d._get_work_registry_size(), 0
|
||||
)
|
||||
work, y = all_reduce_non_functional_eager(x)
|
||||
self.assertEqual(
|
||||
torch._C._distributed_c10d._get_work_registry_size(),
|
||||
expected_registry_size,
|
||||
)
|
||||
out = wait_fn(work, y)
|
||||
self.assertEqual(
|
||||
torch._C._distributed_c10d._get_work_registry_size(), 0
|
||||
)
|
||||
return work, y, out
|
||||
|
||||
# Test: Pure-eager
|
||||
all_reduce_wait_eager = all_reduce_wait
|
||||
work, y, out_ref = _run_loop_collective_wait(
|
||||
x,
|
||||
wait_fn=all_reduce_wait_eager,
|
||||
expected_registry_size=0,
|
||||
)
|
||||
|
||||
all_reduce_wait_compiled = torch.compile(
|
||||
all_reduce_wait,
|
||||
backend="inductor",
|
||||
fullgraph=True,
|
||||
)
|
||||
|
||||
# Test: Issue comm in eager -> wait for comm in compile. Use the context manager.
|
||||
with _functional_collectives.allow_inflight_collective_as_graph_input_ctx():
|
||||
work, y, out_compiled = _run_loop_collective_wait(
|
||||
x, wait_fn=all_reduce_wait_compiled, expected_registry_size=1
|
||||
)
|
||||
self.assertEqual(out_ref, out_compiled)
|
||||
|
||||
# Check that `wait_tensor()` is in the Inductor generated code
|
||||
_, triton_codes = run_and_get_code(all_reduce_wait_compiled, work, y)
|
||||
FileCheck().check("torch.ops._c10d_functional.wait_tensor.default(").run(
|
||||
triton_codes[0]
|
||||
)
|
||||
|
||||
# Failure Case: Issue comm in eager -> wait for comm in compile. Doesn't use the context manager.
|
||||
_, _, out_compiled = _run_loop_collective_wait(
|
||||
x, wait_fn=all_reduce_wait_compiled, expected_registry_size=0
|
||||
)
|
||||
# In this case `.wait_tensor(y)` in compiled region will not be able to find the corresponding work object
|
||||
# to invoke the wait, thus the result will not match eager.
|
||||
self.assertNotEqual(out_ref, out_compiled)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
||||
|
@ -628,11 +628,6 @@ def _register_process_group(
|
||||
) -> None: ...
|
||||
def _resolve_process_group(group_name: str) -> ProcessGroup: ...
|
||||
def _register_work(tensor: torch.Tensor, work: Work) -> ProcessGroup: ...
|
||||
def _get_work_registry_size() -> int: ...
|
||||
def _set_allow_inflight_collective_as_graph_input(
|
||||
value: bool,
|
||||
) -> None: ...
|
||||
def _allow_inflight_collective_as_graph_input() -> bool: ...
|
||||
def _unregister_all_process_groups() -> None: ...
|
||||
def _unregister_process_group(group_name: str) -> None: ...
|
||||
|
||||
|
@ -6,10 +6,80 @@
|
||||
#include <torch/csrc/distributed/c10d/Functional.hpp>
|
||||
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
|
||||
#include <utility>
|
||||
|
||||
namespace {
|
||||
|
||||
class WorkRegistry {
|
||||
public:
|
||||
void register_work(
|
||||
const at::Tensor& tensor,
|
||||
const c10::intrusive_ptr<c10d::Work>& work) {
|
||||
auto storage = tensor.storage().getWeakStorageImpl();
|
||||
std::unique_lock lock(lock_);
|
||||
auto [it, inserted] = registry_.try_emplace(std::move(storage), work);
|
||||
TORCH_CHECK(
|
||||
inserted || it->second != work,
|
||||
"The tensor storage is already associated with another work.");
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10d::Work> pop_work(const at::Tensor& tensor) {
|
||||
const auto storage = tensor.storage().getWeakStorageImpl();
|
||||
std::unique_lock lock(lock_);
|
||||
auto it = registry_.find(storage);
|
||||
if (it == registry_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto work = it->second;
|
||||
registry_.erase(it);
|
||||
return work;
|
||||
}
|
||||
|
||||
~WorkRegistry() {
|
||||
// If there are still unwaited work objects, their corresponding process
|
||||
// groups should have already been destroyed at this stage. Any attempts to
|
||||
// wait for these work objects or to destroy them will only result in
|
||||
// confusing errors. Therefore, we simply issue a warning and intentionally
|
||||
// allow the unwaited work objects to leak.
|
||||
if (!registry_.empty()) {
|
||||
TORCH_WARN(
|
||||
"At the time of process termination, there are still ",
|
||||
registry_.size(),
|
||||
" unwaited c10d_functional collective calls. "
|
||||
"Please review your program to ensure c10d_functional.wait_tensor() "
|
||||
"is invoked on all tensors returned from c10d_functional collective "
|
||||
"ops before they are used.");
|
||||
}
|
||||
for (auto& it : registry_) {
|
||||
it.second.release();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<
|
||||
c10::weak_intrusive_ptr<c10::StorageImpl>,
|
||||
c10::intrusive_ptr<c10d::Work>>
|
||||
registry_;
|
||||
std::mutex lock_;
|
||||
};
|
||||
|
||||
static WorkRegistry process_registry;
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace c10d {
|
||||
|
||||
void register_work(
|
||||
const at::Tensor& tensor,
|
||||
const c10::intrusive_ptr<c10d::Work>& work) {
|
||||
RankLocal<WorkRegistry>::get().register_work(tensor, work);
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
namespace {
|
||||
|
||||
const std::unordered_map<std::string, c10d::ReduceOp> str_to_reduce_op = {
|
||||
{"sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::SUM)},
|
||||
{"avg", c10d::ReduceOp(c10d::ReduceOp::RedOpType::AVG)},
|
||||
@ -226,6 +296,14 @@ at::Tensor broadcast(
|
||||
return broadcast_(output, src, std::move(group_name));
|
||||
}
|
||||
|
||||
at::Tensor wait_tensor(const at::Tensor& tensor) {
|
||||
auto work = c10d::RankLocal<WorkRegistry>::get().pop_work(tensor);
|
||||
if (work != nullptr) {
|
||||
work->wait();
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TORCH_LIBRARY(_c10d_functional, m) {
|
||||
@ -311,7 +389,7 @@ TORCH_LIBRARY(_c10d_functional, m) {
|
||||
m.def(
|
||||
"wait_tensor(Tensor tensor) -> Tensor",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, c10d::wait_tensor),
|
||||
c10::DispatchKey::CompositeExplicitAutograd, ::wait_tensor),
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
}
|
||||
|
||||
@ -360,7 +438,7 @@ class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
|
||||
// TODO: track active cuda stream in wait
|
||||
out = c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
|
||||
.typed<decltype(c10d::wait_tensor)>()
|
||||
.typed<decltype(wait_tensor)>()
|
||||
.call(out);
|
||||
|
||||
return {out, at::Tensor(), at::Tensor(), at::Tensor()};
|
||||
@ -415,7 +493,7 @@ class ReduceScatterTensor
|
||||
// TODO: track active cuda stream in wait
|
||||
out = c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
|
||||
.typed<decltype(c10d::wait_tensor)>()
|
||||
.typed<decltype(wait_tensor)>()
|
||||
.call(out);
|
||||
|
||||
return {
|
||||
@ -471,7 +549,7 @@ class AllGatherIntoTensor
|
||||
// TODO: track active cuda stream in wait
|
||||
out = c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
|
||||
.typed<decltype(c10d::wait_tensor)>()
|
||||
.typed<decltype(wait_tensor)>()
|
||||
.call(out);
|
||||
|
||||
return {
|
||||
|
@ -1,3 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Work.hpp>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
C10_EXPORT void register_work(
|
||||
const at::Tensor& tensor,
|
||||
const c10::intrusive_ptr<c10d::Work>& work);
|
||||
|
||||
} // namespace c10d
|
||||
|
@ -1,6 +1,5 @@
|
||||
#include <ATen/ThreadLocalState.h>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <fmt/format.h>
|
||||
@ -12,6 +11,7 @@
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
|
||||
#include <utility>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
@ -159,172 +159,3 @@ void ProcessGroup::release_resources() {
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
namespace {
|
||||
|
||||
class WorkRegistry {
|
||||
public:
|
||||
void register_work(
|
||||
const at::Tensor& tensor,
|
||||
const c10::intrusive_ptr<c10d::Work>& work) {
|
||||
if (!tensor.has_storage()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Registering collective work for tensor without storage is not supported. "
|
||||
"Calling c10d_functional.wait_tensor() on this tensor will not wait for the collective to complete. "
|
||||
"Unsupported tensor type: " +
|
||||
tensor.toString());
|
||||
return;
|
||||
}
|
||||
auto storage = tensor.storage().getWeakStorageImpl();
|
||||
std::unique_lock lock(lock_);
|
||||
|
||||
auto it = registry_.find(storage);
|
||||
if (it == registry_.end()) {
|
||||
registry_.emplace(
|
||||
std::move(storage),
|
||||
std::vector<c10::intrusive_ptr<c10d::Work>>{work});
|
||||
} else {
|
||||
// There is no guarantee that the previous work object for this
|
||||
// tensor storage is completed before the new work object is registered.
|
||||
// Therefore we need to maintain a list of work objects for each tensor
|
||||
// storage.
|
||||
|
||||
// Check if work is already in the list
|
||||
bool work_exists = false;
|
||||
for (const auto& existing_work : it->second) {
|
||||
if (existing_work == work) {
|
||||
work_exists = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Only append if work is not already in the list
|
||||
if (!work_exists) {
|
||||
it->second.push_back(work);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<c10::intrusive_ptr<c10d::Work>> pop_works(
|
||||
const at::Tensor& tensor) {
|
||||
const auto storage = tensor.storage().getWeakStorageImpl();
|
||||
std::unique_lock lock(lock_);
|
||||
auto it = registry_.find(storage);
|
||||
if (it == registry_.end()) {
|
||||
return {};
|
||||
}
|
||||
auto works = it->second;
|
||||
registry_.erase(it);
|
||||
return works;
|
||||
}
|
||||
|
||||
void unregister_work(const c10::intrusive_ptr<c10d::Work>& work) {
|
||||
std::unique_lock lock(lock_);
|
||||
for (auto it = registry_.begin(); it != registry_.end();) {
|
||||
std::vector<c10::intrusive_ptr<c10d::Work>> nonmatching_works;
|
||||
for (const auto& _work : it->second) {
|
||||
if (_work != work) {
|
||||
nonmatching_works.push_back(_work);
|
||||
}
|
||||
}
|
||||
if (nonmatching_works.empty()) {
|
||||
it = registry_.erase(it);
|
||||
} else {
|
||||
it->second = std::move(nonmatching_works);
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t get_work_registry_size() {
|
||||
std::unique_lock lock(lock_);
|
||||
size_t total_size = 0;
|
||||
for (const auto& [storage, works] : registry_) {
|
||||
total_size += works.size();
|
||||
}
|
||||
return total_size;
|
||||
}
|
||||
|
||||
void set_allow_inflight_collective_as_graph_input(bool value) {
|
||||
std::unique_lock lock(lock_);
|
||||
allow_inflight_collective_as_graph_input_ = value;
|
||||
}
|
||||
|
||||
bool allow_inflight_collective_as_graph_input() {
|
||||
std::unique_lock lock(lock_);
|
||||
return allow_inflight_collective_as_graph_input_;
|
||||
}
|
||||
|
||||
~WorkRegistry() {
|
||||
// If there are still unwaited work objects, their corresponding process
|
||||
// groups should have already been destroyed at this stage. Any attempts to
|
||||
// wait for these work objects or to destroy them will only result in
|
||||
// confusing errors. Therefore, we simply issue a warning and intentionally
|
||||
// allow the unwaited work objects to leak.
|
||||
size_t registry_size = get_work_registry_size();
|
||||
if (registry_size > 0) {
|
||||
TORCH_WARN(
|
||||
"At the time of process termination, there are still ",
|
||||
registry_size,
|
||||
" unwaited collective calls. "
|
||||
"Please review your program to ensure that:\n"
|
||||
"1. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,\n"
|
||||
"2. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective "
|
||||
"called under `with allow_inflight_collective_as_graph_input():`,\n"
|
||||
"before the output tensors of the collective are used.");
|
||||
}
|
||||
for (auto& it : registry_) {
|
||||
for (auto& work : it.second) {
|
||||
work.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<
|
||||
c10::weak_intrusive_ptr<c10::StorageImpl>,
|
||||
std::vector<c10::intrusive_ptr<c10d::Work>>>
|
||||
registry_;
|
||||
bool allow_inflight_collective_as_graph_input_ = false;
|
||||
std::mutex lock_;
|
||||
};
|
||||
|
||||
static WorkRegistry process_registry;
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace c10d {
|
||||
|
||||
void register_work(
|
||||
const at::Tensor& tensor,
|
||||
const c10::intrusive_ptr<c10d::Work>& work) {
|
||||
RankLocal<WorkRegistry>::get().register_work(tensor, work);
|
||||
}
|
||||
|
||||
at::Tensor wait_tensor(const at::Tensor& tensor) {
|
||||
auto works = RankLocal<WorkRegistry>::get().pop_works(tensor);
|
||||
for (const auto& work : works) {
|
||||
work->wait();
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void unregister_work(const c10::intrusive_ptr<c10d::Work>& work) {
|
||||
RankLocal<WorkRegistry>::get().unregister_work(work);
|
||||
}
|
||||
|
||||
size_t get_work_registry_size() {
|
||||
return RankLocal<WorkRegistry>::get().get_work_registry_size();
|
||||
}
|
||||
|
||||
void set_allow_inflight_collective_as_graph_input(bool value) {
|
||||
return RankLocal<WorkRegistry>::get()
|
||||
.set_allow_inflight_collective_as_graph_input(value);
|
||||
}
|
||||
|
||||
bool allow_inflight_collective_as_graph_input() {
|
||||
return RankLocal<WorkRegistry>::get()
|
||||
.allow_inflight_collective_as_graph_input();
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Work.hpp>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
@ -24,20 +23,6 @@ constexpr auto kProcessGroupDefaultTimeout =
|
||||
|
||||
namespace c10d {
|
||||
|
||||
C10_EXPORT void register_work(
|
||||
const at::Tensor& tensor,
|
||||
const c10::intrusive_ptr<c10d::Work>& work);
|
||||
|
||||
C10_EXPORT at::Tensor wait_tensor(const at::Tensor& tensor);
|
||||
|
||||
C10_EXPORT void unregister_work(const c10::intrusive_ptr<c10d::Work>& work);
|
||||
|
||||
C10_EXPORT size_t get_work_registry_size();
|
||||
|
||||
C10_EXPORT void set_allow_inflight_collective_as_graph_input(bool value);
|
||||
|
||||
C10_EXPORT bool allow_inflight_collective_as_graph_input();
|
||||
|
||||
// ProcessGroup is a base class that captures collective and point to
|
||||
// point communication in a fixed set of processes.
|
||||
//
|
||||
@ -173,20 +158,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
// It's awakward to unbox the opts here and box them again in the custom C++
|
||||
// op. But it's also complicated to make opts as a CustomClassHolder. Leave
|
||||
// it as it is now.
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.rootRank,
|
||||
opts.rootTensor,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> allreduce(
|
||||
@ -203,19 +181,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const std::optional<at::Tensor>& sparse_indices,
|
||||
int64_t)>();
|
||||
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<ReduceOp>(opts.reduceOp),
|
||||
opts.sparseIndices,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> allreduce_coalesced(
|
||||
@ -229,18 +200,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
int64_t)>();
|
||||
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<ReduceOp>(opts.reduceOp),
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> reduce(
|
||||
@ -255,20 +219,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
int64_t,
|
||||
int64_t,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<ReduceOp>(opts.reduceOp),
|
||||
opts.rootRank,
|
||||
opts.rootTensor,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> allgather(
|
||||
@ -285,20 +242,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t)>();
|
||||
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor_list : outputTensors) {
|
||||
for (const auto& tensor : tensor_list) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
// Gathers a single tensor inputBuffer into a single buffer outputBuffer that
|
||||
@ -319,17 +267,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
bool,
|
||||
int64_t)>();
|
||||
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
outputBuffer,
|
||||
inputBuffer,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
c10d::register_work(outputBuffer, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
// This function is deprecated and will be moved out of ProcessGroup to comms:
|
||||
@ -348,19 +291,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const at::TensorList&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
|
||||
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
outputTensorLists,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor_list : outputTensorLists) {
|
||||
for (const auto& tensor : tensor_list) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
// This function is a coalesced version of `allgather_into_tensor` (currently
|
||||
@ -378,17 +312,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const at::TensorList,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
|
||||
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : outputTensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> gather(
|
||||
@ -403,21 +330,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.rootRank,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor_list : outputTensors) {
|
||||
for (const auto& tensor : tensor_list) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> scatter(
|
||||
@ -435,20 +353,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
int64_t,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.rootRank,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : outputTensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> reduce_scatter(
|
||||
@ -465,19 +376,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
int64_t)>();
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : outputTensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> _reduce_scatter_base(
|
||||
@ -494,18 +398,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
outputBuffer,
|
||||
inputBuffer,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
c10d::register_work(outputBuffer, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
// This function is a coalesced version of `reduce_scatter_tensor` (currently
|
||||
@ -525,19 +424,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
int64_t)>();
|
||||
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : outputTensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> alltoall_base(
|
||||
@ -555,18 +447,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
outputBuffer,
|
||||
inputBuffer,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
outputSplitSizes,
|
||||
inputSplitSizes,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
c10d::register_work(outputBuffer, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> alltoall(
|
||||
@ -582,18 +469,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const at::TensorList&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t)>();
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : outputTensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual void monitoredBarrier(
|
||||
@ -669,17 +549,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
dstRank,
|
||||
tag);
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> recv(
|
||||
@ -693,17 +567,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
srcRank,
|
||||
tag);
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> recvAnysource(
|
||||
@ -715,16 +583,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
at::TensorList,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
tag);
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> barrier(
|
||||
@ -756,15 +618,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const std::vector<int64_t>&,
|
||||
int64_t)>();
|
||||
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
tensor,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.device_ids,
|
||||
opts.timeout.count());
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
bool hasBackends() {
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
#include <torch/csrc/distributed/c10d/GlooDeviceFactory.hpp>
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
#include <chrono>
|
||||
#include <exception>
|
||||
|
||||
@ -575,9 +574,6 @@ bool ProcessGroupGloo::SendWork::wait(std::chrono::milliseconds timeout) {
|
||||
|
||||
// Completes the Work object and throws the exception.
|
||||
finishAndThrow(exception);
|
||||
c10d::unregister_work(
|
||||
c10::intrusive_ptr<
|
||||
ProcessGroupGloo::SendWork>::unsafe_reclaim_from_nonowning(this));
|
||||
return sendCompleted;
|
||||
}
|
||||
|
||||
@ -625,9 +621,6 @@ bool ProcessGroupGloo::RecvWork::wait(std::chrono::milliseconds timeout) {
|
||||
|
||||
// Completes the Work object and throws the exception.
|
||||
finishAndThrow(exception);
|
||||
c10d::unregister_work(
|
||||
c10::intrusive_ptr<
|
||||
ProcessGroupGloo::RecvWork>::unsafe_reclaim_from_nonowning(this));
|
||||
return recvCompleted;
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,6 @@
|
||||
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
|
||||
#if defined(OPEN_MPI) && OPEN_MPI
|
||||
#include <mpi-ext.h> // Needed for CUDA-aware check
|
||||
@ -199,9 +198,6 @@ bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) {
|
||||
populateException();
|
||||
std::rethrow_exception(exception_);
|
||||
}
|
||||
c10d::unregister_work(
|
||||
c10::intrusive_ptr<
|
||||
ProcessGroupMPI::AsyncWork>::unsafe_reclaim_from_nonowning(this));
|
||||
// Always return true, because abort API is not implemented.
|
||||
return true;
|
||||
}
|
||||
|
@ -725,9 +725,6 @@ void ProcessGroupNCCL::WorkNCCL::handleException(
|
||||
|
||||
void ProcessGroupNCCL::WorkNCCL::synchronize() {
|
||||
synchronizeStream();
|
||||
c10d::unregister_work(
|
||||
c10::intrusive_ptr<
|
||||
ProcessGroupNCCL::WorkNCCL>::unsafe_reclaim_from_nonowning(this));
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::WorkNCCL::synchronizeStream() {
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <c10/util/env.h>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
|
||||
#include <torch/csrc/distributed/c10d/UCCTracing.hpp>
|
||||
#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
|
||||
@ -274,9 +273,6 @@ bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) {
|
||||
Work::recordFunctionEndCallback_();
|
||||
Work::recordFunctionEndCallback_ = nullptr;
|
||||
}
|
||||
c10d::unregister_work(
|
||||
c10::intrusive_ptr<
|
||||
ProcessGroupUCC::WorkUCC>::unsafe_reclaim_from_nonowning(this));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
#include <ATen/ThreadLocalState.h>
|
||||
#include <distributed/c10d/ProcessGroup.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Work.hpp>
|
||||
#include <utility>
|
||||
@ -71,10 +70,7 @@ std::vector<at::Tensor> Work::result() {
|
||||
TORCH_CHECK(false, "result() not implemented.");
|
||||
}
|
||||
|
||||
void Work::synchronize() {
|
||||
c10d::unregister_work(
|
||||
c10::intrusive_ptr<Work>::unsafe_reclaim_from_nonowning(this));
|
||||
}
|
||||
void Work::synchronize() {}
|
||||
|
||||
bool Work::wait(std::chrono::milliseconds timeout) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
@ -937,21 +937,6 @@ This class does not support ``__members__`` property.)");
|
||||
py::arg("tensor"),
|
||||
py::arg("work"));
|
||||
|
||||
module.def("_get_work_registry_size", []() {
|
||||
return ::c10d::get_work_registry_size();
|
||||
});
|
||||
|
||||
module.def(
|
||||
"_set_allow_inflight_collective_as_graph_input",
|
||||
[](bool value) {
|
||||
return ::c10d::set_allow_inflight_collective_as_graph_input(value);
|
||||
},
|
||||
py::arg("value"));
|
||||
|
||||
module.def("_allow_inflight_collective_as_graph_input", []() {
|
||||
return ::c10d::allow_inflight_collective_as_graph_input();
|
||||
});
|
||||
|
||||
// Remove a group from the native registry
|
||||
module.def(
|
||||
"_unregister_process_group",
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import sys
|
||||
import warnings
|
||||
from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
@ -817,43 +816,6 @@ def _maybe_wrap_tensor(self) -> torch.Tensor:
|
||||
return cast(torch.Tensor, res)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def allow_inflight_collective_as_graph_input_ctx(value: bool = True):
|
||||
"""
|
||||
Context manager to temporarily set whether inflight collectives are allowed as torch.compile graph inputs.
|
||||
Common use case is when the collective is issued in eager (with `async_op=True`) but waited in compiled region:
|
||||
```
|
||||
def all_reduce_eager(x):
|
||||
y = x * x
|
||||
req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
|
||||
return y
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def all_reduce_wait_compiled(y):
|
||||
torch.ops.c10d_functional.wait_tensor(y)
|
||||
return y * y
|
||||
|
||||
x = torch.ones(1280, 1280, device="cuda") + self.rank
|
||||
# the context manager ensures that `wait_tensor(y)` will wait on the correct work object
|
||||
with allow_inflight_collective_as_graph_input_ctx():
|
||||
y = all_reduce_eager(x)
|
||||
z = all_reduce_wait_compiled(y)
|
||||
```
|
||||
With this context manager, when a collective is called, under the hood the work object of the collective
|
||||
will be registered in the work registry, and the wait_tensor() in compiled region called on
|
||||
the output tensor of the collective will wait on the correct work object.
|
||||
"""
|
||||
previous = torch._C._distributed_c10d._allow_inflight_collective_as_graph_input()
|
||||
|
||||
try:
|
||||
torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value)
|
||||
yield
|
||||
finally:
|
||||
torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(
|
||||
previous
|
||||
)
|
||||
|
||||
|
||||
def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size):
|
||||
def mk_out_tensor(shard):
|
||||
out_size = list(shard.size())
|
||||
|
Reference in New Issue
Block a user