mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Load tensors directly from pickle archive
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23281 Test Plan: Imported from OSS Differential Revision: D16452815 Pulled By: zdevito fbshipit-source-id: 918eef3ad444b598ab655c39037e4baafdcb51e1
This commit is contained in:
		
				
					committed by
					
						 Facebook Github Bot
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							c33adf539c
						
					
				
				
					commit
					e2ccccee9a
				
			| @ -165,7 +165,7 @@ static std::string getPadding(size_t cursor, const std::string& filename, size_t | ||||
|   return buf; | ||||
| } | ||||
|  | ||||
| bool PyTorchStreamReader::hasFile(const std::string& name) { | ||||
| bool PyTorchStreamReader::hasRecord(const std::string& name) { | ||||
|   std::stringstream ss; | ||||
|   ss << archive_name_ << "/" << name; | ||||
|   mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0); | ||||
| @ -177,7 +177,7 @@ bool PyTorchStreamReader::hasFile(const std::string& name) { | ||||
|   return result; | ||||
| } | ||||
|  | ||||
| size_t PyTorchStreamReader::getFileID(const std::string& name) { | ||||
| size_t PyTorchStreamReader::getRecordID(const std::string& name) { | ||||
|   std::stringstream ss; | ||||
|   ss << archive_name_ << "/" << name; | ||||
|   size_t result = mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0); | ||||
| @ -190,7 +190,7 @@ size_t PyTorchStreamReader::getFileID(const std::string& name) { | ||||
|  | ||||
| // return dataptr, size | ||||
| std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) { | ||||
|   size_t key = getFileID(name); | ||||
|   size_t key = getRecordID(name); | ||||
|   mz_zip_archive_file_stat stat; | ||||
|   mz_zip_reader_file_stat(ar_.get(), key, &stat); | ||||
|   valid("retrieving file meta-data"); | ||||
| @ -208,7 +208,7 @@ static int64_t read_le_16(uint8_t* buf) { | ||||
|  | ||||
| size_t PyTorchStreamReader::getRecordOffset(const std::string& name) { | ||||
|   mz_zip_archive_file_stat stat; | ||||
|   mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat); | ||||
|   mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat); | ||||
|   valid("retriving file meta-data"); | ||||
|   uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE]; | ||||
|   in_->read( | ||||
|  | ||||
| @ -92,9 +92,6 @@ namespace serialize { | ||||
| constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; | ||||
| constexpr uint64_t kMaxSupportedFileFormatVersion = 0x1L; | ||||
|  | ||||
| // Writer-specific constants | ||||
| constexpr uint64_t kFileFormatVersion = 0x2L; | ||||
|  | ||||
| // Writer-specific constants | ||||
| constexpr uint64_t kFieldAlignment = 64; | ||||
|  | ||||
| @ -107,7 +104,7 @@ class CAFFE2_API PyTorchStreamReader final { | ||||
|   // return dataptr, size | ||||
|   std::tuple<at::DataPtr, size_t> getRecord(const std::string& name); | ||||
|   size_t getRecordOffset(const std::string& name); | ||||
|   bool hasFile(const std::string& name); | ||||
|   bool hasRecord(const std::string& name); | ||||
|  | ||||
|   ~PyTorchStreamReader(); | ||||
|  | ||||
| @ -115,7 +112,7 @@ class CAFFE2_API PyTorchStreamReader final { | ||||
|   void init(); | ||||
|   size_t read(uint64_t pos, char* buf, size_t n); | ||||
|   void valid(const char* what); | ||||
|   size_t getFileID(const std::string& name); | ||||
|   size_t getRecordID(const std::string& name); | ||||
|  | ||||
|   friend size_t | ||||
|   istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n); | ||||
|  | ||||
| @ -39,9 +39,9 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { | ||||
|  | ||||
|   // read records through readers | ||||
|   PyTorchStreamReader reader(&iss); | ||||
|   ASSERT_TRUE(reader.hasFile("key1")); | ||||
|   ASSERT_TRUE(reader.hasFile("key2")); | ||||
|   ASSERT_FALSE(reader.hasFile("key2000")); | ||||
|   ASSERT_TRUE(reader.hasRecord("key1")); | ||||
|   ASSERT_TRUE(reader.hasRecord("key2")); | ||||
|   ASSERT_FALSE(reader.hasRecord("key2000")); | ||||
|   at::DataPtr data_ptr; | ||||
|   int64_t size; | ||||
|   std::tie(data_ptr, size) = reader.getRecord("key1"); | ||||
|  | ||||
| @ -9036,6 +9036,7 @@ a") | ||||
|             m_import = self.getExportImportCopy(m_orig) | ||||
|  | ||||
|             self.assertEqual(m_orig.foo(), m_import.foo()) | ||||
|  | ||||
|             self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr()) | ||||
|             self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr()) | ||||
|  | ||||
|  | ||||
| @ -55,7 +55,7 @@ ScriptCall ScriptCall::fromMessage(const Message& message) { | ||||
|   auto payload = static_cast<const char*>(message.payload().data()); | ||||
|   auto payload_size = message.payload().size(); | ||||
|  | ||||
|   auto value = jit::unpickle(payload, payload_size, &message.tensors()); | ||||
|   auto value = jit::unpickle(payload, payload_size, nullptr, &message.tensors()); | ||||
|  | ||||
|   auto values = value.toTuple()->elements(); | ||||
|  | ||||
|  | ||||
| @ -29,7 +29,7 @@ Message ScriptRet::toMessage() { | ||||
| ScriptRet ScriptRet::fromMessage(const Message& message) { | ||||
|   auto payload = static_cast<const char*>(message.payload().data()); | ||||
|   auto payload_size = message.payload().size(); | ||||
|   auto value = jit::unpickle(payload, payload_size, &message.tensors()); | ||||
|   auto value = jit::unpickle(payload, payload_size, nullptr, &message.tensors()); | ||||
|   return ScriptRet(std::move(value)); | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -618,34 +618,70 @@ class ScriptModuleSerializer { | ||||
| // 1. Several tests that depend on the serialization format details. | ||||
| // 2. The emit module hook (since combining the old export and new import code | ||||
| //    is going to cause jitter) | ||||
| class ScriptModuleSerializer2 : public ScriptModuleSerializer { | ||||
| class ScriptModuleSerializer2 { | ||||
|  public: | ||||
|   ScriptModuleSerializer2(const std::string& filename) | ||||
|       : ScriptModuleSerializer(filename) {} | ||||
|       : writer_(filename.c_str()) {} | ||||
|  | ||||
|   ScriptModuleSerializer2(std::ostream* ofs) : ScriptModuleSerializer(ofs) {} | ||||
|  | ||||
|  private: | ||||
|   void convertModel( | ||||
|   ScriptModuleSerializer2(std::ostream* ofs) : ofs_(), writer_(ofs) {} | ||||
|   void serialize( | ||||
|       const script::Module& module, | ||||
|       torch::ModelDef* model_def, | ||||
|       const script::ExtraFilesMap& extra_files) override { | ||||
|     model_def->set_producer_name("pytorch"); | ||||
|     model_def->set_producer_version("1.0"); // TODO: set the producer version | ||||
|                                             // using appropriate function call | ||||
|     model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST); | ||||
|     // Serialize all code info. | ||||
|     convertNamedType(module.type()); | ||||
|     // Then pickle the module | ||||
|     auto data = pickle(module.module_object(), &tensor_table_); | ||||
|     writer_.writeRecord("data.pkl", data.data(), data.size()); | ||||
|  | ||||
|     writeTensorTable(model_def); | ||||
|     writeLibs(model_def); | ||||
|       const script::ExtraFilesMap& extra_files) { | ||||
|     C10_LOG_API_USAGE_ONCE("torch.script.save"); | ||||
|     writeExtraFiles(module, extra_files); | ||||
|     // Serialize all code info. | ||||
|     writeCode(module.type()); | ||||
|     // The tensor constants from the code are written to a separate archive | ||||
|     // so loading the code does not depend on loading the data | ||||
|     std::vector<IValue> ivalue_constants( | ||||
|         constant_table_.begin(), constant_table_.end()); | ||||
|     writeArchive("constants", c10::ivalue::Tuple::create(ivalue_constants)); | ||||
|     // finally we serialize the model | ||||
|     writeArchive("data", module.module_object()); | ||||
|   } | ||||
|  | ||||
|   void writeLibs(ModelDef* model_def) override { | ||||
|  private: | ||||
|   void writeArchive(const std::string& archive_name, const IValue& value) { | ||||
|     std::vector<char> data; | ||||
|     Pickler data_pickle( | ||||
|         [&](const char* buf, size_t size) { | ||||
|           data.insert(data.end(), buf, buf + size); | ||||
|         }, | ||||
|         nullptr); | ||||
|     data_pickle.protocol(); | ||||
|     data_pickle.pushIValue(value); | ||||
|     data_pickle.stop(); | ||||
|     size_t i = 0; | ||||
|     for (const auto& td : data_pickle.tensorData()) { | ||||
|       std::stringstream fname; | ||||
|       fname << archive_name << "/" << i++; | ||||
|       writer_.writeRecord(fname.str(), td.data(), td.sizeInBytes()); | ||||
|     } | ||||
|     std::stringstream fname; | ||||
|     fname << archive_name << ".pkl"; | ||||
|     writer_.writeRecord(fname.str(), data.data(), data.size()); | ||||
|   } | ||||
|  | ||||
|   void writeExtraFiles( | ||||
|       const script::Module& module, | ||||
|       const script::ExtraFilesMap& extra_files) { | ||||
|     // Write out extra files. | ||||
|     for (const auto& kv : extra_files) { | ||||
|       const std::string key = "extra/" + kv.first; | ||||
|       writer_.writeRecord(key, kv.second.data(), kv.second.size()); | ||||
|     } | ||||
|     auto hook = GetExtraFilesHook(); | ||||
|     if (hook) { | ||||
|       script::ExtraFilesMap hook_files = hook(module); | ||||
|       for (const auto& kv : hook_files) { | ||||
|         const std::string key = "extra/" + kv.first; | ||||
|         writer_.writeRecord(key, kv.second.data(), kv.second.size()); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void writeCode(const at::NamedTypePtr& root_type) { | ||||
|     convertNamedType(root_type); | ||||
|     static const std::string opset_string = | ||||
|         c10::str("op_version_set = ", CURRENT_OP_VERSION_SET, "\n"); | ||||
|  | ||||
| @ -661,7 +697,7 @@ class ScriptModuleSerializer2 : public ScriptModuleSerializer { | ||||
|  | ||||
|       // For the type, foo.bar.Baz | ||||
|       const std::string filename = ImportExportHelpers::qualifierToPath( | ||||
|           converted_type->name()->prefix(), torch::PROTO_VERSION_NEWEST); | ||||
|           converted_type->name()->prefix(), "code/"); | ||||
|       // End state: filename is "foo/bar.py", in which we will define a class | ||||
|       // named Baz | ||||
|       auto& stream = fileToSrc[filename]; | ||||
| @ -708,7 +744,7 @@ class ScriptModuleSerializer2 : public ScriptModuleSerializer { | ||||
|           /*compress=*/true); | ||||
|     } | ||||
|   } | ||||
|   void convertNamedType(const c10::NamedTypePtr& class_type) override { | ||||
|   void convertNamedType(const c10::NamedTypePtr& class_type) { | ||||
|     if (converted_types_.contains(class_type)) { | ||||
|       return; | ||||
|     } | ||||
| @ -720,7 +756,7 @@ class ScriptModuleSerializer2 : public ScriptModuleSerializer { | ||||
|         source_stream, | ||||
|         source_ranges, | ||||
|         class_type, | ||||
|         tensor_table_, | ||||
|         constant_table_, | ||||
|         class_deps, | ||||
|         /*enforce_importable=*/true); | ||||
|  | ||||
| @ -739,6 +775,10 @@ class ScriptModuleSerializer2 : public ScriptModuleSerializer { | ||||
|     converted_types_.insert(class_type, std::move(info)); | ||||
|   } | ||||
|  | ||||
|   std::ofstream ofs_; | ||||
|   caffe2::serialize::PyTorchStreamWriter writer_; | ||||
|   std::vector<at::Tensor> constant_table_; | ||||
|  | ||||
|   // all deps used by this module hierarchy | ||||
|   struct TypeInfo { | ||||
|     std::string source; | ||||
| @ -804,8 +844,8 @@ void ScriptModuleSerializer::writeLibs(torch::ModelDef* model_def) { | ||||
|     const auto& class_src = item.value(); | ||||
|  | ||||
|     // For the type, foo.bar.Baz | ||||
|     const std::string filename = | ||||
|         ImportExportHelpers::qualifierToPath(class_type->name()->prefix(), 5); | ||||
|     const std::string filename = ImportExportHelpers::qualifierToPath( | ||||
|         class_type->name()->prefix(), "libs/"); | ||||
|     // End state: filename is "foo/bar.py", in which we will define a class | ||||
|     // named Baz | ||||
|     fileToSrc[filename] << class_src; | ||||
| @ -816,8 +856,8 @@ void ScriptModuleSerializer::writeLibs(torch::ModelDef* model_def) { | ||||
|   std::unordered_set<std::string> written_files; | ||||
|   for (const auto& item : converted_classes_) { | ||||
|     const c10::NamedTypePtr& class_type = item.key(); | ||||
|     const std::string filename = | ||||
|         ImportExportHelpers::qualifierToPath(class_type->name()->prefix(), 5); | ||||
|     const std::string filename = ImportExportHelpers::qualifierToPath( | ||||
|         class_type->name()->prefix(), "libs/"); | ||||
|     if (written_files.count(filename)) { | ||||
|       continue; | ||||
|     } | ||||
|  | ||||
| @ -68,34 +68,34 @@ class ScriptModuleDeserializer final { | ||||
|       script::ExtraFilesMap& extra_files); | ||||
|  | ||||
|  private: | ||||
|   at::Tensor loadTensor( | ||||
|   IValue readArchive(const std::string& archive_name); | ||||
|   script::Module LEGACY_deserialize(); | ||||
|   at::Tensor LEGACY_loadTensor( | ||||
|       const torch::TensorDef& tensor_proto, | ||||
|       std::unordered_map<std::string, at::Storage>& storageMap); | ||||
|  | ||||
|   void loadTensorTable(torch::ModelDef* model_def); | ||||
|   void LEGACY_loadTensorTable(torch::ModelDef* model_def); | ||||
|   void importCallback(const std::string& qualifier); | ||||
|   void moduleSetState(const script::Module& module, IValue state); | ||||
|   void LEGACY_moduleSetState(const script::Module& module, IValue state); | ||||
|  | ||||
|   std::shared_ptr<script::CompilationUnit> compilation_unit_; | ||||
|  | ||||
|   std::unique_ptr<PyTorchStreamReader> reader_; | ||||
|   c10::optional<at::Device> device_; | ||||
|   std::vector<std::string> moduleStack_; | ||||
|   std::vector<std::string> LEGACY_moduleStack_; | ||||
|  | ||||
|   std::vector<at::Tensor> tensor_table_; | ||||
|   std::vector<at::Tensor> constants_table_; | ||||
|   std::unordered_set<std::string> imported_libs_; | ||||
|  | ||||
|   IValue LEGACY_loadPickleArchive(const std::string& name); | ||||
|   script::Module LEGACY_convertModule(const torch::ModuleDef& module_def); | ||||
|   std::vector<IValue> LEGACY_pickled_ivalues_; | ||||
|   size_t proto_version_; | ||||
|   std::string export_prefix_ = "code/"; | ||||
| }; | ||||
|  | ||||
| script::Module ScriptModuleDeserializer::deserialize( | ||||
|     c10::optional<at::Device> device, | ||||
|     script::ExtraFilesMap& extra_files) { | ||||
|   C10_LOG_API_USAGE_ONCE("torch.script.load"); | ||||
| script::Module ScriptModuleDeserializer::LEGACY_deserialize() { | ||||
|   torch::ModelDef model_def; | ||||
|  | ||||
|   at::DataPtr data_ptr; | ||||
|   size_t data_size; | ||||
|   std::tie(data_ptr, data_size) = reader_->getRecord("model.json"); | ||||
| @ -125,13 +125,69 @@ script::Module ScriptModuleDeserializer::deserialize( | ||||
|   AT_ASSERTM( | ||||
|       model_def.ParseFromString(binary_string), | ||||
|       "JSON transcoder produced invalid protobuf output."); | ||||
|   device_ = device; | ||||
|   proto_version_ = model_def.proto_version(); | ||||
|   auto proto_version = model_def.proto_version(); | ||||
|   export_prefix_ = "libs/"; | ||||
|  | ||||
|   LEGACY_loadTensorTable(&model_def); | ||||
|   AT_ASSERT(proto_version < 6); | ||||
|   if (proto_version == 2) { | ||||
|     const auto& list = | ||||
|         LEGACY_loadPickleArchive("attributes.pkl").toGenericList(); | ||||
|     LEGACY_pickled_ivalues_.insert( | ||||
|         LEGACY_pickled_ivalues_.end(), list.begin(), list.end()); | ||||
|   } else if (proto_version >= 3) { | ||||
|     LEGACY_pickled_ivalues_ = | ||||
|         LEGACY_loadPickleArchive("attributes.pkl").toTuple()->elements(); | ||||
|   } | ||||
|   LEGACY_moduleStack_.push_back("__torch__"); | ||||
|   const auto& module_def = model_def.main_module(); | ||||
|   return LEGACY_convertModule(module_def); | ||||
| } | ||||
|  | ||||
| IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) { | ||||
|   std::stringstream picklename; | ||||
|   picklename << archive_name << ".pkl"; | ||||
|   at::DataPtr pickle_ptr; | ||||
|   size_t pickle_size; | ||||
|   std::tie(pickle_ptr, pickle_size) = reader_->getRecord(picklename.str()); | ||||
|  | ||||
|   size_t bytes_read = 0; | ||||
|   auto data = reinterpret_cast<const char*>(pickle_ptr.get()); | ||||
|   auto reader = [&](char* buffer, size_t len) { | ||||
|     if (bytes_read + len > pickle_size) { | ||||
|       return false; | ||||
|     } | ||||
|     // Copy len bytes into buffer | ||||
|     const char* start = data + bytes_read; | ||||
|     std::memcpy(buffer, start, len); | ||||
|     bytes_read += len; | ||||
|     return true; | ||||
|   }; | ||||
|  | ||||
|   auto class_resolver = [&](const c10::QualifiedName& qn) { | ||||
|     importCallback(qn.prefix()); | ||||
|     return c10::StrongTypePtr( | ||||
|         compilation_unit_, compilation_unit_->get_class(qn)); | ||||
|   }; | ||||
|   auto read_record = [&](const std::string& name) { | ||||
|     std::stringstream ss; | ||||
|     ss << archive_name << "/" << name; | ||||
|     return std::get<0>(reader_->getRecord(ss.str())); | ||||
|   }; | ||||
|   Unpickler unpickler( | ||||
|       reader, std::move(class_resolver), std::move(read_record), device_); | ||||
|   return unpickler.parse_ivalue(); | ||||
| } | ||||
|  | ||||
| script::Module ScriptModuleDeserializer::deserialize( | ||||
|     c10::optional<at::Device> device, | ||||
|     script::ExtraFilesMap& extra_files) { | ||||
|   C10_LOG_API_USAGE_ONCE("torch.script.load"); | ||||
|   device_ = device; | ||||
|   // Load extra files. | ||||
|   for (const auto& kv : extra_files) { | ||||
|     const std::string& key = "extra/" + kv.first; | ||||
|     if (reader_->hasFile(key)) { | ||||
|     if (reader_->hasRecord(key)) { | ||||
|       at::DataPtr meta_ptr; | ||||
|       size_t meta_size; | ||||
|       std::tie(meta_ptr, meta_size) = reader_->getRecord(key); | ||||
| @ -139,48 +195,14 @@ script::Module ScriptModuleDeserializer::deserialize( | ||||
|           std::string(static_cast<char*>(meta_ptr.get()), meta_size); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   loadTensorTable(&model_def); | ||||
|  | ||||
|   if (proto_version_ < 6) { | ||||
|     if (proto_version_ == 2) { | ||||
|       const auto& list = | ||||
|           LEGACY_loadPickleArchive("attributes.pkl").toGenericList(); | ||||
|       LEGACY_pickled_ivalues_.insert( | ||||
|           LEGACY_pickled_ivalues_.end(), list.begin(), list.end()); | ||||
|     } else if (proto_version_ >= 3) { | ||||
|       LEGACY_pickled_ivalues_ = | ||||
|           LEGACY_loadPickleArchive("attributes.pkl").toTuple()->elements(); | ||||
|     } | ||||
|     moduleStack_.push_back("__torch__"); | ||||
|     const auto& module_def = model_def.main_module(); | ||||
|     return LEGACY_convertModule(module_def); | ||||
|   } else { | ||||
|     at::DataPtr pickle_ptr; | ||||
|     size_t pickle_size; | ||||
|     std::tie(pickle_ptr, pickle_size) = reader_->getRecord("data.pkl"); | ||||
|  | ||||
|     size_t bytes_read = 0; | ||||
|     auto data = reinterpret_cast<const char*>(pickle_ptr.get()); | ||||
|     auto reader = [&](char* buffer, size_t len) { | ||||
|       if (bytes_read + len > pickle_size) { | ||||
|         return false; | ||||
|       } | ||||
|       // Copy len bytes into buffer | ||||
|       const char* start = data + bytes_read; | ||||
|       std::memcpy(buffer, start, len); | ||||
|       bytes_read += len; | ||||
|       return true; | ||||
|     }; | ||||
|  | ||||
|     Unpickler unpickler( | ||||
|         reader, &tensor_table_, [&](const c10::QualifiedName& qn) { | ||||
|           importCallback(qn.prefix()); | ||||
|           return c10::StrongTypePtr( | ||||
|               compilation_unit_, compilation_unit_->get_class(qn)); | ||||
|         }); | ||||
|     return script::Module(unpickler.parseModule().toObject()); | ||||
|   if (reader_->hasRecord("model.json")) { | ||||
|     return LEGACY_deserialize(); | ||||
|   } | ||||
|   auto tuple = readArchive("constants").toTuple(); | ||||
|   for (auto constant : tuple->elements()) { | ||||
|     constants_table_.push_back(constant.toTensor()); | ||||
|   } | ||||
|   return script::Module(readArchive("data").toObject()); | ||||
| } | ||||
|  | ||||
| IValue ScriptModuleDeserializer::LEGACY_loadPickleArchive( | ||||
| @ -191,23 +213,24 @@ IValue ScriptModuleDeserializer::LEGACY_loadPickleArchive( | ||||
|   auto ivalue = unpickle( | ||||
|       reinterpret_cast<const char*>(attributes_ptr.get()), | ||||
|       attributes_size, | ||||
|       &tensor_table_, | ||||
|       [&](const c10::QualifiedName& qn) { | ||||
|         importCallback(qn.prefix()); | ||||
|         return c10::StrongTypePtr( | ||||
|             compilation_unit_, compilation_unit_->get_class(qn)); | ||||
|       }); | ||||
|       }, | ||||
|       &constants_table_); | ||||
|   return ivalue; | ||||
| } | ||||
|  | ||||
| void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) { | ||||
| void ScriptModuleDeserializer::LEGACY_loadTensorTable( | ||||
|     torch::ModelDef* model_def) { | ||||
|   std::unordered_map<std::string, at::Storage> storageMap; | ||||
|   for (const torch::TensorDef& tensor : model_def->tensors()) { | ||||
|     tensor_table_.emplace_back(loadTensor(tensor, storageMap)); | ||||
|     constants_table_.emplace_back(LEGACY_loadTensor(tensor, storageMap)); | ||||
|   } | ||||
| } | ||||
|  | ||||
| at::Tensor ScriptModuleDeserializer::loadTensor( | ||||
| at::Tensor ScriptModuleDeserializer::LEGACY_loadTensor( | ||||
|     const torch::TensorDef& tensor_proto, | ||||
|     std::unordered_map<std::string, at::Storage>& storageMap) { | ||||
|   std::vector<int64_t> dims( | ||||
| @ -301,17 +324,18 @@ void ScriptModuleDeserializer::importCallback(const std::string& qualifier) { | ||||
|   std::function<void(const std::string&)> import_callback = | ||||
|       [this](const std::string& qualifier) { importCallback(qualifier); }; | ||||
|   const std::string path = | ||||
|       ImportExportHelpers::qualifierToPath(qualifier, proto_version_); | ||||
|       ImportExportHelpers::qualifierToPath(qualifier, export_prefix_); | ||||
|   at::DataPtr data; | ||||
|   size_t size; | ||||
|   std::tie(data, size) = reader_->getRecord(path); | ||||
|  | ||||
|   std::shared_ptr<ConcreteSourceRangeUnpickler> gen_ranges = nullptr; | ||||
|   if (proto_version_ >= 6) { | ||||
|  | ||||
|   std::string debug_file = path + ".debug_pkl"; | ||||
|   if (reader_->hasRecord(debug_file)) { | ||||
|     at::DataPtr debug_data; | ||||
|     size_t debug_size; | ||||
|     std::tie(debug_data, debug_size) = reader_->getRecord(path + ".debug_pkl"); | ||||
|  | ||||
|     std::tie(debug_data, debug_size) = reader_->getRecord(debug_file); | ||||
|     gen_ranges = std::make_shared<ConcreteSourceRangeUnpickler>( | ||||
|         std::move(debug_data), debug_size); | ||||
|   } | ||||
| @ -323,10 +347,10 @@ void ScriptModuleDeserializer::importCallback(const std::string& qualifier) { | ||||
|       gen_ranges); | ||||
|  | ||||
|   script::import_libs( | ||||
|       compilation_unit_, qualifier, src, tensor_table_, import_callback); | ||||
|       compilation_unit_, qualifier, src, constants_table_, import_callback); | ||||
| } | ||||
|  | ||||
| void ScriptModuleDeserializer::moduleSetState( | ||||
| void ScriptModuleDeserializer::LEGACY_moduleSetState( | ||||
|     const script::Module& module, | ||||
|     IValue state) { | ||||
|   auto setstate = module.find_method("__setstate__"); | ||||
| @ -359,10 +383,10 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule( | ||||
|   const auto atoms = c10::QualifiedName(module_def.name()).atoms(); | ||||
|   const size_t numPushed = atoms.size(); | ||||
|   for (const auto& atom : atoms) { | ||||
|     moduleStack_.emplace_back(atom); | ||||
|     LEGACY_moduleStack_.emplace_back(atom); | ||||
|   } | ||||
|   auto module = | ||||
|       script::Module(c10::QualifiedName(moduleStack_), compilation_unit_); | ||||
|   auto module = script::Module( | ||||
|       c10::QualifiedName(LEGACY_moduleStack_), compilation_unit_); | ||||
|   for (int i = 0; i < module_def.submodules_size(); ++i) { | ||||
|     const torch::ModuleDef& sub_def = module_def.submodules(i); | ||||
|     auto submodule = LEGACY_convertModule(sub_def); | ||||
| @ -370,7 +394,7 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule( | ||||
|   } | ||||
|   for (int i = 0; i < module_def.parameters_size(); ++i) { | ||||
|     const torch::ParameterDef& param_def = module_def.parameters(i); | ||||
|     at::Tensor tensor = tensor_table_.at(param_def.tensor_id()); | ||||
|     at::Tensor tensor = constants_table_.at(param_def.tensor_id()); | ||||
|     if (param_def.is_buffer()) { | ||||
|       module.register_buffer(param_def.name(), tensor); | ||||
|     } else { | ||||
| @ -425,11 +449,12 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule( | ||||
|  | ||||
|     std::function<void(const std::string&)> import_callback = | ||||
|         [&, this](const std::string& qualifier) { importCallback(qualifier); }; | ||||
|     script::LEGACY_import_methods(module, src, tensor_table_, import_callback); | ||||
|     script::LEGACY_import_methods( | ||||
|         module, src, constants_table_, import_callback); | ||||
|   } | ||||
|  | ||||
|   if (module_def.has_get_state_attribute_id()) { | ||||
|     moduleSetState( | ||||
|     LEGACY_moduleSetState( | ||||
|         module, | ||||
|         LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id())); | ||||
|   } | ||||
| @ -450,7 +475,7 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule( | ||||
|   } | ||||
|  | ||||
|   for (size_t i = 0; i < numPushed; i++) { | ||||
|     moduleStack_.pop_back(); | ||||
|     LEGACY_moduleStack_.pop_back(); | ||||
|   } | ||||
|   return module; | ||||
| } | ||||
|  | ||||
| @ -7,36 +7,30 @@ namespace torch { | ||||
| namespace jit { | ||||
| namespace ImportExportHelpers { | ||||
|  | ||||
| static const std::string exportPrefix = "code/"; | ||||
| static const std::string LEGACY_exportPrefix = "libs/"; | ||||
| static const std::string kExportSuffix = "py"; | ||||
|  | ||||
| std::string qualifierToPath( | ||||
|     const std::string& qualifier, | ||||
|     size_t proto_version) { | ||||
|   const auto kExportPrefix = | ||||
|       proto_version >= 6 ? exportPrefix : LEGACY_exportPrefix; | ||||
|     const std::string& export_prefix) { | ||||
|   std::string path = qualifier; | ||||
|   std::replace_if( | ||||
|       path.begin(), path.end(), [](char c) { return c == '.'; }, '/'); | ||||
|   return kExportPrefix + path + "." + kExportSuffix; | ||||
|   return export_prefix + path + "." + kExportSuffix; | ||||
| } | ||||
|  | ||||
| std::string pathToQualifier( | ||||
|     const std::string& classPath, | ||||
|     size_t proto_version) { | ||||
|   const auto kExportPrefix = | ||||
|       proto_version >= 6 ? exportPrefix : LEGACY_exportPrefix; | ||||
|     const std::string& export_prefix) { | ||||
|   // strip input suffix | ||||
|   const auto end = classPath.rfind(kExportSuffix); | ||||
|   AT_ASSERT(end != std::string::npos); | ||||
|  | ||||
|   // strip input suffix | ||||
|   size_t libs_idx = classPath.find(kExportPrefix); | ||||
|   size_t libs_idx = classPath.find(export_prefix); | ||||
|   AT_ASSERT(libs_idx == 0); | ||||
|  | ||||
|   AT_ASSERT(classPath.size() > kExportPrefix.size()); | ||||
|   const auto start = kExportPrefix.size(); | ||||
|   AT_ASSERT(classPath.size() > export_prefix.size()); | ||||
|   const auto start = export_prefix.size(); | ||||
|  | ||||
|   std::string class_qualifier = classPath.substr(start, end - start); | ||||
|   std::replace_if( | ||||
|  | ||||
| @ -13,13 +13,17 @@ namespace ImportExportHelpers { | ||||
| // | ||||
| // Qualifier is like: foo.bar.baz | ||||
| // Returns: libs/foo/bar/baz.py | ||||
| std::string qualifierToPath(const std::string& qualifier, size_t proto_version); | ||||
| std::string qualifierToPath( | ||||
|     const std::string& qualifier, | ||||
|     const std::string& export_prefix); | ||||
|  | ||||
| // Convert a source file path to a class type's qualifier name. | ||||
| // | ||||
| // Path is like: libs/foo/bar/baz.py | ||||
| // Returns: foo.bar.baz | ||||
| std::string pathToQualifier(const std::string& classPath, size_t proto_version); | ||||
| std::string pathToQualifier( | ||||
|     const std::string& classPath, | ||||
|     const std::string& export_prefix); | ||||
|  | ||||
| } // namespace ImportExportHelpers | ||||
| } // namespace jit | ||||
|  | ||||
| @ -102,7 +102,8 @@ struct ConstantTableValue : public SugaredValue { | ||||
|                              << " is out of bounds (constant table has " | ||||
|                              << constants_.size() << " entries)"; | ||||
|     } | ||||
|     Value* value = m.graph()->insertConstant(constants_[offset], nullptr, loc); | ||||
|     Value* value = | ||||
|         m.graph()->insertConstant(constants_.at(offset), nullptr, loc); | ||||
|  | ||||
|     // specializing tensor type on compilation messes up typing relations | ||||
|     value->setType(unshapedType(value->type())); | ||||
|  | ||||
| @ -47,18 +47,18 @@ std::vector<char> pickle( | ||||
|  | ||||
| IValue unpickle( | ||||
|     std::function<bool(char*, size_t)> reader, | ||||
|     const std::vector<at::Tensor>* tensor_table, | ||||
|     ClassResolver class_resolver) { | ||||
|     ClassResolver class_resolver, | ||||
|     const std::vector<at::Tensor>* tensor_table) { | ||||
|   Unpickler unpickler( | ||||
|       std::move(reader), tensor_table, std::move(class_resolver)); | ||||
|       std::move(reader), std::move(class_resolver), tensor_table); | ||||
|   return unpickler.parse_ivalue(); | ||||
| } | ||||
|  | ||||
| IValue unpickle( | ||||
|     const char* data, | ||||
|     size_t size, | ||||
|     const std::vector<at::Tensor>* tensor_table, | ||||
|     ClassResolver class_resolver) { | ||||
|     ClassResolver class_resolver, | ||||
|     const std::vector<at::Tensor>* tensor_table) { | ||||
|   size_t bytes_read = 0; | ||||
|   return unpickle( | ||||
|       [&](char* buffer, size_t len) { | ||||
| @ -71,8 +71,8 @@ IValue unpickle( | ||||
|         bytes_read += len; | ||||
|         return true; | ||||
|       }, | ||||
|       tensor_table, | ||||
|       std::move(class_resolver)); | ||||
|       std::move(class_resolver), | ||||
|       tensor_table); | ||||
| } | ||||
|  | ||||
| } // namespace jit | ||||
|  | ||||
| @ -50,18 +50,13 @@ TORCH_API void pickle( | ||||
|     std::vector<at::Tensor>* tensor_table = nullptr); | ||||
|  | ||||
| /// `reader` is a function that takes in a size to read from some pickled | ||||
| /// binary. `reader` should remember where it last read. | ||||
| /// | ||||
| /// `bounds_checker` is a function that returns `true` if the reader can read | ||||
| /// more data, and `false` if it cannot (i.e. if a stream has hit its end of | ||||
| /// file) | ||||
| /// | ||||
| /// binary. `reader` should remember where it last read, and return | ||||
| /// false if the read was not successful. | ||||
| /// See `torch::pickle` for details. | ||||
| TORCH_API IValue unpickle( | ||||
|     std::function<const char*(size_t)> reader, | ||||
|     std::function<bool()> bounds_chcker, | ||||
|     const std::vector<at::Tensor>* tensor_table = nullptr, | ||||
|     ClassResolver class_resolver = nullptr); | ||||
|     std::function<bool(char*, size_t)> reader, | ||||
|     ClassResolver class_resolver, | ||||
|     const std::vector<at::Tensor>* tensor_table); | ||||
|  | ||||
| /// Decode a chunk of memory containing pickled data into its `torch::IValue`s. | ||||
| /// | ||||
| @ -72,8 +67,8 @@ TORCH_API IValue unpickle( | ||||
| TORCH_API IValue unpickle( | ||||
|     const char* data, | ||||
|     size_t size, | ||||
|     const std::vector<at::Tensor>* tensor_table = nullptr, | ||||
|     ClassResolver class_resolver = nullptr); | ||||
|     ClassResolver class_resolver = nullptr, | ||||
|     const std::vector<at::Tensor>* tensor_table = nullptr); | ||||
|  | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
|  | ||||
| @ -151,9 +151,9 @@ void Pickler::pushIValueImpl(const IValue& ivalue) { | ||||
|   } else if (ivalue.isTuple()) { | ||||
|     pushTuple(ivalue); | ||||
|   } else if (ivalue.isDouble()) { | ||||
|     pushDouble(ivalue); | ||||
|     pushDouble(ivalue.toDouble()); | ||||
|   } else if (ivalue.isInt()) { | ||||
|     pushInt(ivalue); | ||||
|     pushInt(ivalue.toInt()); | ||||
|   } else if (ivalue.isBool()) { | ||||
|     if (ivalue.toBool()) { | ||||
|       push<OpCode>(OpCode::NEWTRUE); | ||||
| @ -244,8 +244,7 @@ void Pickler::pushIValue(const IValue& ivalue) { | ||||
|   } | ||||
| } | ||||
|  | ||||
| void Pickler::pushInt(const IValue& ivalue) { | ||||
|   auto n = ivalue.toInt(); | ||||
| void Pickler::pushInt(int64_t n) { | ||||
|   if (n >= std::numeric_limits<int8_t>::min() && | ||||
|       n <= std::numeric_limits<int8_t>::max()) { | ||||
|     push<OpCode>(OpCode::BININT1); | ||||
| @ -311,9 +310,11 @@ void Pickler::pushStorageOfTensor(const at::Tensor& tensor) { | ||||
|   // root_key | ||||
|   pushString(std::to_string(tensor_data_.size())); | ||||
|   // location | ||||
|   pushString("cpu"); | ||||
|   std::stringstream ss; | ||||
|   ss << tensor.device(); | ||||
|   pushString(ss.str()); | ||||
|   // size | ||||
|   pushInt(tensor.numel()); | ||||
|   pushInt(tensor.storage().size()); | ||||
|   // view_metadata | ||||
|   push<OpCode>(OpCode::NONE); | ||||
|   push<OpCode>(OpCode::TUPLE); | ||||
| @ -362,17 +363,18 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) { | ||||
|   // The format here is the same one used by `torch.save()`. The code for the | ||||
|   // format can be found in `torch/serialization.py`. | ||||
|   auto tensor = ivalue.toTensor(); | ||||
|  | ||||
|   bool quantized = tensor.is_quantized(); | ||||
|   // The arguments to this function are: | ||||
|   //    storage, storage_offset, size, stride, requires_grad, backward_hooks | ||||
|   pushGlobal("torch._utils", "_rebuild_tensor_v2"); | ||||
|   pushGlobal( | ||||
|       "torch._utils", quantized ? "_rebuild_qtensor" : "_rebuild_tensor_v2"); | ||||
|  | ||||
|   push<OpCode>(OpCode::MARK); | ||||
|  | ||||
|   pushStorageOfTensor(tensor); | ||||
|  | ||||
|   // storage offset | ||||
|   int64_t storage_offset = 0; | ||||
|   pushInt(storage_offset); | ||||
|   pushInt(tensor.storage_offset()); | ||||
|  | ||||
|   // size | ||||
|   push<OpCode>(OpCode::MARK); | ||||
| @ -388,6 +390,11 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) { | ||||
|   } | ||||
|   push<OpCode>(OpCode::TUPLE); | ||||
|  | ||||
|   if (quantized) { | ||||
|     pushDouble(tensor.q_scale()); | ||||
|     pushInt(tensor.q_zero_point()); | ||||
|   } | ||||
|  | ||||
|   // requires_grad | ||||
|   pushIValue(tensor.requires_grad()); | ||||
|  | ||||
| @ -447,8 +454,7 @@ void Pickler::pushSpecializedList( | ||||
|   push<OpCode>(OpCode::REDUCE); | ||||
| } | ||||
|  | ||||
| void Pickler::pushDouble(const IValue& ivalue) { | ||||
|   double value = ivalue.toDouble(); | ||||
| void Pickler::pushDouble(double value) { | ||||
|   AT_ASSERT(sizeof(double) == 8); | ||||
|   char* bytes = reinterpret_cast<char*>(&value); | ||||
|  | ||||
| @ -522,16 +528,6 @@ IValue Unpickler::parse_ivalue() { | ||||
|   return stack_[0]; | ||||
| } | ||||
|  | ||||
| IValue Unpickler::parseModule() { | ||||
|   run(); | ||||
|   TORCH_CHECK( | ||||
|       stack_.size() == 1, | ||||
|       "Unpickler expected 1 element on the stack, but found ", | ||||
|       stack_.size()); | ||||
|  | ||||
|   return stack_[0]; | ||||
| } | ||||
|  | ||||
| double Unpickler::readFloat() { | ||||
|   AT_ASSERT(sizeof(double) == 8); | ||||
|   double big_endian = read<double>(); | ||||
| @ -600,6 +596,12 @@ static IValue toSpecializedList(const IValue& generic) { | ||||
|   return IValue(std::move(specialized)); | ||||
| } | ||||
|  | ||||
| static std::vector<int64_t> tupleToIntList(const IValue& v) { | ||||
|   return fmap(v.toTuple()->elements(), [](const IValue& v) -> int64_t { | ||||
|     return v.toInt(); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| OpCode Unpickler::readInstruction() { | ||||
|   auto opcode = readOpCode(); | ||||
|   switch (opcode) { | ||||
| @ -748,6 +750,63 @@ OpCode Unpickler::readInstruction() { | ||||
|               AT_ERROR("Unknown pickler class id"); | ||||
|           } | ||||
|         }); | ||||
|       } else if ( | ||||
|           module_name == "torch._utils" && | ||||
|           (class_name == "_rebuild_tensor_v2" || | ||||
|            class_name == "_rebuild_qtensor")) { | ||||
|         bool quantized = class_name == "_rebuild_qtensor"; | ||||
|         globals_.emplace_back([this, quantized] { | ||||
|           auto tup = pop(stack_).toTuple(); | ||||
|           const auto& elements = tup->elements(); | ||||
|           size_t idx = 0; | ||||
|           auto storage_tensor = elements.at(idx++).toTensor(); | ||||
|           int64_t storage_offset = elements.at(idx++).toInt(); | ||||
|           std::vector<int64_t> size = tupleToIntList(elements.at(idx++)); | ||||
|           std::vector<int64_t> stride = tupleToIntList(elements.at(idx++)); | ||||
|           double q_scale = 0.; | ||||
|           int64_t q_zero_point = 0; | ||||
|           if (quantized) { | ||||
|             q_scale = elements.at(idx++).toDouble(); | ||||
|             q_zero_point = elements.at(idx++).toInt(); | ||||
|           } | ||||
|           bool requires_grad = elements.at(idx++).toBool(); | ||||
|           // elements[idx++] is empty backwards hooks | ||||
|           at::Tensor result = quantized | ||||
|               ? at::_empty_affine_quantized( | ||||
|                     {}, storage_tensor.options(), q_scale, q_zero_point) | ||||
|               : at::empty({0}, storage_tensor.options()); | ||||
|           at::TensorImpl* impl = result.unsafeGetTensorImpl(); | ||||
|           impl->set_storage(storage_tensor.storage()); | ||||
|           impl->set_storage_offset(storage_offset); | ||||
|           impl->set_sizes_and_strides(size, stride); | ||||
|           result = autograd::make_variable(result, requires_grad); | ||||
|           stack_.push_back(std::move(result)); | ||||
|         }); | ||||
|       } else if (module_name == "collections" && class_name == "OrderedDict") { | ||||
|         globals_.emplace_back([this] { | ||||
|           // drop the Tuple that was argument to OrderedDict, and replace it | ||||
|           // with None OrderedDicts only appear in tensor deserialization and | ||||
|           // their value is never used | ||||
|           stack_.back() = IValue(); | ||||
|         }); | ||||
|       } else if (module_name == "torch") { | ||||
|         c10::optional<c10::ScalarType> scalar_type; | ||||
| #define CHECK_SCALAR(_, name)          \ | ||||
|   if (class_name == #name "Storage") { \ | ||||
|     scalar_type = c10::k##name;        \ | ||||
|   } | ||||
|         AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CHECK_SCALAR) | ||||
| #undef CHECK_SCALAR | ||||
|         // NOTE: this does not put a global into the global table, | ||||
|         // like the other branches here because no REDUCE or BUILD will | ||||
|         // be called on this value. Instead, we just put it on the stack | ||||
|         // and return early | ||||
|         AT_ASSERT( | ||||
|             scalar_type.has_value(), | ||||
|             "class name not understood: torch.", | ||||
|             class_name); | ||||
|         stack_.emplace_back(int64_t(*scalar_type)); | ||||
|         return opcode; | ||||
|       } else { | ||||
|         AT_ASSERT(class_resolver_); | ||||
|         at::StrongTypePtr type = | ||||
| @ -797,12 +856,52 @@ OpCode Unpickler::readInstruction() { | ||||
|       // stack is: <functor_arg> | ||||
|       globals_.at(idx)(); | ||||
|     } break; | ||||
|     default: | ||||
|     case OpCode::BINPERSID: { | ||||
|       auto args = pop(stack_).toTuple()->elements(); | ||||
|       AT_ASSERT( | ||||
|           args.at(0).toStringRef() == "storage", | ||||
|           "unknown PERSID key ", | ||||
|           args.at(0).toStringRef()); | ||||
|       at::ScalarType type = args.at(1).toScalarType(); | ||||
|       const std::string& key = args.at(2).toStringRef(); | ||||
|       at::Device device(args.at(3).toStringRef()); | ||||
|       if (device_) { | ||||
|         device = *device_; | ||||
|       } | ||||
|       at::DataPtr storage_ptr = read_record_(key); | ||||
|       int64_t numel = args.at(4).toInt(); | ||||
|       at::Storage storage( | ||||
|           at::CPU(type).typeMeta(), | ||||
|           numel, | ||||
|           std::move(storage_ptr), | ||||
|           /*allocator=*/nullptr, | ||||
|           /*resizable=*/false); // NB: we didn't set any allocator for the | ||||
|                                 // tensor | ||||
|       auto options = at::CPU(type).options(); | ||||
|       at::Tensor tensor; | ||||
|       if (options.backend() == c10::Backend::QuantizedCPU) { | ||||
|         tensor = at::_empty_affine_quantized({}, options, 0, 0) | ||||
|                      .set_(storage, 0, {}, {}); | ||||
|       } else { | ||||
|         tensor = at::empty({0}, options).set_(storage); | ||||
|       } | ||||
|  | ||||
|       if (device.type() == at::DeviceType::CUDA) { | ||||
|         tensor = tensor.to(device, tensor.scalar_type()); | ||||
|       } else if (device.type() != at::DeviceType::CPU) { | ||||
|         AT_ERROR( | ||||
|             "supported devices include CPU and CUDA, however got ", | ||||
|             at::DeviceTypeName(device.type(), false)); | ||||
|       } | ||||
|       stack_.push_back(std::move(tensor)); | ||||
|     } break; | ||||
|     default: { | ||||
|       AT_ERROR( | ||||
|           "Unknown opcode for unpickling at ", | ||||
|           reinterpret_cast<void*>(opcode), | ||||
|           ": ", | ||||
|           int(static_cast<uint8_t>(opcode))); | ||||
|     } break; | ||||
|   } | ||||
|   return opcode; | ||||
| } | ||||
|  | ||||
| @ -147,12 +147,16 @@ class Pickler { | ||||
|   void startTuple(); | ||||
|   void endTuple(); | ||||
|  | ||||
|   const std::vector<WriteableTensorData>& tensorData() { | ||||
|     return tensor_data_; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   void pushIValueImpl(const IValue& ivalue); | ||||
|   void pushDict(const IValue& ivalue); | ||||
|   void pushDouble(const IValue& ivalue); | ||||
|   void pushDouble(double value); | ||||
|   void pushGenericList(const IValue& ivalue); | ||||
|   void pushInt(const IValue& ivalue); | ||||
|   void pushInt(int64_t value); | ||||
|   void pushIntList(const IValue& ivalue); | ||||
|   void pushList(const IValue& ivalue); | ||||
|   void pushLiteralTensor(const IValue& ivalue); | ||||
| @ -233,16 +237,29 @@ class Unpickler { | ||||
|   TH_DISALLOW_COPY_AND_ASSIGN(Unpickler); | ||||
|  | ||||
|  public: | ||||
|   // tensors inside the pickle are references to the tensor_table | ||||
|   Unpickler( | ||||
|       std::function<bool(char*, size_t)> reader, | ||||
|       const std::vector<at::Tensor>* tensor_table, | ||||
|       ClassResolver class_resolver) | ||||
|       ClassResolver class_resolver, | ||||
|       const std::vector<at::Tensor>* tensor_table) | ||||
|       : reader_(reader), | ||||
|         tensor_table_(tensor_table), | ||||
|         class_resolver_(std::move(class_resolver)) {} | ||||
|  | ||||
|   // tensors inside the pickle contain meta-data, the raw tensor | ||||
|   // dead is retrieved by calling `read_record`. | ||||
|   Unpickler( | ||||
|       std::function<bool(char*, size_t)> reader, | ||||
|       ClassResolver class_resolver, | ||||
|       std::function<at::DataPtr(const std::string&)> read_record, | ||||
|       c10::optional<at::Device> device) | ||||
|       : reader_(reader), | ||||
|         tensor_table_(nullptr), | ||||
|         class_resolver_(std::move(class_resolver)), | ||||
|         read_record_(std::move(read_record)), | ||||
|         device_(std::move(device)) {} | ||||
|  | ||||
|   IValue parse_ivalue(); | ||||
|   IValue parseModule(); | ||||
|  | ||||
|  private: | ||||
|   // No arguments ensures that a template arugment must be specified | ||||
| @ -282,6 +299,9 @@ class Unpickler { | ||||
|   // optionally nullptr, needs to be present for creating classes | ||||
|   ClassResolver class_resolver_; | ||||
|   IValue empty_tuple_; | ||||
|  | ||||
|   std::function<at::DataPtr(const std::string&)> read_record_; | ||||
|   c10::optional<at::Device> device_; | ||||
| }; | ||||
|  | ||||
| // returns a (tensor, record_size) for a tensor, converting it to a CPU tensor | ||||
|  | ||||
		Reference in New Issue
	
	Block a user