mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove _load_return_value from RPC internal.py (#34492)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34492 Differential Revision: D20347468 Test Plan: Imported from OSS Pulled By: mrshenli fbshipit-source-id: 92388d0d50a08fb895bacacf94c7b5495b4ae2b6
This commit is contained in:
committed by
Facebook Github Bot
parent
6d1c4df660
commit
18ef09f5ac
@ -128,8 +128,11 @@ py::object PyRRef::toHere() {
|
||||
if (rref_->isPyObj()) {
|
||||
// python_rpc_handler deserialization will acquires GIL.
|
||||
auto rfr_values = value.toTuple()->elements();
|
||||
return PythonRpcHandler::getInstance().deserialize(
|
||||
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
||||
auto ret = pythonRpcHandler.deserialize(
|
||||
SerializedPyObj::fromIValues(rfr_values));
|
||||
pythonRpcHandler.handleException(ret);
|
||||
return ret;
|
||||
} else {
|
||||
// acquiring GIL as torch::jit::toPyObject creates new py::object
|
||||
// without grabbing the GIL.
|
||||
|
||||
@ -25,6 +25,10 @@ std::unique_ptr<PythonCall> PythonCall::fromMessage(const Message& message) {
|
||||
return std::make_unique<PythonCall>(std::move(serializedPyObj));
|
||||
}
|
||||
|
||||
const SerializedPyObj& PythonCall::serializedPyObj() const {
|
||||
return serializedPyObj_;
|
||||
}
|
||||
|
||||
const std::string& PythonCall::pickledPayload() const {
|
||||
return serializedPyObj_.payload_;
|
||||
}
|
||||
|
||||
@ -16,6 +16,8 @@ class TORCH_API PythonCall final : public RpcCommandBase {
|
||||
|
||||
static std::unique_ptr<PythonCall> fromMessage(const Message& message);
|
||||
|
||||
const SerializedPyObj& serializedPyObj() const;
|
||||
|
||||
const std::string& pickledPayload() const;
|
||||
|
||||
const std::vector<torch::Tensor>& tensors() const;
|
||||
|
||||
@ -99,9 +99,10 @@ py::object toPyObjInternal(RpcCommandBase& rpc, MessageType messageType) {
|
||||
case MessageType::PYTHON_RET: {
|
||||
// TODO: Try to avoid a copy here.
|
||||
auto& resp = static_cast<PythonResp&>(rpc);
|
||||
|
||||
return PythonRpcHandler::getInstance().loadPythonUDFResult(
|
||||
resp.pickledPayload(), resp.tensors());
|
||||
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
||||
py::object ret = pythonRpcHandler.deserialize(resp.serializedPyObj());
|
||||
pythonRpcHandler.handleException(ret);
|
||||
return ret;
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(false, "Unrecognized response message type ", messageType);
|
||||
|
||||
@ -25,6 +25,10 @@ std::unique_ptr<PythonResp> PythonResp::fromMessage(const Message& message) {
|
||||
return std::make_unique<PythonResp>(std::move(serializedPyObj));
|
||||
}
|
||||
|
||||
const SerializedPyObj& PythonResp::serializedPyObj() const {
|
||||
return serializedPyObj_;
|
||||
}
|
||||
|
||||
const std::string& PythonResp::pickledPayload() const {
|
||||
return serializedPyObj_.payload_;
|
||||
}
|
||||
|
||||
@ -16,6 +16,8 @@ class TORCH_API PythonResp final : public RpcCommandBase {
|
||||
|
||||
static std::unique_ptr<PythonResp> fromMessage(const Message& message);
|
||||
|
||||
const SerializedPyObj& serializedPyObj() const;
|
||||
|
||||
const std::string& pickledPayload() const;
|
||||
|
||||
const std::vector<torch::Tensor>& tensors() const;
|
||||
|
||||
@ -62,8 +62,8 @@ PythonRpcHandler::PythonRpcHandler() {
|
||||
PROFILE_GIL_SCOPED_ACQUIRE;
|
||||
py::object module = py::module::import("torch.distributed.rpc.internal");
|
||||
pyRunFunction_ = getFunction(module, "_run_function");
|
||||
pyLoadReturnValue_ = getFunction(module, "_load_return_value");
|
||||
pySerialize_ = getFunction(module, "serialize");
|
||||
pyDeserialize_ = getFunction(module, "deserialize");
|
||||
pyHandleException_ = getFunction(module, "_handle_exception");
|
||||
jitCompilationUnit_ = torch::jit::get_python_cu();
|
||||
typeParser_ = std::make_shared<jit::script::ScriptTypeParser>(
|
||||
@ -73,8 +73,8 @@ PythonRpcHandler::PythonRpcHandler() {
|
||||
void PythonRpcHandler::cleanup() {
|
||||
PROFILE_GIL_SCOPED_ACQUIRE;
|
||||
pyRunFunction_ = py::none();
|
||||
pyLoadReturnValue_ = py::none();
|
||||
pySerialize_ = py::none();
|
||||
pyDeserialize_ = py::none();
|
||||
pyHandleException_ = py::none();
|
||||
jitCompilationUnit_ = nullptr;
|
||||
typeParser_ = nullptr;
|
||||
@ -102,14 +102,6 @@ std::string PythonRpcHandler::generatePythonUDFResult(
|
||||
return pres[0].cast<std::string>();
|
||||
}
|
||||
|
||||
py::object PythonRpcHandler::loadPythonUDFResult(
|
||||
const std::string& pickledPayload,
|
||||
const std::vector<torch::Tensor>& tensorTable) {
|
||||
PROFILE_GIL_SCOPED_ACQUIRE;
|
||||
auto pargs = py::bytes(pickledPayload);
|
||||
return pyLoadReturnValue_(pargs, tensorTable);
|
||||
}
|
||||
|
||||
py::object PythonRpcHandler::runPythonUDF(
|
||||
const SerializedPyObj& serializedObj) {
|
||||
PROFILE_GIL_SCOPED_ACQUIRE;
|
||||
@ -126,7 +118,7 @@ SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
|
||||
|
||||
py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
|
||||
PROFILE_GIL_SCOPED_ACQUIRE;
|
||||
return pyLoadReturnValue_(
|
||||
return pyDeserialize_(
|
||||
py::bytes(serializedObj.payload_), serializedObj.tensors_);
|
||||
}
|
||||
|
||||
|
||||
@ -25,12 +25,6 @@ class PYBIND11_EXPORT PythonRpcHandler {
|
||||
const std::vector<torch::Tensor>& requestTensorTable,
|
||||
std::vector<torch::Tensor>& responseTensorTable);
|
||||
|
||||
// Returned python UDF result is pickled binary string, so run python
|
||||
// function to unpickle the python UDF result and return py::object to user
|
||||
py::object loadPythonUDFResult(
|
||||
const std::string& pickledPayload,
|
||||
const std::vector<torch::Tensor>& tensorTable);
|
||||
|
||||
// Run a pickled Python UDF and return the result py::object
|
||||
py::object runPythonUDF(const SerializedPyObj& serializedObj);
|
||||
|
||||
@ -85,12 +79,12 @@ class PYBIND11_EXPORT PythonRpcHandler {
|
||||
// Ref to `torch.distributed.rpc.internal._run_function`.
|
||||
py::object pyRunFunction_;
|
||||
|
||||
// Ref to `torch.distributed.rpc.internal._load_return_value`.
|
||||
py::object pyLoadReturnValue_;
|
||||
|
||||
// Ref to `torch.distributed.rpc.internal.serialize`.
|
||||
py::object pySerialize_;
|
||||
|
||||
// Ref to `torch.distributed.rpc.internal.deserialize`.
|
||||
py::object pyDeserialize_;
|
||||
|
||||
// Ref to 'torch.distributed.rpc.internal._handle_exception'
|
||||
py::object pyHandleException_;
|
||||
|
||||
|
||||
@ -118,6 +118,10 @@ def serialize(obj):
|
||||
return _internal_rpc_pickler.serialize(obj)
|
||||
|
||||
|
||||
def deserialize(binary_data, tensor_table):
|
||||
return _internal_rpc_pickler.deserialize(binary_data, tensor_table)
|
||||
|
||||
|
||||
def _run_function(binary_data, tensor_table):
|
||||
r"""
|
||||
This function is exclusively called from C++.
|
||||
@ -141,19 +145,6 @@ def _handle_exception(result):
|
||||
raise result.exception_type(result.msg)
|
||||
|
||||
|
||||
def _load_return_value(binary_data, tensor_table):
|
||||
r"""
|
||||
This function is exclusively called from C++.
|
||||
See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
|
||||
|
||||
Processes the return value of a Python function.
|
||||
Raises exception if the return value is a wrapped exception.
|
||||
"""
|
||||
result = _internal_rpc_pickler.deserialize(binary_data, tensor_table)
|
||||
_handle_exception(result)
|
||||
return result
|
||||
|
||||
|
||||
def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name):
|
||||
"""
|
||||
This function should be called from RPC/RRef functions to create a
|
||||
|
||||
Reference in New Issue
Block a user