Files
pytorch/tools/code_analyzer/gen_oplist.py
Edward Yang 36420b5e8c Rename tools/codegen to torchgen (#76275)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76275

In preparation for addressing
https://github.com/pytorch/pytorch/issues/73212

Diff was generated with:

```
git mv tools/codegen torchgen
git grep -l 'tools.codegen' | xargs sed -i 's/tools.codegen/torchgen/g'
sed -i "s/\${TOOLS_PATH}\/codegen/\${TORCH_ROOT}\/torchgen/g" caffe2/CMakeLists.txt
```

and a manual edits to:

* tools/test/test_gen_backend_stubs.py
* torchgen/build.bzl
* torchgen/gen_backend_stubs.py

aka this diff:

```
 diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py
index 3dc26c6d2d..104054575e 100644
 --- a/tools/test/test_gen_backend_stubs.py
+++ b/tools/test/test_gen_backend_stubs.py
@@ -9,7 +9,7 @@ from torchgen.gen_backend_stubs import run
 from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE  # noqa: F401

 path = os.path.dirname(os.path.realpath(__file__))
-gen_backend_stubs_path = os.path.join(path, '../torchgen/gen_backend_stubs.py')
+gen_backend_stubs_path = os.path.join(path, '../../torchgen/gen_backend_stubs.py')

 # gen_backend_stubs.py is an integration point that is called directly by external backends.
 # The tests here are to confirm that badly formed inputs result in reasonable error messages.
 diff --git a/torchgen/build.bzl b/torchgen/build.bzl
index ed04e35a43..d00078a3cf 100644
 --- a/torchgen/build.bzl
+++ b/torchgen/build.bzl
@@ -1,6 +1,6 @@
 def define_targets(rules):
     rules.py_library(
-        name = "codegen",
+        name = "torchgen",
         srcs = rules.glob(["**/*.py"]),
         deps = [
             rules.requirement("PyYAML"),
@@ -11,6 +11,6 @@ def define_targets(rules):

     rules.py_binary(
         name = "gen",
-        srcs = [":codegen"],
+        srcs = [":torchgen"],
         visibility = ["//visibility:public"],
     )
 diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py
index c1a672a655..beee7a15e0 100644
 --- a/torchgen/gen_backend_stubs.py
+++ b/torchgen/gen_backend_stubs.py
@@ -474,7 +474,7 @@ def run(
 ) -> None:

     # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
-    pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
+    pytorch_root = pathlib.Path(__file__).parent.parent.absolute()
     template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")

     def make_file_manager(install_dir: str) -> FileManager:
```

run_all_fbandroid_tests

Test Plan: sandcastle

Reviewed By: albanD, ngimel

Differential Revision: D35770317

fbshipit-source-id: 153ac4a7fef15b1e750812a90bfafdbc8f1ebcdf
(cherry picked from commit c6d485d1d4648fa1c8a4c14c5bf3d8e899b9b4dd)
2022-04-25 01:38:06 +00:00

190 lines
6.3 KiB
Python

#!/usr/bin/env python3
import argparse
import json
import os
import sys
from functools import reduce
from typing import Set, List, Any
import yaml
from torchgen.selective_build.selector import (
combine_selective_builders,
SelectiveBuilder,
)
from tools.lite_interpreter.gen_selected_mobile_ops_header import (
write_selected_mobile_ops,
)
def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]:
ops = []
for (op_name, op) in selective_builder.operators.items():
ops.append(op_name)
return set(ops)
def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]:
ops = []
for (op_name, op) in selective_builder.operators.items():
if op.is_used_for_training:
ops.append(op_name)
return set(ops)
def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None:
ops = []
for (op_name, op) in selective_builder.operators.items():
if op.include_all_overloads:
ops.append(op_name)
if ops:
raise Exception(
(
"Operators that include all overloads are "
+ "not allowed since --allow_include_all_overloads "
+ "was specified: {}"
).format(", ".join(ops))
)
def gen_supported_mobile_models(model_dicts: List[Any], output_dir: str) -> None:
supported_mobile_models_source = """/*
* Generated by gen_oplist.py
*/
#include "fb/supported_mobile_models/SupportedMobileModels.h"
struct SupportedMobileModelCheckerRegistry {{
SupportedMobileModelCheckerRegistry() {{
auto& ref = facebook::pytorch::supported_model::SupportedMobileModelChecker::singleton();
ref.set_supported_md5_hashes(std::unordered_set<std::string>{{
{supported_hashes_template}
}});
}}
}};
// This is a global object, initializing which causes the registration to happen.
SupportedMobileModelCheckerRegistry register_model_versions;
"""
# Generate SupportedMobileModelsRegistration.cpp
md5_hashes = set()
for model_dict in model_dicts:
if "debug_info" in model_dict:
debug_info = json.loads(model_dict["debug_info"][0])
if debug_info["is_new_style_rule"]:
for asset, asset_info in debug_info["asset_info"].items():
md5_hashes.update(asset_info["md5_hash"])
supported_hashes = ""
for md5 in md5_hashes:
supported_hashes += '"{}",\n'.format(md5)
with open(
os.path.join(output_dir, "SupportedMobileModelsRegistration.cpp"), "wb"
) as out_file:
source = supported_mobile_models_source.format(
supported_hashes_template=supported_hashes
)
out_file.write(source.encode("utf-8"))
def main(argv: List[Any]) -> None:
"""This binary generates 3 files:
1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
dtypes captured by tracing
2. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis)
"""
parser = argparse.ArgumentParser(description="Generate operator lists")
parser.add_argument(
"--output_dir",
help=(
"The directory to store the output yaml files (selected_mobile_ops.h, "
+ "selected_kernel_dtypes.h, selected_operators.yaml)"
),
required=True,
)
parser.add_argument(
"--model_file_list_path",
help=(
"Path to a file that contains the locations of individual "
+ "model YAML files that contain the set of used operators. This "
+ "file path must have a leading @-symbol, which will be stripped "
+ "out before processing."
),
required=True,
)
parser.add_argument(
"--allow_include_all_overloads",
help=(
"Flag to allow operators that include all overloads. "
+ "If not set, operators registered without using the traced style will"
+ "break the build."
),
action="store_true",
default=False,
required=False,
)
options = parser.parse_args()
if os.path.isfile(options.model_file_list_path):
print("Processing model file: ", options.model_file_list_path)
model_dicts = []
model_dict = yaml.safe_load(open(options.model_file_list_path))
model_dicts.append(model_dict)
else:
print("Processing model directory: ", options.model_file_list_path)
assert options.model_file_list_path[0] == "@"
model_file_list_path = options.model_file_list_path[1:]
model_dicts = []
with open(model_file_list_path) as model_list_file:
model_file_names = model_list_file.read().split()
for model_file_name in model_file_names:
with open(model_file_name, "rb") as model_file:
model_dict = yaml.safe_load(model_file)
model_dicts.append(model_dict)
selective_builders = list(
map(
lambda m: SelectiveBuilder.from_yaml_dict(m),
model_dicts,
)
)
# While we have the model_dicts generate the supported mobile models api
gen_supported_mobile_models(model_dicts, options.output_dir)
# We may have 0 selective builders since there may not be any viable
# pt_operator_library rule marked as a dep for the pt_operator_registry rule.
# This is potentially an error, and we should probably raise an assertion
# failure here. However, this needs to be investigated further.
selective_builder = SelectiveBuilder.from_yaml_dict({})
if len(selective_builders) > 0:
selective_builder = reduce(
combine_selective_builders,
selective_builders,
)
if not options.allow_include_all_overloads:
throw_if_any_op_includes_overloads(selective_builder)
with open(
os.path.join(options.output_dir, "selected_operators.yaml"), "wb"
) as out_file:
out_file.write(
yaml.safe_dump(
selective_builder.to_dict(), default_flow_style=False
).encode("utf-8"),
)
write_selected_mobile_ops(
os.path.join(options.output_dir, "selected_mobile_ops.h"),
selective_builder,
)
if __name__ == "__main__":
main(sys.argv)