Load tensors directly from pickle archive

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23281

Test Plan: Imported from OSS

Differential Revision: D16452815

Pulled By: zdevito

fbshipit-source-id: 918eef3ad444b598ab655c39037e4baafdcb51e1
This commit is contained in:
Zachary DeVito
2019-08-22 11:44:53 -07:00
committed by Facebook Github Bot
parent c33adf539c
commit e2ccccee9a
15 changed files with 352 additions and 176 deletions

View File

@ -165,7 +165,7 @@ static std::string getPadding(size_t cursor, const std::string& filename, size_t
return buf;
}
bool PyTorchStreamReader::hasFile(const std::string& name) {
bool PyTorchStreamReader::hasRecord(const std::string& name) {
std::stringstream ss;
ss << archive_name_ << "/" << name;
mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0);
@ -177,7 +177,7 @@ bool PyTorchStreamReader::hasFile(const std::string& name) {
return result;
}
size_t PyTorchStreamReader::getFileID(const std::string& name) {
size_t PyTorchStreamReader::getRecordID(const std::string& name) {
std::stringstream ss;
ss << archive_name_ << "/" << name;
size_t result = mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0);
@ -190,7 +190,7 @@ size_t PyTorchStreamReader::getFileID(const std::string& name) {
// return dataptr, size
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
size_t key = getFileID(name);
size_t key = getRecordID(name);
mz_zip_archive_file_stat stat;
mz_zip_reader_file_stat(ar_.get(), key, &stat);
valid("retrieving file meta-data");
@ -208,7 +208,7 @@ static int64_t read_le_16(uint8_t* buf) {
size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
mz_zip_archive_file_stat stat;
mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat);
mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
valid("retriving file meta-data");
uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
in_->read(

View File

@ -92,9 +92,6 @@ namespace serialize {
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x1L;
// Writer-specific constants
constexpr uint64_t kFileFormatVersion = 0x2L;
// Writer-specific constants
constexpr uint64_t kFieldAlignment = 64;
@ -107,7 +104,7 @@ class CAFFE2_API PyTorchStreamReader final {
// return dataptr, size
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
size_t getRecordOffset(const std::string& name);
bool hasFile(const std::string& name);
bool hasRecord(const std::string& name);
~PyTorchStreamReader();
@ -115,7 +112,7 @@ class CAFFE2_API PyTorchStreamReader final {
void init();
size_t read(uint64_t pos, char* buf, size_t n);
void valid(const char* what);
size_t getFileID(const std::string& name);
size_t getRecordID(const std::string& name);
friend size_t
istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n);

View File

@ -39,9 +39,9 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
// read records through readers
PyTorchStreamReader reader(&iss);
ASSERT_TRUE(reader.hasFile("key1"));
ASSERT_TRUE(reader.hasFile("key2"));
ASSERT_FALSE(reader.hasFile("key2000"));
ASSERT_TRUE(reader.hasRecord("key1"));
ASSERT_TRUE(reader.hasRecord("key2"));
ASSERT_FALSE(reader.hasRecord("key2000"));
at::DataPtr data_ptr;
int64_t size;
std::tie(data_ptr, size) = reader.getRecord("key1");

View File

@ -9036,6 +9036,7 @@ a")
m_import = self.getExportImportCopy(m_orig)
self.assertEqual(m_orig.foo(), m_import.foo())
self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())

View File

@ -55,7 +55,7 @@ ScriptCall ScriptCall::fromMessage(const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value = jit::unpickle(payload, payload_size, &message.tensors());
auto value = jit::unpickle(payload, payload_size, nullptr, &message.tensors());
auto values = value.toTuple()->elements();

View File

@ -29,7 +29,7 @@ Message ScriptRet::toMessage() {
ScriptRet ScriptRet::fromMessage(const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value = jit::unpickle(payload, payload_size, &message.tensors());
auto value = jit::unpickle(payload, payload_size, nullptr, &message.tensors());
return ScriptRet(std::move(value));
}

View File

@ -618,34 +618,70 @@ class ScriptModuleSerializer {
// 1. Several tests that depend on the serialization format details.
// 2. The emit module hook (since combining the old export and new import code
// is going to cause jitter)
class ScriptModuleSerializer2 : public ScriptModuleSerializer {
class ScriptModuleSerializer2 {
public:
ScriptModuleSerializer2(const std::string& filename)
: ScriptModuleSerializer(filename) {}
: writer_(filename.c_str()) {}
ScriptModuleSerializer2(std::ostream* ofs) : ScriptModuleSerializer(ofs) {}
private:
void convertModel(
ScriptModuleSerializer2(std::ostream* ofs) : ofs_(), writer_(ofs) {}
void serialize(
const script::Module& module,
torch::ModelDef* model_def,
const script::ExtraFilesMap& extra_files) override {
model_def->set_producer_name("pytorch");
model_def->set_producer_version("1.0"); // TODO: set the producer version
// using appropriate function call
model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST);
// Serialize all code info.
convertNamedType(module.type());
// Then pickle the module
auto data = pickle(module.module_object(), &tensor_table_);
writer_.writeRecord("data.pkl", data.data(), data.size());
writeTensorTable(model_def);
writeLibs(model_def);
const script::ExtraFilesMap& extra_files) {
C10_LOG_API_USAGE_ONCE("torch.script.save");
writeExtraFiles(module, extra_files);
// Serialize all code info.
writeCode(module.type());
// The tensor constants from the code are written to a separate archive
// so loading the code does not depend on loading the data
std::vector<IValue> ivalue_constants(
constant_table_.begin(), constant_table_.end());
writeArchive("constants", c10::ivalue::Tuple::create(ivalue_constants));
// finally we serialize the model
writeArchive("data", module.module_object());
}
void writeLibs(ModelDef* model_def) override {
private:
void writeArchive(const std::string& archive_name, const IValue& value) {
std::vector<char> data;
Pickler data_pickle(
[&](const char* buf, size_t size) {
data.insert(data.end(), buf, buf + size);
},
nullptr);
data_pickle.protocol();
data_pickle.pushIValue(value);
data_pickle.stop();
size_t i = 0;
for (const auto& td : data_pickle.tensorData()) {
std::stringstream fname;
fname << archive_name << "/" << i++;
writer_.writeRecord(fname.str(), td.data(), td.sizeInBytes());
}
std::stringstream fname;
fname << archive_name << ".pkl";
writer_.writeRecord(fname.str(), data.data(), data.size());
}
void writeExtraFiles(
const script::Module& module,
const script::ExtraFilesMap& extra_files) {
// Write out extra files.
for (const auto& kv : extra_files) {
const std::string key = "extra/" + kv.first;
writer_.writeRecord(key, kv.second.data(), kv.second.size());
}
auto hook = GetExtraFilesHook();
if (hook) {
script::ExtraFilesMap hook_files = hook(module);
for (const auto& kv : hook_files) {
const std::string key = "extra/" + kv.first;
writer_.writeRecord(key, kv.second.data(), kv.second.size());
}
}
}
void writeCode(const at::NamedTypePtr& root_type) {
convertNamedType(root_type);
static const std::string opset_string =
c10::str("op_version_set = ", CURRENT_OP_VERSION_SET, "\n");
@ -661,7 +697,7 @@ class ScriptModuleSerializer2 : public ScriptModuleSerializer {
// For the type, foo.bar.Baz
const std::string filename = ImportExportHelpers::qualifierToPath(
converted_type->name()->prefix(), torch::PROTO_VERSION_NEWEST);
converted_type->name()->prefix(), "code/");
// End state: filename is "foo/bar.py", in which we will define a class
// named Baz
auto& stream = fileToSrc[filename];
@ -708,7 +744,7 @@ class ScriptModuleSerializer2 : public ScriptModuleSerializer {
/*compress=*/true);
}
}
void convertNamedType(const c10::NamedTypePtr& class_type) override {
void convertNamedType(const c10::NamedTypePtr& class_type) {
if (converted_types_.contains(class_type)) {
return;
}
@ -720,7 +756,7 @@ class ScriptModuleSerializer2 : public ScriptModuleSerializer {
source_stream,
source_ranges,
class_type,
tensor_table_,
constant_table_,
class_deps,
/*enforce_importable=*/true);
@ -739,6 +775,10 @@ class ScriptModuleSerializer2 : public ScriptModuleSerializer {
converted_types_.insert(class_type, std::move(info));
}
std::ofstream ofs_;
caffe2::serialize::PyTorchStreamWriter writer_;
std::vector<at::Tensor> constant_table_;
// all deps used by this module hierarchy
struct TypeInfo {
std::string source;
@ -804,8 +844,8 @@ void ScriptModuleSerializer::writeLibs(torch::ModelDef* model_def) {
const auto& class_src = item.value();
// For the type, foo.bar.Baz
const std::string filename =
ImportExportHelpers::qualifierToPath(class_type->name()->prefix(), 5);
const std::string filename = ImportExportHelpers::qualifierToPath(
class_type->name()->prefix(), "libs/");
// End state: filename is "foo/bar.py", in which we will define a class
// named Baz
fileToSrc[filename] << class_src;
@ -816,8 +856,8 @@ void ScriptModuleSerializer::writeLibs(torch::ModelDef* model_def) {
std::unordered_set<std::string> written_files;
for (const auto& item : converted_classes_) {
const c10::NamedTypePtr& class_type = item.key();
const std::string filename =
ImportExportHelpers::qualifierToPath(class_type->name()->prefix(), 5);
const std::string filename = ImportExportHelpers::qualifierToPath(
class_type->name()->prefix(), "libs/");
if (written_files.count(filename)) {
continue;
}

View File

@ -68,34 +68,34 @@ class ScriptModuleDeserializer final {
script::ExtraFilesMap& extra_files);
private:
at::Tensor loadTensor(
IValue readArchive(const std::string& archive_name);
script::Module LEGACY_deserialize();
at::Tensor LEGACY_loadTensor(
const torch::TensorDef& tensor_proto,
std::unordered_map<std::string, at::Storage>& storageMap);
void loadTensorTable(torch::ModelDef* model_def);
void LEGACY_loadTensorTable(torch::ModelDef* model_def);
void importCallback(const std::string& qualifier);
void moduleSetState(const script::Module& module, IValue state);
void LEGACY_moduleSetState(const script::Module& module, IValue state);
std::shared_ptr<script::CompilationUnit> compilation_unit_;
std::unique_ptr<PyTorchStreamReader> reader_;
c10::optional<at::Device> device_;
std::vector<std::string> moduleStack_;
std::vector<std::string> LEGACY_moduleStack_;
std::vector<at::Tensor> tensor_table_;
std::vector<at::Tensor> constants_table_;
std::unordered_set<std::string> imported_libs_;
IValue LEGACY_loadPickleArchive(const std::string& name);
script::Module LEGACY_convertModule(const torch::ModuleDef& module_def);
std::vector<IValue> LEGACY_pickled_ivalues_;
size_t proto_version_;
std::string export_prefix_ = "code/";
};
script::Module ScriptModuleDeserializer::deserialize(
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
C10_LOG_API_USAGE_ONCE("torch.script.load");
script::Module ScriptModuleDeserializer::LEGACY_deserialize() {
torch::ModelDef model_def;
at::DataPtr data_ptr;
size_t data_size;
std::tie(data_ptr, data_size) = reader_->getRecord("model.json");
@ -125,13 +125,69 @@ script::Module ScriptModuleDeserializer::deserialize(
AT_ASSERTM(
model_def.ParseFromString(binary_string),
"JSON transcoder produced invalid protobuf output.");
device_ = device;
proto_version_ = model_def.proto_version();
auto proto_version = model_def.proto_version();
export_prefix_ = "libs/";
LEGACY_loadTensorTable(&model_def);
AT_ASSERT(proto_version < 6);
if (proto_version == 2) {
const auto& list =
LEGACY_loadPickleArchive("attributes.pkl").toGenericList();
LEGACY_pickled_ivalues_.insert(
LEGACY_pickled_ivalues_.end(), list.begin(), list.end());
} else if (proto_version >= 3) {
LEGACY_pickled_ivalues_ =
LEGACY_loadPickleArchive("attributes.pkl").toTuple()->elements();
}
LEGACY_moduleStack_.push_back("__torch__");
const auto& module_def = model_def.main_module();
return LEGACY_convertModule(module_def);
}
IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
std::stringstream picklename;
picklename << archive_name << ".pkl";
at::DataPtr pickle_ptr;
size_t pickle_size;
std::tie(pickle_ptr, pickle_size) = reader_->getRecord(picklename.str());
size_t bytes_read = 0;
auto data = reinterpret_cast<const char*>(pickle_ptr.get());
auto reader = [&](char* buffer, size_t len) {
if (bytes_read + len > pickle_size) {
return false;
}
// Copy len bytes into buffer
const char* start = data + bytes_read;
std::memcpy(buffer, start, len);
bytes_read += len;
return true;
};
auto class_resolver = [&](const c10::QualifiedName& qn) {
importCallback(qn.prefix());
return c10::StrongTypePtr(
compilation_unit_, compilation_unit_->get_class(qn));
};
auto read_record = [&](const std::string& name) {
std::stringstream ss;
ss << archive_name << "/" << name;
return std::get<0>(reader_->getRecord(ss.str()));
};
Unpickler unpickler(
reader, std::move(class_resolver), std::move(read_record), device_);
return unpickler.parse_ivalue();
}
script::Module ScriptModuleDeserializer::deserialize(
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
C10_LOG_API_USAGE_ONCE("torch.script.load");
device_ = device;
// Load extra files.
for (const auto& kv : extra_files) {
const std::string& key = "extra/" + kv.first;
if (reader_->hasFile(key)) {
if (reader_->hasRecord(key)) {
at::DataPtr meta_ptr;
size_t meta_size;
std::tie(meta_ptr, meta_size) = reader_->getRecord(key);
@ -139,48 +195,14 @@ script::Module ScriptModuleDeserializer::deserialize(
std::string(static_cast<char*>(meta_ptr.get()), meta_size);
}
}
loadTensorTable(&model_def);
if (proto_version_ < 6) {
if (proto_version_ == 2) {
const auto& list =
LEGACY_loadPickleArchive("attributes.pkl").toGenericList();
LEGACY_pickled_ivalues_.insert(
LEGACY_pickled_ivalues_.end(), list.begin(), list.end());
} else if (proto_version_ >= 3) {
LEGACY_pickled_ivalues_ =
LEGACY_loadPickleArchive("attributes.pkl").toTuple()->elements();
}
moduleStack_.push_back("__torch__");
const auto& module_def = model_def.main_module();
return LEGACY_convertModule(module_def);
} else {
at::DataPtr pickle_ptr;
size_t pickle_size;
std::tie(pickle_ptr, pickle_size) = reader_->getRecord("data.pkl");
size_t bytes_read = 0;
auto data = reinterpret_cast<const char*>(pickle_ptr.get());
auto reader = [&](char* buffer, size_t len) {
if (bytes_read + len > pickle_size) {
return false;
}
// Copy len bytes into buffer
const char* start = data + bytes_read;
std::memcpy(buffer, start, len);
bytes_read += len;
return true;
};
Unpickler unpickler(
reader, &tensor_table_, [&](const c10::QualifiedName& qn) {
importCallback(qn.prefix());
return c10::StrongTypePtr(
compilation_unit_, compilation_unit_->get_class(qn));
});
return script::Module(unpickler.parseModule().toObject());
if (reader_->hasRecord("model.json")) {
return LEGACY_deserialize();
}
auto tuple = readArchive("constants").toTuple();
for (auto constant : tuple->elements()) {
constants_table_.push_back(constant.toTensor());
}
return script::Module(readArchive("data").toObject());
}
IValue ScriptModuleDeserializer::LEGACY_loadPickleArchive(
@ -191,23 +213,24 @@ IValue ScriptModuleDeserializer::LEGACY_loadPickleArchive(
auto ivalue = unpickle(
reinterpret_cast<const char*>(attributes_ptr.get()),
attributes_size,
&tensor_table_,
[&](const c10::QualifiedName& qn) {
importCallback(qn.prefix());
return c10::StrongTypePtr(
compilation_unit_, compilation_unit_->get_class(qn));
});
},
&constants_table_);
return ivalue;
}
void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
void ScriptModuleDeserializer::LEGACY_loadTensorTable(
torch::ModelDef* model_def) {
std::unordered_map<std::string, at::Storage> storageMap;
for (const torch::TensorDef& tensor : model_def->tensors()) {
tensor_table_.emplace_back(loadTensor(tensor, storageMap));
constants_table_.emplace_back(LEGACY_loadTensor(tensor, storageMap));
}
}
at::Tensor ScriptModuleDeserializer::loadTensor(
at::Tensor ScriptModuleDeserializer::LEGACY_loadTensor(
const torch::TensorDef& tensor_proto,
std::unordered_map<std::string, at::Storage>& storageMap) {
std::vector<int64_t> dims(
@ -301,17 +324,18 @@ void ScriptModuleDeserializer::importCallback(const std::string& qualifier) {
std::function<void(const std::string&)> import_callback =
[this](const std::string& qualifier) { importCallback(qualifier); };
const std::string path =
ImportExportHelpers::qualifierToPath(qualifier, proto_version_);
ImportExportHelpers::qualifierToPath(qualifier, export_prefix_);
at::DataPtr data;
size_t size;
std::tie(data, size) = reader_->getRecord(path);
std::shared_ptr<ConcreteSourceRangeUnpickler> gen_ranges = nullptr;
if (proto_version_ >= 6) {
std::string debug_file = path + ".debug_pkl";
if (reader_->hasRecord(debug_file)) {
at::DataPtr debug_data;
size_t debug_size;
std::tie(debug_data, debug_size) = reader_->getRecord(path + ".debug_pkl");
std::tie(debug_data, debug_size) = reader_->getRecord(debug_file);
gen_ranges = std::make_shared<ConcreteSourceRangeUnpickler>(
std::move(debug_data), debug_size);
}
@ -323,10 +347,10 @@ void ScriptModuleDeserializer::importCallback(const std::string& qualifier) {
gen_ranges);
script::import_libs(
compilation_unit_, qualifier, src, tensor_table_, import_callback);
compilation_unit_, qualifier, src, constants_table_, import_callback);
}
void ScriptModuleDeserializer::moduleSetState(
void ScriptModuleDeserializer::LEGACY_moduleSetState(
const script::Module& module,
IValue state) {
auto setstate = module.find_method("__setstate__");
@ -359,10 +383,10 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule(
const auto atoms = c10::QualifiedName(module_def.name()).atoms();
const size_t numPushed = atoms.size();
for (const auto& atom : atoms) {
moduleStack_.emplace_back(atom);
LEGACY_moduleStack_.emplace_back(atom);
}
auto module =
script::Module(c10::QualifiedName(moduleStack_), compilation_unit_);
auto module = script::Module(
c10::QualifiedName(LEGACY_moduleStack_), compilation_unit_);
for (int i = 0; i < module_def.submodules_size(); ++i) {
const torch::ModuleDef& sub_def = module_def.submodules(i);
auto submodule = LEGACY_convertModule(sub_def);
@ -370,7 +394,7 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule(
}
for (int i = 0; i < module_def.parameters_size(); ++i) {
const torch::ParameterDef& param_def = module_def.parameters(i);
at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
at::Tensor tensor = constants_table_.at(param_def.tensor_id());
if (param_def.is_buffer()) {
module.register_buffer(param_def.name(), tensor);
} else {
@ -425,11 +449,12 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule(
std::function<void(const std::string&)> import_callback =
[&, this](const std::string& qualifier) { importCallback(qualifier); };
script::LEGACY_import_methods(module, src, tensor_table_, import_callback);
script::LEGACY_import_methods(
module, src, constants_table_, import_callback);
}
if (module_def.has_get_state_attribute_id()) {
moduleSetState(
LEGACY_moduleSetState(
module,
LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id()));
}
@ -450,7 +475,7 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule(
}
for (size_t i = 0; i < numPushed; i++) {
moduleStack_.pop_back();
LEGACY_moduleStack_.pop_back();
}
return module;
}

View File

@ -7,36 +7,30 @@ namespace torch {
namespace jit {
namespace ImportExportHelpers {
static const std::string exportPrefix = "code/";
static const std::string LEGACY_exportPrefix = "libs/";
static const std::string kExportSuffix = "py";
std::string qualifierToPath(
const std::string& qualifier,
size_t proto_version) {
const auto kExportPrefix =
proto_version >= 6 ? exportPrefix : LEGACY_exportPrefix;
const std::string& export_prefix) {
std::string path = qualifier;
std::replace_if(
path.begin(), path.end(), [](char c) { return c == '.'; }, '/');
return kExportPrefix + path + "." + kExportSuffix;
return export_prefix + path + "." + kExportSuffix;
}
std::string pathToQualifier(
const std::string& classPath,
size_t proto_version) {
const auto kExportPrefix =
proto_version >= 6 ? exportPrefix : LEGACY_exportPrefix;
const std::string& export_prefix) {
// strip input suffix
const auto end = classPath.rfind(kExportSuffix);
AT_ASSERT(end != std::string::npos);
// strip input suffix
size_t libs_idx = classPath.find(kExportPrefix);
size_t libs_idx = classPath.find(export_prefix);
AT_ASSERT(libs_idx == 0);
AT_ASSERT(classPath.size() > kExportPrefix.size());
const auto start = kExportPrefix.size();
AT_ASSERT(classPath.size() > export_prefix.size());
const auto start = export_prefix.size();
std::string class_qualifier = classPath.substr(start, end - start);
std::replace_if(

View File

@ -13,13 +13,17 @@ namespace ImportExportHelpers {
//
// Qualifier is like: foo.bar.baz
// Returns: libs/foo/bar/baz.py
std::string qualifierToPath(const std::string& qualifier, size_t proto_version);
std::string qualifierToPath(
const std::string& qualifier,
const std::string& export_prefix);
// Convert a source file path to a class type's qualifier name.
//
// Path is like: libs/foo/bar/baz.py
// Returns: foo.bar.baz
std::string pathToQualifier(const std::string& classPath, size_t proto_version);
std::string pathToQualifier(
const std::string& classPath,
const std::string& export_prefix);
} // namespace ImportExportHelpers
} // namespace jit

View File

@ -102,7 +102,8 @@ struct ConstantTableValue : public SugaredValue {
<< " is out of bounds (constant table has "
<< constants_.size() << " entries)";
}
Value* value = m.graph()->insertConstant(constants_[offset], nullptr, loc);
Value* value =
m.graph()->insertConstant(constants_.at(offset), nullptr, loc);
// specializing tensor type on compilation messes up typing relations
value->setType(unshapedType(value->type()));

View File

@ -47,18 +47,18 @@ std::vector<char> pickle(
IValue unpickle(
std::function<bool(char*, size_t)> reader,
const std::vector<at::Tensor>* tensor_table,
ClassResolver class_resolver) {
ClassResolver class_resolver,
const std::vector<at::Tensor>* tensor_table) {
Unpickler unpickler(
std::move(reader), tensor_table, std::move(class_resolver));
std::move(reader), std::move(class_resolver), tensor_table);
return unpickler.parse_ivalue();
}
IValue unpickle(
const char* data,
size_t size,
const std::vector<at::Tensor>* tensor_table,
ClassResolver class_resolver) {
ClassResolver class_resolver,
const std::vector<at::Tensor>* tensor_table) {
size_t bytes_read = 0;
return unpickle(
[&](char* buffer, size_t len) {
@ -71,8 +71,8 @@ IValue unpickle(
bytes_read += len;
return true;
},
tensor_table,
std::move(class_resolver));
std::move(class_resolver),
tensor_table);
}
} // namespace jit

View File

@ -50,18 +50,13 @@ TORCH_API void pickle(
std::vector<at::Tensor>* tensor_table = nullptr);
/// `reader` is a function that takes in a size to read from some pickled
/// binary. `reader` should remember where it last read.
///
/// `bounds_checker` is a function that returns `true` if the reader can read
/// more data, and `false` if it cannot (i.e. if a stream has hit its end of
/// file)
///
/// binary. `reader` should remember where it last read, and return
/// false if the read was not successful.
/// See `torch::pickle` for details.
TORCH_API IValue unpickle(
std::function<const char*(size_t)> reader,
std::function<bool()> bounds_chcker,
const std::vector<at::Tensor>* tensor_table = nullptr,
ClassResolver class_resolver = nullptr);
std::function<bool(char*, size_t)> reader,
ClassResolver class_resolver,
const std::vector<at::Tensor>* tensor_table);
/// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
///
@ -72,8 +67,8 @@ TORCH_API IValue unpickle(
TORCH_API IValue unpickle(
const char* data,
size_t size,
const std::vector<at::Tensor>* tensor_table = nullptr,
ClassResolver class_resolver = nullptr);
ClassResolver class_resolver = nullptr,
const std::vector<at::Tensor>* tensor_table = nullptr);
} // namespace jit
} // namespace torch

View File

@ -151,9 +151,9 @@ void Pickler::pushIValueImpl(const IValue& ivalue) {
} else if (ivalue.isTuple()) {
pushTuple(ivalue);
} else if (ivalue.isDouble()) {
pushDouble(ivalue);
pushDouble(ivalue.toDouble());
} else if (ivalue.isInt()) {
pushInt(ivalue);
pushInt(ivalue.toInt());
} else if (ivalue.isBool()) {
if (ivalue.toBool()) {
push<OpCode>(OpCode::NEWTRUE);
@ -244,8 +244,7 @@ void Pickler::pushIValue(const IValue& ivalue) {
}
}
void Pickler::pushInt(const IValue& ivalue) {
auto n = ivalue.toInt();
void Pickler::pushInt(int64_t n) {
if (n >= std::numeric_limits<int8_t>::min() &&
n <= std::numeric_limits<int8_t>::max()) {
push<OpCode>(OpCode::BININT1);
@ -311,9 +310,11 @@ void Pickler::pushStorageOfTensor(const at::Tensor& tensor) {
// root_key
pushString(std::to_string(tensor_data_.size()));
// location
pushString("cpu");
std::stringstream ss;
ss << tensor.device();
pushString(ss.str());
// size
pushInt(tensor.numel());
pushInt(tensor.storage().size());
// view_metadata
push<OpCode>(OpCode::NONE);
push<OpCode>(OpCode::TUPLE);
@ -362,17 +363,18 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) {
// The format here is the same one used by `torch.save()`. The code for the
// format can be found in `torch/serialization.py`.
auto tensor = ivalue.toTensor();
bool quantized = tensor.is_quantized();
// The arguments to this function are:
// storage, storage_offset, size, stride, requires_grad, backward_hooks
pushGlobal("torch._utils", "_rebuild_tensor_v2");
pushGlobal(
"torch._utils", quantized ? "_rebuild_qtensor" : "_rebuild_tensor_v2");
push<OpCode>(OpCode::MARK);
pushStorageOfTensor(tensor);
// storage offset
int64_t storage_offset = 0;
pushInt(storage_offset);
pushInt(tensor.storage_offset());
// size
push<OpCode>(OpCode::MARK);
@ -388,6 +390,11 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) {
}
push<OpCode>(OpCode::TUPLE);
if (quantized) {
pushDouble(tensor.q_scale());
pushInt(tensor.q_zero_point());
}
// requires_grad
pushIValue(tensor.requires_grad());
@ -447,8 +454,7 @@ void Pickler::pushSpecializedList(
push<OpCode>(OpCode::REDUCE);
}
void Pickler::pushDouble(const IValue& ivalue) {
double value = ivalue.toDouble();
void Pickler::pushDouble(double value) {
AT_ASSERT(sizeof(double) == 8);
char* bytes = reinterpret_cast<char*>(&value);
@ -522,16 +528,6 @@ IValue Unpickler::parse_ivalue() {
return stack_[0];
}
IValue Unpickler::parseModule() {
run();
TORCH_CHECK(
stack_.size() == 1,
"Unpickler expected 1 element on the stack, but found ",
stack_.size());
return stack_[0];
}
double Unpickler::readFloat() {
AT_ASSERT(sizeof(double) == 8);
double big_endian = read<double>();
@ -600,6 +596,12 @@ static IValue toSpecializedList(const IValue& generic) {
return IValue(std::move(specialized));
}
static std::vector<int64_t> tupleToIntList(const IValue& v) {
return fmap(v.toTuple()->elements(), [](const IValue& v) -> int64_t {
return v.toInt();
});
}
OpCode Unpickler::readInstruction() {
auto opcode = readOpCode();
switch (opcode) {
@ -748,6 +750,63 @@ OpCode Unpickler::readInstruction() {
AT_ERROR("Unknown pickler class id");
}
});
} else if (
module_name == "torch._utils" &&
(class_name == "_rebuild_tensor_v2" ||
class_name == "_rebuild_qtensor")) {
bool quantized = class_name == "_rebuild_qtensor";
globals_.emplace_back([this, quantized] {
auto tup = pop(stack_).toTuple();
const auto& elements = tup->elements();
size_t idx = 0;
auto storage_tensor = elements.at(idx++).toTensor();
int64_t storage_offset = elements.at(idx++).toInt();
std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
std::vector<int64_t> stride = tupleToIntList(elements.at(idx++));
double q_scale = 0.;
int64_t q_zero_point = 0;
if (quantized) {
q_scale = elements.at(idx++).toDouble();
q_zero_point = elements.at(idx++).toInt();
}
bool requires_grad = elements.at(idx++).toBool();
// elements[idx++] is empty backwards hooks
at::Tensor result = quantized
? at::_empty_affine_quantized(
{}, storage_tensor.options(), q_scale, q_zero_point)
: at::empty({0}, storage_tensor.options());
at::TensorImpl* impl = result.unsafeGetTensorImpl();
impl->set_storage(storage_tensor.storage());
impl->set_storage_offset(storage_offset);
impl->set_sizes_and_strides(size, stride);
result = autograd::make_variable(result, requires_grad);
stack_.push_back(std::move(result));
});
} else if (module_name == "collections" && class_name == "OrderedDict") {
globals_.emplace_back([this] {
// drop the Tuple that was argument to OrderedDict, and replace it
// with None OrderedDicts only appear in tensor deserialization and
// their value is never used
stack_.back() = IValue();
});
} else if (module_name == "torch") {
c10::optional<c10::ScalarType> scalar_type;
#define CHECK_SCALAR(_, name) \
if (class_name == #name "Storage") { \
scalar_type = c10::k##name; \
}
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CHECK_SCALAR)
#undef CHECK_SCALAR
// NOTE: this does not put a global into the global table,
// like the other branches here because no REDUCE or BUILD will
// be called on this value. Instead, we just put it on the stack
// and return early
AT_ASSERT(
scalar_type.has_value(),
"class name not understood: torch.",
class_name);
stack_.emplace_back(int64_t(*scalar_type));
return opcode;
} else {
AT_ASSERT(class_resolver_);
at::StrongTypePtr type =
@ -797,12 +856,52 @@ OpCode Unpickler::readInstruction() {
// stack is: <functor_arg>
globals_.at(idx)();
} break;
default:
case OpCode::BINPERSID: {
auto args = pop(stack_).toTuple()->elements();
AT_ASSERT(
args.at(0).toStringRef() == "storage",
"unknown PERSID key ",
args.at(0).toStringRef());
at::ScalarType type = args.at(1).toScalarType();
const std::string& key = args.at(2).toStringRef();
at::Device device(args.at(3).toStringRef());
if (device_) {
device = *device_;
}
at::DataPtr storage_ptr = read_record_(key);
int64_t numel = args.at(4).toInt();
at::Storage storage(
at::CPU(type).typeMeta(),
numel,
std::move(storage_ptr),
/*allocator=*/nullptr,
/*resizable=*/false); // NB: we didn't set any allocator for the
// tensor
auto options = at::CPU(type).options();
at::Tensor tensor;
if (options.backend() == c10::Backend::QuantizedCPU) {
tensor = at::_empty_affine_quantized({}, options, 0, 0)
.set_(storage, 0, {}, {});
} else {
tensor = at::empty({0}, options).set_(storage);
}
if (device.type() == at::DeviceType::CUDA) {
tensor = tensor.to(device, tensor.scalar_type());
} else if (device.type() != at::DeviceType::CPU) {
AT_ERROR(
"supported devices include CPU and CUDA, however got ",
at::DeviceTypeName(device.type(), false));
}
stack_.push_back(std::move(tensor));
} break;
default: {
AT_ERROR(
"Unknown opcode for unpickling at ",
reinterpret_cast<void*>(opcode),
": ",
int(static_cast<uint8_t>(opcode)));
} break;
}
return opcode;
}

View File

@ -147,12 +147,16 @@ class Pickler {
void startTuple();
void endTuple();
const std::vector<WriteableTensorData>& tensorData() {
return tensor_data_;
}
private:
void pushIValueImpl(const IValue& ivalue);
void pushDict(const IValue& ivalue);
void pushDouble(const IValue& ivalue);
void pushDouble(double value);
void pushGenericList(const IValue& ivalue);
void pushInt(const IValue& ivalue);
void pushInt(int64_t value);
void pushIntList(const IValue& ivalue);
void pushList(const IValue& ivalue);
void pushLiteralTensor(const IValue& ivalue);
@ -233,16 +237,29 @@ class Unpickler {
TH_DISALLOW_COPY_AND_ASSIGN(Unpickler);
public:
// tensors inside the pickle are references to the tensor_table
Unpickler(
std::function<bool(char*, size_t)> reader,
const std::vector<at::Tensor>* tensor_table,
ClassResolver class_resolver)
ClassResolver class_resolver,
const std::vector<at::Tensor>* tensor_table)
: reader_(reader),
tensor_table_(tensor_table),
class_resolver_(std::move(class_resolver)) {}
// tensors inside the pickle contain meta-data, the raw tensor
// dead is retrieved by calling `read_record`.
Unpickler(
std::function<bool(char*, size_t)> reader,
ClassResolver class_resolver,
std::function<at::DataPtr(const std::string&)> read_record,
c10::optional<at::Device> device)
: reader_(reader),
tensor_table_(nullptr),
class_resolver_(std::move(class_resolver)),
read_record_(std::move(read_record)),
device_(std::move(device)) {}
IValue parse_ivalue();
IValue parseModule();
private:
// No arguments ensures that a template arugment must be specified
@ -282,6 +299,9 @@ class Unpickler {
// optionally nullptr, needs to be present for creating classes
ClassResolver class_resolver_;
IValue empty_tuple_;
std::function<at::DataPtr(const std::string&)> read_record_;
c10::optional<at::Device> device_;
};
// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor