[4] move pt_operator_library to shared BUCK file (#80170)

Summary:
Move pt_operator_library to pt_ops.bzl and make it shared with OSS BUCK build

This will replace D36912042. I will update all load statements in future diffs.

Test Plan: sandcaslte, OSS CI

Differential Revision: D37390060

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80170
Approved by: https://github.com/JacobSzwejbka
This commit is contained in:
Linbin Yu
2022-06-24 21:51:20 +00:00
committed by PyTorch MergeBot
parent bda04e9f5e
commit e98e7fe428
5 changed files with 126 additions and 208 deletions

View File

@ -62,29 +62,17 @@ jobs:
command: | command: |
sh scripts/buck_setup.sh sh scripts/buck_setup.sh
- name: Build glog
run: |
buck build third_party:glog
- name: Build C10 - name: Build C10
run: | run: |
buck build c10:c10 buck build c10:c10
- name: Build cpuinfo
run: |
buck build third_party:cpuinfo
- name: Build pthreadpool
run: |
buck build third_party:pthreadpool
- name: Build XNNPACK - name: Build XNNPACK
run: | run: |
buck build third_party:XNNPACK buck build third_party:XNNPACK
- name: Build QNNPACK - name: Build QNNPACK
run: | run: |
buck build aten/src/ATen/native/quantized/cpu/qnnpack/... --keep-going buck build aten/src/ATen/native/quantized/cpu/qnnpack:pytorch_qnnpack
- name: Build aten_cpu - name: Build aten_cpu
run: | run: |
@ -94,9 +82,9 @@ jobs:
run: | run: |
buck build :torch_mobile_core buck build :torch_mobile_core
- name: Build torch_mobile_all_ops - name: Build pt_ops_full
run: | run: |
buck build :torch_mobile_all_ops buck build :pt_ops_full
- name: Build mobile benchmark - name: Build mobile benchmark
run: | run: |

View File

@ -1,15 +1,17 @@
load("//tools/build_defs:glob_defs.bzl", "subdir_glob") load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
load( load(
":pt_defs.oss.bzl", ":pt_ops.bzl",
"pt_operator_library", "pt_operator_library",
"get_pt_ops_deps",
) )
load(":buckbuild.bzl", load(":buckbuild.bzl",
"define_buck_targets", "define_buck_targets",
"get_pt_operator_registry_dict",
) )
# define shared buck targets
define_buck_targets() define_buck_targets()
# define OSS only targets
cxx_library( cxx_library(
name = "pthreadpool", name = "pthreadpool",
srcs = ['caffe2/utils/threadpool/pthreadpool.cc', 'caffe2/utils/threadpool/pthreadpool_impl.cc', 'caffe2/utils/threadpool/pthreadpool-cpp.cc', 'caffe2/utils/threadpool/thread_pool_guard.cpp', 'caffe2/utils/threadpool/ThreadPool.cc'], srcs = ['caffe2/utils/threadpool/pthreadpool.cc', 'caffe2/utils/threadpool/pthreadpool_impl.cc', 'caffe2/utils/threadpool/pthreadpool-cpp.cc', 'caffe2/utils/threadpool/thread_pool_guard.cpp', 'caffe2/utils/threadpool/ThreadPool.cc'],
@ -76,21 +78,17 @@ cxx_library(
pt_operator_library( pt_operator_library(
name = "torch_mobile_ops_full_dev", name = "torch_mobile_ops_full_dev",
check_decl = False,
include_all_operators = True, include_all_operators = True,
) )
cxx_library( cxx_library(
name = "torch_mobile_all_ops", name = "pt_ops_full",
visibility = ["PUBLIC"], **get_pt_operator_registry_dict(
deps = get_pt_ops_deps(
name = "pt_ops_full", name = "pt_ops_full",
train = False,
deps = [ deps = [
":torch_mobile_ops_full_dev", ":torch_mobile_ops_full_dev",
], ],
enable_flatbuffer = False, )
),
) )
cxx_binary( cxx_binary(
@ -118,7 +116,7 @@ cxx_binary(
], ],
deps = [ deps = [
":torch_mobile_core", ":torch_mobile_core",
":torch_mobile_all_ops", ":pt_ops_full",
"//c10:c10", "//c10:c10",
], ],
) )

View File

@ -25,6 +25,10 @@ load(
"jit_core_sources", "jit_core_sources",
"libtorch_profiler_sources", "libtorch_profiler_sources",
) )
load(
":pt_ops.bzl",
"USED_PT_BACKENDS",
)
load( load(
":pt_template_srcs.bzl", ":pt_template_srcs.bzl",
"METAL_MASKRCNN_SOURCE_LIST", "METAL_MASKRCNN_SOURCE_LIST",
@ -235,12 +239,6 @@ def get_pt_preprocessor_flags():
PT_PREPROCESSOR_FLAGS.append("-DENABLE_PYTORCH_NON_PRODUCTION_BUILDS") PT_PREPROCESSOR_FLAGS.append("-DENABLE_PYTORCH_NON_PRODUCTION_BUILDS")
return PT_PREPROCESSOR_FLAGS return PT_PREPROCESSOR_FLAGS
USED_PT_BACKENDS = [
"CPU",
"QuantizedCPU",
"SparseCPU", # brings ~20 kb size regression
]
# This needs to be kept in sync with https://github.com/pytorch/pytorch/blob/release/1.9/torchgen/gen.py#L892 # This needs to be kept in sync with https://github.com/pytorch/pytorch/blob/release/1.9/torchgen/gen.py#L892
PT_BACKEND_HEADERS = [ PT_BACKEND_HEADERS = [
"CPU", "CPU",

View File

@ -1,177 +0,0 @@
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
load(
":buckbuild.bzl",
"get_pt_operator_registry_dict",
)
PT_BASE_OPS = [
"aten::_coalesced_",
"aten::_copy_from",
"aten::_empty_affine_quantized",
"aten::_empty_per_channel_affine_quantized",
"aten::_indices",
"aten::_nnz",
"aten::_values",
"aten::add",
"aten::add_",
"aten::arange",
"aten::as_strided",
"aten::as_strided_",
"aten::cat",
"aten::clone",
"aten::coalesce",
"aten::contiguous",
"aten::copy_",
"aten::copy_sparse_to_sparse_",
"aten::dense_dim",
"aten::dequantize",
"aten::div",
"aten::div_",
"aten::empty",
"aten::empty_like",
"aten::empty_strided",
"aten::empty.memory_format",
"aten::eq",
"aten::equal",
"aten::expand",
"aten::fill_",
"aten::is_coalesced",
"aten::is_complex",
"aten::is_floating_point",
"aten::is_leaf",
"aten::is_nonzero",
"aten::item",
"aten::max",
"aten::min",
"aten::mul",
"aten::mul_",
"aten::narrow",
"aten::ne",
"aten::permute",
"aten::q_per_channel_axis",
"aten::q_per_channel_scales",
"aten::q_per_channel_zero_points",
"aten::q_scale",
"aten::q_zero_point",
"aten::qscheme",
"aten::quantize_per_tensor",
"aten::reshape",
"aten::_reshape_alias",
"aten::resize_",
"aten::resize_as_",
"aten::scalar_tensor",
"aten::select",
"aten::set_",
"aten::size",
"aten::slice",
"aten::sparse_dim",
"aten::sparse_resize_and_clear_",
"aten::squeeze",
"aten::squeeze_",
"aten::stride",
"aten::sub",
"aten::sub_",
"aten::sum",
"aten::t",
"aten::to",
"aten::_to_copy",
"aten::unsqueeze",
"aten::view",
"aten::zero_",
"aten::zeros",
"aten::zeros_like",
]
######### selective build #########
def pt_operator_registry(
name,
deps = [],
train = False,
labels = [],
env = [],
template_select = True,
enforce_traced_op_list = False,
pt_allow_forced_schema_registration = True,
enable_flatbuffer = False,
**kwargs):
args = get_pt_operator_registry_dict(
name,
deps,
train,
labels,
env,
template_select,
enforce_traced_op_list,
pt_allow_forced_schema_registration,
enable_flatbuffer = True,
**kwargs
)
fb_xplat_cxx_library(
name = name,
**args
)
def get_pt_ops_deps(name, deps, train = False, enforce_traced_op_list = False, enable_flatbuffer = False, **kwargs):
pt_operator_registry(
name,
deps,
train = train,
enforce_traced_op_list = enforce_traced_op_list,
enable_flatbuffer = enable_flatbuffer,
**kwargs
)
return deps + [":" + name]
def pt_operator_library(
name,
ops = [],
exported_deps = [],
check_decl = True,
train = False,
model = None,
include_all_operators = False,
**kwargs):
model_name = name
ops = [op.strip() for op in ops]
# If ops are specified, then we are in static selective build mode, so we append
# base ops to this list to avoid additional special case logic in subsequent code.
if len(ops) > 0:
ops.extend(PT_BASE_OPS)
visibility = kwargs.pop("visibility", ["PUBLIC"])
fb_xplat_genrule(
name = name,
out = "model_operators.yaml",
cmd = (
"$(exe :gen_operators_yaml) " +
"{optionally_root_ops} " +
"{optionally_training_root_ops} " +
"--rule_name {rule_name} " +
"--output_path \"${{OUT}}\" " +
"--model_name {model_name} " +
"--dep_graph_yaml_path pytorch_op_deps.yaml " +
"--models_yaml_path all_mobile_model_configs.yaml " +
#"{optionally_model_versions} " +
#"{optionally_model_assets} " +
#"{optionally_model_traced_backends} " +
"{optionally_include_all_operators}"
).format(
rule_name = name,
model_name = model_name,
optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "",
optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "",
#optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "",
#optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
#optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
),
labels = ["pt_operator_library"], # for pt_operator_query_codegen query
visibility = visibility,
**kwargs
)

View File

@ -1,3 +1,114 @@
load("//tools/build_defs:expect.bzl", "expect")
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
# @lint-ignore BUCKRESTRICTEDSYNTAX
IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
USED_PT_BACKENDS = [
"CPU",
"QuantizedCPU",
"SparseCPU", # brings ~20 kb size regression
]
def pt_operator_library(
name,
ops = [],
exported_deps = [],
check_decl = True,
train = False,
model = None,
include_all_operators = False,
**kwargs):
(model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information(
name,
model,
)
ops = [op.strip() for op in ops]
# If ops are specified, then we are in static selective build mode, so we append
# base ops to this list to avoid additional special case logic in subsequent code.
if len(ops) > 0:
ops.extend(PT_BASE_OPS)
labels = kwargs.pop("labels", [])
visibility = kwargs.pop("visibility", ["PUBLIC"])
fb_xplat_genrule(
name = name,
out = "model_operators.yaml",
cmd = (
"$(exe {root}:gen_operators_yaml) " +
"{optionally_root_ops} " +
"{optionally_training_root_ops} " +
"--rule_name {rule_name} " +
"--output_path \"${{OUT}}\" " +
"--model_name {model_name} " +
"--dep_graph_yaml_path {dep_graph_yaml} " +
"--models_yaml_path {models_yaml} " +
"{optionally_model_versions} " +
"{optionally_model_assets} " +
"{optionally_model_traced_backends} " +
"{optionally_include_all_operators}"
).format(
root = "//" if IS_OSS else "//xplat/caffe2",
rule_name = name,
model_name = model_name,
dep_graph_yaml = "none" if IS_OSS else "$(location //xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
models_yaml = "none" if IS_OSS else "$(location //xplat/pytorch_models:all_mobile_model_configs)/build/all_mobile_model_configs.yaml ",
optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "",
optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "",
optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "",
optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
),
labels = labels + [
"pt_operator_library",
"supermodule:android/default/pytorch",
"supermodule:ios/default/public.pytorch",
] + (["pt_train_operator_library"] if train else []),
visibility = visibility,
**kwargs
)
def validate_and_extract_model_information(name, model):
model_name = name
model_versions = None
model_assets = None
model_traced_backends = None
if model != None:
model_name = model.get("name")
expect(model_name != None, "Expected Model Name to be present")
model_versions = model.get("versions")
expect(is_list(model_versions), "Expected model versions to be a list of string")
for ver in model_versions or []:
expect(is_string(ver), "Expected version '{}' to be string".format(str(ver)))
model_assets = model.get("assets")
expect(
model_assets == None or is_list(model_assets),
"Expected model assets to be a list of string if specified",
)
for asset_name in model_assets or []:
expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name)))
model_traced_backends = model.get("traced_backends")
expect(
model_traced_backends == None or is_list(model_traced_backends),
"Expected model traced backends to be a list of string if specified",
)
if model_traced_backends != None:
for backend in model_traced_backends:
expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend)))
expect(
backend in USED_PT_BACKENDS,
"Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)),
)
return (model_name, model_versions, model_assets, model_traced_backends)
# This file keeps a list of PyTorch operators used by any targets in # This file keeps a list of PyTorch operators used by any targets in
# @fbsource//xplat/... # @fbsource//xplat/...
# The purpose of the list is to avoid generating large number of unused # The purpose of the list is to avoid generating large number of unused