Fix pybind11 warnings in python_rpc_handler.cpp (#27284)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27284

The warnings related to usage of the deprecated != operator. Instead
of checking the member field on every function call, we can check it
once, on construction of PythonRpcHandler.

Test Plan: Imported from OSS

Differential Revision: D17808213

Pulled By: pietern

fbshipit-source-id: 022c8f77f266942c49c55b1729e62dbb06262d77
This commit is contained in:
Pieter Noordhuis
2019-10-08 11:22:18 -07:00
committed by Facebook Github Bot
parent 0d22f3b170
commit c742918854

View File

@ -4,13 +4,27 @@ namespace torch {
namespace distributed {
namespace rpc {
namespace {
py::object getFunction(const py::object& module, const char* name) {
py::object fn = module.attr(name);
TORCH_CHECK(
py::isinstance<py::function>(fn),
"attribute ",
name,
" is not a function");
return fn;
}
} // namespace
PythonRpcHandler::PythonRpcHandler() {
AutoGIL ag;
py::object module =
py::module::import("torch.distributed.internal_rpc_utils");
runUDFFunction_ = module.attr("run_python_udf_internal");
loadResultFunction_ = module.attr("load_python_udf_result_internal");
serializeFunction_ = module.attr("serialize");
runUDFFunction_ = getFunction(module, "run_python_udf_internal");
loadResultFunction_ = getFunction(module, "load_python_udf_result_internal");
serializeFunction_ = getFunction(module, "serialize");
}
PythonRpcHandler& PythonRpcHandler::getInstance() {
@ -24,7 +38,6 @@ std::vector<char> PythonRpcHandler::generatePythonUDFResult(
std::vector<torch::Tensor>& responseTensorTable) {
AutoGIL ag;
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
TORCH_CHECK(runUDFFunction_ != nullptr, "runUDFFunction_ is nullptr");
py::tuple pres =
serializeFunction_(runUDFFunction_(pargs, requestTensorTable));
const auto& presStr = pres[0].cast<std::string>();
@ -38,7 +51,6 @@ py::object PythonRpcHandler::loadPythonUDFResult(
const std::vector<torch::Tensor>& tensorTable) {
AutoGIL ag;
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
TORCH_CHECK(loadResultFunction_ != nullptr, "loadResultFunction_ is nullptr");
return loadResultFunction_(pargs, tensorTable);
}