Revert "[c10d][Partial-Graph Overlap] Support calling .wait_tensor() within compiled region on output tensor of eager async_op=True collective (#137763)"

This reverts commit 362ca54f03f9bb72ba7633ed580fb788b1a8dea9.

Reverted https://github.com/pytorch/pytorch/pull/137763 on behalf of https://github.com/wdvr due to this change is breaking our prod training pipeline (verified with bisect) by increasing memory consumption 4x and causing OOM ([comment](https://github.com/pytorch/pytorch/pull/137763#issuecomment-2435962833))
This commit is contained in:
PyTorch MergeBot
2024-10-24 17:46:09 +00:00
parent 8197e4c70d
commit e7f1e306df
13 changed files with 125 additions and 417 deletions

View File

@ -405,22 +405,6 @@ class TestWithNCCL(MultiProcessTestCase):
assert output.eq(expect).all()
assert output.completed
@skip_if_lt_x_gpu(2)
def test_wait_tensor(self) -> None:
self._init_process_group()
input = torch.full((10, 10), float(self.rank), device=self.device)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
output = torch.ops._c10d_functional.all_reduce(
input,
"avg",
"default",
)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
torch.ops._c10d_functional.wait_tensor(output)
# `wait_tensor(output)` will pop the work from the work registry immediately
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
@skip_if_lt_x_gpu(2)
def test_unwaited(self) -> None:
# Verify that the process can terminate gracefully
@ -428,13 +412,11 @@ class TestWithNCCL(MultiProcessTestCase):
self._init_process_group()
input = torch.full((10, 10), float(self.rank), device=self.device)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
output = torch.ops._c10d_functional.all_reduce(
input,
"avg",
"default",
)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
@skip_if_lt_x_gpu(2)
def test_py_work(self) -> None:

View File

@ -3178,48 +3178,6 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
with self.assertRaisesRegex(TypeError, "Invalid function argument"):
c10d.barrier(device_ids=self.rank)
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_unwaited(self) -> None:
# Verify that the process can terminate gracefully
# even with unwaited tensors
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
)
input = torch.full((10240, 10240), float(self.rank), device=f"cuda:{self.rank}")
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
# Running another collective on the same tensor should still work
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_wait_tensor(self) -> None:
# Verify that c10d_functional.wait_tensor() can be invoked on
# output tensor of non-functional collective
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
)
input1 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}")
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
torch.ops.c10d_functional.wait_tensor(input1)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
input2 = torch.full((10, 10), float(self.rank), device=f"cuda:{self.rank}")
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
work.wait()
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
self.assertEqual(input1, input2)
@requires_nccl()
@skip_if_lt_x_gpu(2)
@with_dist_debug_levels(levels=["DETAIL"])

View File

@ -1,5 +1,4 @@
# Owner(s): ["module: dynamo"]
import datetime
import functools
import unittest
from unittest.mock import patch
@ -15,7 +14,7 @@ from torch._C import FileCheck
from torch._dynamo.testing import CompileCounter
from torch._dynamo.utils import same
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
from torch._inductor.utils import run_and_get_triton_code
from torch.distributed.distributed_c10d import GroupMember
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_distributed import (
@ -29,7 +28,6 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
requires_cuda,
skipIfRocm,
)
from torch.testing._internal.inductor_utils import HAS_GPU
@ -247,74 +245,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@skipIfRocm
def test_eager_async_allreduce_inductor_wait(self):
import torch.distributed as dist
def all_reduce_non_functional_eager(x):
y = x * x
work = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
assert isinstance(work, torch.distributed.Work)
return work, y
def all_reduce_wait(work, y): # potentially compiled
if torch.compiler.is_dynamo_compiling():
torch.ops.c10d_functional.wait_tensor(y)
else:
work.wait(datetime.timedelta(seconds=10))
# Under compile, if `wait_tensor(y)` above is correctly executed,
# `y`'s data is in its final form and the output of this function will match eager;
# otherwise, `y * y` will run in parallel with `all_reduce(y)` and the output of this function
# will not match eager.
return y * y
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
x = torch.ones(12800, 12800, device="cuda") + self.rank
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
# NOTE: We run for 10 iterations each, to ensure that the GPU execution is way behind CPU
# and that `y * y` on CPU side will be issued before `all_reduce(y)` on GPU side is done,
# thus guaranteeing that in the bad case `y * y` on GPU side will run in parallel with `all_reduce(y)`
# thus will produce the wrong result that fails the unit test.
# Test: pure-eager
all_reduce_wait_eager = all_reduce_wait
for _ in range(10):
work, y = all_reduce_non_functional_eager(x)
self.assertEqual(
torch._C._distributed_c10d._get_work_registry_size(), 1
)
out_ref = all_reduce_wait_eager(work, y)
# `work.wait()` will pop the work from the work registry immediately
self.assertEqual(
torch._C._distributed_c10d._get_work_registry_size(), 0
)
# Test: issue comm in eager -> wait for comm in compile
all_reduce_wait_compiled = torch.compile(
all_reduce_wait,
backend="inductor",
fullgraph=True,
)
for _ in range(10):
work, y = all_reduce_non_functional_eager(x)
self.assertEqual(
torch._C._distributed_c10d._get_work_registry_size(), 1
)
out_compiled, triton_codes = run_and_get_code(
all_reduce_wait_compiled, work, y
)
# `wait_tensor(y)` will pop the work from the work registry immediately
self.assertEqual(
torch._C._distributed_c10d._get_work_registry_size(), 0
)
FileCheck().check(
"torch.ops._c10d_functional.wait_tensor.default("
).run(triton_codes[0])
self.assertEqual(out_ref, out_compiled)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)

