mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Pickler C++ API (#23241)
Summary: This PR adds functions to wrap the Pickler and exposes them to the C++ API ](https://our.intern.facebook.com/intern/diff/16675418/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/23241 Pulled By: driazati Differential Revision: D16675418 fbshipit-source-id: 76543c81ac67c3e20a75ebc2073191bcbd6573bf
This commit is contained in:
committed by
Facebook Github Bot
parent
e81f296807
commit
01d98c7cfb
@ -371,6 +371,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/graph_executor.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/import_source.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/import.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/pickle.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/import_export_helpers.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/interpreter.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/constants.cpp
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/jit.h>
|
||||
#include <torch/script.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <string>
|
||||
@ -110,3 +111,18 @@ TEST(TorchScriptTest, TestOptionalArgMatching) {
|
||||
0, module->run_method("optional_tuple_op", torch::jit::IValue()).toInt());
|
||||
|
||||
}
|
||||
|
||||
TEST(TorchScriptTest, TestPickle) {
|
||||
torch::IValue float_value(2.3);
|
||||
|
||||
// TODO: when tensors are stored in the pickle, delete this
|
||||
std::vector<at::Tensor> tensor_table;
|
||||
auto data = torch::jit::pickle(float_value, &tensor_table);
|
||||
|
||||
std::vector<torch::IValue> ivalues =
|
||||
torch::jit::unpickle(data.data(), data.size());
|
||||
|
||||
double diff = ivalues.at(0).toDouble() - float_value.toDouble();
|
||||
double eps = 0.0001;
|
||||
ASSERT_TRUE(diff < eps && diff > -eps);
|
||||
}
|
||||
|
||||
@ -64,6 +64,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/pickler.cpp",
|
||||
"torch/csrc/jit/graph_executor.cpp",
|
||||
"torch/csrc/jit/import.cpp",
|
||||
"torch/csrc/jit/pickle.cpp",
|
||||
"torch/csrc/jit/import_export_helpers.cpp",
|
||||
"torch/csrc/jit/interpreter.cpp",
|
||||
"torch/csrc/jit/ir.cpp",
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
#include <torch/csrc/jit/import_export_helpers.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/python_print.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
#include <torch/csrc/jit/source_range_serialization.h>
|
||||
|
||||
#include <caffe2/core/types.h>
|
||||
@ -784,16 +784,8 @@ void ScriptModuleSerializer::writeTensorTable(torch::ModelDef* model_def) {
|
||||
void ScriptModuleSerializer::writePickleArchive(
|
||||
const std::string& name,
|
||||
const std::vector<IValue>& ivalues) {
|
||||
Pickler pickler(&tensor_table_);
|
||||
pickler.protocol();
|
||||
pickler.startTuple();
|
||||
for (const IValue& ivalue : ivalues) {
|
||||
pickler.pushIValue(ivalue);
|
||||
}
|
||||
pickler.endTuple();
|
||||
pickler.stop();
|
||||
writer_.writeRecord(name, pickler.stack().data(), pickler.stack().size(),
|
||||
/*compress=*/true);
|
||||
auto data = pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table_);
|
||||
writer_.writeRecord(name, data.data(), data.size(), /*compress=*/true);
|
||||
}
|
||||
|
||||
void ScriptModuleSerializer::convertModule(
|
||||
@ -878,8 +870,7 @@ void ScriptModuleSerializer::convertModule(
|
||||
module_def->mutable_torchscript_debug_arena();
|
||||
|
||||
SourceRangePickler source_range_pickler;
|
||||
source_range_pickler.pickle(source_ranges);
|
||||
const auto& range_data = source_range_pickler.get_data();
|
||||
const auto& range_data = source_range_pickler.pickle(source_ranges);
|
||||
std::stringstream debug_filename;
|
||||
debug_filename << "debug/" << module_name.str() << ".pkl";
|
||||
writer_.writeRecord(
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
#include <torch/csrc/jit/import_export_helpers.h>
|
||||
#include <torch/csrc/jit/import_source.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
#include <torch/csrc/jit/script/script_type_parser.h>
|
||||
#include <torch/csrc/jit/source_range_serialization.h>
|
||||
#include <torch/csrc/jit/source_range_serialization_impl.h>
|
||||
@ -75,7 +75,7 @@ class ScriptModuleDeserializer final {
|
||||
script::Module convertModule(const torch::ModuleDef& module_def);
|
||||
|
||||
void loadTensorTable(torch::ModelDef* model_def);
|
||||
std::vector<IValue> loadPickleArchive(const std::string& name);
|
||||
IValue loadPickleArchive(const std::string& name);
|
||||
void importCallback(const std::string& qualifier);
|
||||
void moduleSetState(const script::Module& module, IValue state);
|
||||
|
||||
@ -142,8 +142,12 @@ script::Module ScriptModuleDeserializer::deserialize(
|
||||
}
|
||||
|
||||
loadTensorTable(&model_def);
|
||||
if (model_def.proto_version() >= 2) {
|
||||
pickled_ivalues_ = loadPickleArchive("attributes.pkl");
|
||||
if (model_def.proto_version() == 2) {
|
||||
auto list = loadPickleArchive("attributes.pkl").toGenericList();
|
||||
pickled_ivalues_.insert(pickled_ivalues_.end(), list.begin(), list.end());
|
||||
} else if (model_def.proto_version() >= 3) {
|
||||
pickled_ivalues_ =
|
||||
loadPickleArchive("attributes.pkl").toTuple()->elements();
|
||||
}
|
||||
|
||||
return convertModule(module_def);
|
||||
@ -156,12 +160,12 @@ void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<IValue> ScriptModuleDeserializer::loadPickleArchive(const std::string& name) {
|
||||
IValue ScriptModuleDeserializer::loadPickleArchive(const std::string& name) {
|
||||
at::DataPtr attributes_ptr;
|
||||
size_t attributes_size;
|
||||
std::tie(attributes_ptr, attributes_size) = reader_->getRecord(name);
|
||||
Unpickler unpickler(
|
||||
attributes_ptr.get(),
|
||||
auto ivalue = unpickle(
|
||||
reinterpret_cast<const char*>(attributes_ptr.get()),
|
||||
attributes_size,
|
||||
&tensor_table_,
|
||||
[&](const c10::QualifiedName& qn) {
|
||||
@ -169,7 +173,7 @@ std::vector<IValue> ScriptModuleDeserializer::loadPickleArchive(const std::strin
|
||||
return c10::StrongTypePtr(
|
||||
compilation_unit_, compilation_unit_->get_class(qn));
|
||||
});
|
||||
return unpickler.parse_ivalue_list();
|
||||
return ivalue;
|
||||
}
|
||||
|
||||
at::Tensor ScriptModuleDeserializer::loadTensor(
|
||||
|
||||
82
torch/csrc/jit/pickle.cpp
Normal file
82
torch/csrc/jit/pickle.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void pickle(
|
||||
std::function<void(const char*, size_t)> writer,
|
||||
const IValue& ivalue,
|
||||
std::vector<at::Tensor>* tensor_table) {
|
||||
Pickler pickler(std::move(writer), tensor_table);
|
||||
|
||||
if (tensor_table == nullptr) {
|
||||
// No tensor table provided, so tensors will be stored directly in the blob.
|
||||
// Add torch.save metadata so these tensors can be de-serialized later
|
||||
pickler.torchSaveStart();
|
||||
}
|
||||
|
||||
pickler.protocol();
|
||||
pickler.pushIValue(ivalue);
|
||||
pickler.stop();
|
||||
|
||||
if (tensor_table == nullptr) {
|
||||
// No tensor table provided, so tensors will be stored directly in the blob.
|
||||
// Add torch.save metadata so these tensors can be de-serialized later
|
||||
pickler.torchSaveStop();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<char> pickle(
|
||||
const IValue& ivalue,
|
||||
std::vector<at::Tensor>* tensor_table) {
|
||||
std::vector<char> data;
|
||||
|
||||
pickle(
|
||||
[&](const char* bytes, size_t len) {
|
||||
data.insert(data.end(), bytes, bytes + len);
|
||||
},
|
||||
ivalue,
|
||||
tensor_table);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
IValue unpickle(
|
||||
std::function<void(char*, size_t)> reader,
|
||||
std::function<bool()> bounds_checker,
|
||||
std::vector<at::Tensor>* tensor_table,
|
||||
ClassResolver class_resolver) {
|
||||
Unpickler unpickler(
|
||||
std::move(reader),
|
||||
std::move(bounds_checker),
|
||||
tensor_table,
|
||||
std::move(class_resolver));
|
||||
return unpickler.parse_ivalue();
|
||||
}
|
||||
|
||||
IValue unpickle(
|
||||
const char* data,
|
||||
size_t size,
|
||||
std::vector<at::Tensor>* tensor_table,
|
||||
ClassResolver class_resolver) {
|
||||
size_t bytes_read = 0;
|
||||
return unpickle(
|
||||
[&](char* buffer, size_t len) {
|
||||
// Copy len bytes into buffer
|
||||
const char* start = data + bytes_read;
|
||||
std::memcpy(buffer, start, len);
|
||||
bytes_read += len;
|
||||
},
|
||||
[&]() {
|
||||
return bytes_read < size;
|
||||
},
|
||||
tensor_table,
|
||||
std::move(class_resolver));
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
79
torch/csrc/jit/pickle.h
Normal file
79
torch/csrc/jit/pickle.h
Normal file
@ -0,0 +1,79 @@
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
/// Save a `torch::IValue` in a format compatible with Python's `pickle` module
|
||||
///
|
||||
/// If present, `tensor_table` is a pointer to a table in which tensors that
|
||||
/// are contained within `ivalue` are stored, and the bytes returned by the
|
||||
/// pickler will only include references to these tensors in the table. This can
|
||||
/// be used to keep the binary blob size small.
|
||||
/// If not provided, tensors are stored in the same byte stream as the pickle
|
||||
/// data, similar to `torch.save()` in eager Python.
|
||||
///
|
||||
/// Pickled values can be loaded in Python and C++:
|
||||
/// \rst
|
||||
/// .. code-block:: cpp
|
||||
///
|
||||
/// torch::IValue float_value(2.3);
|
||||
///
|
||||
/// // TODO: when tensors are stored in the pickle, delete this
|
||||
/// std::vector<at::Tensor> tensor_table;
|
||||
/// auto data = torch::jit::pickle(float_value, &tensor_table);
|
||||
///
|
||||
/// std::vector<torch::IValue> ivalues =
|
||||
/// torch::jit::unpickle(data.data(), data.size());
|
||||
///
|
||||
/// .. code-block:: python
|
||||
///
|
||||
/// values = torch.load('data.pkl')
|
||||
/// print(values)
|
||||
///
|
||||
/// \endrst
|
||||
TORCH_API std::vector<char> pickle(
|
||||
const IValue& ivalue,
|
||||
std::vector<at::Tensor>* tensor_table = nullptr);
|
||||
|
||||
/// Pickle an IValue by calling a function to handle writing the data.
|
||||
///
|
||||
/// `writer` is a function that takes in a pointer to a chunk of memory and its
|
||||
/// size and consumes it.
|
||||
///
|
||||
/// See `jit::pickle` for more details.
|
||||
TORCH_API void pickle(
|
||||
std::function<void(const char* data_start, size_t data_len)> writer,
|
||||
const IValue& ivalue,
|
||||
std::vector<at::Tensor>* tensor_table = nullptr);
|
||||
|
||||
/// `reader` is a function that takes in a size to read from some pickled
|
||||
/// binary. `reader` should remember where it last read.
|
||||
///
|
||||
/// `bounds_checker` is a function that returns `true` if the reader can read
|
||||
/// more data, and `false` if it cannot (i.e. if a stream has hit its end of
|
||||
/// file)
|
||||
///
|
||||
/// See `torch::pickle` for details.
|
||||
TORCH_API IValue unpickle(
|
||||
std::function<const char*(size_t)> reader,
|
||||
std::function<bool()> bounds_chcker,
|
||||
std::vector<at::Tensor>* tensor_table = nullptr,
|
||||
ClassResolver class_resolver = nullptr);
|
||||
|
||||
/// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
|
||||
///
|
||||
/// If any `torch::IValue`s in the pickled data are `Object`s, then a
|
||||
/// `class_resolver` function must be provided.
|
||||
///
|
||||
/// See `torch::pickle` for details.
|
||||
TORCH_API IValue unpickle(
|
||||
const char* data,
|
||||
size_t size,
|
||||
std::vector<at::Tensor>* tensor_table = nullptr,
|
||||
ClassResolver class_resolver = nullptr);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
@ -52,10 +52,6 @@ const char* getClassName(PicklerClass cls) {
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<char>& Pickler::stack() {
|
||||
return stack_;
|
||||
}
|
||||
|
||||
void Pickler::protocol() {
|
||||
push<OpCode>(OpCode::PROTO);
|
||||
push<uint8_t>(PROTOCOL_VERSION);
|
||||
@ -89,6 +85,7 @@ void Pickler::torchSaveStop() {
|
||||
push<uint32_t>(key.size());
|
||||
pushBytes(key);
|
||||
}
|
||||
|
||||
push<OpCode>(OpCode::TUPLE);
|
||||
stop();
|
||||
|
||||
@ -96,7 +93,7 @@ void Pickler::torchSaveStop() {
|
||||
for (const auto& data : tensor_data_) {
|
||||
// first dump size
|
||||
push<size_t>(data.numel());
|
||||
stack_.insert(stack_.end(), data.data(), data.data() + data.sizeInBytes());
|
||||
writer_(data.data(), data.sizeInBytes());
|
||||
}
|
||||
}
|
||||
|
||||
@ -304,7 +301,7 @@ void Pickler::pushStorageOfTensor(const at::Tensor& tensor) {
|
||||
}
|
||||
|
||||
void Pickler::pushBytes(const std::string& string) {
|
||||
stack_.insert(stack_.end(), string.begin(), string.end());
|
||||
writer_(string.data(), string.size());
|
||||
}
|
||||
|
||||
void Pickler::pushGlobal(
|
||||
@ -492,49 +489,45 @@ void Pickler::pushTuple(const IValue& ivalue) {
|
||||
push<OpCode>(OpCode::TUPLE);
|
||||
}
|
||||
|
||||
std::vector<IValue> Unpickler::parse_ivalue_list() {
|
||||
IValue Unpickler::parse_ivalue() {
|
||||
run();
|
||||
TORCH_CHECK(
|
||||
stack_.size() == 1,
|
||||
"Unpickler expected 1 element on the stack, but found ",
|
||||
stack_.size());
|
||||
|
||||
auto value = stack_[0];
|
||||
if (value.isGenericList()) {
|
||||
// TODO [unpickler refactor]
|
||||
return value.toGenericListRef().vec();
|
||||
}
|
||||
return value.toTuple()->elements();
|
||||
return stack_[0];
|
||||
}
|
||||
|
||||
double Unpickler::readFloat() {
|
||||
AT_ASSERT(sizeof(double) == 8);
|
||||
AT_ASSERT(bytes_ + 8 < end_ptr_);
|
||||
double result;
|
||||
double big_endian = read<double>();
|
||||
double little_endian;
|
||||
|
||||
// Pickle floats are big endian, so reverse the bytes
|
||||
auto big_endian_ptr = reinterpret_cast<const char*>(&big_endian);
|
||||
std::reverse_copy(
|
||||
reinterpret_cast<const char*>(bytes_),
|
||||
reinterpret_cast<const char*>(bytes_ + 8),
|
||||
reinterpret_cast<char*>(&result));
|
||||
big_endian_ptr,
|
||||
big_endian_ptr + sizeof(big_endian),
|
||||
reinterpret_cast<char*>(&little_endian));
|
||||
|
||||
bytes_ += 8;
|
||||
return result;
|
||||
return little_endian;
|
||||
}
|
||||
|
||||
void Unpickler::run() {
|
||||
// Expect a PROTO opcode and protocol number at the start of blob
|
||||
auto opcode = readOpCode();
|
||||
TORCH_CHECK(
|
||||
readOpCode() == OpCode::PROTO,
|
||||
opcode == OpCode::PROTO,
|
||||
"Expected PROTO opcode at the start"
|
||||
" of pickle archive");
|
||||
" of pickle archive, found ", int(static_cast<uint8_t>(opcode)));
|
||||
uint8_t protocol = read<uint8_t>();
|
||||
TORCH_CHECK(
|
||||
protocol == 2,
|
||||
"Only Pickle protocol 2 is supported, found protocol = ",
|
||||
protocol);
|
||||
|
||||
while (bytes_ < end_ptr_) {
|
||||
while (bounds_checker_()) {
|
||||
OpCode opcode = readInstruction();
|
||||
if (opcode == OpCode::STOP) {
|
||||
return;
|
||||
@ -627,15 +620,12 @@ OpCode Unpickler::readInstruction() {
|
||||
case OpCode::LONG1: {
|
||||
// Only read LONG1s with 8 as the length
|
||||
uint8_t length = read<uint8_t>();
|
||||
AT_ASSERT(length == 8);
|
||||
TORCH_CHECK(length == 8, "Expected length to be 8, got ", int(length));
|
||||
stack_.emplace_back(int64_t(read<int64_t>()));
|
||||
} break;
|
||||
case OpCode::BINUNICODE: {
|
||||
uint32_t length = read<uint32_t>();
|
||||
const char* characters = reinterpret_cast<const char*>(bytes_);
|
||||
AT_ASSERT(bytes_ + length < end_ptr_);
|
||||
bytes_ += length;
|
||||
stack_.emplace_back(std::string(characters, /*n=*/length));
|
||||
stack_.emplace_back(readBytes(length));
|
||||
} break;
|
||||
case OpCode::BINFLOAT:
|
||||
stack_.emplace_back(readFloat());
|
||||
@ -705,6 +695,10 @@ OpCode Unpickler::readInstruction() {
|
||||
stack_.pop_back();
|
||||
switch (pickler_class) {
|
||||
case PicklerClass::TENSOR:
|
||||
TORCH_CHECK(
|
||||
tensor_table_,
|
||||
"Found a tensor table reference but Pickler"
|
||||
" has no tensor table\n");
|
||||
stack_.emplace_back(tensor_table_->at(data.toInt()));
|
||||
break;
|
||||
case PicklerClass::INTLIST:
|
||||
@ -772,13 +766,21 @@ OpCode Unpickler::readInstruction() {
|
||||
"Unknown opcode for unpickling at ",
|
||||
reinterpret_cast<void*>(opcode),
|
||||
": ",
|
||||
static_cast<uint8_t>(opcode));
|
||||
int(static_cast<uint8_t>(opcode)));
|
||||
}
|
||||
return opcode;
|
||||
}
|
||||
|
||||
// Pop all the list items off of the stack and append them to the list at the
|
||||
// corresponding MARK
|
||||
// Read a number of bytes from the input stream
|
||||
std::string Unpickler::readBytes(size_t length) {
|
||||
std::string data(length, 0);
|
||||
// This is fine since C++11 has contiguous strings
|
||||
reader_(&data[0], length);
|
||||
return data;
|
||||
}
|
||||
|
||||
// Pop all the list items off of the stack and append them to the list at
|
||||
// the corresponding MARK
|
||||
void Unpickler::readList() {
|
||||
size_t start = marks_.back();
|
||||
marks_.pop_back();
|
||||
@ -829,33 +831,24 @@ inline bool is_valid_python_id_char(char c) {
|
||||
|
||||
// Read a newline terminated string
|
||||
std::string Unpickler::readString() {
|
||||
const char* chars = reinterpret_cast<const char*>(bytes_);
|
||||
const char* char_end_ptr = reinterpret_cast<const char*>(end_ptr_);
|
||||
size_t n = 0;
|
||||
std::stringstream ss;
|
||||
while (true) {
|
||||
char c = chars[n];
|
||||
char c = read<char>();
|
||||
if (c == '\n') {
|
||||
break;
|
||||
}
|
||||
|
||||
ss << c;
|
||||
|
||||
// Simple check just in case there is no terminating '\n'
|
||||
TORCH_CHECK(
|
||||
is_valid_python_id_char(c),
|
||||
"Found character '",
|
||||
uint8_t(c),
|
||||
"' in string, "
|
||||
int(uint8_t(c)),
|
||||
"' in string, ",
|
||||
"strings must be qualified Python identifiers");
|
||||
|
||||
// Increment after to exclude newline from string
|
||||
++n;
|
||||
TORCH_CHECK(
|
||||
chars + n < char_end_ptr,
|
||||
"Unpickler overran buffer while reading a string (expected a newline)");
|
||||
}
|
||||
|
||||
// Increment by string length + newline char
|
||||
bytes_ += n + 1;
|
||||
return std::string(chars, n);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
OpCode Unpickler::readOpCode() {
|
||||
|
||||
@ -11,6 +11,9 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using ClassResolver =
|
||||
std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
|
||||
|
||||
// See Python's pickletools.py for a detailed description of each of these codes
|
||||
enum class OpCode : char {
|
||||
MARK = '(',
|
||||
@ -122,10 +125,10 @@ class Pickler {
|
||||
TH_DISALLOW_COPY_AND_ASSIGN(Pickler);
|
||||
|
||||
public:
|
||||
Pickler(std::vector<at::Tensor>* tensor_table = nullptr)
|
||||
: tensor_table_(tensor_table) {}
|
||||
|
||||
const std::vector<char>& stack();
|
||||
Pickler(
|
||||
std::function<void(const char*, size_t)> writer,
|
||||
std::vector<at::Tensor>* tensor_table = nullptr)
|
||||
: writer_(writer), tensor_table_(tensor_table) {}
|
||||
|
||||
// Push protocol onto the stack
|
||||
void protocol();
|
||||
@ -186,9 +189,12 @@ class Pickler {
|
||||
template <typename T>
|
||||
void push(typename std::common_type<T>::type value) {
|
||||
const char* begin = reinterpret_cast<const char*>(&value);
|
||||
stack_.insert(stack_.end(), begin, begin + sizeof(T));
|
||||
writer_(begin, sizeof(T));
|
||||
}
|
||||
|
||||
// Stream to write binary data to
|
||||
std::function<void(const char*, size_t)> writer_;
|
||||
|
||||
// Stack of opcodes/data
|
||||
std::vector<char> stack_;
|
||||
|
||||
@ -228,32 +234,29 @@ class Unpickler {
|
||||
|
||||
public:
|
||||
Unpickler(
|
||||
const void* data,
|
||||
size_t size,
|
||||
std::function<void(char*, size_t)> reader,
|
||||
std::function<bool()> bounds_checker,
|
||||
const std::vector<at::Tensor>* tensor_table,
|
||||
std::function<c10::StrongTypePtr(const c10::QualifiedName&)>
|
||||
class_resolver)
|
||||
: bytes_(static_cast<const uint8_t*>(data)),
|
||||
end_ptr_(bytes_ + size),
|
||||
ClassResolver class_resolver)
|
||||
: reader_(reader),
|
||||
bounds_checker_(bounds_checker),
|
||||
tensor_table_(tensor_table),
|
||||
class_resolver_(class_resolver) {}
|
||||
class_resolver_(std::move(class_resolver)) {}
|
||||
|
||||
std::vector<IValue> parse_ivalue_list();
|
||||
IValue parse_ivalue();
|
||||
|
||||
private:
|
||||
// No arguments ensures that a template arugment must be specified
|
||||
// so that the number of bytes read / type read is explicit
|
||||
template <typename T>
|
||||
T read() {
|
||||
TORCH_CHECK(
|
||||
bytes_ + sizeof(T) <= end_ptr_,
|
||||
"Unpickler overran buffer while reading a value");
|
||||
T item;
|
||||
std::memcpy(&item, bytes_, sizeof(T));
|
||||
bytes_ += sizeof(T);
|
||||
reader_(reinterpret_cast<char*>(&item), sizeof(item));
|
||||
return item;
|
||||
}
|
||||
|
||||
std::string readBytes(size_t num_bytes);
|
||||
|
||||
double readFloat();
|
||||
OpCode readInstruction();
|
||||
OpCode readOpCode();
|
||||
@ -262,18 +265,24 @@ class Unpickler {
|
||||
void setInput(size_t memo_id);
|
||||
void run();
|
||||
|
||||
// Returns a pointer to the number of bytes requested. This should state-fully
|
||||
// remember how many bytes have been read
|
||||
std::function<void(char*, size_t)> reader_;
|
||||
|
||||
// Check if the stream has gone past its size
|
||||
std::function<bool()> bounds_checker_;
|
||||
|
||||
std::vector<IValue> stack_;
|
||||
|
||||
// globals are represented on the stack as IValue integer indices
|
||||
// into this list
|
||||
std::vector<std::function<void(void)>> globals_;
|
||||
std::vector<IValue> memo_table_;
|
||||
std::vector<size_t> marks_;
|
||||
const uint8_t* bytes_;
|
||||
const uint8_t* end_ptr_;
|
||||
const std::vector<at::Tensor>* tensor_table_;
|
||||
|
||||
// optionally nullptr, needs to be present for creating classes
|
||||
std::function<c10::StrongTypePtr(const c10::QualifiedName&)> class_resolver_;
|
||||
ClassResolver class_resolver_;
|
||||
IValue empty_tuple_;
|
||||
};
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@
|
||||
#include <torch/csrc/jit/graph_executor.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
#include <torch/csrc/jit/print_handler.h>
|
||||
#include <torch/csrc/jit/profiling_record.h>
|
||||
#include <torch/csrc/jit/script/compilation_unit.h>
|
||||
@ -625,19 +625,14 @@ RegisterOperators reg(
|
||||
"aten::save(t item, str filename) -> ()",
|
||||
[](Stack& stack) {
|
||||
auto filename = pop(stack).toStringRef();
|
||||
auto value = pop(stack);
|
||||
auto ivalue = pop(stack);
|
||||
|
||||
// Pickle the tensor
|
||||
Pickler p;
|
||||
p.torchSaveStart();
|
||||
p.protocol();
|
||||
p.pushIValue(value);
|
||||
p.stop();
|
||||
p.torchSaveStop();
|
||||
auto data = pickle({ivalue});
|
||||
|
||||
// Write file
|
||||
std::fstream output(filename, std::ios::out | std::ios::binary);
|
||||
output.write(p.stack().data(), p.stack().size());
|
||||
output.write(data.data(), data.size());
|
||||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
|
||||
@ -240,6 +240,7 @@ static std::vector<at::Tensor> loadTensors(const std::vector<Slot>& slots) {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<std::shared_ptr<Graph>, std::vector<at::Tensor>> Method::_lowered_graph() {
|
||||
auto result = lower_graph(owner().module_object(), *graph());
|
||||
return std::make_pair(result.first, loadTensors(result.second));
|
||||
|
||||
@ -151,7 +151,7 @@ struct CAFFE2_API SourceRange {
|
||||
bool operator!=(const SourceRange& rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
|
||||
c10::optional<SourceRange> findSourceRangeThatGenerated() const {
|
||||
if (!source_) {
|
||||
return c10::nullopt;
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
#include <torch/csrc/jit/source_range_serialization_impl.h>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -83,23 +83,22 @@ c10::IValue SourceRangeSerializer::serialize_source(
|
||||
}
|
||||
|
||||
SourceRangePickler::SourceRangePickler()
|
||||
: p(new Pickler()), srs(new SourceRangeSerializer()) {}
|
||||
: srs(new SourceRangeSerializer()) {}
|
||||
|
||||
void SourceRangePickler::pickle(const SourceRangeRecords& ranges) {
|
||||
p->protocol();
|
||||
p->startTuple();
|
||||
std::vector<char> SourceRangePickler::pickle(const SourceRangeRecords& ranges) {
|
||||
std::vector<c10::IValue> ivalues;
|
||||
for (const auto& range : ranges) {
|
||||
std::vector<c10::IValue> row_elems{(int64_t)range.bytes,
|
||||
srs->serialize(range.range)};
|
||||
p->pushIValue(c10::ivalue::Tuple::create(std::move(row_elems)));
|
||||
ivalues.emplace_back(c10::ivalue::Tuple::create(std::move(row_elems)));
|
||||
}
|
||||
p->endTuple();
|
||||
p->stop();
|
||||
std::vector<at::Tensor> table;
|
||||
auto ivalue = c10::ivalue::Tuple::create(std::move(ivalues));
|
||||
auto result = jit::pickle(ivalue, &table);
|
||||
TORCH_CHECK(table.size() == 0, "Expected 0 tensors to be written");
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::vector<char>& SourceRangePickler::get_data() {
|
||||
return p->stack();
|
||||
}
|
||||
|
||||
ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
|
||||
at::DataPtr&& data,
|
||||
@ -114,8 +113,9 @@ void ConcreteSourceRangeUnpickler::unpickle() {
|
||||
return;
|
||||
}
|
||||
|
||||
Unpickler up(data.get(), size, nullptr, nullptr);
|
||||
auto ivalues = up.parse_ivalue_list();
|
||||
auto ivalues = jit::unpickle(reinterpret_cast<const char*>(data.get()), size)
|
||||
.toTuple()
|
||||
->elements();
|
||||
|
||||
unpickled_records = std::make_shared<SourceRangeRecords>();
|
||||
for (auto& val : ivalues) {
|
||||
@ -149,4 +149,4 @@ c10::optional<SourceRange> ConcreteSourceRangeUnpickler::
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch
|
||||
|
||||
@ -21,12 +21,9 @@ class SourceRangePickler {
|
||||
public:
|
||||
SourceRangePickler();
|
||||
|
||||
void pickle(const SourceRangeRecords& ranges);
|
||||
|
||||
const std::vector<char>& get_data();
|
||||
std::vector<char> pickle(const SourceRangeRecords& ranges);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Pickler> p;
|
||||
std::shared_ptr<SourceRangeSerializer> srs;
|
||||
};
|
||||
|
||||
@ -39,4 +36,4 @@ class SourceRangeUnpickler {
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch
|
||||
|
||||
@ -4,5 +4,6 @@
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/jit/custom_operator.h>
|
||||
#include <torch/csrc/jit/import.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
Reference in New Issue
Block a user