mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[c10d][Partial-Graph Overlap] Support calling .wait_tensor() within compiled region on output tensor of eager async_op=True collective (#137763)"
This reverts commit 362ca54f03f9bb72ba7633ed580fb788b1a8dea9. Reverted https://github.com/pytorch/pytorch/pull/137763 on behalf of https://github.com/wdvr due to this change is breaking our prod training pipeline (verified with bisect) by increasing memory consumption 4x and causing OOM ([comment](https://github.com/pytorch/pytorch/pull/137763#issuecomment-2435962833))
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Work.hpp>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
@ -24,16 +23,6 @@ constexpr auto kProcessGroupDefaultTimeout =
|
||||
|
||||
namespace c10d {
|
||||
|
||||
C10_EXPORT void register_work(
|
||||
const at::Tensor& tensor,
|
||||
const c10::intrusive_ptr<c10d::Work>& work);
|
||||
|
||||
C10_EXPORT at::Tensor wait_tensor(const at::Tensor& tensor);
|
||||
|
||||
C10_EXPORT void unregister_work(const c10::intrusive_ptr<c10d::Work>& work);
|
||||
|
||||
C10_EXPORT size_t get_work_registry_size();
|
||||
|
||||
// ProcessGroup is a base class that captures collective and point to
|
||||
// point communication in a fixed set of processes.
|
||||
//
|
||||
@ -169,18 +158,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
// It's awakward to unbox the opts here and box them again in the custom C++
|
||||
// op. But it's also complicated to make opts as a CustomClassHolder. Leave
|
||||
// it as it is now.
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.rootRank,
|
||||
opts.rootTensor,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> allreduce(
|
||||
@ -197,17 +181,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const std::optional<at::Tensor>& sparse_indices,
|
||||
int64_t)>();
|
||||
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<ReduceOp>(opts.reduceOp),
|
||||
opts.sparseIndices,
|
||||
opts.timeout.count()));
|
||||
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> allreduce_coalesced(
|
||||
@ -221,16 +200,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
int64_t)>();
|
||||
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<ReduceOp>(opts.reduceOp),
|
||||
opts.timeout.count());
|
||||
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> reduce(
|
||||
@ -245,18 +219,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
int64_t,
|
||||
int64_t,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<ReduceOp>(opts.reduceOp),
|
||||
opts.rootRank,
|
||||
opts.rootTensor,
|
||||
opts.timeout.count());
|
||||
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> allgather(
|
||||
@ -273,18 +242,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t)>();
|
||||
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.timeout.count()));
|
||||
|
||||
for (const auto& tensor_list : outputTensors) {
|
||||
for (const auto& tensor : tensor_list) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
// Gathers a single tensor inputBuffer into a single buffer outputBuffer that
|
||||
@ -305,15 +267,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
bool,
|
||||
int64_t)>();
|
||||
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
outputBuffer,
|
||||
inputBuffer,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
c10d::register_work(outputBuffer, work);
|
||||
return work;
|
||||
}
|
||||
|
||||
// This function is deprecated and will be moved out of ProcessGroup to comms:
|
||||
@ -332,17 +291,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const at::TensorList&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
|
||||
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
outputTensorLists,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
|
||||
|
||||
for (const auto& tensor_list : outputTensorLists) {
|
||||
for (const auto& tensor : tensor_list) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
// This function is a coalesced version of `allgather_into_tensor` (currently
|
||||
@ -360,15 +312,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const at::TensorList,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
|
||||
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
|
||||
|
||||
for (const auto& tensor : outputTensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> gather(
|
||||
@ -383,19 +330,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.rootRank,
|
||||
opts.timeout.count());
|
||||
|
||||
for (const auto& tensor_list : outputTensors) {
|
||||
for (const auto& tensor : tensor_list) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> scatter(
|
||||
@ -413,18 +353,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
int64_t,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.rootRank,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
for (const auto& tensor : outputTensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> reduce_scatter(
|
||||
@ -441,17 +376,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
int64_t)>();
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
|
||||
opts.timeout.count()));
|
||||
|
||||
for (const auto& tensor : outputTensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> _reduce_scatter_base(
|
||||
@ -468,16 +398,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
outputBuffer,
|
||||
inputBuffer,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
c10d::register_work(outputBuffer, work);
|
||||
return work;
|
||||
}
|
||||
|
||||
// This function is a coalesced version of `reduce_scatter_tensor` (currently
|
||||
@ -497,17 +424,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
int64_t)>();
|
||||
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
|
||||
opts.timeout.count());
|
||||
|
||||
for (const auto& tensor : outputTensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> alltoall_base(
|
||||
@ -525,16 +447,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
outputBuffer,
|
||||
inputBuffer,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
outputSplitSizes,
|
||||
inputSplitSizes,
|
||||
opts.timeout.count());
|
||||
|
||||
c10d::register_work(outputBuffer, work);
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> alltoall(
|
||||
@ -550,16 +469,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const at::TensorList&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t)>();
|
||||
auto work = std::get<1>(op.call(
|
||||
return std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.timeout.count()));
|
||||
|
||||
for (const auto& tensor : outputTensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual void monitoredBarrier(
|
||||
@ -635,15 +549,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
dstRank,
|
||||
tag);
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> recv(
|
||||
@ -657,15 +567,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
srcRank,
|
||||
tag);
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> recvAnysource(
|
||||
@ -677,14 +583,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
at::TensorList,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
tag);
|
||||
for (const auto& tensor : tensors) {
|
||||
c10d::register_work(tensor, work);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> barrier(
|
||||
@ -716,13 +618,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const std::vector<int64_t>&,
|
||||
int64_t)>();
|
||||
|
||||
auto work = op.call(
|
||||
return op.call(
|
||||
tensor,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.device_ids,
|
||||
opts.timeout.count());
|
||||
c10d::register_work(tensor, work);
|
||||
return work;
|
||||
}
|
||||
|
||||
bool hasBackends() {
|
||||
|
||||
Reference in New Issue
Block a user