mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Consolidate all python targets in the tools folder (#80408)
Summary: All buck targets that points to caffe2/tools folder are now moved to tools/BUCK. This also eliminates all python library/binary import in pt_defs.bzl, which caused T124308913. Test Plan: CI Differential Revision: D37468313 Pull Request resolved: https://github.com/pytorch/pytorch/pull/80408 Approved by: https://github.com/seemethere, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
70e86b4562
commit
b62d39eda0
10
.github/workflows/_buck-build-test.yml
vendored
10
.github/workflows/_buck-build-test.yml
vendored
@ -62,7 +62,15 @@ jobs:
|
||||
command: |
|
||||
sh scripts/buck_setup.sh
|
||||
|
||||
- name: Build C10
|
||||
- name: Build tools
|
||||
run: |
|
||||
buck build tools: --keep-going
|
||||
|
||||
- name: Run tools tests
|
||||
run: |
|
||||
buck test tools:selective_build_test tools:gen_oplist_test tools:gen_operators_yaml_test
|
||||
|
||||
- name: Build c10
|
||||
run: |
|
||||
buck build c10:c10
|
||||
|
||||
|
@ -16,6 +16,7 @@ exclude_patterns = [
|
||||
'torch/lib/**',
|
||||
'venv/**',
|
||||
'**/*.pyi',
|
||||
'tools/test/test_selective_build.py',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
@ -145,6 +146,10 @@ include_patterns = [
|
||||
exclude_patterns = [
|
||||
# (linbinyu) copied from internal repo
|
||||
'tools/code_analyzer/gen_operators_yaml.py',
|
||||
'tools/gen_vulkan_spv.py',
|
||||
'tools/test/gen_operators_yaml_test.py',
|
||||
'tools/test/gen_oplist_test.py',
|
||||
'tools/test/test_selective_build.py',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
@ -334,6 +339,7 @@ exclude_patterns = [
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/grep_linter.py',
|
||||
# @lint-ignore TXT2
|
||||
'--pattern= ',
|
||||
'--linter-name=TABS',
|
||||
'--error-name=saw some tabs',
|
||||
@ -565,6 +571,9 @@ include_patterns = [
|
||||
'torch/_decomp/**/*.py',
|
||||
'test/onnx/**/*.py',
|
||||
]
|
||||
exclude_patterns = [
|
||||
'tools/gen_vulkan_spv.py',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/black_linter.py',
|
||||
|
104
buckbuild.bzl
104
buckbuild.bzl
@ -3,8 +3,6 @@
|
||||
|
||||
load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
|
||||
load("//tools/build_defs:fb_python_binary.bzl", "fb_python_binary")
|
||||
load("//tools/build_defs:fb_python_library.bzl", "fb_python_library")
|
||||
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
|
||||
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||
load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
|
||||
@ -416,7 +414,7 @@ def gen_aten_files(
|
||||
name = name,
|
||||
default_outs = ["."],
|
||||
outs = get_aten_generated_files(backends),
|
||||
cmd = "$(exe {}:gen_aten_bin) ".format(ROOT) + " ".join([
|
||||
cmd = "$(exe {}torchgen:gen) ".format(ROOT_PATH) + " ".join([
|
||||
"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
|
||||
"--install_dir $OUT",
|
||||
] + extra_params),
|
||||
@ -442,7 +440,7 @@ def gen_aten_unboxing_files(
|
||||
name = genrule_name,
|
||||
default_outs = ["."],
|
||||
outs = get_unboxing_generated_files(),
|
||||
cmd = "$(exe {}:gen_unboxing_bin) ".format(ROOT) + " ".join([
|
||||
cmd = "$(exe {}tools:gen_unboxing_bin) ".format(ROOT_PATH) + " ".join([
|
||||
"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
|
||||
"--install_dir $OUT",
|
||||
] + extra_params),
|
||||
@ -515,7 +513,7 @@ def pt_operator_query_codegen(
|
||||
# @lint-ignore BUCKLINT
|
||||
fb_native.genrule(
|
||||
name = oplist_dir_name,
|
||||
cmd = ("$(exe {}:gen_oplist) ".format(ROOT) +
|
||||
cmd = ("$(exe {}tools:gen_oplist) ".format(ROOT_PATH) +
|
||||
"--model_file_list_path $(@query_outputs 'attrfilter(labels, pt_operator_library, deps(set({deps})))') " +
|
||||
("" if enforce_traced_op_list else "--allow_include_all_overloads ") +
|
||||
"--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])),
|
||||
@ -620,7 +618,7 @@ def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple
|
||||
outs = get_generate_code_bin_outs(),
|
||||
default_outs = ["."],
|
||||
bash = "mkdir -p tools && " +
|
||||
"$(exe {}tools/setup_helpers:generate_code_bin) ".format(ROOT_PATH) + " ".join(
|
||||
"$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
|
||||
# Mobile build only needs libtorch - skip python bindings for now, except
|
||||
# for ovrsource, which needs Python bindings.
|
||||
(["--subset libtorch"] if not is_arvr_mode() else []) + [
|
||||
@ -630,7 +628,7 @@ def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple
|
||||
] + extra_params,
|
||||
),
|
||||
cmd_exe = "@powershell -Command New-Item -Path tools -ItemType Directory -Force; " +
|
||||
"$(exe {}tools/setup_helpers:generate_code_bin) ".format(ROOT_PATH) + " ".join(
|
||||
"$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
|
||||
# Mobile build only needs libtorch - skip python bindings for now, except
|
||||
# for ovrsource, which needs Python bindings.
|
||||
(["--subset libtorch"] if not is_arvr_mode() else []) + [
|
||||
@ -950,7 +948,7 @@ def define_buck_targets(
|
||||
"torch/csrc/api/include/torch/version.h.in",
|
||||
"version.txt",
|
||||
],
|
||||
cmd = "$(exe {}tools/setup_helpers:gen-version-header) ".format(ROOT_PATH) + " ".join([
|
||||
cmd = "$(exe {}tools:gen-version-header) ".format(ROOT_PATH) + " ".join([
|
||||
"--template-path",
|
||||
"torch/csrc/api/include/torch/version.h.in",
|
||||
"--version-path",
|
||||
@ -995,28 +993,13 @@ def define_buck_targets(
|
||||
],
|
||||
)
|
||||
|
||||
fb_python_library(
|
||||
name = "substitutelib",
|
||||
srcs = ["tools/substitute.py"],
|
||||
base_module = "",
|
||||
)
|
||||
|
||||
fb_python_binary(
|
||||
name = "substitute",
|
||||
main_module = "tools.substitute",
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
":substitutelib",
|
||||
],
|
||||
)
|
||||
|
||||
# @lint-ignore BUCKLINT
|
||||
fb_native.genrule(
|
||||
name = "generate_aten_config",
|
||||
srcs = [
|
||||
"aten/src/ATen/Config.h.in",
|
||||
],
|
||||
cmd = "$(exe :substitute) " + " ".join([
|
||||
cmd = "$(exe {}tools:substitute) ".format(ROOT_PATH) + " ".join([
|
||||
"--install_dir",
|
||||
"$OUT",
|
||||
"--input-file",
|
||||
@ -1072,79 +1055,6 @@ def define_buck_targets(
|
||||
default_outs = ["."],
|
||||
)
|
||||
|
||||
fb_python_binary(
|
||||
name = "gen_aten_bin",
|
||||
main_module = "torchgen.gen",
|
||||
visibility = [
|
||||
"PUBLIC",
|
||||
],
|
||||
deps = [
|
||||
ROOT_PATH + "torchgen:torchgen",
|
||||
],
|
||||
)
|
||||
|
||||
fb_python_binary(
|
||||
name = "gen_unboxing_bin",
|
||||
main_module = "tools.jit.gen_unboxing",
|
||||
visibility = [
|
||||
"PUBLIC",
|
||||
],
|
||||
deps = [
|
||||
ROOT_PATH + "tools/jit:jit",
|
||||
],
|
||||
)
|
||||
|
||||
fb_python_library(
|
||||
name = "gen_oplist_lib",
|
||||
srcs = subdir_glob([
|
||||
("tools/code_analyzer", "gen_oplist.py"),
|
||||
("tools/code_analyzer", "gen_op_registration_allowlist.py"),
|
||||
]),
|
||||
base_module = "",
|
||||
tests = [
|
||||
":gen_oplist_test",
|
||||
],
|
||||
deps = [
|
||||
third_party("pyyaml"),
|
||||
ROOT_PATH + "tools/lite_interpreter:gen_selected_mobile_ops_header",
|
||||
ROOT_PATH + "torchgen:torchgen",
|
||||
],
|
||||
)
|
||||
|
||||
fb_python_library(
|
||||
name = "gen_operators_yaml_lib",
|
||||
srcs = subdir_glob([
|
||||
("tools/code_analyzer", "gen_operators_yaml.py"),
|
||||
("tools/code_analyzer", "gen_op_registration_allowlist.py"),
|
||||
]),
|
||||
base_module = "",
|
||||
tests = [
|
||||
":gen_operators_yaml_test",
|
||||
],
|
||||
deps = [
|
||||
third_party("pyyaml"),
|
||||
ROOT_PATH + "torchgen:torchgen",
|
||||
],
|
||||
)
|
||||
|
||||
fb_python_binary(
|
||||
name = "gen_oplist",
|
||||
main_module = "gen_oplist",
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
":gen_oplist_lib",
|
||||
],
|
||||
)
|
||||
|
||||
fb_python_binary(
|
||||
name = "gen_operators_yaml",
|
||||
main_module = "gen_operators_yaml",
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
":gen_operators_yaml_lib",
|
||||
],
|
||||
)
|
||||
|
||||
gen_aten_files(
|
||||
name = "gen_aten",
|
||||
extra_flags = get_aten_codegen_extra_params(USED_PT_BACKENDS),
|
||||
|
@ -62,7 +62,7 @@ if(NOT USE_VULKAN_SHADERC_RUNTIME)
|
||||
execute_process(
|
||||
COMMAND
|
||||
"${PYTHON_EXECUTABLE}"
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen_vulkan_spv.py
|
||||
${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py
|
||||
--glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl
|
||||
--output-path ${VULKAN_GEN_OUTPUT_PATH}
|
||||
--glslc-path=${GLSLC_PATH}
|
||||
|
@ -39,7 +39,7 @@ def pt_operator_library(
|
||||
name = name,
|
||||
out = "model_operators.yaml",
|
||||
cmd = (
|
||||
"$(exe {root}:gen_operators_yaml) " +
|
||||
"$(exe {exe}) " +
|
||||
"{optionally_root_ops} " +
|
||||
"{optionally_training_root_ops} " +
|
||||
"--rule_name {rule_name} " +
|
||||
@ -52,7 +52,7 @@ def pt_operator_library(
|
||||
"{optionally_model_traced_backends} " +
|
||||
"{optionally_include_all_operators}"
|
||||
).format(
|
||||
root = "//" if IS_OSS else "//xplat/caffe2",
|
||||
exe = "//tools:gen_operators_yaml" if IS_OSS else "//xplat/caffe2/tools:gen_operators_yaml",
|
||||
rule_name = name,
|
||||
model_name = model_name,
|
||||
dep_graph_yaml = "none" if IS_OSS else "$(location //xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
|
||||
|
263
tools/BUCK.bzl
Normal file
263
tools/BUCK.bzl
Normal file
@ -0,0 +1,263 @@
|
||||
# @lint-ignore-every FBCODEBZLADDLOADS
|
||||
load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
|
||||
|
||||
# shared by internal and OSS BUCK
|
||||
def define_tools_targets(
|
||||
python_binary,
|
||||
python_library,
|
||||
python_test,
|
||||
third_party,
|
||||
torchgen_deps,
|
||||
contacts = []):
|
||||
python_library(
|
||||
name = "substitutelib",
|
||||
srcs = ["substitute.py"],
|
||||
base_module = "",
|
||||
)
|
||||
|
||||
python_binary(
|
||||
name = "substitute",
|
||||
main_module = "substitute",
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
":substitutelib",
|
||||
],
|
||||
)
|
||||
|
||||
python_library(
|
||||
name = "jit",
|
||||
# @lint-ignore BUCKRESTRICTEDSYNTAX
|
||||
srcs = glob([
|
||||
"jit/*.py",
|
||||
"jit/templates/*",
|
||||
]),
|
||||
base_module = "tools",
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
torchgen_deps,
|
||||
],
|
||||
)
|
||||
|
||||
python_binary(
|
||||
name = "gen_unboxing_bin",
|
||||
main_module = "tools.jit.gen_unboxing",
|
||||
visibility = [
|
||||
"PUBLIC",
|
||||
],
|
||||
deps = [
|
||||
":jit",
|
||||
],
|
||||
)
|
||||
|
||||
python_library(
|
||||
name = "gen_selected_mobile_ops_header",
|
||||
srcs = ["lite_interpreter/gen_selected_mobile_ops_header.py"],
|
||||
base_module = "tools",
|
||||
visibility = ["PUBLIC"],
|
||||
)
|
||||
|
||||
python_library(
|
||||
name = "gen_oplist_lib",
|
||||
srcs = subdir_glob([
|
||||
("code_analyzer", "gen_oplist.py"),
|
||||
("code_analyzer", "gen_op_registration_allowlist.py"),
|
||||
]),
|
||||
base_module = "",
|
||||
tests = [
|
||||
":gen_oplist_test",
|
||||
],
|
||||
deps = [
|
||||
":gen_selected_mobile_ops_header",
|
||||
torchgen_deps,
|
||||
third_party("pyyaml"),
|
||||
],
|
||||
)
|
||||
|
||||
python_binary(
|
||||
name = "gen_oplist",
|
||||
main_module = "gen_oplist",
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
":gen_oplist_lib",
|
||||
],
|
||||
)
|
||||
|
||||
python_library(
|
||||
name = "gen_operators_yaml_lib",
|
||||
srcs = subdir_glob([
|
||||
("code_analyzer", "gen_operators_yaml.py"),
|
||||
("code_analyzer", "gen_op_registration_allowlist.py"),
|
||||
]),
|
||||
base_module = "",
|
||||
tests = [
|
||||
":gen_operators_yaml_test",
|
||||
],
|
||||
deps = [
|
||||
third_party("pyyaml"),
|
||||
torchgen_deps,
|
||||
],
|
||||
)
|
||||
|
||||
python_binary(
|
||||
name = "gen_operators_yaml",
|
||||
main_module = "gen_operators_yaml",
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
":gen_operators_yaml_lib",
|
||||
],
|
||||
)
|
||||
|
||||
python_library(
|
||||
name = "autograd",
|
||||
# @lint-ignore BUCKRESTRICTEDSYNTAX
|
||||
srcs = glob(
|
||||
["autograd/*.py"],
|
||||
),
|
||||
base_module = "tools",
|
||||
resources = [
|
||||
"autograd/deprecated.yaml",
|
||||
"autograd/derivatives.yaml",
|
||||
"autograd/templates/ADInplaceOrViewType.cpp",
|
||||
"autograd/templates/Functions.cpp",
|
||||
"autograd/templates/Functions.h",
|
||||
"autograd/templates/TraceType.cpp",
|
||||
"autograd/templates/VariableType.cpp",
|
||||
"autograd/templates/VariableType.h",
|
||||
"autograd/templates/annotated_fn_args.py.in",
|
||||
"autograd/templates/python_enum_tag.cpp",
|
||||
"autograd/templates/python_fft_functions.cpp",
|
||||
"autograd/templates/python_functions.cpp",
|
||||
"autograd/templates/python_functions.h",
|
||||
"autograd/templates/python_linalg_functions.cpp",
|
||||
"autograd/templates/python_nn_functions.cpp",
|
||||
"autograd/templates/python_return_types.cpp",
|
||||
"autograd/templates/python_sparse_functions.cpp",
|
||||
"autograd/templates/python_special_functions.cpp",
|
||||
"autograd/templates/python_torch_functions.cpp",
|
||||
"autograd/templates/python_variable_methods.cpp",
|
||||
"autograd/templates/variable_factories.h",
|
||||
],
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
third_party("pyyaml"),
|
||||
torchgen_deps,
|
||||
],
|
||||
)
|
||||
|
||||
python_library(
|
||||
name = "generate_code",
|
||||
srcs = [
|
||||
"setup_helpers/generate_code.py",
|
||||
],
|
||||
base_module = "tools",
|
||||
deps = [
|
||||
":autograd",
|
||||
":jit",
|
||||
torchgen_deps,
|
||||
],
|
||||
)
|
||||
|
||||
python_binary(
|
||||
name = "generate_code_bin",
|
||||
main_module = "tools.setup_helpers.generate_code",
|
||||
# Windows does not support inplace:
|
||||
# https://github.com/facebook/buck/issues/2161.
|
||||
#
|
||||
# Note that //arvr/mode/embedded/win/clang-aarch64-release sets
|
||||
# its target platform to
|
||||
# ovr_config//platform/embedded:clang-aarch64-linux-release, hence
|
||||
# that is why we are selecting that OS to trigger this behavior.
|
||||
package_style = select({
|
||||
"DEFAULT": "inplace",
|
||||
"ovr_config//os:linux-arm64": "standalone",
|
||||
}),
|
||||
visibility = ["PUBLIC"],
|
||||
# Because Windows does not support inplace packaging, we need to
|
||||
# ensure it is unzipped before executing it, otherwise it will not
|
||||
# be able to find any resources using path manipulation.
|
||||
#
|
||||
# See note above about why the OS is Linux here and not Windows.
|
||||
zip_safe = select({
|
||||
"DEFAULT": True,
|
||||
"ovr_config//os:linux-arm64": False,
|
||||
}),
|
||||
deps = [
|
||||
":generate_code",
|
||||
],
|
||||
)
|
||||
|
||||
python_library(
|
||||
name = "gen-version-header-lib",
|
||||
srcs = [
|
||||
"setup_helpers/gen_version_header.py",
|
||||
],
|
||||
base_module = "",
|
||||
deps = [],
|
||||
)
|
||||
|
||||
python_binary(
|
||||
name = "gen-version-header",
|
||||
main_module = "setup_helpers.gen_version_header",
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
":gen-version-header-lib",
|
||||
],
|
||||
)
|
||||
|
||||
python_library(
|
||||
name = "gen_aten_vulkan_spv_lib",
|
||||
srcs = [
|
||||
"gen_vulkan_spv.py",
|
||||
],
|
||||
base_module = "",
|
||||
deps = [
|
||||
torchgen_deps,
|
||||
],
|
||||
)
|
||||
|
||||
python_binary(
|
||||
name = "gen_aten_vulkan_spv_bin",
|
||||
main_module = "gen_vulkan_spv",
|
||||
visibility = [
|
||||
"PUBLIC",
|
||||
],
|
||||
deps = [
|
||||
":gen_aten_vulkan_spv_lib",
|
||||
],
|
||||
)
|
||||
|
||||
python_test(
|
||||
name = "selective_build_test",
|
||||
srcs = [
|
||||
"test/test_selective_build.py",
|
||||
],
|
||||
contacts = contacts,
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
torchgen_deps,
|
||||
],
|
||||
)
|
||||
|
||||
python_test(
|
||||
name = "gen_oplist_test",
|
||||
srcs = [
|
||||
"test/gen_oplist_test.py",
|
||||
],
|
||||
contacts = contacts,
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
":gen_oplist_lib",
|
||||
],
|
||||
)
|
||||
|
||||
python_test(
|
||||
name = "gen_operators_yaml_test",
|
||||
srcs = [
|
||||
"test/gen_operators_yaml_test.py",
|
||||
],
|
||||
visibility = ["PUBLIC"],
|
||||
contacts = contacts,
|
||||
deps = [
|
||||
":gen_operators_yaml_lib",
|
||||
],
|
||||
)
|
10
tools/BUCK.oss
Normal file
10
tools/BUCK.oss
Normal file
@ -0,0 +1,10 @@
|
||||
load("//:buckbuild.bzl", "third_party")
|
||||
load(":BUCK.bzl", "define_tools_targets")
|
||||
|
||||
define_tools_targets(
|
||||
python_binary = python_binary,
|
||||
python_library = python_library,
|
||||
python_test = python_test,
|
||||
third_party = third_party,
|
||||
torchgen_deps = "//torchgen:torchgen",
|
||||
)
|
@ -1,35 +0,0 @@
|
||||
python_library(
|
||||
name = "autograd",
|
||||
srcs = glob(
|
||||
["*.py"],
|
||||
),
|
||||
base_module = "tools.autograd",
|
||||
resources = [
|
||||
"deprecated.yaml",
|
||||
"derivatives.yaml",
|
||||
"templates/ADInplaceOrViewType.cpp",
|
||||
"templates/Functions.cpp",
|
||||
"templates/Functions.h",
|
||||
"templates/TraceType.cpp",
|
||||
"templates/VariableType.cpp",
|
||||
"templates/VariableType.h",
|
||||
"templates/annotated_fn_args.py.in",
|
||||
"templates/python_fft_functions.cpp",
|
||||
"templates/python_functions.cpp",
|
||||
"templates/python_functions.h",
|
||||
"templates/python_linalg_functions.cpp",
|
||||
"templates/python_nn_functions.cpp",
|
||||
"templates/python_return_types.cpp",
|
||||
"templates/python_sparse_functions.cpp",
|
||||
"templates/python_special_functions.cpp",
|
||||
"templates/python_torch_functions.cpp",
|
||||
"templates/python_variable_methods.cpp",
|
||||
"templates/variable_factories.h",
|
||||
"templates/python_enum_tag.cpp",
|
||||
],
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
"//third_party:pyyaml",
|
||||
"//torchgen:torchgen",
|
||||
],
|
||||
)
|
@ -1,9 +0,0 @@
|
||||
# Only used for PyTorch open source BUCK build
|
||||
# @lint-ignore-every BUCKRESTRICTEDSYNTAX
|
||||
# @lint-ignore-every FBCODEBZLADDLOADS
|
||||
|
||||
def fb_python_binary(**kwgs):
|
||||
if read_config("pt", "is_oss", "0") == "0":
|
||||
fail("This file is for open source pytorch build. Do not use it in fbsource!")
|
||||
|
||||
python_binary(**kwgs)
|
@ -1,9 +0,0 @@
|
||||
# Only used for PyTorch open source BUCK build
|
||||
# @lint-ignore-every BUCKRESTRICTEDSYNTAX
|
||||
# @lint-ignore-every FBCODEBZLADDLOADS
|
||||
|
||||
def fb_python_library(**kwgs):
|
||||
if read_config("pt", "is_oss", "0") == "0":
|
||||
fail("This file is for open source pytorch build. Do not use it in fbsource!")
|
||||
|
||||
python_library(**kwgs)
|
@ -1,12 +0,0 @@
|
||||
python_library(
|
||||
name = "jit",
|
||||
srcs = glob([
|
||||
"*.py",
|
||||
"templates/*",
|
||||
]),
|
||||
base_module = "tools.jit",
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
"//torchgen:torchgen",
|
||||
],
|
||||
)
|
@ -1,6 +0,0 @@
|
||||
python_library(
|
||||
name = "gen_selected_mobile_ops_header",
|
||||
srcs = ["gen_selected_mobile_ops_header.py"],
|
||||
base_module = "tools.lite_interpreter",
|
||||
visibility = ["PUBLIC"],
|
||||
)
|
@ -1,41 +0,0 @@
|
||||
python_library(
|
||||
name = "generate_code",
|
||||
srcs = [
|
||||
"generate_code.py",
|
||||
],
|
||||
base_module = "tools.setup_helpers",
|
||||
deps = [
|
||||
"//tools/autograd:autograd",
|
||||
"//tools/jit:jit",
|
||||
"//torchgen:torchgen",
|
||||
],
|
||||
)
|
||||
|
||||
python_binary(
|
||||
name = "generate_code_bin",
|
||||
main_module = "tools.setup_helpers.generate_code",
|
||||
visibility = ["PUBLIC"],
|
||||
# package_style = "inplace",
|
||||
zip_safe = False,
|
||||
deps = [
|
||||
":generate_code",
|
||||
],
|
||||
)
|
||||
|
||||
python_library(
|
||||
name = "gen-version-header-lib",
|
||||
srcs = [
|
||||
"gen_version_header.py",
|
||||
],
|
||||
base_module = "tools.setup_helpers",
|
||||
deps = [],
|
||||
)
|
||||
|
||||
python_binary(
|
||||
name = "gen-version-header",
|
||||
main_module = "tools.setup_helpers.gen_version_header",
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
":gen-version-header-lib",
|
||||
],
|
||||
)
|
190
tools/test/gen_operators_yaml_test.py
Normal file
190
tools/test/gen_operators_yaml_test.py
Normal file
@ -0,0 +1,190 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2004-present Facebook. All Rights Reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
from gen_operators_yaml import make_filter_from_options, verify_all_specified_present
|
||||
|
||||
|
||||
class GenOperatorsYAMLTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_filter_creation(self):
|
||||
filter_func = make_filter_from_options(
|
||||
model_name="abc",
|
||||
model_versions=["100", "101"],
|
||||
model_assets=None,
|
||||
model_backends=None,
|
||||
)
|
||||
config = [
|
||||
{
|
||||
"model": {
|
||||
"name": "abc",
|
||||
"version": 100,
|
||||
"asset": "asset-1",
|
||||
"backend": "CPU",
|
||||
},
|
||||
"root_operators": [],
|
||||
"traced_operators": [],
|
||||
},
|
||||
{
|
||||
"model": {
|
||||
"name": "abc",
|
||||
"version": 102,
|
||||
"asset": "asset-1",
|
||||
"backend": "CPU",
|
||||
},
|
||||
"root_operators": [],
|
||||
},
|
||||
{
|
||||
"model": {
|
||||
"name": "abcd",
|
||||
"version": 100,
|
||||
"asset": "asset-1",
|
||||
"backend": "CPU",
|
||||
},
|
||||
"root_operators": [],
|
||||
"traced_operators": [],
|
||||
},
|
||||
{
|
||||
"model": {
|
||||
"name": "abc",
|
||||
"version": 101,
|
||||
"asset": "asset-2",
|
||||
"backend": "CPU",
|
||||
},
|
||||
"root_operators": [],
|
||||
},
|
||||
]
|
||||
|
||||
filtered_configs = list(filter(filter_func, config))
|
||||
assert (
|
||||
len(filtered_configs) == 2
|
||||
), "Expected 2 elements in filtered_configs, but got {}".format(
|
||||
len(filtered_configs)
|
||||
)
|
||||
|
||||
def test_verification_success(self):
|
||||
filter_func = make_filter_from_options(
|
||||
model_name="abc",
|
||||
model_versions=["100", "101"],
|
||||
model_assets=["asset-1", "asset-2"],
|
||||
model_backends=None,
|
||||
)
|
||||
config = [
|
||||
{
|
||||
"model": {
|
||||
"name": "abc",
|
||||
"version": 100,
|
||||
"asset": "asset-1",
|
||||
"backend": "CPU",
|
||||
},
|
||||
"root_operators": [],
|
||||
"traced_operators": [],
|
||||
},
|
||||
{
|
||||
"model": {
|
||||
"name": "abc",
|
||||
"version": 101,
|
||||
"asset": "asset-2",
|
||||
"backend": "CPU",
|
||||
},
|
||||
"root_operators": [],
|
||||
},
|
||||
]
|
||||
filtered_configs = list(filter(filter_func, config))
|
||||
try:
|
||||
verify_all_specified_present(
|
||||
model_assets=["asset-1", "asset-2"],
|
||||
model_versions=["100", "101"],
|
||||
selected_models_yaml=filtered_configs,
|
||||
rule_name="test",
|
||||
model_name="abc",
|
||||
new_style_rule=True,
|
||||
)
|
||||
except Exception:
|
||||
self.fail(
|
||||
"expected verify_all_specified_present to succeed instead it raised an exception"
|
||||
)
|
||||
|
||||
def test_verification_fail(self):
|
||||
config = [
|
||||
{
|
||||
"model": {
|
||||
"name": "abc",
|
||||
"version": 100,
|
||||
"asset": "asset-1",
|
||||
"backend": "CPU",
|
||||
},
|
||||
"root_operators": [],
|
||||
"traced_operators": [],
|
||||
},
|
||||
{
|
||||
"model": {
|
||||
"name": "abc",
|
||||
"version": 101,
|
||||
"asset": "asset-2",
|
||||
"backend": "CPU",
|
||||
},
|
||||
"root_operators": [],
|
||||
},
|
||||
]
|
||||
|
||||
good_assets = ["asset-1", "asset-2"]
|
||||
good_versions = ["100", "101"]
|
||||
good_name = "abc"
|
||||
|
||||
# Test bad asset
|
||||
filter_func_bad_asset = make_filter_from_options(
|
||||
model_name=good_name,
|
||||
model_versions=good_versions,
|
||||
model_assets=["asset-1", "asset-2", "asset-3"],
|
||||
model_backends=None,
|
||||
)
|
||||
filtered_configs_asset = list(filter(filter_func_bad_asset, config))
|
||||
with self.assertRaises(RuntimeError):
|
||||
verify_all_specified_present(
|
||||
model_assets=["asset-1", "asset-2", "asset-3"],
|
||||
model_versions=good_versions,
|
||||
selected_models_yaml=filtered_configs_asset,
|
||||
rule_name="test",
|
||||
model_name=good_name,
|
||||
new_style_rule=True,
|
||||
)
|
||||
|
||||
# Test bad version
|
||||
filter_func_bad_version = make_filter_from_options(
|
||||
model_name=good_name,
|
||||
model_versions=["100", "101", "102"],
|
||||
model_assets=good_assets,
|
||||
model_backends=None,
|
||||
)
|
||||
filtered_configs_version = list(filter(filter_func_bad_version, config))
|
||||
with self.assertRaises(RuntimeError):
|
||||
verify_all_specified_present(
|
||||
model_assets=good_assets,
|
||||
model_versions=["100", "101", "102"],
|
||||
selected_models_yaml=filtered_configs_version,
|
||||
rule_name="test",
|
||||
model_name=good_name,
|
||||
new_style_rule=True,
|
||||
)
|
||||
|
||||
# Test bad name
|
||||
filter_func_bad_name = make_filter_from_options(
|
||||
model_name="abcd",
|
||||
model_versions=good_versions,
|
||||
model_assets=good_assets,
|
||||
model_backends=None,
|
||||
)
|
||||
filtered_configs_name = list(filter(filter_func_bad_name, config))
|
||||
with self.assertRaises(RuntimeError):
|
||||
verify_all_specified_present(
|
||||
model_assets=good_assets,
|
||||
model_versions=good_versions,
|
||||
selected_models_yaml=filtered_configs_name,
|
||||
rule_name="test",
|
||||
model_name="abcd",
|
||||
new_style_rule=True,
|
||||
)
|
35
tools/test/gen_oplist_test.py
Normal file
35
tools/test/gen_oplist_test.py
Normal file
@ -0,0 +1,35 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2004-present Facebook. All Rights Reserved.
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from gen_oplist import throw_if_any_op_includes_overloads
|
||||
|
||||
|
||||
class GenOplistTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_throw_if_any_op_includes_overloads(self):
|
||||
selective_builder = MagicMock()
|
||||
selective_builder.operators = MagicMock()
|
||||
selective_builder.operators.items.return_value = [
|
||||
("op1", MagicMock(include_all_overloads=True)),
|
||||
("op2", MagicMock(include_all_overloads=False)),
|
||||
("op3", MagicMock(include_all_overloads=True)),
|
||||
]
|
||||
|
||||
self.assertRaises(
|
||||
Exception, throw_if_any_op_includes_overloads, selective_builder
|
||||
)
|
||||
|
||||
selective_builder.operators.items.return_value = [
|
||||
("op1", MagicMock(include_all_overloads=False)),
|
||||
("op2", MagicMock(include_all_overloads=False)),
|
||||
("op3", MagicMock(include_all_overloads=False)),
|
||||
]
|
||||
|
||||
# Here we do not expect it to throw an exception since none of the ops
|
||||
# include all overloads.
|
||||
throw_if_any_op_includes_overloads(selective_builder)
|
281
tools/test/test_selective_build.py
Normal file
281
tools/test/test_selective_build.py
Normal file
@ -0,0 +1,281 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import unittest
|
||||
|
||||
from torchgen.selective_build.operator import *
|
||||
from torchgen.selective_build.selector import (
|
||||
combine_selective_builders,
|
||||
SelectiveBuilder,
|
||||
)
|
||||
|
||||
|
||||
class TestSelectiveBuild(unittest.TestCase):
|
||||
def test_selective_build_operator(self):
|
||||
op = SelectiveBuildOperator(
|
||||
"aten::add.int",
|
||||
is_root_operator=True,
|
||||
is_used_for_training=False,
|
||||
include_all_overloads=False,
|
||||
_debug_info=None,
|
||||
)
|
||||
self.assertTrue(op.is_root_operator)
|
||||
self.assertFalse(op.is_used_for_training)
|
||||
self.assertFalse(op.include_all_overloads)
|
||||
|
||||
def test_selector_factory(self):
|
||||
yaml_config_v1 = """
|
||||
debug_info:
|
||||
- model1@v100
|
||||
- model2@v51
|
||||
operators:
|
||||
aten::add:
|
||||
is_used_for_training: No
|
||||
is_root_operator: Yes
|
||||
include_all_overloads: Yes
|
||||
aten::add.int:
|
||||
is_used_for_training: Yes
|
||||
is_root_operator: No
|
||||
include_all_overloads: No
|
||||
aten::mul.int:
|
||||
is_used_for_training: Yes
|
||||
is_root_operator: No
|
||||
include_all_overloads: No
|
||||
"""
|
||||
|
||||
yaml_config_v2 = """
|
||||
debug_info:
|
||||
- model1@v100
|
||||
- model2@v51
|
||||
operators:
|
||||
aten::sub:
|
||||
is_used_for_training: No
|
||||
is_root_operator: Yes
|
||||
include_all_overloads: No
|
||||
debug_info:
|
||||
- model1@v100
|
||||
aten::sub.int:
|
||||
is_used_for_training: Yes
|
||||
is_root_operator: No
|
||||
include_all_overloads: No
|
||||
"""
|
||||
|
||||
yaml_config_all = "include_all_operators: Yes"
|
||||
|
||||
yaml_config_invalid = "invalid:"
|
||||
|
||||
selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1)
|
||||
|
||||
self.assertTrue(selector1.is_operator_selected("aten::add"))
|
||||
self.assertTrue(selector1.is_operator_selected("aten::add.int"))
|
||||
# Overload name is not used for checking in v1.
|
||||
self.assertTrue(selector1.is_operator_selected("aten::add.float"))
|
||||
|
||||
def gen():
|
||||
return SelectiveBuilder.from_yaml_str(yaml_config_invalid)
|
||||
|
||||
self.assertRaises(Exception, gen)
|
||||
|
||||
selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all)
|
||||
|
||||
self.assertTrue(selector_all.is_operator_selected("aten::add"))
|
||||
self.assertTrue(selector_all.is_operator_selected("aten::sub"))
|
||||
self.assertTrue(selector_all.is_operator_selected("aten::sub.int"))
|
||||
self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32"))
|
||||
|
||||
selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2)
|
||||
|
||||
self.assertFalse(selector2.is_operator_selected("aten::add"))
|
||||
self.assertTrue(selector2.is_operator_selected("aten::sub"))
|
||||
self.assertTrue(selector2.is_operator_selected("aten::sub.int"))
|
||||
|
||||
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||
["aten::add", "aten::add.int", "aten::mul.int"],
|
||||
False,
|
||||
False,
|
||||
)
|
||||
self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float"))
|
||||
self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add"))
|
||||
self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int"))
|
||||
self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub"))
|
||||
|
||||
self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
|
||||
self.assertFalse(
|
||||
selector_legacy_v1.is_operator_selected_for_training("aten::add")
|
||||
)
|
||||
|
||||
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||
["aten::add", "aten::add.int", "aten::mul.int"],
|
||||
True,
|
||||
False,
|
||||
)
|
||||
|
||||
self.assertTrue(selector_legacy_v1.is_root_operator("aten::add"))
|
||||
self.assertFalse(
|
||||
selector_legacy_v1.is_operator_selected_for_training("aten::add")
|
||||
)
|
||||
self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float"))
|
||||
self.assertFalse(
|
||||
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
|
||||
)
|
||||
|
||||
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||
["aten::add", "aten::add.int", "aten::mul.int"],
|
||||
False,
|
||||
True,
|
||||
)
|
||||
|
||||
self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
|
||||
self.assertTrue(
|
||||
selector_legacy_v1.is_operator_selected_for_training("aten::add")
|
||||
)
|
||||
self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float"))
|
||||
self.assertTrue(
|
||||
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
|
||||
)
|
||||
|
||||
def test_operator_combine(self):
|
||||
op1 = SelectiveBuildOperator(
|
||||
"aten::add.int",
|
||||
is_root_operator=True,
|
||||
is_used_for_training=False,
|
||||
include_all_overloads=False,
|
||||
_debug_info=None,
|
||||
)
|
||||
op2 = SelectiveBuildOperator(
|
||||
"aten::add.int",
|
||||
is_root_operator=False,
|
||||
is_used_for_training=False,
|
||||
include_all_overloads=False,
|
||||
_debug_info=None,
|
||||
)
|
||||
op3 = SelectiveBuildOperator(
|
||||
"aten::add",
|
||||
is_root_operator=True,
|
||||
is_used_for_training=False,
|
||||
include_all_overloads=False,
|
||||
_debug_info=None,
|
||||
)
|
||||
op4 = SelectiveBuildOperator(
|
||||
"aten::add.int",
|
||||
is_root_operator=True,
|
||||
is_used_for_training=True,
|
||||
include_all_overloads=False,
|
||||
_debug_info=None,
|
||||
)
|
||||
|
||||
op5 = combine_operators(op1, op2)
|
||||
|
||||
self.assertTrue(op5.is_root_operator)
|
||||
self.assertFalse(op5.is_used_for_training)
|
||||
|
||||
op6 = combine_operators(op1, op4)
|
||||
|
||||
self.assertTrue(op6.is_root_operator)
|
||||
self.assertTrue(op6.is_used_for_training)
|
||||
|
||||
def gen_new_op():
|
||||
return combine_operators(op1, op3)
|
||||
|
||||
self.assertRaises(Exception, gen_new_op)
|
||||
|
||||
def test_training_op_fetch(self):
|
||||
yaml_config = """
|
||||
operators:
|
||||
aten::add.int:
|
||||
is_used_for_training: No
|
||||
is_root_operator: Yes
|
||||
include_all_overloads: No
|
||||
aten::add:
|
||||
is_used_for_training: Yes
|
||||
is_root_operator: No
|
||||
include_all_overloads: Yes
|
||||
"""
|
||||
|
||||
selector = SelectiveBuilder.from_yaml_str(yaml_config)
|
||||
self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
|
||||
self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
|
||||
|
||||
def test_kernel_dtypes(self):
|
||||
yaml_config = """
|
||||
kernel_metadata:
|
||||
add_kernel:
|
||||
- int8
|
||||
- int32
|
||||
sub_kernel:
|
||||
- int16
|
||||
- int32
|
||||
add/sub_kernel:
|
||||
- float
|
||||
- complex
|
||||
"""
|
||||
|
||||
selector = SelectiveBuilder.from_yaml_str(yaml_config)
|
||||
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))
|
||||
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
|
||||
|
||||
def test_merge_kernel_dtypes(self):
|
||||
yaml_config1 = """
|
||||
kernel_metadata:
|
||||
add_kernel:
|
||||
- int8
|
||||
add/sub_kernel:
|
||||
- float
|
||||
- complex
|
||||
- none
|
||||
mul_kernel:
|
||||
- int8
|
||||
"""
|
||||
|
||||
yaml_config2 = """
|
||||
kernel_metadata:
|
||||
add_kernel:
|
||||
- int32
|
||||
sub_kernel:
|
||||
- int16
|
||||
- int32
|
||||
add/sub_kernel:
|
||||
- float
|
||||
- complex
|
||||
"""
|
||||
|
||||
selector1 = SelectiveBuilder.from_yaml_str(yaml_config1)
|
||||
selector2 = SelectiveBuilder.from_yaml_str(yaml_config2)
|
||||
|
||||
selector = combine_selective_builders(selector1, selector2)
|
||||
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))
|
||||
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "none"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
|
||||
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
|
||||
|
||||
def test_all_kernel_dtypes_selected(self):
|
||||
yaml_config = """
|
||||
include_all_non_op_selectives: True
|
||||
"""
|
||||
|
||||
selector = SelectiveBuilder.from_yaml_str(yaml_config)
|
||||
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int16"))
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))
|
Reference in New Issue
Block a user