Files
pytorch/torch/csrc/distributed/rpc/python_remote_call.cpp
Scott Wolchok 82f7f8d471 [PyTorch] Adopt IValue::toTupleRef() where obvious (#65505)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65505

Generated with

`fastmod -m 'toTuple\(\)(\s*)->' 'toTupleRef()${1}.'`

, followed by

`fastmod '(std::move\(.*)toTupleRef\(\).' '${1}toTuple()->'`

to unbreak 2 callsites.
ghstack-source-id: 142065835

Test Plan: CI

Reviewed By: gchanan

Differential Revision: D31131025

fbshipit-source-id: 54457ae5bbeb38db9c7f196d469b98521c3d3f34
2021-11-02 10:22:18 -07:00

70 lines
2.2 KiB
C++

#include <torch/csrc/distributed/rpc/python_remote_call.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/jit/serialization/pickle.h>
namespace torch {
namespace distributed {
namespace rpc {
PythonRemoteCall::PythonRemoteCall(
SerializedPyObj&& serializedPyObj,
at::IValue retRRefId,
at::IValue retForkId,
const bool isAsyncExecution)
: serializedPyObj_(std::move(serializedPyObj)),
retRRefId_(std::move(retRRefId)),
retForkId_(std::move(retForkId)),
isAsyncExecution_(isAsyncExecution) {}
c10::intrusive_ptr<Message> PythonRemoteCall::toMessageImpl() && {
std::vector<IValue> ivalues = std::move(serializedPyObj_).toIValues();
ivalues.emplace_back(retRRefId_);
ivalues.emplace_back(retForkId_);
ivalues.emplace_back(isAsyncExecution_);
std::vector<torch::Tensor> tensor_table;
auto payload =
jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table);
return c10::make_intrusive<Message>(
std::move(payload),
std::move(tensor_table),
MessageType::PYTHON_REMOTE_CALL);
}
std::unique_ptr<PythonRemoteCall> PythonRemoteCall::fromMessage(
const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value = jit::unpickle(
payload,
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
auto values = value.toTupleRef().elements().vec();
// remove the last elements from values and convert it back to an RRef
TORCH_INTERNAL_ASSERT(
values.size() >= 3,
"Expect at least 3 elements in the unpickled values, but got ",
values.size());
bool isAsyncExecution = values.back().toBool();
values.pop_back();
auto retForkId = std::move(values.back());
values.pop_back();
auto retRRefId = std::move(values.back());
values.pop_back();
auto serializedPyObj = SerializedPyObj::fromIValues(std::move(values));
return std::make_unique<PythonRemoteCall>(
std::move(serializedPyObj),
std::move(retRRefId),
std::move(retForkId),
isAsyncExecution);
}
} // namespace rpc
} // namespace distributed
} // namespace torch