Remove torch.save-related logic from pickler (#25502)

Summary:
The Pickler previously had a distinction between tensors that would be inlined in 1 pickle binary (matching the format of `torch.save()`) and tensors that are saved elsewhere with only a reference stored in the binary. This PR moves that distinction out to `torch::pickle_save` to match the eager Python interface.

The change can be seen in `register_prim_ops.cpp` where the call to `jit::pickle` is now `torch::pickle_save`
](https://our.intern.facebook.com/intern/diff/17175215/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25502

Pulled By: driazati

Differential Revision: D17175215

fbshipit-source-id: 8c9a21327cc79eaf6a0e488ea99e305be52f82b1
This commit is contained in:
davidriazati
2019-09-17 20:36:37 -07:00
committed by Facebook Github Bot
parent acb300fd6b
commit 61197e94b3
9 changed files with 144 additions and 102 deletions

View File

@ -521,6 +521,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/random.cpp
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/sequential.cpp
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/stream.cpp
${TORCH_SRC_DIR}/csrc/api/src/serialize.cpp
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/init.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/module.cpp

View File

@ -184,6 +184,7 @@ def add_torch_libs():
"torch/csrc/api/src/data/samplers/sequential.cpp",
"torch/csrc/api/src/data/samplers/stream.cpp",
"torch/csrc/api/src/jit.cpp",
"torch/csrc/api/src/serialize.cpp",
"torch/csrc/api/src/nn/init.cpp",
"torch/csrc/api/src/nn/module.cpp",
"torch/csrc/api/src/nn/modules/batchnorm.cpp",

View File

@ -2,6 +2,7 @@
#include <torch/serialize/archive.h>
#include <torch/serialize/tensor.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <utility>
@ -72,6 +73,8 @@ void save(const std::vector<torch::Tensor>& tensor_vec, SaveToArgs&&... args) {
archive.save_to(std::forward<SaveToArgs>(args)...);
}
TORCH_API std::vector<char> pickle_save(const torch::IValue& ivalue);
/// Deserializes the given `value`.
/// There must be an overload of `operator>>` between `serialize::InputArchive`
/// and `Value` for this method to be well-formed. Currently, such an overload

View File

@ -0,0 +1,13 @@
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/serialize.h>
#include <vector>
namespace torch {
std::vector<char> pickle_save(const at::IValue& ivalue) {
return jit::pickle_save(ivalue);
}
} // namespace torch

View File

@ -1,33 +1,24 @@
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/jit/pickler.h>
namespace torch {
namespace jit {
// These are both defined in `torch/serialization.py`
const char* torch_save_magic_number =
"\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19";
uint16_t protocol_version = 1001;
void pickle(
std::function<void(const char*, size_t)> writer,
std::function<void(const char* data_start, size_t data_len)> writer,
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table) {
Pickler pickler(std::move(writer), tensor_table);
if (tensor_table == nullptr) {
// No tensor table provided, so tensors will be stored directly in the blob.
// Add torch.save metadata so these tensors can be de-serialized later
pickler.torchSaveStart();
}
pickler.protocol();
pickler.pushIValue(ivalue);
pickler.stop();
if (tensor_table == nullptr) {
// No tensor table provided, so tensors will be stored directly in the blob.
// Add torch.save metadata so these tensors can be de-serialized later
pickler.torchSaveStop();
}
}
std::vector<char> pickle(
@ -45,6 +36,63 @@ std::vector<char> pickle(
return data;
}
// This has to live here instead of the C++ API to mirror torch.save since the
// mobile build excludes the C++ API
std::vector<char> pickle_save(const at::IValue& ivalue) {
std::vector<char> data;
auto writer = [&](const char* bytes, size_t len) {
data.insert(data.end(), bytes, bytes + len);
};
jit::Pickler pickler(writer, /*tensor_table=*/nullptr);
// Output data to match torch.save, see torch/serialization.py for details
// Magic number (0x1950a86a20f9469cfc6c)
pickler.protocol();
pickler.pushLong(torch_save_magic_number);
pickler.stop();
// Protocol Version
pickler.protocol();
pickler.pushInt(protocol_version);
pickler.stop();
// sys_info, this isn't actually used in de-serialization so we can leave this
// one empty
pickler.protocol();
IValue dict_ivalue =
c10::impl::GenericDict(c10::impl::deprecatedUntypedDict());
pickler.pushDict(dict_ivalue);
pickler.stop();
jit::Pickler data_pickler(writer, /*tensor_table=*/nullptr);
data_pickler.protocol();
data_pickler.pushIValue(ivalue);
data_pickler.stop();
auto writeable_tensors = data_pickler.tensorData();
std::vector<at::IValue> keys;
keys.reserve(writeable_tensors.size());
std::vector<at::TypePtr> types(writeable_tensors.size(), at::StringType::get());
for (size_t i = 0; i < writeable_tensors.size(); i++) {
keys.emplace_back(std::to_string(i));
}
auto keys_tuple = at::ivalue::Tuple::create(keys, at::TupleType::create(types));
jit::pickle(writer, keys_tuple);
for (const auto& tensor_data : writeable_tensors) {
const char* addr = tensor_data.data();
size_t numel = tensor_data.numel();
writer(reinterpret_cast<const char*>(&numel), sizeof(numel));
writer(addr, tensor_data.sizeInBytes());
}
return data;
}
IValue unpickle(
std::function<bool(char*, size_t)> reader,
ClassResolver class_resolver,

View File

@ -1,3 +1,5 @@
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/pickler.h>
@ -6,6 +8,17 @@
namespace torch {
namespace jit {
/// Pickle an IValue by calling a function to handle writing the data.
///
/// `writer` is a function that takes in a pointer to a chunk of memory and its
/// size and consumes it.
///
/// See `jit::pickle` for more details.
TORCH_API void pickle(
std::function<void(const char* data_start, size_t data_len)> writer,
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table = nullptr);
/// Save a `torch::IValue` in a format compatible with Python's `pickle` module
///
/// If present, `tensor_table` is a pointer to a table in which tensors that
@ -38,16 +51,9 @@ TORCH_API std::vector<char> pickle(
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table = nullptr);
/// Pickle an IValue by calling a function to handle writing the data.
///
/// `writer` is a function that takes in a pointer to a chunk of memory and its
/// size and consumes it.
///
/// See `jit::pickle` for more details.
TORCH_API void pickle(
std::function<void(const char* data_start, size_t data_len)> writer,
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table = nullptr);
TORCH_API std::vector<char> pickle_save(const IValue& ivalue);
/// `reader` is a function that takes in a size to read from some pickled
/// binary. `reader` should remember where it last read, and return

View File

@ -94,56 +94,6 @@ void Pickler::stop() {
push<PickleOpCode>(PickleOpCode::STOP);
}
void Pickler::torchSaveStop() {
// Add the binary data for all the tensors to be included in the same binary
// TODO: The pickler should be refactored to stream out to a stream directly
// instead of staging in the stack_ array
// As another pickle program in the same binary archive, add a list of
// keys for each tensor (see torch/serialization.py)
protocol();
push<PickleOpCode>(PickleOpCode::MARK);
for (size_t i = 0; i < tensor_data_.size(); ++i) {
std::string key = std::to_string(i);
push<PickleOpCode>(PickleOpCode::BINUNICODE);
push<uint32_t>(key.size());
pushBytes(key);
}
push<PickleOpCode>(PickleOpCode::TUPLE);
stop();
// Now dump the tensor binary data
for (const auto& data : tensor_data_) {
// first dump size
push<size_t>(data.numel());
writer_(data.data(), data.sizeInBytes());
}
}
void Pickler::torchSaveStart() {
// Output data to match torch.save, see torch/serialization.py for details
// Magic number (0x1950a86a20f9469cfc6c)
protocol();
push<PickleOpCode>(PickleOpCode::LONG1);
// LONG1 size
pushBytes("\x0a");
// LONG1 data
pushBytes("\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19");
stop();
// Protocol Version (1001)
protocol();
push<PickleOpCode>(PickleOpCode::BININT2);
pushBytes("\xe9\x03");
stop();
// sys_info, this isn't actually used in de-serialization so we can leave this
// one empty
protocol();
push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
stop();
}
// unmemoized version called by pushIValue
void Pickler::pushIValueImpl(const IValue& ivalue) {
if (ivalue.isTensor()) {
@ -332,6 +282,7 @@ void Pickler::pushStorageOfTensor(const at::Tensor& tensor) {
push<PickleOpCode>(PickleOpCode::TUPLE);
push<PickleOpCode>(PickleOpCode::BINPERSID);
// TODO: Skip this if not writing tensors
memoized_storage_map_[addr] = pushNextBinPut();
tensor_data_.push_back(getWriteableTensorData(tensor));
}
@ -426,19 +377,6 @@ void Pickler::pushClass(PicklerClass cls) {
pushGlobal("torch.jit._pickle", getClassName(cls));
}
void Pickler::pushTensorReference(const IValue& ivalue) {
pushClass(PicklerClass::TENSOR);
tensor_table_->push_back(ivalue.toTensor());
int64_t tensor_id = tensor_table_->size() - 1;
// Reduce arguments are spread (e.g. `*args`) before calling the global,
// so wrap in a tuple
push<PickleOpCode>(PickleOpCode::MARK);
pushIValue(tensor_id);
push<PickleOpCode>(PickleOpCode::TUPLE);
push<PickleOpCode>(PickleOpCode::REDUCE);
}
void Pickler::pushSpecializedList(
const IValue& ivalue,
PicklerClass cls,
@ -476,13 +414,46 @@ void Pickler::pushDouble(double value) {
}
}
void Pickler::pushLong(const std::string& data) {
uint64_t size = data.size();
if (size <= std::numeric_limits<uint8_t>::max()) {
push<PickleOpCode>(PickleOpCode::LONG1);
push<uint8_t>(size);
} else {
TORCH_INTERNAL_ASSERT(
data.size() > std::numeric_limits<uint32_t>::max(),
"Cannot pickle a long with a size larger than 4 bytes")
push<PickleOpCode>(PickleOpCode::LONG4);
push<uint64_t>(size);
}
pushBytes(data);
}
void Pickler::pushTensorReference(const IValue& ivalue) {
pushClass(PicklerClass::TENSOR);
tensor_table_->push_back(ivalue.toTensor());
int64_t tensor_id = tensor_table_->size() - 1;
// Reduce arguments are spread (e.g. `*args`) before calling the global,
// so wrap in a tuple
push<PickleOpCode>(PickleOpCode::MARK);
pushIValue(tensor_id);
push<PickleOpCode>(PickleOpCode::TUPLE);
push<PickleOpCode>(PickleOpCode::REDUCE);
}
void Pickler::pushDict(const IValue& ivalue) {
push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
auto dict_items = iterationOrder(ivalue.toGenericDict());
if (dict_items.size() == 0) {
return;
}
push<PickleOpCode>(PickleOpCode::MARK);
// Sort the dict for deterministic keys
auto dict_items = iterationOrder(ivalue.toGenericDict());
for (const auto& pair : dict_items) {
pushIValue(pair.first);
pushIValue(pair.second);
@ -760,6 +731,9 @@ PickleOpCode Unpickler::readInstruction() {
stack_.pop_back();
switch (pickler_class) {
case PicklerClass::TENSOR:
TORCH_INTERNAL_ASSERT(
tensor_table_,
"Pickler tried to write a tensor but had no tensor table to write to");
stack_.emplace_back(tensor_table_->at(setitem_data.toInt()));
break;
case PicklerClass::INTLIST:
@ -779,7 +753,7 @@ PickleOpCode Unpickler::readInstruction() {
case PicklerClass::TENSOR:
TORCH_CHECK(
tensor_table_,
"Found a tensor table reference but Pickler"
"Found a tensor table reference but Unpickler"
" has no tensor table\n");
stack_.emplace_back(tensor_table_->at(data.toInt()));
break;

View File

@ -127,7 +127,7 @@ class Pickler {
public:
Pickler(
std::function<void(const char*, size_t)> writer,
std::vector<at::Tensor>* tensor_table = nullptr)
std::vector<at::Tensor>* tensor_table)
: writer_(writer), tensor_table_(tensor_table) {}
// Push protocol onto the stack
@ -138,12 +138,6 @@ class Pickler {
void pushIValue(const IValue& ivalue);
// See torch/serialization.py for details, pushes a magic number, torch
// serialization version, and system info to the pickle archive all as
// individual pickle programs
void torchSaveStart();
void torchSaveStop();
void startTuple();
void endTuple();
@ -151,17 +145,19 @@ class Pickler {
return tensor_data_;
}
void pushDict(const IValue& ivalue);
void pushInt(int64_t value);
void pushLong(const std::string& data);
private:
void pushIValueImpl(const IValue& ivalue);
void pushDict(const IValue& ivalue);
void pushDouble(double value);
void pushGenericList(const IValue& ivalue);
void pushInt(int64_t value);
void pushIntList(const IValue& ivalue);
void pushList(const IValue& ivalue);
void pushLiteralTensor(const IValue& ivalue);
void pushTensor(const IValue& ivalue);
void pushTensorReference(const IValue& ivalue);
void pushLiteralTensor(const IValue& ivalue);
void pushTuple(const IValue& ivalue);
void pushString(const std::string& string);
// unmemoized version

View File

@ -745,7 +745,7 @@ RegisterOperators reg(
auto ivalue = pop(stack);
// Pickle the tensor
auto data = pickle({ivalue});
auto data = jit::pickle_save(ivalue);
// Write file
std::fstream output(filename, std::ios::out | std::ios::binary);