mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 06:34:55 +08:00
Fix pybind11 warnings in python_rpc_handler.cpp (#27284)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27284 The warnings related to usage of the deprecated != operator. Instead of checking the member field on every function call, we can check it once, on construction of PythonRpcHandler. Test Plan: Imported from OSS Differential Revision: D17808213 Pulled By: pietern fbshipit-source-id: 022c8f77f266942c49c55b1729e62dbb06262d77
This commit is contained in:
committed by
Facebook Github Bot
parent
0d22f3b170
commit
c742918854
@ -4,13 +4,27 @@ namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
namespace {
|
||||
|
||||
py::object getFunction(const py::object& module, const char* name) {
|
||||
py::object fn = module.attr(name);
|
||||
TORCH_CHECK(
|
||||
py::isinstance<py::function>(fn),
|
||||
"attribute ",
|
||||
name,
|
||||
" is not a function");
|
||||
return fn;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PythonRpcHandler::PythonRpcHandler() {
|
||||
AutoGIL ag;
|
||||
py::object module =
|
||||
py::module::import("torch.distributed.internal_rpc_utils");
|
||||
runUDFFunction_ = module.attr("run_python_udf_internal");
|
||||
loadResultFunction_ = module.attr("load_python_udf_result_internal");
|
||||
serializeFunction_ = module.attr("serialize");
|
||||
runUDFFunction_ = getFunction(module, "run_python_udf_internal");
|
||||
loadResultFunction_ = getFunction(module, "load_python_udf_result_internal");
|
||||
serializeFunction_ = getFunction(module, "serialize");
|
||||
}
|
||||
|
||||
PythonRpcHandler& PythonRpcHandler::getInstance() {
|
||||
@ -24,7 +38,6 @@ std::vector<char> PythonRpcHandler::generatePythonUDFResult(
|
||||
std::vector<torch::Tensor>& responseTensorTable) {
|
||||
AutoGIL ag;
|
||||
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
|
||||
TORCH_CHECK(runUDFFunction_ != nullptr, "runUDFFunction_ is nullptr");
|
||||
py::tuple pres =
|
||||
serializeFunction_(runUDFFunction_(pargs, requestTensorTable));
|
||||
const auto& presStr = pres[0].cast<std::string>();
|
||||
@ -38,7 +51,6 @@ py::object PythonRpcHandler::loadPythonUDFResult(
|
||||
const std::vector<torch::Tensor>& tensorTable) {
|
||||
AutoGIL ag;
|
||||
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
|
||||
TORCH_CHECK(loadResultFunction_ != nullptr, "loadResultFunction_ is nullptr");
|
||||
return loadResultFunction_(pargs, tensorTable);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user