Files
pytorch/torch/csrc/distributed/rpc/script_remote_call.cpp
Scott Wolchok 82f7f8d471 [PyTorch] Adopt IValue::toTupleRef() where obvious (#65505)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65505

Generated with

`fastmod -m 'toTuple\(\)(\s*)->' 'toTupleRef()${1}.'`

, followed by

`fastmod '(std::move\(.*)toTupleRef\(\).' '${1}toTuple()->'`

to unbreak 2 callsites.
ghstack-source-id: 142065835

Test Plan: CI

Reviewed By: gchanan

Differential Revision: D31131025

fbshipit-source-id: 54457ae5bbeb38db9c7f196d469b98521c3d3f34
2021-11-02 10:22:18 -07:00

86 lines
2.6 KiB
C++

#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <c10/util/C++17.h>
#include <torch/csrc/jit/serialization/pickle.h>
namespace torch {
namespace distributed {
namespace rpc {
ScriptRemoteCall::ScriptRemoteCall(
std::shared_ptr<Operator> op,
std::vector<at::IValue>&& stack,
const RRefId& retRRefId,
const ForkId& retForkId)
: ScriptCall(std::move(op), std::move(stack)),
retRRefId_(retRRefId),
retForkId_(retForkId) {}
ScriptRemoteCall::ScriptRemoteCall(
const c10::QualifiedName& qualifiedName,
std::vector<at::IValue>&& stack,
const RRefId& retRRefId,
const ForkId& retForkId,
const bool isAsyncExecution)
: ScriptCall(qualifiedName, std::move(stack), isAsyncExecution),
retRRefId_(retRRefId),
retForkId_(retForkId) {}
std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromIValues(
std::vector<at::IValue>& ivalues) {
// remove the last element from values and convert it back to an RRef
auto retForkId = RRefId::fromIValue(ivalues.back());
ivalues.pop_back();
auto retRRefId = ForkId::fromIValue(ivalues.back());
ivalues.pop_back();
auto scriptCallPtr = ScriptCall::fromIValues(ivalues);
if (scriptCallPtr->hasOp()) {
return std::make_unique<ScriptRemoteCall>(
scriptCallPtr->op(), std::move(ivalues), retRRefId, retForkId);
} else {
return std::make_unique<ScriptRemoteCall>(
scriptCallPtr->qualifiedName(),
std::move(ivalues),
retRRefId,
retForkId,
scriptCallPtr->isAsyncExecution());
}
}
c10::intrusive_ptr<Message> ScriptRemoteCall::toMessageImpl() && {
std::vector<IValue> ivalues;
ScriptCall::toIValues(ivalues);
ivalues.emplace_back(retRRefId_.toIValue());
ivalues.emplace_back(retForkId_.toIValue());
std::vector<torch::Tensor> tensor_table;
auto payload = jit::pickle(
c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
return c10::make_intrusive<Message>(
std::move(payload),
std::move(tensor_table),
MessageType::SCRIPT_REMOTE_CALL);
}
std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value = jit::unpickle(
payload,
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
auto values = value.toTupleRef().elements().vec();
return fromIValues(values);
}
} // namespace rpc
} // namespace distributed
} // namespace torch