Files
pytorch/torch/csrc/jit/mobile/model_compatibility.cpp
Scott Wolchok 82f7f8d471 [PyTorch] Adopt IValue::toTupleRef() where obvious (#65505)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65505

Generated with

`fastmod -m 'toTuple\(\)(\s*)->' 'toTupleRef()${1}.'`

, followed by

`fastmod '(std::move\(.*)toTupleRef\(\).' '${1}toTuple()->'`

to unbreak 2 callsites.
ghstack-source-id: 142065835

Test Plan: CI

Reviewed By: gchanan

Differential Revision: D31131025

fbshipit-source-id: 54457ae5bbeb38db9c7f196d469b98521c3d3f34
2021-11-02 10:22:18 -07:00

356 lines
14 KiB
C++

#include <ATen/core/ivalue.h>
#include <caffe2/serialize/file_adapter.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/model_compatibility.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_read.h>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
namespace c10 {
TypePtr parseType(const std::string& pythonStr);
} // namespace c10
namespace torch {
namespace jit {
using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;
c10::IValue readArchive(
const std::string& archive_name,
PyTorchStreamReader& stream_reader) {
c10::optional<at::Device> device;
std::shared_ptr<CompilationUnit> compilation_unit =
std::make_shared<CompilationUnit>();
// TODO (T90180710): Simplify type_resolver and obj_loader when getting
// bytecode version from model
auto type_resolver = [&](const c10::QualifiedName& qn) {
return typeResolverMobile(qn, compilation_unit);
};
std::shared_ptr<mobile::CompilationUnit> mobile_compilation_unit =
std::make_shared<mobile::CompilationUnit>();
auto obj_loader = [&](at::StrongTypePtr type, IValue input) {
return objLoaderMobile(type, input, *mobile_compilation_unit);
};
bool bytecode_tensor_in_constants_archive =
(archive_name == "bytecode" && !isTensorInBytecodeArchive(stream_reader));
auto ivalues = torch::jit::readArchiveAndTensors(
archive_name,
/*pickle_prefix=*/"",
/*tensor_prefix=*/
bytecode_tensor_in_constants_archive ? "constants/" : "",
type_resolver,
obj_loader,
device,
stream_reader);
return ivalues;
}
std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
return std::move(*readArchive("bytecode", reader).toTuple()).elements().vec();
}
/********************** Bytecode **********************/
// Forward declare
uint64_t _get_model_bytecode_version(
const std::vector<IValue>& bytecode_ivalues);
uint64_t _get_model_bytecode_version(std::istream& in) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
return _get_model_bytecode_version(std::move(rai));
}
uint64_t _get_model_bytecode_version(const std::string& filename) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
return _get_model_bytecode_version(std::move(rai));
}
uint64_t _get_model_bytecode_version(
std::shared_ptr<ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_CHECK(
false,
"Failed to open .ptl file please ensure the model was exported for mobile");
}
PyTorchStreamReader reader(std::move(rai));
auto bytecode_values = get_bytecode_ivalues(reader);
return _get_model_bytecode_version(bytecode_values);
}
uint64_t _get_model_bytecode_version(
const std::vector<IValue>& bytecode_ivalues) {
if (!bytecode_ivalues.empty() && bytecode_ivalues[0].isInt()) {
int64_t model_version = bytecode_ivalues[0].toInt();
TORCH_CHECK(
model_version > 0,
"Expected model bytecode version > 0 got ",
model_version);
return static_cast<uint64_t>(model_version);
}
TORCH_CHECK(false, "Failed to get bytecode version.");
}
/********************** Operators and Info **********************/
// Forward declare
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
std::vector<IValue> bytecode_ivalues);
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
std::istream& in) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
return _get_model_ops_and_info(std::move(rai));
}
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
const std::string& filename) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
return _get_model_ops_and_info(std::move(rai));
}
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
std::shared_ptr<ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_WARN("Failed to open zip file for model ops.");
return std::unordered_map<std::string, OperatorInfo>{};
}
PyTorchStreamReader reader(std::move(rai));
auto bytecode_values = get_bytecode_ivalues(reader);
return _get_model_ops_and_info(bytecode_values);
}
/* A function to retrieve the root (top level) operators of a model and their
* corresponding compatibility info. These root operators can call other
* operators within them (traced ops), and a root op can call many different
* traced ops depending on internal code paths in the root op. These traced ops
* are not returned by this function. Those operators are abstracted into the
* runtime as an implementation detail (and the traced ops themselves can also
* call other operators) making retrieving them difficult and their value from
* this api negligible since they will differ between which runtime version the
* model is run on. Because of this, there is a false positive this api can't
* prevent in a compatibility usecase. All the root ops of a model are present
* in a target runtime, but not all the traced ops are which prevents a model
* from being able to run.
**/
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
std::vector<IValue> bytecode_ivalues) {
constexpr uint64_t min_version_with_schema = 6;
if (_get_model_bytecode_version(bytecode_ivalues) < min_version_with_schema) {
TORCH_WARN(
"Only models with bytecode version 6 and above contain operator schema information. Please re-export your model to generate it");
}
std::unordered_map<std::string, OperatorInfo> result;
if (bytecode_ivalues.empty()) {
TORCH_WARN("Failed to get model ops and info.");
return result;
}
// loop over all the functions in the bytecode
for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
// descend to the operators list
const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
auto operators_tuple = method_tuple.at(1).toTupleRef().elements()[1];
auto operators = operators_tuple.toTupleRef().elements()[1];
for (auto& op_tuple : operators.toTupleRef().elements()) {
const auto& op = op_tuple.toTupleRef().elements();
// grab name
std::string op_name = op.at(0).toStringRef();
std::string op_overload_name = op.at(1).toStringRef();
if (op_overload_name != "") {
op_name.append(".");
op_name.append(op_overload_name);
}
// grab schema size
if (op.size() > 2) {
result.emplace(op_name, OperatorInfo{(int)op.at(2).toInt()});
} else { // no schema information use default
result.emplace(op_name, OperatorInfo{});
}
}
}
return result;
}
/********************** Get Type Table **********************/
// Forward declare
std::unordered_set<std::string> _get_mobile_model_contained_types(
const std::vector<IValue>& bytecode_ivalues);
std::unordered_set<std::string> _get_mobile_model_contained_types(
std::istream& in) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
return _get_mobile_model_contained_types(std::move(rai));
}
std::unordered_set<std::string> _get_mobile_model_contained_types(
const std::string& filename) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
return _get_mobile_model_contained_types(std::move(rai));
}
std::unordered_set<std::string> _get_mobile_model_contained_types(
std::shared_ptr<ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_CHECK(
false,
"Failed to open .ptl file please ensure the model was exported for mobile");
}
PyTorchStreamReader reader(std::move(rai));
auto bytecode_values = get_bytecode_ivalues(reader);
return _get_mobile_model_contained_types(bytecode_values);
}
// Get deduplicate type table given bytecode, and each string is a atomic type,
// like str, Tensor and etc. For example,
// input: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
// output: {Dict, int, Tuple, Tensor}
std::unordered_set<std::string> _get_mobile_model_contained_types(
const std::vector<IValue>& bytecode_ivalues) {
std::unordered_set<std::string> contained_types;
// To avoid parsing same type twice, declare $parsed_type_names_records and
// use type name (string, ex: "Dict[int, Tuple[Tensor, Tensor, Tensor]]") as
// the hash to record which types are parsed.
std::unordered_set<std::string> parsed_type_names_records;
for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
auto type_table_tuple =
method_tuple.at(1).toTupleRef().elements()[BYTECODE_INDEX_TYPE];
const auto& type_table =
type_table_tuple.toTupleRef().elements()[1].toTupleRef().elements();
// type_table is a list of IValue, and each IValue is a string,
// for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
std::vector<std::string> type_name_list;
for (const auto& type_definition : type_table) {
std::unordered_set<std::string> type_tokens;
std::string type_name = type_definition.toString()->string();
type_name_list.emplace_back(type_name);
}
at::TypeParser parser(type_name_list);
parser.parseList();
contained_types = parser.getContainedTypes();
}
return contained_types;
}
/********************** Compatibility Checker **********************/
ModelCompatibilityInfo ModelCompatibilityInfo::get(std::istream& in) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
return get(std::move(rai));
}
ModelCompatibilityInfo ModelCompatibilityInfo::get(
const std::string& filename) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
return get(std::move(rai));
}
ModelCompatibilityInfo ModelCompatibilityInfo::get(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_CHECK(
false, "Failed to open zip file for model compatibility information");
}
PyTorchStreamReader reader(std::move(rai));
std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader);
uint64_t model_bytecode_version =
_get_model_bytecode_version(bytecode_values);
auto model_info = _get_model_ops_and_info(bytecode_values);
std::unordered_set<std::string> type_table =
_get_mobile_model_contained_types(bytecode_values);
return ModelCompatibilityInfo{model_bytecode_version, model_info, type_table};
}
ModelCompatCheckResult is_compatible(
RuntimeCompatibilityInfo runtime_info,
ModelCompatibilityInfo model_info) {
ModelCompatCheckResult result = {ModelCompatibilityStatus::OK, {}};
// Check that the models bytecode version is less than or equal to
// kMaxSupportedBytecodeVersion from the runtime
if (model_info.bytecode_version > runtime_info.bytecode_version) {
result.status = ModelCompatibilityStatus::ERROR;
std::ostringstream s;
s << "model bytecode version " << model_info.bytecode_version
<< "is greater than the runtimes " << runtime_info.bytecode_version;
result.errors.emplace_back(s.str());
}
std::unordered_set<std::string> supported_type = runtime_info.supported_types;
// Check type table
for (const auto& type_name : model_info.type_table) {
if (supported_type.find(type_name) == supported_type.end()) {
result.status = ModelCompatibilityStatus::ERROR;
std::ostringstream s;
s << "Primitive type: '" << type_name
<< "' is not supported in current runtime";
result.errors.push_back(s.str());
}
}
// Check operators
std::unordered_map<std::string, OperatorInfo> operator_info =
model_info.operator_info;
for (auto const& op : operator_info) {
std::string op_name = op.first;
OperatorInfo model_op_info = op.second;
// Check if operator not present in runtime
if (runtime_info.operator_info.find(op_name) ==
runtime_info.operator_info.end()) {
result.status = ModelCompatibilityStatus::ERROR;
std::ostringstream s;
s << "Operator '" << op_name << "' missing from runtime (not found)";
result.errors.push_back(s.str());
} else {
OperatorInfo runtime_op_info = runtime_info.operator_info.at(op_name);
// If the runtime op has no schema information its a false alarm and isn't
// actually useable
if (!runtime_op_info.num_schema_args.has_value()) {
result.status = ModelCompatibilityStatus::ERROR;
std::ostringstream s;
s << "Operator '" << op_name
<< "' missing from runtime (missing schema)";
result.errors.push_back(s.str());
} else {
// Check if the model operator has schema information. If it doesn't
// then the model is from a bytecode version < 6 and we are done. If the
// model has more args than the runtime, then the runtime can't know
// what to do so we aren't compatible. If the runtime has more args than
// the model then we can just use default values and be fine.
if (model_op_info.num_schema_args.has_value() &&
(model_op_info.num_schema_args.value() >
runtime_op_info.num_schema_args.value())) {
std::ostringstream s;
s << "Operator schema for'" << op_name << "' has "
<< model_op_info.num_schema_args.value()
<< " args in model but only "
<< runtime_op_info.num_schema_args.value() << " in the runtime";
result.errors.push_back(s.str());
}
}
}
}
return result;
}
} // namespace jit
} // namespace torch