View File

@ -6,10 +6,80 @@
#include <torch/csrc/distributed/c10d/Functional.hpp>
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
#include <utility>
namespace {
class WorkRegistry {
public:
void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work) {
auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto [it, inserted] = registry_.try_emplace(std::move(storage), work);
TORCH_CHECK(
inserted || it->second != work,
"The tensor storage is already associated with another work.");
}
c10::intrusive_ptr<c10d::Work> pop_work(const at::Tensor& tensor) {
const auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto it = registry_.find(storage);
if (it == registry_.end()) {
return nullptr;
}
auto work = it->second;
registry_.erase(it);
return work;
}
~WorkRegistry() {
// If there are still unwaited work objects, their corresponding process
// groups should have already been destroyed at this stage. Any attempts to
// wait for these work objects or to destroy them will only result in
// confusing errors. Therefore, we simply issue a warning and intentionally
// allow the unwaited work objects to leak.
if (!registry_.empty()) {
TORCH_WARN(
"At the time of process termination, there are still ",
registry_.size(),
" unwaited c10d_functional collective calls. "
"Please review your program to ensure c10d_functional.wait_tensor() "
"is invoked on all tensors returned from c10d_functional collective "
"ops before they are used.");
}
for (auto& it : registry_) {
it.second.release();
}
}
private:
std::unordered_map<
c10::weak_intrusive_ptr<c10::StorageImpl>,
c10::intrusive_ptr<c10d::Work>>
registry_;
std::mutex lock_;
};
static WorkRegistry process_registry;
} // namespace
namespace c10d {
void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work) {
RankLocal<WorkRegistry>::get().register_work(tensor, work);
}
} // namespace c10d
namespace {
const std::unordered_map<std::string, c10d::ReduceOp> str_to_reduce_op = {
{"sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::SUM)},
{"avg", c10d::ReduceOp(c10d::ReduceOp::RedOpType::AVG)},
@ -42,6 +112,7 @@ at::Tensor& all_reduce_(
std::vector<at::Tensor> inputs{input};
auto group = c10d::resolve_process_group(group_name);
auto work = group->allreduce(inputs, opts);
c10d::register_work(input, work);
return input;
}
@ -64,6 +135,9 @@ std::vector<at::Tensor> all_reduce_coalesced_(
auto group = c10d::resolve_process_group(group_name);
auto work = group->allreduce_coalesced(inputs, opts);
for (const auto& tensor : inputs) {
c10d::register_work(tensor, work);
}
return inputs;
}
@ -104,6 +178,9 @@ std::vector<at::Tensor> all_gather_into_tensor_coalesced(
auto group = c10d::resolve_process_group(group_name);
auto work = group->allgather_into_tensor_coalesced(outputs, inputs);
for (const auto& tensor : outputs) {
c10d::register_work(tensor, work);
}
return outputs;
}
@ -125,6 +202,7 @@ at::Tensor& all_gather_into_tensor_out(
auto group = c10d::resolve_process_group(group_name);
auto work = group->_allgather_base(output, input, opts);
c10d::register_work(output, work);
return output;
}
@ -160,6 +238,9 @@ std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
auto group = c10d::resolve_process_group(group_name);
auto work = group->reduce_scatter_tensor_coalesced(outputs, inputs, opts);
for (const auto& tensor : outputs) {
c10d::register_work(tensor, work);
}
return outputs;
}
@ -191,6 +272,7 @@ at::Tensor all_to_all_single(
const_cast<at::Tensor&>(input),
output_split_sizes,
input_split_sizes);
c10d::register_work(output, work);
return output;
}
@ -202,6 +284,7 @@ at::Tensor& broadcast_(at::Tensor& input, int64_t src, std::string group_name) {
auto group = c10d::resolve_process_group(group_name);
auto work = group->broadcast(inputs, opts);
c10d::register_work(input, work);
return input;
}
@ -213,6 +296,14 @@ at::Tensor broadcast(
return broadcast_(output, src, std::move(group_name));
}
at::Tensor wait_tensor(const at::Tensor& tensor) {
auto work = c10d::RankLocal<WorkRegistry>::get().pop_work(tensor);
if (work != nullptr) {
work->wait();
}
return tensor;
}
} // namespace
TORCH_LIBRARY(_c10d_functional, m) {
@ -298,7 +389,7 @@ TORCH_LIBRARY(_c10d_functional, m) {
m.def(
"wait_tensor(Tensor tensor) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, c10d::wait_tensor),
c10::DispatchKey::CompositeExplicitAutograd, ::wait_tensor),
{at::Tag::pt2_compliant_tag});
}
@ -347,7 +438,7 @@ class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
// TODO: track active cuda stream in wait
out = c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
.typed<decltype(c10d::wait_tensor)>()
.typed<decltype(wait_tensor)>()
.call(out);
return {out, at::Tensor(), at::Tensor(), at::Tensor()};
@ -402,7 +493,7 @@ class ReduceScatterTensor
// TODO: track active cuda stream in wait
out = c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
.typed<decltype(c10d::wait_tensor)>()
.typed<decltype(wait_tensor)>()
.call(out);
return {
@ -458,7 +549,7 @@ class AllGatherIntoTensor
// TODO: track active cuda stream in wait
out = c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
.typed<decltype(c10d::wait_tensor)>()
.typed<decltype(wait_tensor)>()
.call(out);
return {

View File

@ -1,3 +1,11 @@
#pragma once
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
namespace c10d {
C10_EXPORT void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work);
} // namespace c10d

View File

@ -1,6 +1,5 @@
#include <ATen/ThreadLocalState.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
#include <c10/util/Logging.h>
#include <fmt/format.h>
@ -160,137 +159,3 @@ void ProcessGroup::release_resources() {
}
} // namespace c10d
namespace {
class WorkRegistry {
public:
void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work) {
if (!tensor.has_storage()) {
TORCH_WARN_ONCE(
"Registering collective work for tensor without storage is not supported. "
"Calling c10d_functional.wait_tensor() on this tensor will not wait for the collective to complete. "
"Unsupported tensor type: " +
tensor.toString());
return;
}
auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto it = registry_.find(storage);
if (it == registry_.end()) {
registry_.emplace(
std::move(storage),
std::vector<c10::intrusive_ptr<c10d::Work>>{work});
} else {
// There is no guarantee that the previous work object for this
// tensor storage is completed before the new work object is registered.
// Therefore we need to maintain a list of work objects for each tensor
// storage.
it->second.push_back(work);
}
}
std::vector<c10::intrusive_ptr<c10d::Work>> pop_works(
const at::Tensor& tensor) {
const auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto it = registry_.find(storage);
if (it == registry_.end()) {
return {};
}
auto works = it->second;
registry_.erase(it);
return works;
}
void unregister_work(const c10::intrusive_ptr<c10d::Work>& work) {
std::unique_lock lock(lock_);
for (auto it = registry_.begin(); it != registry_.end();) {
std::vector<c10::intrusive_ptr<c10d::Work>> nonmatching_works;
for (const auto& _work : it->second) {
if (_work != work) {
nonmatching_works.push_back(_work);
}
}
if (nonmatching_works.empty()) {
it = registry_.erase(it);
} else {
it->second = std::move(nonmatching_works);
++it;
}
}
}
size_t get_work_registry_size() {
std::unique_lock lock(lock_);
size_t total_size = 0;
for (const auto& [storage, works] : registry_) {
total_size += works.size();
}
return total_size;
}
~WorkRegistry() {
// If there are still unwaited work objects, their corresponding process
// groups should have already been destroyed at this stage. Any attempts to
// wait for these work objects or to destroy them will only result in
// confusing errors. Therefore, we simply issue a warning and intentionally
// allow the unwaited work objects to leak.
size_t registry_size = get_work_registry_size();
if (registry_size > 0) {
TORCH_WARN(
"At the time of process termination, there are still ",
registry_size,
" unwaited collective calls. "
"Please review your program to ensure that:\n"
"1. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,\n"
"2. work.wait() is invoked on work object returned from torch.distributed collective with async_op=True,\n"
"before the output tensors of the collective are used.");
}
for (auto& it : registry_) {
for (auto& work : it.second) {
work.release();
}
}
}
private:
std::unordered_map<
c10::weak_intrusive_ptr<c10::StorageImpl>,
std::vector<c10::intrusive_ptr<c10d::Work>>>
registry_;
std::mutex lock_;
};
static WorkRegistry process_registry;
} // namespace
namespace c10d {
void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work) {
RankLocal<WorkRegistry>::get().register_work(tensor, work);
}
at::Tensor wait_tensor(const at::Tensor& tensor) {
auto works = RankLocal<WorkRegistry>::get().pop_works(tensor);
for (const auto& work : works) {
work->wait();
}
return tensor;
}
void unregister_work(const c10::intrusive_ptr<c10d::Work>& work) {
RankLocal<WorkRegistry>::get().unregister_work(work);
}
size_t get_work_registry_size() {
return RankLocal<WorkRegistry>::get().get_work_registry_size();
}
} // namespace c10d

