Files
pytorch/torch/csrc/distributed/rpc/python_functions.cpp
Shen Li 18ef09f5ac Remove _load_return_value from RPC internal.py (#34492)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34492

Differential Revision: D20347468

Test Plan: Imported from OSS

Pulled By: mrshenli

fbshipit-source-id: 92388d0d50a08fb895bacacf94c7b5495b4ae2b6
2020-03-09 20:40:50 -07:00

230 lines
7.6 KiB
C++

#include <torch/csrc/distributed/rpc/python_functions.h>
#include <c10/util/C++17.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/python_call.h>
#include <torch/csrc/distributed/rpc/python_remote_call.h>
#include <torch/csrc/distributed/rpc/python_resp.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/jit/python/pybind_utils.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
std::shared_ptr<Operator> matchBuiltinOp(
const std::string& opName,
const py::args& args,
const py::kwargs& kwargs,
Stack& stack) {
Symbol symbol = Symbol::fromQualString(opName);
if (symbol.is_aten()) {
for (const auto& op : torch::jit::getAllOperatorsFor(symbol)) {
try {
// FIXME: This is temporary solution. We should at least refactor
// ``createStackForSchema`` to avoid throwing an error.
stack = torch::jit::createStackForSchema(
op->schema(), args, kwargs, c10::nullopt);
} catch (std::runtime_error& e) {
VLOG(1) << "Couldn't match schema: " << op->schema()
<< " to args: " << args << " and kwargs: " << kwargs
<< ", reason: " << e.what();
continue;
}
// Found the right op!
return op;
}
}
TORCH_CHECK(
false,
"Failed to match operator name ",
opName,
" and arguments "
"(args: ",
args,
", kwargs: ",
kwargs,
") to a builtin operator");
}
std::shared_ptr<FutureMessage> sendPythonRemoteCall(
const WorkerInfo& dst,
SerializedPyObj serializedPyObj,
const IValue& rrefId,
const IValue& forkId,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
auto pythonRemoteCall = std::make_unique<PythonRemoteCall>(
std::move(serializedPyObj), rrefId, forkId);
// set forceGradRecording to true as even if the args does not contain any
// tensor, the return value might still contain tensors.
auto agent = RpcAgent::getCurrentRpcAgent();
return torch::distributed::autograd::sendMessageWithAutograd(
*agent,
dst,
std::move(*pythonRemoteCall).toMessage(),
true /*forceGradRecording*/,
rf);
}
} // namespace
using namespace torch::distributed::autograd;
py::object toPyObjInternal(RpcCommandBase& rpc, MessageType messageType) {
switch (messageType) {
case MessageType::SCRIPT_RET: {
auto& ret = static_cast<ScriptResp&>(rpc);
Stack stack;
stack.push_back(ret.value());
{
pybind11::gil_scoped_acquire ag;
// The createPyObjectForStack does not acquire GIL, but creating a new
// py::object requires GIL.
return torch::jit::createPyObjectForStack(std::move(stack));
}
}
case MessageType::PYTHON_RET: {
// TODO: Try to avoid a copy here.
auto& resp = static_cast<PythonResp&>(rpc);
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
py::object ret = pythonRpcHandler.deserialize(resp.serializedPyObj());
pythonRpcHandler.handleException(ret);
return ret;
}
default: {
TORCH_CHECK(false, "Unrecognized response message type ", messageType);
}
}
}
py::object toPyObj(const Message& message) {
MessageType msgType = message.type();
auto response = deserializeResponse(message, msgType);
return toPyObjInternal(*response, msgType);
}
std::shared_ptr<FutureMessage> pyRpcBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
const py::args& args,
const py::kwargs& kwargs) {
Stack stack;
auto op = matchBuiltinOp(opName, args, kwargs, stack);
// Release GIL since args and kwargs processing is done.
py::gil_scoped_release release;
auto scriptCall = std::make_unique<ScriptCall>(op, std::move(stack));
auto agent = RpcAgent::getCurrentRpcAgent();
return sendMessageWithAutograd(
*agent, dst, std::move(*scriptCall).toMessage(), false, rf);
}
PyRRef pyRemoteBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
const py::args& args,
const py::kwargs& kwargs) {
Stack stack;
auto op = matchBuiltinOp(opName, args, kwargs, stack);
// Release GIL since args and kwargs processing is done.
py::gil_scoped_release release;
TypePtr returnType = op->schema().returns()[0].type();
auto& ctx = RRefContext::getInstance();
// TODO: support creating RRefs on a local object.
TORCH_INTERNAL_ASSERT(
ctx.getWorkerId() != dst.id_,
"Does not support creating RRef on self yet.");
auto userRRef = ctx.createUserRRef(dst.id_, returnType);
auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
op, std::move(stack), userRRef->rrefId(), userRRef->forkId());
auto agent = RpcAgent::getCurrentRpcAgent();
auto fm = sendMessageWithAutograd(
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false, rf);
ctx.addPendingUser(userRRef->forkId(), userRRef);
fm->addCallback(callback::confirmPendingUser);
return PyRRef(userRRef);
}
std::shared_ptr<FutureMessage> pyRpcPythonUdf(
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
auto serializedPyObj =
SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
auto pythonCall = std::make_unique<PythonCall>(std::move(serializedPyObj));
auto agent = RpcAgent::getCurrentRpcAgent();
return sendMessageWithAutograd(
*agent,
dst,
std::move(*pythonCall).toMessage(),
true /*forceGradRecording*/,
rf);
}
PyRRef pyRemotePythonUdf(
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
auto& ctx = RRefContext::getInstance();
auto serializedPyObj =
SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
if (ctx.getWorkerId() != dst.id_) {
auto userRRef = ctx.createUserRRef(dst.id_, PyObjectType::get());
ctx.addPendingUser(userRRef->forkId(), userRRef);
auto fm = sendPythonRemoteCall(
dst,
std::move(serializedPyObj),
userRRef->rrefId().toIValue(),
userRRef->forkId().toIValue(),
rf);
fm->addCallback(callback::confirmPendingUser);
return PyRRef(userRRef);
} else {
auto ownerRRef = ctx.createOwnerRRef(PyObjectType::get());
// prevent this owner RRef be deleted due to other forks
ctx.addSelfAsFork(ownerRRef);
auto fm = sendPythonRemoteCall(
dst,
std::move(serializedPyObj),
ownerRRef->rrefId().toIValue(),
ownerRRef->rrefId().toIValue(),
rf);
fm->addCallback([](const Message& message,
const c10::optional<utils::FutureError>& futErr) {
auto deletedRRef = callback::finishCreatingOwnerRRef(message, futErr);
if (deletedRRef && deletedRRef->isPyObj()) {
pybind11::gil_scoped_acquire ag;
deletedRRef.reset();
}
});
return PyRRef(ownerRRef);
}
}
} // namespace rpc
} // namespace distributed
} // namespace torch