Add Pickler C++ API (#23241)

Summary:
This PR adds functions to wrap the Pickler and exposes them to the C++ API
](https://our.intern.facebook.com/intern/diff/16675418/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23241

Pulled By: driazati

Differential Revision: D16675418

fbshipit-source-id: 76543c81ac67c3e20a75ebc2073191bcbd6573bf
This commit is contained in:
davidriazati
2019-08-09 12:16:10 -07:00
committed by Facebook Github Bot
parent e81f296807
commit 01d98c7cfb
15 changed files with 288 additions and 118 deletions

View File

@ -371,6 +371,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/graph_executor.cpp
${TORCH_SRC_DIR}/csrc/jit/import_source.cpp
${TORCH_SRC_DIR}/csrc/jit/import.cpp
${TORCH_SRC_DIR}/csrc/jit/pickle.cpp
${TORCH_SRC_DIR}/csrc/jit/import_export_helpers.cpp
${TORCH_SRC_DIR}/csrc/jit/interpreter.cpp
${TORCH_SRC_DIR}/csrc/jit/constants.cpp

View File

@ -1,6 +1,7 @@
#include <gtest/gtest.h>
#include <torch/jit.h>
#include <torch/script.h>
#include <torch/types.h>
#include <string>
@ -110,3 +111,18 @@ TEST(TorchScriptTest, TestOptionalArgMatching) {
0, module->run_method("optional_tuple_op", torch::jit::IValue()).toInt());
}
TEST(TorchScriptTest, TestPickle) {
torch::IValue float_value(2.3);
// TODO: when tensors are stored in the pickle, delete this
std::vector<at::Tensor> tensor_table;
auto data = torch::jit::pickle(float_value, &tensor_table);
std::vector<torch::IValue> ivalues =
torch::jit::unpickle(data.data(), data.size());
double diff = ivalues.at(0).toDouble() - float_value.toDouble();
double eps = 0.0001;
ASSERT_TRUE(diff < eps && diff > -eps);
}

View File

