Remove Python dependency (toPyTuple/fromPyTuple, jitCompilationUnit, deserialize) in rref_impl.h/cpp (#32753)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32753

Functions to be bound as an Aten operator could not have Python dependency.

This is to refactor and remove Python dependency.
ghstack-source-id: 97485800

Test Plan:
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork -- test_script_functions_not_supported

buck build mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork

buck-out/gen/caffe2/test/distributed/rpc/rpc_fork\#binary.par -r test_script_functions_not_supported
```

```
buck test mode/dev-nosan //caffe2/test/distributed/rpc:dist_autograd_fork

buck build mode/dev-nosan //caffe2/test/distributed/rpc:dist_autograd_fork

buck-out/gen/caffe2/test/distributed/rpc/dist_autograd_fork\#binary.par -r test_backward_simple_script_call
```

Differential Revision: D5741675

fbshipit-source-id: 31ee60955be8d815d0773f3699e3ff2f1f9d8849
This commit is contained in:
Shihao Xu
2020-01-30 17:51:05 -08:00
committed by Facebook Github Bot
parent 29fabb1fbc
commit 12bcfa7c77
12 changed files with 152 additions and 156 deletions

View File

@ -503,9 +503,12 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_call.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_remote_call.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_resp.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_agent.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_agent.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_context.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_proto.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_impl.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/torchscript_functions.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_call.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_remote_call.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_resp.cpp

View File

@ -66,9 +66,12 @@ libtorch_sources = [
"torch/csrc/distributed/rpc/python_call.cpp",
"torch/csrc/distributed/rpc/python_remote_call.cpp",
"torch/csrc/distributed/rpc/python_resp.cpp",
"torch/csrc/distributed/rpc/rpc_agent.cpp",
"torch/csrc/distributed/rpc/request_callback.cpp",
"torch/csrc/distributed/rpc/rpc_agent.cpp",
"torch/csrc/distributed/rpc/rref_context.cpp",
"torch/csrc/distributed/rpc/rref_proto.cpp",
"torch/csrc/distributed/rpc/rref_impl.cpp",
"torch/csrc/distributed/rpc/torchscript_functions.cpp",
"torch/csrc/distributed/rpc/script_call.cpp",
"torch/csrc/distributed/rpc/script_remote_call.cpp",
"torch/csrc/distributed/rpc/script_resp.cpp",
@ -307,9 +310,6 @@ def add_torch_libs():
"torch/csrc/distributed/rpc/python_functions.cpp",
"torch/csrc/distributed/rpc/python_rpc_handler.cpp",
"torch/csrc/distributed/rpc/request_callback_impl.cpp",
"torch/csrc/distributed/rpc/rref_context.cpp",
"torch/csrc/distributed/rpc/rref_impl.cpp",
"torch/csrc/distributed/rpc/torchscript_functions.cpp",
"torch/csrc/jit/init.cpp",
"torch/csrc/jit/passes/inline_fork_wait.cpp",
"torch/csrc/jit/passes/onnx.cpp",

View File

@ -245,9 +245,6 @@ if (USE_DISTRIBUTED)
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_functions.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_rpc_handler.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback_impl.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_context.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_impl.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/torchscript_functions.cpp
)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D)

View File

@ -342,21 +342,22 @@ If the future completes with an error, an exception is thrown.
module.def(
"_invoke_rpc_torchscript",
[](const std::string& dstWorkerName,
const std::string& qualifiedName,
const std::string& qualifiedNameStr,
const py::args& args,
const py::kwargs& kwargs) {
// No need to catch exception here, if function can not be found,
// exception will be thrown in get_function() call; if args do not match
// with function schema, exception will be thrown in
// createStackForSchema() call.
auto name = c10::QualifiedName(qualifiedName);
auto fnSchema = PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(name)
.getSchema();
auto qualifiedName = c10::QualifiedName(qualifiedNameStr);
auto functionSchema = PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(qualifiedName)
.getSchema();
auto stack = torch::jit::createStackForSchema(
fnSchema, args, kwargs, c10::nullopt);
auto fut = rpcTorchscript(dstWorkerName, name, stack);
functionSchema, args, kwargs, c10::nullopt);
auto fut =
rpcTorchscript(dstWorkerName, qualifiedName, functionSchema, stack);
return PythonFutureWrapper(fut);
},
py::call_guard<py::gil_scoped_release>());
@ -374,17 +375,18 @@ If the future completes with an error, an exception is thrown.
module.def(
"_invoke_remote_torchscript",
[](const std::string& dstWorkerName,
const std::string& qualifiedName,
const std::string& qualifiedNameStr,
const py::args& args,
const py::kwargs& kwargs) {
auto name = c10::QualifiedName(qualifiedName);
auto fnSchema = PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(name)
.getSchema();
auto qualifiedName = c10::QualifiedName(qualifiedNameStr);
auto functionSchema = PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(qualifiedName)
.getSchema();
auto stack = torch::jit::createStackForSchema(
fnSchema, args, kwargs, c10::nullopt);
auto userRRefPtr = remoteTorchscript(dstWorkerName, name, stack);
functionSchema, args, kwargs, c10::nullopt);
auto userRRefPtr = remoteTorchscript(
dstWorkerName, qualifiedName, functionSchema, stack);
return PyRRef(userRRefPtr);
},
py::call_guard<py::gil_scoped_release>());

View File

@ -8,6 +8,55 @@
namespace torch {
namespace distributed {
namespace rpc {
///////////////////// Pickle/Unpickle Helplers ////////////////////////////
namespace {
constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
constexpr int PARENT_IDX = 5; // index of parent in the tuple
constexpr int TYPE_IDX = 6; // index of parent in the tuple
// NB: if more fields are added, make sure this field is also bumped
constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
py::tuple toPyTuple(const RRefForkData& rrefForkData) {
// add GIL as it is contructing a py::object
pybind11::gil_scoped_acquire ag;
return py::make_tuple(
rrefForkData.ownerId_,
rrefForkData.rrefId_.createdOn_,
rrefForkData.rrefId_.localId_,
rrefForkData.forkId_.createdOn_,
rrefForkData.forkId_.localId_,
rrefForkData.parent_,
rrefForkData.typeStr_);
}
RRefForkData fromPyTuple(const py::tuple& pyTuple) {
// add GIL as it is accessing a py::object
pybind11::gil_scoped_acquire ag;
TORCH_INTERNAL_ASSERT(
pyTuple.size() == RFD_TUPLE_SIZE,
"Pickled RRefForkData must contain 6 numbers.");
worker_id_t ownerId = pyTuple[OWNER_IDX].cast<worker_id_t>();
// const reference will extend the lifetime of the temporary variable
const RRefId& rrefId = RRefId(
pyTuple[RREFID_ON_IDX].cast<worker_id_t>(),
pyTuple[RREFID_ID_IDX].cast<local_id_t>());
const RRefId& forkId = RRefId(
pyTuple[FORKID_ON_IDX].cast<worker_id_t>(),
pyTuple[FORKID_ID_IDX].cast<local_id_t>());
worker_id_t parent = pyTuple[PARENT_IDX].cast<worker_id_t>();
const std::string& typeStr = pyTuple[TYPE_IDX].cast<std::string>();
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
}
} // namespace
/////////////////////////// PyRRef //////////////////////////////////
PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) {
@ -38,7 +87,17 @@ py::object PyRRef::toHere() {
} else {
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
// a python object
IValue value = std::static_pointer_cast<UserRRef>(rref_)->toHere();
std::vector<IValue> rawValues =
std::static_pointer_cast<UserRRef>(rref_)->toHere();
IValue value;
if (rref_->isPyObj()) {
value = jit::toIValue(
PythonRpcHandler::getInstance().deserialize(
SerializedPyObj::fromIValues(std::move(rawValues))),
PyObjectType::get());
} else {
value = std::move(rawValues).front();
}
{
// acquiring GIL as torch::jit::toPyObject creates new py::object
// without grabbing the GIL.
@ -85,18 +144,19 @@ py::tuple PyRRef::pickle() const {
// install the dispatch table only when there are indeed RPC activities. As
// a counter example, checkpointing a model with RRefs should not trigger
// forks to be added as a fork or a child.
auto rfd = ctx.prepareChildFork(rref_);
return rfd.toPyTuple();
auto rrefForkData = ctx.prepareChildFork(rref_);
return toPyTuple(rrefForkData);
}
PyRRef PyRRef::unpickle(const py::tuple& t) {
PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
auto& ctx = RRefContext::getInstance();
auto rfd = RRefForkData::fromPyTuple(t.cast<py::tuple>());
auto rrefForkData = fromPyTuple(pyTuple);
std::shared_ptr<RRef> rref = nullptr;
TypePtr rref_type =
PythonRpcHandler::getInstance().parseTypeFromStr(rfd.type_str_);
rref = ctx.getOrCreateRRef(rfd, rref_type);
ctx.notifyOwnerAndParentOfFork(rfd.forkId_, rfd.parent_, rref);
TypePtr rrefType =
PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
ctx.notifyOwnerAndParentOfFork(
rrefForkData.forkId_, rrefForkData.parent_, rref);
return PyRRef(std::move(rref));
}

View File

@ -157,11 +157,11 @@ void RRefContext::delUser(
}
std::shared_ptr<RRef> RRefContext::getOrCreateRRef(
const RRefForkData& rfd,
const RRefForkData& rrefForkData,
const TypePtr& type) {
auto& ownerId = rfd.ownerId_;
auto& rrefId = rfd.rrefId_;
auto& forkId = rfd.forkId_;
auto& ownerId = rrefForkData.ownerId_;
auto& rrefId = rrefForkData.rrefId_;
auto& forkId = rrefForkData.forkId_;
if (ownerId == getWorkerId()) {
auto ownerRRef = getOwnerRRef(rrefId);
TORCH_INTERNAL_ASSERT(ownerRRef->type() == type);
@ -217,7 +217,7 @@ std::shared_ptr<OwnerRRef> RRefContext::getOwnerRRef(const RRefId& rrefId) {
}
RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
auto rfd = rref->fork();
auto rrefForkData = rref->fork();
if (rref->isOwner()) {
// Note [Early Fork Registration]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -229,7 +229,7 @@ RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
// ACK does not making any difference but only add complexity.
// TODO: When adding failure retries and timeout, this fork needs to be
// deleted if the owner does not receive the ACK within the timeout.
addForkOfOwner(rfd.rrefId_, rfd.forkId_);
addForkOfOwner(rrefForkData.rrefId_, rrefForkData.forkId_);
// ensure that this RRef is in the owners_ list to keep it alive.
// this is needed for OwnerRRefs that were created locally.
{
@ -240,17 +240,17 @@ RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
// Note [Useful Phantom Fork ID for User to Owner Call]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// If the callee of dist.remote or dist.rpc is the owner of this RRef, the
// callee will not create a fork using this rfd.forkId_, because the owner
// will only keep one `OwnerRRef` instance and will not create any
// `UserRRef` instances. However, this rfd.forkId_ is still necessary, as
// the caller user needs to keep this `UserRRef` alive until it gets the
// ACK from the callee owner. Otherwise, the delete message could arrive
// at the owner before this dist.rpc or dist.remote call, which could
// potentially trigger the `OwnerRRef` to be deleted before running the
// user code.
addPendingChild(rfd.forkId_, rref);
// callee will not create a fork using this rrefForkData.forkId_, because
// the owner will only keep one `OwnerRRef` instance and will not create any
// `UserRRef` instances. However, this rrefForkData.forkId_ is still
// necessary, as the caller user needs to keep this `UserRRef` alive until
// it gets the ACK from the callee owner. Otherwise, the delete message
// could arrive at the owner before this dist.rpc or dist.remote call, which
// could potentially trigger the `OwnerRRef` to be deleted before running
// the user code.
addPendingChild(rrefForkData.forkId_, rref);
}
return rfd;
return rrefForkData;
}
void RRefContext::notifyOwnerAndParentOfFork(

View File

@ -14,13 +14,13 @@ namespace rpc {
namespace callback {
// It's the callback for RemoteCall.
void confirmPendingUser(
void TORCH_API confirmPendingUser(
const rpc::Message& message,
const c10::optional<utils::FutureError>& futErr);
} // namespace callback
// Manages RRef lifetime and keeps track of RRef forks.
class RRefContext {
class TORCH_API RRefContext {
public:
static RRefContext& getInstance();
// NB: This method must be called before destructing RRefContext singleton.

View File

@ -2,30 +2,14 @@
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/jit/pybind_utils.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
constexpr int PARENT_IDX = 5; // index of parent in the tuple
constexpr int TYPE_IDX = 6; // index of parent in the tuple
// NB: if more fields are added, make sure this field is also bumped
constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
} // namespace
std::atomic<local_id_t> RRefContext::nextLocalId_{0};
////////////////////////// RRefForkData /////////////////////////////////
@ -35,46 +19,12 @@ RRefForkData::RRefForkData(
const RRefId& rrefId,
const ForkId& forkId,
worker_id_t parent,
std::string type_str)
std::string typeStr)
: ownerId_(ownerId),
rrefId_(rrefId),
forkId_(forkId),
parent_(parent),
type_str_(std::move(type_str)) {}
py::tuple RRefForkData::toPyTuple() const {
// add GIL as it is contructing a py::object
pybind11::gil_scoped_acquire ag;
return py::make_tuple(
ownerId_,
rrefId_.createdOn_,
rrefId_.localId_,
forkId_.createdOn_,
forkId_.localId_,
parent_,
type_str_);
}
RRefForkData RRefForkData::fromPyTuple(const py::tuple& t) {
// add GIL as it is accessing a py::object
pybind11::gil_scoped_acquire ag;
TORCH_INTERNAL_ASSERT(
t.size() == RFD_TUPLE_SIZE,
"Pickled RRefForkData must contain 6 numbers.");
worker_id_t ownerId = t[OWNER_IDX].cast<worker_id_t>();
// const reference will extend the lifetime of the temporary variable
const RRefId& rrefId = RRefId(
t[RREFID_ON_IDX].cast<worker_id_t>(),
t[RREFID_ID_IDX].cast<local_id_t>());
const RRefId& forkId = RRefId(
t[FORKID_ON_IDX].cast<worker_id_t>(),
t[FORKID_ID_IDX].cast<local_id_t>());
worker_id_t parent = t[PARENT_IDX].cast<worker_id_t>();
const std::string& typeStr = t[TYPE_IDX].cast<std::string>();
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
}
typeStr_(std::move(typeStr)) {}
////////////////////////////// RRef /////////////////////////////////////
@ -127,7 +77,7 @@ const ForkId& UserRRef::forkId() const {
return forkId_;
}
IValue UserRRef::toHere() {
std::vector<IValue> UserRRef::toHere() {
auto agent = RpcAgent::getCurrentRpcAgent();
// ScriptRRefFetchCall message always carries autograd context id even if
@ -156,16 +106,8 @@ IValue UserRRef::toHere() {
"Message type should either be SCRIPT_RREF_FETCH_RET "
"or PYTHON_RREF_FETCH_RET");
RpcCommandBase& rpc = *response;
if (isPyObj()) {
auto& rfr = static_cast<PythonRRefFetchRet&>(rpc);
return jit::toIValue(
PythonRpcHandler::getInstance().deserialize(
SerializedPyObj::fromIValues(rfr.values())),
PyObjectType::get());
} else {
auto& rfr = static_cast<ScriptRRefFetchRet&>(rpc);
return rfr.values().front();
}
auto& rrefFetchRet = static_cast<RRefFetchRet&>(rpc);
return rrefFetchRet.values();
}
////////////////////////// OwnerRRef /////////////////////////////////////

View File

@ -6,7 +6,6 @@
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref_interface.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/utils/pybind.h>
#include <atomic>
@ -19,27 +18,19 @@ class RRefContext;
class UserRRef;
// Represents fork of an RRef to be sent over the wire.
struct RRefForkData {
py::tuple toPyTuple() const;
static RRefForkData fromPyTuple(const py::tuple& obj);
struct TORCH_API RRefForkData {
const worker_id_t ownerId_;
const RRefId rrefId_;
const ForkId forkId_;
const worker_id_t parent_;
const std::string type_str_;
private:
friend class RRef;
friend class RRefContext;
friend class UserRRef;
const std::string typeStr_;
RRefForkData(
worker_id_t ownerId,
const RRefId& rrefId_,
const ForkId& forkId_,
worker_id_t parent,
std::string type_str);
std::string typeStr);
};
// Note [RRef Protocol]
@ -184,7 +175,7 @@ struct RRefForkData {
//
// ``RRef`` is the base type for both ``UserRRef`` and ``OwnerRRef``.
// Each ``RRef`` has a globally unique ``RRefId``.
class RRef : public RRefInterface {
class TORCH_API RRef : public RRefInterface {
public:
// RRef is made NOT copyable NOT movable to prevent messing up reference
// counting.
@ -230,7 +221,7 @@ class RRef : public RRefInterface {
// also has a globally unique ``ForkId`` to identify this user. ``UserRRef``
// never owns the real value, the only way to get the value of the ``RRef`` is
// to call ``to_here()`` and get a copy..
class UserRRef final : public RRef {
class TORCH_API UserRRef final : public RRef {
public:
UserRRef(const UserRRef& other) = delete;
UserRRef(UserRRef&& other) = delete;
@ -246,7 +237,7 @@ class UserRRef final : public RRef {
// Get of copy of the value from the ``OwnerRRef``. If the value is not ready
// yet, this call will block.
IValue toHere();
std::vector<IValue> toHere();
// Upon destruction, this ``UserRRef`` will tell the owner to deref.
~UserRRef() override;
@ -265,7 +256,7 @@ class UserRRef final : public RRef {
// Keep the template only on the derived class because ``RRefContext`` needs to
// erase the type on ``RRef`` and keep them in one map.
class OwnerRRef final : public RRef {
class TORCH_API OwnerRRef final : public RRef {
public:
OwnerRRef(const OwnerRRef& other) = delete;
OwnerRRef(OwnerRRef&& other) = delete;

View File

@ -2,7 +2,6 @@
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/utils.h>
@ -14,6 +13,7 @@ namespace rpc {
c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
const std::string& dstWorkerName,
const c10::QualifiedName& qualifiedName,
const c10::FunctionSchema& functionSchema,
std::vector<c10::IValue>& stack) {
auto scriptCall =
std::make_unique<ScriptCall>(qualifiedName, std::move(stack));
@ -24,11 +24,7 @@ c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
std::move(*scriptCall).toMessage());
// Get function return type to construct c10::ivalue::Future.
auto returns = PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(qualifiedName)
.getSchema()
.returns();
auto returns = functionSchema.returns();
// Script call only allows single IValue returned.
TORCH_INTERNAL_ASSERT(
returns.size() == 1,
@ -56,6 +52,7 @@ c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
std::shared_ptr<UserRRef> remoteTorchscript(
const std::string& dstWorkerName,
const c10::QualifiedName& qualifiedName,
const c10::FunctionSchema& functionSchema,
std::vector<c10::IValue>& stack) {
auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
auto dstWorkerInfo = rpcAgentPtr->getWorkerInfo(dstWorkerName);
@ -66,11 +63,7 @@ std::shared_ptr<UserRRef> remoteTorchscript(
"Does not support creating RRef on self yet.");
// Get function return type to construct UserRRef.
auto returns = PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(qualifiedName)
.getSchema()
.returns();
auto returns = functionSchema.returns();
// Script call only allows single IValue returned.
TORCH_INTERNAL_ASSERT(
returns.size() == 1,

View File

@ -2,8 +2,6 @@
#include <ATen/core/ivalue.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/python_functions.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
@ -21,14 +19,16 @@ namespace rpc {
// "dist_autograd_test::my_py_add"
// stack: a bag of IValue args passed to torchscriptFunctionName
// It returns c10::intrusive_ptr<ivalue::Future>
c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
c10::intrusive_ptr<c10::ivalue::Future> TORCH_API rpcTorchscript(
const std::string& dstWorkerName,
const c10::QualifiedName& qualifiedName,
const c10::FunctionSchema& functionSchema,
std::vector<c10::IValue>& stack);
std::shared_ptr<UserRRef> remoteTorchscript(
std::shared_ptr<UserRRef> TORCH_API remoteTorchscript(
const std::string& dstWorkerName,
const c10::QualifiedName& qualifiedName,
const c10::FunctionSchema& functionSchema,
std::vector<c10::IValue>& stack);
} // namespace rpc

View File

@ -248,8 +248,8 @@ def clear_global_rref():
@torch.jit.script
def no_args():
a = 1
def one_arg(value):
return value + 1
@torch.jit.script
@ -832,19 +832,27 @@ class RpcTest(RpcAgentTestFixture):
fut.wait()
@dist_init
def test_script_function_exception(self):
dst_rank = (self.rank + 1) % self.world_size
def test_torchscript_function(self):
dst_worker_name = "worker{}".format((self.rank + 1) % self.world_size)
with self.assertRaisesRegex(Exception, "no_args"):
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_args, args=(10,))
ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(torch.ones(2, 2),))
with self.assertRaisesRegex(
Exception, r"no_args\(\) expected at most 0 argument"
):
rref = rpc.remote("worker{}".format(dst_rank), no_args, args=(10,))
rref = rpc.remote(dst_worker_name, one_arg, args=(torch.ones(2, 2),))
@dist_init
def test_script_functions_not_supported(self):
def test_torchscript_function_exception(self):
dst_worker_name = "worker{}".format((self.rank + 1) % self.world_size)
with self.assertRaisesRegex(Exception, r"one_arg\(\) expected at most"):
ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(10, 20))
with self.assertRaisesRegex(
Exception, r"one_arg\(\) expected at most"
):
rref = rpc.remote(dst_worker_name, one_arg, args=(10, 20))
@dist_init
def test_torchscript_functions_not_supported(self):
# Right now _rpc_sync_torchscript does not accept annotated torchscript
# class name or script module class name or their class method names.
# But rpc_sync still accepts script class name and run it in