Remove _load_return_value from RPC internal.py (#34492)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34492

Differential Revision: D20347468

Test Plan: Imported from OSS

Pulled By: mrshenli

fbshipit-source-id: 92388d0d50a08fb895bacacf94c7b5495b4ae2b6
This commit is contained in:
Shen Li
2020-03-09 20:35:26 -07:00
committed by Facebook Github Bot
parent 6d1c4df660
commit 18ef09f5ac
9 changed files with 30 additions and 37 deletions

View File

@ -128,8 +128,11 @@ py::object PyRRef::toHere() {
if (rref_->isPyObj()) {
// python_rpc_handler deserialization will acquires GIL.
auto rfr_values = value.toTuple()->elements();
return PythonRpcHandler::getInstance().deserialize(
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
auto ret = pythonRpcHandler.deserialize(
SerializedPyObj::fromIValues(rfr_values));
pythonRpcHandler.handleException(ret);
return ret;
} else {
// acquiring GIL as torch::jit::toPyObject creates new py::object
// without grabbing the GIL.

View File

@ -25,6 +25,10 @@ std::unique_ptr<PythonCall> PythonCall::fromMessage(const Message& message) {
return std::make_unique<PythonCall>(std::move(serializedPyObj));
}
const SerializedPyObj& PythonCall::serializedPyObj() const {
return serializedPyObj_;
}
const std::string& PythonCall::pickledPayload() const {
return serializedPyObj_.payload_;
}

View File

@ -16,6 +16,8 @@ class TORCH_API PythonCall final : public RpcCommandBase {
static std::unique_ptr<PythonCall> fromMessage(const Message& message);
const SerializedPyObj& serializedPyObj() const;
const std::string& pickledPayload() const;
const std::vector<torch::Tensor>& tensors() const;

View File

@ -99,9 +99,10 @@ py::object toPyObjInternal(RpcCommandBase& rpc, MessageType messageType) {
case MessageType::PYTHON_RET: {
// TODO: Try to avoid a copy here.
auto& resp = static_cast<PythonResp&>(rpc);
return PythonRpcHandler::getInstance().loadPythonUDFResult(
resp.pickledPayload(), resp.tensors());
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
py::object ret = pythonRpcHandler.deserialize(resp.serializedPyObj());
pythonRpcHandler.handleException(ret);
return ret;
}
default: {
TORCH_CHECK(false, "Unrecognized response message type ", messageType);

View File

@ -25,6 +25,10 @@ std::unique_ptr<PythonResp> PythonResp::fromMessage(const Message& message) {
return std::make_unique<PythonResp>(std::move(serializedPyObj));
}
const SerializedPyObj& PythonResp::serializedPyObj() const {
return serializedPyObj_;
}
const std::string& PythonResp::pickledPayload() const {
return serializedPyObj_.payload_;
}

View File

@ -16,6 +16,8 @@ class TORCH_API PythonResp final : public RpcCommandBase {
static std::unique_ptr<PythonResp> fromMessage(const Message& message);
const SerializedPyObj& serializedPyObj() const;
const std::string& pickledPayload() const;
const std::vector<torch::Tensor>& tensors() const;

View File

@ -62,8 +62,8 @@ PythonRpcHandler::PythonRpcHandler() {
PROFILE_GIL_SCOPED_ACQUIRE;
py::object module = py::module::import("torch.distributed.rpc.internal");
pyRunFunction_ = getFunction(module, "_run_function");
pyLoadReturnValue_ = getFunction(module, "_load_return_value");
pySerialize_ = getFunction(module, "serialize");
pyDeserialize_ = getFunction(module, "deserialize");
pyHandleException_ = getFunction(module, "_handle_exception");
jitCompilationUnit_ = torch::jit::get_python_cu();
typeParser_ = std::make_shared<jit::script::ScriptTypeParser>(
@ -73,8 +73,8 @@ PythonRpcHandler::PythonRpcHandler() {
void PythonRpcHandler::cleanup() {
PROFILE_GIL_SCOPED_ACQUIRE;
pyRunFunction_ = py::none();
pyLoadReturnValue_ = py::none();
pySerialize_ = py::none();
pyDeserialize_ = py::none();
pyHandleException_ = py::none();
jitCompilationUnit_ = nullptr;
typeParser_ = nullptr;
@ -102,14 +102,6 @@ std::string PythonRpcHandler::generatePythonUDFResult(
return pres[0].cast<std::string>();
}
py::object PythonRpcHandler::loadPythonUDFResult(
const std::string& pickledPayload,
const std::vector<torch::Tensor>& tensorTable) {
PROFILE_GIL_SCOPED_ACQUIRE;
auto pargs = py::bytes(pickledPayload);
return pyLoadReturnValue_(pargs, tensorTable);
}
py::object PythonRpcHandler::runPythonUDF(
const SerializedPyObj& serializedObj) {
PROFILE_GIL_SCOPED_ACQUIRE;
@ -126,7 +118,7 @@ SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
PROFILE_GIL_SCOPED_ACQUIRE;
return pyLoadReturnValue_(
return pyDeserialize_(
py::bytes(serializedObj.payload_), serializedObj.tensors_);
}

View File

@ -25,12 +25,6 @@ class PYBIND11_EXPORT PythonRpcHandler {
const std::vector<torch::Tensor>& requestTensorTable,
std::vector<torch::Tensor>& responseTensorTable);
// Returned python UDF result is pickled binary string, so run python
// function to unpickle the python UDF result and return py::object to user
py::object loadPythonUDFResult(
const std::string& pickledPayload,
const std::vector<torch::Tensor>& tensorTable);
// Run a pickled Python UDF and return the result py::object
py::object runPythonUDF(const SerializedPyObj& serializedObj);
@ -85,12 +79,12 @@ class PYBIND11_EXPORT PythonRpcHandler {
// Ref to `torch.distributed.rpc.internal._run_function`.
py::object pyRunFunction_;
// Ref to `torch.distributed.rpc.internal._load_return_value`.
py::object pyLoadReturnValue_;
// Ref to `torch.distributed.rpc.internal.serialize`.
py::object pySerialize_;
// Ref to `torch.distributed.rpc.internal.deserialize`.
py::object pyDeserialize_;
// Ref to 'torch.distributed.rpc.internal._handle_exception'
py::object pyHandleException_;

View File

@ -118,6 +118,10 @@ def serialize(obj):
return _internal_rpc_pickler.serialize(obj)
def deserialize(binary_data, tensor_table):
return _internal_rpc_pickler.deserialize(binary_data, tensor_table)
def _run_function(binary_data, tensor_table):
r"""
This function is exclusively called from C++.
@ -141,19 +145,6 @@ def _handle_exception(result):
raise result.exception_type(result.msg)
def _load_return_value(binary_data, tensor_table):
r"""
This function is exclusively called from C++.
See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
Processes the return value of a Python function.
Raises exception if the return value is a wrapped exception.
"""
result = _internal_rpc_pickler.deserialize(binary_data, tensor_table)
_handle_exception(result)
return result
def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name):
"""
This function should be called from RPC/RRef functions to create a