@ -64,6 +64,7 @@ libtorch_sources = [
"torch/csrc/jit/pickler.cpp",
"torch/csrc/jit/graph_executor.cpp",
"torch/csrc/jit/import.cpp",
"torch/csrc/jit/pickle.cpp",
"torch/csrc/jit/import_export_helpers.cpp",
"torch/csrc/jit/interpreter.cpp",
"torch/csrc/jit/ir.cpp",

View File

@ -10,7 +10,7 @@
#include <torch/csrc/jit/import_export_helpers.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/jit/source_range_serialization.h>
#include <caffe2/core/types.h>
@ -784,16 +784,8 @@ void ScriptModuleSerializer::writeTensorTable(torch::ModelDef* model_def) {
void ScriptModuleSerializer::writePickleArchive(
const std::string& name,
const std::vector<IValue>& ivalues) {
Pickler pickler(&tensor_table_);
pickler.protocol();
pickler.startTuple();
for (const IValue& ivalue : ivalues) {
pickler.pushIValue(ivalue);
}
pickler.endTuple();
pickler.stop();
writer_.writeRecord(name, pickler.stack().data(), pickler.stack().size(),
/*compress=*/true);
auto data = pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table_);
writer_.writeRecord(name, data.data(), data.size(), /*compress=*/true);
}
void ScriptModuleSerializer::convertModule(
@ -878,8 +870,7 @@ void ScriptModuleSerializer::convertModule(
module_def->mutable_torchscript_debug_arena();
SourceRangePickler source_range_pickler;
source_range_pickler.pickle(source_ranges);
const auto& range_data = source_range_pickler.get_data();
const auto& range_data = source_range_pickler.pickle(source_ranges);
std::stringstream debug_filename;
debug_filename << "debug/" << module_name.str() << ".pkl";
writer_.writeRecord(

View File

@ -7,7 +7,7 @@
#include <torch/csrc/jit/import_export_helpers.h>
#include <torch/csrc/jit/import_source.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/jit/script/script_type_parser.h>
#include <torch/csrc/jit/source_range_serialization.h>
#include <torch/csrc/jit/source_range_serialization_impl.h>
@ -75,7 +75,7 @@ class ScriptModuleDeserializer final {
script::Module convertModule(const torch::ModuleDef& module_def);
void loadTensorTable(torch::ModelDef* model_def);
std::vector<IValue> loadPickleArchive(const std::string& name);
IValue loadPickleArchive(const std::string& name);
void importCallback(const std::string& qualifier);
void moduleSetState(const script::Module& module, IValue state);
@ -142,8 +142,12 @@ script::Module ScriptModuleDeserializer::deserialize(
}
loadTensorTable(&model_def);
if (model_def.proto_version() >= 2) {
pickled_ivalues_ = loadPickleArchive("attributes.pkl");
if (model_def.proto_version() == 2) {
auto list = loadPickleArchive("attributes.pkl").toGenericList();
pickled_ivalues_.insert(pickled_ivalues_.end(), list.begin(), list.end());
} else if (model_def.proto_version() >= 3) {
pickled_ivalues_ =
loadPickleArchive("attributes.pkl").toTuple()->elements();
}
return convertModule(module_def);
@ -156,12 +160,12 @@ void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
}
}
std::vector<IValue> ScriptModuleDeserializer::loadPickleArchive(const std::string& name) {
IValue ScriptModuleDeserializer::loadPickleArchive(const std::string& name) {
at::DataPtr attributes_ptr;
size_t attributes_size;
std::tie(attributes_ptr, attributes_size) = reader_->getRecord(name);
Unpickler unpickler(
attributes_ptr.get(),
auto ivalue = unpickle(
reinterpret_cast<const char*>(attributes_ptr.get()),
attributes_size,
&tensor_table_,
[&](const c10::QualifiedName& qn) {
@ -169,7 +173,7 @@ std::vector<IValue> ScriptModuleDeserializer::loadPickleArchive(const std::strin
return c10::StrongTypePtr(
compilation_unit_, compilation_unit_->get_class(qn));
});
return unpickler.parse_ivalue_list();
return ivalue;
}
at::Tensor ScriptModuleDeserializer::loadTensor(

82
torch/csrc/jit/pickle.cpp Normal file
View File

@ -0,0 +1,82 @@
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/jit/pickler.h>
namespace torch {
namespace jit {
void pickle(
std::function<void(const char*, size_t)> writer,
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table) {
Pickler pickler(std::move(writer), tensor_table);
if (tensor_table == nullptr) {
// No tensor table provided, so tensors will be stored directly in the blob.
// Add torch.save metadata so these tensors can be de-serialized later
pickler.torchSaveStart();
}
pickler.protocol();
pickler.pushIValue(ivalue);
pickler.stop();
if (tensor_table == nullptr) {
// No tensor table provided, so tensors will be stored directly in the blob.
// Add torch.save metadata so these tensors can be de-serialized later
pickler.torchSaveStop();
}
}
std::vector<char> pickle(
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table) {
std::vector<char> data;
pickle(
[&](const char* bytes, size_t len) {
data.insert(data.end(), bytes, bytes + len);
},
ivalue,
tensor_table);
return data;
}
IValue unpickle(
std::function<void(char*, size_t)> reader,
std::function<bool()> bounds_checker,
std::vector<at::Tensor>* tensor_table,
ClassResolver class_resolver) {
Unpickler unpickler(
std::move(reader),
std::move(bounds_checker),
tensor_table,
std::move(class_resolver));
return unpickler.parse_ivalue();
}
IValue unpickle(
const char* data,
size_t size,
std::vector<at::Tensor>* tensor_table,
ClassResolver class_resolver) {
size_t bytes_read = 0;
return unpickle(
[&](char* buffer, size_t len) {
// Copy len bytes into buffer
const char* start = data + bytes_read;
std::memcpy(buffer, start, len);
bytes_read += len;
},
[&]() {
return bytes_read < size;
},
tensor_table,
std::move(class_resolver));
}
} // namespace jit
} // namespace torch

79
torch/csrc/jit/pickle.h Normal file
View File

@ -0,0 +1,79 @@
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/pickler.h>
namespace torch {
namespace jit {
/// Save a `torch::IValue` in a format compatible with Python's `pickle` module
///
/// If present, `tensor_table` is a pointer to a table in which tensors that
/// are contained within `ivalue` are stored, and the bytes returned by the
/// pickler will only include references to these tensors in the table. This can
/// be used to keep the binary blob size small.
/// If not provided, tensors are stored in the same byte stream as the pickle
/// data, similar to `torch.save()` in eager Python.
///
/// Pickled values can be loaded in Python and C++:
/// \rst
/// .. code-block:: cpp
///
/// torch::IValue float_value(2.3);
///
/// // TODO: when tensors are stored in the pickle, delete this
/// std::vector<at::Tensor> tensor_table;
/// auto data = torch::jit::pickle(float_value, &tensor_table);
///
/// std::vector<torch::IValue> ivalues =
/// torch::jit::unpickle(data.data(), data.size());
///
/// .. code-block:: python
///
/// values = torch.load('data.pkl')
/// print(values)
///
/// \endrst
TORCH_API std::vector<char> pickle(
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table = nullptr);
/// Pickle an IValue by calling a function to handle writing the data.
///
/// `writer` is a function that takes in a pointer to a chunk of memory and its
/// size and consumes it.
///
/// See `jit::pickle` for more details.
TORCH_API void pickle(
std::function<void(const char* data_start, size_t data_len)> writer,
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table = nullptr);
/// `reader` is a function that takes in a size to read from some pickled
/// binary. `reader` should remember where it last read.
///
/// `bounds_checker` is a function that returns `true` if the reader can read
/// more data, and `false` if it cannot (i.e. if a stream has hit its end of
/// file)
///
/// See `torch::pickle` for details.
TORCH_API IValue unpickle(
std::function<const char*(size_t)> reader,
std::function<bool()> bounds_chcker,
std::vector<at::Tensor>* tensor_table = nullptr,
ClassResolver class_resolver = nullptr);
/// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
///
/// If any `torch::IValue`s in the pickled data are `Object`s, then a
/// `class_resolver` function must be provided.
///
/// See `torch::pickle` for details.
TORCH_API IValue unpickle(
const char* data,
size_t size,
std::vector<at::Tensor>* tensor_table = nullptr,
ClassResolver class_resolver = nullptr);
} // namespace jit
} // namespace torch

View File

@ -52,10 +52,6 @@ const char* getClassName(PicklerClass cls) {
}
}
const std::vector<char>& Pickler::stack() {
return stack_;
}
void Pickler::protocol() {
push<OpCode>(OpCode::PROTO);
push<uint8_t>(PROTOCOL_VERSION);
@ -89,6 +85,7 @@ void Pickler::torchSaveStop() {
push<uint32_t>(key.size());
pushBytes(key);
}
push<OpCode>(OpCode::TUPLE);
stop();
@ -96,7 +93,7 @@ void Pickler::torchSaveStop() {
for (const auto& data : tensor_data_) {
// first dump size
push<size_t>(data.numel());
stack_.insert(stack_.end(), data.data(), data.data() + data.sizeInBytes());
writer_(data.data(), data.sizeInBytes());
}
}
@ -304,7 +301,7 @@ void Pickler::pushStorageOfTensor(const at::Tensor& tensor) {
}
void Pickler::pushBytes(const std::string& string) {
stack_.insert(stack_.end(), string.begin(), string.end());
writer_(string.data(), string.size());
}
void Pickler::pushGlobal(
@ -492,49 +489,45 @@ void Pickler::pushTuple(const IValue& ivalue) {
push<OpCode>(OpCode::TUPLE);
}
std::vector<IValue> Unpickler::parse_ivalue_list() {
IValue Unpickler::parse_ivalue() {
run();
TORCH_CHECK(
stack_.size() == 1,
"Unpickler expected 1 element on the stack, but found ",
stack_.size());
auto value = stack_[0];
if (value.isGenericList()) {
// TODO [unpickler refactor]
return value.toGenericListRef().vec();
}
return value.toTuple()->elements();
return stack_[0];
}
double Unpickler::readFloat() {
AT_ASSERT(sizeof(double) == 8);
AT_ASSERT(bytes_ + 8 < end_ptr_);
double result;
double big_endian = read<double>();
double little_endian;
// Pickle floats are big endian, so reverse the bytes
auto big_endian_ptr = reinterpret_cast<const char*>(&big_endian);
std::reverse_copy(
reinterpret_cast<const char*>(bytes_),
reinterpret_cast<const char*>(bytes_ + 8),
reinterpret_cast<char*>(&result));
big_endian_ptr,
big_endian_ptr + sizeof(big_endian),
reinterpret_cast<char*>(&little_endian));
bytes_ += 8;
return result;
return little_endian;
}
void Unpickler::run() {
// Expect a PROTO opcode and protocol number at the start of blob
auto opcode = readOpCode();
TORCH_CHECK(
readOpCode() == OpCode::PROTO,
opcode == OpCode::PROTO,
"Expected PROTO opcode at the start"
" of pickle archive");
" of pickle archive, found ", int(static_cast<uint8_t>(opcode)));
uint8_t protocol = read<uint8_t>();
TORCH_CHECK(
protocol == 2,
"Only Pickle protocol 2 is supported, found protocol = ",
protocol);
while (bytes_ < end_ptr_) {
while (bounds_checker_()) {
OpCode opcode = readInstruction();
if (opcode == OpCode::STOP) {
return;
@ -627,15 +620,12 @@ OpCode Unpickler::readInstruction() {
case OpCode::LONG1: {
// Only read LONG1s with 8 as the length
uint8_t length = read<uint8_t>();
AT_ASSERT(length == 8);
TORCH_CHECK(length == 8, "Expected length to be 8, got ", int(length));
stack_.emplace_back(int64_t(read<int64_t>()));
} break;
case OpCode::BINUNICODE: {
uint32_t length = read<uint32_t>();
const char* characters = reinterpret_cast<const char*>(bytes_);
AT_ASSERT(bytes_ + length < end_ptr_);
bytes_ += length;
stack_.emplace_back(std::string(characters, /*n=*/length));
stack_.emplace_back(readBytes(length));
} break;
case OpCode::BINFLOAT:
stack_.emplace_back(readFloat());
@ -705,6 +695,10 @@ OpCode Unpickler::readInstruction() {
stack_.pop_back();
switch (pickler_class) {
case PicklerClass::TENSOR:
TORCH_CHECK(
tensor_table_,
"Found a tensor table reference but Pickler"
" has no tensor table\n");
stack_.emplace_back(tensor_table_->at(data.toInt()));
break;
case PicklerClass::INTLIST:
@ -772,13 +766,21 @@ OpCode Unpickler::readInstruction() {
"Unknown opcode for unpickling at ",
reinterpret_cast<void*>(opcode),
": ",
static_cast<uint8_t>(opcode));
int(static_cast<uint8_t>(opcode)));
}
return opcode;
}
// Pop all the list items off of the stack and append them to the list at the
// corresponding MARK
// Read a number of bytes from the input stream
std::string Unpickler::readBytes(size_t length) {
std::string data(length, 0);
// This is fine since C++11 has contiguous strings
reader_(&data[0], length);
return data;
}
// Pop all the list items off of the stack and append them to the list at
// the corresponding MARK
void Unpickler::readList() {
size_t start = marks_.back();
marks_.pop_back();
@ -829,33 +831,24 @@ inline bool is_valid_python_id_char(char c) {
// Read a newline terminated string
std::string Unpickler::readString() {
const char* chars = reinterpret_cast<const char*>(bytes_);
const char* char_end_ptr = reinterpret_cast<const char*>(end_ptr_);
size_t n = 0;
std::stringstream ss;
while (true) {
char c = chars[n];
char c = read<char>();
if (c == '\n') {
break;
}
ss << c;
// Simple check just in case there is no terminating '\n'
TORCH_CHECK(
is_valid_python_id_char(c),
"Found character '",
uint8_t(c),
"' in string, "
int(uint8_t(c)),
"' in string, ",
"strings must be qualified Python identifiers");
// Increment after to exclude newline from string
++n;
TORCH_CHECK(
chars + n < char_end_ptr,
"Unpickler overran buffer while reading a string (expected a newline)");
}
// Increment by string length + newline char
bytes_ += n + 1;
return std::string(chars, n);
return ss.str();
}
OpCode Unpickler::readOpCode() {

View File

@ -11,6 +11,9 @@
namespace torch {
namespace jit {
using ClassResolver =
std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
// See Python's pickletools.py for a detailed description of each of these codes
enum class OpCode : char {
MARK = '(',
@ -122,10 +125,10 @@ class Pickler {
TH_DISALLOW_COPY_AND_ASSIGN(Pickler);
public:
Pickler(std::vector<at::Tensor>* tensor_table = nullptr)
: tensor_table_(tensor_table) {}
const std::vector<char>& stack();
Pickler(
std::function<void(const char*, size_t)> writer,
std::vector<at::Tensor>* tensor_table = nullptr)
: writer_(writer), tensor_table_(tensor_table) {}
// Push protocol onto the stack
void protocol();
@ -186,9 +189,12 @@ class Pickler {
template <typename T>
void push(typename std::common_type<T>::type value) {
const char* begin = reinterpret_cast<const char*>(&value);
stack_.insert(stack_.end(), begin, begin + sizeof(T));
writer_(begin, sizeof(T));
}
// Stream to write binary data to
std::function<void(const char*, size_t)> writer_;
// Stack of opcodes/data
std::vector<char> stack_;
@ -228,32 +234,29 @@ class Unpickler {
public:
Unpickler(
const void* data,
size_t size,
std::function<void(char*, size_t)> reader,
std::function<bool()> bounds_checker,
const std::vector<at::Tensor>* tensor_table,
std::function<c10::StrongTypePtr(const c10::QualifiedName&)>
class_resolver)
: bytes_(static_cast<const uint8_t*>(data)),
end_ptr_(bytes_ + size),
ClassResolver class_resolver)
: reader_(reader),
bounds_checker_(bounds_checker),
tensor_table_(tensor_table),
class_resolver_(class_resolver) {}
class_resolver_(std::move(class_resolver)) {}
std::vector<IValue> parse_ivalue_list();
IValue parse_ivalue();
private:
// No arguments ensures that a template arugment must be specified
// so that the number of bytes read / type read is explicit
template <typename T>
T read() {
TORCH_CHECK(
bytes_ + sizeof(T) <= end_ptr_,
"Unpickler overran buffer while reading a value");
T item;
std::memcpy(&item, bytes_, sizeof(T));
bytes_ += sizeof(T);
reader_(reinterpret_cast<char*>(&item), sizeof(item));
return item;
}
std::string readBytes(size_t num_bytes);
double readFloat();
OpCode readInstruction();
OpCode readOpCode();
@ -262,18 +265,24 @@ class Unpickler {
void setInput(size_t memo_id);
void run();
// Returns a pointer to the number of bytes requested. This should state-fully
// remember how many bytes have been read
std::function<void(char*, size_t)> reader_;
// Check if the stream has gone past its size
std::function<bool()> bounds_checker_;
std::vector<IValue> stack_;
// globals are represented on the stack as IValue integer indices
// into this list
std::vector<std::function<void(void)>> globals_;
std::vector<IValue> memo_table_;
std::vector<size_t> marks_;
const uint8_t* bytes_;
const uint8_t* end_ptr_;
const std::vector<at::Tensor>* tensor_table_;
// optionally nullptr, needs to be present for creating classes
std::function<c10::StrongTypePtr(const c10::QualifiedName&)> class_resolver_;
ClassResolver class_resolver_;
IValue empty_tuple_;
};

View File

@ -9,7 +9,7 @@
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/jit/print_handler.h>
#include <torch/csrc/jit/profiling_record.h>
#include <torch/csrc/jit/script/compilation_unit.h>
@ -625,19 +625,14 @@ RegisterOperators reg(
"aten::save(t item, str filename) -> ()",
[](Stack& stack) {
auto filename = pop(stack).toStringRef();
auto value = pop(stack);
auto ivalue = pop(stack);
// Pickle the tensor
Pickler p;
p.torchSaveStart();
p.protocol();
p.pushIValue(value);
p.stop();
p.torchSaveStop();
auto data = pickle({ivalue});
// Write file
std::fstream output(filename, std::ios::out | std::ios::binary);
output.write(p.stack().data(), p.stack().size());
output.write(data.data(), data.size());
return 0;
},
aliasAnalysisFromSchema()),

View File

@ -240,6 +240,7 @@ static std::vector<at::Tensor> loadTensors(const std::vector<Slot>& slots) {
}
return result;
}
std::pair<std::shared_ptr<Graph>, std::vector<at::Tensor>> Method::_lowered_graph() {
auto result = lower_graph(owner().module_object(), *graph());
return std::make_pair(result.first, loadTensors(result.second));

View File

@ -151,7 +151,7 @@ struct CAFFE2_API SourceRange {
bool operator!=(const SourceRange& rhs) const {
return !(*this == rhs);
}
c10::optional<SourceRange> findSourceRangeThatGenerated() const {
if (!source_) {
return c10::nullopt;

View File

@ -2,7 +2,7 @@
#include <torch/csrc/jit/source_range_serialization_impl.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/pickle.h>
namespace torch {
namespace jit {
@ -83,23 +83,22 @@ c10::IValue SourceRangeSerializer::serialize_source(
}
SourceRangePickler::SourceRangePickler()
: p(new Pickler()), srs(new SourceRangeSerializer()) {}
: srs(new SourceRangeSerializer()) {}
void SourceRangePickler::pickle(const SourceRangeRecords& ranges) {
p->protocol();
p->startTuple();
std::vector<char> SourceRangePickler::pickle(const SourceRangeRecords& ranges) {
std::vector<c10::IValue> ivalues;
for (const auto& range : ranges) {
std::vector<c10::IValue> row_elems{(int64_t)range.bytes,
srs->serialize(range.range)};
p->pushIValue(c10::ivalue::Tuple::create(std::move(row_elems)));
ivalues.emplace_back(c10::ivalue::Tuple::create(std::move(row_elems)));
}
p->endTuple();
p->stop();
std::vector<at::Tensor> table;
auto ivalue = c10::ivalue::Tuple::create(std::move(ivalues));
auto result = jit::pickle(ivalue, &table);
TORCH_CHECK(table.size() == 0, "Expected 0 tensors to be written");
return result;
}
const std::vector<char>& SourceRangePickler::get_data() {
return p->stack();
}
ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
at::DataPtr&& data,
@ -114,8 +113,9 @@ void ConcreteSourceRangeUnpickler::unpickle() {
return;
}
Unpickler up(data.get(), size, nullptr, nullptr);
auto ivalues = up.parse_ivalue_list();
auto ivalues = jit::unpickle(reinterpret_cast<const char*>(data.get()), size)
.toTuple()
->elements();
unpickled_records = std::make_shared<SourceRangeRecords>();
for (auto& val : ivalues) {
@ -149,4 +149,4 @@ c10::optional<SourceRange> ConcreteSourceRangeUnpickler::
}
} // namespace jit
} // namespace torch
} // namespace torch

View File

@ -21,12 +21,9 @@ class SourceRangePickler {
public:
SourceRangePickler();
void pickle(const SourceRangeRecords& ranges);
const std::vector<char>& get_data();
std::vector<char> pickle(const SourceRangeRecords& ranges);
private:
std::shared_ptr<Pickler> p;
std::shared_ptr<SourceRangeSerializer> srs;
};
@ -39,4 +36,4 @@ class SourceRangeUnpickler {
};
} // namespace jit
} // namespace torch
} // namespace torch

View File

@ -4,5 +4,6 @@
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/import.h>
#include <torch/csrc/jit/pickle.h>
#include <ATen/ATen.h>