jit pickling rref (#32959)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32959

in rpc torch script call path, we need to pickle/unpickle rref, this diff is added to make jit pickler/unpickler be able to pickle/unpickle rref. It is similar to what is implemented for PyRef::pickle() and PyRef::unpickle().
The pickling/unpickling design assumes it is always coupled with RPC calls. It is not needed to checkpoint a model with rref, before checkpointing the model, user should call ref.to_here() to get value inside rref.

The pickling process is:
1. push torch.distributed.rpc.rref global string
1. call rref.fork() and create rrefForkData, which is a few IDs and type str of the value held inside the rref, the IDs includes rref id, fork id, caller work id, callee work id, owner work id
2. push the rrefForkData

The unpickling process is:
1. read torch.distributed.rpc.rref global string, and retrieve the cached global lamda function
2. the globa lamda function will get rrefForkData
3. if callee is also owner work id, then get owner rref based on Ids inside rrefFork data and return the ownerRRef
4. if callee is not owner work id, then create user rref using the rrefForkData and return the userRRef
5. meanwhile owner rref will be notified and do reference counting correctly

During unpickling, a type_resolver is needed to parse type str. This type_resolver has python dependency, so we get it from rpc_agent, and pass it to unpickler during construction. So we added a type_resolver argumenmt to jit unpickler constructor in this diff.
ghstack-source-id: 98814793

Test Plan: unit test

Differential Revision: D19713293

fbshipit-source-id: 4fd776cdd4ce8f457c4034d79acdfb4cd095c52e
This commit is contained in:
Yanli Zhao
2020-02-24 11:14:00 -08:00
committed by Facebook Github Bot
parent 481e7f2e78
commit 4d9b649261
25 changed files with 417 additions and 94 deletions

View File

@ -885,6 +885,14 @@ if(USE_ROCM)
)
endif()
# Pass USE_DISTRIBUTED to torch_cpu, as some codes in jit/pickler.cpp and
# jit/unpickler.cpp need to be compiled only when USE_DISTRIBUTED is set
if (USE_DISTRIBUTED)
target_compile_definitions(torch_cpu PRIVATE
USE_DISTRIBUTED
)
endif()
if (NOT INTERN_BUILD_MOBILE OR BUILD_CAFFE2_MOBILE)
caffe2_interface_library(caffe2_protos caffe2_protos_whole)
target_link_libraries(torch_cpu PRIVATE caffe2_protos_whole)

View File