View File

@ -1,7 +1,6 @@
#pragma once
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <memory>
#include <unordered_map>
#include <utility>
@ -24,16 +23,6 @@ constexpr auto kProcessGroupDefaultTimeout =
namespace c10d {
C10_EXPORT void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work);
C10_EXPORT at::Tensor wait_tensor(const at::Tensor& tensor);
C10_EXPORT void unregister_work(const c10::intrusive_ptr<c10d::Work>& work);
C10_EXPORT size_t get_work_registry_size();
// ProcessGroup is a base class that captures collective and point to
// point communication in a fixed set of processes.
//
@ -169,18 +158,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// It's awakward to unbox the opts here and box them again in the custom C++
// op. But it's also complicated to make opts as a CustomClassHolder. Leave
// it as it is now.
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.rootTensor,
opts.asyncOp,
opts.timeout.count()));
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> allreduce(
@ -197,17 +181,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::optional<at::Tensor>& sparse_indices,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.sparseIndices,
opts.timeout.count()));
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> allreduce_coalesced(
@ -221,16 +200,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
auto work = op.call(
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.timeout.count());
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> reduce(
@ -245,18 +219,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
int64_t,
int64_t,
int64_t)>();
auto work = op.call(
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.rootRank,
opts.rootTensor,
opts.timeout.count());
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> allgather(
@ -273,18 +242,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.timeout.count()));
for (const auto& tensor_list : outputTensors) {
for (const auto& tensor : tensor_list) {
c10d::register_work(tensor, work);
}
}
return work;
}
// Gathers a single tensor inputBuffer into a single buffer outputBuffer that
@ -305,15 +267,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
bool,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));
c10d::register_work(outputBuffer, work);
return work;
}
// This function is deprecated and will be moved out of ProcessGroup to comms:
@ -332,17 +291,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
auto work = op.call(
return op.call(
outputTensorLists,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
for (const auto& tensor_list : outputTensorLists) {
for (const auto& tensor : tensor_list) {
c10d::register_work(tensor, work);
}
}
return work;
}
// This function is a coalesced version of `allgather_into_tensor` (currently
@ -360,15 +312,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
auto work = op.call(
return op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> gather(
@ -383,19 +330,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
auto work = op.call(
return op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.timeout.count());
for (const auto& tensor_list : outputTensors) {
for (const auto& tensor : tensor_list) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> scatter(
@ -413,18 +353,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
int64_t,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.asyncOp,
opts.timeout.count()));
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> reduce_scatter(
@ -441,17 +376,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.timeout.count()));
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> _reduce_scatter_base(
@ -468,16 +398,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count()));
c10d::register_work(outputBuffer, work);
return work;
}
// This function is a coalesced version of `reduce_scatter_tensor` (currently
@ -497,17 +424,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
auto work = op.call(
return op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.timeout.count());
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> alltoall_base(
@ -525,16 +447,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::vector<int64_t>,
std::vector<int64_t>,
int64_t)>();
auto work = op.call(
return op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
outputSplitSizes,
inputSplitSizes,
opts.timeout.count());
c10d::register_work(outputBuffer, work);
return work;
}
virtual c10::intrusive_ptr<Work> alltoall(
@ -550,16 +469,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.timeout.count()));
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual void monitoredBarrier(
@ -635,15 +549,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
auto work = op.call(
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
dstRank,
tag);
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> recv(
@ -657,15 +567,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
auto work = op.call(
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
srcRank,
tag);
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> recvAnysource(
@ -677,14 +583,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
auto work = op.call(
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
tag);
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> barrier(
@ -716,13 +618,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<int64_t>&,
int64_t)>();
auto work = op.call(
return op.call(
tensor,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.device_ids,
opts.timeout.count());
c10d::register_work(tensor, work);
return work;
}
bool hasBackends() {

View File

@ -5,7 +5,6 @@
#include <torch/csrc/distributed/c10d/GlooDeviceFactory.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <chrono>
#include <exception>
@ -575,9 +574,6 @@ bool ProcessGroupGloo::SendWork::wait(std::chrono::milliseconds timeout) {
// Completes the Work object and throws the exception.
finishAndThrow(exception);
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupGloo::SendWork>::unsafe_reclaim_from_nonowning(this));
return sendCompleted;
}
@ -625,9 +621,6 @@ bool ProcessGroupGloo::RecvWork::wait(std::chrono::milliseconds timeout) {
// Completes the Work object and throws the exception.
finishAndThrow(exception);
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupGloo::RecvWork>::unsafe_reclaim_from_nonowning(this));
return recvCompleted;
}

View File

@ -7,7 +7,6 @@
#include <c10/core/DeviceGuard.h>
#include <c10/util/irange.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#if defined(OPEN_MPI) && OPEN_MPI
#include <mpi-ext.h> // Needed for CUDA-aware check
@ -199,9 +198,6 @@ bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) {
populateException();
std::rethrow_exception(exception_);
}
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupMPI::AsyncWork>::unsafe_reclaim_from_nonowning(this));
// Always return true, because abort API is not implemented.
return true;
}

