mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[4] move pt_operator_library to shared BUCK file (#80170)
Summary: Move pt_operator_library to pt_ops.bzl and make it shared with OSS BUCK build This will replace D36912042. I will update all load statements in future diffs. Test Plan: sandcaslte, OSS CI Differential Revision: D37390060 Pull Request resolved: https://github.com/pytorch/pytorch/pull/80170 Approved by: https://github.com/JacobSzwejbka
This commit is contained in:
committed by
PyTorch MergeBot
parent
bda04e9f5e
commit
e98e7fe428
111
pt_ops.bzl
111
pt_ops.bzl
@ -1,3 +1,114 @@
|
||||
load("//tools/build_defs:expect.bzl", "expect")
|
||||
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||
load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
|
||||
|
||||
# @lint-ignore BUCKRESTRICTEDSYNTAX
|
||||
IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
|
||||
|
||||
USED_PT_BACKENDS = [
|
||||
"CPU",
|
||||
"QuantizedCPU",
|
||||
"SparseCPU", # brings ~20 kb size regression
|
||||
]
|
||||
|
||||
def pt_operator_library(
|
||||
name,
|
||||
ops = [],
|
||||
exported_deps = [],
|
||||
check_decl = True,
|
||||
train = False,
|
||||
model = None,
|
||||
include_all_operators = False,
|
||||
**kwargs):
|
||||
(model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information(
|
||||
name,
|
||||
model,
|
||||
)
|
||||
|
||||
ops = [op.strip() for op in ops]
|
||||
|
||||
# If ops are specified, then we are in static selective build mode, so we append
|
||||
# base ops to this list to avoid additional special case logic in subsequent code.
|
||||
if len(ops) > 0:
|
||||
ops.extend(PT_BASE_OPS)
|
||||
|
||||
labels = kwargs.pop("labels", [])
|
||||
visibility = kwargs.pop("visibility", ["PUBLIC"])
|
||||
|
||||
fb_xplat_genrule(
|
||||
name = name,
|
||||
out = "model_operators.yaml",
|
||||
cmd = (
|
||||
"$(exe {root}:gen_operators_yaml) " +
|
||||
"{optionally_root_ops} " +
|
||||
"{optionally_training_root_ops} " +
|
||||
"--rule_name {rule_name} " +
|
||||
"--output_path \"${{OUT}}\" " +
|
||||
"--model_name {model_name} " +
|
||||
"--dep_graph_yaml_path {dep_graph_yaml} " +
|
||||
"--models_yaml_path {models_yaml} " +
|
||||
"{optionally_model_versions} " +
|
||||
"{optionally_model_assets} " +
|
||||
"{optionally_model_traced_backends} " +
|
||||
"{optionally_include_all_operators}"
|
||||
).format(
|
||||
root = "//" if IS_OSS else "//xplat/caffe2",
|
||||
rule_name = name,
|
||||
model_name = model_name,
|
||||
dep_graph_yaml = "none" if IS_OSS else "$(location //xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
|
||||
models_yaml = "none" if IS_OSS else "$(location //xplat/pytorch_models:all_mobile_model_configs)/build/all_mobile_model_configs.yaml ",
|
||||
optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "",
|
||||
optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "",
|
||||
optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "",
|
||||
optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
|
||||
optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
|
||||
optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
|
||||
),
|
||||
labels = labels + [
|
||||
"pt_operator_library",
|
||||
"supermodule:android/default/pytorch",
|
||||
"supermodule:ios/default/public.pytorch",
|
||||
] + (["pt_train_operator_library"] if train else []),
|
||||
visibility = visibility,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def validate_and_extract_model_information(name, model):
|
||||
model_name = name
|
||||
model_versions = None
|
||||
model_assets = None
|
||||
model_traced_backends = None
|
||||
|
||||
if model != None:
|
||||
model_name = model.get("name")
|
||||
expect(model_name != None, "Expected Model Name to be present")
|
||||
model_versions = model.get("versions")
|
||||
expect(is_list(model_versions), "Expected model versions to be a list of string")
|
||||
for ver in model_versions or []:
|
||||
expect(is_string(ver), "Expected version '{}' to be string".format(str(ver)))
|
||||
model_assets = model.get("assets")
|
||||
expect(
|
||||
model_assets == None or is_list(model_assets),
|
||||
"Expected model assets to be a list of string if specified",
|
||||
)
|
||||
for asset_name in model_assets or []:
|
||||
expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name)))
|
||||
model_traced_backends = model.get("traced_backends")
|
||||
expect(
|
||||
model_traced_backends == None or is_list(model_traced_backends),
|
||||
"Expected model traced backends to be a list of string if specified",
|
||||
)
|
||||
|
||||
if model_traced_backends != None:
|
||||
for backend in model_traced_backends:
|
||||
expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend)))
|
||||
expect(
|
||||
backend in USED_PT_BACKENDS,
|
||||
"Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)),
|
||||
)
|
||||
|
||||
return (model_name, model_versions, model_assets, model_traced_backends)
|
||||
|
||||
# This file keeps a list of PyTorch operators used by any targets in
|
||||
# @fbsource//xplat/...
|
||||
# The purpose of the list is to avoid generating large number of unused
|
||||
|
Reference in New Issue
Block a user