mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Load tensors directly from pickle archive
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23281 Test Plan: Imported from OSS Differential Revision: D16452815 Pulled By: zdevito fbshipit-source-id: 918eef3ad444b598ab655c39037e4baafdcb51e1
This commit is contained in:
committed by
Facebook Github Bot
parent
c33adf539c
commit
e2ccccee9a
@ -165,7 +165,7 @@ static std::string getPadding(size_t cursor, const std::string& filename, size_t
|
||||
return buf;
|
||||
}
|
||||
|
||||
bool PyTorchStreamReader::hasFile(const std::string& name) {
|
||||
bool PyTorchStreamReader::hasRecord(const std::string& name) {
|
||||
std::stringstream ss;
|
||||
ss << archive_name_ << "/" << name;
|
||||
mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0);
|
||||
@ -177,7 +177,7 @@ bool PyTorchStreamReader::hasFile(const std::string& name) {
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t PyTorchStreamReader::getFileID(const std::string& name) {
|
||||
size_t PyTorchStreamReader::getRecordID(const std::string& name) {
|
||||
std::stringstream ss;
|
||||
ss << archive_name_ << "/" << name;
|
||||
size_t result = mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0);
|
||||
@ -190,7 +190,7 @@ size_t PyTorchStreamReader::getFileID(const std::string& name) {
|
||||
|
||||
// return dataptr, size
|
||||
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
|
||||
size_t key = getFileID(name);
|
||||
size_t key = getRecordID(name);
|
||||
mz_zip_archive_file_stat stat;
|
||||
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
||||
valid("retrieving file meta-data");
|
||||
@ -208,7 +208,7 @@ static int64_t read_le_16(uint8_t* buf) {
|
||||
|
||||
size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
|
||||
mz_zip_archive_file_stat stat;
|
||||
mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat);
|
||||
mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
|
||||
valid("retriving file meta-data");
|
||||
uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
|
||||
in_->read(
|
||||
|
||||
@ -92,9 +92,6 @@ namespace serialize {
|
||||
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x1L;
|
||||
|
||||
// Writer-specific constants
|
||||
constexpr uint64_t kFileFormatVersion = 0x2L;
|
||||
|
||||
// Writer-specific constants
|
||||
constexpr uint64_t kFieldAlignment = 64;
|
||||
|
||||
@ -107,7 +104,7 @@ class CAFFE2_API PyTorchStreamReader final {
|
||||
// return dataptr, size
|
||||
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
|
||||
size_t getRecordOffset(const std::string& name);
|
||||
bool hasFile(const std::string& name);
|
||||
bool hasRecord(const std::string& name);
|
||||
|
||||
~PyTorchStreamReader();
|
||||
|
||||
@ -115,7 +112,7 @@ class CAFFE2_API PyTorchStreamReader final {
|
||||
void init();
|
||||
size_t read(uint64_t pos, char* buf, size_t n);
|
||||
void valid(const char* what);
|
||||
size_t getFileID(const std::string& name);
|
||||
size_t getRecordID(const std::string& name);
|
||||
|
||||
friend size_t
|
||||
istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n);
|
||||
|
||||
@ -39,9 +39,9 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
|
||||
// read records through readers
|
||||
PyTorchStreamReader reader(&iss);
|
||||
ASSERT_TRUE(reader.hasFile("key1"));
|
||||
ASSERT_TRUE(reader.hasFile("key2"));
|
||||
ASSERT_FALSE(reader.hasFile("key2000"));
|
||||
ASSERT_TRUE(reader.hasRecord("key1"));
|
||||
ASSERT_TRUE(reader.hasRecord("key2"));
|
||||
ASSERT_FALSE(reader.hasRecord("key2000"));
|
||||
at::DataPtr data_ptr;
|
||||
int64_t size;
|
||||
std::tie(data_ptr, size) = reader.getRecord("key1");
|
||||
|
||||
@ -9036,6 +9036,7 @@ a")
|
||||
m_import = self.getExportImportCopy(m_orig)
|
||||
|
||||
self.assertEqual(m_orig.foo(), m_import.foo())
|
||||
|
||||
self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
|
||||
self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ ScriptCall ScriptCall::fromMessage(const Message& message) {
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
|
||||
auto value = jit::unpickle(payload, payload_size, &message.tensors());
|
||||
auto value = jit::unpickle(payload, payload_size, nullptr, &message.tensors());
|
||||
|
||||
auto values = value.toTuple()->elements();
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ Message ScriptRet::toMessage() {
|
||||
ScriptRet ScriptRet::fromMessage(const Message& message) {
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
auto value = jit::unpickle(payload, payload_size, &message.tensors());
|
||||
auto value = jit::unpickle(payload, payload_size, nullptr, &message.tensors());
|
||||
return ScriptRet(std::move(value));
|
||||
}
|
||||
|
||||
|
||||
@ -618,34 +618,70 @@ class ScriptModuleSerializer {
|
||||
// 1. Several tests that depend on the serialization format details.
|
||||
// 2. The emit module hook (since combining the old export and new import code
|
||||
// is going to cause jitter)
|
||||
class ScriptModuleSerializer2 : public ScriptModuleSerializer {
|
||||
class ScriptModuleSerializer2 {
|
||||
public:
|
||||
ScriptModuleSerializer2(const std::string& filename)
|
||||
: ScriptModuleSerializer(filename) {}
|
||||
: writer_(filename.c_str()) {}
|
||||
|
||||
ScriptModuleSerializer2(std::ostream* ofs) : ScriptModuleSerializer(ofs) {}
|
||||
|
||||
private:
|
||||
void convertModel(
|
||||
ScriptModuleSerializer2(std::ostream* ofs) : ofs_(), writer_(ofs) {}
|
||||
void serialize(
|
||||
const script::Module& module,
|
||||
torch::ModelDef* model_def,
|
||||
const script::ExtraFilesMap& extra_files) override {
|
||||
model_def->set_producer_name("pytorch");
|
||||
model_def->set_producer_version("1.0"); // TODO: set the producer version
|
||||
// using appropriate function call
|
||||
model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST);
|
||||
// Serialize all code info.
|
||||
convertNamedType(module.type());
|
||||
// Then pickle the module
|
||||
auto data = pickle(module.module_object(), &tensor_table_);
|
||||
writer_.writeRecord("data.pkl", data.data(), data.size());
|
||||
|
||||
writeTensorTable(model_def);
|
||||
writeLibs(model_def);
|
||||
const script::ExtraFilesMap& extra_files) {
|
||||
C10_LOG_API_USAGE_ONCE("torch.script.save");
|
||||
writeExtraFiles(module, extra_files);
|
||||
// Serialize all code info.
|
||||
writeCode(module.type());
|
||||
// The tensor constants from the code are written to a separate archive
|
||||
// so loading the code does not depend on loading the data
|
||||
std::vector<IValue> ivalue_constants(
|
||||
constant_table_.begin(), constant_table_.end());
|
||||
writeArchive("constants", c10::ivalue::Tuple::create(ivalue_constants));
|
||||
// finally we serialize the model
|
||||
writeArchive("data", module.module_object());
|
||||
}
|
||||
|
||||
void writeLibs(ModelDef* model_def) override {
|
||||
private:
|
||||
void writeArchive(const std::string& archive_name, const IValue& value) {
|
||||
std::vector<char> data;
|
||||
Pickler data_pickle(
|
||||
[&](const char* buf, size_t size) {
|
||||
data.insert(data.end(), buf, buf + size);
|
||||
},
|
||||
nullptr);
|
||||
data_pickle.protocol();
|
||||
data_pickle.pushIValue(value);
|
||||
data_pickle.stop();
|
||||
size_t i = 0;
|
||||
for (const auto& td : data_pickle.tensorData()) {
|
||||
std::stringstream fname;
|
||||
fname << archive_name << "/" << i++;
|
||||
writer_.writeRecord(fname.str(), td.data(), td.sizeInBytes());
|
||||
}
|
||||
std::stringstream fname;
|
||||
fname << archive_name << ".pkl";
|
||||
writer_.writeRecord(fname.str(), data.data(), data.size());
|
||||
}
|
||||
|
||||
void writeExtraFiles(
|
||||
const script::Module& module,
|
||||
const script::ExtraFilesMap& extra_files) {
|
||||
// Write out extra files.
|
||||
for (const auto& kv : extra_files) {
|
||||
const std::string key = "extra/" + kv.first;
|
||||
writer_.writeRecord(key, kv.second.data(), kv.second.size());
|
||||
}
|
||||
auto hook = GetExtraFilesHook();
|
||||
if (hook) {
|
||||
script::ExtraFilesMap hook_files = hook(module);
|
||||
for (const auto& kv : hook_files) {
|
||||
const std::string key = "extra/" + kv.first;
|
||||
writer_.writeRecord(key, kv.second.data(), kv.second.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void writeCode(const at::NamedTypePtr& root_type) {
|
||||
convertNamedType(root_type);
|
||||
static const std::string opset_string =
|
||||
c10::str("op_version_set = ", CURRENT_OP_VERSION_SET, "\n");
|
||||
|
||||
@ -661,7 +697,7 @@ class ScriptModuleSerializer2 : public ScriptModuleSerializer {
|
||||
|
||||
// For the type, foo.bar.Baz
|
||||
const std::string filename = ImportExportHelpers::qualifierToPath(
|
||||
converted_type->name()->prefix(), torch::PROTO_VERSION_NEWEST);
|
||||
converted_type->name()->prefix(), "code/");
|
||||
// End state: filename is "foo/bar.py", in which we will define a class
|
||||
// named Baz
|
||||
auto& stream = fileToSrc[filename];
|
||||
@ -708,7 +744,7 @@ class ScriptModuleSerializer2 : public ScriptModuleSerializer {
|
||||
/*compress=*/true);
|
||||
}
|
||||
}
|
||||
void convertNamedType(const c10::NamedTypePtr& class_type) override {
|
||||
void convertNamedType(const c10::NamedTypePtr& class_type) {
|
||||
if (converted_types_.contains(class_type)) {
|
||||
return;
|
||||
}
|
||||
@ -720,7 +756,7 @@ class ScriptModuleSerializer2 : public ScriptModuleSerializer {
|
||||
source_stream,
|
||||
source_ranges,
|
||||
class_type,
|
||||
tensor_table_,
|
||||
constant_table_,
|
||||
class_deps,
|
||||
/*enforce_importable=*/true);
|
||||
|
||||
@ -739,6 +775,10 @@ class ScriptModuleSerializer2 : public ScriptModuleSerializer {
|
||||
converted_types_.insert(class_type, std::move(info));
|
||||
}
|
||||
|
||||
std::ofstream ofs_;
|
||||
caffe2::serialize::PyTorchStreamWriter writer_;
|
||||
std::vector<at::Tensor> constant_table_;
|
||||
|
||||
// all deps used by this module hierarchy
|
||||
struct TypeInfo {
|
||||
std::string source;
|
||||
@ -804,8 +844,8 @@ void ScriptModuleSerializer::writeLibs(torch::ModelDef* model_def) {
|
||||
const auto& class_src = item.value();
|
||||
|
||||
// For the type, foo.bar.Baz
|
||||
const std::string filename =
|
||||
ImportExportHelpers::qualifierToPath(class_type->name()->prefix(), 5);
|
||||
const std::string filename = ImportExportHelpers::qualifierToPath(
|
||||
class_type->name()->prefix(), "libs/");
|
||||
// End state: filename is "foo/bar.py", in which we will define a class
|
||||
// named Baz
|
||||
fileToSrc[filename] << class_src;
|
||||
@ -816,8 +856,8 @@ void ScriptModuleSerializer::writeLibs(torch::ModelDef* model_def) {
|
||||
std::unordered_set<std::string> written_files;
|
||||
for (const auto& item : converted_classes_) {
|
||||
const c10::NamedTypePtr& class_type = item.key();
|
||||
const std::string filename =
|
||||
ImportExportHelpers::qualifierToPath(class_type->name()->prefix(), 5);
|
||||
const std::string filename = ImportExportHelpers::qualifierToPath(
|
||||
class_type->name()->prefix(), "libs/");
|
||||
if (written_files.count(filename)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -68,34 +68,34 @@ class ScriptModuleDeserializer final {
|
||||
script::ExtraFilesMap& extra_files);
|
||||
|
||||
private:
|
||||
at::Tensor loadTensor(
|
||||
IValue readArchive(const std::string& archive_name);
|
||||
script::Module LEGACY_deserialize();
|
||||
at::Tensor LEGACY_loadTensor(
|
||||
const torch::TensorDef& tensor_proto,
|
||||
std::unordered_map<std::string, at::Storage>& storageMap);
|
||||
|
||||
void loadTensorTable(torch::ModelDef* model_def);
|
||||
void LEGACY_loadTensorTable(torch::ModelDef* model_def);
|
||||
void importCallback(const std::string& qualifier);
|
||||
void moduleSetState(const script::Module& module, IValue state);
|
||||
void LEGACY_moduleSetState(const script::Module& module, IValue state);
|
||||
|
||||
std::shared_ptr<script::CompilationUnit> compilation_unit_;
|
||||
|
||||
std::unique_ptr<PyTorchStreamReader> reader_;
|
||||
c10::optional<at::Device> device_;
|
||||
std::vector<std::string> moduleStack_;
|
||||
std::vector<std::string> LEGACY_moduleStack_;
|
||||
|
||||
std::vector<at::Tensor> tensor_table_;
|
||||
std::vector<at::Tensor> constants_table_;
|
||||
std::unordered_set<std::string> imported_libs_;
|
||||
|
||||
IValue LEGACY_loadPickleArchive(const std::string& name);
|
||||
script::Module LEGACY_convertModule(const torch::ModuleDef& module_def);
|
||||
std::vector<IValue> LEGACY_pickled_ivalues_;
|
||||
size_t proto_version_;
|
||||
std::string export_prefix_ = "code/";
|
||||
};
|
||||
|
||||
script::Module ScriptModuleDeserializer::deserialize(
|
||||
c10::optional<at::Device> device,
|
||||
script::ExtraFilesMap& extra_files) {
|
||||
C10_LOG_API_USAGE_ONCE("torch.script.load");
|
||||
script::Module ScriptModuleDeserializer::LEGACY_deserialize() {
|
||||
torch::ModelDef model_def;
|
||||
|
||||
at::DataPtr data_ptr;
|
||||
size_t data_size;
|
||||
std::tie(data_ptr, data_size) = reader_->getRecord("model.json");
|
||||
@ -125,13 +125,69 @@ script::Module ScriptModuleDeserializer::deserialize(
|
||||
AT_ASSERTM(
|
||||
model_def.ParseFromString(binary_string),
|
||||
"JSON transcoder produced invalid protobuf output.");
|
||||
device_ = device;
|
||||
proto_version_ = model_def.proto_version();
|
||||
auto proto_version = model_def.proto_version();
|
||||
export_prefix_ = "libs/";
|
||||
|
||||
LEGACY_loadTensorTable(&model_def);
|
||||
AT_ASSERT(proto_version < 6);
|
||||
if (proto_version == 2) {
|
||||
const auto& list =
|
||||
LEGACY_loadPickleArchive("attributes.pkl").toGenericList();
|
||||
LEGACY_pickled_ivalues_.insert(
|
||||
LEGACY_pickled_ivalues_.end(), list.begin(), list.end());
|
||||
} else if (proto_version >= 3) {
|
||||
LEGACY_pickled_ivalues_ =
|
||||
LEGACY_loadPickleArchive("attributes.pkl").toTuple()->elements();
|
||||
}
|
||||
LEGACY_moduleStack_.push_back("__torch__");
|
||||
const auto& module_def = model_def.main_module();
|
||||
return LEGACY_convertModule(module_def);
|
||||
}
|
||||
|
||||
IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
|
||||
std::stringstream picklename;
|
||||
picklename << archive_name << ".pkl";
|
||||
at::DataPtr pickle_ptr;
|
||||
size_t pickle_size;
|
||||
std::tie(pickle_ptr, pickle_size) = reader_->getRecord(picklename.str());
|
||||
|
||||
size_t bytes_read = 0;
|
||||
auto data = reinterpret_cast<const char*>(pickle_ptr.get());
|
||||
auto reader = [&](char* buffer, size_t len) {
|
||||
if (bytes_read + len > pickle_size) {
|
||||
return false;
|
||||
}
|
||||
// Copy len bytes into buffer
|
||||
const char* start = data + bytes_read;
|
||||
std::memcpy(buffer, start, len);
|
||||
bytes_read += len;
|
||||
return true;
|
||||
};
|
||||
|
||||
auto class_resolver = [&](const c10::QualifiedName& qn) {
|
||||
importCallback(qn.prefix());
|
||||
return c10::StrongTypePtr(
|
||||
compilation_unit_, compilation_unit_->get_class(qn));
|
||||
};
|
||||
auto read_record = [&](const std::string& name) {
|
||||
std::stringstream ss;
|
||||
ss << archive_name << "/" << name;
|
||||
return std::get<0>(reader_->getRecord(ss.str()));
|
||||
};
|
||||
Unpickler unpickler(
|
||||
reader, std::move(class_resolver), std::move(read_record), device_);
|
||||
return unpickler.parse_ivalue();
|
||||
}
|
||||
|
||||
script::Module ScriptModuleDeserializer::deserialize(
|
||||
c10::optional<at::Device> device,
|
||||
script::ExtraFilesMap& extra_files) {
|
||||
C10_LOG_API_USAGE_ONCE("torch.script.load");
|
||||
device_ = device;
|
||||
// Load extra files.
|
||||
for (const auto& kv : extra_files) {
|
||||
const std::string& key = "extra/" + kv.first;
|
||||
if (reader_->hasFile(key)) {
|
||||
if (reader_->hasRecord(key)) {
|
||||
at::DataPtr meta_ptr;
|
||||
size_t meta_size;
|
||||
std::tie(meta_ptr, meta_size) = reader_->getRecord(key);
|
||||
@ -139,48 +195,14 @@ script::Module ScriptModuleDeserializer::deserialize(
|
||||
std::string(static_cast<char*>(meta_ptr.get()), meta_size);
|
||||
}
|
||||
}
|
||||
|
||||
loadTensorTable(&model_def);
|
||||
|
||||
if (proto_version_ < 6) {
|
||||
if (proto_version_ == 2) {
|
||||
const auto& list =
|
||||
LEGACY_loadPickleArchive("attributes.pkl").toGenericList();
|
||||
LEGACY_pickled_ivalues_.insert(
|
||||
LEGACY_pickled_ivalues_.end(), list.begin(), list.end());
|
||||
} else if (proto_version_ >= 3) {
|
||||
LEGACY_pickled_ivalues_ =
|
||||
LEGACY_loadPickleArchive("attributes.pkl").toTuple()->elements();
|
||||
}
|
||||
moduleStack_.push_back("__torch__");
|
||||
const auto& module_def = model_def.main_module();
|
||||
return LEGACY_convertModule(module_def);
|
||||
} else {
|
||||
at::DataPtr pickle_ptr;
|
||||
size_t pickle_size;
|
||||
std::tie(pickle_ptr, pickle_size) = reader_->getRecord("data.pkl");
|
||||
|
||||
size_t bytes_read = 0;
|
||||
auto data = reinterpret_cast<const char*>(pickle_ptr.get());
|
||||
auto reader = [&](char* buffer, size_t len) {
|
||||
if (bytes_read + len > pickle_size) {
|
||||
return false;
|
||||
}
|
||||
// Copy len bytes into buffer
|
||||
const char* start = data + bytes_read;
|
||||
std::memcpy(buffer, start, len);
|
||||
bytes_read += len;
|
||||
return true;
|
||||
};
|
||||
|
||||
Unpickler unpickler(
|
||||
reader, &tensor_table_, [&](const c10::QualifiedName& qn) {
|
||||
importCallback(qn.prefix());
|
||||
return c10::StrongTypePtr(
|
||||
compilation_unit_, compilation_unit_->get_class(qn));
|
||||
});
|
||||
return script::Module(unpickler.parseModule().toObject());
|
||||
if (reader_->hasRecord("model.json")) {
|
||||
return LEGACY_deserialize();
|
||||
}
|
||||
auto tuple = readArchive("constants").toTuple();
|
||||
for (auto constant : tuple->elements()) {
|
||||
constants_table_.push_back(constant.toTensor());
|
||||
}
|
||||
return script::Module(readArchive("data").toObject());
|
||||
}
|
||||
|
||||
IValue ScriptModuleDeserializer::LEGACY_loadPickleArchive(
|
||||
@ -191,23 +213,24 @@ IValue ScriptModuleDeserializer::LEGACY_loadPickleArchive(
|
||||
auto ivalue = unpickle(
|
||||
reinterpret_cast<const char*>(attributes_ptr.get()),
|
||||
attributes_size,
|
||||
&tensor_table_,
|
||||
[&](const c10::QualifiedName& qn) {
|
||||
importCallback(qn.prefix());
|
||||
return c10::StrongTypePtr(
|
||||
compilation_unit_, compilation_unit_->get_class(qn));
|
||||
});
|
||||
},
|
||||
&constants_table_);
|
||||
return ivalue;
|
||||
}
|
||||
|
||||
void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
|
||||
void ScriptModuleDeserializer::LEGACY_loadTensorTable(
|
||||
torch::ModelDef* model_def) {
|
||||
std::unordered_map<std::string, at::Storage> storageMap;
|
||||
for (const torch::TensorDef& tensor : model_def->tensors()) {
|
||||
tensor_table_.emplace_back(loadTensor(tensor, storageMap));
|
||||
constants_table_.emplace_back(LEGACY_loadTensor(tensor, storageMap));
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor ScriptModuleDeserializer::loadTensor(
|
||||
at::Tensor ScriptModuleDeserializer::LEGACY_loadTensor(
|
||||
const torch::TensorDef& tensor_proto,
|
||||
std::unordered_map<std::string, at::Storage>& storageMap) {
|
||||
std::vector<int64_t> dims(
|
||||
@ -301,17 +324,18 @@ void ScriptModuleDeserializer::importCallback(const std::string& qualifier) {
|
||||
std::function<void(const std::string&)> import_callback =
|
||||
[this](const std::string& qualifier) { importCallback(qualifier); };
|
||||
const std::string path =
|
||||
ImportExportHelpers::qualifierToPath(qualifier, proto_version_);
|
||||
ImportExportHelpers::qualifierToPath(qualifier, export_prefix_);
|
||||
at::DataPtr data;
|
||||
size_t size;
|
||||
std::tie(data, size) = reader_->getRecord(path);
|
||||
|
||||
std::shared_ptr<ConcreteSourceRangeUnpickler> gen_ranges = nullptr;
|
||||
if (proto_version_ >= 6) {
|
||||
|
||||
std::string debug_file = path + ".debug_pkl";
|
||||
if (reader_->hasRecord(debug_file)) {
|
||||
at::DataPtr debug_data;
|
||||
size_t debug_size;
|
||||
std::tie(debug_data, debug_size) = reader_->getRecord(path + ".debug_pkl");
|
||||
|
||||
std::tie(debug_data, debug_size) = reader_->getRecord(debug_file);
|
||||
gen_ranges = std::make_shared<ConcreteSourceRangeUnpickler>(
|
||||
std::move(debug_data), debug_size);
|
||||
}
|
||||
@ -323,10 +347,10 @@ void ScriptModuleDeserializer::importCallback(const std::string& qualifier) {
|
||||
gen_ranges);
|
||||
|
||||
script::import_libs(
|
||||
compilation_unit_, qualifier, src, tensor_table_, import_callback);
|
||||
compilation_unit_, qualifier, src, constants_table_, import_callback);
|
||||
}
|
||||
|
||||
void ScriptModuleDeserializer::moduleSetState(
|
||||
void ScriptModuleDeserializer::LEGACY_moduleSetState(
|
||||
const script::Module& module,
|
||||
IValue state) {
|
||||
auto setstate = module.find_method("__setstate__");
|
||||
@ -359,10 +383,10 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule(
|
||||
const auto atoms = c10::QualifiedName(module_def.name()).atoms();
|
||||
const size_t numPushed = atoms.size();
|
||||
for (const auto& atom : atoms) {
|
||||
moduleStack_.emplace_back(atom);
|
||||
LEGACY_moduleStack_.emplace_back(atom);
|
||||
}
|
||||
auto module =
|
||||
script::Module(c10::QualifiedName(moduleStack_), compilation_unit_);
|
||||
auto module = script::Module(
|
||||
c10::QualifiedName(LEGACY_moduleStack_), compilation_unit_);
|
||||
for (int i = 0; i < module_def.submodules_size(); ++i) {
|
||||
const torch::ModuleDef& sub_def = module_def.submodules(i);
|
||||
auto submodule = LEGACY_convertModule(sub_def);
|
||||
@ -370,7 +394,7 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule(
|
||||
}
|
||||
for (int i = 0; i < module_def.parameters_size(); ++i) {
|
||||
const torch::ParameterDef& param_def = module_def.parameters(i);
|
||||
at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
|
||||
at::Tensor tensor = constants_table_.at(param_def.tensor_id());
|
||||
if (param_def.is_buffer()) {
|
||||
module.register_buffer(param_def.name(), tensor);
|
||||
} else {
|
||||
@ -425,11 +449,12 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule(
|
||||
|
||||
std::function<void(const std::string&)> import_callback =
|
||||
[&, this](const std::string& qualifier) { importCallback(qualifier); };
|
||||
script::LEGACY_import_methods(module, src, tensor_table_, import_callback);
|
||||
script::LEGACY_import_methods(
|
||||
module, src, constants_table_, import_callback);
|
||||
}
|
||||
|
||||
if (module_def.has_get_state_attribute_id()) {
|
||||
moduleSetState(
|
||||
LEGACY_moduleSetState(
|
||||
module,
|
||||
LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id()));
|
||||
}
|
||||
@ -450,7 +475,7 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule(
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < numPushed; i++) {
|
||||
moduleStack_.pop_back();
|
||||
LEGACY_moduleStack_.pop_back();
|
||||
}
|
||||
return module;
|
||||
}
|
||||
|
||||
@ -7,36 +7,30 @@ namespace torch {
|
||||
namespace jit {
|
||||
namespace ImportExportHelpers {
|
||||
|
||||
static const std::string exportPrefix = "code/";
|
||||
static const std::string LEGACY_exportPrefix = "libs/";
|
||||
static const std::string kExportSuffix = "py";
|
||||
|
||||
std::string qualifierToPath(
|
||||
const std::string& qualifier,
|
||||
size_t proto_version) {
|
||||
const auto kExportPrefix =
|
||||
proto_version >= 6 ? exportPrefix : LEGACY_exportPrefix;
|
||||
const std::string& export_prefix) {
|
||||
std::string path = qualifier;
|
||||
std::replace_if(
|
||||
path.begin(), path.end(), [](char c) { return c == '.'; }, '/');
|
||||
return kExportPrefix + path + "." + kExportSuffix;
|
||||
return export_prefix + path + "." + kExportSuffix;
|
||||
}
|
||||
|
||||
std::string pathToQualifier(
|
||||
const std::string& classPath,
|
||||
size_t proto_version) {
|
||||
const auto kExportPrefix =
|
||||
proto_version >= 6 ? exportPrefix : LEGACY_exportPrefix;
|
||||
const std::string& export_prefix) {
|
||||
// strip input suffix
|
||||
const auto end = classPath.rfind(kExportSuffix);
|
||||
AT_ASSERT(end != std::string::npos);
|
||||
|
||||
// strip input suffix
|
||||
size_t libs_idx = classPath.find(kExportPrefix);
|
||||
size_t libs_idx = classPath.find(export_prefix);
|
||||
AT_ASSERT(libs_idx == 0);
|
||||
|
||||
AT_ASSERT(classPath.size() > kExportPrefix.size());
|
||||
const auto start = kExportPrefix.size();
|
||||
AT_ASSERT(classPath.size() > export_prefix.size());
|
||||
const auto start = export_prefix.size();
|
||||
|
||||
std::string class_qualifier = classPath.substr(start, end - start);
|
||||
std::replace_if(
|
||||
|
||||
@ -13,13 +13,17 @@ namespace ImportExportHelpers {
|
||||
//
|
||||
// Qualifier is like: foo.bar.baz
|
||||
// Returns: libs/foo/bar/baz.py
|
||||
std::string qualifierToPath(const std::string& qualifier, size_t proto_version);
|
||||
std::string qualifierToPath(
|
||||
const std::string& qualifier,
|
||||
const std::string& export_prefix);
|
||||
|
||||
// Convert a source file path to a class type's qualifier name.
|
||||
//
|
||||
// Path is like: libs/foo/bar/baz.py
|
||||
// Returns: foo.bar.baz
|
||||
std::string pathToQualifier(const std::string& classPath, size_t proto_version);
|
||||
std::string pathToQualifier(
|
||||
const std::string& classPath,
|
||||
const std::string& export_prefix);
|
||||
|
||||
} // namespace ImportExportHelpers
|
||||
} // namespace jit
|
||||
|
||||
@ -102,7 +102,8 @@ struct ConstantTableValue : public SugaredValue {
|
||||
<< " is out of bounds (constant table has "
|
||||
<< constants_.size() << " entries)";
|
||||
}
|
||||
Value* value = m.graph()->insertConstant(constants_[offset], nullptr, loc);
|
||||
Value* value =
|
||||
m.graph()->insertConstant(constants_.at(offset), nullptr, loc);
|
||||
|
||||
// specializing tensor type on compilation messes up typing relations
|
||||
value->setType(unshapedType(value->type()));
|
||||
|
||||
@ -47,18 +47,18 @@ std::vector<char> pickle(
|
||||
|
||||
IValue unpickle(
|
||||
std::function<bool(char*, size_t)> reader,
|
||||
const std::vector<at::Tensor>* tensor_table,
|
||||
ClassResolver class_resolver) {
|
||||
ClassResolver class_resolver,
|
||||
const std::vector<at::Tensor>* tensor_table) {
|
||||
Unpickler unpickler(
|
||||
std::move(reader), tensor_table, std::move(class_resolver));
|
||||
std::move(reader), std::move(class_resolver), tensor_table);
|
||||
return unpickler.parse_ivalue();
|
||||
}
|
||||
|
||||
IValue unpickle(
|
||||
const char* data,
|
||||
size_t size,
|
||||
const std::vector<at::Tensor>* tensor_table,
|
||||
ClassResolver class_resolver) {
|
||||
ClassResolver class_resolver,
|
||||
const std::vector<at::Tensor>* tensor_table) {
|
||||
size_t bytes_read = 0;
|
||||
return unpickle(
|
||||
[&](char* buffer, size_t len) {
|
||||
@ -71,8 +71,8 @@ IValue unpickle(
|
||||
bytes_read += len;
|
||||
return true;
|
||||
},
|
||||
tensor_table,
|
||||
std::move(class_resolver));
|
||||
std::move(class_resolver),
|
||||
tensor_table);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
||||
@ -50,18 +50,13 @@ TORCH_API void pickle(
|
||||
std::vector<at::Tensor>* tensor_table = nullptr);
|
||||
|
||||
/// `reader` is a function that takes in a size to read from some pickled
|
||||
/// binary. `reader` should remember where it last read.
|
||||
///
|
||||
/// `bounds_checker` is a function that returns `true` if the reader can read
|
||||
/// more data, and `false` if it cannot (i.e. if a stream has hit its end of
|
||||
/// file)
|
||||
///
|
||||
/// binary. `reader` should remember where it last read, and return
|
||||
/// false if the read was not successful.
|
||||
/// See `torch::pickle` for details.
|
||||
TORCH_API IValue unpickle(
|
||||
std::function<const char*(size_t)> reader,
|
||||
std::function<bool()> bounds_chcker,
|
||||
const std::vector<at::Tensor>* tensor_table = nullptr,
|
||||
ClassResolver class_resolver = nullptr);
|
||||
std::function<bool(char*, size_t)> reader,
|
||||
ClassResolver class_resolver,
|
||||
const std::vector<at::Tensor>* tensor_table);
|
||||
|
||||
/// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
|
||||
///
|
||||
@ -72,8 +67,8 @@ TORCH_API IValue unpickle(
|
||||
TORCH_API IValue unpickle(
|
||||
const char* data,
|
||||
size_t size,
|
||||
const std::vector<at::Tensor>* tensor_table = nullptr,
|
||||
ClassResolver class_resolver = nullptr);
|
||||
ClassResolver class_resolver = nullptr,
|
||||
const std::vector<at::Tensor>* tensor_table = nullptr);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -151,9 +151,9 @@ void Pickler::pushIValueImpl(const IValue& ivalue) {
|
||||
} else if (ivalue.isTuple()) {
|
||||
pushTuple(ivalue);
|
||||
} else if (ivalue.isDouble()) {
|
||||
pushDouble(ivalue);
|
||||
pushDouble(ivalue.toDouble());
|
||||
} else if (ivalue.isInt()) {
|
||||
pushInt(ivalue);
|
||||
pushInt(ivalue.toInt());
|
||||
} else if (ivalue.isBool()) {
|
||||
if (ivalue.toBool()) {
|
||||
push<OpCode>(OpCode::NEWTRUE);
|
||||
@ -244,8 +244,7 @@ void Pickler::pushIValue(const IValue& ivalue) {
|
||||
}
|
||||
}
|
||||
|
||||
void Pickler::pushInt(const IValue& ivalue) {
|
||||
auto n = ivalue.toInt();
|
||||
void Pickler::pushInt(int64_t n) {
|
||||
if (n >= std::numeric_limits<int8_t>::min() &&
|
||||
n <= std::numeric_limits<int8_t>::max()) {
|
||||
push<OpCode>(OpCode::BININT1);
|
||||
@ -311,9 +310,11 @@ void Pickler::pushStorageOfTensor(const at::Tensor& tensor) {
|
||||
// root_key
|
||||
pushString(std::to_string(tensor_data_.size()));
|
||||
// location
|
||||
pushString("cpu");
|
||||
std::stringstream ss;
|
||||
ss << tensor.device();
|
||||
pushString(ss.str());
|
||||
// size
|
||||
pushInt(tensor.numel());
|
||||
pushInt(tensor.storage().size());
|
||||
// view_metadata
|
||||
push<OpCode>(OpCode::NONE);
|
||||
push<OpCode>(OpCode::TUPLE);
|
||||
@ -362,17 +363,18 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) {
|
||||
// The format here is the same one used by `torch.save()`. The code for the
|
||||
// format can be found in `torch/serialization.py`.
|
||||
auto tensor = ivalue.toTensor();
|
||||
|
||||
bool quantized = tensor.is_quantized();
|
||||
// The arguments to this function are:
|
||||
// storage, storage_offset, size, stride, requires_grad, backward_hooks
|
||||
pushGlobal("torch._utils", "_rebuild_tensor_v2");
|
||||
pushGlobal(
|
||||
"torch._utils", quantized ? "_rebuild_qtensor" : "_rebuild_tensor_v2");
|
||||
|
||||
push<OpCode>(OpCode::MARK);
|
||||
|
||||
pushStorageOfTensor(tensor);
|
||||
|
||||
// storage offset
|
||||
int64_t storage_offset = 0;
|
||||
pushInt(storage_offset);
|
||||
pushInt(tensor.storage_offset());
|
||||
|
||||
// size
|
||||
push<OpCode>(OpCode::MARK);
|
||||
@ -388,6 +390,11 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) {
|
||||
}
|
||||
push<OpCode>(OpCode::TUPLE);
|
||||
|
||||
if (quantized) {
|
||||
pushDouble(tensor.q_scale());
|
||||
pushInt(tensor.q_zero_point());
|
||||
}
|
||||
|
||||
// requires_grad
|
||||
pushIValue(tensor.requires_grad());
|
||||
|
||||
@ -447,8 +454,7 @@ void Pickler::pushSpecializedList(
|
||||
push<OpCode>(OpCode::REDUCE);
|
||||
}
|
||||
|
||||
void Pickler::pushDouble(const IValue& ivalue) {
|
||||
double value = ivalue.toDouble();
|
||||
void Pickler::pushDouble(double value) {
|
||||
AT_ASSERT(sizeof(double) == 8);
|
||||
char* bytes = reinterpret_cast<char*>(&value);
|
||||
|
||||
@ -522,16 +528,6 @@ IValue Unpickler::parse_ivalue() {
|
||||
return stack_[0];
|
||||
}
|
||||
|
||||
IValue Unpickler::parseModule() {
|
||||
run();
|
||||
TORCH_CHECK(
|
||||
stack_.size() == 1,
|
||||
"Unpickler expected 1 element on the stack, but found ",
|
||||
stack_.size());
|
||||
|
||||
return stack_[0];
|
||||
}
|
||||
|
||||
double Unpickler::readFloat() {
|
||||
AT_ASSERT(sizeof(double) == 8);
|
||||
double big_endian = read<double>();
|
||||
@ -600,6 +596,12 @@ static IValue toSpecializedList(const IValue& generic) {
|
||||
return IValue(std::move(specialized));
|
||||
}
|
||||
|
||||
static std::vector<int64_t> tupleToIntList(const IValue& v) {
|
||||
return fmap(v.toTuple()->elements(), [](const IValue& v) -> int64_t {
|
||||
return v.toInt();
|
||||
});
|
||||
}
|
||||
|
||||
OpCode Unpickler::readInstruction() {
|
||||
auto opcode = readOpCode();
|
||||
switch (opcode) {
|
||||
@ -748,6 +750,63 @@ OpCode Unpickler::readInstruction() {
|
||||
AT_ERROR("Unknown pickler class id");
|
||||
}
|
||||
});
|
||||
} else if (
|
||||
module_name == "torch._utils" &&
|
||||
(class_name == "_rebuild_tensor_v2" ||
|
||||
class_name == "_rebuild_qtensor")) {
|
||||
bool quantized = class_name == "_rebuild_qtensor";
|
||||
globals_.emplace_back([this, quantized] {
|
||||
auto tup = pop(stack_).toTuple();
|
||||
const auto& elements = tup->elements();
|
||||
size_t idx = 0;
|
||||
auto storage_tensor = elements.at(idx++).toTensor();
|
||||
int64_t storage_offset = elements.at(idx++).toInt();
|
||||
std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
|
||||
std::vector<int64_t> stride = tupleToIntList(elements.at(idx++));
|
||||
double q_scale = 0.;
|
||||
int64_t q_zero_point = 0;
|
||||
if (quantized) {
|
||||
q_scale = elements.at(idx++).toDouble();
|
||||
q_zero_point = elements.at(idx++).toInt();
|
||||
}
|
||||
bool requires_grad = elements.at(idx++).toBool();
|
||||
// elements[idx++] is empty backwards hooks
|
||||
at::Tensor result = quantized
|
||||
? at::_empty_affine_quantized(
|
||||
{}, storage_tensor.options(), q_scale, q_zero_point)
|
||||
: at::empty({0}, storage_tensor.options());
|
||||
at::TensorImpl* impl = result.unsafeGetTensorImpl();
|
||||
impl->set_storage(storage_tensor.storage());
|
||||
impl->set_storage_offset(storage_offset);
|
||||
impl->set_sizes_and_strides(size, stride);
|
||||
result = autograd::make_variable(result, requires_grad);
|
||||
stack_.push_back(std::move(result));
|
||||
});
|
||||
} else if (module_name == "collections" && class_name == "OrderedDict") {
|
||||
globals_.emplace_back([this] {
|
||||
// drop the Tuple that was argument to OrderedDict, and replace it
|
||||
// with None OrderedDicts only appear in tensor deserialization and
|
||||
// their value is never used
|
||||
stack_.back() = IValue();
|
||||
});
|
||||
} else if (module_name == "torch") {
|
||||
c10::optional<c10::ScalarType> scalar_type;
|
||||
#define CHECK_SCALAR(_, name) \
|
||||
if (class_name == #name "Storage") { \
|
||||
scalar_type = c10::k##name; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CHECK_SCALAR)
|
||||
#undef CHECK_SCALAR
|
||||
// NOTE: this does not put a global into the global table,
|
||||
// like the other branches here because no REDUCE or BUILD will
|
||||
// be called on this value. Instead, we just put it on the stack
|
||||
// and return early
|
||||
AT_ASSERT(
|
||||
scalar_type.has_value(),
|
||||
"class name not understood: torch.",
|
||||
class_name);
|
||||
stack_.emplace_back(int64_t(*scalar_type));
|
||||
return opcode;
|
||||
} else {
|
||||
AT_ASSERT(class_resolver_);
|
||||
at::StrongTypePtr type =
|
||||
@ -797,12 +856,52 @@ OpCode Unpickler::readInstruction() {
|
||||
// stack is: <functor_arg>
|
||||
globals_.at(idx)();
|
||||
} break;
|
||||
default:
|
||||
case OpCode::BINPERSID: {
|
||||
auto args = pop(stack_).toTuple()->elements();
|
||||
AT_ASSERT(
|
||||
args.at(0).toStringRef() == "storage",
|
||||
"unknown PERSID key ",
|
||||
args.at(0).toStringRef());
|
||||
at::ScalarType type = args.at(1).toScalarType();
|
||||
const std::string& key = args.at(2).toStringRef();
|
||||
at::Device device(args.at(3).toStringRef());
|
||||
if (device_) {
|
||||
device = *device_;
|
||||
}
|
||||
at::DataPtr storage_ptr = read_record_(key);
|
||||
int64_t numel = args.at(4).toInt();
|
||||
at::Storage storage(
|
||||
at::CPU(type).typeMeta(),
|
||||
numel,
|
||||
std::move(storage_ptr),
|
||||
/*allocator=*/nullptr,
|
||||
/*resizable=*/false); // NB: we didn't set any allocator for the
|
||||
// tensor
|
||||
auto options = at::CPU(type).options();
|
||||
at::Tensor tensor;
|
||||
if (options.backend() == c10::Backend::QuantizedCPU) {
|
||||
tensor = at::_empty_affine_quantized({}, options, 0, 0)
|
||||
.set_(storage, 0, {}, {});
|
||||
} else {
|
||||
tensor = at::empty({0}, options).set_(storage);
|
||||
}
|
||||
|
||||
if (device.type() == at::DeviceType::CUDA) {
|
||||
tensor = tensor.to(device, tensor.scalar_type());
|
||||
} else if (device.type() != at::DeviceType::CPU) {
|
||||
AT_ERROR(
|
||||
"supported devices include CPU and CUDA, however got ",
|
||||
at::DeviceTypeName(device.type(), false));
|
||||
}
|
||||
stack_.push_back(std::move(tensor));
|
||||
} break;
|
||||
default: {
|
||||
AT_ERROR(
|
||||
"Unknown opcode for unpickling at ",
|
||||
reinterpret_cast<void*>(opcode),
|
||||
": ",
|
||||
int(static_cast<uint8_t>(opcode)));
|
||||
} break;
|
||||
}
|
||||
return opcode;
|
||||
}
|
||||
|
||||
@ -147,12 +147,16 @@ class Pickler {
|
||||
void startTuple();
|
||||
void endTuple();
|
||||
|
||||
const std::vector<WriteableTensorData>& tensorData() {
|
||||
return tensor_data_;
|
||||
}
|
||||
|
||||
private:
|
||||
void pushIValueImpl(const IValue& ivalue);
|
||||
void pushDict(const IValue& ivalue);
|
||||
void pushDouble(const IValue& ivalue);
|
||||
void pushDouble(double value);
|
||||
void pushGenericList(const IValue& ivalue);
|
||||
void pushInt(const IValue& ivalue);
|
||||
void pushInt(int64_t value);
|
||||
void pushIntList(const IValue& ivalue);
|
||||
void pushList(const IValue& ivalue);
|
||||
void pushLiteralTensor(const IValue& ivalue);
|
||||
@ -233,16 +237,29 @@ class Unpickler {
|
||||
TH_DISALLOW_COPY_AND_ASSIGN(Unpickler);
|
||||
|
||||
public:
|
||||
// tensors inside the pickle are references to the tensor_table
|
||||
Unpickler(
|
||||
std::function<bool(char*, size_t)> reader,
|
||||
const std::vector<at::Tensor>* tensor_table,
|
||||
ClassResolver class_resolver)
|
||||
ClassResolver class_resolver,
|
||||
const std::vector<at::Tensor>* tensor_table)
|
||||
: reader_(reader),
|
||||
tensor_table_(tensor_table),
|
||||
class_resolver_(std::move(class_resolver)) {}
|
||||
|
||||
// tensors inside the pickle contain meta-data, the raw tensor
|
||||
// dead is retrieved by calling `read_record`.
|
||||
Unpickler(
|
||||
std::function<bool(char*, size_t)> reader,
|
||||
ClassResolver class_resolver,
|
||||
std::function<at::DataPtr(const std::string&)> read_record,
|
||||
c10::optional<at::Device> device)
|
||||
: reader_(reader),
|
||||
tensor_table_(nullptr),
|
||||
class_resolver_(std::move(class_resolver)),
|
||||
read_record_(std::move(read_record)),
|
||||
device_(std::move(device)) {}
|
||||
|
||||
IValue parse_ivalue();
|
||||
IValue parseModule();
|
||||
|
||||
private:
|
||||
// No arguments ensures that a template arugment must be specified
|
||||
@ -282,6 +299,9 @@ class Unpickler {
|
||||
// optionally nullptr, needs to be present for creating classes
|
||||
ClassResolver class_resolver_;
|
||||
IValue empty_tuple_;
|
||||
|
||||
std::function<at::DataPtr(const std::string&)> read_record_;
|
||||
c10::optional<at::Device> device_;
|
||||
};
|
||||
|
||||
// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor
|
||||
|
||||
Reference in New Issue
Block a user