Files
pytorch/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp
PyTorch MergeBot 564d00f364 Revert "Fix clang-tidy warnings in Caffe2 code (#134935)"
This reverts commit 7cfd23636c8fa6fcbb8bf3ea34e15b847ec9ad9d.

Reverted https://github.com/pytorch/pytorch/pull/134935 on behalf of https://github.com/izaitsevfb due to breaks internal builds, caffe2 is still used internally ([comment](https://github.com/pytorch/pytorch/pull/134935#issuecomment-2349368152))
2024-09-13 16:42:37 +00:00

447 lines
17 KiB
C++

#include <ATen/core/ivalue.h>
#include <caffe2/serialize/file_adapter.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_read.h>
#include <caffe2/serialize/in_memory_adapter.h>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
namespace c10 {
TypePtr parseType(const std::string& pythonStr);
} // namespace c10
namespace torch::jit {
using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;
c10::IValue readArchive(
const std::string& archive_name,
PyTorchStreamReader& stream_reader) {
std::optional<at::Device> device;
std::shared_ptr<CompilationUnit> compilation_unit =
std::make_shared<CompilationUnit>();
// TODO (T90180710): Simplify type_resolver and obj_loader when getting
// bytecode version from model
auto type_resolver = [&](const c10::QualifiedName& qn) {
return typeResolverMobile(qn, compilation_unit);
};
std::shared_ptr<mobile::CompilationUnit> mobile_compilation_unit =
std::make_shared<mobile::CompilationUnit>();
auto obj_loader = [&](const at::StrongTypePtr& type, const IValue& input) {
return objLoaderMobile(type, input, *mobile_compilation_unit);
};
bool bytecode_tensor_in_constants_archive =
(archive_name == "bytecode" && !isTensorInBytecodeArchive(stream_reader));
auto ivalues = torch::jit::readArchiveAndTensors(
archive_name,
/*pickle_prefix=*/"",
/*tensor_prefix=*/
bytecode_tensor_in_constants_archive ? "constants/" : "",
type_resolver,
obj_loader,
device,
stream_reader,
nullptr);
return ivalues;
}
std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
return std::move(*readArchive("bytecode", reader).toTuple()).elements().vec();
}
/********************** Bytecode **********************/
// Forward declare
uint64_t _get_model_bytecode_version(
const std::vector<IValue>& bytecode_ivalues);
static uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size);
uint64_t _get_model_bytecode_version(std::istream& in) {
auto orig_pos = in.tellg();
in.seekg(0, in.beg);
auto [data, size] = get_stream_content(in);
in.seekg(orig_pos, in.beg);
return _get_model_bytecode_version_from_bytes(data.get(), size);
}
uint64_t _get_model_bytecode_version(const std::string& filename) {
std::ifstream ifile(filename);
return _get_model_bytecode_version(ifile);
}
uint64_t _get_model_bytecode_version(
const std::shared_ptr<ReadAdapterInterface>& rai) {
auto [data, size] = get_rai_content(rai.get());
return _get_model_bytecode_version_from_bytes(data.get(), size);
}
static uint64_t _get_model_bytecode_version_zip(
std::shared_ptr<ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_CHECK(
false,
"Failed to open .ptl file please ensure the model was exported for mobile");
}
PyTorchStreamReader reader(std::move(rai));
auto bytecode_values = get_bytecode_ivalues(reader);
return _get_model_bytecode_version(bytecode_values);
}
uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size) {
TORCH_CHECK(data != nullptr, "Pointer to bytes is null.");
TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
auto format = getFileFormat(data);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
return get_bytecode_version_from_bytes(data);
}
case FileFormat::ZipFileFormat: {
auto rai =
std::make_unique<caffe2::serialize::MemoryReadAdapter>(data, size);
auto version = _get_model_bytecode_version_zip(std::move(rai));
return version;
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
}
uint64_t _get_model_bytecode_version(
const std::vector<IValue>& bytecode_ivalues) {
if (!bytecode_ivalues.empty() && bytecode_ivalues[0].isInt()) {
int64_t model_version = bytecode_ivalues[0].toInt();
TORCH_CHECK(
model_version > 0,
"Expected model bytecode version > 0 got ",
model_version);
return static_cast<uint64_t>(model_version);
}
TORCH_CHECK(false, "Failed to get bytecode version.");
}
/********************** Operator Version **********************/
uint64_t _get_model_operator_version(
PyTorchStreamReader& reader); // Forward Declare
uint64_t _get_model_operator_version(std::istream& in) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
return _get_model_operator_version(std::move(rai));
}
uint64_t _get_model_operator_version(const std::string& filename) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
return _get_model_operator_version(std::move(rai));
}
uint64_t _get_model_operator_version(
std::shared_ptr<ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_CHECK(
false,
"Failed to open .ptl file please ensure the model was exported for mobile");
}
PyTorchStreamReader reader(std::move(rai));
return _get_model_operator_version(reader);
}
uint64_t _get_model_operator_version(PyTorchStreamReader& reader) {
return reader.version();
}
/********************** Operators and Info **********************/
// Forward declare
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
std::vector<IValue> bytecode_ivalues);
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
std::istream& in) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
return _get_model_ops_and_info(std::move(rai));
}
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
const std::string& filename) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
return _get_model_ops_and_info(std::move(rai));
}
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
std::shared_ptr<ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_WARN("Failed to open zip file for model ops.");
return std::unordered_map<std::string, OperatorInfo>{};
}
PyTorchStreamReader reader(std::move(rai));
auto bytecode_values = get_bytecode_ivalues(reader);
return _get_model_ops_and_info(bytecode_values);
}
/* A function to retrieve the root (top level) operators of a model and their
* corresponding compatibility info. These root operators can call other
* operators within them (traced ops), and a root op can call many different
* traced ops depending on internal code paths in the root op. These traced ops
* are not returned by this function. Those operators are abstracted into the
* runtime as an implementation detail (and the traced ops themselves can also
* call other operators) making retrieving them difficult and their value from
* this api negligible since they will differ between which runtime version the
* model is run on. Because of this, there is a false positive this api can't
* prevent in a compatibility usecase. All the root ops of a model are present
* in a target runtime, but not all the traced ops are which prevents a model
* from being able to run.
**/
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
std::vector<IValue> bytecode_ivalues) {
constexpr uint64_t min_version_with_schema = 6;
if (_get_model_bytecode_version(bytecode_ivalues) < min_version_with_schema) {
TORCH_WARN(
"Only models with bytecode version 6 and above contain operator schema information. Please re-export your model to generate it");
}
std::unordered_map<std::string, OperatorInfo> result;
if (bytecode_ivalues.empty()) {
TORCH_WARN("Failed to get model ops and info.");
return result;
}
// loop over all the functions in the bytecode
for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
// descend to the operators list
const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
auto operators_tuple = method_tuple.at(1).toTupleRef().elements()[1];
auto operators = operators_tuple.toTupleRef().elements()[1];
for (auto& op_tuple : operators.toTupleRef().elements()) {
const auto& op = op_tuple.toTupleRef().elements();
// grab name
std::string op_name = op.at(0).toStringRef();
std::string op_overload_name = op.at(1).toStringRef();
if (!op_overload_name.empty()) {
op_name.append(".");
op_name.append(op_overload_name);
}
// grab schema size
if (op.size() > 2) {
result.emplace(op_name, OperatorInfo{(int)op.at(2).toInt()});
} else { // no schema information use default
result.emplace(op_name, OperatorInfo{});
}
}
}
return result;
}
/********************** Get Type Table **********************/
// Forward declare
std::unordered_set<std::string> _get_mobile_model_contained_types(
const std::vector<IValue>& bytecode_ivalues);
std::unordered_set<std::string> _get_mobile_model_contained_types(
std::istream& in) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
return _get_mobile_model_contained_types(std::move(rai));
}
std::unordered_set<std::string> _get_mobile_model_contained_types(
const std::string& filename) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
return _get_mobile_model_contained_types(std::move(rai));
}
std::unordered_set<std::string> _get_mobile_model_contained_types(
std::shared_ptr<ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_CHECK(
false,
"Failed to open .ptl file please ensure the model was exported for mobile");
}
PyTorchStreamReader reader(std::move(rai));
auto bytecode_values = get_bytecode_ivalues(reader);
return _get_mobile_model_contained_types(bytecode_values);
}
// Get deduplicate type table given bytecode, and each string is a atomic type,
// like str, Tensor and etc. For example,
// input: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
// output: {Dict, int, Tuple, Tensor}
std::unordered_set<std::string> _get_mobile_model_contained_types(
const std::vector<IValue>& bytecode_ivalues) {
std::unordered_set<std::string> contained_types;
// To avoid parsing same type twice, declare $parsed_type_names_records and
// use type name (string, ex: "Dict[int, Tuple[Tensor, Tensor, Tensor]]") as
// the hash to record which types are parsed.
std::unordered_set<std::string> parsed_type_names_records;
for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
auto type_table_tuple =
method_tuple.at(1).toTupleRef().elements()[BYTECODE_INDEX_TYPE];
const auto& type_table =
type_table_tuple.toTupleRef().elements()[1].toTupleRef().elements();
// type_table is a list of IValue, and each IValue is a string,
// for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
std::vector<std::string> type_name_list;
for (const auto& type_definition : type_table) {
std::unordered_set<std::string> type_tokens;
std::string type_name = type_definition.toStringRef();
type_name_list.emplace_back(type_name);
}
at::TypeParser parser(type_name_list);
parser.parseList();
contained_types = parser.getContainedTypes();
}
return contained_types;
}
/********************** Compatibility Checker **********************/
ModelCompatibilityInfo ModelCompatibilityInfo::get(std::istream& in) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
return get(std::move(rai));
}
ModelCompatibilityInfo ModelCompatibilityInfo::get(
const std::string& filename) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
return get(std::move(rai));
}
ModelCompatibilityInfo ModelCompatibilityInfo::get(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_CHECK(
false, "Failed to open zip file for model compatibility information");
}
PyTorchStreamReader reader(std::move(rai));
std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader);
uint64_t model_bytecode_version =
_get_model_bytecode_version(bytecode_values);
auto model_info = _get_model_ops_and_info(bytecode_values);
std::unordered_set<std::string> type_table =
_get_mobile_model_contained_types(bytecode_values);
uint64_t operator_version = _get_model_operator_version(reader);
return ModelCompatibilityInfo{
model_bytecode_version, model_info, type_table, operator_version};
}
ModelCompatCheckResult is_compatible(
RuntimeCompatibilityInfo runtime_info,
const ModelCompatibilityInfo& model_info) {
ModelCompatCheckResult result = {ModelCompatibilityStatus::OK, {}};
// Check that the models bytecode version is less than or equal to
// kMaxSupportedBytecodeVersion from the runtime
if (model_info.bytecode_version >
runtime_info.min_max_supported_bytecode_version.second) {
result.status = ModelCompatibilityStatus::ERROR;
std::ostringstream s;
s << "model bytecode version " << model_info.bytecode_version
<< "is greater than the max supported bytecode version in runtimes "
<< runtime_info.min_max_supported_bytecode_version.second;
result.errors.emplace_back(s.str());
} else if (
model_info.bytecode_version <
runtime_info.min_max_supported_bytecode_version.first) {
result.status = ModelCompatibilityStatus::ERROR;
std::ostringstream s;
s << "model bytecode version " << model_info.bytecode_version
<< "is less than the minimum supported bytecode version in runtime "
<< runtime_info.min_max_supported_bytecode_version.first;
result.errors.emplace_back(s.str());
}
std::unordered_set<std::string> supported_type = runtime_info.supported_types;
// Check type table
for (const auto& type_name : model_info.type_table) {
if (supported_type.find(type_name) == supported_type.end()) {
result.status = ModelCompatibilityStatus::ERROR;
std::ostringstream s;
s << "Primitive type: '" << type_name
<< "' is not supported in current runtime";
result.errors.push_back(s.str());
}
}
// Check operators
std::unordered_map<std::string, OperatorInfo> operator_info =
model_info.operator_info;
for (auto const& op : operator_info) {
std::string op_name = op.first;
OperatorInfo model_op_info = op.second;
// Check if operator not present in runtime
if (runtime_info.operator_info.find(op_name) ==
runtime_info.operator_info.end()) {
result.status = ModelCompatibilityStatus::ERROR;
std::ostringstream s;
s << "Operator '" << op_name << "' missing from runtime (not found)";
result.errors.push_back(s.str());
} else {
OperatorInfo runtime_op_info = runtime_info.operator_info.at(op_name);
// If the runtime op has no schema information its a false alarm and isn't
// actually useable
if (!runtime_op_info.num_schema_args.has_value()) {
result.status = ModelCompatibilityStatus::ERROR;
std::ostringstream s;
s << "Operator '" << op_name
<< "' missing from runtime (missing schema)";
result.errors.push_back(s.str());
} else {
// Check if the model operator has schema information. If it doesn't
// then the model is from a bytecode version < 6 and we are done. If the
// model has more args than the runtime, then the runtime can't know
// what to do so we aren't compatible. If the runtime has more args than
// the model then we can just use default values and be fine.
if (model_op_info.num_schema_args.has_value() &&
(model_op_info.num_schema_args.value() >
runtime_op_info.num_schema_args.value())) {
result.status = ModelCompatibilityStatus::ERROR;
std::ostringstream s;
s << "Operator schema for'" << op_name << "' has "
<< model_op_info.num_schema_args.value()
<< " args in model but only "
<< runtime_op_info.num_schema_args.value() << " in the runtime";
result.errors.push_back(s.str());
}
}
}
}
// Check Operator Versions
if (model_info.operator_version <
runtime_info.min_max_supported_opperator_versions.first ||
model_info.operator_version >
runtime_info.min_max_supported_opperator_versions.second) {
result.status = ModelCompatibilityStatus::ERROR;
std::ostringstream s;
s << "Model Operator Version " << model_info.operator_version
<< "is not within supported version range of the runtime "
<< runtime_info.min_max_supported_opperator_versions.first << " to "
<< runtime_info.min_max_supported_opperator_versions.second;
result.errors.push_back(s.str());
}
return result;
}
} // namespace torch::jit