Files
pytorch/tools/test/gen_operators_yaml_test.py
Zhicheng Yan 01b662bafe [gen_operators_yaml] add arguments to control include_all_overloads (#108396)
Summary:
In SelectiveBuildOperator, we can specify argument `include_all_overloads`. If True, all overloaded operators (for example, `aten::to.dtype_layout`, `aten::to.prim_Device"` are considered as overloaded operators of `aten::to`), will be built and linked to the final binary. This can significantly increases the final binary size, which could be a deal breaker for on-device deployment.

In this diff, we make back-compatible changes to add new arguments `--not-include-all-overloads-static-root-ops` and `--not-include-all-overloads-closure-ops`. When they are set, we set `include_all_overloads` flag to False for static root ops and closure ops, and rely on code analyzer to decide the actual used overloaded operator.

Test Plan:
- unit test
```
buck test //xplat/caffe2/tools:gen_operators_yaml_test
```
- See test plan in D48771544 where we reduce the shared lib file `libmrengine.lib` from 16653072 bytes to 13686032 bytes.
- See detailed document: https://fburl.com/gdoc/mc93h6kb

Reviewed By: larryliu0820

Differential Revision: D48772302

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108396
Approved by: https://github.com/larryliu0820
2023-09-02 17:37:36 +00:00

247 lines
7.7 KiB
Python

#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
import argparse
import json
import unittest
from collections import defaultdict
from unittest.mock import Mock, patch
from gen_operators_yaml import (
fill_output,
get_parser_options,
make_filter_from_options,
verify_all_specified_present,
)
def _mock_options():
options = argparse.Namespace()
options.root_ops = "aten::add,aten::cat"
options.training_root_ops = []
options.output_path = "/tmp"
options.dep_graph_yaml_path = "dummy_pytorch_op_deps.yaml"
options.model_name = "test_model"
options.model_versions = None
options.model_assets = None
options.model_backends = None
options.models_yaml_path = None
options.include_all_operators = False
options.rule_name = "test_rule"
options.not_include_all_overloads_static_root_ops = True
options.not_include_all_overloads_closure_ops = True
return options
def _mock_load_op_dep_graph():
result = defaultdict(set)
result["aten::add"] = {"aten::add", "aten::as_strided_"}
result["aten::cat"] = {"aten::cat", "aten::as_strided_"}
return dict(result)
class GenOperatorsYAMLTest(unittest.TestCase):
def setUp(self):
pass
def test_filter_creation(self):
filter_func = make_filter_from_options(
model_name="abc",
model_versions=["100", "101"],
model_assets=None,
model_backends=None,
)
config = [
{
"model": {
"name": "abc",
"version": 100,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
"traced_operators": [],
},
{
"model": {
"name": "abc",
"version": 102,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
},
{
"model": {
"name": "abcd",
"version": 100,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
"traced_operators": [],
},
{
"model": {
"name": "abc",
"version": 101,
"asset": "asset-2",
"backend": "CPU",
},
"root_operators": [],
},
]
filtered_configs = list(filter(filter_func, config))
assert (
len(filtered_configs) == 2
), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
def test_verification_success(self):
filter_func = make_filter_from_options(
model_name="abc",
model_versions=["100", "101"],
model_assets=["asset-1", "asset-2"],
model_backends=None,
)
config = [
{
"model": {
"name": "abc",
"version": 100,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
"traced_operators": [],
},
{
"model": {
"name": "abc",
"version": 101,
"asset": "asset-2",
"backend": "CPU",
},
"root_operators": [],
},
]
filtered_configs = list(filter(filter_func, config))
try:
verify_all_specified_present(
model_assets=["asset-1", "asset-2"],
model_versions=["100", "101"],
selected_models_yaml=filtered_configs,
rule_name="test",
model_name="abc",
new_style_rule=True,
)
except Exception:
self.fail(
"expected verify_all_specified_present to succeed instead it raised an exception"
)
def test_verification_fail(self):
config = [
{
"model": {
"name": "abc",
"version": 100,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
"traced_operators": [],
},
{
"model": {
"name": "abc",
"version": 101,
"asset": "asset-2",
"backend": "CPU",
},
"root_operators": [],
},
]
good_assets = ["asset-1", "asset-2"]
good_versions = ["100", "101"]
good_name = "abc"
# Test bad asset
filter_func_bad_asset = make_filter_from_options(
model_name=good_name,
model_versions=good_versions,
model_assets=["asset-1", "asset-2", "asset-3"],
model_backends=None,
)
filtered_configs_asset = list(filter(filter_func_bad_asset, config))
with self.assertRaises(RuntimeError):
verify_all_specified_present(
model_assets=["asset-1", "asset-2", "asset-3"],
model_versions=good_versions,
selected_models_yaml=filtered_configs_asset,
rule_name="test",
model_name=good_name,
new_style_rule=True,
)
# Test bad version
filter_func_bad_version = make_filter_from_options(
model_name=good_name,
model_versions=["100", "101", "102"],
model_assets=good_assets,
model_backends=None,
)
filtered_configs_version = list(filter(filter_func_bad_version, config))
with self.assertRaises(RuntimeError):
verify_all_specified_present(
model_assets=good_assets,
model_versions=["100", "101", "102"],
selected_models_yaml=filtered_configs_version,
rule_name="test",
model_name=good_name,
new_style_rule=True,
)
# Test bad name
filter_func_bad_name = make_filter_from_options(
model_name="abcd",
model_versions=good_versions,
model_assets=good_assets,
model_backends=None,
)
filtered_configs_name = list(filter(filter_func_bad_name, config))
with self.assertRaises(RuntimeError):
verify_all_specified_present(
model_assets=good_assets,
model_versions=good_versions,
selected_models_yaml=filtered_configs_name,
rule_name="test",
model_name="abcd",
new_style_rule=True,
)
@patch("gen_operators_yaml.parse_options", return_value=_mock_options())
@patch(
"gen_operators_yaml.load_op_dep_graph", return_value=_mock_load_op_dep_graph()
)
def test_fill_output_with_arguments_not_include_all_overloads(
self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
):
parser = argparse.ArgumentParser(description="Generate used operators YAML")
options = get_parser_options(parser)
model_dict = {
"model_name": options.model_name,
"asset_info": {},
"is_new_style_rule": False,
}
output = {"debug_info": [json.dumps(model_dict)]}
fill_output(output, options)
for op_val in output["operators"].values():
self.assertFalse(op_val["include_all_overloads"])