Files
pytorch/torch/csrc/distributed/rpc/python_rpc_handler.cpp
Pritam Damania fe4170bda8 Add send and recv backward functions for builtin operators RPC. (#25527)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25527

Master GH issue: https://github.com/pytorch/pytorch/issues/23110.

This change builds upon https://github.com/pytorch/pytorch/pull/24876 and
provides all the autograd hooks needed for a forward pass with distributed rpc
for builtin operators. This change does not address distributed rpc for python
UDFs and that will be addressed in follow up PRs.

Summary of changes:
1. Attach send autograd functions when a request is sent from the client and
response is sent from the server.
2. Attach receive autograd functions when a request is received on the server
and a response is received on the client.
3. Generate a globally unique autograd_message_id for each send/recv autograd
function pair to uniquely identify them.
ghstack-source-id: 91240466

Test Plan: unit tests.

Differential Revision: D17148077

fbshipit-source-id: 192d8a3f552ed7cc939f55dcca332965c9bd3233
2019-10-03 01:18:46 -07:00

46 lines
1.5 KiB
C++

#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
namespace torch {
namespace distributed {
namespace rpc {
PythonRpcHandler::PythonRpcHandler() {
AutoGIL ag;
py::object module =
py::module::import("torch.distributed.internal_rpc_utils");
runUDFFunction_ = module.attr("run_python_udf_internal");
loadResultFunction_ = module.attr("load_python_udf_result_internal");
}
PythonRpcHandler& PythonRpcHandler::getInstance() {
static PythonRpcHandler handler;
return handler;
}
std::vector<char> PythonRpcHandler::generatePythonUDFResult(
const std::vector<char>& pickledPayload,
const std::vector<torch::Tensor>& requestTensorTable,
std::vector<torch::Tensor>& responseTensorTable) {
AutoGIL ag;
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
TORCH_CHECK(runUDFFunction_ != nullptr, "runUDFFunction_ is nullptr");
py::tuple pres = runUDFFunction_(pargs, requestTensorTable);
const auto& presStr = pres[0].cast<std::string>();
responseTensorTable = pres[1].cast<std::vector<torch::Tensor>>();
std::vector<char> payload(presStr.begin(), presStr.end());
return payload;
}
py::object PythonRpcHandler::loadPythonUDFResult(
const std::vector<char>& pickledPayload,
const std::vector<torch::Tensor>& tensorTable) {
AutoGIL ag;
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
TORCH_CHECK(loadResultFunction_ != nullptr, "loadResultFunction_ is nullptr");
return loadResultFunction_(pargs, tensorTable);
}
} // namespace rpc
} // namespace distributed
} // namespace torch