mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25527 Master GH issue: https://github.com/pytorch/pytorch/issues/23110. This change builds upon https://github.com/pytorch/pytorch/pull/24876 and provides all the autograd hooks needed for a forward pass with distributed rpc for builtin operators. This change does not address distributed rpc for python UDFs and that will be addressed in follow up PRs. Summary of changes: 1. Attach send autograd functions when a request is sent from the client and response is sent from the server. 2. Attach receive autograd functions when a request is received on the server and a response is received on the client. 3. Generate a globally unique autograd_message_id for each send/recv autograd function pair to uniquely identify them. ghstack-source-id: 91240466 Test Plan: unit tests. Differential Revision: D17148077 fbshipit-source-id: 192d8a3f552ed7cc939f55dcca332965c9bd3233
46 lines
1.5 KiB
C++
46 lines
1.5 KiB
C++
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
PythonRpcHandler::PythonRpcHandler() {
|
|
AutoGIL ag;
|
|
py::object module =
|
|
py::module::import("torch.distributed.internal_rpc_utils");
|
|
runUDFFunction_ = module.attr("run_python_udf_internal");
|
|
loadResultFunction_ = module.attr("load_python_udf_result_internal");
|
|
}
|
|
|
|
PythonRpcHandler& PythonRpcHandler::getInstance() {
|
|
static PythonRpcHandler handler;
|
|
return handler;
|
|
}
|
|
|
|
std::vector<char> PythonRpcHandler::generatePythonUDFResult(
|
|
const std::vector<char>& pickledPayload,
|
|
const std::vector<torch::Tensor>& requestTensorTable,
|
|
std::vector<torch::Tensor>& responseTensorTable) {
|
|
AutoGIL ag;
|
|
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
|
|
TORCH_CHECK(runUDFFunction_ != nullptr, "runUDFFunction_ is nullptr");
|
|
py::tuple pres = runUDFFunction_(pargs, requestTensorTable);
|
|
const auto& presStr = pres[0].cast<std::string>();
|
|
responseTensorTable = pres[1].cast<std::vector<torch::Tensor>>();
|
|
std::vector<char> payload(presStr.begin(), presStr.end());
|
|
return payload;
|
|
}
|
|
|
|
py::object PythonRpcHandler::loadPythonUDFResult(
|
|
const std::vector<char>& pickledPayload,
|
|
const std::vector<torch::Tensor>& tensorTable) {
|
|
AutoGIL ag;
|
|
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
|
|
TORCH_CHECK(loadResultFunction_ != nullptr, "loadResultFunction_ is nullptr");
|
|
return loadResultFunction_(pargs, tensorTable);
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|