mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Run Black on all of tools/
Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/76089 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
ae864d4fb9
commit
a11c1bbdd0
@ -44,6 +44,7 @@ selected_mobile_ops_preamble = """#pragma once
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]:
|
||||
ops = []
|
||||
for (op_name, op) in selective_builder.operators.items():
|
||||
@ -51,18 +52,24 @@ def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]:
|
||||
ops.append(op_name)
|
||||
return set(ops)
|
||||
|
||||
|
||||
def get_selected_kernel_dtypes_code(
|
||||
selective_builder: SelectiveBuilder,
|
||||
selective_builder: SelectiveBuilder,
|
||||
) -> str:
|
||||
# See https://www.internalfb.com/intern/paste/P153411698/ for an example of the
|
||||
# generated code in case all kernel dtypes are selected and in case some kernel
|
||||
# dtypes are selected (i.e. both cases).
|
||||
#
|
||||
body = "return true;"
|
||||
if selective_builder.include_all_operators is False and selective_builder.include_all_non_op_selectives is False:
|
||||
if (
|
||||
selective_builder.include_all_operators is False
|
||||
and selective_builder.include_all_non_op_selectives is False
|
||||
):
|
||||
body_parts = []
|
||||
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
|
||||
conditions = list(map(lambda x: 'scalar_type == at::ScalarType::' + x, dtypes))
|
||||
conditions = list(
|
||||
map(lambda x: "scalar_type == at::ScalarType::" + x, dtypes)
|
||||
)
|
||||
body_parts.append(
|
||||
if_condition_template.substitute(
|
||||
kernel_tag_name=kernel_tag,
|
||||
@ -79,8 +86,8 @@ def get_selected_kernel_dtypes_code(
|
||||
# 1. The selected root operators
|
||||
# 2. The selected kernel dtypes
|
||||
def write_selected_mobile_ops(
|
||||
output_file_path: str,
|
||||
selective_builder: SelectiveBuilder,
|
||||
output_file_path: str,
|
||||
selective_builder: SelectiveBuilder,
|
||||
) -> None:
|
||||
root_ops = extract_root_operators(selective_builder)
|
||||
custom_classes = selective_builder.custom_classes
|
||||
@ -90,16 +97,29 @@ def write_selected_mobile_ops(
|
||||
# This condition checks if we are in selective build.
|
||||
# if these lists are not defined the corresponding selective build macros trivially return the item in question was selected
|
||||
if not selective_builder.include_all_operators:
|
||||
body_parts.append("#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n")
|
||||
body_parts.append(
|
||||
"#define TORCH_OPERATOR_WHITELIST "
|
||||
+ (";".join(sorted(root_ops)))
|
||||
+ ";\n\n"
|
||||
)
|
||||
# This condition checks if we are in tracing based selective build
|
||||
if selective_builder.include_all_non_op_selectives is False:
|
||||
body_parts.append("#define TORCH_CUSTOM_CLASS_ALLOWLIST " + (";".join(sorted(custom_classes))) + ";\n\n")
|
||||
body_parts.append("#define TORCH_BUILD_FEATURE_ALLOWLIST " + (";".join(sorted(build_features))) + ";\n\n")
|
||||
body_parts.append(
|
||||
"#define TORCH_CUSTOM_CLASS_ALLOWLIST "
|
||||
+ (";".join(sorted(custom_classes)))
|
||||
+ ";\n\n"
|
||||
)
|
||||
body_parts.append(
|
||||
"#define TORCH_BUILD_FEATURE_ALLOWLIST "
|
||||
+ (";".join(sorted(build_features)))
|
||||
+ ";\n\n"
|
||||
)
|
||||
|
||||
body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
|
||||
header_contents = "".join(body_parts)
|
||||
out_file.write(header_contents.encode("utf-8"))
|
||||
|
||||
|
||||
# root_ops: a set of selected root operators for selective build
|
||||
# Write the file selected_mobile_ops.h with optionally:
|
||||
# 1. The selected root operators from root_ops
|
||||
@ -110,7 +130,9 @@ def write_selected_mobile_ops_with_all_dtypes(
|
||||
) -> None:
|
||||
with open(output_file_path, "wb") as out_file:
|
||||
body_parts = [selected_mobile_ops_preamble]
|
||||
body_parts.append("#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n")
|
||||
body_parts.append(
|
||||
"#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n"
|
||||
)
|
||||
|
||||
selective_builder = SelectiveBuilder.get_nop_selector()
|
||||
body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
|
||||
@ -118,17 +140,25 @@ def write_selected_mobile_ops_with_all_dtypes(
|
||||
header_contents = "".join(body_parts)
|
||||
out_file.write(header_contents.encode("utf-8"))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate selected_mobile_ops.h for selective build."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--yaml_file_path", type=str, required=True, help="Path to the yaml"
|
||||
" file with a list of operators used by the model."
|
||||
"-p",
|
||||
"--yaml_file_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the yaml" " file with a list of operators used by the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_file_path", type=str, required=True, help="Path to destination"
|
||||
"folder where selected_mobile_ops.h will be written."
|
||||
"-o",
|
||||
"--output_file_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to destination"
|
||||
"folder where selected_mobile_ops.h will be written.",
|
||||
)
|
||||
parsed_args = parser.parse_args()
|
||||
model_file_name = parsed_args.yaml_file_path
|
||||
@ -138,12 +168,13 @@ def main() -> None:
|
||||
with open(model_file_name, "rb") as model_file:
|
||||
loaded_model = yaml.load(model_file, Loader=Loader)
|
||||
|
||||
|
||||
root_operators_set = set(loaded_model)
|
||||
print("Writing header file selected_mobile_ops.h: ", parsed_args.output_file_path)
|
||||
write_selected_mobile_ops_with_all_dtypes(
|
||||
os.path.join(parsed_args.output_file_path, "selected_mobile_ops.h"),
|
||||
root_operators_set)
|
||||
root_operators_set,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user