@ -1,4 +1,5 @@
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/jit/pickle.h>
namespace torch {
@ -28,8 +29,11 @@ std::unique_ptr<CleanupAutogradContextReq> CleanupAutogradContextReq::
// unpickle and get the context_id we need to clean up
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
IValue ivalue_context_id =
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
IValue ivalue_context_id = jit::unpickle(
payload,
payload_size,
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
&message.tensors());
// convert ivalue to int and construct request
int64_t context_id = ivalue_context_id.toInt();

View File

@ -1,4 +1,5 @@
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/jit/pickle.h>
namespace torch {
@ -47,8 +48,11 @@ std::unique_ptr<PropagateGradientsReq> PropagateGradientsReq::fromMessage(
// Unpickle the message and retrieve tupleElements.
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
IValue tuple =
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
IValue tuple = jit::unpickle(
payload,
payload_size,
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
&message.tensors());
std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
// Build PropagateGradientsReq.

View File

@ -1,5 +1,6 @@
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
#include <c10/util/C++17.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/utils/byte_order.h>
@ -117,7 +118,10 @@ std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
autogradPayLoadSize;
std::vector<torch::Tensor> tensorTable;
IValue tuple = jit::unpickle(
autogradPayLoadBegin, autogradPayLoadSize, nullptr, &tensorTable);
autogradPayLoadBegin,
autogradPayLoadSize,
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
&tensorTable);
std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
// Gather all the fields.

View File

@ -259,8 +259,21 @@ If the future completes with an error, an exception is thrown.
"_set_and_start_rpc_agent",
[](const std::shared_ptr<RpcAgent>& rpcAgent) {
RpcAgent::setCurrentRpcAgent(rpcAgent);
// Initializing typeResolver inside RpcAgent constructor will make
// RpcAgent have python dependency. To avoid RpcAgent to have python
// dependency, setTypeResolver() here.
std::shared_ptr<TypeResolver> typeResolver =
std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr(
qn.qualifiedName());
return c10::StrongTypePtr(
PythonRpcHandler::getInstance().jitCompilationUnit(),
std::move(typePtr));
});
rpcAgent->setTypeResolver(typeResolver);
rpcAgent->start();
});
},
py::call_guard<py::gil_scoped_release>());
module.def("_reset_current_rpc_agent", []() {
RpcAgent::setCurrentRpcAgent(nullptr);

View File

@ -1,7 +1,7 @@
#pragma once
#include <torch/csrc/utils/future.h>
#include <torch/serialize.h>
#include <torch/types.h>
#include <vector>
namespace torch {

View File

@ -12,17 +12,6 @@ namespace rpc {
///////////////////// Pickle/Unpickle Helplers ////////////////////////////
namespace {
constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
constexpr int PARENT_IDX = 5; // index of parent in the tuple
constexpr int TYPE_IDX = 6; // index of parent in the tuple
// NB: if more fields are added, make sure this field is also bumped
constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
py::tuple toPyTuple(const RRefForkData& rrefForkData) {
// add GIL as it is contructing a py::object
pybind11::gil_scoped_acquire ag;
@ -40,7 +29,9 @@ RRefForkData fromPyTuple(const py::tuple& pyTuple) {
pybind11::gil_scoped_acquire ag;
TORCH_INTERNAL_ASSERT(
pyTuple.size() == RFD_TUPLE_SIZE,
"Pickled RRefForkData must contain 6 numbers.");
"Pickled RRefForkData must contain ",
RFD_TUPLE_SIZE,
" numbers.");
worker_id_t ownerId = pyTuple[OWNER_IDX].cast<worker_id_t>();
// const reference will extend the lifetime of the temporary variable
const RRefId& rrefId = RRefId(

View File

@ -1,4 +1,5 @@
#include <torch/csrc/distributed/rpc/python_remote_call.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/jit/pickle.h>
namespace torch {
@ -33,8 +34,11 @@ std::unique_ptr<PythonRemoteCall> PythonRemoteCall::fromMessage(
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value =
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
auto value = jit::unpickle(
payload,
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
&message.tensors());
auto values = value.toTuple()->elements();
// remove the last element from values and convert it back to an RRef

View File

@ -207,6 +207,15 @@ void RpcAgent::setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent) {
currentRpcAgent_ = std::move(rpcAgent);
}
void RpcAgent::setTypeResolver(std::shared_ptr<TypeResolver> typeResolver) {
typeResolver_ = std::move(typeResolver);
}
std::shared_ptr<TypeResolver> RpcAgent::getTypeResolver() {
TORCH_INTERNAL_ASSERT(typeResolver_, "Type resolver is not set!");
return typeResolver_;
}
void RpcAgent::enableGILProfiling(bool flag) {
profilingEnabled_ = flag;
}

View File

@ -13,6 +13,11 @@ namespace rpc {
using steady_clock_time_point =
std::chrono::time_point<std::chrono::steady_clock>;
// Input is qualified name string, output is JIT StrongTypePtr
// Same as jit::TypeResolver, did not import jit::TypeResolver to here
// because it could instroduce cyclic dependencies.
using TypeResolver =
std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
struct RpcBackendOptions {
RpcBackendOptions() = default;
@ -212,11 +217,19 @@ class TORCH_API RpcAgent {
// Retrieve wheher we should profile GIL wait times or not.
bool isGILProfilingEnabled();
// Set type resolver that will be passed to JIT pickler to resolver type Ptr
// based on type str.
void setTypeResolver(std::shared_ptr<TypeResolver> typeResolver);
// Get the type resolver
std::shared_ptr<TypeResolver> getTypeResolver();
protected:
const WorkerInfo workerInfo_;
const std::unique_ptr<RequestCallback> cb_;
std::atomic<std::chrono::milliseconds> rpcTimeout_;
std::atomic<bool> profilingEnabled_;
std::shared_ptr<TypeResolver> typeResolver_;
private:
static std::shared_ptr<RpcAgent> currentRpcAgent_;

View File

@ -6,6 +6,25 @@
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/utils.h>
namespace {
// If the type is subtype of named type, return its qualifiedname, otherwise
// return its type str.
std::string getTypeStr(const c10::TypePtr& type) {
switch (type->kind()) {
case c10::TypeKind::FunctionType:
return type->cast<c10::FunctionType>()->name()->qualifiedName();
case c10::TypeKind::TupleType:
return type->cast<c10::TupleType>()->name()->qualifiedName();
case c10::TypeKind::ClassType:
return type->cast<c10::ClassType>()->name()->qualifiedName();
case c10::TypeKind::InterfaceType:
return type->cast<c10::InterfaceType>()->name()->qualifiedName();
default:
return type->str();
}
}
} // namespace
namespace torch {
namespace distributed {
namespace rpc {
@ -41,7 +60,7 @@ RRefForkData RRef::fork() const {
rrefId_,
ctx.genGloballyUniqueId(),
ctx.getWorkerId(),
type_->str());
getTypeStr(type_));
}
////////////////////////// UserRRef /////////////////////////////////////

View File

@ -17,6 +17,17 @@ class RRef;
class RRefContext;
class UserRRef;
constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
constexpr int PARENT_IDX = 5; // index of parent in the tuple
constexpr int TYPE_IDX = 6; // index of parent in the tuple
// NB: if more fields are added, make sure this field is also bumped
constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
// Represents fork of an RRef to be sent over the wire.
struct TORCH_API RRefForkData {
const worker_id_t ownerId_;

View File

@ -1,4 +1,5 @@
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/jit/pickle.h>
#include <limits>
@ -19,8 +20,11 @@ std::vector<IValue> toIValues(const Message& message, MessageType type) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value =
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
auto value = jit::unpickle(
payload,
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
&message.tensors());
return value.toTuple()->elements();
}

View File

@ -1,4 +1,5 @@
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/jit/pickle.h>
namespace torch {
@ -104,8 +105,8 @@ Message ScriptCall::toMessage() && {
toIValues(ivalues);
std::vector<torch::Tensor> tensor_table;
auto payload =
jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table);
auto payload = jit::pickle(
c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
return Message(
std::move(payload), std::move(tensor_table), MessageType::SCRIPT_CALL);
@ -114,8 +115,11 @@ Message ScriptCall::toMessage() && {
std::unique_ptr<ScriptCall> ScriptCall::fromMessage(const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value =
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
auto value = jit::unpickle(
payload,
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
&message.tensors());
auto values = value.toTuple()->elements();
return fromIValues(values);

View File

@ -1,4 +1,5 @@
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <c10/util/C++17.h>
#include <torch/csrc/jit/pickle.h>
@ -68,8 +69,11 @@ std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value =
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
auto value = jit::unpickle(
payload,
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
&message.tensors());
auto values = value.toTuple()->elements();
return fromIValues(values);
}

View File

@ -1,6 +1,7 @@
#include <torch/csrc/distributed/rpc/script_resp.h>
#include <c10/util/C++17.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/jit/unpickler.h>
@ -32,8 +33,11 @@ Message ScriptResp::toMessage() && {
std::unique_ptr<ScriptResp> ScriptResp::fromMessage(const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value =
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
auto value = jit::unpickle(
payload,
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
&message.tensors());
return std::make_unique<ScriptResp>(std::move(value));
}

View File

@ -331,6 +331,8 @@ std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize(
return dptr;
};
// No need to pass typeResolver here, as it always processes string and
// tensors only
torch::jit::Unpickler unpickler(
metaDataReadFunc, nullptr, nullptr, sectionReadFunc, {});
auto ival = unpickler.parse_ivalue();

View File

@ -66,6 +66,7 @@ TORCH_API IValue pickle_load(const std::vector<char>& data);
/// binary. `reader` should remember where it last read, and return
/// the number of bytes read.
/// See `torch::pickle` for details.
/// type_resolver is used to resolve any JIT type based on type str
TORCH_API IValue unpickle(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,

View File

@ -1,5 +1,8 @@
#include <ATen/ATen.h>
#include <ATen/core/Dict.h>
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/rpc/rref_context.h>
#endif
#include <torch/csrc/jit/function.h>
#include <torch/csrc/jit/pickler.h>
#include <aten/src/ATen/quantized/Quantizer.h>
@ -126,6 +129,13 @@ void Pickler::pushIValueImpl(const IValue& ivalue) {
err << ". Please define serialization methods via torch::jit::pickle_ for "
"this class.";
AT_ERROR(err.str());
} else if (ivalue.isRRef()) {
#ifdef USE_DISTRIBUTED
pushRRef(ivalue);
#else
TORCH_CHECK(
false, "RRef pickling is only supported with the distributed package");
#endif
} else {
AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind());
}
@ -146,6 +156,28 @@ void Pickler::pushDevice(const IValue& ivalue) {
}
}
#ifdef USE_DISTRIBUTED
void Pickler::pushRRef(const IValue& ivalue) {
// It is the same as how rref is pickled in python, see PyRRef::pickle
auto rrefInterface = ivalue.toRRef();
auto rref =
c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(rrefInterface);
pushGlobal("torch.distributed.rpc", "rref");
auto& ctx = distributed::rpc::RRefContext::getInstance();
auto rrefForkData = ctx.prepareChildFork(rref);
push<PickleOpCode>(PickleOpCode::MARK);
pushInt(rrefForkData.ownerId_);
pushInt(rrefForkData.rrefId_.createdOn_);
pushInt(rrefForkData.rrefId_.localId_);
pushInt(rrefForkData.forkId_.createdOn_);
pushInt(rrefForkData.forkId_.localId_);
pushInt(rrefForkData.parent_);
pushString(rrefForkData.typeStr_);
push<PickleOpCode>(PickleOpCode::TUPLE);
push<PickleOpCode>(PickleOpCode::REDUCE);
}
#endif
void Pickler::pushIValue(const IValue& ivalue) {
bool shouldMemoizeByPointer =
ivalue.isPtrType() && !ivalue.isString() && ivalue.use_count() > 1;

View File

@ -155,6 +155,9 @@ class Pickler {
void pushTuple(const IValue& ivalue);
void pushString(const std::string& string);
void pushDevice(const IValue& ivalue);
#ifdef USE_DISTRIBUTED
void pushRRef(const IValue& ivalue);
#endif
// unmemoized version
void pushStringImpl(const std::string& string);
void pushStorageOfTensor(const at::Tensor& tensor);

View File

@ -749,6 +749,14 @@ inline py::object toPyObject(IValue ivalue) {
return py::reinterpret_borrow<py::object>(ivalue.toPyObject());
} else if (ivalue.isCapsule()) {
return py::cast(ivalue.toCapsule());
} else if (ivalue.isRRef()) {
#ifdef USE_DISTRIBUTED
return py::cast(torch::distributed::rpc::PyRRef(
c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(
ivalue.toRRef())));
#else
TORCH_CHECK(false, "RRef is only supported with the distributed package");
#endif
} else {
AT_ERROR(
"Missing cases in 'toPyObject'! Can't convert ",

View File

@ -1,9 +1,12 @@
#include <ATen/ATen.h>
#include <ATen/core/Dict.h>
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/rpc/rref_context.h>
#endif
#include <torch/csrc/jit/function.h>
#include <torch/csrc/jit/pickler.h>
#include "unpickler.h"
#include <string>
#include "unpickler.h"
namespace torch {
namespace jit {
@ -133,7 +136,7 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
case ClassType::Kind: {
auto obj = w.value.toObject();
auto typ = obj->type(); // note: intentionally using the dynamic type,
// the static type is potentially less accurate
// the static type is potentially less accurate
for (size_t i = 0; i < typ->numAttributes(); ++i) {
Work elem = {typ->getAttribute(i), obj->getSlot(i)};
to_process.emplace_back(std::move(elem));
@ -174,7 +177,8 @@ void Unpickler::run() {
TORCH_CHECK(
opcode == PickleOpCode::PROTO,
"Expected PROTO opcode at the start"
" of pickle archive, found ", int(static_cast<uint8_t>(opcode)));
" of pickle archive, found ",
int(static_cast<uint8_t>(opcode)));
uint8_t protocol = read<uint8_t>();
TORCH_CHECK(
protocol == 2,
@ -220,8 +224,7 @@ static std::vector<int64_t> tupleToIntList(const IValue& v) {
// lists are not yet tagged
template <typename T>
static std::vector<T> convertList(const IValue& v) {
return fmap(
v.toListRef(), [](const IValue& elem) { return elem.to<T>(); });
return fmap(v.toListRef(), [](const IValue& elem) { return elem.to<T>(); });
}
PickleOpCode Unpickler::readInstruction() {
@ -453,11 +456,11 @@ void Unpickler::readGlobal(
// Pop reduce arg off the stack
auto data = stack_.back().toTuple()->elements().at(0);
stack_.pop_back();
TORCH_CHECK(
tensor_table_,
"Found a tensor table reference but Unpickler"
" has no tensor table\n");
stack_.emplace_back(tensor_table_->at(data.toInt()));
TORCH_CHECK(
tensor_table_,
"Found a tensor table reference but Unpickler"
" has no tensor table\n");
stack_.emplace_back(tensor_table_->at(data.toInt()));
});
} else {
TypePtr elem_type = nullptr;
@ -498,13 +501,21 @@ void Unpickler::readGlobal(
stack_.back() = IValue();
});
} else if (module_name == "torch" && class_name == "device") {
globals_.emplace_back([this] {
auto device_string = stack_.back().toTuple()->elements().at(0);
stack_.pop_back();
stack_.emplace_back(c10::Device(device_string.toStringRef()));
});
stack_.emplace_back(int64_t(globals_.size() - 1));
return;
globals_.emplace_back([this] {
auto device_string = stack_.back().toTuple()->elements().at(0);
stack_.pop_back();
stack_.emplace_back(c10::Device(device_string.toStringRef()));
});
stack_.emplace_back(int64_t(globals_.size() - 1));
return;
} else if (module_name == "torch.distributed.rpc" && class_name == "rref") {
#ifdef USE_DISTRIBUTED
return rebuildRRef();
#else
TORCH_INTERNAL_ASSERT(
false,
"RRef unpickling is only supported with the distributed package");
#endif
} else if (module_name == "torch") {
// Try to manually resolve several global enums
// NOTE: this does not put a global into the global table,
@ -577,7 +588,7 @@ void Unpickler::rebuildTensor(bool quantized) {
const auto& scales = qparams.at(1).toTensor();
const auto& zero_points = qparams.at(2).toTensor();
int64_t axis = qparams.at(3).toInt();
result = _empty_per_channel_affine_quantized(
result = at::_empty_per_channel_affine_quantized(
{0},
scales,
zero_points,
@ -605,7 +616,48 @@ void Unpickler::rebuildTensor(bool quantized) {
});
}
void Unpickler::readSlowWithBuffer(char *dest, size_t sz) {
#ifdef USE_DISTRIBUTED
void Unpickler::rebuildRRef() {
globals_.emplace_back([this] {
// It is the same as how rref is unpickled in python,
// see PyRRef::unpickle
auto args = stack_.back().toTuple()->elements();
stack_.pop_back();
TORCH_INTERNAL_ASSERT(
args.size() == distributed::rpc::RFD_TUPLE_SIZE,
"Pickled RRefForkData must contain 7 numbers.");
auto ownerId =
static_cast<int16_t>(args.at(distributed::rpc::OWNER_IDX).toInt());
// const reference will extend the lifetime of the temporary variable
const auto& rrefId = distributed::rpc::RRefId(
static_cast<int16_t>(args.at(distributed::rpc::RREFID_ON_IDX).toInt()),
static_cast<int64_t>(args.at(distributed::rpc::RREFID_ID_IDX).toInt()));
const auto& forkId = distributed::rpc::RRefId(
static_cast<int16_t>(args.at(distributed::rpc::FORKID_ON_IDX).toInt()),
static_cast<int64_t>(args.at(distributed::rpc::FORKID_ID_IDX).toInt()));
auto parent =
static_cast<int16_t>(args.at(distributed::rpc::PARENT_IDX).toInt());
const auto& typeStr = static_cast<std::string>(
args.at(distributed::rpc::TYPE_IDX).toStringRef());
auto rrefForkData = distributed::rpc::RRefForkData(
ownerId, rrefId, forkId, parent, typeStr);
auto& ctx = distributed::rpc::RRefContext::getInstance();
c10::intrusive_ptr<distributed::rpc::RRef> rref;
TORCH_INTERNAL_ASSERT(
type_resolver_ != nullptr, "type_resolver_ is nullptr.");
at::StrongTypePtr type = type_resolver_(c10::QualifiedName(typeStr));
rref = ctx.getOrCreateRRef(rrefForkData, type.type_);
ctx.notifyOwnerAndParentOfFork(
rrefForkData.forkId_, rrefForkData.parent_, rref);
stack_.emplace_back(
c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref));
});
stack_.emplace_back(int64_t(globals_.size() - 1));
return;
}
#endif
void Unpickler::readSlowWithBuffer(char* dest, size_t sz) {
// First, read any partial from buffer (may be 0).
// We explicitly assume that sz > buffer_remaining_,
// and that sz is never bigger than buffer_.size().
@ -623,7 +675,7 @@ void Unpickler::readSlowWithBuffer(char *dest, size_t sz) {
AT_ERROR("Unexpected end of pickler archive.");
}
memcpy(dest + from_old_buf, buffer_.data(), needed);
buffer_pos_ = needed; // assignment (0'ed from read)
buffer_pos_ = needed; // assignment (0'ed from read)
buffer_remaining_ -= needed;
}

View File

@ -19,7 +19,11 @@ class Unpickler {
TH_DISALLOW_COPY_AND_ASSIGN(Unpickler);
public:
// tensors inside the pickle are references to the tensor_table
// tensors inside the pickle are references to the tensor_table.
// class_resolver is to resolve strong class type, type_resolver_ is
// to resolve any JIT type. class_resolver and type_resolver are not merged
// here because some use cases need to get strong class type that
// type_resolver_ can not return.
Unpickler(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
@ -76,6 +80,9 @@ class Unpickler {
const std::string& module_name,
const std::string& class_name);
void rebuildTensor(bool quantized);
#ifdef USE_DISTRIBUTED
void rebuildRRef();
#endif
PickleOpCode readInstruction();
PickleOpCode readOpCode() {
return static_cast<PickleOpCode>(read<uint8_t>());

View File

@ -59,13 +59,6 @@ def _check_rpc_done(rank_distance):
def _torch_ones(sizes, requires_grad=False):
return torch.ones(sizes, requires_grad=requires_grad)
# creates an owner rref on the given dst, and the rref holds a torch.ones tensor
# of the given size.
def _create_ones_rref_on(dst, sizes):
return rpc.remote(dst, _torch_ones, args=(sizes,), kwargs={"requires_grad": True})
# This method must be called on the rref owner, and verifies that the grad of
# rref tensor equals to the given grad.
def _compare_owner_value(context_id, rref, grad):
@ -73,6 +66,16 @@ def _compare_owner_value(context_id, rref, grad):
return torch.equal(grads[rref.local_value()], grad)
def create_tensor():
return torch.ones((3, 3), requires_grad=True)
@torch.jit.script
def create_torchscript_tensor():
# type: () -> Tensor
return torch.ones((3, 3)).requires_grad_()
def my_py_add(t1, t2):
return torch.add(t1, t2)
@ -91,6 +94,13 @@ def my_script_add(t1, t2):
return torch.add(t1, t2)
@torch.jit.script
def my_script_ref_add(ref_t1, t2):
# type: (RRef[Tensor], Tensor) -> Tensor
t1 = ref_t1.to_here()
return torch.add(t1, t2)
def my_nested_rref_add(dst, rref_t1, t2):
return rpc.rpc_sync(dst, my_rref_add, args=(rref_t1, t2))
@ -142,6 +152,16 @@ def _run_trainer(rref_t1, t2, ps, rank_diff):
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
# This function is the same as _run_trainer, except rpc calls torchscript
# function "my_script_ref_add" instead of python funciton "my_rref_add"
def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff):
with dist_autograd.context() as context_id:
ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2))
dist_autograd.backward([ret.sum()])
# prevent deleting dist autograd context
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
class SimulateBackwardError(Function):
@staticmethod
@ -913,8 +933,7 @@ class DistAutogradTest(RpcAgentTestFixture):
#
# These four test ps-trainer groups run on completely separate autograd
# graphs, but they share the same set of underlying RpcAgents.
@dist_init
def test_trainer_ps(self):
def _test_trainer_ps(self, create_ref_fn, trainer_fn):
local_grads = None
t1 = torch.ones((3, 3), requires_grad=True)
t2 = torch.zeros((3, 3), requires_grad=True)
@ -923,13 +942,10 @@ class DistAutogradTest(RpcAgentTestFixture):
local_ret.sum().backward()
# create rref on self
# TODO: simplify this once we support rpc to self
self_name = "worker{}".format(self.rank)
rref_t1 = rpc.rpc_sync(
"worker{}".format(self._next_rank()),
_create_ones_rref_on,
args=(self_name, (3, 3)),
)
rref_t1 = rpc.remote(
"worker{}".format(self.rank),
create_ref_fn,
args=())
# kick off forward and backward pass on three other workers (trainers)
rank_diffs = [1, 2, 3]
@ -938,8 +954,8 @@ class DistAutogradTest(RpcAgentTestFixture):
futures.append(
rpc.rpc_async(
"worker{}".format((self.rank + rank_diff) % self.world_size),
_run_trainer,
args=(rref_t1, t2, self_name, rank_diff),
trainer_fn,
args=(rref_t1, t2, "worker{}".format(self.rank), rank_diff),
)
)
@ -954,7 +970,7 @@ class DistAutogradTest(RpcAgentTestFixture):
# are all correct
ctx_id = ctx_ids[rank_diff]
grads = dist_autograd.get_gradients(ctx_id)
local_t1 = rref_t1.local_value()
local_t1 = rref_t1.to_here()
self.assertIn(local_t1, grads)
self.assertEqual(grads[local_t1], t1.grad)
@ -965,6 +981,21 @@ class DistAutogradTest(RpcAgentTestFixture):
for fut in futures:
fut.wait()
@dist_init
def test_trainer_ps(self):
self._test_trainer_ps(create_tensor, _run_trainer)
@dist_init
def test_trainer_ps_torchscript_functions(self):
# TODO, need more investigation
# there is rref leak when shutting down, suspect it is because
# ref as arg is passed to pybind boundary, and the ref is not garbage
# collected by python when calling shutdown()
import torch.distributed.rpc.api as api
api._ignore_rref_leak = True
self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript)
@dist_init
def test_backward_multiple_round_trips(self):
local_grads = None
@ -1261,8 +1292,8 @@ class DistAutogradTest(RpcAgentTestFixture):
ExecMode.REMOTE,
]:
with dist_autograd.context() as context_id:
ret = self._exec_func(exec_mode, my_script_add, t1, t2)
loss = ret.sum()
forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2)
loss = forward_ret.sum()
ret = self._verify_backwards(
exec_mode, [loss], context_id, local_grads, t1, t2
)

View File

@ -237,6 +237,12 @@ def heavy_rpc(tensor):
tensor /= i + 1
return 0
@torch.jit.script
def heavy_rpc_torchscript(tensor):
for i in range(1, 100):
tensor *= i
tensor /= i + 1
return 0
def raise_func():
raise ValueError("Expected error")
@ -257,6 +263,7 @@ def clear_global_rref():
@torch.jit.script
def one_arg(value):
# type: (Tensor) -> Tensor
return value + 1
@ -265,16 +272,21 @@ class MyScriptClass:
def __init__(self):
self.a = 10
@torch.jit.interface
class MyModuleInterface(torch.nn.Module):
def forward(self):
# type: () -> Tensor
pass
class MyScriptModule(torch.jit.ScriptModule):
def __init__(self):
def __init__(self, rank):
super().__init__()
self.a = 10
self.a = torch.ones(rank)
@torch.jit.script_method
def my_method(self):
self.a = 11
def forward(self):
# type: () -> Tensor
return self.a
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
@ -877,7 +889,7 @@ class RpcTest(RpcAgentTestFixture):
RuntimeError, "attempted to get undefined function"
):
ret = rpc._rpc_sync_torchscript(
"worker{}".format(dst_rank), _qualified_name(MyScriptModule), args=()
"worker{}".format(dst_rank), _qualified_name(MyScriptModule), args=(self.rank, )
)
with self.assertRaisesRegex(
@ -885,15 +897,17 @@ class RpcTest(RpcAgentTestFixture):
):
ret = rpc._rpc_sync_torchscript(
"worker{}".format(dst_rank),
_qualified_name(MyScriptModule().my_method),
_qualified_name(MyScriptModule(self.rank).forward),
args=(),
)
# Python 3.5 and Python 3.6 throw different error message, the only
# common word can be greped is "pickle".
with self.assertRaisesRegex(Exception, "pickle"):
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), MyScriptModule().my_method, args=()
)
'worker{}'.format(dst_rank),
MyScriptModule(self.rank).forward,
args=())
@dist_init
def test_nested_rpc(self):
@ -919,8 +933,8 @@ class RpcTest(RpcAgentTestFixture):
self.assertEqual(fut.wait(), 0)
tok = time.time()
print(
"Rank {} finished testing {} {} times in {} seconds.".format(
self.rank, f.__name__, repeat, tok - tik
"Rank {} finished testing {} times in {} seconds.".format(
self.rank, repeat, tok - tik
)
)
@ -932,6 +946,10 @@ class RpcTest(RpcAgentTestFixture):
def test_stress_heavy_rpc(self):
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
@dist_init
def test_stress_heavy_rpc_torchscript(self):
self._stress_test_rpc(heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),))
@dist_init
def test_builtin_remote_ret(self):
n = self.rank + 1
@ -1716,25 +1734,98 @@ class RpcTest(RpcAgentTestFixture):
rpc.rpc_sync(callee_worker, foo_add, args=())
self.assertTrue(torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler)
@unittest.skipIf(
sys.version_info < (3, 0),
"Pytorch distributed rpc package " "does not support python2",
)
class RpcJitTest(RpcAgentTestFixture):
@dist_init
def test_rref_as_arg(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_var = rpc_return_rref("worker{}".format(dst_rank))
def test_rref_as_arg_and_return(self):
@torch.jit.script
def rref_tensor_to_here(rref_var):
def rref_to_here(rref_var):
# type: (RRef[Tensor]) -> Tensor
return rref_var.to_here()
res = rref_tensor_to_here(rref_var)
self.assertEqual(res, torch.ones(2, 2) + 1)
@torch.jit.script
def return_rref(rref_var):
# type: (RRef[Tensor]) -> RRef[Tensor]
return rref_var
n = self.rank + 1
dst_rank = n % self.world_size
local_ret = one_arg(torch.ones(2, 2))
# create rref on current rank
rref = rpc.remote("worker{}".format(self.rank), one_arg, args=(torch.ones(2, 2),))
# pass rref to another user in rpc call
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
rref_to_here,
args=(rref,))
self.assertEqual(ret, local_ret)
# return rref in rpc call
rref1 = rpc.rpc_sync(
"worker{}".format(dst_rank),
return_rref,
args=(rref,))
self.assertEqual(rref1.to_here(), local_ret)
# pass rref to another user in remote call
rref2 = rpc.remote(
"worker{}".format(dst_rank),
rref_to_here,
args=(rref,))
self.assertEqual(rref2.to_here(), local_ret)
# return rref in remote call
rref3 = rpc.remote(
"worker{}".format(dst_rank),
return_rref,
args=(rref,))
self.assertEqual(rref3.to_here().to_here(), local_ret)
@dist_init
def test_remote_script_module(self):
@torch.jit.ignore
def my_script_module_init(rank):
# type: (int) -> MyModuleInterface
return MyScriptModule(rank)
@torch.jit.script
def construct_my_script_module(rank):
# type: (int) -> MyModuleInterface
return my_script_module_init(rank)
@torch.jit.script
def run_ref_script_module(ref_script_module, t):
# type: (RRef[MyModuleInterface], Tensor) -> Tensor
module = ref_script_module.to_here()
return module.forward() + t
# TODO, need more investigation
# there is rref leak when shutting down, suspect it is because
# ref as arg is passed to pybind boundary, and the ref is not garbage
# collected by python when calling shutdown()
import torch.distributed.rpc.api as api
api._ignore_rref_leak = True
local_ret = MyScriptModule(self.rank).forward() + torch.ones(self.rank)
n = self.rank + 1
dst_rank = n % self.world_size
remote_ref = rpc.remote(
"worker{}".format(dst_rank),
construct_my_script_module,
args=(self.rank, ))
# pass rref arg to owner
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
run_ref_script_module,
args=(remote_ref, torch.ones(self.rank)))
self.assertEqual(ret, local_ret)
@dist_init
def test_rref_is_owner(self):