[gen_operators_yaml] add arguments to control include_all_overloads (#108396)

Summary:
In SelectiveBuildOperator, we can specify argument `include_all_overloads`. If True, all overloaded operators (for example, `aten::to.dtype_layout`, `aten::to.prim_Device"` are considered as overloaded operators of `aten::to`), will be built and linked to the final binary. This can significantly increases the final binary size, which could be a deal breaker for on-device deployment.

In this diff, we make back-compatible changes to add new arguments `--not-include-all-overloads-static-root-ops` and `--not-include-all-overloads-closure-ops`. When they are set, we set `include_all_overloads` flag to False for static root ops and closure ops, and rely on code analyzer to decide the actual used overloaded operator.

Test Plan:
- unit test
```
buck test //xplat/caffe2/tools:gen_operators_yaml_test
```
- See test plan in D48771544 where we reduce the shared lib file `libmrengine.lib` from 16653072 bytes to 13686032 bytes.
- See detailed document: https://fburl.com/gdoc/mc93h6kb

Reviewed By: larryliu0820

Differential Revision: D48772302

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108396
Approved by: https://github.com/larryliu0820
This commit is contained in:
Zhicheng Yan
2023-09-02 17:37:36 +00:00
committed by PyTorch MergeBot
parent b9dfdc091b
commit 01b662bafe
3 changed files with 101 additions and 8 deletions

View File

@ -1,5 +1,5 @@
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
load("//tools/build_defs:expect.bzl", "expect")
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
@ -72,6 +72,13 @@ def pt_operator_library(
# was hand-crafted which is not a support workflow for traced ops.
yaml_option = "--models_yaml_path $(location fbsource//xplat/pytorch_models/build/{}/v{}:{})/{}.yaml".format(model_name, model_version, yaml_dep, model_asset)
not_include_all_overloads_static_root_ops = kwargs.pop(
"not_include_all_overloads_static_root_ops",
False,
)
not_include_all_overloads_closure_ops = kwargs.pop("not_include_all_overloads_closure_ops", False)
fb_xplat_genrule(
name = name,
out = "model_operators.yaml",
@ -87,7 +94,9 @@ def pt_operator_library(
"{optionally_model_versions} " +
"{optionally_model_assets} " +
"{optionally_model_traced_backends} " +
"{optionally_include_all_operators}"
"{optionally_include_all_operators}" +
"{not_include_all_overloads_static_root_ops}" +
"{not_include_all_overloads_closure_ops}"
).format(
exe = "//tools:gen_operators_yaml" if IS_OSS else "fbsource//xplat/caffe2/tools:gen_operators_yaml",
rule_name = name,
@ -100,6 +109,8 @@ def pt_operator_library(
optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
not_include_all_overloads_static_root_ops = "--not_include_all_overloads_static_root_ops " if not_include_all_overloads_static_root_ops else "",
not_include_all_overloads_closure_ops = "--not_include_all_overloads_closure_ops " if not_include_all_overloads_closure_ops else "",
),
labels = labels + [
"pt_operator_library",

View File

@ -348,7 +348,7 @@ def fill_output(output: Dict[str, object], options: object):
{
"is_root_operator": True,
"is_used_for_training": False,
"include_all_overloads": True,
"include_all_overloads": not options.not_include_all_overloads_static_root_ops,
"debug_info": [options.model_name],
},
)
@ -362,7 +362,7 @@ def fill_output(output: Dict[str, object], options: object):
{
"is_root_operator": False,
"is_used_for_training": False,
"include_all_overloads": True,
"include_all_overloads": not options.not_include_all_overloads_closure_ops,
"debug_info": [options.model_name],
},
)
@ -489,7 +489,7 @@ def fill_output(output: Dict[str, object], options: object):
output["kernel_metadata"] = kernel_metadata
def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
def add_arguments_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"--root-ops",
"--root_ops",
@ -567,8 +567,32 @@ def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
help="The name of pt_operator_library rule resulting in this generation",
required=True,
)
options = parser.parse_args()
return options
parser.add_argument(
"--not-include-all-overloads-static-root-ops",
"--not_include_all_overloads_static_root_ops",
action="store_true",
default=False,
help="Set this flag to not include all overloaded operators for static root ops bucket in fill_output() subroutine",
required=False,
)
parser.add_argument(
"--not-include-all-overloads-closure-ops",
"--not_include_all_overloads_closure_ops",
action="store_true",
default=False,
help="Set this flag to not include all overloaded operators for closure ops bucket in fill_output() subroutine",
required=False,
)
return parser
def parse_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
return parser.parse_args()
def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
parser = add_arguments_parser(parser)
return parse_options(parser)
def main(argv) -> None:

View File

@ -1,9 +1,45 @@
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
import argparse
import json
import unittest
from collections import defaultdict
from gen_operators_yaml import make_filter_from_options, verify_all_specified_present
from unittest.mock import Mock, patch
from gen_operators_yaml import (
fill_output,
get_parser_options,
make_filter_from_options,
verify_all_specified_present,
)
def _mock_options():
options = argparse.Namespace()
options.root_ops = "aten::add,aten::cat"
options.training_root_ops = []
options.output_path = "/tmp"
options.dep_graph_yaml_path = "dummy_pytorch_op_deps.yaml"
options.model_name = "test_model"
options.model_versions = None
options.model_assets = None
options.model_backends = None
options.models_yaml_path = None
options.include_all_operators = False
options.rule_name = "test_rule"
options.not_include_all_overloads_static_root_ops = True
options.not_include_all_overloads_closure_ops = True
return options
def _mock_load_op_dep_graph():
result = defaultdict(set)
result["aten::add"] = {"aten::add", "aten::as_strided_"}
result["aten::cat"] = {"aten::cat", "aten::as_strided_"}
return dict(result)
class GenOperatorsYAMLTest(unittest.TestCase):
@ -186,3 +222,25 @@ class GenOperatorsYAMLTest(unittest.TestCase):
model_name="abcd",
new_style_rule=True,
)
@patch("gen_operators_yaml.parse_options", return_value=_mock_options())
@patch(
"gen_operators_yaml.load_op_dep_graph", return_value=_mock_load_op_dep_graph()
)
def test_fill_output_with_arguments_not_include_all_overloads(
self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
):
parser = argparse.ArgumentParser(description="Generate used operators YAML")
options = get_parser_options(parser)
model_dict = {
"model_name": options.model_name,
"asset_info": {},
"is_new_style_rule": False,
}
output = {"debug_info": [json.dumps(model_dict)]}
fill_output(output, options)
for op_val in output["operators"].values():
self.assertFalse(op_val["include_all_overloads"])