Consolidate all python targets in the tools folder (#80408)

Summary:
All buck targets that points to caffe2/tools folder are now moved to tools/BUCK.
This also eliminates all python library/binary import in pt_defs.bzl, which caused T124308913.

Test Plan: CI

Differential Revision: D37468313

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80408
Approved by: https://github.com/seemethere, https://github.com/malfet
This commit is contained in:
Linbin Yu
2022-06-29 23:27:47 +00:00
committed by PyTorch MergeBot
parent 70e86b4562
commit b62d39eda0
17 changed files with 807 additions and 213 deletions

View File

@ -62,7 +62,15 @@ jobs:
command: |
sh scripts/buck_setup.sh
- name: Build C10
- name: Build tools
run: |
buck build tools: --keep-going
- name: Run tools tests
run: |
buck test tools:selective_build_test tools:gen_oplist_test tools:gen_operators_yaml_test
- name: Build c10
run: |
buck build c10:c10

View File

@ -16,6 +16,7 @@ exclude_patterns = [
'torch/lib/**',
'venv/**',
'**/*.pyi',
'tools/test/test_selective_build.py',
]
command = [
'python3',
@ -145,6 +146,10 @@ include_patterns = [
exclude_patterns = [
# (linbinyu) copied from internal repo
'tools/code_analyzer/gen_operators_yaml.py',
'tools/gen_vulkan_spv.py',
'tools/test/gen_operators_yaml_test.py',
'tools/test/gen_oplist_test.py',
'tools/test/test_selective_build.py',
]
command = [
'python3',
@ -334,6 +339,7 @@ exclude_patterns = [
command = [
'python3',
'tools/linter/adapters/grep_linter.py',
# @lint-ignore TXT2
'--pattern= ',
'--linter-name=TABS',
'--error-name=saw some tabs',
@ -565,6 +571,9 @@ include_patterns = [
'torch/_decomp/**/*.py',
'test/onnx/**/*.py',
]
exclude_patterns = [
'tools/gen_vulkan_spv.py',
]
command = [
'python3',
'tools/linter/adapters/black_linter.py',

View File

@ -3,8 +3,6 @@
load("@bazel_skylib//lib:paths.bzl", "paths")
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
load("//tools/build_defs:fb_python_binary.bzl", "fb_python_binary")
load("//tools/build_defs:fb_python_library.bzl", "fb_python_library")
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
@ -416,7 +414,7 @@ def gen_aten_files(
name = name,
default_outs = ["."],
outs = get_aten_generated_files(backends),
cmd = "$(exe {}:gen_aten_bin) ".format(ROOT) + " ".join([
cmd = "$(exe {}torchgen:gen) ".format(ROOT_PATH) + " ".join([
"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
"--install_dir $OUT",
] + extra_params),
@ -442,7 +440,7 @@ def gen_aten_unboxing_files(
name = genrule_name,
default_outs = ["."],
outs = get_unboxing_generated_files(),
cmd = "$(exe {}:gen_unboxing_bin) ".format(ROOT) + " ".join([
cmd = "$(exe {}tools:gen_unboxing_bin) ".format(ROOT_PATH) + " ".join([
"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
"--install_dir $OUT",
] + extra_params),
@ -515,7 +513,7 @@ def pt_operator_query_codegen(
# @lint-ignore BUCKLINT
fb_native.genrule(
name = oplist_dir_name,
cmd = ("$(exe {}:gen_oplist) ".format(ROOT) +
cmd = ("$(exe {}tools:gen_oplist) ".format(ROOT_PATH) +
"--model_file_list_path $(@query_outputs 'attrfilter(labels, pt_operator_library, deps(set({deps})))') " +
("" if enforce_traced_op_list else "--allow_include_all_overloads ") +
"--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])),
@ -620,7 +618,7 @@ def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple
outs = get_generate_code_bin_outs(),
default_outs = ["."],
bash = "mkdir -p tools && " +
"$(exe {}tools/setup_helpers:generate_code_bin) ".format(ROOT_PATH) + " ".join(
"$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
# Mobile build only needs libtorch - skip python bindings for now, except
# for ovrsource, which needs Python bindings.
(["--subset libtorch"] if not is_arvr_mode() else []) + [
@ -630,7 +628,7 @@ def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple
] + extra_params,
),
cmd_exe = "@powershell -Command New-Item -Path tools -ItemType Directory -Force; " +
"$(exe {}tools/setup_helpers:generate_code_bin) ".format(ROOT_PATH) + " ".join(
"$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
# Mobile build only needs libtorch - skip python bindings for now, except
# for ovrsource, which needs Python bindings.
(["--subset libtorch"] if not is_arvr_mode() else []) + [
@ -950,7 +948,7 @@ def define_buck_targets(
"torch/csrc/api/include/torch/version.h.in",
"version.txt",
],
cmd = "$(exe {}tools/setup_helpers:gen-version-header) ".format(ROOT_PATH) + " ".join([
cmd = "$(exe {}tools:gen-version-header) ".format(ROOT_PATH) + " ".join([
"--template-path",
"torch/csrc/api/include/torch/version.h.in",
"--version-path",
@ -995,28 +993,13 @@ def define_buck_targets(
],
)
fb_python_library(
name = "substitutelib",
srcs = ["tools/substitute.py"],
base_module = "",
)
fb_python_binary(
name = "substitute",
main_module = "tools.substitute",
visibility = ["PUBLIC"],
deps = [
":substitutelib",
],
)
# @lint-ignore BUCKLINT
fb_native.genrule(
name = "generate_aten_config",
srcs = [
"aten/src/ATen/Config.h.in",
],
cmd = "$(exe :substitute) " + " ".join([
cmd = "$(exe {}tools:substitute) ".format(ROOT_PATH) + " ".join([
"--install_dir",
"$OUT",
"--input-file",
@ -1072,79 +1055,6 @@ def define_buck_targets(
default_outs = ["."],
)
fb_python_binary(
name = "gen_aten_bin",
main_module = "torchgen.gen",
visibility = [
"PUBLIC",
],
deps = [
ROOT_PATH + "torchgen:torchgen",
],
)
fb_python_binary(
name = "gen_unboxing_bin",
main_module = "tools.jit.gen_unboxing",
visibility = [
"PUBLIC",
],
deps = [
ROOT_PATH + "tools/jit:jit",
],
)
fb_python_library(
name = "gen_oplist_lib",
srcs = subdir_glob([
("tools/code_analyzer", "gen_oplist.py"),
("tools/code_analyzer", "gen_op_registration_allowlist.py"),
]),
base_module = "",
tests = [
":gen_oplist_test",
],
deps = [
third_party("pyyaml"),
ROOT_PATH + "tools/lite_interpreter:gen_selected_mobile_ops_header",
ROOT_PATH + "torchgen:torchgen",
],
)
fb_python_library(
name = "gen_operators_yaml_lib",
srcs = subdir_glob([
("tools/code_analyzer", "gen_operators_yaml.py"),
("tools/code_analyzer", "gen_op_registration_allowlist.py"),
]),
base_module = "",
tests = [
":gen_operators_yaml_test",
],
deps = [
third_party("pyyaml"),
ROOT_PATH + "torchgen:torchgen",
],
)
fb_python_binary(
name = "gen_oplist",
main_module = "gen_oplist",
visibility = ["PUBLIC"],
deps = [
":gen_oplist_lib",
],
)
fb_python_binary(
name = "gen_operators_yaml",
main_module = "gen_operators_yaml",
visibility = ["PUBLIC"],
deps = [
":gen_operators_yaml_lib",
],
)
gen_aten_files(
name = "gen_aten",
extra_flags = get_aten_codegen_extra_params(USED_PT_BACKENDS),

View File

@ -62,7 +62,7 @@ if(NOT USE_VULKAN_SHADERC_RUNTIME)
execute_process(
COMMAND
"${PYTHON_EXECUTABLE}"
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen_vulkan_spv.py
${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py
--glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl
--output-path ${VULKAN_GEN_OUTPUT_PATH}
--glslc-path=${GLSLC_PATH}

View File

@ -39,7 +39,7 @@ def pt_operator_library(
name = name,
out = "model_operators.yaml",
cmd = (
"$(exe {root}:gen_operators_yaml) " +
"$(exe {exe}) " +
"{optionally_root_ops} " +
"{optionally_training_root_ops} " +
"--rule_name {rule_name} " +
@ -52,7 +52,7 @@ def pt_operator_library(
"{optionally_model_traced_backends} " +
"{optionally_include_all_operators}"
).format(
root = "//" if IS_OSS else "//xplat/caffe2",
exe = "//tools:gen_operators_yaml" if IS_OSS else "//xplat/caffe2/tools:gen_operators_yaml",
rule_name = name,
model_name = model_name,
dep_graph_yaml = "none" if IS_OSS else "$(location //xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",

263
tools/BUCK.bzl Normal file
View File

@ -0,0 +1,263 @@
# @lint-ignore-every FBCODEBZLADDLOADS
load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
# shared by internal and OSS BUCK
def define_tools_targets(
python_binary,
python_library,
python_test,
third_party,
torchgen_deps,
contacts = []):
python_library(
name = "substitutelib",
srcs = ["substitute.py"],
base_module = "",
)
python_binary(
name = "substitute",
main_module = "substitute",
visibility = ["PUBLIC"],
deps = [
":substitutelib",
],
)
python_library(
name = "jit",
# @lint-ignore BUCKRESTRICTEDSYNTAX
srcs = glob([
"jit/*.py",
"jit/templates/*",
]),
base_module = "tools",
visibility = ["PUBLIC"],
deps = [
torchgen_deps,
],
)
python_binary(
name = "gen_unboxing_bin",
main_module = "tools.jit.gen_unboxing",
visibility = [
"PUBLIC",
],
deps = [
":jit",
],
)
python_library(
name = "gen_selected_mobile_ops_header",
srcs = ["lite_interpreter/gen_selected_mobile_ops_header.py"],
base_module = "tools",
visibility = ["PUBLIC"],
)
python_library(
name = "gen_oplist_lib",
srcs = subdir_glob([
("code_analyzer", "gen_oplist.py"),
("code_analyzer", "gen_op_registration_allowlist.py"),
]),
base_module = "",
tests = [
":gen_oplist_test",
],
deps = [
":gen_selected_mobile_ops_header",
torchgen_deps,
third_party("pyyaml"),
],
)
python_binary(
name = "gen_oplist",
main_module = "gen_oplist",
visibility = ["PUBLIC"],
deps = [
":gen_oplist_lib",
],
)
python_library(
name = "gen_operators_yaml_lib",
srcs = subdir_glob([
("code_analyzer", "gen_operators_yaml.py"),
("code_analyzer", "gen_op_registration_allowlist.py"),
]),
base_module = "",
tests = [
":gen_operators_yaml_test",
],
deps = [
third_party("pyyaml"),
torchgen_deps,
],
)
python_binary(
name = "gen_operators_yaml",
main_module = "gen_operators_yaml",
visibility = ["PUBLIC"],
deps = [
":gen_operators_yaml_lib",
],
)
python_library(
name = "autograd",
# @lint-ignore BUCKRESTRICTEDSYNTAX
srcs = glob(
["autograd/*.py"],
),
base_module = "tools",
resources = [
"autograd/deprecated.yaml",
"autograd/derivatives.yaml",
"autograd/templates/ADInplaceOrViewType.cpp",
"autograd/templates/Functions.cpp",
"autograd/templates/Functions.h",
"autograd/templates/TraceType.cpp",
"autograd/templates/VariableType.cpp",
"autograd/templates/VariableType.h",
"autograd/templates/annotated_fn_args.py.in",
"autograd/templates/python_enum_tag.cpp",
"autograd/templates/python_fft_functions.cpp",
"autograd/templates/python_functions.cpp",
"autograd/templates/python_functions.h",
"autograd/templates/python_linalg_functions.cpp",
"autograd/templates/python_nn_functions.cpp",
"autograd/templates/python_return_types.cpp",
"autograd/templates/python_sparse_functions.cpp",
"autograd/templates/python_special_functions.cpp",
"autograd/templates/python_torch_functions.cpp",
"autograd/templates/python_variable_methods.cpp",
"autograd/templates/variable_factories.h",
],
visibility = ["PUBLIC"],
deps = [
third_party("pyyaml"),
torchgen_deps,
],
)
python_library(
name = "generate_code",
srcs = [
"setup_helpers/generate_code.py",
],
base_module = "tools",
deps = [
":autograd",
":jit",
torchgen_deps,
],
)
python_binary(
name = "generate_code_bin",
main_module = "tools.setup_helpers.generate_code",
# Windows does not support inplace:
# https://github.com/facebook/buck/issues/2161.
#
# Note that //arvr/mode/embedded/win/clang-aarch64-release sets
# its target platform to
# ovr_config//platform/embedded:clang-aarch64-linux-release, hence
# that is why we are selecting that OS to trigger this behavior.
package_style = select({
"DEFAULT": "inplace",
"ovr_config//os:linux-arm64": "standalone",
}),
visibility = ["PUBLIC"],
# Because Windows does not support inplace packaging, we need to
# ensure it is unzipped before executing it, otherwise it will not
# be able to find any resources using path manipulation.
#
# See note above about why the OS is Linux here and not Windows.
zip_safe = select({
"DEFAULT": True,
"ovr_config//os:linux-arm64": False,
}),
deps = [
":generate_code",
],
)
python_library(
name = "gen-version-header-lib",
srcs = [
"setup_helpers/gen_version_header.py",
],
base_module = "",
deps = [],
)
python_binary(
name = "gen-version-header",
main_module = "setup_helpers.gen_version_header",
visibility = ["PUBLIC"],
deps = [
":gen-version-header-lib",
],
)
python_library(
name = "gen_aten_vulkan_spv_lib",
srcs = [
"gen_vulkan_spv.py",
],
base_module = "",
deps = [
torchgen_deps,
],
)
python_binary(
name = "gen_aten_vulkan_spv_bin",
main_module = "gen_vulkan_spv",
visibility = [
"PUBLIC",
],
deps = [
":gen_aten_vulkan_spv_lib",
],
)
python_test(
name = "selective_build_test",
srcs = [
"test/test_selective_build.py",
],
contacts = contacts,
visibility = ["PUBLIC"],
deps = [
torchgen_deps,
],
)
python_test(
name = "gen_oplist_test",
srcs = [
"test/gen_oplist_test.py",
],
contacts = contacts,
visibility = ["PUBLIC"],
deps = [
":gen_oplist_lib",
],
)
python_test(
name = "gen_operators_yaml_test",
srcs = [
"test/gen_operators_yaml_test.py",
],
visibility = ["PUBLIC"],
contacts = contacts,
deps = [
":gen_operators_yaml_lib",
],
)

10
tools/BUCK.oss Normal file
View File

@ -0,0 +1,10 @@
load("//:buckbuild.bzl", "third_party")
load(":BUCK.bzl", "define_tools_targets")
define_tools_targets(
python_binary = python_binary,
python_library = python_library,
python_test = python_test,
third_party = third_party,
torchgen_deps = "//torchgen:torchgen",
)

View File

@ -1,35 +0,0 @@
python_library(
name = "autograd",
srcs = glob(
["*.py"],
),
base_module = "tools.autograd",
resources = [
"deprecated.yaml",
"derivatives.yaml",
"templates/ADInplaceOrViewType.cpp",
"templates/Functions.cpp",
"templates/Functions.h",
"templates/TraceType.cpp",
"templates/VariableType.cpp",
"templates/VariableType.h",
"templates/annotated_fn_args.py.in",
"templates/python_fft_functions.cpp",
"templates/python_functions.cpp",
"templates/python_functions.h",
"templates/python_linalg_functions.cpp",
"templates/python_nn_functions.cpp",
"templates/python_return_types.cpp",
"templates/python_sparse_functions.cpp",
"templates/python_special_functions.cpp",
"templates/python_torch_functions.cpp",
"templates/python_variable_methods.cpp",
"templates/variable_factories.h",
"templates/python_enum_tag.cpp",
],
visibility = ["PUBLIC"],
deps = [
"//third_party:pyyaml",
"//torchgen:torchgen",
],
)

View File

@ -1,9 +0,0 @@
# Only used for PyTorch open source BUCK build
# @lint-ignore-every BUCKRESTRICTEDSYNTAX
# @lint-ignore-every FBCODEBZLADDLOADS
def fb_python_binary(**kwgs):
if read_config("pt", "is_oss", "0") == "0":
fail("This file is for open source pytorch build. Do not use it in fbsource!")
python_binary(**kwgs)

View File

@ -1,9 +0,0 @@
# Only used for PyTorch open source BUCK build
# @lint-ignore-every BUCKRESTRICTEDSYNTAX
# @lint-ignore-every FBCODEBZLADDLOADS
def fb_python_library(**kwgs):
if read_config("pt", "is_oss", "0") == "0":
fail("This file is for open source pytorch build. Do not use it in fbsource!")
python_library(**kwgs)

View File

@ -1,12 +0,0 @@
python_library(
name = "jit",
srcs = glob([
"*.py",
"templates/*",
]),
base_module = "tools.jit",
visibility = ["PUBLIC"],
deps = [
"//torchgen:torchgen",
],
)

View File

@ -1,6 +0,0 @@
python_library(
name = "gen_selected_mobile_ops_header",
srcs = ["gen_selected_mobile_ops_header.py"],
base_module = "tools.lite_interpreter",
visibility = ["PUBLIC"],
)

View File

@ -1,41 +0,0 @@
python_library(
name = "generate_code",
srcs = [
"generate_code.py",
],
base_module = "tools.setup_helpers",
deps = [
"//tools/autograd:autograd",
"//tools/jit:jit",
"//torchgen:torchgen",
],
)
python_binary(
name = "generate_code_bin",
main_module = "tools.setup_helpers.generate_code",
visibility = ["PUBLIC"],
# package_style = "inplace",
zip_safe = False,
deps = [
":generate_code",
],
)
python_library(
name = "gen-version-header-lib",
srcs = [
"gen_version_header.py",
],
base_module = "tools.setup_helpers",
deps = [],
)
python_binary(
name = "gen-version-header",
main_module = "tools.setup_helpers.gen_version_header",
visibility = ["PUBLIC"],
deps = [
":gen-version-header-lib",
],
)

View File

@ -0,0 +1,190 @@
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
import unittest
from gen_operators_yaml import make_filter_from_options, verify_all_specified_present
class GenOperatorsYAMLTest(unittest.TestCase):
def setUp(self):
pass
def test_filter_creation(self):
filter_func = make_filter_from_options(
model_name="abc",
model_versions=["100", "101"],
model_assets=None,
model_backends=None,
)
config = [
{
"model": {
"name": "abc",
"version": 100,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
"traced_operators": [],
},
{
"model": {
"name": "abc",
"version": 102,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
},
{
"model": {
"name": "abcd",
"version": 100,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
"traced_operators": [],
},
{
"model": {
"name": "abc",
"version": 101,
"asset": "asset-2",
"backend": "CPU",
},
"root_operators": [],
},
]
filtered_configs = list(filter(filter_func, config))
assert (
len(filtered_configs) == 2
), "Expected 2 elements in filtered_configs, but got {}".format(
len(filtered_configs)
)
def test_verification_success(self):
filter_func = make_filter_from_options(
model_name="abc",
model_versions=["100", "101"],
model_assets=["asset-1", "asset-2"],
model_backends=None,
)
config = [
{
"model": {
"name": "abc",
"version": 100,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
"traced_operators": [],
},
{
"model": {
"name": "abc",
"version": 101,
"asset": "asset-2",
"backend": "CPU",
},
"root_operators": [],
},
]
filtered_configs = list(filter(filter_func, config))
try:
verify_all_specified_present(
model_assets=["asset-1", "asset-2"],
model_versions=["100", "101"],
selected_models_yaml=filtered_configs,
rule_name="test",
model_name="abc",
new_style_rule=True,
)
except Exception:
self.fail(
"expected verify_all_specified_present to succeed instead it raised an exception"
)
def test_verification_fail(self):
config = [
{
"model": {
"name": "abc",
"version": 100,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
"traced_operators": [],
},
{
"model": {
"name": "abc",
"version": 101,
"asset": "asset-2",
"backend": "CPU",
},
"root_operators": [],
},
]
good_assets = ["asset-1", "asset-2"]
good_versions = ["100", "101"]
good_name = "abc"
# Test bad asset
filter_func_bad_asset = make_filter_from_options(
model_name=good_name,
model_versions=good_versions,
model_assets=["asset-1", "asset-2", "asset-3"],
model_backends=None,
)
filtered_configs_asset = list(filter(filter_func_bad_asset, config))
with self.assertRaises(RuntimeError):
verify_all_specified_present(
model_assets=["asset-1", "asset-2", "asset-3"],
model_versions=good_versions,
selected_models_yaml=filtered_configs_asset,
rule_name="test",
model_name=good_name,
new_style_rule=True,
)
# Test bad version
filter_func_bad_version = make_filter_from_options(
model_name=good_name,
model_versions=["100", "101", "102"],
model_assets=good_assets,
model_backends=None,
)
filtered_configs_version = list(filter(filter_func_bad_version, config))
with self.assertRaises(RuntimeError):
verify_all_specified_present(
model_assets=good_assets,
model_versions=["100", "101", "102"],
selected_models_yaml=filtered_configs_version,
rule_name="test",
model_name=good_name,
new_style_rule=True,
)
# Test bad name
filter_func_bad_name = make_filter_from_options(
model_name="abcd",
model_versions=good_versions,
model_assets=good_assets,
model_backends=None,
)
filtered_configs_name = list(filter(filter_func_bad_name, config))
with self.assertRaises(RuntimeError):
verify_all_specified_present(
model_assets=good_assets,
model_versions=good_versions,
selected_models_yaml=filtered_configs_name,
rule_name="test",
model_name="abcd",
new_style_rule=True,
)

View File

@ -0,0 +1,35 @@
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
import unittest
from unittest.mock import MagicMock
from gen_oplist import throw_if_any_op_includes_overloads
class GenOplistTest(unittest.TestCase):
def setUp(self):
pass
def test_throw_if_any_op_includes_overloads(self):
selective_builder = MagicMock()
selective_builder.operators = MagicMock()
selective_builder.operators.items.return_value = [
("op1", MagicMock(include_all_overloads=True)),
("op2", MagicMock(include_all_overloads=False)),
("op3", MagicMock(include_all_overloads=True)),
]
self.assertRaises(
Exception, throw_if_any_op_includes_overloads, selective_builder
)
selective_builder.operators.items.return_value = [
("op1", MagicMock(include_all_overloads=False)),
("op2", MagicMock(include_all_overloads=False)),
("op3", MagicMock(include_all_overloads=False)),
]
# Here we do not expect it to throw an exception since none of the ops
# include all overloads.
throw_if_any_op_includes_overloads(selective_builder)

View File

@ -0,0 +1,281 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import unittest
from torchgen.selective_build.operator import *
from torchgen.selective_build.selector import (
combine_selective_builders,
SelectiveBuilder,
)
class TestSelectiveBuild(unittest.TestCase):
def test_selective_build_operator(self):
op = SelectiveBuildOperator(
"aten::add.int",
is_root_operator=True,
is_used_for_training=False,
include_all_overloads=False,
_debug_info=None,
)
self.assertTrue(op.is_root_operator)
self.assertFalse(op.is_used_for_training)
self.assertFalse(op.include_all_overloads)
def test_selector_factory(self):
yaml_config_v1 = """
debug_info:
- model1@v100
- model2@v51
operators:
aten::add:
is_used_for_training: No
is_root_operator: Yes
include_all_overloads: Yes
aten::add.int:
is_used_for_training: Yes
is_root_operator: No
include_all_overloads: No
aten::mul.int:
is_used_for_training: Yes
is_root_operator: No
include_all_overloads: No
"""
yaml_config_v2 = """
debug_info:
- model1@v100
- model2@v51
operators:
aten::sub:
is_used_for_training: No
is_root_operator: Yes
include_all_overloads: No
debug_info:
- model1@v100
aten::sub.int:
is_used_for_training: Yes
is_root_operator: No
include_all_overloads: No
"""
yaml_config_all = "include_all_operators: Yes"
yaml_config_invalid = "invalid:"
selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1)
self.assertTrue(selector1.is_operator_selected("aten::add"))
self.assertTrue(selector1.is_operator_selected("aten::add.int"))
# Overload name is not used for checking in v1.
self.assertTrue(selector1.is_operator_selected("aten::add.float"))
def gen():
return SelectiveBuilder.from_yaml_str(yaml_config_invalid)
self.assertRaises(Exception, gen)
selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all)
self.assertTrue(selector_all.is_operator_selected("aten::add"))
self.assertTrue(selector_all.is_operator_selected("aten::sub"))
self.assertTrue(selector_all.is_operator_selected("aten::sub.int"))
self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32"))
selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2)
self.assertFalse(selector2.is_operator_selected("aten::add"))
self.assertTrue(selector2.is_operator_selected("aten::sub"))
self.assertTrue(selector2.is_operator_selected("aten::sub.int"))
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
["aten::add", "aten::add.int", "aten::mul.int"],
False,
False,
)
self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float"))
self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add"))
self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int"))
self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub"))
self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
self.assertFalse(
selector_legacy_v1.is_operator_selected_for_training("aten::add")
)
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
["aten::add", "aten::add.int", "aten::mul.int"],
True,
False,
)
self.assertTrue(selector_legacy_v1.is_root_operator("aten::add"))
self.assertFalse(
selector_legacy_v1.is_operator_selected_for_training("aten::add")
)
self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float"))
self.assertFalse(
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
)
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
["aten::add", "aten::add.int", "aten::mul.int"],
False,
True,
)
self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
self.assertTrue(
selector_legacy_v1.is_operator_selected_for_training("aten::add")
)
self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float"))
self.assertTrue(
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
)
def test_operator_combine(self):
op1 = SelectiveBuildOperator(
"aten::add.int",
is_root_operator=True,
is_used_for_training=False,
include_all_overloads=False,
_debug_info=None,
)
op2 = SelectiveBuildOperator(
"aten::add.int",
is_root_operator=False,
is_used_for_training=False,
include_all_overloads=False,
_debug_info=None,
)
op3 = SelectiveBuildOperator(
"aten::add",
is_root_operator=True,
is_used_for_training=False,
include_all_overloads=False,
_debug_info=None,
)
op4 = SelectiveBuildOperator(
"aten::add.int",
is_root_operator=True,
is_used_for_training=True,
include_all_overloads=False,
_debug_info=None,
)
op5 = combine_operators(op1, op2)
self.assertTrue(op5.is_root_operator)
self.assertFalse(op5.is_used_for_training)
op6 = combine_operators(op1, op4)
self.assertTrue(op6.is_root_operator)
self.assertTrue(op6.is_used_for_training)
def gen_new_op():
return combine_operators(op1, op3)
self.assertRaises(Exception, gen_new_op)
def test_training_op_fetch(self):
yaml_config = """
operators:
aten::add.int:
is_used_for_training: No
is_root_operator: Yes
include_all_overloads: No
aten::add:
is_used_for_training: Yes
is_root_operator: No
include_all_overloads: Yes
"""
selector = SelectiveBuilder.from_yaml_str(yaml_config)
self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
def test_kernel_dtypes(self):
yaml_config = """
kernel_metadata:
add_kernel:
- int8
- int32
sub_kernel:
- int16
- int32
add/sub_kernel:
- float
- complex
"""
selector = SelectiveBuilder.from_yaml_str(yaml_config)
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
def test_merge_kernel_dtypes(self):
yaml_config1 = """
kernel_metadata:
add_kernel:
- int8
add/sub_kernel:
- float
- complex
- none
mul_kernel:
- int8
"""
yaml_config2 = """
kernel_metadata:
add_kernel:
- int32
sub_kernel:
- int16
- int32
add/sub_kernel:
- float
- complex
"""
selector1 = SelectiveBuilder.from_yaml_str(yaml_config1)
selector2 = SelectiveBuilder.from_yaml_str(yaml_config2)
selector = combine_selective_builders(selector1, selector2)
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "none"))
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
def test_all_kernel_dtypes_selected(self):
yaml_config = """
include_all_non_op_selectives: True
"""
selector = SelectiveBuilder.from_yaml_str(yaml_config)
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int16"))
self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))