mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32959 in rpc torch script call path, we need to pickle/unpickle rref, this diff is added to make jit pickler/unpickler be able to pickle/unpickle rref. It is similar to what is implemented for PyRef::pickle() and PyRef::unpickle(). The pickling/unpickling design assumes it is always coupled with RPC calls. It is not needed to checkpoint a model with rref, before checkpointing the model, user should call ref.to_here() to get value inside rref. The pickling process is: 1. push torch.distributed.rpc.rref global string 1. call rref.fork() and create rrefForkData, which is a few IDs and type str of the value held inside the rref, the IDs includes rref id, fork id, caller work id, callee work id, owner work id 2. push the rrefForkData The unpickling process is: 1. read torch.distributed.rpc.rref global string, and retrieve the cached global lamda function 2. the globa lamda function will get rrefForkData 3. if callee is also owner work id, then get owner rref based on Ids inside rrefFork data and return the ownerRRef 4. if callee is not owner work id, then create user rref using the rrefForkData and return the userRRef 5. meanwhile owner rref will be notified and do reference counting correctly During unpickling, a type_resolver is needed to parse type str. This type_resolver has python dependency, so we get it from rpc_agent, and pass it to unpickler during construction. So we added a type_resolver argumenmt to jit unpickler constructor in this diff. ghstack-source-id: 98814793 Test Plan: unit test Differential Revision: D19713293 fbshipit-source-id: 4fd776cdd4ce8f457c4034d79acdfb4cd095c52e
180 lines
5.4 KiB
C++
180 lines
5.4 KiB
C++
#include <torch/csrc/distributed/rpc/rref_impl.h>
|
|
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
|
#include <torch/csrc/distributed/autograd/utils.h>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
|
#include <torch/csrc/distributed/rpc/utils.h>
|
|
|
|
namespace {
|
|
// If the type is subtype of named type, return its qualifiedname, otherwise
|
|
// return its type str.
|
|
std::string getTypeStr(const c10::TypePtr& type) {
|
|
switch (type->kind()) {
|
|
case c10::TypeKind::FunctionType:
|
|
return type->cast<c10::FunctionType>()->name()->qualifiedName();
|
|
case c10::TypeKind::TupleType:
|
|
return type->cast<c10::TupleType>()->name()->qualifiedName();
|
|
case c10::TypeKind::ClassType:
|
|
return type->cast<c10::ClassType>()->name()->qualifiedName();
|
|
case c10::TypeKind::InterfaceType:
|
|
return type->cast<c10::InterfaceType>()->name()->qualifiedName();
|
|
default:
|
|
return type->str();
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
std::atomic<local_id_t> RRefContext::nextLocalId_{0};
|
|
|
|
////////////////////////// RRefForkData /////////////////////////////////
|
|
|
|
RRefForkData::RRefForkData(
|
|
worker_id_t ownerId,
|
|
const RRefId& rrefId,
|
|
const ForkId& forkId,
|
|
worker_id_t parent,
|
|
std::string typeStr)
|
|
: ownerId_(ownerId),
|
|
rrefId_(rrefId),
|
|
forkId_(forkId),
|
|
parent_(parent),
|
|
typeStr_(std::move(typeStr)) {}
|
|
|
|
////////////////////////////// RRef /////////////////////////////////////
|
|
|
|
RRef::RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
|
|
: RRefInterface(),
|
|
ownerId_(ownerId),
|
|
rrefId_(rrefId),
|
|
type_(std::move(type)) {}
|
|
|
|
RRefForkData RRef::fork() const {
|
|
auto& ctx = RRefContext::getInstance();
|
|
return RRefForkData(
|
|
ownerId_,
|
|
rrefId_,
|
|
ctx.genGloballyUniqueId(),
|
|
ctx.getWorkerId(),
|
|
getTypeStr(type_));
|
|
}
|
|
|
|
////////////////////////// UserRRef /////////////////////////////////////
|
|
|
|
UserRRef::UserRRef(
|
|
worker_id_t ownerId,
|
|
const RRefId& rrefId,
|
|
const ForkId& forkId,
|
|
TypePtr type)
|
|
: RRef(ownerId, rrefId, std::move(type)), forkId_(forkId) {
|
|
// Do nothing,
|
|
// (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
|
|
// a RREF_FORK_REQUEST message to the owner.
|
|
// (2) If this the creator UserRRef, ScriptRemoteCall or PythonRemoteCall will
|
|
// properly notify the owner.
|
|
}
|
|
|
|
UserRRef::~UserRRef() {
|
|
try {
|
|
RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_);
|
|
} catch (const std::exception& ex) {
|
|
LOG(ERROR) << "Error occurred when deleting UserRRef instance, "
|
|
<< "RRefId = " << rrefId_ << ", ForkId = " << forkId_ << " : "
|
|
<< ex.what();
|
|
} catch (...) {
|
|
LOG(ERROR) << "Error occurred when deleting UserRRef instance, "
|
|
<< "RRefId = " << rrefId_ << ", ForkId = " << forkId_ << " : "
|
|
<< "unknown error";
|
|
}
|
|
}
|
|
|
|
const ForkId& UserRRef::forkId() const {
|
|
return forkId_;
|
|
}
|
|
|
|
IValue UserRRef::toHere() {
|
|
auto agent = RpcAgent::getCurrentRpcAgent();
|
|
|
|
// ScriptRRefFetchCall message always carries autograd context id even if
|
|
// the message itself does not contain any tensor, because the response would
|
|
// potentially contain tensors.
|
|
Message msgToSend;
|
|
|
|
if (isPyObj()) {
|
|
msgToSend = PythonRRefFetchCall(ownerId_, rrefId()).toMessage();
|
|
} else {
|
|
msgToSend = ScriptRRefFetchCall(ownerId_, rrefId()).toMessage();
|
|
}
|
|
|
|
auto futureResponse = autograd::sendMessageWithAutograd(
|
|
*agent,
|
|
agent->getWorkerInfo(ownerId_),
|
|
std::move(msgToSend),
|
|
true /* forceGradRecording */);
|
|
|
|
const Message& message = futureResponse->wait();
|
|
MessageType msgType = message.type();
|
|
auto response = deserializeResponse(message, msgType);
|
|
TORCH_INTERNAL_ASSERT(
|
|
msgType == MessageType::SCRIPT_RREF_FETCH_RET ||
|
|
msgType == MessageType::PYTHON_RREF_FETCH_RET,
|
|
"Message type should either be SCRIPT_RREF_FETCH_RET "
|
|
"or PYTHON_RREF_FETCH_RET");
|
|
RpcCommandBase& rpc = *response;
|
|
auto& rrefFetchRet = static_cast<RRefFetchRet&>(rpc);
|
|
if (isPyObj()) {
|
|
// wrap python serialized vector of ivalues into tuple, this
|
|
// made the C++ toHere interface to return single IValue
|
|
return ivalue::Tuple::create(rrefFetchRet.values());
|
|
} else {
|
|
return rrefFetchRet.values().front();
|
|
}
|
|
}
|
|
|
|
////////////////////////// OwnerRRef /////////////////////////////////////
|
|
|
|
const IValue& OwnerRRef::getValue() const {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
valueCV_.wait(lock, [this] { return value_.has_value(); });
|
|
return value_.value();
|
|
}
|
|
|
|
bool OwnerRRef::hasValue() const {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
return value_.has_value();
|
|
}
|
|
|
|
std::shared_ptr<FutureMessage> OwnerRRef::getFuture() {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
if (future_.get()) {
|
|
return future_;
|
|
}
|
|
future_ = std::make_shared<FutureMessage>();
|
|
std::shared_ptr<FutureMessage> ret = future_;
|
|
if (value_.has_value()) {
|
|
lock.unlock();
|
|
ret->markCompleted(Message());
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
void OwnerRRef::setValue(IValue&& value) {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
value_ = std::move(value);
|
|
std::shared_ptr<FutureMessage> future;
|
|
future.swap(future_);
|
|
lock.unlock();
|
|
valueCV_.notify_all();
|
|
if (future.get() && !future->completed()) {
|
|
future->markCompleted(Message());
|
|
}
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|