mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[4] move pt_operator_library to shared BUCK file (#80170)
Summary: Move pt_operator_library to pt_ops.bzl and make it shared with OSS BUCK build This will replace D36912042. I will update all load statements in future diffs. Test Plan: sandcaslte, OSS CI Differential Revision: D37390060 Pull Request resolved: https://github.com/pytorch/pytorch/pull/80170 Approved by: https://github.com/JacobSzwejbka
This commit is contained in:
committed by
PyTorch MergeBot
parent
bda04e9f5e
commit
e98e7fe428
18
.github/workflows/_buck-build-test.yml
vendored
18
.github/workflows/_buck-build-test.yml
vendored
@ -62,29 +62,17 @@ jobs:
|
||||
command: |
|
||||
sh scripts/buck_setup.sh
|
||||
|
||||
- name: Build glog
|
||||
run: |
|
||||
buck build third_party:glog
|
||||
|
||||
- name: Build C10
|
||||
run: |
|
||||
buck build c10:c10
|
||||
|
||||
- name: Build cpuinfo
|
||||
run: |
|
||||
buck build third_party:cpuinfo
|
||||
|
||||
- name: Build pthreadpool
|
||||
run: |
|
||||
buck build third_party:pthreadpool
|
||||
|
||||
- name: Build XNNPACK
|
||||
run: |
|
||||
buck build third_party:XNNPACK
|
||||
|
||||
- name: Build QNNPACK
|
||||
run: |
|
||||
buck build aten/src/ATen/native/quantized/cpu/qnnpack/... --keep-going
|
||||
buck build aten/src/ATen/native/quantized/cpu/qnnpack:pytorch_qnnpack
|
||||
|
||||
- name: Build aten_cpu
|
||||
run: |
|
||||
@ -94,9 +82,9 @@ jobs:
|
||||
run: |
|
||||
buck build :torch_mobile_core
|
||||
|
||||
- name: Build torch_mobile_all_ops
|
||||
- name: Build pt_ops_full
|
||||
run: |
|
||||
buck build :torch_mobile_all_ops
|
||||
buck build :pt_ops_full
|
||||
|
||||
- name: Build mobile benchmark
|
||||
run: |
|
||||
|
18
BUCK.oss
18
BUCK.oss
@ -1,15 +1,17 @@
|
||||
load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
|
||||
load(
|
||||
":pt_defs.oss.bzl",
|
||||
":pt_ops.bzl",
|
||||
"pt_operator_library",
|
||||
"get_pt_ops_deps",
|
||||
)
|
||||
load(":buckbuild.bzl",
|
||||
"define_buck_targets",
|
||||
"get_pt_operator_registry_dict",
|
||||
)
|
||||
|
||||
# define shared buck targets
|
||||
define_buck_targets()
|
||||
|
||||
# define OSS only targets
|
||||
cxx_library(
|
||||
name = "pthreadpool",
|
||||
srcs = ['caffe2/utils/threadpool/pthreadpool.cc', 'caffe2/utils/threadpool/pthreadpool_impl.cc', 'caffe2/utils/threadpool/pthreadpool-cpp.cc', 'caffe2/utils/threadpool/thread_pool_guard.cpp', 'caffe2/utils/threadpool/ThreadPool.cc'],
|
||||
@ -76,21 +78,17 @@ cxx_library(
|
||||
|
||||
pt_operator_library(
|
||||
name = "torch_mobile_ops_full_dev",
|
||||
check_decl = False,
|
||||
include_all_operators = True,
|
||||
)
|
||||
|
||||
cxx_library(
|
||||
name = "torch_mobile_all_ops",
|
||||
visibility = ["PUBLIC"],
|
||||
deps = get_pt_ops_deps(
|
||||
name = "pt_ops_full",
|
||||
**get_pt_operator_registry_dict(
|
||||
name = "pt_ops_full",
|
||||
train = False,
|
||||
deps = [
|
||||
":torch_mobile_ops_full_dev",
|
||||
],
|
||||
enable_flatbuffer = False,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
cxx_binary(
|
||||
@ -118,7 +116,7 @@ cxx_binary(
|
||||
],
|
||||
deps = [
|
||||
":torch_mobile_core",
|
||||
":torch_mobile_all_ops",
|
||||
":pt_ops_full",
|
||||
"//c10:c10",
|
||||
],
|
||||
)
|
||||
|
@ -25,6 +25,10 @@ load(
|
||||
"jit_core_sources",
|
||||
"libtorch_profiler_sources",
|
||||
)
|
||||
load(
|
||||
":pt_ops.bzl",
|
||||
"USED_PT_BACKENDS",
|
||||
)
|
||||
load(
|
||||
":pt_template_srcs.bzl",
|
||||
"METAL_MASKRCNN_SOURCE_LIST",
|
||||
@ -235,12 +239,6 @@ def get_pt_preprocessor_flags():
|
||||
PT_PREPROCESSOR_FLAGS.append("-DENABLE_PYTORCH_NON_PRODUCTION_BUILDS")
|
||||
return PT_PREPROCESSOR_FLAGS
|
||||
|
||||
USED_PT_BACKENDS = [
|
||||
"CPU",
|
||||
"QuantizedCPU",
|
||||
"SparseCPU", # brings ~20 kb size regression
|
||||
]
|
||||
|
||||
# This needs to be kept in sync with https://github.com/pytorch/pytorch/blob/release/1.9/torchgen/gen.py#L892
|
||||
PT_BACKEND_HEADERS = [
|
||||
"CPU",
|
||||
|
177
pt_defs.oss.bzl
177
pt_defs.oss.bzl
@ -1,177 +0,0 @@
|
||||
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
|
||||
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||
load(
|
||||
":buckbuild.bzl",
|
||||
"get_pt_operator_registry_dict",
|
||||
)
|
||||
|
||||
PT_BASE_OPS = [
|
||||
"aten::_coalesced_",
|
||||
"aten::_copy_from",
|
||||
"aten::_empty_affine_quantized",
|
||||
"aten::_empty_per_channel_affine_quantized",
|
||||
"aten::_indices",
|
||||
"aten::_nnz",
|
||||
"aten::_values",
|
||||
"aten::add",
|
||||
"aten::add_",
|
||||
"aten::arange",
|
||||
"aten::as_strided",
|
||||
"aten::as_strided_",
|
||||
"aten::cat",
|
||||
"aten::clone",
|
||||
"aten::coalesce",
|
||||
"aten::contiguous",
|
||||
"aten::copy_",
|
||||
"aten::copy_sparse_to_sparse_",
|
||||
"aten::dense_dim",
|
||||
"aten::dequantize",
|
||||
"aten::div",
|
||||
"aten::div_",
|
||||
"aten::empty",
|
||||
"aten::empty_like",
|
||||
"aten::empty_strided",
|
||||
"aten::empty.memory_format",
|
||||
"aten::eq",
|
||||
"aten::equal",
|
||||
"aten::expand",
|
||||
"aten::fill_",
|
||||
"aten::is_coalesced",
|
||||
"aten::is_complex",
|
||||
"aten::is_floating_point",
|
||||
"aten::is_leaf",
|
||||
"aten::is_nonzero",
|
||||
"aten::item",
|
||||
"aten::max",
|
||||
"aten::min",
|
||||
"aten::mul",
|
||||
"aten::mul_",
|
||||
"aten::narrow",
|
||||
"aten::ne",
|
||||
"aten::permute",
|
||||
"aten::q_per_channel_axis",
|
||||
"aten::q_per_channel_scales",
|
||||
"aten::q_per_channel_zero_points",
|
||||
"aten::q_scale",
|
||||
"aten::q_zero_point",
|
||||
"aten::qscheme",
|
||||
"aten::quantize_per_tensor",
|
||||
"aten::reshape",
|
||||
"aten::_reshape_alias",
|
||||
"aten::resize_",
|
||||
"aten::resize_as_",
|
||||
"aten::scalar_tensor",
|
||||
"aten::select",
|
||||
"aten::set_",
|
||||
"aten::size",
|
||||
"aten::slice",
|
||||
"aten::sparse_dim",
|
||||
"aten::sparse_resize_and_clear_",
|
||||
"aten::squeeze",
|
||||
"aten::squeeze_",
|
||||
"aten::stride",
|
||||
"aten::sub",
|
||||
"aten::sub_",
|
||||
"aten::sum",
|
||||
"aten::t",
|
||||
"aten::to",
|
||||
"aten::_to_copy",
|
||||
"aten::unsqueeze",
|
||||
"aten::view",
|
||||
"aten::zero_",
|
||||
"aten::zeros",
|
||||
"aten::zeros_like",
|
||||
]
|
||||
|
||||
######### selective build #########
|
||||
|
||||
def pt_operator_registry(
|
||||
name,
|
||||
deps = [],
|
||||
train = False,
|
||||
labels = [],
|
||||
env = [],
|
||||
template_select = True,
|
||||
enforce_traced_op_list = False,
|
||||
pt_allow_forced_schema_registration = True,
|
||||
enable_flatbuffer = False,
|
||||
**kwargs):
|
||||
args = get_pt_operator_registry_dict(
|
||||
name,
|
||||
deps,
|
||||
train,
|
||||
labels,
|
||||
env,
|
||||
template_select,
|
||||
enforce_traced_op_list,
|
||||
pt_allow_forced_schema_registration,
|
||||
enable_flatbuffer = True,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
fb_xplat_cxx_library(
|
||||
name = name,
|
||||
**args
|
||||
)
|
||||
|
||||
def get_pt_ops_deps(name, deps, train = False, enforce_traced_op_list = False, enable_flatbuffer = False, **kwargs):
|
||||
pt_operator_registry(
|
||||
name,
|
||||
deps,
|
||||
train = train,
|
||||
enforce_traced_op_list = enforce_traced_op_list,
|
||||
enable_flatbuffer = enable_flatbuffer,
|
||||
**kwargs
|
||||
)
|
||||
return deps + [":" + name]
|
||||
|
||||
def pt_operator_library(
|
||||
name,
|
||||
ops = [],
|
||||
exported_deps = [],
|
||||
check_decl = True,
|
||||
train = False,
|
||||
model = None,
|
||||
include_all_operators = False,
|
||||
**kwargs):
|
||||
model_name = name
|
||||
|
||||
ops = [op.strip() for op in ops]
|
||||
|
||||
# If ops are specified, then we are in static selective build mode, so we append
|
||||
# base ops to this list to avoid additional special case logic in subsequent code.
|
||||
if len(ops) > 0:
|
||||
ops.extend(PT_BASE_OPS)
|
||||
|
||||
visibility = kwargs.pop("visibility", ["PUBLIC"])
|
||||
|
||||
fb_xplat_genrule(
|
||||
name = name,
|
||||
out = "model_operators.yaml",
|
||||
cmd = (
|
||||
"$(exe :gen_operators_yaml) " +
|
||||
"{optionally_root_ops} " +
|
||||
"{optionally_training_root_ops} " +
|
||||
"--rule_name {rule_name} " +
|
||||
"--output_path \"${{OUT}}\" " +
|
||||
"--model_name {model_name} " +
|
||||
"--dep_graph_yaml_path pytorch_op_deps.yaml " +
|
||||
"--models_yaml_path all_mobile_model_configs.yaml " +
|
||||
#"{optionally_model_versions} " +
|
||||
#"{optionally_model_assets} " +
|
||||
#"{optionally_model_traced_backends} " +
|
||||
"{optionally_include_all_operators}"
|
||||
).format(
|
||||
rule_name = name,
|
||||
model_name = model_name,
|
||||
optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "",
|
||||
optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "",
|
||||
#optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "",
|
||||
#optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
|
||||
#optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
|
||||
optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
|
||||
),
|
||||
labels = ["pt_operator_library"], # for pt_operator_query_codegen query
|
||||
visibility = visibility,
|
||||
**kwargs
|
||||
)
|
111
pt_ops.bzl
111
pt_ops.bzl
@ -1,3 +1,114 @@
|
||||
load("//tools/build_defs:expect.bzl", "expect")
|
||||
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||
load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
|
||||
|
||||
# @lint-ignore BUCKRESTRICTEDSYNTAX
|
||||
IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
|
||||
|
||||
USED_PT_BACKENDS = [
|
||||
"CPU",
|
||||
"QuantizedCPU",
|
||||
"SparseCPU", # brings ~20 kb size regression
|
||||
]
|
||||
|
||||
def pt_operator_library(
|
||||
name,
|
||||
ops = [],
|
||||
exported_deps = [],
|
||||
check_decl = True,
|
||||
train = False,
|
||||
model = None,
|
||||
include_all_operators = False,
|
||||
**kwargs):
|
||||
(model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information(
|
||||
name,
|
||||
model,
|
||||
)
|
||||
|
||||
ops = [op.strip() for op in ops]
|
||||
|
||||
# If ops are specified, then we are in static selective build mode, so we append
|
||||
# base ops to this list to avoid additional special case logic in subsequent code.
|
||||
if len(ops) > 0:
|
||||
ops.extend(PT_BASE_OPS)
|
||||
|
||||
labels = kwargs.pop("labels", [])
|
||||
visibility = kwargs.pop("visibility", ["PUBLIC"])
|
||||
|
||||
fb_xplat_genrule(
|
||||
name = name,
|
||||
out = "model_operators.yaml",
|
||||
cmd = (
|
||||
"$(exe {root}:gen_operators_yaml) " +
|
||||
"{optionally_root_ops} " +
|
||||
"{optionally_training_root_ops} " +
|
||||
"--rule_name {rule_name} " +
|
||||
"--output_path \"${{OUT}}\" " +
|
||||
"--model_name {model_name} " +
|
||||
"--dep_graph_yaml_path {dep_graph_yaml} " +
|
||||
"--models_yaml_path {models_yaml} " +
|
||||
"{optionally_model_versions} " +
|
||||
"{optionally_model_assets} " +
|
||||
"{optionally_model_traced_backends} " +
|
||||
"{optionally_include_all_operators}"
|
||||
).format(
|
||||
root = "//" if IS_OSS else "//xplat/caffe2",
|
||||
rule_name = name,
|
||||
model_name = model_name,
|
||||
dep_graph_yaml = "none" if IS_OSS else "$(location //xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
|
||||
models_yaml = "none" if IS_OSS else "$(location //xplat/pytorch_models:all_mobile_model_configs)/build/all_mobile_model_configs.yaml ",
|
||||
optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "",
|
||||
optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "",
|
||||
optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "",
|
||||
optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
|
||||
optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
|
||||
optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
|
||||
),
|
||||
labels = labels + [
|
||||
"pt_operator_library",
|
||||
"supermodule:android/default/pytorch",
|
||||
"supermodule:ios/default/public.pytorch",
|
||||
] + (["pt_train_operator_library"] if train else []),
|
||||
visibility = visibility,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def validate_and_extract_model_information(name, model):
|
||||
model_name = name
|
||||
model_versions = None
|
||||
model_assets = None
|
||||
model_traced_backends = None
|
||||
|
||||
if model != None:
|
||||
model_name = model.get("name")
|
||||
expect(model_name != None, "Expected Model Name to be present")
|
||||
model_versions = model.get("versions")
|
||||
expect(is_list(model_versions), "Expected model versions to be a list of string")
|
||||
for ver in model_versions or []:
|
||||
expect(is_string(ver), "Expected version '{}' to be string".format(str(ver)))
|
||||
model_assets = model.get("assets")
|
||||
expect(
|
||||
model_assets == None or is_list(model_assets),
|
||||
"Expected model assets to be a list of string if specified",
|
||||
)
|
||||
for asset_name in model_assets or []:
|
||||
expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name)))
|
||||
model_traced_backends = model.get("traced_backends")
|
||||
expect(
|
||||
model_traced_backends == None or is_list(model_traced_backends),
|
||||
"Expected model traced backends to be a list of string if specified",
|
||||
)
|
||||
|
||||
if model_traced_backends != None:
|
||||
for backend in model_traced_backends:
|
||||
expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend)))
|
||||
expect(
|
||||
backend in USED_PT_BACKENDS,
|
||||
"Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)),
|
||||
)
|
||||
|
||||
return (model_name, model_versions, model_assets, model_traced_backends)
|
||||
|
||||
# This file keeps a list of PyTorch operators used by any targets in
|
||||
# @fbsource//xplat/...
|
||||
# The purpose of the list is to avoid generating large number of unused
|
||||
|
Reference in New Issue
Block a user