mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove torch.save
-related logic from pickler (#25502)
Summary: The Pickler previously had a distinction between tensors that would be inlined in 1 pickle binary (matching the format of `torch.save()`) and tensors that are saved elsewhere with only a reference stored in the binary. This PR moves that distinction out to `torch::pickle_save` to match the eager Python interface. The change can be seen in `register_prim_ops.cpp` where the call to `jit::pickle` is now `torch::pickle_save` ](https://our.intern.facebook.com/intern/diff/17175215/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/25502 Pulled By: driazati Differential Revision: D17175215 fbshipit-source-id: 8c9a21327cc79eaf6a0e488ea99e305be52f82b1
This commit is contained in:
committed by
Facebook Github Bot
parent
acb300fd6b
commit
61197e94b3
@ -521,6 +521,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/random.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/sequential.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/stream.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/serialize.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/init.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/module.cpp
|
||||
|
@ -184,6 +184,7 @@ def add_torch_libs():
|
||||
"torch/csrc/api/src/data/samplers/sequential.cpp",
|
||||
"torch/csrc/api/src/data/samplers/stream.cpp",
|
||||
"torch/csrc/api/src/jit.cpp",
|
||||
"torch/csrc/api/src/serialize.cpp",
|
||||
"torch/csrc/api/src/nn/init.cpp",
|
||||
"torch/csrc/api/src/nn/module.cpp",
|
||||
"torch/csrc/api/src/nn/modules/batchnorm.cpp",
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/serialize/tensor.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
#include <utility>
|
||||
|
||||
@ -72,6 +73,8 @@ void save(const std::vector<torch::Tensor>& tensor_vec, SaveToArgs&&... args) {
|
||||
archive.save_to(std::forward<SaveToArgs>(args)...);
|
||||
}
|
||||
|
||||
TORCH_API std::vector<char> pickle_save(const torch::IValue& ivalue);
|
||||
|
||||
/// Deserializes the given `value`.
|
||||
/// There must be an overload of `operator>>` between `serialize::InputArchive`
|
||||
/// and `Value` for this method to be well-formed. Currently, such an overload
|
||||
|
13
torch/csrc/api/src/serialize.cpp
Normal file
13
torch/csrc/api/src/serialize.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include <torch/serialize.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
|
||||
std::vector<char> pickle_save(const at::IValue& ivalue) {
|
||||
return jit::pickle_save(ivalue);
|
||||
}
|
||||
|
||||
} // namespace torch
|
@ -1,33 +1,24 @@
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/jit/pickle.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// These are both defined in `torch/serialization.py`
|
||||
const char* torch_save_magic_number =
|
||||
"\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19";
|
||||
uint16_t protocol_version = 1001;
|
||||
|
||||
void pickle(
|
||||
std::function<void(const char*, size_t)> writer,
|
||||
std::function<void(const char* data_start, size_t data_len)> writer,
|
||||
const IValue& ivalue,
|
||||
std::vector<at::Tensor>* tensor_table) {
|
||||
Pickler pickler(std::move(writer), tensor_table);
|
||||
|
||||
if (tensor_table == nullptr) {
|
||||
// No tensor table provided, so tensors will be stored directly in the blob.
|
||||
// Add torch.save metadata so these tensors can be de-serialized later
|
||||
pickler.torchSaveStart();
|
||||
}
|
||||
|
||||
pickler.protocol();
|
||||
pickler.pushIValue(ivalue);
|
||||
pickler.stop();
|
||||
|
||||
if (tensor_table == nullptr) {
|
||||
// No tensor table provided, so tensors will be stored directly in the blob.
|
||||
// Add torch.save metadata so these tensors can be de-serialized later
|
||||
pickler.torchSaveStop();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<char> pickle(
|
||||
@ -45,6 +36,63 @@ std::vector<char> pickle(
|
||||
return data;
|
||||
}
|
||||
|
||||
// This has to live here instead of the C++ API to mirror torch.save since the
|
||||
// mobile build excludes the C++ API
|
||||
std::vector<char> pickle_save(const at::IValue& ivalue) {
|
||||
std::vector<char> data;
|
||||
|
||||
auto writer = [&](const char* bytes, size_t len) {
|
||||
data.insert(data.end(), bytes, bytes + len);
|
||||
};
|
||||
|
||||
jit::Pickler pickler(writer, /*tensor_table=*/nullptr);
|
||||
// Output data to match torch.save, see torch/serialization.py for details
|
||||
// Magic number (0x1950a86a20f9469cfc6c)
|
||||
pickler.protocol();
|
||||
pickler.pushLong(torch_save_magic_number);
|
||||
pickler.stop();
|
||||
|
||||
// Protocol Version
|
||||
pickler.protocol();
|
||||
pickler.pushInt(protocol_version);
|
||||
pickler.stop();
|
||||
|
||||
// sys_info, this isn't actually used in de-serialization so we can leave this
|
||||
// one empty
|
||||
pickler.protocol();
|
||||
IValue dict_ivalue =
|
||||
c10::impl::GenericDict(c10::impl::deprecatedUntypedDict());
|
||||
pickler.pushDict(dict_ivalue);
|
||||
pickler.stop();
|
||||
|
||||
jit::Pickler data_pickler(writer, /*tensor_table=*/nullptr);
|
||||
data_pickler.protocol();
|
||||
data_pickler.pushIValue(ivalue);
|
||||
data_pickler.stop();
|
||||
|
||||
auto writeable_tensors = data_pickler.tensorData();
|
||||
|
||||
std::vector<at::IValue> keys;
|
||||
keys.reserve(writeable_tensors.size());
|
||||
std::vector<at::TypePtr> types(writeable_tensors.size(), at::StringType::get());
|
||||
|
||||
for (size_t i = 0; i < writeable_tensors.size(); i++) {
|
||||
keys.emplace_back(std::to_string(i));
|
||||
}
|
||||
|
||||
auto keys_tuple = at::ivalue::Tuple::create(keys, at::TupleType::create(types));
|
||||
jit::pickle(writer, keys_tuple);
|
||||
|
||||
for (const auto& tensor_data : writeable_tensors) {
|
||||
const char* addr = tensor_data.data();
|
||||
size_t numel = tensor_data.numel();
|
||||
writer(reinterpret_cast<const char*>(&numel), sizeof(numel));
|
||||
writer(addr, tensor_data.sizeInBytes());
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
IValue unpickle(
|
||||
std::function<bool(char*, size_t)> reader,
|
||||
ClassResolver class_resolver,
|
||||
|
@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
@ -6,6 +8,17 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
/// Pickle an IValue by calling a function to handle writing the data.
|
||||
///
|
||||
/// `writer` is a function that takes in a pointer to a chunk of memory and its
|
||||
/// size and consumes it.
|
||||
///
|
||||
/// See `jit::pickle` for more details.
|
||||
TORCH_API void pickle(
|
||||
std::function<void(const char* data_start, size_t data_len)> writer,
|
||||
const IValue& ivalue,
|
||||
std::vector<at::Tensor>* tensor_table = nullptr);
|
||||
|
||||
/// Save a `torch::IValue` in a format compatible with Python's `pickle` module
|
||||
///
|
||||
/// If present, `tensor_table` is a pointer to a table in which tensors that
|
||||
@ -38,16 +51,9 @@ TORCH_API std::vector<char> pickle(
|
||||
const IValue& ivalue,
|
||||
std::vector<at::Tensor>* tensor_table = nullptr);
|
||||
|
||||
/// Pickle an IValue by calling a function to handle writing the data.
|
||||
///
|
||||
/// `writer` is a function that takes in a pointer to a chunk of memory and its
|
||||
/// size and consumes it.
|
||||
///
|
||||
/// See `jit::pickle` for more details.
|
||||
TORCH_API void pickle(
|
||||
std::function<void(const char* data_start, size_t data_len)> writer,
|
||||
const IValue& ivalue,
|
||||
std::vector<at::Tensor>* tensor_table = nullptr);
|
||||
|
||||
TORCH_API std::vector<char> pickle_save(const IValue& ivalue);
|
||||
|
||||
|
||||
/// `reader` is a function that takes in a size to read from some pickled
|
||||
/// binary. `reader` should remember where it last read, and return
|
||||
|
@ -94,56 +94,6 @@ void Pickler::stop() {
|
||||
push<PickleOpCode>(PickleOpCode::STOP);
|
||||
}
|
||||
|
||||
void Pickler::torchSaveStop() {
|
||||
// Add the binary data for all the tensors to be included in the same binary
|
||||
// TODO: The pickler should be refactored to stream out to a stream directly
|
||||
// instead of staging in the stack_ array
|
||||
// As another pickle program in the same binary archive, add a list of
|
||||
// keys for each tensor (see torch/serialization.py)
|
||||
protocol();
|
||||
push<PickleOpCode>(PickleOpCode::MARK);
|
||||
for (size_t i = 0; i < tensor_data_.size(); ++i) {
|
||||
std::string key = std::to_string(i);
|
||||
push<PickleOpCode>(PickleOpCode::BINUNICODE);
|
||||
push<uint32_t>(key.size());
|
||||
pushBytes(key);
|
||||
}
|
||||
|
||||
push<PickleOpCode>(PickleOpCode::TUPLE);
|
||||
stop();
|
||||
|
||||
// Now dump the tensor binary data
|
||||
for (const auto& data : tensor_data_) {
|
||||
// first dump size
|
||||
push<size_t>(data.numel());
|
||||
writer_(data.data(), data.sizeInBytes());
|
||||
}
|
||||
}
|
||||
|
||||
void Pickler::torchSaveStart() {
|
||||
// Output data to match torch.save, see torch/serialization.py for details
|
||||
// Magic number (0x1950a86a20f9469cfc6c)
|
||||
protocol();
|
||||
push<PickleOpCode>(PickleOpCode::LONG1);
|
||||
// LONG1 size
|
||||
pushBytes("\x0a");
|
||||
// LONG1 data
|
||||
pushBytes("\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19");
|
||||
stop();
|
||||
|
||||
// Protocol Version (1001)
|
||||
protocol();
|
||||
push<PickleOpCode>(PickleOpCode::BININT2);
|
||||
pushBytes("\xe9\x03");
|
||||
stop();
|
||||
|
||||
// sys_info, this isn't actually used in de-serialization so we can leave this
|
||||
// one empty
|
||||
protocol();
|
||||
push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
|
||||
stop();
|
||||
}
|
||||
|
||||
// unmemoized version called by pushIValue
|
||||
void Pickler::pushIValueImpl(const IValue& ivalue) {
|
||||
if (ivalue.isTensor()) {
|
||||
@ -332,6 +282,7 @@ void Pickler::pushStorageOfTensor(const at::Tensor& tensor) {
|
||||
push<PickleOpCode>(PickleOpCode::TUPLE);
|
||||
push<PickleOpCode>(PickleOpCode::BINPERSID);
|
||||
|
||||
// TODO: Skip this if not writing tensors
|
||||
memoized_storage_map_[addr] = pushNextBinPut();
|
||||
tensor_data_.push_back(getWriteableTensorData(tensor));
|
||||
}
|
||||
@ -426,19 +377,6 @@ void Pickler::pushClass(PicklerClass cls) {
|
||||
pushGlobal("torch.jit._pickle", getClassName(cls));
|
||||
}
|
||||
|
||||
void Pickler::pushTensorReference(const IValue& ivalue) {
|
||||
pushClass(PicklerClass::TENSOR);
|
||||
tensor_table_->push_back(ivalue.toTensor());
|
||||
int64_t tensor_id = tensor_table_->size() - 1;
|
||||
// Reduce arguments are spread (e.g. `*args`) before calling the global,
|
||||
// so wrap in a tuple
|
||||
push<PickleOpCode>(PickleOpCode::MARK);
|
||||
pushIValue(tensor_id);
|
||||
push<PickleOpCode>(PickleOpCode::TUPLE);
|
||||
|
||||
push<PickleOpCode>(PickleOpCode::REDUCE);
|
||||
}
|
||||
|
||||
void Pickler::pushSpecializedList(
|
||||
const IValue& ivalue,
|
||||
PicklerClass cls,
|
||||
@ -476,13 +414,46 @@ void Pickler::pushDouble(double value) {
|
||||
}
|
||||
}
|
||||
|
||||
void Pickler::pushLong(const std::string& data) {
|
||||
uint64_t size = data.size();
|
||||
|
||||
if (size <= std::numeric_limits<uint8_t>::max()) {
|
||||
push<PickleOpCode>(PickleOpCode::LONG1);
|
||||
push<uint8_t>(size);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
data.size() > std::numeric_limits<uint32_t>::max(),
|
||||
"Cannot pickle a long with a size larger than 4 bytes")
|
||||
push<PickleOpCode>(PickleOpCode::LONG4);
|
||||
push<uint64_t>(size);
|
||||
}
|
||||
pushBytes(data);
|
||||
}
|
||||
|
||||
void Pickler::pushTensorReference(const IValue& ivalue) {
|
||||
pushClass(PicklerClass::TENSOR);
|
||||
tensor_table_->push_back(ivalue.toTensor());
|
||||
int64_t tensor_id = tensor_table_->size() - 1;
|
||||
// Reduce arguments are spread (e.g. `*args`) before calling the global,
|
||||
// so wrap in a tuple
|
||||
push<PickleOpCode>(PickleOpCode::MARK);
|
||||
pushIValue(tensor_id);
|
||||
push<PickleOpCode>(PickleOpCode::TUPLE);
|
||||
|
||||
push<PickleOpCode>(PickleOpCode::REDUCE);
|
||||
}
|
||||
|
||||
void Pickler::pushDict(const IValue& ivalue) {
|
||||
push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
|
||||
|
||||
auto dict_items = iterationOrder(ivalue.toGenericDict());
|
||||
if (dict_items.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
push<PickleOpCode>(PickleOpCode::MARK);
|
||||
|
||||
// Sort the dict for deterministic keys
|
||||
auto dict_items = iterationOrder(ivalue.toGenericDict());
|
||||
for (const auto& pair : dict_items) {
|
||||
pushIValue(pair.first);
|
||||
pushIValue(pair.second);
|
||||
@ -760,6 +731,9 @@ PickleOpCode Unpickler::readInstruction() {
|
||||
stack_.pop_back();
|
||||
switch (pickler_class) {
|
||||
case PicklerClass::TENSOR:
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
tensor_table_,
|
||||
"Pickler tried to write a tensor but had no tensor table to write to");
|
||||
stack_.emplace_back(tensor_table_->at(setitem_data.toInt()));
|
||||
break;
|
||||
case PicklerClass::INTLIST:
|
||||
@ -779,7 +753,7 @@ PickleOpCode Unpickler::readInstruction() {
|
||||
case PicklerClass::TENSOR:
|
||||
TORCH_CHECK(
|
||||
tensor_table_,
|
||||
"Found a tensor table reference but Pickler"
|
||||
"Found a tensor table reference but Unpickler"
|
||||
" has no tensor table\n");
|
||||
stack_.emplace_back(tensor_table_->at(data.toInt()));
|
||||
break;
|
||||
|
@ -127,7 +127,7 @@ class Pickler {
|
||||
public:
|
||||
Pickler(
|
||||
std::function<void(const char*, size_t)> writer,
|
||||
std::vector<at::Tensor>* tensor_table = nullptr)
|
||||
std::vector<at::Tensor>* tensor_table)
|
||||
: writer_(writer), tensor_table_(tensor_table) {}
|
||||
|
||||
// Push protocol onto the stack
|
||||
@ -138,12 +138,6 @@ class Pickler {
|
||||
|
||||
void pushIValue(const IValue& ivalue);
|
||||
|
||||
// See torch/serialization.py for details, pushes a magic number, torch
|
||||
// serialization version, and system info to the pickle archive all as
|
||||
// individual pickle programs
|
||||
void torchSaveStart();
|
||||
void torchSaveStop();
|
||||
|
||||
void startTuple();
|
||||
void endTuple();
|
||||
|
||||
@ -151,17 +145,19 @@ class Pickler {
|
||||
return tensor_data_;
|
||||
}
|
||||
|
||||
void pushDict(const IValue& ivalue);
|
||||
void pushInt(int64_t value);
|
||||
void pushLong(const std::string& data);
|
||||
|
||||
private:
|
||||
void pushIValueImpl(const IValue& ivalue);
|
||||
void pushDict(const IValue& ivalue);
|
||||
void pushDouble(double value);
|
||||
void pushGenericList(const IValue& ivalue);
|
||||
void pushInt(int64_t value);
|
||||
void pushIntList(const IValue& ivalue);
|
||||
void pushList(const IValue& ivalue);
|
||||
void pushLiteralTensor(const IValue& ivalue);
|
||||
void pushTensor(const IValue& ivalue);
|
||||
void pushTensorReference(const IValue& ivalue);
|
||||
void pushLiteralTensor(const IValue& ivalue);
|
||||
void pushTuple(const IValue& ivalue);
|
||||
void pushString(const std::string& string);
|
||||
// unmemoized version
|
||||
|
@ -745,7 +745,7 @@ RegisterOperators reg(
|
||||
auto ivalue = pop(stack);
|
||||
|
||||
// Pickle the tensor
|
||||
auto data = pickle({ivalue});
|
||||
auto data = jit::pickle_save(ivalue);
|
||||
|
||||
// Write file
|
||||
std::fstream output(filename, std::ios::out | std::ios::binary);
|
||||
|
Reference in New Issue
Block a user