Files
pytorch/pt_ops.bzl
Linbin Yu a563a4880f [Edge] Add an option to avoid adding base ops to static op library (#84360)
Summary: We use a static op library in a test for PyTorch C++ usages, but don't want to introduce all base ops. Because the goal is to check if a given model can run on the exact op collection (i.e., fbios ops, fb4a ops), and these base ops are not present in real apps. So add an option to disable this feature.

Test Plan: Build. Expect no change to existing targets.

Differential Revision: D39164021

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84360
Approved by: https://github.com/kit1980
2022-09-01 22:32:55 +00:00

620 lines
17 KiB
Python

load("//tools/build_defs:expect.bzl", "expect")
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
# @lint-ignore BUCKRESTRICTEDSYNTAX
IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
USED_PT_BACKENDS = [
"CPU",
"QuantizedCPU",
"SparseCPU", # brings ~20 kb size regression
]
def pt_operator_library(
name,
ops = [],
exported_deps = [],
check_decl = True,
train = False,
model = None,
include_all_operators = False,
include_base_operators = True,
**kwargs):
(model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information(
name,
model,
)
ops = [op.strip() for op in ops]
# If ops are specified, then we are in static selective build mode, so we append
# base ops to this list to avoid additional special case logic in subsequent code,
# unless include_base_operators is explicitly set to False (the default is True)
if len(ops) > 0 and include_base_operators:
ops.extend(PT_BASE_OPS)
labels = kwargs.pop("labels", [])
visibility = kwargs.pop("visibility", ["PUBLIC"])
fb_xplat_genrule(
name = name,
out = "model_operators.yaml",
cmd = (
"$(exe {exe}) " +
"{optionally_root_ops} " +
"{optionally_training_root_ops} " +
"--rule_name {rule_name} " +
"--output_path \"${{OUT}}\" " +
"--model_name {model_name} " +
"--dep_graph_yaml_path {dep_graph_yaml} " +
"--models_yaml_path {models_yaml} " +
"{optionally_model_versions} " +
"{optionally_model_assets} " +
"{optionally_model_traced_backends} " +
"{optionally_include_all_operators}"
).format(
exe = "//tools:gen_operators_yaml" if IS_OSS else "fbsource//xplat/caffe2/tools:gen_operators_yaml",
rule_name = name,
model_name = model_name,
dep_graph_yaml = "none" if IS_OSS else "$(location fbsource//xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
models_yaml = "none" if IS_OSS else "$(location fbsource//xplat/pytorch_models:all_mobile_model_configs)/build/all_mobile_model_configs.yaml ",
optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "",
optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "",
optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "",
optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
),
labels = labels + [
"pt_operator_library",
"supermodule:android/default/pytorch",
"supermodule:ios/default/public.pytorch",
] + (["pt_train_operator_library"] if train else []),
visibility = visibility,
**kwargs
)
def validate_and_extract_model_information(name, model):
model_name = name
model_versions = None
model_assets = None
model_traced_backends = None
if model != None:
model_name = model.get("name")
expect(model_name != None, "Expected Model Name to be present")
model_versions = model.get("versions")
expect(is_list(model_versions), "Expected model versions to be a list of string")
for ver in model_versions or []:
expect(is_string(ver), "Expected version '{}' to be string".format(str(ver)))
model_assets = model.get("assets")
expect(
model_assets == None or is_list(model_assets),
"Expected model assets to be a list of string if specified",
)
for asset_name in model_assets or []:
expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name)))
model_traced_backends = model.get("traced_backends")
expect(
model_traced_backends == None or is_list(model_traced_backends),
"Expected model traced backends to be a list of string if specified",
)
if model_traced_backends != None:
for backend in model_traced_backends:
expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend)))
expect(
backend in USED_PT_BACKENDS,
"Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)),
)
return (model_name, model_versions, model_assets, model_traced_backends)
# This file keeps a list of PyTorch operators used by any targets in
# @fbsource//xplat/...
# The purpose of the list is to avoid generating large number of unused
# operator registration code / BUCK rules at build time.
# See more detail at: https://fb.quip.com/ZVh1AgOKW8Vv
PT_OPS_PRIM = [
"aten::str",
"aten::list",
"aten::__range_length",
"aten::__derive_index",
"prim::TupleUnpack",
"prim::unchecked_cast",
"aten::IntImplicit",
"aten::FloatImplicit",
"aten::ScalarImplicit",
"aten::Bool.Tensor",
"aten::Bool.int",
"aten::Bool.float",
"aten::Int.Tensor",
"aten::Int.Scalar",
"aten::Int.int",
"aten::Int.bool",
"aten::Int.str",
"aten::Float.Tensor",
"aten::Float.Scalar",
"aten::Float.int",
"aten::Float.bool",
"aten::Float.str",
"aten::format",
"prim::NumToTensor.Scalar",
"prim::RaiseException",
"aten::Size",
"aten::size",
"prim::EnumName",
"prim::EnumValue.int",
"prim::EnumValue.float",
"prim::EnumValue.str",
"prim::TupleIndex",
"aten::ne.int_list",
"prim::unchecked_unwrap_optional",
"prim::device",
"prim::dtype",
"aten::__not__",
"aten::__is__",
"aten::__isnot__",
"aten::element_size",
"aten::numel",
"aten::dim",
"aten::get_device",
"aten::storage_offset",
"aten::is_contiguous",
"aten::select.t",
"aten::__getitem__.t",
"aten::append.t",
"aten::reverse.t",
"aten::extend.t",
"aten::copy.t",
"aten::_set_item.t",
"aten::clear.t",
"aten::Delete.t",
"aten::insert.t",
"aten::pop.t",
"aten::add.t",
"aten::add_.t",
"aten::slice.t",
"aten::list.t",
"aten::mul.left_t",
"aten::mul.right_",
"aten::mul_.t",
"aten::len.t",
"aten::eq.int_list",
"prim::Uninitialized",
"prim::Print",
"aten::eq.enum",
"aten::ne.enum",
"aten::dequantize.tensor",
"aten::dequantize.any",
"aten::add.str",
"aten::eq.int",
"aten::eq.float",
"aten::eq.int_float",
"aten::eq.float_int",
"aten::eq",
"aten::eq.str",
"aten::ne.int",
"aten::ne.float",
"aten::ne.int_float",
"aten::ne.float_int",
"aten::ne",
"aten::ne.str",
"aten::lt.int",
"aten::lt.float",
"aten::lt.int_float",
"aten::lt.float_int",
"aten::lt",
"aten::lt.str",
"aten::gt.int",
"aten::gt.float",
"aten::gt.int_float",
"aten::gt.float_int",
"aten::gt",
"aten::gt.str",
"aten::le.int",
"aten::le.float",
"aten::le.int_float",
"aten::le.float_int",
"aten::le",
"aten::le.str",
"aten::ge.int",
"aten::ge.float",
"aten::ge.int_float",
"aten::ge.float_int",
"aten::ge",
"aten::ge.str",
"aten::add.int",
"aten::add.float",
"aten::add.int_float",
"aten::add.float_int",
"aten::add",
"aten::sub.int",
"aten::sub.float",
"aten::sub.int_float",
"aten::sub.float_int",
"aten::sub",
"aten::mul.int",
"aten::mul.float",
"aten::mul.int_float",
"aten::mul.float_int",
"aten::mul",
"aten::__and__.bool",
"aten::__or__.bool",
"aten::__xor__.bool",
"aten::floor.int",
"aten::floor.float",
"aten::floor.Scalar",
"aten::ceil.int",
"aten::ceil.float",
"aten::ceil.Scalar",
"aten::neg.int",
"aten::neg.float",
"aten::neg.Scalar",
"aten::exp.int",
"aten::exp.float",
"aten::exp.Scalar",
"aten::remainder.int",
"aten::remainder.float",
"aten::remainder.int_float",
"aten::remainder.float_int",
"aten::remainder",
"aten::div.int",
"aten::div.float",
"aten::div",
"aten::floordiv.int",
"aten::floordiv.float",
"aten::floordiv.int_float",
"aten::floordiv.float_int",
"aten::floordiv",
"aten::pow.int",
"aten::pow.float",
"aten::pow.int_float",
"aten::pow.float_int",
"aten::pow.Scalar_Scalar",
"aten::pow.int_to_int",
"prim::min.int",
"prim::min.float",
"prim::min.int_float",
"prim::min.float_int",
"prim::min",
"prim::max.int",
"prim::max.float",
"prim::max.int_float",
"prim::max.float_int",
"prim::max",
"prim::type",
"aten::len.Tensor",
"aten::ord",
"aten::lower",
"aten::__contains__.str_list",
"aten::len.str",
"aten::__getitem__.str",
"aten::copy_.Tensor",
"aten::copy_.int",
"aten::copy_.float",
"aten::backward",
"aten::index.Tensor_hacked_twin",
"aten::_index_put_impl_.hacked_twin",
"aten::index_put_.hacked_twin",
"aten::index_put.hacked_twin",
"aten::to.prim_Device",
"aten::to.prim_dtype",
"prim::is_cuda",
"prim::data",
"prim::min.int_list",
"prim::max.int_list",
"prim::min.self_int",
"prim::max.self_int",
"prim::min.float_list",
"prim::max.float_list",
"prim::min.self_float",
"prim::max.self_float",
"prim::min.bool_list",
"prim::max.bool_list",
"prim::min.self_bool",
"prim::max.self_bool",
"aten::len.Dict_str",
"aten::keys.str",
"aten::values.str",
"aten::__getitem__.Dict_str",
"aten::get.str",
"aten::get.default_str",
"aten::setdefault.str",
"aten::Delete.Dict_str",
"aten::pop.Dict_str",
"aten::pop.Dict_default_str",
"aten::popitem.str",
"aten::clear.str",
"aten::update.str",
"aten::items.str",
"aten::copy.Dict_str",
"aten::__contains__.str",
"aten::_set_item.str",
"aten::dict.str",
"aten::len.Dict_int",
"aten::keys.int",
"aten::values.int",
"aten::__getitem__.Dict_int",
"aten::get.int",
"aten::get.default_int",
"aten::setdefault.int",
"aten::Delete.Dict_int",
"aten::pop.Dict_int",
"aten::pop.Dict_default_int",
"aten::popitem.int",
"aten::clear.int",
"aten::update.int",
"aten::items.int",
"aten::copy.Dict_int",
"aten::__contains__.int",
"aten::_set_item.int",
"aten::dict.int",
"aten::len.Dict_bool",
"aten::keys.bool",
"aten::values.bool",
"aten::__getitem__.Dict_bool",
"aten::get.bool",
"aten::get.default_bool",
"aten::setdefault.bool",
"aten::Delete.Dict_bool",
"aten::pop.Dict_bool",
"aten::pop.Dict_default_bool",
"aten::popitem.bool",
"aten::clear.bool",
"aten::update.bool",
"aten::items.bool",
"aten::copy.Dict_bool",
"aten::__contains__.bool",
"aten::_set_item.bool",
"aten::dict.bool",
"aten::len.Dict_float",
"aten::keys.float",
"aten::values.float",
"aten::__getitem__.Dict_float",
"aten::get.float",
"aten::get.default_float",
"aten::setdefault.float",
"aten::Delete.Dict_float",
"aten::pop.Dict_float",
"aten::pop.Dict_default_float",
"aten::popitem.float",
"aten::clear.float",
"aten::update.float",
"aten::items.float",
"aten::copy.Dict_float",
"aten::__contains__.float",
"aten::_set_item.float",
"aten::dict.float",
"aten::len.Dict_Tensor",
"aten::keys.Tensor",
"aten::values.Tensor",
"aten::__getitem__.Dict_Tensor",
"aten::get.Tensor",
"aten::get.default_Tensor",
"aten::setdefault.Tensor",
"aten::Delete.Dict_Tensor",
"aten::pop.Dict_Tensor",
"aten::pop.Dict_default_Tensor",
"aten::popitem.Tensor",
"aten::clear.Tensor",
"aten::update.Tensor",
"aten::items.Tensor",
"aten::copy.Dict_Tensor",
"aten::__contains__.Tensor",
"aten::_set_item.Tensor",
"aten::dict.Tensor",
"aten::__round_to_zero_floordiv.int",
"aten::mathremainder.int",
"aten::mathremainder.float",
"aten::mathremainder.int_float",
"aten::mathremainder.float_int",
"aten::mathremainder",
"aten::__and__.int",
"aten::__or__.int",
"aten::__xor__.int",
"aten::__lshift__.int",
"aten::__rshift__.int",
"aten::round.int",
"aten::round.float",
"aten::round.Scalar",
"aten::log.int",
"aten::log.float",
"aten::log.Scalar",
"aten::log.int_int",
"aten::log.float_float",
"aten::log.int_float",
"aten::log.float_int",
"aten::log.Scalar_Scalar",
"aten::log1p.int",
"aten::log1p.float",
"aten::log1p.Scalar",
"aten::log10.int",
"aten::log10.float",
"aten::log10.Scalar",
"aten::sqrt.int",
"aten::sqrt.float",
"aten::sqrt.Scalar",
"aten::acos.int",
"aten::acos.float",
"aten::acos.Scalar",
"aten::asin.int",
"aten::asin.float",
"aten::asin.Scalar",
"aten::atan.int",
"aten::atan.float",
"aten::atan.Scalar",
"aten::atan2.int",
"aten::atan2.float",
"aten::atan2.int_float",
"aten::atan2.float_int",
"aten::atan2.Scalar_Scalar",
"aten::cos.int",
"aten::cos.float",
"aten::cos.Scalar",
"aten::sin.int",
"aten::sin.float",
"aten::sin.Scalar",
"aten::tan.int",
"aten::tan.float",
"aten::tan.Scalar",
"aten::asinh.int",
"aten::asinh.float",
"aten::asinh.Scalar",
"aten::atanh.int",
"aten::atanh.float",
"aten::atanh.Scalar",
"aten::acosh.int",
"aten::acosh.float",
"aten::acosh.Scalar",
"aten::sinh.int",
"aten::sinh.float",
"aten::sinh.Scalar",
"aten::cosh.int",
"aten::cosh.float",
"aten::cosh.Scalar",
"aten::tanh.int",
"aten::tanh.float",
"aten::tanh.Scalar",
"aten::degrees.int",
"aten::degrees.float",
"aten::degrees.Scalar",
"aten::radians.int",
"aten::radians.float",
"aten::radians.Scalar",
"aten::fmod.int",
"aten::fmod.float",
"aten::fmod.int_float",
"aten::fmod.float_int",
"aten::fmod",
"aten::factorial.int",
"aten::isnan.float",
"aten::isfinite.float",
"aten::isinf.float",
"aten::gamma.int",
"aten::gamma.float",
"aten::gamma.Scalar",
"aten::erf.int",
"aten::erf.float",
"aten::erf.Scalar",
"aten::erfc.int",
"aten::erfc.float",
"aten::erfc.Scalar",
"aten::expm1.int",
"aten::expm1.float",
"aten::expm1.Scalar",
"aten::fabs.int",
"aten::fabs.float",
"aten::fabs.Scalar",
"aten::lgamma.int",
"aten::lgamma.float",
"aten::lgamma.Scalar",
"prim::abs.int",
"prim::abs.float",
"prim::abs.Scalar",
"aten::gcd.int",
"aten::copysign.int",
"aten::copysign.float",
"aten::copysign.int_float",
"aten::copysign.float_int",
"aten::copysign",
"aten::split",
"aten::tensor.float",
"aten::as_tensor.float",
"aten::tensor.int",
"aten::as_tensor.int",
"aten::tensor.bool",
"aten::as_tensor.bool",
"aten::_infer_size",
"aten::_no_grad_embedding_renorm_",
"aten::tensor",
"aten::as_tensor",
"aten::as_tensor.list",
"aten::_pack_sequence",
"aten::_get_tracing_state",
"aten::is_scripting",
"aten::_no_grad_uniform_",
"aten::_no_grad_normal_",
"aten::_no_grad_fill_",
"aten::_no_grad_zero_",
]
PT_BASE_OPS = [
"aten::_coalesced_",
"aten::_copy_from",
"aten::_empty_affine_quantized",
"aten::_empty_per_channel_affine_quantized",
"aten::_indices",
"aten::_nnz",
"aten::_values",
"aten::add",
"aten::add_",
"aten::arange",
"aten::as_strided",
"aten::as_strided_",
"aten::cat",
"aten::clone",
"aten::coalesce",
"aten::contiguous",
"aten::copy_",
"aten::copy_sparse_to_sparse_",
"aten::dense_dim",
"aten::dequantize",
"aten::div",
"aten::div_",
"aten::empty",
"aten::empty_like",
"aten::empty_strided",
"aten::eq",
"aten::equal",
"aten::expand",
"aten::fill_",
"aten::is_coalesced",
"aten::is_complex",
"aten::is_floating_point",
"aten::is_leaf",
"aten::is_nonzero",
"aten::item",
"aten::max",
"aten::min",
"aten::mul",
"aten::mul_",
"aten::narrow",
"aten::ne",
"aten::permute",
"aten::q_per_channel_axis",
"aten::q_per_channel_scales",
"aten::q_per_channel_zero_points",
"aten::q_scale",
"aten::q_zero_point",
"aten::qscheme",
"aten::quantize_per_tensor",
"aten::reshape",
"aten::_reshape_alias",
"aten::resize_",
"aten::resize_as_",
"aten::scalar_tensor",
"aten::select",
"aten::set_",
"aten::size",
"aten::slice",
"aten::sparse_dim",
"aten::sparse_resize_and_clear_",
"aten::squeeze",
"aten::squeeze_",
"aten::stride",
"aten::sub",
"aten::sub_",
"aten::sum",
"aten::t",
"aten::to",
"aten::_to_copy",
"aten::unsqueeze",
"aten::view",
"aten::zero_",
"aten::zeros",
"aten::zeros_like",
]