[gen_operators_yaml] add arguments to control include_all_overloads (#108396)

Summary:
In SelectiveBuildOperator, we can specify argument `include_all_overloads`. If True, all overloaded operators (for example, `aten::to.dtype_layout`, `aten::to.prim_Device"` are considered as overloaded operators of `aten::to`), will be built and linked to the final binary. This can significantly increases the final binary size, which could be a deal breaker for on-device deployment.

In this diff, we make back-compatible changes to add new arguments `--not-include-all-overloads-static-root-ops` and `--not-include-all-overloads-closure-ops`. When they are set, we set `include_all_overloads` flag to False for static root ops and closure ops, and rely on code analyzer to decide the actual used overloaded operator.

Test Plan:
- unit test
```
buck test //xplat/caffe2/tools:gen_operators_yaml_test
```
- See test plan in D48771544 where we reduce the shared lib file `libmrengine.lib` from 16653072 bytes to 13686032 bytes.
- See detailed document: https://fburl.com/gdoc/mc93h6kb

Reviewed By: larryliu0820

Differential Revision: D48772302

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108396
Approved by: https://github.com/larryliu0820
This commit is contained in:
Zhicheng Yan
2023-09-02 17:37:36 +00:00
committed by PyTorch MergeBot
parent b9dfdc091b
commit 01b662bafe
3 changed files with 101 additions and 8 deletions

View File

@ -1,5 +1,5 @@
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
load("//tools/build_defs:expect.bzl", "expect")
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
@ -72,6 +72,13 @@ def pt_operator_library(
# was hand-crafted which is not a support workflow for traced ops.
yaml_option = "--models_yaml_path $(location fbsource//xplat/pytorch_models/build/{}/v{}:{})/{}.yaml".format(model_name, model_version, yaml_dep, model_asset)
not_include_all_overloads_static_root_ops = kwargs.pop(
"not_include_all_overloads_static_root_ops",
False,
)
not_include_all_overloads_closure_ops = kwargs.pop("not_include_all_overloads_closure_ops", False)
fb_xplat_genrule(
name = name,
out = "model_operators.yaml",
@ -87,7 +94,9 @@ def pt_operator_library(
"{optionally_model_versions} " +
"{optionally_model_assets} " +
"{optionally_model_traced_backends} " +
"{optionally_include_all_operators}"
"{optionally_include_all_operators}" +
"{not_include_all_overloads_static_root_ops}" +
"{not_include_all_overloads_closure_ops}"
).format(
exe = "//tools:gen_operators_yaml" if IS_OSS else "fbsource//xplat/caffe2/tools:gen_operators_yaml",
rule_name = name,
@ -100,6 +109,8 @@ def pt_operator_library(
optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
not_include_all_overloads_static_root_ops = "--not_include_all_overloads_static_root_ops " if not_include_all_overloads_static_root_ops else "",
not_include_all_overloads_closure_ops = "--not_include_all_overloads_closure_ops " if not_include_all_overloads_closure_ops else "",
),
labels = labels + [
"pt_operator_library",