Files
pytorch/torch/lib/c10d/ProcessGroup.cpp
Pieter Noordhuis 4ec6bd7356 Add sourceRank() to ProcessGroup::Work (#14453)
Summary:
This function is only implemented for the subclasses where it makes
sense. If it's not overridden it will throw an error. Having this
function removes the need for a pointer passing hack to pass the
source rank of a recv operation back to the caller. Instead, the
caller can now call `source_rank` on the work object and achieve
the same result.

Closes #11804.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14453

Differential Revision: D13230898

Pulled By: pietern

fbshipit-source-id: ef38f48bfaca8ef9a364e5be122951bafc9f8e49
2018-11-29 09:16:53 -08:00

53 lines
1.2 KiB
C++

#include "ProcessGroup.hpp"
namespace c10d {
ProcessGroup::Work::~Work() {}
bool ProcessGroup::Work::isCompleted() {
std::lock_guard<std::mutex> lock(mutex_);
return completed_;
}
bool ProcessGroup::Work::isSuccess() const {
std::lock_guard<std::mutex> lock(mutex_);
return !exception_;
}
std::exception_ptr ProcessGroup::Work::exception() const {
std::lock_guard<std::mutex> lock(mutex_);
return exception_;
}
int ProcessGroup::Work::sourceRank() const {
throw std::runtime_error(
"sourceRank() may only be called on work objects "
"that correspond to a recv or recv-from-any call.");
}
void ProcessGroup::Work::synchronize() {}
void ProcessGroup::Work::wait() {
std::unique_lock<std::mutex> lock(mutex_);
while (!completed_) {
cv_.wait(lock);
}
if (exception_) {
std::rethrow_exception(exception_);
}
synchronize();
}
void ProcessGroup::Work::finish(std::exception_ptr exception) {
std::lock_guard<std::mutex> lock(mutex_);
completed_ = true;
exception_ = exception;
cv_.notify_all();
}
ProcessGroup::ProcessGroup(int rank, int size) : rank_(rank), size_(size) {}
ProcessGroup::~ProcessGroup() {}
} // namespace c10d