[C10d] Cleanup collective sequence number. (#109136)

Sequence numbers must be associated with a Work object
if we want to use it as a way to report collective progress.

The API surface change is introducing Work::getSequenceNumber, which
should eventually be exposed to python.

The bulk of this change is changing gloo to make the sequence number
be always in use and weave it to the dozens subclasses of Work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109136
Approved by: https://github.com/fduwjj
This commit is contained in:
Rodrigo Kumpera
2023-09-25 11:33:05 -07:00
committed by PyTorch MergeBot
parent 818f2297e6
commit 317e39a8ad
8 changed files with 157 additions and 71 deletions

View File

@ -15,7 +15,6 @@
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <torch/csrc/distributed/c10d/debug.h>
#include <torch/csrc/distributed/c10d/sequence_num.hpp>
constexpr auto kBackendDefaultTimeout =
std::chrono::milliseconds(30 * 60 * 1000);
@ -368,8 +367,6 @@ class TORCH_API Backend : public torch::CustomClassHolder {
// appropriate logging etc.
void init();
// Optional sequence number structure for matching collectives.
c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt;
const int rank_;
const int size_;
// Debug level setting. It is parsed once when ProcessGroup is constructed and

View File

@ -692,8 +692,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const int size_;
const c10::intrusive_ptr<Options> options_;
const BackendType backendType_;
// Optional sequence number structure for matching collectives.
c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt;
// Debug level setting. It is parsed once when ProcessGroup is constructed and
// remains the same across use of this process group.

View File

@ -513,6 +513,7 @@ inline void ProcessGroupGloo::AsyncWork::recordAsyncWorkProfilingInfo(
ProcessGroupGloo::AsyncWork::AsyncWork(
std::vector<std::vector<at::Tensor>> outputTensors,
OpType opType,
uint64_t seq,
const char* profilingTitle,
const c10::optional<std::vector<at::Tensor>>& inputTensors)
// Profiler: Pass nullptr as profilingTitle to parent constructor to
@ -520,12 +521,17 @@ ProcessGroupGloo::AsyncWork::AsyncWork(
// correct timestamps for work that is asynchronously executed.
: Work(-1, opType, nullptr, inputTensors),
outputTensors_(std::move(outputTensors)),
future_(createFutureAsOutput(outputTensors_)) {
future_(createFutureAsOutput(outputTensors_)),
seq_(seq) {
if (profilingTitle != nullptr) {
recordAsyncWorkProfilingInfo(profilingTitle, inputTensors);
}
}
uint64_t ProcessGroupGloo::AsyncWork::getSequencenumber() const {
return seq_;
}
void ProcessGroupGloo::AsyncWork::finishWorkGlooError(std::exception_ptr eptr) {
future_->setError(eptr);
finish(eptr);
@ -538,14 +544,20 @@ void ProcessGroupGloo::AsyncWork::finishWorkGloo() {
ProcessGroupGloo::SendWork::SendWork(
at::Tensor& tensor,
std::unique_ptr<::gloo::transport::UnboundBuffer> buffer)
std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,
uint64_t seq)
: Work(
-1,
OpType::SEND,
"gloo:send",
c10::optional<std::vector<at::Tensor>>({tensor})),
tensor_(tensor),
buffer_(std::move(buffer)) {}
buffer_(std::move(buffer)),
seq_(seq) {}
uint64_t ProcessGroupGloo::SendWork::getSequencenumber() const {
return seq_;
}
bool ProcessGroupGloo::SendWork::wait(std::chrono::milliseconds timeout) {
bool sendCompleted = false;
@ -573,6 +585,7 @@ ProcessGroupGloo::RecvWork::RecvWork(
at::Tensor& tensor,
std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,
OpType opType,
uint64_t seq,
const char* profilingTitle)
: Work(
-1,
@ -581,7 +594,12 @@ ProcessGroupGloo::RecvWork::RecvWork(
c10::optional<std::vector<at::Tensor>>({tensor})),
tensor_(tensor),
buffer_(std::move(buffer)),
srcRank_(-1) {}
srcRank_(-1),
seq_(seq) {}
uint64_t ProcessGroupGloo::RecvWork::getSequencenumber() const {
return seq_;
}
int ProcessGroupGloo::RecvWork::sourceRank() const {
std::lock_guard<std::mutex> lock(mutex_);
@ -838,10 +856,6 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
std::unique_lock<std::mutex> lock(workMutex_);
// Bump collective counter
if (sequenceNum_) {
sequenceNum_->increment();
}
workQueue_.push_back(std::move(work));
lock.unlock();
@ -859,10 +873,12 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork {
std::vector<at::Tensor>& inputs,
int rootRank,
int rootTensor,
uint32_t tag)
uint32_t tag,
uint64_t seq)
: ProcessGroupGloo::AsyncWork(
{inputs},
OpType::BROADCAST,
seq,
"gloo:broadcast",
inputs),
context(context),
@ -906,8 +922,9 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork {
std::vector<at::Tensor>& inputs,
int rootRank,
int rootTensor,
uint32_t tag)
: AsyncBroadcastWork(context, inputs, rootRank, rootTensor, tag) {
uint32_t tag,
uint64_t seq)
: AsyncBroadcastWork(context, inputs, rootRank, rootTensor, tag, seq) {
initializeStreamsEvents(inputs, streams, events);
// Create pinned host side tensors.
@ -980,12 +997,13 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::broadcast(
c10::intrusive_ptr<AsyncBroadcastWork> work;
auto tag = nextTag();
auto context = getContext(tag);
++seq_;
if (device.type() == at::kCPU) {
work = c10::make_intrusive<AsyncBroadcastWork>(
std::move(context), inputs, opts.rootRank, opts.rootTensor, tag);
std::move(context), inputs, opts.rootRank, opts.rootTensor, tag, seq_);
} else if (device.type() == at::kCUDA) {
work = c10::make_intrusive<AsyncBroadcastCUDAWork>(
std::move(context), inputs, opts.rootRank, opts.rootTensor, tag);
std::move(context), inputs, opts.rootRank, opts.rootTensor, tag, seq_);
} else {
TORCH_CHECK(false, "Invalid backend");
}
@ -1002,10 +1020,12 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork {
const std::shared_ptr<gloo::Context>& context,
std::vector<at::Tensor>& inputs,
ReduceOp reduceOp,
uint32_t tag)
uint32_t tag,
uint64_t seq)
: ProcessGroupGloo::AsyncWork(
{inputs},
OpType::ALLREDUCE,
seq,
"gloo:all_reduce",
inputs),
context(context),
@ -1051,8 +1071,9 @@ class AsyncAllreduceCoalescedWork : public AsyncAllreduceWork {
const std::shared_ptr<gloo::Context>& context,
std::vector<at::Tensor>& inputs,
ReduceOp reduceOp,
uint32_t tag)
: AsyncAllreduceWork(context, inputs, reduceOp, tag) {}
uint32_t tag,
uint64_t seq)
: AsyncAllreduceWork(context, inputs, reduceOp, tag, seq) {}
void run() override {
allreduceCoalesced(inputs);
@ -1082,10 +1103,12 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
AsyncSparseAllreduceWork(
const std::shared_ptr<gloo::Context>& context,
std::vector<at::Tensor>& inputs,
uint32_t tag)
uint32_t tag,
uint64_t seq)
: ProcessGroupGloo::AsyncWork(
{inputs},
OpType::_ALLREDUCE_SPARSE,
seq,
"gloo:sparse_all_reduce",
inputs),
context(context),
@ -1355,8 +1378,9 @@ class AsyncAllreduceCUDAWork : public AsyncAllreduceWork {
const std::shared_ptr<gloo::Context>& context,
std::vector<at::Tensor>& inputs,
ReduceOp reduceOp,
uint32_t tag)
: AsyncAllreduceWork(context, inputs, reduceOp, tag) {
uint32_t tag,
uint64_t seq)
: AsyncAllreduceWork(context, inputs, reduceOp, tag, seq) {
initializeStreamsEvents(inputs, streams, events);
// Kick off copy from CUDA tensors to pinned CPU tensors.
@ -1404,8 +1428,9 @@ class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork {
AsyncSparseAllreduceCUDAWork(
const std::shared_ptr<gloo::Context>& context,
std::vector<at::Tensor>& inputs,
uint32_t tag)
: AsyncSparseAllreduceWork(context, inputs, tag) {
uint32_t tag,
uint64_t seq)
: AsyncSparseAllreduceWork(context, inputs, tag, seq) {
initializeStreamsEvents(inputs, streams, events);
// Kick off copy from CUDA tensors to CPU tensors.
@ -1487,23 +1512,24 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::allreduce(
c10::intrusive_ptr<AsyncWork> work;
auto tag = nextTag();
auto context = getContext(tag);
++seq_;
if (device.type() == at::kCPU) {
if (layout == c10::kStrided) {
work = c10::make_intrusive<AsyncAllreduceWork>(
std::move(context), inputs, opts.reduceOp, tag);
std::move(context), inputs, opts.reduceOp, tag, seq_);
} else if (layout == c10::kSparse) {
work = c10::make_intrusive<AsyncSparseAllreduceWork>(
std::move(context), inputs, tag);
std::move(context), inputs, tag, seq_);
} else {
invalidArgument("unsupported layout");
}
} else if (device.type() == at::kCUDA) {
if (layout == c10::kStrided) {
work = c10::make_intrusive<AsyncAllreduceCUDAWork>(
std::move(context), inputs, opts.reduceOp, tag);
std::move(context), inputs, opts.reduceOp, tag, seq_);
} else if (layout == c10::kSparse) {
work = c10::make_intrusive<AsyncSparseAllreduceCUDAWork>(
std::move(context), inputs, tag);
std::move(context), inputs, tag, seq_);
} else {
invalidArgument("unsupported layout");
}
@ -1560,10 +1586,11 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::allreduce_coalesced(
c10::intrusive_ptr<AsyncWork> work;
const uint32_t tag = nextTag();
std::shared_ptr<gloo::Context> context = getContext(tag);
++seq_;
if (device.type() == c10::kCPU) {
if (layout == c10::kStrided) {
work = c10::make_intrusive<AsyncAllreduceCoalescedWork>(
std::move(context), tensors, opts.reduceOp, tag);
std::move(context), tensors, opts.reduceOp, tag, seq_);
} else {
invalidArgument("unsupported layout");
}
@ -1584,10 +1611,12 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork {
int rootRank,
int rootTensor,
ReduceOp reduceOp,
uint32_t tag)
uint32_t tag,
uint64_t seq)
: ProcessGroupGloo::AsyncWork(
{inputs},
OpType::REDUCE,
seq,
"gloo:reduce",
inputs),
context(context),
@ -1641,8 +1670,16 @@ class AsyncReduceCUDAWork : public AsyncReduceWork {
int rootRank,
int rootTensor,
ReduceOp reduceOp,
uint32_t tag)
: AsyncReduceWork(context, inputs, rootRank, rootTensor, reduceOp, tag) {
uint32_t tag,
uint64_t seq)
: AsyncReduceWork(
context,
inputs,
rootRank,
rootTensor,
reduceOp,
tag,
seq) {
initializeStreamsEvents(inputs, streams, events);
// Kick off copy from CUDA tensors to pinned CPU tensors.
@ -1715,6 +1752,7 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::reduce(
c10::intrusive_ptr<AsyncReduceWork> work;
auto tag = nextTag();
auto context = getContext(tag);
++seq_;
if (device.type() == at::kCPU) {
work = c10::make_intrusive<AsyncReduceWork>(
std::move(context),
@ -1722,7 +1760,8 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::reduce(
opts.rootRank,
opts.rootTensor,
opts.reduceOp,
tag);
tag,
seq_);
} else if (device.type() == at::kCUDA) {
work = c10::make_intrusive<AsyncReduceCUDAWork>(
std::move(context),
@ -1730,7 +1769,8 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::reduce(
opts.rootRank,
opts.rootTensor,
opts.reduceOp,
tag);
tag,
seq_);
} else {
TORCH_CHECK(false, "Invalid backend");
}
@ -1746,10 +1786,12 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
const std::shared_ptr<gloo::Context>& context,
std::vector<std::vector<at::Tensor>>& outputs,
std::vector<at::Tensor>& inputs,
uint32_t tag)
uint32_t tag,
uint64_t seq)
: ProcessGroupGloo::AsyncWork(
outputs,
OpType::ALLGATHER,
seq,
"gloo:all_gather",
inputs),
context(context),
@ -1801,8 +1843,9 @@ class AsyncAllgatherCUDAWork : public AsyncAllgatherWork {
const std::shared_ptr<gloo::Context>& context,
std::vector<std::vector<at::Tensor>>& outputs,
std::vector<at::Tensor>& inputs,
uint32_t tag)
: AsyncAllgatherWork(context, outputs, inputs, tag) {
uint32_t tag,
uint64_t seq)
: AsyncAllgatherWork(context, outputs, inputs, tag, seq) {
initializeStreamsEvents(inputs, inputStreams, inputEvents);
initializeStreamsEvents(outputs, outputStreams, outputEvents);
@ -1922,12 +1965,13 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::allgather(
c10::intrusive_ptr<AsyncAllgatherWork> work;
auto tag = nextTag();
auto context = getContext(tag);
++seq_;
if (device.type() == at::kCPU) {
work = c10::make_intrusive<AsyncAllgatherWork>(
std::move(context), outputs, inputs, tag);
std::move(context), outputs, inputs, tag, seq_);
} else if (device.type() == at::kCUDA) {
work = c10::make_intrusive<AsyncAllgatherCUDAWork>(
std::move(context), outputs, inputs, tag);
std::move(context), outputs, inputs, tag, seq_);
} else {
TORCH_CHECK(false, "Invalid backend");
}
@ -1943,10 +1987,12 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork {
const std::shared_ptr<gloo::Context>& context,
std::vector<std::vector<at::Tensor>>& output_lists,
std::vector<at::Tensor>& input_list,
uint32_t tag)
uint32_t tag,
uint64_t seq)
: ProcessGroupGloo::AsyncWork(
output_lists,
OpType::ALLGATHER_COALESCED,
seq,
"gloo:all_gather",
input_list),
context(context),
@ -2053,8 +2099,9 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::allgather_coalesced(
auto tag = nextTag();
auto context = getContext(tag);
++seq_;
auto work = c10::make_intrusive<AsyncAllgatherCoalescedWork>(
std::move(context), output_lists, input_list, tag);
std::move(context), output_lists, input_list, tag, seq_);
enqueue(work);
return work;
}
@ -2075,10 +2122,12 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
std::vector<std::vector<at::Tensor>>& outputs,
std::vector<at::Tensor>& inputs,
int root,
uint32_t tag)
uint32_t tag,
uint64_t seq)
: ProcessGroupGloo::AsyncWork(
outputs,
OpType::GATHER,
seq,
"gloo:gather",
inputs),
context(context),
@ -2138,8 +2187,9 @@ class AsyncGatherCUDAWork : public AsyncGatherWork {
std::vector<std::vector<at::Tensor>>& outputs,
std::vector<at::Tensor>& inputs,
int root,
uint32_t tag)
: AsyncGatherWork(context, outputs, inputs, root, tag) {
uint32_t tag,
uint64_t seq)
: AsyncGatherWork(context, outputs, inputs, root, tag, seq) {
initializeStreamsEvents(inputs, inputStreams, inputEvents);
initializeStreamsEvents(outputs, outputStreams, outputEvents);
@ -2254,12 +2304,13 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::gather(
c10::intrusive_ptr<AsyncGatherWork> work;
auto tag = nextTag();
auto context = getContext(tag);
++seq_;
if (device.type() == at::kCPU) {
work = c10::make_intrusive<AsyncGatherWork>(
std::move(context), outputs, inputs, opts.rootRank, tag);
std::move(context), outputs, inputs, opts.rootRank, tag, seq_);
} else if (device.type() == at::kCUDA) {
work = c10::make_intrusive<AsyncGatherCUDAWork>(
std::move(context), outputs, inputs, opts.rootRank, tag);
std::move(context), outputs, inputs, opts.rootRank, tag, seq_);
} else {
TORCH_CHECK(false, "Invalid backend");
}
@ -2276,10 +2327,12 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork {
std::vector<at::Tensor>& outputs,
std::vector<std::vector<at::Tensor>>& inputs,
int root,
uint32_t tag)
uint32_t tag,
uint64_t seq)
: ProcessGroupGloo::AsyncWork(
{outputs},
OpType::SCATTER,
seq,
"gloo:scatter",
!inputs.empty() ? c10::optional<std::vector<at::Tensor>>(inputs[0])
: c10::nullopt),
@ -2325,8 +2378,9 @@ class AsyncScatterCUDAWork : public AsyncScatterWork {
std::vector<at::Tensor>& outputs,
std::vector<std::vector<at::Tensor>>& inputs,
int root,
uint32_t tag)
: AsyncScatterWork(context, outputs, inputs, root, tag) {
uint32_t tag,
uint64_t seq)
: AsyncScatterWork(context, outputs, inputs, root, tag, seq) {
initializeStreamsEvents(inputs, inputStreams, inputEvents);
initializeStreamsEvents(outputs, outputStreams, outputEvents);
@ -2438,12 +2492,13 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::scatter(
c10::intrusive_ptr<AsyncScatterWork> work;
auto tag = nextTag();
auto context = getContext(tag);
++seq_;
if (device.type() == at::kCPU) {
work = c10::make_intrusive<AsyncScatterWork>(
std::move(context), outputs, inputs, opts.rootRank, tag);
std::move(context), outputs, inputs, opts.rootRank, tag, seq_);
} else if (device.type() == at::kCUDA) {
work = c10::make_intrusive<AsyncScatterCUDAWork>(
std::move(context), outputs, inputs, opts.rootRank, tag);
std::move(context), outputs, inputs, opts.rootRank, tag, seq_);
} else {
TORCH_CHECK(false, "Invalid backend");
}
@ -2468,10 +2523,12 @@ class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork {
at::Tensor& inputTensor,
std::vector<int64_t>& outputCounts,
std::vector<int64_t>& inputCounts,
uint32_t tag)
uint32_t tag,
uint64_t seq)
: ProcessGroupGloo::AsyncWork(
{{outputTensor}},
OpType::ALLTOALL,
seq,
"gloo:all_to_all",
c10::optional<std::vector<at::Tensor>>({inputTensor})),
context(context),
@ -2530,14 +2587,16 @@ class AsyncAlltoallCUDAWork : public AsyncAlltoallWork {
at::Tensor& inputTensor,
std::vector<int64_t>& outputCounts,
std::vector<int64_t>& inputCounts,
uint32_t tag)
uint32_t tag,
uint64_t seq)
: AsyncAlltoallWork(
context,
outputTensor,
inputTensor,
outputCounts,
inputCounts,
tag) {
tag,
seq) {
initializeStreamsEvents({inputTensor}, inputStreams, inputEvents);
initializeStreamsEvents({outputTensor}, outputStreams, outputEvents);
@ -2603,6 +2662,7 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::alltoall_base(
c10::intrusive_ptr<AsyncAlltoallWork> work;
auto tag = nextTag();
auto context = getContext(tag);
++seq_;
if (device.type() == at::kCPU) {
work = c10::make_intrusive<AsyncAlltoallWork>(
@ -2611,7 +2671,8 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::alltoall_base(
inputTensor,
outputCounts,
inputCounts,
tag);
tag,
seq_);
} else if (device.type() == at::kCUDA) {
work = c10::make_intrusive<AsyncAlltoallCUDAWork>(
std::move(context),
@ -2619,7 +2680,8 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::alltoall_base(
inputTensor,
outputCounts,
inputCounts,
tag);
tag,
seq_);
} else {
invalidArgument(c10::str("unsupported device type ", device.type()));
}
@ -2659,10 +2721,11 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::send(
auto context = getContext(tag);
auto buf = context->createUnboundBuffer(const_cast<void*>(ptr), size);
buf->send(dstRank, utag);
++seq_;
// The work captures the tensor to prevent it being deallocated and
// the unbound buffer to synchronize on completion of the send.
return c10::make_intrusive<SendWork>(tensor, std::move(buf));
return c10::make_intrusive<SendWork>(tensor, std::move(buf), seq_);
}
c10::intrusive_ptr<Work> ProcessGroupGloo::recv(
@ -2678,11 +2741,12 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::recv(
auto context = getContext(tag);
auto buf = context->createUnboundBuffer(ptr, size);
buf->recv(srcRank, utag);
++seq_;
// The work captures the tensor to prevent it being deallocated and
// the unbound buffer to synchronize on completion of the recv.
return c10::make_intrusive<RecvWork>(
tensor, std::move(buf), OpType::RECV, "gloo:recv");
tensor, std::move(buf), OpType::RECV, seq_, "gloo:recv");
}
c10::intrusive_ptr<Work> ProcessGroupGloo::recvAnysource(
@ -2707,11 +2771,16 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::recvAnysource(
}
buf->recv(srcRanks, utag);
++seq_;
// The work captures the tensor to prevent it being deallocated and
// the unbound buffer to synchronize on completion of the recv.
return c10::make_intrusive<RecvWork>(
tensor, std::move(buf), OpType::RECVANYSOURCE, "gloo:recvAnySource");
tensor,
std::move(buf),
OpType::RECVANYSOURCE,
seq_,
"gloo:recvAnySource");
}
namespace {
@ -2721,10 +2790,12 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork {
AsyncBarrierWork(
const std::shared_ptr<gloo::Context>& context,
std::vector<c10::weak_intrusive_ptr<AsyncWork>> priorWork,
uint32_t tag)
uint32_t tag,
uint64_t seq)
: ProcessGroupGloo::AsyncWork(
{},
OpType::BARRIER,
seq,
"gloo:barrier",
c10::nullopt),
context(context),
@ -2767,8 +2838,9 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::barrier(const BarrierOptions& opts) {
auto tag = nextTag();
auto context = getContext(tag);
++seq_;
auto work = c10::make_intrusive<AsyncBarrierWork>(
std::move(context), std::move(priorWork), tag);
std::move(context), std::move(priorWork), tag, seq_);
enqueue(work);
return work;
}
@ -2891,14 +2963,10 @@ void ProcessGroupGloo::monitoredBarrier(
}
void ProcessGroupGloo::setSequenceNumberForGroup() {
sequenceNum_ = c10d::SequenceNum(0);
}
} // Gloo just starts sequence numbers at 0.
uint64_t ProcessGroupGloo::getSequenceNumberForGroup() {
if (sequenceNum_ == c10::nullopt) {
return 0;
}
return sequenceNum_->get();
return seq_;
}
void ProcessGroupGloo::enableCollectivesTiming() {

View File

@ -72,6 +72,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
explicit AsyncWork(
std::vector<std::vector<at::Tensor>> outputTensors,
OpType opType,
uint64_t seq,
const char* profilingTitle = nullptr,
const c10::optional<std::vector<at::Tensor>>& inputTensors =
c10::nullopt);
@ -85,6 +86,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
std::vector<at::Tensor> result() override;
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
uint64_t getSequencenumber() const override;
protected:
friend class ProcessGroupGloo;
@ -99,6 +101,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
const std::vector<std::vector<at::Tensor>> outputTensors_;
c10::intrusive_ptr<at::ivalue::Future> future_;
std::function<void()> recordFunctionBeforeCallback_;
const uint64_t seq_;
};
// Wrap c10d store as Gloo store
@ -184,15 +187,19 @@ class TORCH_API ProcessGroupGloo : public Backend {
public:
explicit SendWork(
at::Tensor& tensor,
std::unique_ptr<::gloo::transport::UnboundBuffer> buffer);
std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,
uint64_t seq);
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
void abort() override;
uint64_t getSequencenumber() const override;
protected:
at::Tensor tensor_;
std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_;
const uint64_t seq_;
};
class TORCH_API RecvWork : public Work {
@ -201,6 +208,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
at::Tensor& tensor,
std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,
OpType opType,
uint64_t seq,
const char* profilingTitle = nullptr);
int sourceRank() const override;
@ -209,10 +217,13 @@ class TORCH_API ProcessGroupGloo : public Backend {
void abort() override;
uint64_t getSequencenumber() const override;
protected:
at::Tensor tensor_;
std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_;
int srcRank_;
const uint64_t seq_;
};
struct TORCH_API Options : public Backend::Options {
@ -410,6 +421,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
std::mutex workMutex_;
std::condition_variable workProduceCV_;
std::condition_variable workConsumeCV_;
uint64_t seq_{0};
};
} // namespace c10d

View File

@ -1586,6 +1586,9 @@ float ProcessGroupNCCL::WorkNCCL::getDuration() const {
"getDuration can only be called after work is succeeded.")
return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]);
}
uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
return seq_;
}
void ProcessGroupNCCL::workEnqueue(
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work) {

View File

@ -166,6 +166,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
float getDuration() const override;
uint64_t getSequencenumber() const override;
// Helper function that sets an exception_ptr on the WorkNCCL object.
void setException(std::exception_ptr exception_ptr);

View File

@ -128,7 +128,11 @@ void Work::finishAndThrow(std::exception_ptr exception) {
}
float Work::getDuration() const {
TORCH_CHECK(false, "Only ProcessGrouppNCCL::WorkNCCL supports getDuration.");
TORCH_CHECK(false, "This Backend doesn't support getDuration.");
}
uint64_t Work::getSequencenumber() const {
TORCH_CHECK(false, "This Backend doesn't support getSequencenumber.");
}
class FutureWrappingWork : public Work {

View File

@ -109,6 +109,8 @@ class TORCH_API Work : public torch::CustomClassHolder {
virtual float getDuration() const;
virtual uint64_t getSequencenumber() const;
OpType retrieveOpType() const;
static c10::intrusive_ptr<Work> create_from_future(