[4] move pt_operator_library to shared BUCK file (#80170)

Summary:
Move pt_operator_library to pt_ops.bzl and make it shared with OSS BUCK build

This will replace D36912042. I will update all load statements in future diffs.

Test Plan: sandcaslte, OSS CI

Differential Revision: D37390060

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80170
Approved by: https://github.com/JacobSzwejbka
This commit is contained in:
Linbin Yu
2022-06-24 21:51:20 +00:00
committed by PyTorch MergeBot
parent bda04e9f5e
commit e98e7fe428
5 changed files with 126 additions and 208 deletions

View File

@ -1,3 +1,114 @@
load("//tools/build_defs:expect.bzl", "expect")
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
# @lint-ignore BUCKRESTRICTEDSYNTAX
IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
USED_PT_BACKENDS = [
"CPU",
"QuantizedCPU",
"SparseCPU", # brings ~20 kb size regression
]
def pt_operator_library(
name,
ops = [],
exported_deps = [],
check_decl = True,
train = False,
model = None,
include_all_operators = False,
**kwargs):
(model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information(
name,
model,
)
ops = [op.strip() for op in ops]
# If ops are specified, then we are in static selective build mode, so we append
# base ops to this list to avoid additional special case logic in subsequent code.
if len(ops) > 0:
ops.extend(PT_BASE_OPS)
labels = kwargs.pop("labels", [])
visibility = kwargs.pop("visibility", ["PUBLIC"])
fb_xplat_genrule(
name = name,
out = "model_operators.yaml",
cmd = (
"$(exe {root}:gen_operators_yaml) " +
"{optionally_root_ops} " +
"{optionally_training_root_ops} " +
"--rule_name {rule_name} " +
"--output_path \"${{OUT}}\" " +
"--model_name {model_name} " +
"--dep_graph_yaml_path {dep_graph_yaml} " +
"--models_yaml_path {models_yaml} " +
"{optionally_model_versions} " +
"{optionally_model_assets} " +
"{optionally_model_traced_backends} " +
"{optionally_include_all_operators}"
).format(
root = "//" if IS_OSS else "//xplat/caffe2",
rule_name = name,
model_name = model_name,
dep_graph_yaml = "none" if IS_OSS else "$(location //xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
models_yaml = "none" if IS_OSS else "$(location //xplat/pytorch_models:all_mobile_model_configs)/build/all_mobile_model_configs.yaml ",
optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "",
optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "",
optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "",
optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
),
labels = labels + [
"pt_operator_library",
"supermodule:android/default/pytorch",
"supermodule:ios/default/public.pytorch",
] + (["pt_train_operator_library"] if train else []),
visibility = visibility,
**kwargs
)
def validate_and_extract_model_information(name, model):
model_name = name
model_versions = None
model_assets = None
model_traced_backends = None
if model != None:
model_name = model.get("name")
expect(model_name != None, "Expected Model Name to be present")
model_versions = model.get("versions")
expect(is_list(model_versions), "Expected model versions to be a list of string")
for ver in model_versions or []:
expect(is_string(ver), "Expected version '{}' to be string".format(str(ver)))
model_assets = model.get("assets")
expect(
model_assets == None or is_list(model_assets),
"Expected model assets to be a list of string if specified",
)
for asset_name in model_assets or []:
expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name)))
model_traced_backends = model.get("traced_backends")
expect(
model_traced_backends == None or is_list(model_traced_backends),
"Expected model traced backends to be a list of string if specified",
)
if model_traced_backends != None:
for backend in model_traced_backends:
expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend)))
expect(
backend in USED_PT_BACKENDS,
"Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)),
)
return (model_name, model_versions, model_assets, model_traced_backends)
# This file keeps a list of PyTorch operators used by any targets in
# @fbsource//xplat/...
# The purpose of the list is to avoid generating large number of unused