mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30915 Since we now have C++14, we don't need these c10::guts helpers anymore ghstack-source-id: 95777609 Test Plan: waitforsandcastle Differential Revision: D18869639 fbshipit-source-id: 97716f932297c64c6e814410ac47b444c33d4e2e
54 lines
1.6 KiB
C++
54 lines
1.6 KiB
C++
#include <torch/csrc/distributed/rpc/python_remote_call.h>
|
|
#include <torch/csrc/jit/pickle.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
PythonRemoteCall::PythonRemoteCall(
|
|
SerializedPyObj&& serializedPyObj,
|
|
at::IValue retRRefId,
|
|
at::IValue retForkId)
|
|
: serializedPyObj_(std::move(serializedPyObj)),
|
|
retRRefId_(std::move(retRRefId)),
|
|
retForkId_(std::move(retForkId)) {}
|
|
|
|
Message PythonRemoteCall::toMessage() && {
|
|
std::vector<IValue> ivalues = serializedPyObj_.toIValues();
|
|
ivalues.emplace_back(retRRefId_);
|
|
ivalues.emplace_back(retForkId_);
|
|
|
|
std::vector<torch::Tensor> tensor_table;
|
|
auto payload =
|
|
jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table);
|
|
|
|
return Message(
|
|
std::move(payload),
|
|
std::move(tensor_table),
|
|
MessageType::PYTHON_REMOTE_CALL);
|
|
}
|
|
|
|
std::unique_ptr<PythonRemoteCall> PythonRemoteCall::fromMessage(
|
|
const Message& message) {
|
|
auto payload = static_cast<const char*>(message.payload().data());
|
|
auto payload_size = message.payload().size();
|
|
|
|
auto value =
|
|
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
|
|
auto values = value.toTuple()->elements();
|
|
|
|
// remove the last element from values and convert it back to an RRef
|
|
auto retForkId = std::move(values.back());
|
|
values.pop_back();
|
|
auto retRRefId = std::move(values.back());
|
|
values.pop_back();
|
|
auto serializedPyObj = SerializedPyObj::fromIValues(std::move(values));
|
|
|
|
return std::make_unique<PythonRemoteCall>(
|
|
std::move(serializedPyObj), std::move(retRRefId), std::move(retForkId));
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|