mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76275 In preparation for addressing https://github.com/pytorch/pytorch/issues/73212 Diff was generated with: ``` git mv tools/codegen torchgen git grep -l 'tools.codegen' | xargs sed -i 's/tools.codegen/torchgen/g' sed -i "s/\${TOOLS_PATH}\/codegen/\${TORCH_ROOT}\/torchgen/g" caffe2/CMakeLists.txt ``` and a manual edits to: * tools/test/test_gen_backend_stubs.py * torchgen/build.bzl * torchgen/gen_backend_stubs.py aka this diff: ``` diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py index 3dc26c6d2d..104054575e 100644 --- a/tools/test/test_gen_backend_stubs.py +++ b/tools/test/test_gen_backend_stubs.py @@ -9,7 +9,7 @@ from torchgen.gen_backend_stubs import run from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401 path = os.path.dirname(os.path.realpath(__file__)) -gen_backend_stubs_path = os.path.join(path, '../torchgen/gen_backend_stubs.py') +gen_backend_stubs_path = os.path.join(path, '../../torchgen/gen_backend_stubs.py') # gen_backend_stubs.py is an integration point that is called directly by external backends. # The tests here are to confirm that badly formed inputs result in reasonable error messages. diff --git a/torchgen/build.bzl b/torchgen/build.bzl index ed04e35a43..d00078a3cf 100644 --- a/torchgen/build.bzl +++ b/torchgen/build.bzl @@ -1,6 +1,6 @@ def define_targets(rules): rules.py_library( - name = "codegen", + name = "torchgen", srcs = rules.glob(["**/*.py"]), deps = [ rules.requirement("PyYAML"), @@ -11,6 +11,6 @@ def define_targets(rules): rules.py_binary( name = "gen", - srcs = [":codegen"], + srcs = [":torchgen"], visibility = ["//visibility:public"], ) diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index c1a672a655..beee7a15e0 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -474,7 +474,7 @@ def run( ) -> None: # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py - pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute() + pytorch_root = pathlib.Path(__file__).parent.parent.absolute() template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates") def make_file_manager(install_dir: str) -> FileManager: ``` run_all_fbandroid_tests Test Plan: sandcastle Reviewed By: albanD, ngimel Differential Revision: D35770317 fbshipit-source-id: 153ac4a7fef15b1e750812a90bfafdbc8f1ebcdf (cherry picked from commit c6d485d1d4648fa1c8a4c14c5bf3d8e899b9b4dd)
181 lines
5.9 KiB
Python
181 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import os
|
|
from typing import Set
|
|
from torchgen.selective_build.selector import SelectiveBuilder
|
|
from torchgen.code_template import CodeTemplate
|
|
|
|
import yaml
|
|
|
|
# Safely load fast C Yaml loader/dumper if they are available
|
|
try:
|
|
from yaml import CSafeLoader as Loader
|
|
except ImportError:
|
|
from yaml import SafeLoader as Loader # type: ignore[misc]
|
|
|
|
|
|
if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) {
|
|
return $dtype_checks;
|
|
}"""
|
|
if_condition_template = CodeTemplate(if_condition_template_str)
|
|
|
|
selected_kernel_dtypes_h_template_str = """
|
|
#include <c10/core/ScalarType.h>
|
|
#include <c10/util/string_view.h>
|
|
#include <c10/macros/Macros.h>
|
|
|
|
namespace at {
|
|
inline constexpr bool should_include_kernel_dtype(
|
|
const char *kernel_tag_str,
|
|
at::ScalarType scalar_type
|
|
) {
|
|
c10::string_view kernel_tag_sv C10_UNUSED = c10::string_view(kernel_tag_str);
|
|
$body
|
|
return false;
|
|
}
|
|
}
|
|
"""
|
|
selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str)
|
|
|
|
selected_mobile_ops_preamble = """#pragma once
|
|
/**
|
|
* Generated by gen_selected_mobile_ops_header.py
|
|
*/
|
|
|
|
"""
|
|
|
|
|
|
def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]:
|
|
ops = []
|
|
for (op_name, op) in selective_builder.operators.items():
|
|
if op.is_root_operator:
|
|
ops.append(op_name)
|
|
return set(ops)
|
|
|
|
|
|
def get_selected_kernel_dtypes_code(
|
|
selective_builder: SelectiveBuilder,
|
|
) -> str:
|
|
# See https://www.internalfb.com/intern/paste/P153411698/ for an example of the
|
|
# generated code in case all kernel dtypes are selected and in case some kernel
|
|
# dtypes are selected (i.e. both cases).
|
|
#
|
|
body = "return true;"
|
|
if (
|
|
selective_builder.include_all_operators is False
|
|
and selective_builder.include_all_non_op_selectives is False
|
|
):
|
|
body_parts = []
|
|
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
|
|
conditions = list(
|
|
map(lambda x: "scalar_type == at::ScalarType::" + x, dtypes)
|
|
)
|
|
body_parts.append(
|
|
if_condition_template.substitute(
|
|
kernel_tag_name=kernel_tag,
|
|
dtype_checks=" || ".join(conditions),
|
|
),
|
|
)
|
|
body = " else ".join(body_parts)
|
|
|
|
header_contents = selected_kernel_dtypes_h_template.substitute(body=body)
|
|
return header_contents
|
|
|
|
|
|
# Write the file selected_mobile_ops.h with optionally:
|
|
# 1. The selected root operators
|
|
# 2. The selected kernel dtypes
|
|
def write_selected_mobile_ops(
|
|
output_file_path: str,
|
|
selective_builder: SelectiveBuilder,
|
|
) -> None:
|
|
root_ops = extract_root_operators(selective_builder)
|
|
custom_classes = selective_builder.custom_classes
|
|
build_features = selective_builder.build_features
|
|
with open(output_file_path, "wb") as out_file:
|
|
body_parts = [selected_mobile_ops_preamble]
|
|
# This condition checks if we are in selective build.
|
|
# if these lists are not defined the corresponding selective build macros trivially return the item in question was selected
|
|
if not selective_builder.include_all_operators:
|
|
body_parts.append(
|
|
"#define TORCH_OPERATOR_WHITELIST "
|
|
+ (";".join(sorted(root_ops)))
|
|
+ ";\n\n"
|
|
)
|
|
# This condition checks if we are in tracing based selective build
|
|
if selective_builder.include_all_non_op_selectives is False:
|
|
body_parts.append(
|
|
"#define TORCH_CUSTOM_CLASS_ALLOWLIST "
|
|
+ (";".join(sorted(custom_classes)))
|
|
+ ";\n\n"
|
|
)
|
|
body_parts.append(
|
|
"#define TORCH_BUILD_FEATURE_ALLOWLIST "
|
|
+ (";".join(sorted(build_features)))
|
|
+ ";\n\n"
|
|
)
|
|
|
|
body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
|
|
header_contents = "".join(body_parts)
|
|
out_file.write(header_contents.encode("utf-8"))
|
|
|
|
|
|
# root_ops: a set of selected root operators for selective build
|
|
# Write the file selected_mobile_ops.h with optionally:
|
|
# 1. The selected root operators from root_ops
|
|
# 2. All kernel dtypes
|
|
def write_selected_mobile_ops_with_all_dtypes(
|
|
output_file_path: str,
|
|
root_ops: Set[str],
|
|
) -> None:
|
|
with open(output_file_path, "wb") as out_file:
|
|
body_parts = [selected_mobile_ops_preamble]
|
|
body_parts.append(
|
|
"#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n"
|
|
)
|
|
|
|
selective_builder = SelectiveBuilder.get_nop_selector()
|
|
body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
|
|
|
|
header_contents = "".join(body_parts)
|
|
out_file.write(header_contents.encode("utf-8"))
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate selected_mobile_ops.h for selective build."
|
|
)
|
|
parser.add_argument(
|
|
"-p",
|
|
"--yaml_file_path",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the yaml" " file with a list of operators used by the model.",
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output_file_path",
|
|
type=str,
|
|
required=True,
|
|
help="Path to destination"
|
|
"folder where selected_mobile_ops.h will be written.",
|
|
)
|
|
parsed_args = parser.parse_args()
|
|
model_file_name = parsed_args.yaml_file_path
|
|
|
|
print("Loading yaml file: ", model_file_name)
|
|
loaded_model = {}
|
|
with open(model_file_name, "rb") as model_file:
|
|
loaded_model = yaml.load(model_file, Loader=Loader)
|
|
|
|
root_operators_set = set(loaded_model)
|
|
print("Writing header file selected_mobile_ops.h: ", parsed_args.output_file_path)
|
|
write_selected_mobile_ops_with_all_dtypes(
|
|
os.path.join(parsed_args.output_file_path, "selected_mobile_ops.h"),
|
|
root_operators_set,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|