Files
pytorch/torch/csrc/jit/serialization/import.cpp
cyy e9e93c5350 [Reland] Move torch::make_unique to std::make_unique (#109780)
We can first try to move torch::make_unique to std::make_unique despite reverting of #108866 .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109780
Approved by: https://github.com/ezyang
2023-09-21 18:30:21 +00:00

632 lines
23 KiB
C++

#include <ATen/core/interned_strings.h>
#include <c10/core/CPUAllocator.h>
#include <c10/core/impl/alloc_cpu.h>
#include <caffe2/serialize/file_adapter.h>
#include <caffe2/serialize/in_memory_adapter.h>
#include <caffe2/serialize/inline_container.h>
#include <caffe2/serialize/istream_adapter.h>
#include <caffe2/serialize/read_adapter_interface.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/serialization/source_range_serialization.h>
#include <ATen/core/functional.h>
#include <ATen/core/ivalue_inl.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/serialization/import_export_helpers.h>
#if !defined(C10_MOBILE) && !defined(C10_DISABLE_LEGACY_IMPORT)
#include <torch/csrc/jit/serialization/import_legacy.h>
#endif
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/ir/graph_utils.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/serialization/import_read.h>
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <torch/csrc/jit/serialization/source_range_serialization.h>
#include <torch/csrc/jit/serialization/unpickler.h>
#include <ATen/ATen.h>
#include <fmt/format.h>
#include <fstream>
#include <string>
#include <unordered_map>
#include <vector>
namespace torch::jit {
using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::MemoryReadAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;
static void postSetStateValidate(const IValue& v) {
auto obj = v.toObject();
const auto& objType = obj->type();
for (const auto i : c10::irange(objType->numAttributes())) {
const auto& attrType = objType->getAttribute(i);
const auto& attrName = objType->getAttributeName(i);
const auto& slot = obj->getSlot(i);
// const auto attrType = objType->getAttribute(i);
// Verify that all the non-optional attributes have been initialized
// TODO: Issue #20497
if (attrType->kind() != TypeKind::UnionType &&
attrType->kind() != TypeKind::OptionalType &&
attrType->kind() != TypeKind::NoneType) {
TORCH_CHECK(
!slot.isNone(),
fmt::format(
"The field '{}' was left uninitialized after '__setstate__', "
"but expected a value of type '{}'",
attrName,
attrType->repr_str()));
}
}
}
namespace {
// This is a deserializer class which loads script modules from pt files.
// Content of the file is written using PyTorchStreamWriter, for details please
// check caffe2/serialize/inline_container.h.
// The module is saved in pickle. readArchive() is called to parse and construct
// the constant table and the script module.
class ScriptModuleDeserializer final {
public:
ScriptModuleDeserializer(
std::shared_ptr<CompilationUnit> cu,
std::shared_ptr<PyTorchStreamReader> reader)
: compilation_unit_(std::move(cu)),
reader_(std::move(reader)),
code_prefix_("code/"),
pickle_dir_prefix_(""),
tensor_dir_prefix_(""),
source_importer_(
compilation_unit_,
&constants_table_,
[this](const std::string& qualifier) {
return findSourceInArchiveFromQualifier(
*reader_, code_prefix_, qualifier);
},
reader_->version()) {}
ScriptModuleDeserializer(
std::shared_ptr<CompilationUnit> cu,
std::shared_ptr<PyTorchStreamReader> reader,
std::string pickle_dir_prefix,
std::string tensor_dir_prefix,
std::shared_ptr<DeserializationStorageContext> storage_context)
: compilation_unit_(std::move(cu)),
reader_(std::move(reader)),
storage_context_(std::move(storage_context)),
code_prefix_(".data/ts_code/code/"),
pickle_dir_prefix_(std::move(pickle_dir_prefix)),
tensor_dir_prefix_(std::move(tensor_dir_prefix)),
source_importer_(
compilation_unit_,
&constants_table_,
[this](const std::string& qualifier) {
return findSourceInArchiveFromQualifier(
*reader_, code_prefix_, qualifier);
},
reader_->version()) {}
Module deserialize(
c10::optional<at::Device> device,
ExtraFilesMap& extra_files,
bool restore_shapes = false);
private:
IValue readArchive(const std::string& archive_name);
std::shared_ptr<CompilationUnit> compilation_unit_;
std::shared_ptr<PyTorchStreamReader> reader_;
std::shared_ptr<DeserializationStorageContext> storage_context_;
c10::optional<at::Device> device_;
std::vector<at::IValue> constants_table_;
std::string code_prefix_;
std::string pickle_dir_prefix_;
std::string tensor_dir_prefix_;
SourceImporter source_importer_;
};
IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
auto type_resolver = [&](const c10::QualifiedName& qn) {
auto cls = source_importer_.loadType(qn);
return c10::StrongTypePtr(compilation_unit_, std::move(cls));
};
// Decouple how to get obj from type. In this file it's dependent on
// Method.run() and graph executor, etc.
// For bytecode import we need to decouple these dependencies.
auto obj_loader = [&](const at::StrongTypePtr& type, IValue input) {
auto cls = type.type_->expect<at::ClassType>();
auto qn = cls->name();
size_t n = cls->numAttributes();
if (checkHasValidSetGetState(cls)) {
auto obj = c10::ivalue::Object::create(type, n);
// XXX: Do not optimize __setstate__, so that we don't try to
// specialize the class before it is initialized.
GraphOptimizerEnabledGuard guard(false);
Function& set_state = cls->getMethod("__setstate__");
// since we are in the middle of unpickling we might still have lists and
// dicts that do not have accurate tags (e.g. they report they are
// List[Any]). But we need to run __setstate__ which will check the input
// type and may access the tags. Since setstate has a known input type, we
// can correctly restore the tags now by apply the input type of set_state
// to the state object being passed.
// TODO: Remove once [serialization type tags] is landed
restoreAccurateTypeTags(
input, set_state.getSchema().arguments().at(1).type());
set_state({obj, input});
postSetStateValidate(obj);
return obj;
} else {
auto dict = std::move(input).toGenericDict();
auto obj = c10::ivalue::Object::create(type, n);
for (const auto i : c10::irange(n)) {
obj->setSlot(i, dict.at(cls->getAttributeName(i)));
}
return obj;
}
};
return readArchiveAndTensors(
/*archive_name=*/archive_name,
/*pickle_prefix=*/pickle_dir_prefix_,
/*tensor_prefix=*/tensor_dir_prefix_,
type_resolver,
obj_loader,
device_,
*reader_.get(),
nullptr,
storage_context_);
}
void rewriteQuantizedConvForBC(const Module& module) {
const std::string& old_quantized_conv2d = R"(
graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
%r = quantized::conv2d(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
return (%r) )";
const std::string& old_quantized_conv2d_relu = R"(
graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
%r = quantized::conv2d_relu(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
return (%r) )";
const std::string& old_quantized_conv3d = R"(
graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
%r = quantized::conv3d(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
return (%r) )";
const std::string& old_quantized_conv3d_relu = R"(
graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
%r = quantized::conv3d_relu(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
return (%r) )";
const std::string& new_quantized_conv2d = R"(
graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
%r = quantized::conv2d(%x, %packed_params, %r_scale, %r_zero_point)
return (%r) )";
const std::string& new_quantized_conv2d_relu = R"(
graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
%r = quantized::conv2d_relu(%x, %packed_params, %r_scale, %r_zero_point)
return (%r) )";
const std::string& new_quantized_conv3d = R"(
graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
%r = quantized::conv3d(%x, %packed_params, %r_scale, %r_zero_point)
return (%r) )";
const std::string& new_quantized_conv3d_relu = R"(
graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
%r = quantized::conv3d_relu(%x, %packed_params, %r_scale, %r_zero_point)
return (%r) )";
SubgraphRewriter rewriter;
static const std::vector<std::pair<std::string, std::string>>
patterns_and_replacements = {
{old_quantized_conv2d, new_quantized_conv2d},
{old_quantized_conv2d_relu, new_quantized_conv2d_relu},
{old_quantized_conv3d, new_quantized_conv3d},
{old_quantized_conv3d_relu, new_quantized_conv3d_relu},
};
for (const auto& item : patterns_and_replacements) {
rewriter.RegisterRewritePattern(item.first, item.second);
}
rewriter.runOnModule(module);
for (const Module& child : module.children()) {
rewriteQuantizedConvForBC(child);
}
}
Module ScriptModuleDeserializer::deserialize(
c10::optional<at::Device> device,
ExtraFilesMap& extra_files,
bool restore_shapes) {
// we populate the upgraders map before any load starts
populate_upgraders_graph_map();
C10_LOG_API_USAGE_ONCE("torch.script.load");
device_ = device;
// Load extra files.
for (const auto& kv : extra_files) {
const std::string& key = "extra/" + kv.first;
if (reader_->hasRecord(key)) {
at::DataPtr meta_ptr;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t meta_size;
std::tie(meta_ptr, meta_size) = reader_->getRecord(key);
extra_files[kv.first] =
std::string(static_cast<char*>(meta_ptr.get()), meta_size);
}
}
if (reader_->hasRecord("model.json") && code_prefix_ == "code/") {
#if !defined(C10_MOBILE) && !defined(C10_DISABLE_LEGACY_IMPORT)
return torch::jit::LEGACY_deserialize(compilation_unit_, reader_, device_);
#else
AT_ERROR("Legacy model format is not supported on mobile.");
#endif
}
auto tuple = readArchive("constants").toTuple();
for (auto constant : tuple->elements()) {
constants_table_.push_back(constant.toIValue());
}
auto m_ivalue = readArchive("data");
auto m = Module(m_ivalue.toObject());
rewriteQuantizedConvForBC(m);
// Checking for and loading saved traced inputs
if (restore_shapes && reader_->hasRecord("traced_inputs.pkl")) {
auto dict = readArchive("traced_inputs").toGenericDict();
for (const auto& entry : dict) {
auto inputs = entry.value().toList().vec();
auto g =
toGraphFunction(m.get_method(entry.key().toStringRef()).function())
.graph();
Stack stack(inputs.begin(), inputs.end());
// Added the module as the first input if we are missing
// an input as traced modules refer to self as an additional input
if (g->inputs().size() == stack.size() + 1) {
stack.insert(stack.begin(), m_ivalue);
}
setInputTensorTypes(*g, stack, /*complete=*/true);
PropagateInputShapes(g);
}
} else {
if (restore_shapes) {
TORCH_WARN("Cannot restore shapes as no traced inputs were stored");
}
}
c10::LogAPIUsageMetadata(
"torch.script.load.metadata",
{{"serialization_id", reader_->serializationId()}});
return m;
}
} // namespace
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::istream& in,
c10::optional<at::Device> device,
bool load_debug_files) {
ExtraFilesMap extra_files;
return import_ir_module(
std::move(cu), in, device, extra_files, load_debug_files);
}
static Module _load_jit_module_from_bytes(
std::shared_ptr<char> data,
size_t size,
std::shared_ptr<CompilationUnit> cu,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files,
bool restore_shapes);
Module parse_and_initialize_jit_module(
std::shared_ptr<char> data,
size_t size,
ExtraFilesMap& extra_files,
c10::optional<at::Device> device) {
populate_upgraders_graph_map();
ExtraFilesMap jit_files;
std::vector<IValue> jit_constants;
mobile::Module mobilem = parse_and_initialize_mobile_module_for_jit(
data.get(), size, jit_files, jit_constants, device, &extra_files);
Module m = jitModuleFromSourceAndConstants(
mobilem._ivalue(),
jit_files,
jit_constants,
static_cast<int32_t>(mobilem.bytecode_version()));
m.set_delete_memory(data);
return m;
}
Module load_jit_module_from_file(
const std::string& filename,
ExtraFilesMap& extra_files,
c10::optional<at::Device> device) {
auto data = get_file_content(filename.c_str());
return parse_and_initialize_jit_module(
std::move(std::get<0>(data)), std::get<1>(data), extra_files, device);
}
Module load_jit_module_from_stream(
std::istream& in,
ExtraFilesMap& extra_files,
c10::optional<at::Device> device) {
auto data = get_stream_content(in);
return parse_and_initialize_jit_module(
std::move(std::get<0>(data)), std::get<1>(data), extra_files, device);
}
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files,
bool restore_shapes) {
in.seekg(0, in.beg);
// NOTE: Zipformat can be large files. So using stream version directly
// instead of reading the file all at once.
if (getFileFormat(in) != FileFormat::FlatbufferFileFormat) {
auto reader = std::make_unique<PyTorchStreamReader>(&in);
reader->setShouldLoadDebugSymbol(load_debug_files);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files, restore_shapes);
}
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
return _load_jit_module_from_bytes(
data, size, cu, device, extra_files, restore_shapes);
}
// For reading unified serialization format from torch.Package.
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::shared_ptr<PyTorchStreamReader> reader,
std::shared_ptr<DeserializationStorageContext> storage_context,
c10::optional<at::Device> device,
std::string ts_id) {
ScriptModuleDeserializer deserializer(
std::move(cu),
std::move(reader),
/* pickle_dir_prefix = */ ".data/ts_code/" + ts_id + "/",
/* tensor_dir_prefix = */ ".data/",
storage_context);
ExtraFilesMap extra_files;
return deserializer.deserialize(device, extra_files);
}
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
const std::string& filename,
c10::optional<at::Device> device,
bool load_debug_files) {
ExtraFilesMap extra_files;
return import_ir_module(
std::move(cu), filename, device, extra_files, load_debug_files);
}
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
const std::string& filename,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files,
bool restore_shapes) {
// NOTE: Zipformat can be large files. So using stream version directly
// instead of reading the file all at once.
if (getFileFormat(filename) != FileFormat::FlatbufferFileFormat) {
auto reader = std::make_unique<PyTorchStreamReader>(filename);
reader->setShouldLoadDebugSymbol(load_debug_files);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files, restore_shapes);
}
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_file_content(filename.c_str());
return _load_jit_module_from_bytes(
data, size, cu, device, extra_files, restore_shapes);
}
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device,
bool load_debug_files) {
ExtraFilesMap extra_files;
return import_ir_module(
std::move(cu), std::move(rai), device, extra_files, load_debug_files);
}
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files) {
std::shared_ptr<ReadAdapterInterface> rai_shared = std::move(rai);
return import_ir_module(
cu, rai_shared, device, extra_files, load_debug_files);
}
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::shared_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files) {
auto reader = std::make_shared<PyTorchStreamReader>(std::move(rai));
reader->setShouldLoadDebugSymbol(load_debug_files);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}
Module load(
std::istream& in,
c10::optional<at::Device> device,
bool load_debug_files) {
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(std::move(cu), in, device, load_debug_files);
}
Module load(
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files) {
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(
std::move(cu), in, device, extra_files, load_debug_files);
}
Module load(
const std::string& filename,
c10::optional<at::Device> device,
bool load_debug_files) {
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(std::move(cu), filename, device, load_debug_files);
}
Module load(
const std::string& filename,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files) {
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(
std::move(cu), filename, device, extra_files, load_debug_files);
}
Module load(
std::shared_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device,
bool load_debug_files) {
auto cu = std::make_shared<CompilationUnit>();
ExtraFilesMap extra_files;
return import_ir_module(
std::move(cu), std::move(rai), device, extra_files, load_debug_files);
}
Module load(
std::shared_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files) {
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(
std::move(cu), std::move(rai), device, extra_files, load_debug_files);
}
Module _load_jit_module_from_bytes(
std::shared_ptr<char> data,
size_t size,
std::shared_ptr<CompilationUnit> cu,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files,
bool restore_shapes) {
TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
auto format = getFileFormat(data.get());
switch (format) {
case FileFormat::FlatbufferFileFormat: {
return parse_and_initialize_jit_module(data, size, extra_files, device);
}
case FileFormat::ZipFileFormat: {
auto rai = std::make_unique<MemoryReadAdapter>(data.get(), size);
auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files, restore_shapes);
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
}
// Replace object with a newly created but equivalent object.
// The goal is to replace object's methods. However, since object's
// methods are attached to type; we need to replace it's type.
// Non-objects are unchanged; however, nested structures such as list, dict
// are also reconstructed because they might contain an object.
static IValue recreateObject(IValue ivalue, TypeResolver resolver) {
if (ivalue.isObject()) {
auto obj = ivalue.toObject();
auto classtype_old = obj->type();
auto newtype = resolver(*classtype_old->name());
size_t n = classtype_old->numAttributes();
auto newobj = c10::ivalue::Object::create(newtype, n);
for (const auto i : c10::irange(n)) {
newobj->setSlot(i, recreateObject(obj->getSlot(i), resolver));
}
return newobj;
} else if (ivalue.isList()) {
auto res = c10::impl::GenericList(ivalue.type()->containedType(0));
for (const auto& ival : ivalue.toList()) {
res.emplace_back(recreateObject(ival, resolver));
}
return res;
} else if (ivalue.isGenericDict()) {
auto result = c10::impl::GenericDict(
ivalue.type()->containedType(0), ivalue.type()->containedType(1));
for (const auto& kv : ivalue.toGenericDict()) {
result.insert_or_assign(
recreateObject(kv.key(), resolver),
recreateObject(kv.value(), resolver));
}
return result;
} else if (ivalue.isTuple()) {
std::vector<IValue> res;
for (const auto& ival : ivalue.toTuple()->elements()) {
res.push_back(recreateObject(ival, resolver));
}
return c10::ivalue::Tuple::create(res);
}
// Leaf types are returned verbatim.
return ivalue;
}
Module jitModuleFromSourceAndConstants(
const IValue& ivalue,
const ExtraFilesMap& source,
const std::vector<IValue>& constants,
int32_t version) {
auto compilation_unit = std::make_shared<CompilationUnit>();
SourceImporter importer(
compilation_unit,
&constants,
[&source](const std::string& qualifier) -> std::shared_ptr<Source> {
auto source_iter = source.find(qualifier);
if (source_iter == source.end()) {
return nullptr;
}
return std::make_shared<Source>(
source_iter->second, qualifier, 1, nullptr, Source::COPIES_STRING);
},
version);
auto type_resolver = [&](const c10::QualifiedName& qn) {
auto cls = importer.loadType(qn);
return c10::StrongTypePtr(compilation_unit, std::move(cls));
};
auto newIvalue = recreateObject(ivalue, type_resolver).toObject();
Module m(newIvalue);
rewriteQuantizedConvForBC(m);
return m;
}
} // namespace torch::jit