View File

@ -720,9 +720,6 @@ void ProcessGroupNCCL::WorkNCCL::handleException(
void ProcessGroupNCCL::WorkNCCL::synchronize() {
synchronizeStream();
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupNCCL::WorkNCCL>::unsafe_reclaim_from_nonowning(this));
}
void ProcessGroupNCCL::WorkNCCL::synchronizeStream() {

View File

@ -2,7 +2,6 @@
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <c10/util/env.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
#include <torch/csrc/distributed/c10d/UCCTracing.hpp>
#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
@ -274,9 +273,6 @@ bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) {
Work::recordFunctionEndCallback_();
Work::recordFunctionEndCallback_ = nullptr;
}
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupUCC::WorkUCC>::unsafe_reclaim_from_nonowning(this));
return true;
}

View File

@ -1,5 +1,4 @@
#include <ATen/ThreadLocalState.h>
#include <distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <utility>
@ -71,10 +70,7 @@ std::vector<at::Tensor> Work::result() {
TORCH_CHECK(false, "result() not implemented.");
}
void Work::synchronize() {
c10d::unregister_work(
c10::intrusive_ptr<Work>::unsafe_reclaim_from_nonowning(this));
}
void Work::synchronize() {}
bool Work::wait(std::chrono::milliseconds timeout) {
std::unique_lock<std::mutex> lock(mutex_);

View File

@ -933,10 +933,6 @@ This class does not support ``__members__`` property.)");
py::arg("tensor"),
py::arg("work"));
module.def("_get_work_registry_size", []() {
return ::c10d::get_work_registry_size();
});
// Remove a group from the native registry
module.def(
"_unregister_process_group",