Files
pytorch/torch/csrc/distributed/rpc/python_rpc_handler.cpp
Yanli Zhao 58234c0254 support torch script call over rpc (#32197)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32197

This is to reland https://github.com/pytorch/pytorch/pull/30063, the main change is to match a general exception and grep "pickle" error word in "test_script_functions_not_supported" unit test, as Python 3.5 and Python 3.6 throw different types of errors with different error message for the rpc call in the unit test.
[test all]This diff makes following changes:
1. Providing a new set of python rpc privated APIs, they can accept an annotated TorchScript call and this call can be serialized, deserialized and executed in C++ without GIL. These privated APIs will be binded to JIT in the future, and they are different from public APIs as future JIT binded private APIs will be able to accept qualified_name, not callables. These private APIs are subject to be deprecated once JIT supports torch script function to be a JIT type.

Also, these APIs require torch script function to be defined and annotated by users in python land, it can not be script class/module constructor or class/module methods.

2. This diff also allows public rpc APIs to accept an annotated TorchScript call and execute code path that above private APIs ran on. Therefore if users invoke an annotated TorchScript call over RPC, this call can be serialized, deserialized and executed in C++ without GIL as well.

3. The above private APIs call a newly defined C++ function to make rpc torch script call to be serialized, deserialized and executed in C++ land. This C++ function returns an ivalue::Future. so that in follow up diff this C++ function can be called when these privated APIs are binded to JIT.

4. script_call.cpp/.h and request_callback_impl.cpp files are refactored accordingly so that torch script call and builtin call can share same message type and codes.

5. refactored deserializeResponse() and added a new utility to deserizalize response to IValue

ghstack-source-id: 96879167
ghstack-source-id: 96879167

Test Plan: unit test

Differential Revision: D19402374

fbshipit-source-id: 04efcc7c167d08a6503f29efe55e76f2be4b2c5e
2020-01-18 09:24:17 -08:00

100 lines
3.1 KiB
C++

#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/jit/pybind_utils.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
py::object getFunction(const py::object& module, const char* name) {
py::object fn = module.attr(name);
TORCH_CHECK(
py::isinstance<py::function>(fn),
"attribute ",
name,
" is not a function");
return fn;
}
} // namespace
PythonRpcHandler::PythonRpcHandler() {
pybind11::gil_scoped_acquire ag;
py::object module = py::module::import("torch.distributed.rpc.internal");
pyRunFunction_ = getFunction(module, "_run_function");
pyLoadReturnValue_ = getFunction(module, "_load_return_value");
pySerialize_ = getFunction(module, "serialize");
pyHandleException_ = getFunction(module, "_handle_exception");
jitCompilationUnit_ = torch::jit::get_python_cu();
}
void PythonRpcHandler::cleanup() {
pybind11::gil_scoped_acquire ag;
pyRunFunction_ = py::none();
pyLoadReturnValue_ = py::none();
pySerialize_ = py::none();
pyHandleException_ = py::none();
jitCompilationUnit_ = nullptr;
}
PythonRpcHandler& PythonRpcHandler::getInstance() {
static PythonRpcHandler handler;
return handler;
}
std::shared_ptr<torch::jit::script::CompilationUnit> PythonRpcHandler::
jitCompilationUnit() {
return jitCompilationUnit_;
}
std::vector<char> PythonRpcHandler::generatePythonUDFResult(
const std::vector<char>& pickledPayload,
const std::vector<torch::Tensor>& requestTensorTable,
std::vector<torch::Tensor>& responseTensorTable) {
pybind11::gil_scoped_acquire ag;
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
py::tuple pres = pySerialize_(pyRunFunction_(pargs, requestTensorTable));
const auto& presStr = pres[0].cast<std::string>();
responseTensorTable = pres[1].cast<std::vector<torch::Tensor>>();
std::vector<char> payload(presStr.begin(), presStr.end());
return payload;
}
py::object PythonRpcHandler::loadPythonUDFResult(
const std::vector<char>& pickledPayload,
const std::vector<torch::Tensor>& tensorTable) {
pybind11::gil_scoped_acquire ag;
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
return pyLoadReturnValue_(pargs, tensorTable);
}
py::object PythonRpcHandler::runPythonUDF(
const SerializedPyObj& serializedObj) {
pybind11::gil_scoped_acquire ag;
return pyRunFunction_(
py::bytes(serializedObj.payload_), serializedObj.tensors_);
}
SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
pybind11::gil_scoped_acquire ag;
py::tuple t = pySerialize_(obj);
return SerializedPyObj(
t[0].cast<std::string>(), t[1].cast<std::vector<torch::Tensor>>());
}
py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
pybind11::gil_scoped_acquire ag;
return pyLoadReturnValue_(
py::bytes(serializedObj.payload_), serializedObj.tensors_);
}
void PythonRpcHandler::handleException(const py::object& obj) {
pybind11::gil_scoped_acquire ag;
pyHandleException_(obj);
}
} // namespace rpc
} // namespace distributed
} // namespace torch