Remove _load_return_value from RPC internal.py (#34492)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34492

Differential Revision: D20347468

Test Plan: Imported from OSS

Pulled By: mrshenli

fbshipit-source-id: 92388d0d50a08fb895bacacf94c7b5495b4ae2b6
This commit is contained in:
Shen Li
2020-03-09 20:35:26 -07:00
committed by Facebook Github Bot
parent 6d1c4df660
commit 18ef09f5ac
9 changed files with 30 additions and 37 deletions

View File

@ -62,8 +62,8 @@ PythonRpcHandler::PythonRpcHandler() {
PROFILE_GIL_SCOPED_ACQUIRE;
py::object module = py::module::import("torch.distributed.rpc.internal");
pyRunFunction_ = getFunction(module, "_run_function");
pyLoadReturnValue_ = getFunction(module, "_load_return_value");
pySerialize_ = getFunction(module, "serialize");
pyDeserialize_ = getFunction(module, "deserialize");
pyHandleException_ = getFunction(module, "_handle_exception");
jitCompilationUnit_ = torch::jit::get_python_cu();
typeParser_ = std::make_shared<jit::script::ScriptTypeParser>(
@ -73,8 +73,8 @@ PythonRpcHandler::PythonRpcHandler() {
void PythonRpcHandler::cleanup() {
PROFILE_GIL_SCOPED_ACQUIRE;
pyRunFunction_ = py::none();
pyLoadReturnValue_ = py::none();
pySerialize_ = py::none();
pyDeserialize_ = py::none();
pyHandleException_ = py::none();
jitCompilationUnit_ = nullptr;
typeParser_ = nullptr;
@ -102,14 +102,6 @@ std::string PythonRpcHandler::generatePythonUDFResult(
return pres[0].cast<std::string>();
}
py::object PythonRpcHandler::loadPythonUDFResult(
const std::string& pickledPayload,
const std::vector<torch::Tensor>& tensorTable) {
PROFILE_GIL_SCOPED_ACQUIRE;
auto pargs = py::bytes(pickledPayload);
return pyLoadReturnValue_(pargs, tensorTable);
}
py::object PythonRpcHandler::runPythonUDF(
const SerializedPyObj& serializedObj) {
PROFILE_GIL_SCOPED_ACQUIRE;
@ -126,7 +118,7 @@ SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
PROFILE_GIL_SCOPED_ACQUIRE;
return pyLoadReturnValue_(
return pyDeserialize_(
py::bytes(serializedObj.payload_), serializedObj.tensors_);
}