mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH` All changes but the ones to `.clang-tidy` are generated using following script: ``` for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008 Reviewed By: driazati, r-barnes Differential Revision: D29838584 Pulled By: malfet fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
297 lines
9.9 KiB
C++
297 lines
9.9 KiB
C++
#include <torch/csrc/distributed/rpc/rref_impl.h>
|
|
|
|
#include <ATen/record_function.h>
|
|
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
|
#include <fmt/format.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
|
#include <torch/csrc/distributed/autograd/utils.h>
|
|
#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
|
#include <torch/csrc/distributed/rpc/utils.h>
|
|
|
|
namespace {
|
|
// If the type is subtype of named type, return its qualifiedname, otherwise
|
|
// return its type str.
|
|
std::string getTypeStr(const c10::TypePtr& type) {
|
|
switch (type->kind()) {
|
|
case c10::TypeKind::FunctionType:
|
|
return type->castRaw<c10::FunctionType>()->name()->qualifiedName();
|
|
case c10::TypeKind::TupleType:
|
|
return type->castRaw<c10::TupleType>()->name()->qualifiedName();
|
|
case c10::TypeKind::ClassType:
|
|
return type->castRaw<c10::ClassType>()->name()->qualifiedName();
|
|
case c10::TypeKind::InterfaceType:
|
|
return type->castRaw<c10::InterfaceType>()->name()->qualifiedName();
|
|
default:
|
|
return type->annotation_str();
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
std::atomic<local_id_t> RRefContext::nextLocalId_{0};
|
|
|
|
////////////////////////// RRefForkData /////////////////////////////////
|
|
|
|
RRefForkData::RRefForkData(
|
|
worker_id_t ownerId,
|
|
const RRefId& rrefId,
|
|
const ForkId& forkId,
|
|
worker_id_t parent,
|
|
std::string typeStr)
|
|
: ownerId_(ownerId),
|
|
rrefId_(rrefId),
|
|
forkId_(forkId),
|
|
parent_(parent),
|
|
typeStr_(std::move(typeStr)) {}
|
|
|
|
////////////////////////////// RRef /////////////////////////////////////
|
|
|
|
RRef::RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
|
|
: RRefInterface(),
|
|
ownerId_(ownerId),
|
|
rrefId_(rrefId),
|
|
type_(std::move(type)) {}
|
|
|
|
RRefForkData RRef::fork() const {
|
|
auto& ctx = RRefContext::getInstance();
|
|
return RRefForkData(
|
|
ownerId_,
|
|
rrefId_,
|
|
ctx.genGloballyUniqueId(),
|
|
ctx.getWorkerId(),
|
|
getTypeStr(type_));
|
|
}
|
|
|
|
void RRef::handleError(RPCErrorType errorType, const JitFuture& jitFuture) {
|
|
static std::unordered_map<
|
|
RPCErrorType,
|
|
std::function<void(const JitFuture& jitFuture)>,
|
|
std::hash<int>>
|
|
errorHandlers = {
|
|
{RPCErrorType::TIMEOUT,
|
|
[this](const JitFuture& /* unused */) { setTimedOut(); }},
|
|
{RPCErrorType::INTENTIONAL_FAILURE,
|
|
[this](const JitFuture& /* unused */) { setTimedOut(); }},
|
|
{RPCErrorType::UNKNOWN_ERROR, [](const JitFuture& jitFuture) {
|
|
// Default error handler
|
|
RRefContext::handleException(jitFuture);
|
|
}}};
|
|
errorHandlers.find(errorType)->second(jitFuture);
|
|
}
|
|
|
|
////////////////////////// UserRRef /////////////////////////////////////
|
|
|
|
UserRRef::UserRRef(
|
|
worker_id_t ownerId,
|
|
const RRefId& rrefId,
|
|
const ForkId& forkId,
|
|
TypePtr type)
|
|
: RRef(ownerId, rrefId, std::move(type)),
|
|
forkId_(forkId),
|
|
confirmedByOwner_(false) {
|
|
// Do nothing,
|
|
// (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
|
|
// a RREF_FORK_REQUEST message to the owner.
|
|
// (2) If this the creator UserRRef, ScriptRemoteCall or PythonRemoteCall will
|
|
// properly notify the owner.
|
|
}
|
|
|
|
void UserRRef::tryDel() {
|
|
std::lock_guard<std::mutex> lockGuard(deletedOnOwnerMutex_);
|
|
if (!deletedOnOwner_) {
|
|
try {
|
|
RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_);
|
|
deletedOnOwner_ = true;
|
|
} catch (const std::exception& ex) {
|
|
LOG(ERROR) << "Error occurred when deleting" << *this << " : "
|
|
<< ex.what();
|
|
} catch (...) {
|
|
LOG(ERROR) << "Error occurred when deleting" << *this << " : "
|
|
<< "unknown error";
|
|
}
|
|
}
|
|
}
|
|
|
|
void UserRRef::release_resources() {
|
|
tryDel();
|
|
}
|
|
|
|
const ForkId& UserRRef::forkId() const {
|
|
return forkId_;
|
|
}
|
|
|
|
IValue UserRRef::toHere(const float timeoutSeconds) const {
|
|
TORCH_CHECK(
|
|
!getTimedOut(),
|
|
"RRef creation via rpc.remote() timed out, and it "
|
|
"is possible that the RRef on the owner node does not exist.");
|
|
// see Note [Best-Effort Check on Deleted UserRRefs]
|
|
TORCH_CHECK(
|
|
!deletedOnOwner_,
|
|
*this,
|
|
" has been deleted. Cannot call to_here() on it after deletion.");
|
|
auto toHereKey = std::string("");
|
|
if (torch::autograd::profiler::profilerEnabled()) {
|
|
toHereKey = fmt::format(
|
|
"to_here#({})->({})",
|
|
RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_,
|
|
RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_).name_);
|
|
}
|
|
RECORD_USER_SCOPE(toHereKey);
|
|
TORCH_CHECK(
|
|
!type_->is_module(),
|
|
*this,
|
|
" is an RRef to a ScriptModule. "
|
|
"It can't be sent through RPC "
|
|
"from owner, ",
|
|
ownerWorkerInfo(),
|
|
", to user, ",
|
|
RpcAgent::getCurrentRpcAgent()->getWorkerInfo(),
|
|
".");
|
|
|
|
auto agent = RpcAgent::getCurrentRpcAgent();
|
|
|
|
// ScriptRRefFetchCall message always carries autograd context id even if
|
|
// the message itself does not contain any tensor, because the response would
|
|
// potentially contain tensors.
|
|
c10::intrusive_ptr<Message> msgToSend;
|
|
|
|
if (isPyObj()) {
|
|
msgToSend = PythonRRefFetchCall(ownerId_, rrefId()).toMessage();
|
|
} else {
|
|
msgToSend = ScriptRRefFetchCall(ownerId_, rrefId()).toMessage();
|
|
}
|
|
|
|
// toHere is profiled as a blocking call, and does not execute operations on
|
|
// the remote node. Hence, don't wrap it with a profiling message since we
|
|
// don't need the profiler to be enabled remotely.
|
|
auto jitFuture = autograd::sendMessageWithAutograd(
|
|
*agent,
|
|
agent->getWorkerInfo(ownerId_),
|
|
std::move(msgToSend),
|
|
true /* forceGradRecording */,
|
|
timeoutSeconds,
|
|
true /* forceDisableProfiling */);
|
|
|
|
// TODO: we should ideally be able to interrupt this blocking wait if we check
|
|
// getTimedOut() and it is true
|
|
// (https://github.com/pytorch/pytorch/issues/39411).
|
|
jitFuture->waitAndThrow();
|
|
auto messagePtr = jitFuture->constValue().toCustomClass<Message>();
|
|
MessageType msgType = messagePtr->type();
|
|
auto response = deserializeResponse(*messagePtr, msgType);
|
|
TORCH_INTERNAL_ASSERT(
|
|
msgType == MessageType::SCRIPT_RREF_FETCH_RET ||
|
|
msgType == MessageType::PYTHON_RREF_FETCH_RET,
|
|
"Message type should either be SCRIPT_RREF_FETCH_RET "
|
|
"or PYTHON_RREF_FETCH_RET");
|
|
RpcCommandBase& rpc = *response;
|
|
auto& rrefFetchRet = static_cast<RRefFetchRet&>(rpc);
|
|
if (isPyObj()) {
|
|
// wrap python serialized vector of ivalues into tuple, this
|
|
// made the C++ toHere interface to return single IValue
|
|
return ivalue::Tuple::create(rrefFetchRet.values());
|
|
} else {
|
|
return rrefFetchRet.values().front();
|
|
}
|
|
}
|
|
|
|
RRefForkData UserRRef::fork() const {
|
|
// Note [Best-Effort Check on Deleted UserRRefs]
|
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
// This check does not guarantee correctness, as there could be another thread
|
|
// trying to delete this UserRRef concurrently. Passing this check does not
|
|
// mean this RRef will be alive throughout this function. This is just our
|
|
// best-effort attempt to raise proper error messages. The behavior of using
|
|
// deleted UserRRefs is undefined.
|
|
//
|
|
// The reason for not implementing strict checks are:
|
|
// 1. This would need to acquire lock on deletedOnOwnerMutex_, which would
|
|
// introduce unnecessary overhead for most normal use cases.
|
|
// 2. This would introduce a lot of complexities to get the behavior correct.
|
|
// Assume we acquired the lock here, and there is another thread X block
|
|
// waiting in tryDel() on the lock. Exiting this fork function would
|
|
// unblock thread X. However, while X proceeds with deleting this UserRRef,
|
|
// the call site of fork() might have added the UserRRef to
|
|
// pendingChildren_ map, but up to this point, nothing prevents X from
|
|
// deleting this RRef even if it shouldn't do so due to the state change
|
|
// in pendingChildren_. We might be able to get it right for now by locking
|
|
// and checking pendingChildren_ in X, but the gain does not seem to
|
|
// worth the complexity.
|
|
TORCH_CHECK(
|
|
!deletedOnOwner_,
|
|
*this,
|
|
" has been deleted. Cannot call fork an UserRRef after deletion.");
|
|
return RRef::fork();
|
|
}
|
|
|
|
////////////////////////// OwnerRRef /////////////////////////////////////
|
|
|
|
OwnerRRef::OwnerRRef(
|
|
worker_id_t ownerId,
|
|
const RRefId& rrefId,
|
|
TypePtr type,
|
|
std::vector<c10::Device> devices)
|
|
: OwnerRRef(ownerId, rrefId, type, /* value */ {}, std::move(devices)) {}
|
|
|
|
OwnerRRef::OwnerRRef(
|
|
worker_id_t ownerId,
|
|
const RRefId& rrefId,
|
|
TypePtr type,
|
|
c10::optional<IValue> value,
|
|
std::vector<c10::Device> devices)
|
|
: RRef(ownerId, rrefId, type) {
|
|
future_ = c10::make_intrusive<JitFuture>(type_, std::move(devices));
|
|
|
|
if (value.has_value()) {
|
|
future_->markCompleted(value.value());
|
|
}
|
|
}
|
|
|
|
const IValue& OwnerRRef::getValue() const {
|
|
TORCH_CHECK(
|
|
!getTimedOut(),
|
|
"RRef creation via rpc.remote() timed out, and it "
|
|
"is possible that the RRef on the owner node does not exist.");
|
|
future_->waitAndThrow();
|
|
return future_->constValue();
|
|
}
|
|
|
|
bool OwnerRRef::hasValue() const {
|
|
return future_->completed();
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> OwnerRRef::getFuture() {
|
|
return future_;
|
|
}
|
|
|
|
void OwnerRRef::setValue(IValue&& value) {
|
|
future_->markCompleted(value);
|
|
}
|
|
|
|
void OwnerRRef::setError(std::exception_ptr eptr) {
|
|
future_->setErrorIfNeeded(std::move(eptr));
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& os, const RRef& rref) {
|
|
if (rref.isOwner()) {
|
|
return os << "OwnerRRef("
|
|
<< "rref_id=" << rref.rrefId() << ")";
|
|
} else {
|
|
return os << "UserRRef("
|
|
<< "rref_id=" << rref.rrefId()
|
|
<< ", fork_id=" << static_cast<const UserRRef*>(&rref)->forkId()
|
|
<< ")";
|
|
}
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|