mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[4] move pt_operator_library to shared BUCK file (#80170)
Summary: Move pt_operator_library to pt_ops.bzl and make it shared with OSS BUCK build This will replace D36912042. I will update all load statements in future diffs. Test Plan: sandcaslte, OSS CI Differential Revision: D37390060 Pull Request resolved: https://github.com/pytorch/pytorch/pull/80170 Approved by: https://github.com/JacobSzwejbka
This commit is contained in:
committed by
PyTorch MergeBot
parent
bda04e9f5e
commit
e98e7fe428
18
.github/workflows/_buck-build-test.yml
vendored
18
.github/workflows/_buck-build-test.yml
vendored
@ -62,29 +62,17 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
sh scripts/buck_setup.sh
|
sh scripts/buck_setup.sh
|
||||||
|
|
||||||
- name: Build glog
|
|
||||||
run: |
|
|
||||||
buck build third_party:glog
|
|
||||||
|
|
||||||
- name: Build C10
|
- name: Build C10
|
||||||
run: |
|
run: |
|
||||||
buck build c10:c10
|
buck build c10:c10
|
||||||
|
|
||||||
- name: Build cpuinfo
|
|
||||||
run: |
|
|
||||||
buck build third_party:cpuinfo
|
|
||||||
|
|
||||||
- name: Build pthreadpool
|
|
||||||
run: |
|
|
||||||
buck build third_party:pthreadpool
|
|
||||||
|
|
||||||
- name: Build XNNPACK
|
- name: Build XNNPACK
|
||||||
run: |
|
run: |
|
||||||
buck build third_party:XNNPACK
|
buck build third_party:XNNPACK
|
||||||
|
|
||||||
- name: Build QNNPACK
|
- name: Build QNNPACK
|
||||||
run: |
|
run: |
|
||||||
buck build aten/src/ATen/native/quantized/cpu/qnnpack/... --keep-going
|
buck build aten/src/ATen/native/quantized/cpu/qnnpack:pytorch_qnnpack
|
||||||
|
|
||||||
- name: Build aten_cpu
|
- name: Build aten_cpu
|
||||||
run: |
|
run: |
|
||||||
@ -94,9 +82,9 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
buck build :torch_mobile_core
|
buck build :torch_mobile_core
|
||||||
|
|
||||||
- name: Build torch_mobile_all_ops
|
- name: Build pt_ops_full
|
||||||
run: |
|
run: |
|
||||||
buck build :torch_mobile_all_ops
|
buck build :pt_ops_full
|
||||||
|
|
||||||
- name: Build mobile benchmark
|
- name: Build mobile benchmark
|
||||||
run: |
|
run: |
|
||||||
|
18
BUCK.oss
18
BUCK.oss
@ -1,15 +1,17 @@
|
|||||||
load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
|
load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
|
||||||
load(
|
load(
|
||||||
":pt_defs.oss.bzl",
|
":pt_ops.bzl",
|
||||||
"pt_operator_library",
|
"pt_operator_library",
|
||||||
"get_pt_ops_deps",
|
|
||||||
)
|
)
|
||||||
load(":buckbuild.bzl",
|
load(":buckbuild.bzl",
|
||||||
"define_buck_targets",
|
"define_buck_targets",
|
||||||
|
"get_pt_operator_registry_dict",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# define shared buck targets
|
||||||
define_buck_targets()
|
define_buck_targets()
|
||||||
|
|
||||||
|
# define OSS only targets
|
||||||
cxx_library(
|
cxx_library(
|
||||||
name = "pthreadpool",
|
name = "pthreadpool",
|
||||||
srcs = ['caffe2/utils/threadpool/pthreadpool.cc', 'caffe2/utils/threadpool/pthreadpool_impl.cc', 'caffe2/utils/threadpool/pthreadpool-cpp.cc', 'caffe2/utils/threadpool/thread_pool_guard.cpp', 'caffe2/utils/threadpool/ThreadPool.cc'],
|
srcs = ['caffe2/utils/threadpool/pthreadpool.cc', 'caffe2/utils/threadpool/pthreadpool_impl.cc', 'caffe2/utils/threadpool/pthreadpool-cpp.cc', 'caffe2/utils/threadpool/thread_pool_guard.cpp', 'caffe2/utils/threadpool/ThreadPool.cc'],
|
||||||
@ -76,21 +78,17 @@ cxx_library(
|
|||||||
|
|
||||||
pt_operator_library(
|
pt_operator_library(
|
||||||
name = "torch_mobile_ops_full_dev",
|
name = "torch_mobile_ops_full_dev",
|
||||||
check_decl = False,
|
|
||||||
include_all_operators = True,
|
include_all_operators = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cxx_library(
|
cxx_library(
|
||||||
name = "torch_mobile_all_ops",
|
name = "pt_ops_full",
|
||||||
visibility = ["PUBLIC"],
|
**get_pt_operator_registry_dict(
|
||||||
deps = get_pt_ops_deps(
|
|
||||||
name = "pt_ops_full",
|
name = "pt_ops_full",
|
||||||
train = False,
|
|
||||||
deps = [
|
deps = [
|
||||||
":torch_mobile_ops_full_dev",
|
":torch_mobile_ops_full_dev",
|
||||||
],
|
],
|
||||||
enable_flatbuffer = False,
|
)
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cxx_binary(
|
cxx_binary(
|
||||||
@ -118,7 +116,7 @@ cxx_binary(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":torch_mobile_core",
|
":torch_mobile_core",
|
||||||
":torch_mobile_all_ops",
|
":pt_ops_full",
|
||||||
"//c10:c10",
|
"//c10:c10",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -25,6 +25,10 @@ load(
|
|||||||
"jit_core_sources",
|
"jit_core_sources",
|
||||||
"libtorch_profiler_sources",
|
"libtorch_profiler_sources",
|
||||||
)
|
)
|
||||||
|
load(
|
||||||
|
":pt_ops.bzl",
|
||||||
|
"USED_PT_BACKENDS",
|
||||||
|
)
|
||||||
load(
|
load(
|
||||||
":pt_template_srcs.bzl",
|
":pt_template_srcs.bzl",
|
||||||
"METAL_MASKRCNN_SOURCE_LIST",
|
"METAL_MASKRCNN_SOURCE_LIST",
|
||||||
@ -235,12 +239,6 @@ def get_pt_preprocessor_flags():
|
|||||||
PT_PREPROCESSOR_FLAGS.append("-DENABLE_PYTORCH_NON_PRODUCTION_BUILDS")
|
PT_PREPROCESSOR_FLAGS.append("-DENABLE_PYTORCH_NON_PRODUCTION_BUILDS")
|
||||||
return PT_PREPROCESSOR_FLAGS
|
return PT_PREPROCESSOR_FLAGS
|
||||||
|
|
||||||
USED_PT_BACKENDS = [
|
|
||||||
"CPU",
|
|
||||||
"QuantizedCPU",
|
|
||||||
"SparseCPU", # brings ~20 kb size regression
|
|
||||||
]
|
|
||||||
|
|
||||||
# This needs to be kept in sync with https://github.com/pytorch/pytorch/blob/release/1.9/torchgen/gen.py#L892
|
# This needs to be kept in sync with https://github.com/pytorch/pytorch/blob/release/1.9/torchgen/gen.py#L892
|
||||||
PT_BACKEND_HEADERS = [
|
PT_BACKEND_HEADERS = [
|
||||||
"CPU",
|
"CPU",
|
||||||
|
177
pt_defs.oss.bzl
177
pt_defs.oss.bzl
@ -1,177 +0,0 @@
|
|||||||
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
|
|
||||||
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
|
||||||
load(
|
|
||||||
":buckbuild.bzl",
|
|
||||||
"get_pt_operator_registry_dict",
|
|
||||||
)
|
|
||||||
|
|
||||||
PT_BASE_OPS = [
|
|
||||||
"aten::_coalesced_",
|
|
||||||
"aten::_copy_from",
|
|
||||||
"aten::_empty_affine_quantized",
|
|
||||||
"aten::_empty_per_channel_affine_quantized",
|
|
||||||
"aten::_indices",
|
|
||||||
"aten::_nnz",
|
|
||||||
"aten::_values",
|
|
||||||
"aten::add",
|
|
||||||
"aten::add_",
|
|
||||||
"aten::arange",
|
|
||||||
"aten::as_strided",
|
|
||||||
"aten::as_strided_",
|
|
||||||
"aten::cat",
|
|
||||||
"aten::clone",
|
|
||||||
"aten::coalesce",
|
|
||||||
"aten::contiguous",
|
|
||||||
"aten::copy_",
|
|
||||||
"aten::copy_sparse_to_sparse_",
|
|
||||||
"aten::dense_dim",
|
|
||||||
"aten::dequantize",
|
|
||||||
"aten::div",
|
|
||||||
"aten::div_",
|
|
||||||
"aten::empty",
|
|
||||||
"aten::empty_like",
|
|
||||||
"aten::empty_strided",
|
|
||||||
"aten::empty.memory_format",
|
|
||||||
"aten::eq",
|
|
||||||
"aten::equal",
|
|
||||||
"aten::expand",
|
|
||||||
"aten::fill_",
|
|
||||||
"aten::is_coalesced",
|
|
||||||
"aten::is_complex",
|
|
||||||
"aten::is_floating_point",
|
|
||||||
"aten::is_leaf",
|
|
||||||
"aten::is_nonzero",
|
|
||||||
"aten::item",
|
|
||||||
"aten::max",
|
|
||||||
"aten::min",
|
|
||||||
"aten::mul",
|
|
||||||
"aten::mul_",
|
|
||||||
"aten::narrow",
|
|
||||||
"aten::ne",
|
|
||||||
"aten::permute",
|
|
||||||
"aten::q_per_channel_axis",
|
|
||||||
"aten::q_per_channel_scales",
|
|
||||||
"aten::q_per_channel_zero_points",
|
|
||||||
"aten::q_scale",
|
|
||||||
"aten::q_zero_point",
|
|
||||||
"aten::qscheme",
|
|
||||||
"aten::quantize_per_tensor",
|
|
||||||
"aten::reshape",
|
|
||||||
"aten::_reshape_alias",
|
|
||||||
"aten::resize_",
|
|
||||||
"aten::resize_as_",
|
|
||||||
"aten::scalar_tensor",
|
|
||||||
"aten::select",
|
|
||||||
"aten::set_",
|
|
||||||
"aten::size",
|
|
||||||
"aten::slice",
|
|
||||||
"aten::sparse_dim",
|
|
||||||
"aten::sparse_resize_and_clear_",
|
|
||||||
"aten::squeeze",
|
|
||||||
"aten::squeeze_",
|
|
||||||
"aten::stride",
|
|
||||||
"aten::sub",
|
|
||||||
"aten::sub_",
|
|
||||||
"aten::sum",
|
|
||||||
"aten::t",
|
|
||||||
"aten::to",
|
|
||||||
"aten::_to_copy",
|
|
||||||
"aten::unsqueeze",
|
|
||||||
"aten::view",
|
|
||||||
"aten::zero_",
|
|
||||||
"aten::zeros",
|
|
||||||
"aten::zeros_like",
|
|
||||||
]
|
|
||||||
|
|
||||||
######### selective build #########
|
|
||||||
|
|
||||||
def pt_operator_registry(
|
|
||||||
name,
|
|
||||||
deps = [],
|
|
||||||
train = False,
|
|
||||||
labels = [],
|
|
||||||
env = [],
|
|
||||||
template_select = True,
|
|
||||||
enforce_traced_op_list = False,
|
|
||||||
pt_allow_forced_schema_registration = True,
|
|
||||||
enable_flatbuffer = False,
|
|
||||||
**kwargs):
|
|
||||||
args = get_pt_operator_registry_dict(
|
|
||||||
name,
|
|
||||||
deps,
|
|
||||||
train,
|
|
||||||
labels,
|
|
||||||
env,
|
|
||||||
template_select,
|
|
||||||
enforce_traced_op_list,
|
|
||||||
pt_allow_forced_schema_registration,
|
|
||||||
enable_flatbuffer = True,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
fb_xplat_cxx_library(
|
|
||||||
name = name,
|
|
||||||
**args
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_pt_ops_deps(name, deps, train = False, enforce_traced_op_list = False, enable_flatbuffer = False, **kwargs):
|
|
||||||
pt_operator_registry(
|
|
||||||
name,
|
|
||||||
deps,
|
|
||||||
train = train,
|
|
||||||
enforce_traced_op_list = enforce_traced_op_list,
|
|
||||||
enable_flatbuffer = enable_flatbuffer,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
return deps + [":" + name]
|
|
||||||
|
|
||||||
def pt_operator_library(
|
|
||||||
name,
|
|
||||||
ops = [],
|
|
||||||
exported_deps = [],
|
|
||||||
check_decl = True,
|
|
||||||
train = False,
|
|
||||||
model = None,
|
|
||||||
include_all_operators = False,
|
|
||||||
**kwargs):
|
|
||||||
model_name = name
|
|
||||||
|
|
||||||
ops = [op.strip() for op in ops]
|
|
||||||
|
|
||||||
# If ops are specified, then we are in static selective build mode, so we append
|
|
||||||
# base ops to this list to avoid additional special case logic in subsequent code.
|
|
||||||
if len(ops) > 0:
|
|
||||||
ops.extend(PT_BASE_OPS)
|
|
||||||
|
|
||||||
visibility = kwargs.pop("visibility", ["PUBLIC"])
|
|
||||||
|
|
||||||
fb_xplat_genrule(
|
|
||||||
name = name,
|
|
||||||
out = "model_operators.yaml",
|
|
||||||
cmd = (
|
|
||||||
"$(exe :gen_operators_yaml) " +
|
|
||||||
"{optionally_root_ops} " +
|
|
||||||
"{optionally_training_root_ops} " +
|
|
||||||
"--rule_name {rule_name} " +
|
|
||||||
"--output_path \"${{OUT}}\" " +
|
|
||||||
"--model_name {model_name} " +
|
|
||||||
"--dep_graph_yaml_path pytorch_op_deps.yaml " +
|
|
||||||
"--models_yaml_path all_mobile_model_configs.yaml " +
|
|
||||||
#"{optionally_model_versions} " +
|
|
||||||
#"{optionally_model_assets} " +
|
|
||||||
#"{optionally_model_traced_backends} " +
|
|
||||||
"{optionally_include_all_operators}"
|
|
||||||
).format(
|
|
||||||
rule_name = name,
|
|
||||||
model_name = model_name,
|
|
||||||
optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "",
|
|
||||||
optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "",
|
|
||||||
#optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "",
|
|
||||||
#optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
|
|
||||||
#optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
|
|
||||||
optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
|
|
||||||
),
|
|
||||||
labels = ["pt_operator_library"], # for pt_operator_query_codegen query
|
|
||||||
visibility = visibility,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
111
pt_ops.bzl
111
pt_ops.bzl
@ -1,3 +1,114 @@
|
|||||||
|
load("//tools/build_defs:expect.bzl", "expect")
|
||||||
|
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||||
|
load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
|
||||||
|
|
||||||
|
# @lint-ignore BUCKRESTRICTEDSYNTAX
|
||||||
|
IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
|
||||||
|
|
||||||
|
USED_PT_BACKENDS = [
|
||||||
|
"CPU",
|
||||||
|
"QuantizedCPU",
|
||||||
|
"SparseCPU", # brings ~20 kb size regression
|
||||||
|
]
|
||||||
|
|
||||||
|
def pt_operator_library(
|
||||||
|
name,
|
||||||
|
ops = [],
|
||||||
|
exported_deps = [],
|
||||||
|
check_decl = True,
|
||||||
|
train = False,
|
||||||
|
model = None,
|
||||||
|
include_all_operators = False,
|
||||||
|
**kwargs):
|
||||||
|
(model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information(
|
||||||
|
name,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
|
||||||
|
ops = [op.strip() for op in ops]
|
||||||
|
|
||||||
|
# If ops are specified, then we are in static selective build mode, so we append
|
||||||
|
# base ops to this list to avoid additional special case logic in subsequent code.
|
||||||
|
if len(ops) > 0:
|
||||||
|
ops.extend(PT_BASE_OPS)
|
||||||
|
|
||||||
|
labels = kwargs.pop("labels", [])
|
||||||
|
visibility = kwargs.pop("visibility", ["PUBLIC"])
|
||||||
|
|
||||||
|
fb_xplat_genrule(
|
||||||
|
name = name,
|
||||||
|
out = "model_operators.yaml",
|
||||||
|
cmd = (
|
||||||
|
"$(exe {root}:gen_operators_yaml) " +
|
||||||
|
"{optionally_root_ops} " +
|
||||||
|
"{optionally_training_root_ops} " +
|
||||||
|
"--rule_name {rule_name} " +
|
||||||
|
"--output_path \"${{OUT}}\" " +
|
||||||
|
"--model_name {model_name} " +
|
||||||
|
"--dep_graph_yaml_path {dep_graph_yaml} " +
|
||||||
|
"--models_yaml_path {models_yaml} " +
|
||||||
|
"{optionally_model_versions} " +
|
||||||
|
"{optionally_model_assets} " +
|
||||||
|
"{optionally_model_traced_backends} " +
|
||||||
|
"{optionally_include_all_operators}"
|
||||||
|
).format(
|
||||||
|
root = "//" if IS_OSS else "//xplat/caffe2",
|
||||||
|
rule_name = name,
|
||||||
|
model_name = model_name,
|
||||||
|
dep_graph_yaml = "none" if IS_OSS else "$(location //xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
|
||||||
|
models_yaml = "none" if IS_OSS else "$(location //xplat/pytorch_models:all_mobile_model_configs)/build/all_mobile_model_configs.yaml ",
|
||||||
|
optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "",
|
||||||
|
optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "",
|
||||||
|
optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "",
|
||||||
|
optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
|
||||||
|
optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
|
||||||
|
optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
|
||||||
|
),
|
||||||
|
labels = labels + [
|
||||||
|
"pt_operator_library",
|
||||||
|
"supermodule:android/default/pytorch",
|
||||||
|
"supermodule:ios/default/public.pytorch",
|
||||||
|
] + (["pt_train_operator_library"] if train else []),
|
||||||
|
visibility = visibility,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_and_extract_model_information(name, model):
|
||||||
|
model_name = name
|
||||||
|
model_versions = None
|
||||||
|
model_assets = None
|
||||||
|
model_traced_backends = None
|
||||||
|
|
||||||
|
if model != None:
|
||||||
|
model_name = model.get("name")
|
||||||
|
expect(model_name != None, "Expected Model Name to be present")
|
||||||
|
model_versions = model.get("versions")
|
||||||
|
expect(is_list(model_versions), "Expected model versions to be a list of string")
|
||||||
|
for ver in model_versions or []:
|
||||||
|
expect(is_string(ver), "Expected version '{}' to be string".format(str(ver)))
|
||||||
|
model_assets = model.get("assets")
|
||||||
|
expect(
|
||||||
|
model_assets == None or is_list(model_assets),
|
||||||
|
"Expected model assets to be a list of string if specified",
|
||||||
|
)
|
||||||
|
for asset_name in model_assets or []:
|
||||||
|
expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name)))
|
||||||
|
model_traced_backends = model.get("traced_backends")
|
||||||
|
expect(
|
||||||
|
model_traced_backends == None or is_list(model_traced_backends),
|
||||||
|
"Expected model traced backends to be a list of string if specified",
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_traced_backends != None:
|
||||||
|
for backend in model_traced_backends:
|
||||||
|
expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend)))
|
||||||
|
expect(
|
||||||
|
backend in USED_PT_BACKENDS,
|
||||||
|
"Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (model_name, model_versions, model_assets, model_traced_backends)
|
||||||
|
|
||||||
# This file keeps a list of PyTorch operators used by any targets in
|
# This file keeps a list of PyTorch operators used by any targets in
|
||||||
# @fbsource//xplat/...
|
# @fbsource//xplat/...
|
||||||
# The purpose of the list is to avoid generating large number of unused
|
# The purpose of the list is to avoid generating large number of unused
|
||||||
|
Reference in New Issue
Block a user