Revert "[c10d][Partial-Graph Overlap] Support calling .wait_tensor() within compiled region on output tensor of eager async_op=True collective (#137763)"

This reverts commit 362ca54f03f9bb72ba7633ed580fb788b1a8dea9.

Reverted https://github.com/pytorch/pytorch/pull/137763 on behalf of https://github.com/wdvr due to this change is breaking our prod training pipeline (verified with bisect) by increasing memory consumption 4x and causing OOM ([comment](https://github.com/pytorch/pytorch/pull/137763#issuecomment-2435962833))
This commit is contained in:
PyTorch MergeBot
2024-10-24 17:46:09 +00:00
parent 8197e4c70d
commit e7f1e306df
13 changed files with 125 additions and 417 deletions

View File

@ -1,7 +1,6 @@
#pragma once
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <memory>
#include <unordered_map>
#include <utility>
@ -24,16 +23,6 @@ constexpr auto kProcessGroupDefaultTimeout =
namespace c10d {
C10_EXPORT void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work);
C10_EXPORT at::Tensor wait_tensor(const at::Tensor& tensor);
C10_EXPORT void unregister_work(const c10::intrusive_ptr<c10d::Work>& work);
C10_EXPORT size_t get_work_registry_size();
// ProcessGroup is a base class that captures collective and point to
// point communication in a fixed set of processes.
//
@ -169,18 +158,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// It's awakward to unbox the opts here and box them again in the custom C++
// op. But it's also complicated to make opts as a CustomClassHolder. Leave
// it as it is now.
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.rootTensor,
opts.asyncOp,
opts.timeout.count()));
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> allreduce(
@ -197,17 +181,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::optional<at::Tensor>& sparse_indices,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.sparseIndices,
opts.timeout.count()));
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> allreduce_coalesced(
@ -221,16 +200,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
auto work = op.call(
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.timeout.count());
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> reduce(
@ -245,18 +219,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
int64_t,
int64_t,
int64_t)>();
auto work = op.call(
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.rootRank,
opts.rootTensor,
opts.timeout.count());
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> allgather(
@ -273,18 +242,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.timeout.count()));
for (const auto& tensor_list : outputTensors) {
for (const auto& tensor : tensor_list) {
c10d::register_work(tensor, work);
}
}
return work;
}
// Gathers a single tensor inputBuffer into a single buffer outputBuffer that
@ -305,15 +267,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
bool,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));
c10d::register_work(outputBuffer, work);
return work;
}
// This function is deprecated and will be moved out of ProcessGroup to comms:
@ -332,17 +291,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
auto work = op.call(
return op.call(
outputTensorLists,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
for (const auto& tensor_list : outputTensorLists) {
for (const auto& tensor : tensor_list) {
c10d::register_work(tensor, work);
}
}
return work;
}
// This function is a coalesced version of `allgather_into_tensor` (currently
@ -360,15 +312,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
auto work = op.call(
return op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> gather(
@ -383,19 +330,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
auto work = op.call(
return op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.timeout.count());
for (const auto& tensor_list : outputTensors) {
for (const auto& tensor : tensor_list) {
c10d::register_work(tensor, work);
}
}
return work;
}
virtual c10::intrusive_ptr<Work> scatter(
@ -413,18 +353,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
int64_t,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.asyncOp,
opts.timeout.count()));
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> reduce_scatter(
@ -441,17 +376,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.timeout.count()));
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> _reduce_scatter_base(
@ -468,16 +398,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count()));
c10d::register_work(outputBuffer, work);
return work;
}
// This function is a coalesced version of `reduce_scatter_tensor` (currently
@ -497,17 +424,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
auto work = op.call(
return op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.timeout.count());
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> alltoall_base(
@ -525,16 +447,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::vector<int64_t>,
std::vector<int64_t>,
int64_t)>();
auto work = op.call(
return op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
outputSplitSizes,
inputSplitSizes,
opts.timeout.count());
c10d::register_work(outputBuffer, work);
return work;
}
virtual c10::intrusive_ptr<Work> alltoall(
@ -550,16 +469,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
auto work = std::get<1>(op.call(
return std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.timeout.count()));
for (const auto& tensor : outputTensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual void monitoredBarrier(
@ -635,15 +549,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
auto work = op.call(
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
dstRank,
tag);
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> recv(
@ -657,15 +567,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
auto work = op.call(
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
srcRank,
tag);
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> recvAnysource(
@ -677,14 +583,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
auto work = op.call(
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
tag);
for (const auto& tensor : tensors) {
c10d::register_work(tensor, work);
}
return work;
}
virtual c10::intrusive_ptr<Work> barrier(
@ -716,13 +618,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<int64_t>&,
int64_t)>();
auto work = op.call(
return op.call(
tensor,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.device_ids,
opts.timeout.count());
c10d::register_work(tensor, work);
return work;
}
bool hasBackends() {