Files
pytorch/tools/lite_interpreter/gen_selected_mobile_ops_header.py
Nikita Shulga 634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00

181 lines
5.9 KiB
Python

#!/usr/bin/env python3
import argparse
import os
from typing import Set
import yaml
from torchgen.code_template import CodeTemplate
from torchgen.selective_build.selector import SelectiveBuilder
# Safely load fast C Yaml loader/dumper if they are available
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) {
return $dtype_checks;
}"""
if_condition_template = CodeTemplate(if_condition_template_str)
selected_kernel_dtypes_h_template_str = """
#include <c10/core/ScalarType.h>
#include <c10/util/string_view.h>
#include <c10/macros/Macros.h>
namespace at {
inline constexpr bool should_include_kernel_dtype(
const char *kernel_tag_str,
at::ScalarType scalar_type
) {
c10::string_view kernel_tag_sv C10_UNUSED = c10::string_view(kernel_tag_str);
$body
return false;
}
}
"""
selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str)
selected_mobile_ops_preamble = """#pragma once
/**
* Generated by gen_selected_mobile_ops_header.py
*/
"""
def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]:
ops = []
for op_name, op in selective_builder.operators.items():
if op.is_root_operator:
ops.append(op_name)
return set(ops)
def get_selected_kernel_dtypes_code(
selective_builder: SelectiveBuilder,
) -> str:
# See https://www.internalfb.com/intern/paste/P153411698/ for an example of the
# generated code in case all kernel dtypes are selected and in case some kernel
# dtypes are selected (i.e. both cases).
#
body = "return true;"
if (
selective_builder.include_all_operators is False
and selective_builder.include_all_non_op_selectives is False
):
body_parts = []
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes]
body_parts.append(
if_condition_template.substitute(
kernel_tag_name=kernel_tag,
dtype_checks=" || ".join(conditions),
),
)
body = " else ".join(body_parts)
header_contents = selected_kernel_dtypes_h_template.substitute(body=body)
return header_contents
# Write the file selected_mobile_ops.h with optionally:
# 1. The selected root operators
# 2. The selected kernel dtypes
def write_selected_mobile_ops(
output_file_path: str,
selective_builder: SelectiveBuilder,
) -> None:
root_ops = extract_root_operators(selective_builder)
custom_classes = selective_builder.custom_classes
build_features = selective_builder.build_features
with open(output_file_path, "wb") as out_file:
body_parts = [selected_mobile_ops_preamble]
# This condition checks if we are in selective build.
# if these lists are not defined the corresponding selective build macros trivially return the item in question was selected
if not selective_builder.include_all_operators:
body_parts.append(
"#define TORCH_OPERATOR_WHITELIST "
+ (";".join(sorted(root_ops)))
+ ";\n\n"
)
# This condition checks if we are in tracing based selective build
if selective_builder.include_all_non_op_selectives is False:
body_parts.append(
"#define TORCH_CUSTOM_CLASS_ALLOWLIST "
+ (";".join(sorted(custom_classes)))
+ ";\n\n"
)
body_parts.append(
"#define TORCH_BUILD_FEATURE_ALLOWLIST "
+ (";".join(sorted(build_features)))
+ ";\n\n"
)
body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
header_contents = "".join(body_parts)
out_file.write(header_contents.encode("utf-8"))
# root_ops: a set of selected root operators for selective build
# Write the file selected_mobile_ops.h with optionally:
# 1. The selected root operators from root_ops
# 2. All kernel dtypes
def write_selected_mobile_ops_with_all_dtypes(
output_file_path: str,
root_ops: Set[str],
) -> None:
with open(output_file_path, "wb") as out_file:
body_parts = [selected_mobile_ops_preamble]
body_parts.append(
"#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n"
)
selective_builder = SelectiveBuilder.get_nop_selector()
body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
header_contents = "".join(body_parts)
out_file.write(header_contents.encode("utf-8"))
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate selected_mobile_ops.h for selective build."
)
parser.add_argument(
"-p",
"--yaml-file-path",
"--yaml_file_path",
type=str,
required=True,
help="Path to the yaml" " file with a list of operators used by the model.",
)
parser.add_argument(
"-o",
"--output-file-path",
"--output_file_path",
type=str,
required=True,
help="Path to destination"
"folder where selected_mobile_ops.h will be written.",
)
parsed_args = parser.parse_args()
model_file_name = parsed_args.yaml_file_path
print("Loading yaml file: ", model_file_name)
loaded_model = {}
with open(model_file_name, "rb") as model_file:
loaded_model = yaml.load(model_file, Loader=Loader)
root_operators_set = set(loaded_model)
print("Writing header file selected_mobile_ops.h: ", parsed_args.output_file_path)
write_selected_mobile_ops_with_all_dtypes(
os.path.join(parsed_args.output_file_path, "selected_mobile_ops.h"),
root_operators_set,
)
if __name__ == "__main__":
main()