From b62d39eda07ab6af4dca5d760d99db638ecbc726 Mon Sep 17 00:00:00 2001 From: Linbin Yu Date: Wed, 29 Jun 2022 23:27:47 +0000 Subject: [PATCH] Consolidate all python targets in the tools folder (#80408) Summary: All buck targets that points to caffe2/tools folder are now moved to tools/BUCK. This also eliminates all python library/binary import in pt_defs.bzl, which caused T124308913. Test Plan: CI Differential Revision: D37468313 Pull Request resolved: https://github.com/pytorch/pytorch/pull/80408 Approved by: https://github.com/seemethere, https://github.com/malfet --- .github/workflows/_buck-build-test.yml | 10 +- .lintrunner.toml | 9 + buckbuild.bzl | 104 +------- cmake/VulkanCodegen.cmake | 2 +- pt_ops.bzl | 4 +- tools/BUCK.bzl | 263 +++++++++++++++++++ tools/BUCK.oss | 10 + tools/autograd/BUCK.oss | 35 --- tools/build_defs/fb_python_binary.bzl | 9 - tools/build_defs/fb_python_library.bzl | 9 - {aten/src/ATen => tools}/gen_vulkan_spv.py | 0 tools/jit/BUCK.oss | 12 - tools/lite_interpreter/BUCK.oss | 6 - tools/setup_helpers/BUCK.oss | 41 --- tools/test/gen_operators_yaml_test.py | 190 ++++++++++++++ tools/test/gen_oplist_test.py | 35 +++ tools/test/test_selective_build.py | 281 +++++++++++++++++++++ 17 files changed, 807 insertions(+), 213 deletions(-) create mode 100644 tools/BUCK.bzl create mode 100644 tools/BUCK.oss delete mode 100644 tools/autograd/BUCK.oss delete mode 100644 tools/build_defs/fb_python_binary.bzl delete mode 100644 tools/build_defs/fb_python_library.bzl rename {aten/src/ATen => tools}/gen_vulkan_spv.py (100%) delete mode 100644 tools/jit/BUCK.oss delete mode 100644 tools/lite_interpreter/BUCK.oss delete mode 100644 tools/setup_helpers/BUCK.oss create mode 100644 tools/test/gen_operators_yaml_test.py create mode 100644 tools/test/gen_oplist_test.py create mode 100644 tools/test/test_selective_build.py diff --git a/.github/workflows/_buck-build-test.yml b/.github/workflows/_buck-build-test.yml index 2d1e563ed0ee..ae7f7517e2ed 100644 --- a/.github/workflows/_buck-build-test.yml +++ b/.github/workflows/_buck-build-test.yml @@ -62,7 +62,15 @@ jobs: command: | sh scripts/buck_setup.sh - - name: Build C10 + - name: Build tools + run: | + buck build tools: --keep-going + + - name: Run tools tests + run: | + buck test tools:selective_build_test tools:gen_oplist_test tools:gen_operators_yaml_test + + - name: Build c10 run: | buck build c10:c10 diff --git a/.lintrunner.toml b/.lintrunner.toml index 1b98226e96c8..302ff02e4f84 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -16,6 +16,7 @@ exclude_patterns = [ 'torch/lib/**', 'venv/**', '**/*.pyi', + 'tools/test/test_selective_build.py', ] command = [ 'python3', @@ -145,6 +146,10 @@ include_patterns = [ exclude_patterns = [ # (linbinyu) copied from internal repo 'tools/code_analyzer/gen_operators_yaml.py', + 'tools/gen_vulkan_spv.py', + 'tools/test/gen_operators_yaml_test.py', + 'tools/test/gen_oplist_test.py', + 'tools/test/test_selective_build.py', ] command = [ 'python3', @@ -334,6 +339,7 @@ exclude_patterns = [ command = [ 'python3', 'tools/linter/adapters/grep_linter.py', + # @lint-ignore TXT2 '--pattern= ', '--linter-name=TABS', '--error-name=saw some tabs', @@ -565,6 +571,9 @@ include_patterns = [ 'torch/_decomp/**/*.py', 'test/onnx/**/*.py', ] +exclude_patterns = [ + 'tools/gen_vulkan_spv.py', +] command = [ 'python3', 'tools/linter/adapters/black_linter.py', diff --git a/buckbuild.bzl b/buckbuild.bzl index 42abc497af77..ecacc34759f6 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -3,8 +3,6 @@ load("@bazel_skylib//lib:paths.bzl", "paths") load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native") -load("//tools/build_defs:fb_python_binary.bzl", "fb_python_binary") -load("//tools/build_defs:fb_python_library.bzl", "fb_python_library") load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") @@ -416,7 +414,7 @@ def gen_aten_files( name = name, default_outs = ["."], outs = get_aten_generated_files(backends), - cmd = "$(exe {}:gen_aten_bin) ".format(ROOT) + " ".join([ + cmd = "$(exe {}torchgen:gen) ".format(ROOT_PATH) + " ".join([ "--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT), "--install_dir $OUT", ] + extra_params), @@ -442,7 +440,7 @@ def gen_aten_unboxing_files( name = genrule_name, default_outs = ["."], outs = get_unboxing_generated_files(), - cmd = "$(exe {}:gen_unboxing_bin) ".format(ROOT) + " ".join([ + cmd = "$(exe {}tools:gen_unboxing_bin) ".format(ROOT_PATH) + " ".join([ "--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT), "--install_dir $OUT", ] + extra_params), @@ -515,7 +513,7 @@ def pt_operator_query_codegen( # @lint-ignore BUCKLINT fb_native.genrule( name = oplist_dir_name, - cmd = ("$(exe {}:gen_oplist) ".format(ROOT) + + cmd = ("$(exe {}tools:gen_oplist) ".format(ROOT_PATH) + "--model_file_list_path $(@query_outputs 'attrfilter(labels, pt_operator_library, deps(set({deps})))') " + ("" if enforce_traced_op_list else "--allow_include_all_overloads ") + "--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])), @@ -620,7 +618,7 @@ def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple outs = get_generate_code_bin_outs(), default_outs = ["."], bash = "mkdir -p tools && " + - "$(exe {}tools/setup_helpers:generate_code_bin) ".format(ROOT_PATH) + " ".join( + "$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join( # Mobile build only needs libtorch - skip python bindings for now, except # for ovrsource, which needs Python bindings. (["--subset libtorch"] if not is_arvr_mode() else []) + [ @@ -630,7 +628,7 @@ def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple ] + extra_params, ), cmd_exe = "@powershell -Command New-Item -Path tools -ItemType Directory -Force; " + - "$(exe {}tools/setup_helpers:generate_code_bin) ".format(ROOT_PATH) + " ".join( + "$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join( # Mobile build only needs libtorch - skip python bindings for now, except # for ovrsource, which needs Python bindings. (["--subset libtorch"] if not is_arvr_mode() else []) + [ @@ -950,7 +948,7 @@ def define_buck_targets( "torch/csrc/api/include/torch/version.h.in", "version.txt", ], - cmd = "$(exe {}tools/setup_helpers:gen-version-header) ".format(ROOT_PATH) + " ".join([ + cmd = "$(exe {}tools:gen-version-header) ".format(ROOT_PATH) + " ".join([ "--template-path", "torch/csrc/api/include/torch/version.h.in", "--version-path", @@ -995,28 +993,13 @@ def define_buck_targets( ], ) - fb_python_library( - name = "substitutelib", - srcs = ["tools/substitute.py"], - base_module = "", - ) - - fb_python_binary( - name = "substitute", - main_module = "tools.substitute", - visibility = ["PUBLIC"], - deps = [ - ":substitutelib", - ], - ) - # @lint-ignore BUCKLINT fb_native.genrule( name = "generate_aten_config", srcs = [ "aten/src/ATen/Config.h.in", ], - cmd = "$(exe :substitute) " + " ".join([ + cmd = "$(exe {}tools:substitute) ".format(ROOT_PATH) + " ".join([ "--install_dir", "$OUT", "--input-file", @@ -1072,79 +1055,6 @@ def define_buck_targets( default_outs = ["."], ) - fb_python_binary( - name = "gen_aten_bin", - main_module = "torchgen.gen", - visibility = [ - "PUBLIC", - ], - deps = [ - ROOT_PATH + "torchgen:torchgen", - ], - ) - - fb_python_binary( - name = "gen_unboxing_bin", - main_module = "tools.jit.gen_unboxing", - visibility = [ - "PUBLIC", - ], - deps = [ - ROOT_PATH + "tools/jit:jit", - ], - ) - - fb_python_library( - name = "gen_oplist_lib", - srcs = subdir_glob([ - ("tools/code_analyzer", "gen_oplist.py"), - ("tools/code_analyzer", "gen_op_registration_allowlist.py"), - ]), - base_module = "", - tests = [ - ":gen_oplist_test", - ], - deps = [ - third_party("pyyaml"), - ROOT_PATH + "tools/lite_interpreter:gen_selected_mobile_ops_header", - ROOT_PATH + "torchgen:torchgen", - ], - ) - - fb_python_library( - name = "gen_operators_yaml_lib", - srcs = subdir_glob([ - ("tools/code_analyzer", "gen_operators_yaml.py"), - ("tools/code_analyzer", "gen_op_registration_allowlist.py"), - ]), - base_module = "", - tests = [ - ":gen_operators_yaml_test", - ], - deps = [ - third_party("pyyaml"), - ROOT_PATH + "torchgen:torchgen", - ], - ) - - fb_python_binary( - name = "gen_oplist", - main_module = "gen_oplist", - visibility = ["PUBLIC"], - deps = [ - ":gen_oplist_lib", - ], - ) - - fb_python_binary( - name = "gen_operators_yaml", - main_module = "gen_operators_yaml", - visibility = ["PUBLIC"], - deps = [ - ":gen_operators_yaml_lib", - ], - ) - gen_aten_files( name = "gen_aten", extra_flags = get_aten_codegen_extra_params(USED_PT_BACKENDS), diff --git a/cmake/VulkanCodegen.cmake b/cmake/VulkanCodegen.cmake index c39b54df3af3..075f2b36ad2a 100644 --- a/cmake/VulkanCodegen.cmake +++ b/cmake/VulkanCodegen.cmake @@ -62,7 +62,7 @@ if(NOT USE_VULKAN_SHADERC_RUNTIME) execute_process( COMMAND "${PYTHON_EXECUTABLE}" - ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen_vulkan_spv.py + ${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py --glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl --output-path ${VULKAN_GEN_OUTPUT_PATH} --glslc-path=${GLSLC_PATH} diff --git a/pt_ops.bzl b/pt_ops.bzl index 2dd4ce3e2ab2..a8089a9ca410 100644 --- a/pt_ops.bzl +++ b/pt_ops.bzl @@ -39,7 +39,7 @@ def pt_operator_library( name = name, out = "model_operators.yaml", cmd = ( - "$(exe {root}:gen_operators_yaml) " + + "$(exe {exe}) " + "{optionally_root_ops} " + "{optionally_training_root_ops} " + "--rule_name {rule_name} " + @@ -52,7 +52,7 @@ def pt_operator_library( "{optionally_model_traced_backends} " + "{optionally_include_all_operators}" ).format( - root = "//" if IS_OSS else "//xplat/caffe2", + exe = "//tools:gen_operators_yaml" if IS_OSS else "//xplat/caffe2/tools:gen_operators_yaml", rule_name = name, model_name = model_name, dep_graph_yaml = "none" if IS_OSS else "$(location //xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ", diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl new file mode 100644 index 000000000000..959a73d55e2e --- /dev/null +++ b/tools/BUCK.bzl @@ -0,0 +1,263 @@ +# @lint-ignore-every FBCODEBZLADDLOADS +load("//tools/build_defs:glob_defs.bzl", "subdir_glob") + +# shared by internal and OSS BUCK +def define_tools_targets( + python_binary, + python_library, + python_test, + third_party, + torchgen_deps, + contacts = []): + python_library( + name = "substitutelib", + srcs = ["substitute.py"], + base_module = "", + ) + + python_binary( + name = "substitute", + main_module = "substitute", + visibility = ["PUBLIC"], + deps = [ + ":substitutelib", + ], + ) + + python_library( + name = "jit", + # @lint-ignore BUCKRESTRICTEDSYNTAX + srcs = glob([ + "jit/*.py", + "jit/templates/*", + ]), + base_module = "tools", + visibility = ["PUBLIC"], + deps = [ + torchgen_deps, + ], + ) + + python_binary( + name = "gen_unboxing_bin", + main_module = "tools.jit.gen_unboxing", + visibility = [ + "PUBLIC", + ], + deps = [ + ":jit", + ], + ) + + python_library( + name = "gen_selected_mobile_ops_header", + srcs = ["lite_interpreter/gen_selected_mobile_ops_header.py"], + base_module = "tools", + visibility = ["PUBLIC"], + ) + + python_library( + name = "gen_oplist_lib", + srcs = subdir_glob([ + ("code_analyzer", "gen_oplist.py"), + ("code_analyzer", "gen_op_registration_allowlist.py"), + ]), + base_module = "", + tests = [ + ":gen_oplist_test", + ], + deps = [ + ":gen_selected_mobile_ops_header", + torchgen_deps, + third_party("pyyaml"), + ], + ) + + python_binary( + name = "gen_oplist", + main_module = "gen_oplist", + visibility = ["PUBLIC"], + deps = [ + ":gen_oplist_lib", + ], + ) + + python_library( + name = "gen_operators_yaml_lib", + srcs = subdir_glob([ + ("code_analyzer", "gen_operators_yaml.py"), + ("code_analyzer", "gen_op_registration_allowlist.py"), + ]), + base_module = "", + tests = [ + ":gen_operators_yaml_test", + ], + deps = [ + third_party("pyyaml"), + torchgen_deps, + ], + ) + + python_binary( + name = "gen_operators_yaml", + main_module = "gen_operators_yaml", + visibility = ["PUBLIC"], + deps = [ + ":gen_operators_yaml_lib", + ], + ) + + python_library( + name = "autograd", + # @lint-ignore BUCKRESTRICTEDSYNTAX + srcs = glob( + ["autograd/*.py"], + ), + base_module = "tools", + resources = [ + "autograd/deprecated.yaml", + "autograd/derivatives.yaml", + "autograd/templates/ADInplaceOrViewType.cpp", + "autograd/templates/Functions.cpp", + "autograd/templates/Functions.h", + "autograd/templates/TraceType.cpp", + "autograd/templates/VariableType.cpp", + "autograd/templates/VariableType.h", + "autograd/templates/annotated_fn_args.py.in", + "autograd/templates/python_enum_tag.cpp", + "autograd/templates/python_fft_functions.cpp", + "autograd/templates/python_functions.cpp", + "autograd/templates/python_functions.h", + "autograd/templates/python_linalg_functions.cpp", + "autograd/templates/python_nn_functions.cpp", + "autograd/templates/python_return_types.cpp", + "autograd/templates/python_sparse_functions.cpp", + "autograd/templates/python_special_functions.cpp", + "autograd/templates/python_torch_functions.cpp", + "autograd/templates/python_variable_methods.cpp", + "autograd/templates/variable_factories.h", + ], + visibility = ["PUBLIC"], + deps = [ + third_party("pyyaml"), + torchgen_deps, + ], + ) + + python_library( + name = "generate_code", + srcs = [ + "setup_helpers/generate_code.py", + ], + base_module = "tools", + deps = [ + ":autograd", + ":jit", + torchgen_deps, + ], + ) + + python_binary( + name = "generate_code_bin", + main_module = "tools.setup_helpers.generate_code", + # Windows does not support inplace: + # https://github.com/facebook/buck/issues/2161. + # + # Note that //arvr/mode/embedded/win/clang-aarch64-release sets + # its target platform to + # ovr_config//platform/embedded:clang-aarch64-linux-release, hence + # that is why we are selecting that OS to trigger this behavior. + package_style = select({ + "DEFAULT": "inplace", + "ovr_config//os:linux-arm64": "standalone", + }), + visibility = ["PUBLIC"], + # Because Windows does not support inplace packaging, we need to + # ensure it is unzipped before executing it, otherwise it will not + # be able to find any resources using path manipulation. + # + # See note above about why the OS is Linux here and not Windows. + zip_safe = select({ + "DEFAULT": True, + "ovr_config//os:linux-arm64": False, + }), + deps = [ + ":generate_code", + ], + ) + + python_library( + name = "gen-version-header-lib", + srcs = [ + "setup_helpers/gen_version_header.py", + ], + base_module = "", + deps = [], + ) + + python_binary( + name = "gen-version-header", + main_module = "setup_helpers.gen_version_header", + visibility = ["PUBLIC"], + deps = [ + ":gen-version-header-lib", + ], + ) + + python_library( + name = "gen_aten_vulkan_spv_lib", + srcs = [ + "gen_vulkan_spv.py", + ], + base_module = "", + deps = [ + torchgen_deps, + ], + ) + + python_binary( + name = "gen_aten_vulkan_spv_bin", + main_module = "gen_vulkan_spv", + visibility = [ + "PUBLIC", + ], + deps = [ + ":gen_aten_vulkan_spv_lib", + ], + ) + + python_test( + name = "selective_build_test", + srcs = [ + "test/test_selective_build.py", + ], + contacts = contacts, + visibility = ["PUBLIC"], + deps = [ + torchgen_deps, + ], + ) + + python_test( + name = "gen_oplist_test", + srcs = [ + "test/gen_oplist_test.py", + ], + contacts = contacts, + visibility = ["PUBLIC"], + deps = [ + ":gen_oplist_lib", + ], + ) + + python_test( + name = "gen_operators_yaml_test", + srcs = [ + "test/gen_operators_yaml_test.py", + ], + visibility = ["PUBLIC"], + contacts = contacts, + deps = [ + ":gen_operators_yaml_lib", + ], + ) diff --git a/tools/BUCK.oss b/tools/BUCK.oss new file mode 100644 index 000000000000..97f67945120e --- /dev/null +++ b/tools/BUCK.oss @@ -0,0 +1,10 @@ +load("//:buckbuild.bzl", "third_party") +load(":BUCK.bzl", "define_tools_targets") + +define_tools_targets( + python_binary = python_binary, + python_library = python_library, + python_test = python_test, + third_party = third_party, + torchgen_deps = "//torchgen:torchgen", +) diff --git a/tools/autograd/BUCK.oss b/tools/autograd/BUCK.oss deleted file mode 100644 index 04403f4a5269..000000000000 --- a/tools/autograd/BUCK.oss +++ /dev/null @@ -1,35 +0,0 @@ -python_library( - name = "autograd", - srcs = glob( - ["*.py"], - ), - base_module = "tools.autograd", - resources = [ - "deprecated.yaml", - "derivatives.yaml", - "templates/ADInplaceOrViewType.cpp", - "templates/Functions.cpp", - "templates/Functions.h", - "templates/TraceType.cpp", - "templates/VariableType.cpp", - "templates/VariableType.h", - "templates/annotated_fn_args.py.in", - "templates/python_fft_functions.cpp", - "templates/python_functions.cpp", - "templates/python_functions.h", - "templates/python_linalg_functions.cpp", - "templates/python_nn_functions.cpp", - "templates/python_return_types.cpp", - "templates/python_sparse_functions.cpp", - "templates/python_special_functions.cpp", - "templates/python_torch_functions.cpp", - "templates/python_variable_methods.cpp", - "templates/variable_factories.h", - "templates/python_enum_tag.cpp", - ], - visibility = ["PUBLIC"], - deps = [ - "//third_party:pyyaml", - "//torchgen:torchgen", - ], -) diff --git a/tools/build_defs/fb_python_binary.bzl b/tools/build_defs/fb_python_binary.bzl deleted file mode 100644 index 5e69f32881b0..000000000000 --- a/tools/build_defs/fb_python_binary.bzl +++ /dev/null @@ -1,9 +0,0 @@ -# Only used for PyTorch open source BUCK build -# @lint-ignore-every BUCKRESTRICTEDSYNTAX -# @lint-ignore-every FBCODEBZLADDLOADS - -def fb_python_binary(**kwgs): - if read_config("pt", "is_oss", "0") == "0": - fail("This file is for open source pytorch build. Do not use it in fbsource!") - - python_binary(**kwgs) diff --git a/tools/build_defs/fb_python_library.bzl b/tools/build_defs/fb_python_library.bzl deleted file mode 100644 index e0ab86f77b7f..000000000000 --- a/tools/build_defs/fb_python_library.bzl +++ /dev/null @@ -1,9 +0,0 @@ -# Only used for PyTorch open source BUCK build -# @lint-ignore-every BUCKRESTRICTEDSYNTAX -# @lint-ignore-every FBCODEBZLADDLOADS - -def fb_python_library(**kwgs): - if read_config("pt", "is_oss", "0") == "0": - fail("This file is for open source pytorch build. Do not use it in fbsource!") - - python_library(**kwgs) diff --git a/aten/src/ATen/gen_vulkan_spv.py b/tools/gen_vulkan_spv.py similarity index 100% rename from aten/src/ATen/gen_vulkan_spv.py rename to tools/gen_vulkan_spv.py diff --git a/tools/jit/BUCK.oss b/tools/jit/BUCK.oss deleted file mode 100644 index 8c0105f1cf8e..000000000000 --- a/tools/jit/BUCK.oss +++ /dev/null @@ -1,12 +0,0 @@ -python_library( - name = "jit", - srcs = glob([ - "*.py", - "templates/*", - ]), - base_module = "tools.jit", - visibility = ["PUBLIC"], - deps = [ - "//torchgen:torchgen", - ], -) diff --git a/tools/lite_interpreter/BUCK.oss b/tools/lite_interpreter/BUCK.oss deleted file mode 100644 index 10415c26aee7..000000000000 --- a/tools/lite_interpreter/BUCK.oss +++ /dev/null @@ -1,6 +0,0 @@ -python_library( - name = "gen_selected_mobile_ops_header", - srcs = ["gen_selected_mobile_ops_header.py"], - base_module = "tools.lite_interpreter", - visibility = ["PUBLIC"], -) diff --git a/tools/setup_helpers/BUCK.oss b/tools/setup_helpers/BUCK.oss deleted file mode 100644 index afcd31fb3a03..000000000000 --- a/tools/setup_helpers/BUCK.oss +++ /dev/null @@ -1,41 +0,0 @@ -python_library( - name = "generate_code", - srcs = [ - "generate_code.py", - ], - base_module = "tools.setup_helpers", - deps = [ - "//tools/autograd:autograd", - "//tools/jit:jit", - "//torchgen:torchgen", - ], -) - -python_binary( - name = "generate_code_bin", - main_module = "tools.setup_helpers.generate_code", - visibility = ["PUBLIC"], - # package_style = "inplace", - zip_safe = False, - deps = [ - ":generate_code", - ], -) - -python_library( - name = "gen-version-header-lib", - srcs = [ - "gen_version_header.py", - ], - base_module = "tools.setup_helpers", - deps = [], -) - -python_binary( - name = "gen-version-header", - main_module = "tools.setup_helpers.gen_version_header", - visibility = ["PUBLIC"], - deps = [ - ":gen-version-header-lib", - ], -) diff --git a/tools/test/gen_operators_yaml_test.py b/tools/test/gen_operators_yaml_test.py new file mode 100644 index 000000000000..87455d3a13ff --- /dev/null +++ b/tools/test/gen_operators_yaml_test.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import unittest + +from gen_operators_yaml import make_filter_from_options, verify_all_specified_present + + +class GenOperatorsYAMLTest(unittest.TestCase): + def setUp(self): + pass + + def test_filter_creation(self): + filter_func = make_filter_from_options( + model_name="abc", + model_versions=["100", "101"], + model_assets=None, + model_backends=None, + ) + config = [ + { + "model": { + "name": "abc", + "version": 100, + "asset": "asset-1", + "backend": "CPU", + }, + "root_operators": [], + "traced_operators": [], + }, + { + "model": { + "name": "abc", + "version": 102, + "asset": "asset-1", + "backend": "CPU", + }, + "root_operators": [], + }, + { + "model": { + "name": "abcd", + "version": 100, + "asset": "asset-1", + "backend": "CPU", + }, + "root_operators": [], + "traced_operators": [], + }, + { + "model": { + "name": "abc", + "version": 101, + "asset": "asset-2", + "backend": "CPU", + }, + "root_operators": [], + }, + ] + + filtered_configs = list(filter(filter_func, config)) + assert ( + len(filtered_configs) == 2 + ), "Expected 2 elements in filtered_configs, but got {}".format( + len(filtered_configs) + ) + + def test_verification_success(self): + filter_func = make_filter_from_options( + model_name="abc", + model_versions=["100", "101"], + model_assets=["asset-1", "asset-2"], + model_backends=None, + ) + config = [ + { + "model": { + "name": "abc", + "version": 100, + "asset": "asset-1", + "backend": "CPU", + }, + "root_operators": [], + "traced_operators": [], + }, + { + "model": { + "name": "abc", + "version": 101, + "asset": "asset-2", + "backend": "CPU", + }, + "root_operators": [], + }, + ] + filtered_configs = list(filter(filter_func, config)) + try: + verify_all_specified_present( + model_assets=["asset-1", "asset-2"], + model_versions=["100", "101"], + selected_models_yaml=filtered_configs, + rule_name="test", + model_name="abc", + new_style_rule=True, + ) + except Exception: + self.fail( + "expected verify_all_specified_present to succeed instead it raised an exception" + ) + + def test_verification_fail(self): + config = [ + { + "model": { + "name": "abc", + "version": 100, + "asset": "asset-1", + "backend": "CPU", + }, + "root_operators": [], + "traced_operators": [], + }, + { + "model": { + "name": "abc", + "version": 101, + "asset": "asset-2", + "backend": "CPU", + }, + "root_operators": [], + }, + ] + + good_assets = ["asset-1", "asset-2"] + good_versions = ["100", "101"] + good_name = "abc" + + # Test bad asset + filter_func_bad_asset = make_filter_from_options( + model_name=good_name, + model_versions=good_versions, + model_assets=["asset-1", "asset-2", "asset-3"], + model_backends=None, + ) + filtered_configs_asset = list(filter(filter_func_bad_asset, config)) + with self.assertRaises(RuntimeError): + verify_all_specified_present( + model_assets=["asset-1", "asset-2", "asset-3"], + model_versions=good_versions, + selected_models_yaml=filtered_configs_asset, + rule_name="test", + model_name=good_name, + new_style_rule=True, + ) + + # Test bad version + filter_func_bad_version = make_filter_from_options( + model_name=good_name, + model_versions=["100", "101", "102"], + model_assets=good_assets, + model_backends=None, + ) + filtered_configs_version = list(filter(filter_func_bad_version, config)) + with self.assertRaises(RuntimeError): + verify_all_specified_present( + model_assets=good_assets, + model_versions=["100", "101", "102"], + selected_models_yaml=filtered_configs_version, + rule_name="test", + model_name=good_name, + new_style_rule=True, + ) + + # Test bad name + filter_func_bad_name = make_filter_from_options( + model_name="abcd", + model_versions=good_versions, + model_assets=good_assets, + model_backends=None, + ) + filtered_configs_name = list(filter(filter_func_bad_name, config)) + with self.assertRaises(RuntimeError): + verify_all_specified_present( + model_assets=good_assets, + model_versions=good_versions, + selected_models_yaml=filtered_configs_name, + rule_name="test", + model_name="abcd", + new_style_rule=True, + ) diff --git a/tools/test/gen_oplist_test.py b/tools/test/gen_oplist_test.py new file mode 100644 index 000000000000..d58e2ccc9067 --- /dev/null +++ b/tools/test/gen_oplist_test.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import unittest +from unittest.mock import MagicMock + +from gen_oplist import throw_if_any_op_includes_overloads + + +class GenOplistTest(unittest.TestCase): + def setUp(self): + pass + + def test_throw_if_any_op_includes_overloads(self): + selective_builder = MagicMock() + selective_builder.operators = MagicMock() + selective_builder.operators.items.return_value = [ + ("op1", MagicMock(include_all_overloads=True)), + ("op2", MagicMock(include_all_overloads=False)), + ("op3", MagicMock(include_all_overloads=True)), + ] + + self.assertRaises( + Exception, throw_if_any_op_includes_overloads, selective_builder + ) + + selective_builder.operators.items.return_value = [ + ("op1", MagicMock(include_all_overloads=False)), + ("op2", MagicMock(include_all_overloads=False)), + ("op3", MagicMock(include_all_overloads=False)), + ] + + # Here we do not expect it to throw an exception since none of the ops + # include all overloads. + throw_if_any_op_includes_overloads(selective_builder) diff --git a/tools/test/test_selective_build.py b/tools/test/test_selective_build.py new file mode 100644 index 000000000000..50a3ba56eb79 --- /dev/null +++ b/tools/test/test_selective_build.py @@ -0,0 +1,281 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import unittest + +from torchgen.selective_build.operator import * +from torchgen.selective_build.selector import ( + combine_selective_builders, + SelectiveBuilder, +) + + +class TestSelectiveBuild(unittest.TestCase): + def test_selective_build_operator(self): + op = SelectiveBuildOperator( + "aten::add.int", + is_root_operator=True, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + self.assertTrue(op.is_root_operator) + self.assertFalse(op.is_used_for_training) + self.assertFalse(op.include_all_overloads) + + def test_selector_factory(self): + yaml_config_v1 = """ +debug_info: + - model1@v100 + - model2@v51 +operators: + aten::add: + is_used_for_training: No + is_root_operator: Yes + include_all_overloads: Yes + aten::add.int: + is_used_for_training: Yes + is_root_operator: No + include_all_overloads: No + aten::mul.int: + is_used_for_training: Yes + is_root_operator: No + include_all_overloads: No +""" + + yaml_config_v2 = """ +debug_info: + - model1@v100 + - model2@v51 +operators: + aten::sub: + is_used_for_training: No + is_root_operator: Yes + include_all_overloads: No + debug_info: + - model1@v100 + aten::sub.int: + is_used_for_training: Yes + is_root_operator: No + include_all_overloads: No +""" + + yaml_config_all = "include_all_operators: Yes" + + yaml_config_invalid = "invalid:" + + selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1) + + self.assertTrue(selector1.is_operator_selected("aten::add")) + self.assertTrue(selector1.is_operator_selected("aten::add.int")) + # Overload name is not used for checking in v1. + self.assertTrue(selector1.is_operator_selected("aten::add.float")) + + def gen(): + return SelectiveBuilder.from_yaml_str(yaml_config_invalid) + + self.assertRaises(Exception, gen) + + selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all) + + self.assertTrue(selector_all.is_operator_selected("aten::add")) + self.assertTrue(selector_all.is_operator_selected("aten::sub")) + self.assertTrue(selector_all.is_operator_selected("aten::sub.int")) + self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32")) + + selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2) + + self.assertFalse(selector2.is_operator_selected("aten::add")) + self.assertTrue(selector2.is_operator_selected("aten::sub")) + self.assertTrue(selector2.is_operator_selected("aten::sub.int")) + + selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( + ["aten::add", "aten::add.int", "aten::mul.int"], + False, + False, + ) + self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float")) + self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add")) + self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int")) + self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub")) + + self.assertFalse(selector_legacy_v1.is_root_operator("aten::add")) + self.assertFalse( + selector_legacy_v1.is_operator_selected_for_training("aten::add") + ) + + selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( + ["aten::add", "aten::add.int", "aten::mul.int"], + True, + False, + ) + + self.assertTrue(selector_legacy_v1.is_root_operator("aten::add")) + self.assertFalse( + selector_legacy_v1.is_operator_selected_for_training("aten::add") + ) + self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float")) + self.assertFalse( + selector_legacy_v1.is_operator_selected_for_training("aten::add.float") + ) + + selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( + ["aten::add", "aten::add.int", "aten::mul.int"], + False, + True, + ) + + self.assertFalse(selector_legacy_v1.is_root_operator("aten::add")) + self.assertTrue( + selector_legacy_v1.is_operator_selected_for_training("aten::add") + ) + self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float")) + self.assertTrue( + selector_legacy_v1.is_operator_selected_for_training("aten::add.float") + ) + + def test_operator_combine(self): + op1 = SelectiveBuildOperator( + "aten::add.int", + is_root_operator=True, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + op2 = SelectiveBuildOperator( + "aten::add.int", + is_root_operator=False, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + op3 = SelectiveBuildOperator( + "aten::add", + is_root_operator=True, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + op4 = SelectiveBuildOperator( + "aten::add.int", + is_root_operator=True, + is_used_for_training=True, + include_all_overloads=False, + _debug_info=None, + ) + + op5 = combine_operators(op1, op2) + + self.assertTrue(op5.is_root_operator) + self.assertFalse(op5.is_used_for_training) + + op6 = combine_operators(op1, op4) + + self.assertTrue(op6.is_root_operator) + self.assertTrue(op6.is_used_for_training) + + def gen_new_op(): + return combine_operators(op1, op3) + + self.assertRaises(Exception, gen_new_op) + + def test_training_op_fetch(self): + yaml_config = """ +operators: + aten::add.int: + is_used_for_training: No + is_root_operator: Yes + include_all_overloads: No + aten::add: + is_used_for_training: Yes + is_root_operator: No + include_all_overloads: Yes +""" + + selector = SelectiveBuilder.from_yaml_str(yaml_config) + self.assertTrue(selector.is_operator_selected_for_training("aten::add.int")) + self.assertTrue(selector.is_operator_selected_for_training("aten::add")) + + def test_kernel_dtypes(self): + yaml_config = """ +kernel_metadata: + add_kernel: + - int8 + - int32 + sub_kernel: + - int16 + - int32 + add/sub_kernel: + - float + - complex +""" + + selector = SelectiveBuilder.from_yaml_str(yaml_config) + + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32")) + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8")) + self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16")) + self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32")) + self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float")) + + self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float")) + self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex")) + self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16")) + self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32")) + + def test_merge_kernel_dtypes(self): + yaml_config1 = """ +kernel_metadata: + add_kernel: + - int8 + add/sub_kernel: + - float + - complex + - none + mul_kernel: + - int8 +""" + + yaml_config2 = """ +kernel_metadata: + add_kernel: + - int32 + sub_kernel: + - int16 + - int32 + add/sub_kernel: + - float + - complex +""" + + selector1 = SelectiveBuilder.from_yaml_str(yaml_config1) + selector2 = SelectiveBuilder.from_yaml_str(yaml_config2) + + selector = combine_selective_builders(selector1, selector2) + + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32")) + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8")) + self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16")) + self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32")) + self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float")) + + self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float")) + self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex")) + self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "none")) + self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16")) + self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32")) + + self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8")) + self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32")) + + def test_all_kernel_dtypes_selected(self): + yaml_config = """ +include_all_non_op_selectives: True +""" + + selector = SelectiveBuilder.from_yaml_str(yaml_config) + + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32")) + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8")) + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int16")) + self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32")) + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))