mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove Python dependency (toPyTuple/fromPyTuple, jitCompilationUnit, deserialize) in rref_impl.h/cpp (#32753)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32753 Functions to be bound as an Aten operator could not have Python dependency. This is to refactor and remove Python dependency. ghstack-source-id: 97485800 Test Plan: ``` buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork -- test_script_functions_not_supported buck build mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork buck-out/gen/caffe2/test/distributed/rpc/rpc_fork\#binary.par -r test_script_functions_not_supported ``` ``` buck test mode/dev-nosan //caffe2/test/distributed/rpc:dist_autograd_fork buck build mode/dev-nosan //caffe2/test/distributed/rpc:dist_autograd_fork buck-out/gen/caffe2/test/distributed/rpc/dist_autograd_fork\#binary.par -r test_backward_simple_script_call ``` Differential Revision: D5741675 fbshipit-source-id: 31ee60955be8d815d0773f3699e3ff2f1f9d8849
This commit is contained in:
committed by
Facebook Github Bot
parent
29fabb1fbc
commit
12bcfa7c77
@ -503,9 +503,12 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_call.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_remote_call.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_resp.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_agent.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_agent.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_context.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_proto.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_impl.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/torchscript_functions.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_call.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_remote_call.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_resp.cpp
|
||||
|
@ -66,9 +66,12 @@ libtorch_sources = [
|
||||
"torch/csrc/distributed/rpc/python_call.cpp",
|
||||
"torch/csrc/distributed/rpc/python_remote_call.cpp",
|
||||
"torch/csrc/distributed/rpc/python_resp.cpp",
|
||||
"torch/csrc/distributed/rpc/rpc_agent.cpp",
|
||||
"torch/csrc/distributed/rpc/request_callback.cpp",
|
||||
"torch/csrc/distributed/rpc/rpc_agent.cpp",
|
||||
"torch/csrc/distributed/rpc/rref_context.cpp",
|
||||
"torch/csrc/distributed/rpc/rref_proto.cpp",
|
||||
"torch/csrc/distributed/rpc/rref_impl.cpp",
|
||||
"torch/csrc/distributed/rpc/torchscript_functions.cpp",
|
||||
"torch/csrc/distributed/rpc/script_call.cpp",
|
||||
"torch/csrc/distributed/rpc/script_remote_call.cpp",
|
||||
"torch/csrc/distributed/rpc/script_resp.cpp",
|
||||
@ -307,9 +310,6 @@ def add_torch_libs():
|
||||
"torch/csrc/distributed/rpc/python_functions.cpp",
|
||||
"torch/csrc/distributed/rpc/python_rpc_handler.cpp",
|
||||
"torch/csrc/distributed/rpc/request_callback_impl.cpp",
|
||||
"torch/csrc/distributed/rpc/rref_context.cpp",
|
||||
"torch/csrc/distributed/rpc/rref_impl.cpp",
|
||||
"torch/csrc/distributed/rpc/torchscript_functions.cpp",
|
||||
"torch/csrc/jit/init.cpp",
|
||||
"torch/csrc/jit/passes/inline_fork_wait.cpp",
|
||||
"torch/csrc/jit/passes/onnx.cpp",
|
||||
|
@ -245,9 +245,6 @@ if (USE_DISTRIBUTED)
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_functions.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_rpc_handler.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback_impl.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_context.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_impl.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/torchscript_functions.cpp
|
||||
)
|
||||
list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d)
|
||||
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D)
|
||||
|
@ -342,21 +342,22 @@ If the future completes with an error, an exception is thrown.
|
||||
module.def(
|
||||
"_invoke_rpc_torchscript",
|
||||
[](const std::string& dstWorkerName,
|
||||
const std::string& qualifiedName,
|
||||
const std::string& qualifiedNameStr,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
// No need to catch exception here, if function can not be found,
|
||||
// exception will be thrown in get_function() call; if args do not match
|
||||
// with function schema, exception will be thrown in
|
||||
// createStackForSchema() call.
|
||||
auto name = c10::QualifiedName(qualifiedName);
|
||||
auto fnSchema = PythonRpcHandler::getInstance()
|
||||
.jitCompilationUnit()
|
||||
->get_function(name)
|
||||
.getSchema();
|
||||
auto qualifiedName = c10::QualifiedName(qualifiedNameStr);
|
||||
auto functionSchema = PythonRpcHandler::getInstance()
|
||||
.jitCompilationUnit()
|
||||
->get_function(qualifiedName)
|
||||
.getSchema();
|
||||
auto stack = torch::jit::createStackForSchema(
|
||||
fnSchema, args, kwargs, c10::nullopt);
|
||||
auto fut = rpcTorchscript(dstWorkerName, name, stack);
|
||||
functionSchema, args, kwargs, c10::nullopt);
|
||||
auto fut =
|
||||
rpcTorchscript(dstWorkerName, qualifiedName, functionSchema, stack);
|
||||
return PythonFutureWrapper(fut);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
@ -374,17 +375,18 @@ If the future completes with an error, an exception is thrown.
|
||||
module.def(
|
||||
"_invoke_remote_torchscript",
|
||||
[](const std::string& dstWorkerName,
|
||||
const std::string& qualifiedName,
|
||||
const std::string& qualifiedNameStr,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
auto name = c10::QualifiedName(qualifiedName);
|
||||
auto fnSchema = PythonRpcHandler::getInstance()
|
||||
.jitCompilationUnit()
|
||||
->get_function(name)
|
||||
.getSchema();
|
||||
auto qualifiedName = c10::QualifiedName(qualifiedNameStr);
|
||||
auto functionSchema = PythonRpcHandler::getInstance()
|
||||
.jitCompilationUnit()
|
||||
->get_function(qualifiedName)
|
||||
.getSchema();
|
||||
auto stack = torch::jit::createStackForSchema(
|
||||
fnSchema, args, kwargs, c10::nullopt);
|
||||
auto userRRefPtr = remoteTorchscript(dstWorkerName, name, stack);
|
||||
functionSchema, args, kwargs, c10::nullopt);
|
||||
auto userRRefPtr = remoteTorchscript(
|
||||
dstWorkerName, qualifiedName, functionSchema, stack);
|
||||
return PyRRef(userRRefPtr);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
@ -8,6 +8,55 @@
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
///////////////////// Pickle/Unpickle Helplers ////////////////////////////
|
||||
|
||||
namespace {
|
||||
constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
|
||||
constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
|
||||
constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
|
||||
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
|
||||
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
|
||||
constexpr int PARENT_IDX = 5; // index of parent in the tuple
|
||||
constexpr int TYPE_IDX = 6; // index of parent in the tuple
|
||||
|
||||
// NB: if more fields are added, make sure this field is also bumped
|
||||
constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
|
||||
|
||||
py::tuple toPyTuple(const RRefForkData& rrefForkData) {
|
||||
// add GIL as it is contructing a py::object
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
return py::make_tuple(
|
||||
rrefForkData.ownerId_,
|
||||
rrefForkData.rrefId_.createdOn_,
|
||||
rrefForkData.rrefId_.localId_,
|
||||
rrefForkData.forkId_.createdOn_,
|
||||
rrefForkData.forkId_.localId_,
|
||||
rrefForkData.parent_,
|
||||
rrefForkData.typeStr_);
|
||||
}
|
||||
RRefForkData fromPyTuple(const py::tuple& pyTuple) {
|
||||
// add GIL as it is accessing a py::object
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
pyTuple.size() == RFD_TUPLE_SIZE,
|
||||
"Pickled RRefForkData must contain 6 numbers.");
|
||||
worker_id_t ownerId = pyTuple[OWNER_IDX].cast<worker_id_t>();
|
||||
// const reference will extend the lifetime of the temporary variable
|
||||
const RRefId& rrefId = RRefId(
|
||||
pyTuple[RREFID_ON_IDX].cast<worker_id_t>(),
|
||||
pyTuple[RREFID_ID_IDX].cast<local_id_t>());
|
||||
const RRefId& forkId = RRefId(
|
||||
pyTuple[FORKID_ON_IDX].cast<worker_id_t>(),
|
||||
pyTuple[FORKID_ID_IDX].cast<local_id_t>());
|
||||
|
||||
worker_id_t parent = pyTuple[PARENT_IDX].cast<worker_id_t>();
|
||||
const std::string& typeStr = pyTuple[TYPE_IDX].cast<std::string>();
|
||||
|
||||
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/////////////////////////// PyRRef //////////////////////////////////
|
||||
|
||||
PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) {
|
||||
@ -38,7 +87,17 @@ py::object PyRRef::toHere() {
|
||||
} else {
|
||||
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
|
||||
// a python object
|
||||
IValue value = std::static_pointer_cast<UserRRef>(rref_)->toHere();
|
||||
std::vector<IValue> rawValues =
|
||||
std::static_pointer_cast<UserRRef>(rref_)->toHere();
|
||||
IValue value;
|
||||
if (rref_->isPyObj()) {
|
||||
value = jit::toIValue(
|
||||
PythonRpcHandler::getInstance().deserialize(
|
||||
SerializedPyObj::fromIValues(std::move(rawValues))),
|
||||
PyObjectType::get());
|
||||
} else {
|
||||
value = std::move(rawValues).front();
|
||||
}
|
||||
{
|
||||
// acquiring GIL as torch::jit::toPyObject creates new py::object
|
||||
// without grabbing the GIL.
|
||||
@ -85,18 +144,19 @@ py::tuple PyRRef::pickle() const {
|
||||
// install the dispatch table only when there are indeed RPC activities. As
|
||||
// a counter example, checkpointing a model with RRefs should not trigger
|
||||
// forks to be added as a fork or a child.
|
||||
auto rfd = ctx.prepareChildFork(rref_);
|
||||
return rfd.toPyTuple();
|
||||
auto rrefForkData = ctx.prepareChildFork(rref_);
|
||||
return toPyTuple(rrefForkData);
|
||||
}
|
||||
|
||||
PyRRef PyRRef::unpickle(const py::tuple& t) {
|
||||
PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
auto rfd = RRefForkData::fromPyTuple(t.cast<py::tuple>());
|
||||
auto rrefForkData = fromPyTuple(pyTuple);
|
||||
std::shared_ptr<RRef> rref = nullptr;
|
||||
TypePtr rref_type =
|
||||
PythonRpcHandler::getInstance().parseTypeFromStr(rfd.type_str_);
|
||||
rref = ctx.getOrCreateRRef(rfd, rref_type);
|
||||
ctx.notifyOwnerAndParentOfFork(rfd.forkId_, rfd.parent_, rref);
|
||||
TypePtr rrefType =
|
||||
PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
|
||||
rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
|
||||
ctx.notifyOwnerAndParentOfFork(
|
||||
rrefForkData.forkId_, rrefForkData.parent_, rref);
|
||||
return PyRRef(std::move(rref));
|
||||
}
|
||||
|
||||
|
@ -157,11 +157,11 @@ void RRefContext::delUser(
|
||||
}
|
||||
|
||||
std::shared_ptr<RRef> RRefContext::getOrCreateRRef(
|
||||
const RRefForkData& rfd,
|
||||
const RRefForkData& rrefForkData,
|
||||
const TypePtr& type) {
|
||||
auto& ownerId = rfd.ownerId_;
|
||||
auto& rrefId = rfd.rrefId_;
|
||||
auto& forkId = rfd.forkId_;
|
||||
auto& ownerId = rrefForkData.ownerId_;
|
||||
auto& rrefId = rrefForkData.rrefId_;
|
||||
auto& forkId = rrefForkData.forkId_;
|
||||
if (ownerId == getWorkerId()) {
|
||||
auto ownerRRef = getOwnerRRef(rrefId);
|
||||
TORCH_INTERNAL_ASSERT(ownerRRef->type() == type);
|
||||
@ -217,7 +217,7 @@ std::shared_ptr<OwnerRRef> RRefContext::getOwnerRRef(const RRefId& rrefId) {
|
||||
}
|
||||
|
||||
RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
|
||||
auto rfd = rref->fork();
|
||||
auto rrefForkData = rref->fork();
|
||||
if (rref->isOwner()) {
|
||||
// Note [Early Fork Registration]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -229,7 +229,7 @@ RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
|
||||
// ACK does not making any difference but only add complexity.
|
||||
// TODO: When adding failure retries and timeout, this fork needs to be
|
||||
// deleted if the owner does not receive the ACK within the timeout.
|
||||
addForkOfOwner(rfd.rrefId_, rfd.forkId_);
|
||||
addForkOfOwner(rrefForkData.rrefId_, rrefForkData.forkId_);
|
||||
// ensure that this RRef is in the owners_ list to keep it alive.
|
||||
// this is needed for OwnerRRefs that were created locally.
|
||||
{
|
||||
@ -240,17 +240,17 @@ RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
|
||||
// Note [Useful Phantom Fork ID for User to Owner Call]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// If the callee of dist.remote or dist.rpc is the owner of this RRef, the
|
||||
// callee will not create a fork using this rfd.forkId_, because the owner
|
||||
// will only keep one `OwnerRRef` instance and will not create any
|
||||
// `UserRRef` instances. However, this rfd.forkId_ is still necessary, as
|
||||
// the caller user needs to keep this `UserRRef` alive until it gets the
|
||||
// ACK from the callee owner. Otherwise, the delete message could arrive
|
||||
// at the owner before this dist.rpc or dist.remote call, which could
|
||||
// potentially trigger the `OwnerRRef` to be deleted before running the
|
||||
// user code.
|
||||
addPendingChild(rfd.forkId_, rref);
|
||||
// callee will not create a fork using this rrefForkData.forkId_, because
|
||||
// the owner will only keep one `OwnerRRef` instance and will not create any
|
||||
// `UserRRef` instances. However, this rrefForkData.forkId_ is still
|
||||
// necessary, as the caller user needs to keep this `UserRRef` alive until
|
||||
// it gets the ACK from the callee owner. Otherwise, the delete message
|
||||
// could arrive at the owner before this dist.rpc or dist.remote call, which
|
||||
// could potentially trigger the `OwnerRRef` to be deleted before running
|
||||
// the user code.
|
||||
addPendingChild(rrefForkData.forkId_, rref);
|
||||
}
|
||||
return rfd;
|
||||
return rrefForkData;
|
||||
}
|
||||
|
||||
void RRefContext::notifyOwnerAndParentOfFork(
|
||||
|
@ -14,13 +14,13 @@ namespace rpc {
|
||||
|
||||
namespace callback {
|
||||
// It's the callback for RemoteCall.
|
||||
void confirmPendingUser(
|
||||
void TORCH_API confirmPendingUser(
|
||||
const rpc::Message& message,
|
||||
const c10::optional<utils::FutureError>& futErr);
|
||||
} // namespace callback
|
||||
|
||||
// Manages RRef lifetime and keeps track of RRef forks.
|
||||
class RRefContext {
|
||||
class TORCH_API RRefContext {
|
||||
public:
|
||||
static RRefContext& getInstance();
|
||||
// NB: This method must be called before destructing RRefContext singleton.
|
||||
|
@ -2,30 +2,14 @@
|
||||
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
||||
#include <torch/csrc/distributed/autograd/utils.h>
|
||||
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
||||
#include <torch/csrc/distributed/rpc/rref_context.h>
|
||||
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
|
||||
constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
|
||||
constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
|
||||
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
|
||||
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
|
||||
constexpr int PARENT_IDX = 5; // index of parent in the tuple
|
||||
constexpr int TYPE_IDX = 6; // index of parent in the tuple
|
||||
|
||||
// NB: if more fields are added, make sure this field is also bumped
|
||||
constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
|
||||
} // namespace
|
||||
|
||||
std::atomic<local_id_t> RRefContext::nextLocalId_{0};
|
||||
|
||||
////////////////////////// RRefForkData /////////////////////////////////
|
||||
@ -35,46 +19,12 @@ RRefForkData::RRefForkData(
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId,
|
||||
worker_id_t parent,
|
||||
std::string type_str)
|
||||
std::string typeStr)
|
||||
: ownerId_(ownerId),
|
||||
rrefId_(rrefId),
|
||||
forkId_(forkId),
|
||||
parent_(parent),
|
||||
type_str_(std::move(type_str)) {}
|
||||
|
||||
py::tuple RRefForkData::toPyTuple() const {
|
||||
// add GIL as it is contructing a py::object
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
return py::make_tuple(
|
||||
ownerId_,
|
||||
rrefId_.createdOn_,
|
||||
rrefId_.localId_,
|
||||
forkId_.createdOn_,
|
||||
forkId_.localId_,
|
||||
parent_,
|
||||
type_str_);
|
||||
}
|
||||
|
||||
RRefForkData RRefForkData::fromPyTuple(const py::tuple& t) {
|
||||
// add GIL as it is accessing a py::object
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
t.size() == RFD_TUPLE_SIZE,
|
||||
"Pickled RRefForkData must contain 6 numbers.");
|
||||
worker_id_t ownerId = t[OWNER_IDX].cast<worker_id_t>();
|
||||
// const reference will extend the lifetime of the temporary variable
|
||||
const RRefId& rrefId = RRefId(
|
||||
t[RREFID_ON_IDX].cast<worker_id_t>(),
|
||||
t[RREFID_ID_IDX].cast<local_id_t>());
|
||||
const RRefId& forkId = RRefId(
|
||||
t[FORKID_ON_IDX].cast<worker_id_t>(),
|
||||
t[FORKID_ID_IDX].cast<local_id_t>());
|
||||
|
||||
worker_id_t parent = t[PARENT_IDX].cast<worker_id_t>();
|
||||
const std::string& typeStr = t[TYPE_IDX].cast<std::string>();
|
||||
|
||||
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
|
||||
}
|
||||
typeStr_(std::move(typeStr)) {}
|
||||
|
||||
////////////////////////////// RRef /////////////////////////////////////
|
||||
|
||||
@ -127,7 +77,7 @@ const ForkId& UserRRef::forkId() const {
|
||||
return forkId_;
|
||||
}
|
||||
|
||||
IValue UserRRef::toHere() {
|
||||
std::vector<IValue> UserRRef::toHere() {
|
||||
auto agent = RpcAgent::getCurrentRpcAgent();
|
||||
|
||||
// ScriptRRefFetchCall message always carries autograd context id even if
|
||||
@ -156,16 +106,8 @@ IValue UserRRef::toHere() {
|
||||
"Message type should either be SCRIPT_RREF_FETCH_RET "
|
||||
"or PYTHON_RREF_FETCH_RET");
|
||||
RpcCommandBase& rpc = *response;
|
||||
if (isPyObj()) {
|
||||
auto& rfr = static_cast<PythonRRefFetchRet&>(rpc);
|
||||
return jit::toIValue(
|
||||
PythonRpcHandler::getInstance().deserialize(
|
||||
SerializedPyObj::fromIValues(rfr.values())),
|
||||
PyObjectType::get());
|
||||
} else {
|
||||
auto& rfr = static_cast<ScriptRRefFetchRet&>(rpc);
|
||||
return rfr.values().front();
|
||||
}
|
||||
auto& rrefFetchRet = static_cast<RRefFetchRet&>(rpc);
|
||||
return rrefFetchRet.values();
|
||||
}
|
||||
|
||||
////////////////////////// OwnerRRef /////////////////////////////////////
|
||||
|
@ -6,7 +6,6 @@
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/distributed/rpc/rref_interface.h>
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <atomic>
|
||||
|
||||
@ -19,27 +18,19 @@ class RRefContext;
|
||||
class UserRRef;
|
||||
|
||||
// Represents fork of an RRef to be sent over the wire.
|
||||
struct RRefForkData {
|
||||
py::tuple toPyTuple() const;
|
||||
static RRefForkData fromPyTuple(const py::tuple& obj);
|
||||
|
||||
struct TORCH_API RRefForkData {
|
||||
const worker_id_t ownerId_;
|
||||
const RRefId rrefId_;
|
||||
const ForkId forkId_;
|
||||
const worker_id_t parent_;
|
||||
const std::string type_str_;
|
||||
|
||||
private:
|
||||
friend class RRef;
|
||||
friend class RRefContext;
|
||||
friend class UserRRef;
|
||||
const std::string typeStr_;
|
||||
|
||||
RRefForkData(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId_,
|
||||
const ForkId& forkId_,
|
||||
worker_id_t parent,
|
||||
std::string type_str);
|
||||
std::string typeStr);
|
||||
};
|
||||
|
||||
// Note [RRef Protocol]
|
||||
@ -184,7 +175,7 @@ struct RRefForkData {
|
||||
//
|
||||
// ``RRef`` is the base type for both ``UserRRef`` and ``OwnerRRef``.
|
||||
// Each ``RRef`` has a globally unique ``RRefId``.
|
||||
class RRef : public RRefInterface {
|
||||
class TORCH_API RRef : public RRefInterface {
|
||||
public:
|
||||
// RRef is made NOT copyable NOT movable to prevent messing up reference
|
||||
// counting.
|
||||
@ -230,7 +221,7 @@ class RRef : public RRefInterface {
|
||||
// also has a globally unique ``ForkId`` to identify this user. ``UserRRef``
|
||||
// never owns the real value, the only way to get the value of the ``RRef`` is
|
||||
// to call ``to_here()`` and get a copy..
|
||||
class UserRRef final : public RRef {
|
||||
class TORCH_API UserRRef final : public RRef {
|
||||
public:
|
||||
UserRRef(const UserRRef& other) = delete;
|
||||
UserRRef(UserRRef&& other) = delete;
|
||||
@ -246,7 +237,7 @@ class UserRRef final : public RRef {
|
||||
|
||||
// Get of copy of the value from the ``OwnerRRef``. If the value is not ready
|
||||
// yet, this call will block.
|
||||
IValue toHere();
|
||||
std::vector<IValue> toHere();
|
||||
|
||||
// Upon destruction, this ``UserRRef`` will tell the owner to deref.
|
||||
~UserRRef() override;
|
||||
@ -265,7 +256,7 @@ class UserRRef final : public RRef {
|
||||
|
||||
// Keep the template only on the derived class because ``RRefContext`` needs to
|
||||
// erase the type on ``RRef`` and keep them in one map.
|
||||
class OwnerRRef final : public RRef {
|
||||
class TORCH_API OwnerRRef final : public RRef {
|
||||
public:
|
||||
OwnerRRef(const OwnerRRef& other) = delete;
|
||||
OwnerRRef(OwnerRRef&& other) = delete;
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <torch/csrc/distributed/autograd/utils.h>
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/distributed/rpc/script_call.h>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
@ -14,6 +13,7 @@ namespace rpc {
|
||||
c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
|
||||
const std::string& dstWorkerName,
|
||||
const c10::QualifiedName& qualifiedName,
|
||||
const c10::FunctionSchema& functionSchema,
|
||||
std::vector<c10::IValue>& stack) {
|
||||
auto scriptCall =
|
||||
std::make_unique<ScriptCall>(qualifiedName, std::move(stack));
|
||||
@ -24,11 +24,7 @@ c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
|
||||
std::move(*scriptCall).toMessage());
|
||||
|
||||
// Get function return type to construct c10::ivalue::Future.
|
||||
auto returns = PythonRpcHandler::getInstance()
|
||||
.jitCompilationUnit()
|
||||
->get_function(qualifiedName)
|
||||
.getSchema()
|
||||
.returns();
|
||||
auto returns = functionSchema.returns();
|
||||
// Script call only allows single IValue returned.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
returns.size() == 1,
|
||||
@ -56,6 +52,7 @@ c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
|
||||
std::shared_ptr<UserRRef> remoteTorchscript(
|
||||
const std::string& dstWorkerName,
|
||||
const c10::QualifiedName& qualifiedName,
|
||||
const c10::FunctionSchema& functionSchema,
|
||||
std::vector<c10::IValue>& stack) {
|
||||
auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
|
||||
auto dstWorkerInfo = rpcAgentPtr->getWorkerInfo(dstWorkerName);
|
||||
@ -66,11 +63,7 @@ std::shared_ptr<UserRRef> remoteTorchscript(
|
||||
"Does not support creating RRef on self yet.");
|
||||
|
||||
// Get function return type to construct UserRRef.
|
||||
auto returns = PythonRpcHandler::getInstance()
|
||||
.jitCompilationUnit()
|
||||
->get_function(qualifiedName)
|
||||
.getSchema()
|
||||
.returns();
|
||||
auto returns = functionSchema.returns();
|
||||
// Script call only allows single IValue returned.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
returns.size() == 1,
|
||||
|
@ -2,8 +2,6 @@
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/distributed/autograd/utils.h>
|
||||
#include <torch/csrc/distributed/rpc/py_rref.h>
|
||||
#include <torch/csrc/distributed/rpc/python_functions.h>
|
||||
#include <torch/csrc/distributed/rpc/rref_context.h>
|
||||
#include <torch/csrc/distributed/rpc/script_remote_call.h>
|
||||
|
||||
@ -21,14 +19,16 @@ namespace rpc {
|
||||
// "dist_autograd_test::my_py_add"
|
||||
// stack: a bag of IValue args passed to torchscriptFunctionName
|
||||
// It returns c10::intrusive_ptr<ivalue::Future>
|
||||
c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
|
||||
c10::intrusive_ptr<c10::ivalue::Future> TORCH_API rpcTorchscript(
|
||||
const std::string& dstWorkerName,
|
||||
const c10::QualifiedName& qualifiedName,
|
||||
const c10::FunctionSchema& functionSchema,
|
||||
std::vector<c10::IValue>& stack);
|
||||
|
||||
std::shared_ptr<UserRRef> remoteTorchscript(
|
||||
std::shared_ptr<UserRRef> TORCH_API remoteTorchscript(
|
||||
const std::string& dstWorkerName,
|
||||
const c10::QualifiedName& qualifiedName,
|
||||
const c10::FunctionSchema& functionSchema,
|
||||
std::vector<c10::IValue>& stack);
|
||||
|
||||
} // namespace rpc
|
||||
|
@ -248,8 +248,8 @@ def clear_global_rref():
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def no_args():
|
||||
a = 1
|
||||
def one_arg(value):
|
||||
return value + 1
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@ -832,19 +832,27 @@ class RpcTest(RpcAgentTestFixture):
|
||||
fut.wait()
|
||||
|
||||
@dist_init
|
||||
def test_script_function_exception(self):
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
def test_torchscript_function(self):
|
||||
dst_worker_name = "worker{}".format((self.rank + 1) % self.world_size)
|
||||
|
||||
with self.assertRaisesRegex(Exception, "no_args"):
|
||||
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_args, args=(10,))
|
||||
ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(torch.ones(2, 2),))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
Exception, r"no_args\(\) expected at most 0 argument"
|
||||
):
|
||||
rref = rpc.remote("worker{}".format(dst_rank), no_args, args=(10,))
|
||||
rref = rpc.remote(dst_worker_name, one_arg, args=(torch.ones(2, 2),))
|
||||
|
||||
@dist_init
|
||||
def test_script_functions_not_supported(self):
|
||||
def test_torchscript_function_exception(self):
|
||||
dst_worker_name = "worker{}".format((self.rank + 1) % self.world_size)
|
||||
|
||||
with self.assertRaisesRegex(Exception, r"one_arg\(\) expected at most"):
|
||||
ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(10, 20))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
Exception, r"one_arg\(\) expected at most"
|
||||
):
|
||||
rref = rpc.remote(dst_worker_name, one_arg, args=(10, 20))
|
||||
|
||||
@dist_init
|
||||
def test_torchscript_functions_not_supported(self):
|
||||
# Right now _rpc_sync_torchscript does not accept annotated torchscript
|
||||
# class name or script module class name or their class method names.
|
||||
# But rpc_sync still accepts script class name and run it in
|
||||
|
Reference in New Issue
Block a user