mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160449 Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/dcci
211 lines
5.7 KiB
C++
211 lines
5.7 KiB
C++
#include <ATen/ThreadLocalState.h>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
|
|
|
|
#include <torch/csrc/distributed/c10d/Work.hpp>
|
|
#include <utility>
|
|
|
|
namespace c10d {
|
|
|
|
Work::Work(
|
|
int rank,
|
|
OpType opType,
|
|
const char* profilingTitle,
|
|
const std::optional<std::vector<at::Tensor>>& inputTensors)
|
|
: rank_(rank), opType_(opType) {
|
|
if (profilingTitle != nullptr) {
|
|
auto recordingFunction =
|
|
std::make_shared<at::RecordFunction>(at::RecordScope::USER_SCOPE);
|
|
if (recordingFunction->isActive()) {
|
|
// Work events follow a future like pattern and can potentially be marked
|
|
// as complete by different threads, so explicitly set as async event.
|
|
recordingFunction->_setAsync();
|
|
// Passing input tensor to recordFunction allows for shape information in
|
|
// profiling output.
|
|
std::vector<c10::IValue> inputs;
|
|
if (inputTensors) {
|
|
inputs.reserve(inputTensors->size());
|
|
for (const auto& tensor : *inputTensors) {
|
|
inputs.emplace_back(tensor);
|
|
}
|
|
}
|
|
recordingFunction->before(
|
|
profilingTitle,
|
|
c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()));
|
|
std::function<void()> end_handler = [recordingFunction]() {
|
|
recordingFunction->end();
|
|
};
|
|
recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler);
|
|
}
|
|
}
|
|
}
|
|
|
|
OpType Work::retrieveOpType() const {
|
|
return opType_;
|
|
}
|
|
|
|
Work::~Work() = default;
|
|
|
|
bool Work::isCompleted() {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
return completed_;
|
|
}
|
|
|
|
bool Work::isSuccess() const {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
return !exception_;
|
|
}
|
|
|
|
std::exception_ptr Work::exception() const {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
return exception_;
|
|
}
|
|
|
|
int Work::sourceRank() const {
|
|
TORCH_CHECK(
|
|
false,
|
|
"sourceRank() may only be called on work objects "
|
|
"that correspond to a recv or recv-from-any call.");
|
|
}
|
|
|
|
std::vector<at::Tensor> Work::result() {
|
|
TORCH_CHECK(false, "result() not implemented.");
|
|
}
|
|
|
|
void Work::synchronize() {
|
|
if (c10d::allow_inflight_collective_as_graph_input()) {
|
|
c10d::unregister_work(
|
|
c10::intrusive_ptr<Work>::unsafe_reclaim_from_nonowning(this));
|
|
}
|
|
}
|
|
|
|
bool Work::wait(std::chrono::milliseconds timeout) {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
if (timeout == kNoTimeout) {
|
|
// This waits without a timeout.
|
|
cv_.wait(lock, [&] { return completed_; });
|
|
} else {
|
|
// Waits for the user-provided timeout.
|
|
cv_.wait_for(lock, timeout, [&] { return completed_; });
|
|
if (!completed_) {
|
|
// Throw exception if the wait operation timed out and the work was not
|
|
// completed.
|
|
TORCH_CHECK(false, "Operation timed out!");
|
|
}
|
|
}
|
|
if (exception_) {
|
|
std::rethrow_exception(exception_);
|
|
}
|
|
synchronize();
|
|
// Always return true, because abort API is not implemented.
|
|
return true;
|
|
}
|
|
|
|
void Work::blockCurrentStream() {
|
|
// block cuda stream indefinitely until work is completed.
|
|
std::shared_ptr<c10d::cuda::StreamBlock> handle =
|
|
c10d::cuda::block_stream(std::chrono::milliseconds(0));
|
|
|
|
getFuture()->addCallback(
|
|
[handle](c10::ivalue::Future& future) { handle->abort(); });
|
|
}
|
|
|
|
void Work::abort() {
|
|
TORCH_CHECK(false, "Work::abort not implemented.");
|
|
}
|
|
|
|
c10::intrusive_ptr<c10::ivalue::Future> Work::getFuture(){
|
|
TORCH_CHECK(false, "Work::getFuture not implemented.")}
|
|
|
|
c10::intrusive_ptr<c10::ivalue::Future> Work::getFutureResult() {
|
|
TORCH_CHECK(false, "Work::getFutureResult not implemented.")
|
|
}
|
|
|
|
void Work::finish(std::exception_ptr exception) {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
completed_ = true;
|
|
exception_ = std::move(exception);
|
|
if (recordFunctionEndCallback_) {
|
|
recordFunctionEndCallback_();
|
|
recordFunctionEndCallback_ = nullptr;
|
|
}
|
|
lock.unlock();
|
|
cv_.notify_all();
|
|
}
|
|
|
|
void Work::finishAndThrow(std::exception_ptr exception) {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
completed_ = true;
|
|
exception_ = std::move(exception);
|
|
if (recordFunctionEndCallback_) {
|
|
recordFunctionEndCallback_();
|
|
recordFunctionEndCallback_ = nullptr;
|
|
}
|
|
if (exception_) {
|
|
std::rethrow_exception(exception_);
|
|
}
|
|
}
|
|
|
|
float Work::getDuration() const {
|
|
TORCH_CHECK(false, "This Backend doesn't support getDuration.");
|
|
}
|
|
|
|
uint64_t Work::getSequencenumber() const {
|
|
TORCH_CHECK(false, "This Backend doesn't support getSequencenumber.");
|
|
}
|
|
|
|
class FutureWrappingWork : public Work {
|
|
public:
|
|
FutureWrappingWork(c10::intrusive_ptr<c10::ivalue::Future> fut)
|
|
: _fut(std::move(fut)) {}
|
|
|
|
~FutureWrappingWork() override = default;
|
|
|
|
bool isCompleted() override {
|
|
return _fut->completed();
|
|
}
|
|
|
|
bool isSuccess() const override {
|
|
return _fut->hasValue();
|
|
}
|
|
|
|
std::exception_ptr exception() const override {
|
|
return _fut->exception_ptr();
|
|
}
|
|
|
|
int sourceRank() const override {
|
|
TORCH_CHECK(false, "FutureWrappingWork::sourceRank() not implemented");
|
|
}
|
|
|
|
std::vector<at::Tensor> result() override {
|
|
return _fut->value().toPyObjectHolder()->extractTensors();
|
|
}
|
|
|
|
bool wait(std::chrono::milliseconds timeout) override {
|
|
// FIXME
|
|
TORCH_CHECK(
|
|
timeout == kNoTimeout,
|
|
"FutureWrappingWork::wait() with finite timeout not implemented");
|
|
_fut->wait();
|
|
return true;
|
|
}
|
|
|
|
void abort() override {
|
|
TORCH_CHECK(false, "FutureWrappingWork::abort() not implemented");
|
|
}
|
|
|
|
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
|
|
return _fut;
|
|
}
|
|
|
|
private:
|
|
c10::intrusive_ptr<c10::ivalue::Future> _fut;
|
|
};
|
|
|
|
c10::intrusive_ptr<Work> Work::create_from_future(
|
|
const c10::intrusive_ptr<c10::ivalue::Future>& future) {
|
|
return c10::make_intrusive<FutureWrappingWork>(future);
|
|
}
|
|
|
|
} // namespace c10d
|