mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 22:25:10 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23281 Test Plan: Imported from OSS Differential Revision: D16452815 Pulled By: zdevito fbshipit-source-id: 918eef3ad444b598ab655c39037e4baafdcb51e1
39 lines
973 B
C++
39 lines
973 B
C++
#include <torch/csrc/distributed/rpc/script_ret.h>
|
|
#include <torch/csrc/jit/pickle.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
namespace {
|
|
|
|
using torch::jit::Pickler;
|
|
using torch::jit::Unpickler;
|
|
|
|
} // namespace
|
|
|
|
ScriptRet::ScriptRet(at::IValue&& value) : value_(value) {}
|
|
|
|
const at::IValue& ScriptRet::value() {
|
|
return value_;
|
|
}
|
|
|
|
Message ScriptRet::toMessage() {
|
|
std::vector<torch::Tensor> tensor_table;
|
|
auto payload = jit::pickle(value_, &tensor_table);;
|
|
return Message(std::move(payload),
|
|
std::move(tensor_table),
|
|
MessageType::SCRIPT_RET);
|
|
}
|
|
|
|
ScriptRet ScriptRet::fromMessage(const Message& message) {
|
|
auto payload = static_cast<const char*>(message.payload().data());
|
|
auto payload_size = message.payload().size();
|
|
auto value = jit::unpickle(payload, payload_size, nullptr, &message.tensors());
|
|
return ScriptRet(std::move(value));
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|