mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
jit pickling rref (#32959)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32959 in rpc torch script call path, we need to pickle/unpickle rref, this diff is added to make jit pickler/unpickler be able to pickle/unpickle rref. It is similar to what is implemented for PyRef::pickle() and PyRef::unpickle(). The pickling/unpickling design assumes it is always coupled with RPC calls. It is not needed to checkpoint a model with rref, before checkpointing the model, user should call ref.to_here() to get value inside rref. The pickling process is: 1. push torch.distributed.rpc.rref global string 1. call rref.fork() and create rrefForkData, which is a few IDs and type str of the value held inside the rref, the IDs includes rref id, fork id, caller work id, callee work id, owner work id 2. push the rrefForkData The unpickling process is: 1. read torch.distributed.rpc.rref global string, and retrieve the cached global lamda function 2. the globa lamda function will get rrefForkData 3. if callee is also owner work id, then get owner rref based on Ids inside rrefFork data and return the ownerRRef 4. if callee is not owner work id, then create user rref using the rrefForkData and return the userRRef 5. meanwhile owner rref will be notified and do reference counting correctly During unpickling, a type_resolver is needed to parse type str. This type_resolver has python dependency, so we get it from rpc_agent, and pass it to unpickler during construction. So we added a type_resolver argumenmt to jit unpickler constructor in this diff. ghstack-source-id: 98814793 Test Plan: unit test Differential Revision: D19713293 fbshipit-source-id: 4fd776cdd4ce8f457c4034d79acdfb4cd095c52e
This commit is contained in:
committed by
Facebook Github Bot
parent
481e7f2e78
commit
4d9b649261
@ -885,6 +885,14 @@ if(USE_ROCM)
|
||||
)
|
||||
endif()
|
||||
|
||||
# Pass USE_DISTRIBUTED to torch_cpu, as some codes in jit/pickler.cpp and
|
||||
# jit/unpickler.cpp need to be compiled only when USE_DISTRIBUTED is set
|
||||
if (USE_DISTRIBUTED)
|
||||
target_compile_definitions(torch_cpu PRIVATE
|
||||
USE_DISTRIBUTED
|
||||
)
|
||||
endif()
|
||||
|
||||
if (NOT INTERN_BUILD_MOBILE OR BUILD_CAFFE2_MOBILE)
|
||||
caffe2_interface_library(caffe2_protos caffe2_protos_whole)
|
||||
target_link_libraries(torch_cpu PRIVATE caffe2_protos_whole)
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
|
||||
namespace torch {
|
||||
@ -28,8 +29,11 @@ std::unique_ptr<CleanupAutogradContextReq> CleanupAutogradContextReq::
|
||||
// unpickle and get the context_id we need to clean up
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
IValue ivalue_context_id =
|
||||
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
|
||||
IValue ivalue_context_id = jit::unpickle(
|
||||
payload,
|
||||
payload_size,
|
||||
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
|
||||
&message.tensors());
|
||||
|
||||
// convert ivalue to int and construct request
|
||||
int64_t context_id = ivalue_context_id.toInt();
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
|
||||
namespace torch {
|
||||
@ -47,8 +48,11 @@ std::unique_ptr<PropagateGradientsReq> PropagateGradientsReq::fromMessage(
|
||||
// Unpickle the message and retrieve tupleElements.
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
IValue tuple =
|
||||
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
|
||||
IValue tuple = jit::unpickle(
|
||||
payload,
|
||||
payload_size,
|
||||
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
|
||||
&message.tensors());
|
||||
std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
|
||||
|
||||
// Build PropagateGradientsReq.
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
#include <torch/csrc/utils/byte_order.h>
|
||||
@ -117,7 +118,10 @@ std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
|
||||
autogradPayLoadSize;
|
||||
std::vector<torch::Tensor> tensorTable;
|
||||
IValue tuple = jit::unpickle(
|
||||
autogradPayLoadBegin, autogradPayLoadSize, nullptr, &tensorTable);
|
||||
autogradPayLoadBegin,
|
||||
autogradPayLoadSize,
|
||||
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
|
||||
&tensorTable);
|
||||
std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
|
||||
|
||||
// Gather all the fields.
|
||||
|
@ -259,8 +259,21 @@ If the future completes with an error, an exception is thrown.
|
||||
"_set_and_start_rpc_agent",
|
||||
[](const std::shared_ptr<RpcAgent>& rpcAgent) {
|
||||
RpcAgent::setCurrentRpcAgent(rpcAgent);
|
||||
// Initializing typeResolver inside RpcAgent constructor will make
|
||||
// RpcAgent have python dependency. To avoid RpcAgent to have python
|
||||
// dependency, setTypeResolver() here.
|
||||
std::shared_ptr<TypeResolver> typeResolver =
|
||||
std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
|
||||
auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr(
|
||||
qn.qualifiedName());
|
||||
return c10::StrongTypePtr(
|
||||
PythonRpcHandler::getInstance().jitCompilationUnit(),
|
||||
std::move(typePtr));
|
||||
});
|
||||
rpcAgent->setTypeResolver(typeResolver);
|
||||
rpcAgent->start();
|
||||
});
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
module.def("_reset_current_rpc_agent", []() {
|
||||
RpcAgent::setCurrentRpcAgent(nullptr);
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/utils/future.h>
|
||||
#include <torch/serialize.h>
|
||||
#include <torch/types.h>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
|
@ -12,17 +12,6 @@ namespace rpc {
|
||||
///////////////////// Pickle/Unpickle Helplers ////////////////////////////
|
||||
|
||||
namespace {
|
||||
constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
|
||||
constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
|
||||
constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
|
||||
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
|
||||
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
|
||||
constexpr int PARENT_IDX = 5; // index of parent in the tuple
|
||||
constexpr int TYPE_IDX = 6; // index of parent in the tuple
|
||||
|
||||
// NB: if more fields are added, make sure this field is also bumped
|
||||
constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
|
||||
|
||||
py::tuple toPyTuple(const RRefForkData& rrefForkData) {
|
||||
// add GIL as it is contructing a py::object
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
@ -40,7 +29,9 @@ RRefForkData fromPyTuple(const py::tuple& pyTuple) {
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
pyTuple.size() == RFD_TUPLE_SIZE,
|
||||
"Pickled RRefForkData must contain 6 numbers.");
|
||||
"Pickled RRefForkData must contain ",
|
||||
RFD_TUPLE_SIZE,
|
||||
" numbers.");
|
||||
worker_id_t ownerId = pyTuple[OWNER_IDX].cast<worker_id_t>();
|
||||
// const reference will extend the lifetime of the temporary variable
|
||||
const RRefId& rrefId = RRefId(
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <torch/csrc/distributed/rpc/python_remote_call.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
|
||||
namespace torch {
|
||||
@ -33,8 +34,11 @@ std::unique_ptr<PythonRemoteCall> PythonRemoteCall::fromMessage(
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
|
||||
auto value =
|
||||
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
|
||||
auto value = jit::unpickle(
|
||||
payload,
|
||||
payload_size,
|
||||
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
|
||||
&message.tensors());
|
||||
auto values = value.toTuple()->elements();
|
||||
|
||||
// remove the last element from values and convert it back to an RRef
|
||||
|
@ -207,6 +207,15 @@ void RpcAgent::setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent) {
|
||||
currentRpcAgent_ = std::move(rpcAgent);
|
||||
}
|
||||
|
||||
void RpcAgent::setTypeResolver(std::shared_ptr<TypeResolver> typeResolver) {
|
||||
typeResolver_ = std::move(typeResolver);
|
||||
}
|
||||
|
||||
std::shared_ptr<TypeResolver> RpcAgent::getTypeResolver() {
|
||||
TORCH_INTERNAL_ASSERT(typeResolver_, "Type resolver is not set!");
|
||||
return typeResolver_;
|
||||
}
|
||||
|
||||
void RpcAgent::enableGILProfiling(bool flag) {
|
||||
profilingEnabled_ = flag;
|
||||
}
|
||||
|
@ -13,6 +13,11 @@ namespace rpc {
|
||||
|
||||
using steady_clock_time_point =
|
||||
std::chrono::time_point<std::chrono::steady_clock>;
|
||||
// Input is qualified name string, output is JIT StrongTypePtr
|
||||
// Same as jit::TypeResolver, did not import jit::TypeResolver to here
|
||||
// because it could instroduce cyclic dependencies.
|
||||
using TypeResolver =
|
||||
std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
|
||||
|
||||
struct RpcBackendOptions {
|
||||
RpcBackendOptions() = default;
|
||||
@ -212,11 +217,19 @@ class TORCH_API RpcAgent {
|
||||
// Retrieve wheher we should profile GIL wait times or not.
|
||||
bool isGILProfilingEnabled();
|
||||
|
||||
// Set type resolver that will be passed to JIT pickler to resolver type Ptr
|
||||
// based on type str.
|
||||
void setTypeResolver(std::shared_ptr<TypeResolver> typeResolver);
|
||||
|
||||
// Get the type resolver
|
||||
std::shared_ptr<TypeResolver> getTypeResolver();
|
||||
|
||||
protected:
|
||||
const WorkerInfo workerInfo_;
|
||||
const std::unique_ptr<RequestCallback> cb_;
|
||||
std::atomic<std::chrono::milliseconds> rpcTimeout_;
|
||||
std::atomic<bool> profilingEnabled_;
|
||||
std::shared_ptr<TypeResolver> typeResolver_;
|
||||
|
||||
private:
|
||||
static std::shared_ptr<RpcAgent> currentRpcAgent_;
|
||||
|
@ -6,6 +6,25 @@
|
||||
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
|
||||
namespace {
|
||||
// If the type is subtype of named type, return its qualifiedname, otherwise
|
||||
// return its type str.
|
||||
std::string getTypeStr(const c10::TypePtr& type) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::FunctionType:
|
||||
return type->cast<c10::FunctionType>()->name()->qualifiedName();
|
||||
case c10::TypeKind::TupleType:
|
||||
return type->cast<c10::TupleType>()->name()->qualifiedName();
|
||||
case c10::TypeKind::ClassType:
|
||||
return type->cast<c10::ClassType>()->name()->qualifiedName();
|
||||
case c10::TypeKind::InterfaceType:
|
||||
return type->cast<c10::InterfaceType>()->name()->qualifiedName();
|
||||
default:
|
||||
return type->str();
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
@ -41,7 +60,7 @@ RRefForkData RRef::fork() const {
|
||||
rrefId_,
|
||||
ctx.genGloballyUniqueId(),
|
||||
ctx.getWorkerId(),
|
||||
type_->str());
|
||||
getTypeStr(type_));
|
||||
}
|
||||
|
||||
////////////////////////// UserRRef /////////////////////////////////////
|
||||
|
@ -17,6 +17,17 @@ class RRef;
|
||||
class RRefContext;
|
||||
class UserRRef;
|
||||
|
||||
constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
|
||||
constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
|
||||
constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
|
||||
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
|
||||
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
|
||||
constexpr int PARENT_IDX = 5; // index of parent in the tuple
|
||||
constexpr int TYPE_IDX = 6; // index of parent in the tuple
|
||||
|
||||
// NB: if more fields are added, make sure this field is also bumped
|
||||
constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
|
||||
|
||||
// Represents fork of an RRef to be sent over the wire.
|
||||
struct TORCH_API RRefForkData {
|
||||
const worker_id_t ownerId_;
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
|
||||
#include <limits>
|
||||
@ -19,8 +20,11 @@ std::vector<IValue> toIValues(const Message& message, MessageType type) {
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
|
||||
auto value =
|
||||
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
|
||||
auto value = jit::unpickle(
|
||||
payload,
|
||||
payload_size,
|
||||
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
|
||||
&message.tensors());
|
||||
return value.toTuple()->elements();
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <torch/csrc/distributed/rpc/script_call.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
|
||||
namespace torch {
|
||||
@ -104,8 +105,8 @@ Message ScriptCall::toMessage() && {
|
||||
toIValues(ivalues);
|
||||
|
||||
std::vector<torch::Tensor> tensor_table;
|
||||
auto payload =
|
||||
jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table);
|
||||
auto payload = jit::pickle(
|
||||
c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
|
||||
|
||||
return Message(
|
||||
std::move(payload), std::move(tensor_table), MessageType::SCRIPT_CALL);
|
||||
@ -114,8 +115,11 @@ Message ScriptCall::toMessage() && {
|
||||
std::unique_ptr<ScriptCall> ScriptCall::fromMessage(const Message& message) {
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
auto value =
|
||||
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
|
||||
auto value = jit::unpickle(
|
||||
payload,
|
||||
payload_size,
|
||||
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
|
||||
&message.tensors());
|
||||
|
||||
auto values = value.toTuple()->elements();
|
||||
return fromIValues(values);
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <torch/csrc/distributed/rpc/script_remote_call.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
|
||||
#include <c10/util/C++17.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
@ -68,8 +69,11 @@ std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
|
||||
auto value =
|
||||
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
|
||||
auto value = jit::unpickle(
|
||||
payload,
|
||||
payload_size,
|
||||
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
|
||||
&message.tensors());
|
||||
auto values = value.toTuple()->elements();
|
||||
return fromIValues(values);
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/csrc/distributed/rpc/script_resp.h>
|
||||
|
||||
#include <c10/util/C++17.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
#include <torch/csrc/jit/unpickler.h>
|
||||
|
||||
@ -32,8 +33,11 @@ Message ScriptResp::toMessage() && {
|
||||
std::unique_ptr<ScriptResp> ScriptResp::fromMessage(const Message& message) {
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
auto value =
|
||||
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
|
||||
auto value = jit::unpickle(
|
||||
payload,
|
||||
payload_size,
|
||||
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
|
||||
&message.tensors());
|
||||
return std::make_unique<ScriptResp>(std::move(value));
|
||||
}
|
||||
|
||||
|
@ -331,6 +331,8 @@ std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize(
|
||||
return dptr;
|
||||
};
|
||||
|
||||
// No need to pass typeResolver here, as it always processes string and
|
||||
// tensors only
|
||||
torch::jit::Unpickler unpickler(
|
||||
metaDataReadFunc, nullptr, nullptr, sectionReadFunc, {});
|
||||
auto ival = unpickler.parse_ivalue();
|
||||
|
@ -66,6 +66,7 @@ TORCH_API IValue pickle_load(const std::vector<char>& data);
|
||||
/// binary. `reader` should remember where it last read, and return
|
||||
/// the number of bytes read.
|
||||
/// See `torch::pickle` for details.
|
||||
/// type_resolver is used to resolve any JIT type based on type str
|
||||
TORCH_API IValue unpickle(
|
||||
std::function<size_t(char*, size_t)> reader,
|
||||
TypeResolver type_resolver,
|
||||
|
@ -1,5 +1,8 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/Dict.h>
|
||||
#ifdef USE_DISTRIBUTED
|
||||
#include <torch/csrc/distributed/rpc/rref_context.h>
|
||||
#endif
|
||||
#include <torch/csrc/jit/function.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include <aten/src/ATen/quantized/Quantizer.h>
|
||||
@ -126,6 +129,13 @@ void Pickler::pushIValueImpl(const IValue& ivalue) {
|
||||
err << ". Please define serialization methods via torch::jit::pickle_ for "
|
||||
"this class.";
|
||||
AT_ERROR(err.str());
|
||||
} else if (ivalue.isRRef()) {
|
||||
#ifdef USE_DISTRIBUTED
|
||||
pushRRef(ivalue);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "RRef pickling is only supported with the distributed package");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind());
|
||||
}
|
||||
@ -146,6 +156,28 @@ void Pickler::pushDevice(const IValue& ivalue) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_DISTRIBUTED
|
||||
void Pickler::pushRRef(const IValue& ivalue) {
|
||||
// It is the same as how rref is pickled in python, see PyRRef::pickle
|
||||
auto rrefInterface = ivalue.toRRef();
|
||||
auto rref =
|
||||
c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(rrefInterface);
|
||||
pushGlobal("torch.distributed.rpc", "rref");
|
||||
auto& ctx = distributed::rpc::RRefContext::getInstance();
|
||||
auto rrefForkData = ctx.prepareChildFork(rref);
|
||||
push<PickleOpCode>(PickleOpCode::MARK);
|
||||
pushInt(rrefForkData.ownerId_);
|
||||
pushInt(rrefForkData.rrefId_.createdOn_);
|
||||
pushInt(rrefForkData.rrefId_.localId_);
|
||||
pushInt(rrefForkData.forkId_.createdOn_);
|
||||
pushInt(rrefForkData.forkId_.localId_);
|
||||
pushInt(rrefForkData.parent_);
|
||||
pushString(rrefForkData.typeStr_);
|
||||
push<PickleOpCode>(PickleOpCode::TUPLE);
|
||||
push<PickleOpCode>(PickleOpCode::REDUCE);
|
||||
}
|
||||
#endif
|
||||
|
||||
void Pickler::pushIValue(const IValue& ivalue) {
|
||||
bool shouldMemoizeByPointer =
|
||||
ivalue.isPtrType() && !ivalue.isString() && ivalue.use_count() > 1;
|
||||
|
@ -155,6 +155,9 @@ class Pickler {
|
||||
void pushTuple(const IValue& ivalue);
|
||||
void pushString(const std::string& string);
|
||||
void pushDevice(const IValue& ivalue);
|
||||
#ifdef USE_DISTRIBUTED
|
||||
void pushRRef(const IValue& ivalue);
|
||||
#endif
|
||||
// unmemoized version
|
||||
void pushStringImpl(const std::string& string);
|
||||
void pushStorageOfTensor(const at::Tensor& tensor);
|
||||
|
@ -749,6 +749,14 @@ inline py::object toPyObject(IValue ivalue) {
|
||||
return py::reinterpret_borrow<py::object>(ivalue.toPyObject());
|
||||
} else if (ivalue.isCapsule()) {
|
||||
return py::cast(ivalue.toCapsule());
|
||||
} else if (ivalue.isRRef()) {
|
||||
#ifdef USE_DISTRIBUTED
|
||||
return py::cast(torch::distributed::rpc::PyRRef(
|
||||
c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(
|
||||
ivalue.toRRef())));
|
||||
#else
|
||||
TORCH_CHECK(false, "RRef is only supported with the distributed package");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR(
|
||||
"Missing cases in 'toPyObject'! Can't convert ",
|
||||
|
@ -1,9 +1,12 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/Dict.h>
|
||||
#ifdef USE_DISTRIBUTED
|
||||
#include <torch/csrc/distributed/rpc/rref_context.h>
|
||||
#endif
|
||||
#include <torch/csrc/jit/function.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include "unpickler.h"
|
||||
#include <string>
|
||||
#include "unpickler.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -133,7 +136,7 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
||||
case ClassType::Kind: {
|
||||
auto obj = w.value.toObject();
|
||||
auto typ = obj->type(); // note: intentionally using the dynamic type,
|
||||
// the static type is potentially less accurate
|
||||
// the static type is potentially less accurate
|
||||
for (size_t i = 0; i < typ->numAttributes(); ++i) {
|
||||
Work elem = {typ->getAttribute(i), obj->getSlot(i)};
|
||||
to_process.emplace_back(std::move(elem));
|
||||
@ -174,7 +177,8 @@ void Unpickler::run() {
|
||||
TORCH_CHECK(
|
||||
opcode == PickleOpCode::PROTO,
|
||||
"Expected PROTO opcode at the start"
|
||||
" of pickle archive, found ", int(static_cast<uint8_t>(opcode)));
|
||||
" of pickle archive, found ",
|
||||
int(static_cast<uint8_t>(opcode)));
|
||||
uint8_t protocol = read<uint8_t>();
|
||||
TORCH_CHECK(
|
||||
protocol == 2,
|
||||
@ -220,8 +224,7 @@ static std::vector<int64_t> tupleToIntList(const IValue& v) {
|
||||
// lists are not yet tagged
|
||||
template <typename T>
|
||||
static std::vector<T> convertList(const IValue& v) {
|
||||
return fmap(
|
||||
v.toListRef(), [](const IValue& elem) { return elem.to<T>(); });
|
||||
return fmap(v.toListRef(), [](const IValue& elem) { return elem.to<T>(); });
|
||||
}
|
||||
|
||||
PickleOpCode Unpickler::readInstruction() {
|
||||
@ -453,11 +456,11 @@ void Unpickler::readGlobal(
|
||||
// Pop reduce arg off the stack
|
||||
auto data = stack_.back().toTuple()->elements().at(0);
|
||||
stack_.pop_back();
|
||||
TORCH_CHECK(
|
||||
tensor_table_,
|
||||
"Found a tensor table reference but Unpickler"
|
||||
" has no tensor table\n");
|
||||
stack_.emplace_back(tensor_table_->at(data.toInt()));
|
||||
TORCH_CHECK(
|
||||
tensor_table_,
|
||||
"Found a tensor table reference but Unpickler"
|
||||
" has no tensor table\n");
|
||||
stack_.emplace_back(tensor_table_->at(data.toInt()));
|
||||
});
|
||||
} else {
|
||||
TypePtr elem_type = nullptr;
|
||||
@ -498,13 +501,21 @@ void Unpickler::readGlobal(
|
||||
stack_.back() = IValue();
|
||||
});
|
||||
} else if (module_name == "torch" && class_name == "device") {
|
||||
globals_.emplace_back([this] {
|
||||
auto device_string = stack_.back().toTuple()->elements().at(0);
|
||||
stack_.pop_back();
|
||||
stack_.emplace_back(c10::Device(device_string.toStringRef()));
|
||||
});
|
||||
stack_.emplace_back(int64_t(globals_.size() - 1));
|
||||
return;
|
||||
globals_.emplace_back([this] {
|
||||
auto device_string = stack_.back().toTuple()->elements().at(0);
|
||||
stack_.pop_back();
|
||||
stack_.emplace_back(c10::Device(device_string.toStringRef()));
|
||||
});
|
||||
stack_.emplace_back(int64_t(globals_.size() - 1));
|
||||
return;
|
||||
} else if (module_name == "torch.distributed.rpc" && class_name == "rref") {
|
||||
#ifdef USE_DISTRIBUTED
|
||||
return rebuildRRef();
|
||||
#else
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false,
|
||||
"RRef unpickling is only supported with the distributed package");
|
||||
#endif
|
||||
} else if (module_name == "torch") {
|
||||
// Try to manually resolve several global enums
|
||||
// NOTE: this does not put a global into the global table,
|
||||
@ -577,7 +588,7 @@ void Unpickler::rebuildTensor(bool quantized) {
|
||||
const auto& scales = qparams.at(1).toTensor();
|
||||
const auto& zero_points = qparams.at(2).toTensor();
|
||||
int64_t axis = qparams.at(3).toInt();
|
||||
result = _empty_per_channel_affine_quantized(
|
||||
result = at::_empty_per_channel_affine_quantized(
|
||||
{0},
|
||||
scales,
|
||||
zero_points,
|
||||
@ -605,7 +616,48 @@ void Unpickler::rebuildTensor(bool quantized) {
|
||||
});
|
||||
}
|
||||
|
||||
void Unpickler::readSlowWithBuffer(char *dest, size_t sz) {
|
||||
#ifdef USE_DISTRIBUTED
|
||||
void Unpickler::rebuildRRef() {
|
||||
globals_.emplace_back([this] {
|
||||
// It is the same as how rref is unpickled in python,
|
||||
// see PyRRef::unpickle
|
||||
auto args = stack_.back().toTuple()->elements();
|
||||
stack_.pop_back();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
args.size() == distributed::rpc::RFD_TUPLE_SIZE,
|
||||
"Pickled RRefForkData must contain 7 numbers.");
|
||||
auto ownerId =
|
||||
static_cast<int16_t>(args.at(distributed::rpc::OWNER_IDX).toInt());
|
||||
// const reference will extend the lifetime of the temporary variable
|
||||
const auto& rrefId = distributed::rpc::RRefId(
|
||||
static_cast<int16_t>(args.at(distributed::rpc::RREFID_ON_IDX).toInt()),
|
||||
static_cast<int64_t>(args.at(distributed::rpc::RREFID_ID_IDX).toInt()));
|
||||
const auto& forkId = distributed::rpc::RRefId(
|
||||
static_cast<int16_t>(args.at(distributed::rpc::FORKID_ON_IDX).toInt()),
|
||||
static_cast<int64_t>(args.at(distributed::rpc::FORKID_ID_IDX).toInt()));
|
||||
auto parent =
|
||||
static_cast<int16_t>(args.at(distributed::rpc::PARENT_IDX).toInt());
|
||||
const auto& typeStr = static_cast<std::string>(
|
||||
args.at(distributed::rpc::TYPE_IDX).toStringRef());
|
||||
auto rrefForkData = distributed::rpc::RRefForkData(
|
||||
ownerId, rrefId, forkId, parent, typeStr);
|
||||
auto& ctx = distributed::rpc::RRefContext::getInstance();
|
||||
c10::intrusive_ptr<distributed::rpc::RRef> rref;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
type_resolver_ != nullptr, "type_resolver_ is nullptr.");
|
||||
at::StrongTypePtr type = type_resolver_(c10::QualifiedName(typeStr));
|
||||
rref = ctx.getOrCreateRRef(rrefForkData, type.type_);
|
||||
ctx.notifyOwnerAndParentOfFork(
|
||||
rrefForkData.forkId_, rrefForkData.parent_, rref);
|
||||
stack_.emplace_back(
|
||||
c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref));
|
||||
});
|
||||
stack_.emplace_back(int64_t(globals_.size() - 1));
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
void Unpickler::readSlowWithBuffer(char* dest, size_t sz) {
|
||||
// First, read any partial from buffer (may be 0).
|
||||
// We explicitly assume that sz > buffer_remaining_,
|
||||
// and that sz is never bigger than buffer_.size().
|
||||
@ -623,7 +675,7 @@ void Unpickler::readSlowWithBuffer(char *dest, size_t sz) {
|
||||
AT_ERROR("Unexpected end of pickler archive.");
|
||||
}
|
||||
memcpy(dest + from_old_buf, buffer_.data(), needed);
|
||||
buffer_pos_ = needed; // assignment (0'ed from read)
|
||||
buffer_pos_ = needed; // assignment (0'ed from read)
|
||||
buffer_remaining_ -= needed;
|
||||
}
|
||||
|
||||
|
@ -19,7 +19,11 @@ class Unpickler {
|
||||
TH_DISALLOW_COPY_AND_ASSIGN(Unpickler);
|
||||
|
||||
public:
|
||||
// tensors inside the pickle are references to the tensor_table
|
||||
// tensors inside the pickle are references to the tensor_table.
|
||||
// class_resolver is to resolve strong class type, type_resolver_ is
|
||||
// to resolve any JIT type. class_resolver and type_resolver are not merged
|
||||
// here because some use cases need to get strong class type that
|
||||
// type_resolver_ can not return.
|
||||
Unpickler(
|
||||
std::function<size_t(char*, size_t)> reader,
|
||||
TypeResolver type_resolver,
|
||||
@ -76,6 +80,9 @@ class Unpickler {
|
||||
const std::string& module_name,
|
||||
const std::string& class_name);
|
||||
void rebuildTensor(bool quantized);
|
||||
#ifdef USE_DISTRIBUTED
|
||||
void rebuildRRef();
|
||||
#endif
|
||||
PickleOpCode readInstruction();
|
||||
PickleOpCode readOpCode() {
|
||||
return static_cast<PickleOpCode>(read<uint8_t>());
|
||||
|
@ -59,13 +59,6 @@ def _check_rpc_done(rank_distance):
|
||||
def _torch_ones(sizes, requires_grad=False):
|
||||
return torch.ones(sizes, requires_grad=requires_grad)
|
||||
|
||||
|
||||
# creates an owner rref on the given dst, and the rref holds a torch.ones tensor
|
||||
# of the given size.
|
||||
def _create_ones_rref_on(dst, sizes):
|
||||
return rpc.remote(dst, _torch_ones, args=(sizes,), kwargs={"requires_grad": True})
|
||||
|
||||
|
||||
# This method must be called on the rref owner, and verifies that the grad of
|
||||
# rref tensor equals to the given grad.
|
||||
def _compare_owner_value(context_id, rref, grad):
|
||||
@ -73,6 +66,16 @@ def _compare_owner_value(context_id, rref, grad):
|
||||
return torch.equal(grads[rref.local_value()], grad)
|
||||
|
||||
|
||||
def create_tensor():
|
||||
return torch.ones((3, 3), requires_grad=True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def create_torchscript_tensor():
|
||||
# type: () -> Tensor
|
||||
return torch.ones((3, 3)).requires_grad_()
|
||||
|
||||
|
||||
def my_py_add(t1, t2):
|
||||
return torch.add(t1, t2)
|
||||
|
||||
@ -91,6 +94,13 @@ def my_script_add(t1, t2):
|
||||
return torch.add(t1, t2)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def my_script_ref_add(ref_t1, t2):
|
||||
# type: (RRef[Tensor], Tensor) -> Tensor
|
||||
t1 = ref_t1.to_here()
|
||||
return torch.add(t1, t2)
|
||||
|
||||
|
||||
def my_nested_rref_add(dst, rref_t1, t2):
|
||||
return rpc.rpc_sync(dst, my_rref_add, args=(rref_t1, t2))
|
||||
|
||||
@ -142,6 +152,16 @@ def _run_trainer(rref_t1, t2, ps, rank_diff):
|
||||
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
|
||||
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
|
||||
|
||||
# This function is the same as _run_trainer, except rpc calls torchscript
|
||||
# function "my_script_ref_add" instead of python funciton "my_rref_add"
|
||||
def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff):
|
||||
with dist_autograd.context() as context_id:
|
||||
ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2))
|
||||
dist_autograd.backward([ret.sum()])
|
||||
# prevent deleting dist autograd context
|
||||
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
|
||||
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
|
||||
|
||||
|
||||
class SimulateBackwardError(Function):
|
||||
@staticmethod
|
||||
@ -913,8 +933,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
#
|
||||
# These four test ps-trainer groups run on completely separate autograd
|
||||
# graphs, but they share the same set of underlying RpcAgents.
|
||||
@dist_init
|
||||
def test_trainer_ps(self):
|
||||
def _test_trainer_ps(self, create_ref_fn, trainer_fn):
|
||||
local_grads = None
|
||||
t1 = torch.ones((3, 3), requires_grad=True)
|
||||
t2 = torch.zeros((3, 3), requires_grad=True)
|
||||
@ -923,13 +942,10 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
local_ret.sum().backward()
|
||||
|
||||
# create rref on self
|
||||
# TODO: simplify this once we support rpc to self
|
||||
self_name = "worker{}".format(self.rank)
|
||||
rref_t1 = rpc.rpc_sync(
|
||||
"worker{}".format(self._next_rank()),
|
||||
_create_ones_rref_on,
|
||||
args=(self_name, (3, 3)),
|
||||
)
|
||||
rref_t1 = rpc.remote(
|
||||
"worker{}".format(self.rank),
|
||||
create_ref_fn,
|
||||
args=())
|
||||
|
||||
# kick off forward and backward pass on three other workers (trainers)
|
||||
rank_diffs = [1, 2, 3]
|
||||
@ -938,8 +954,8 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
futures.append(
|
||||
rpc.rpc_async(
|
||||
"worker{}".format((self.rank + rank_diff) % self.world_size),
|
||||
_run_trainer,
|
||||
args=(rref_t1, t2, self_name, rank_diff),
|
||||
trainer_fn,
|
||||
args=(rref_t1, t2, "worker{}".format(self.rank), rank_diff),
|
||||
)
|
||||
)
|
||||
|
||||
@ -954,7 +970,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
# are all correct
|
||||
ctx_id = ctx_ids[rank_diff]
|
||||
grads = dist_autograd.get_gradients(ctx_id)
|
||||
local_t1 = rref_t1.local_value()
|
||||
local_t1 = rref_t1.to_here()
|
||||
self.assertIn(local_t1, grads)
|
||||
self.assertEqual(grads[local_t1], t1.grad)
|
||||
|
||||
@ -965,6 +981,21 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
for fut in futures:
|
||||
fut.wait()
|
||||
|
||||
@dist_init
|
||||
def test_trainer_ps(self):
|
||||
self._test_trainer_ps(create_tensor, _run_trainer)
|
||||
|
||||
@dist_init
|
||||
def test_trainer_ps_torchscript_functions(self):
|
||||
# TODO, need more investigation
|
||||
# there is rref leak when shutting down, suspect it is because
|
||||
# ref as arg is passed to pybind boundary, and the ref is not garbage
|
||||
# collected by python when calling shutdown()
|
||||
import torch.distributed.rpc.api as api
|
||||
api._ignore_rref_leak = True
|
||||
|
||||
self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript)
|
||||
|
||||
@dist_init
|
||||
def test_backward_multiple_round_trips(self):
|
||||
local_grads = None
|
||||
@ -1261,8 +1292,8 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
ExecMode.REMOTE,
|
||||
]:
|
||||
with dist_autograd.context() as context_id:
|
||||
ret = self._exec_func(exec_mode, my_script_add, t1, t2)
|
||||
loss = ret.sum()
|
||||
forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2)
|
||||
loss = forward_ret.sum()
|
||||
ret = self._verify_backwards(
|
||||
exec_mode, [loss], context_id, local_grads, t1, t2
|
||||
)
|
||||
|
@ -237,6 +237,12 @@ def heavy_rpc(tensor):
|
||||
tensor /= i + 1
|
||||
return 0
|
||||
|
||||
@torch.jit.script
|
||||
def heavy_rpc_torchscript(tensor):
|
||||
for i in range(1, 100):
|
||||
tensor *= i
|
||||
tensor /= i + 1
|
||||
return 0
|
||||
|
||||
def raise_func():
|
||||
raise ValueError("Expected error")
|
||||
@ -257,6 +263,7 @@ def clear_global_rref():
|
||||
|
||||
@torch.jit.script
|
||||
def one_arg(value):
|
||||
# type: (Tensor) -> Tensor
|
||||
return value + 1
|
||||
|
||||
|
||||
@ -265,16 +272,21 @@ class MyScriptClass:
|
||||
def __init__(self):
|
||||
self.a = 10
|
||||
|
||||
@torch.jit.interface
|
||||
class MyModuleInterface(torch.nn.Module):
|
||||
def forward(self):
|
||||
# type: () -> Tensor
|
||||
pass
|
||||
|
||||
class MyScriptModule(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
def __init__(self, rank):
|
||||
super().__init__()
|
||||
self.a = 10
|
||||
self.a = torch.ones(rank)
|
||||
|
||||
@torch.jit.script_method
|
||||
def my_method(self):
|
||||
self.a = 11
|
||||
|
||||
def forward(self):
|
||||
# type: () -> Tensor
|
||||
return self.a
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
@ -877,7 +889,7 @@ class RpcTest(RpcAgentTestFixture):
|
||||
RuntimeError, "attempted to get undefined function"
|
||||
):
|
||||
ret = rpc._rpc_sync_torchscript(
|
||||
"worker{}".format(dst_rank), _qualified_name(MyScriptModule), args=()
|
||||
"worker{}".format(dst_rank), _qualified_name(MyScriptModule), args=(self.rank, )
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
@ -885,15 +897,17 @@ class RpcTest(RpcAgentTestFixture):
|
||||
):
|
||||
ret = rpc._rpc_sync_torchscript(
|
||||
"worker{}".format(dst_rank),
|
||||
_qualified_name(MyScriptModule().my_method),
|
||||
_qualified_name(MyScriptModule(self.rank).forward),
|
||||
args=(),
|
||||
)
|
||||
# Python 3.5 and Python 3.6 throw different error message, the only
|
||||
# common word can be greped is "pickle".
|
||||
with self.assertRaisesRegex(Exception, "pickle"):
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank), MyScriptModule().my_method, args=()
|
||||
)
|
||||
'worker{}'.format(dst_rank),
|
||||
MyScriptModule(self.rank).forward,
|
||||
args=())
|
||||
|
||||
|
||||
@dist_init
|
||||
def test_nested_rpc(self):
|
||||
@ -919,8 +933,8 @@ class RpcTest(RpcAgentTestFixture):
|
||||
self.assertEqual(fut.wait(), 0)
|
||||
tok = time.time()
|
||||
print(
|
||||
"Rank {} finished testing {} {} times in {} seconds.".format(
|
||||
self.rank, f.__name__, repeat, tok - tik
|
||||
"Rank {} finished testing {} times in {} seconds.".format(
|
||||
self.rank, repeat, tok - tik
|
||||
)
|
||||
)
|
||||
|
||||
@ -932,6 +946,10 @@ class RpcTest(RpcAgentTestFixture):
|
||||
def test_stress_heavy_rpc(self):
|
||||
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
|
||||
|
||||
@dist_init
|
||||
def test_stress_heavy_rpc_torchscript(self):
|
||||
self._stress_test_rpc(heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),))
|
||||
|
||||
@dist_init
|
||||
def test_builtin_remote_ret(self):
|
||||
n = self.rank + 1
|
||||
@ -1716,25 +1734,98 @@ class RpcTest(RpcAgentTestFixture):
|
||||
rpc.rpc_sync(callee_worker, foo_add, args=())
|
||||
self.assertTrue(torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 0),
|
||||
"Pytorch distributed rpc package " "does not support python2",
|
||||
)
|
||||
class RpcJitTest(RpcAgentTestFixture):
|
||||
@dist_init
|
||||
def test_rref_as_arg(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
rref_var = rpc_return_rref("worker{}".format(dst_rank))
|
||||
|
||||
def test_rref_as_arg_and_return(self):
|
||||
@torch.jit.script
|
||||
def rref_tensor_to_here(rref_var):
|
||||
def rref_to_here(rref_var):
|
||||
# type: (RRef[Tensor]) -> Tensor
|
||||
return rref_var.to_here()
|
||||
|
||||
res = rref_tensor_to_here(rref_var)
|
||||
self.assertEqual(res, torch.ones(2, 2) + 1)
|
||||
@torch.jit.script
|
||||
def return_rref(rref_var):
|
||||
# type: (RRef[Tensor]) -> RRef[Tensor]
|
||||
return rref_var
|
||||
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
local_ret = one_arg(torch.ones(2, 2))
|
||||
|
||||
# create rref on current rank
|
||||
rref = rpc.remote("worker{}".format(self.rank), one_arg, args=(torch.ones(2, 2),))
|
||||
|
||||
# pass rref to another user in rpc call
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
rref_to_here,
|
||||
args=(rref,))
|
||||
self.assertEqual(ret, local_ret)
|
||||
|
||||
# return rref in rpc call
|
||||
rref1 = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
return_rref,
|
||||
args=(rref,))
|
||||
self.assertEqual(rref1.to_here(), local_ret)
|
||||
|
||||
# pass rref to another user in remote call
|
||||
rref2 = rpc.remote(
|
||||
"worker{}".format(dst_rank),
|
||||
rref_to_here,
|
||||
args=(rref,))
|
||||
self.assertEqual(rref2.to_here(), local_ret)
|
||||
|
||||
# return rref in remote call
|
||||
rref3 = rpc.remote(
|
||||
"worker{}".format(dst_rank),
|
||||
return_rref,
|
||||
args=(rref,))
|
||||
self.assertEqual(rref3.to_here().to_here(), local_ret)
|
||||
|
||||
@dist_init
|
||||
def test_remote_script_module(self):
|
||||
@torch.jit.ignore
|
||||
def my_script_module_init(rank):
|
||||
# type: (int) -> MyModuleInterface
|
||||
return MyScriptModule(rank)
|
||||
|
||||
@torch.jit.script
|
||||
def construct_my_script_module(rank):
|
||||
# type: (int) -> MyModuleInterface
|
||||
return my_script_module_init(rank)
|
||||
|
||||
@torch.jit.script
|
||||
def run_ref_script_module(ref_script_module, t):
|
||||
# type: (RRef[MyModuleInterface], Tensor) -> Tensor
|
||||
module = ref_script_module.to_here()
|
||||
return module.forward() + t
|
||||
|
||||
# TODO, need more investigation
|
||||
# there is rref leak when shutting down, suspect it is because
|
||||
# ref as arg is passed to pybind boundary, and the ref is not garbage
|
||||
# collected by python when calling shutdown()
|
||||
import torch.distributed.rpc.api as api
|
||||
api._ignore_rref_leak = True
|
||||
|
||||
local_ret = MyScriptModule(self.rank).forward() + torch.ones(self.rank)
|
||||
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
remote_ref = rpc.remote(
|
||||
"worker{}".format(dst_rank),
|
||||
construct_my_script_module,
|
||||
args=(self.rank, ))
|
||||
|
||||
# pass rref arg to owner
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
run_ref_script_module,
|
||||
args=(remote_ref, torch.ones(self.rank)))
|
||||
self.assertEqual(ret, local_ret)
|
||||
|
||||
@dist_init
|
||||
def test_rref_is_owner(self):
|
||||
|
Reference in New Issue
Block a user