mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: In SelectiveBuildOperator, we can specify argument `include_all_overloads`. If True, all overloaded operators (for example, `aten::to.dtype_layout`, `aten::to.prim_Device"` are considered as overloaded operators of `aten::to`), will be built and linked to the final binary. This can significantly increases the final binary size, which could be a deal breaker for on-device deployment. In this diff, we make back-compatible changes to add new arguments `--not-include-all-overloads-static-root-ops` and `--not-include-all-overloads-closure-ops`. When they are set, we set `include_all_overloads` flag to False for static root ops and closure ops, and rely on code analyzer to decide the actual used overloaded operator. Test Plan: - unit test ``` buck test //xplat/caffe2/tools:gen_operators_yaml_test ``` - See test plan in D48771544 where we reduce the shared lib file `libmrengine.lib` from 16653072 bytes to 13686032 bytes. - See detailed document: https://fburl.com/gdoc/mc93h6kb Reviewed By: larryliu0820 Differential Revision: D48772302 Pull Request resolved: https://github.com/pytorch/pytorch/pull/108396 Approved by: https://github.com/larryliu0820
247 lines
7.7 KiB
Python
247 lines
7.7 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright 2004-present Facebook. All Rights Reserved.
|
|
|
|
import argparse
|
|
import json
|
|
import unittest
|
|
from collections import defaultdict
|
|
|
|
from unittest.mock import Mock, patch
|
|
|
|
from gen_operators_yaml import (
|
|
fill_output,
|
|
get_parser_options,
|
|
make_filter_from_options,
|
|
verify_all_specified_present,
|
|
)
|
|
|
|
|
|
def _mock_options():
|
|
options = argparse.Namespace()
|
|
options.root_ops = "aten::add,aten::cat"
|
|
options.training_root_ops = []
|
|
options.output_path = "/tmp"
|
|
options.dep_graph_yaml_path = "dummy_pytorch_op_deps.yaml"
|
|
options.model_name = "test_model"
|
|
options.model_versions = None
|
|
options.model_assets = None
|
|
options.model_backends = None
|
|
options.models_yaml_path = None
|
|
options.include_all_operators = False
|
|
options.rule_name = "test_rule"
|
|
options.not_include_all_overloads_static_root_ops = True
|
|
options.not_include_all_overloads_closure_ops = True
|
|
|
|
return options
|
|
|
|
|
|
def _mock_load_op_dep_graph():
|
|
result = defaultdict(set)
|
|
result["aten::add"] = {"aten::add", "aten::as_strided_"}
|
|
result["aten::cat"] = {"aten::cat", "aten::as_strided_"}
|
|
return dict(result)
|
|
|
|
|
|
class GenOperatorsYAMLTest(unittest.TestCase):
|
|
def setUp(self):
|
|
pass
|
|
|
|
def test_filter_creation(self):
|
|
filter_func = make_filter_from_options(
|
|
model_name="abc",
|
|
model_versions=["100", "101"],
|
|
model_assets=None,
|
|
model_backends=None,
|
|
)
|
|
config = [
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 100,
|
|
"asset": "asset-1",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
"traced_operators": [],
|
|
},
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 102,
|
|
"asset": "asset-1",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
},
|
|
{
|
|
"model": {
|
|
"name": "abcd",
|
|
"version": 100,
|
|
"asset": "asset-1",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
"traced_operators": [],
|
|
},
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 101,
|
|
"asset": "asset-2",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
},
|
|
]
|
|
|
|
filtered_configs = list(filter(filter_func, config))
|
|
assert (
|
|
len(filtered_configs) == 2
|
|
), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
|
|
|
|
def test_verification_success(self):
|
|
filter_func = make_filter_from_options(
|
|
model_name="abc",
|
|
model_versions=["100", "101"],
|
|
model_assets=["asset-1", "asset-2"],
|
|
model_backends=None,
|
|
)
|
|
config = [
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 100,
|
|
"asset": "asset-1",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
"traced_operators": [],
|
|
},
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 101,
|
|
"asset": "asset-2",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
},
|
|
]
|
|
filtered_configs = list(filter(filter_func, config))
|
|
try:
|
|
verify_all_specified_present(
|
|
model_assets=["asset-1", "asset-2"],
|
|
model_versions=["100", "101"],
|
|
selected_models_yaml=filtered_configs,
|
|
rule_name="test",
|
|
model_name="abc",
|
|
new_style_rule=True,
|
|
)
|
|
except Exception:
|
|
self.fail(
|
|
"expected verify_all_specified_present to succeed instead it raised an exception"
|
|
)
|
|
|
|
def test_verification_fail(self):
|
|
config = [
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 100,
|
|
"asset": "asset-1",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
"traced_operators": [],
|
|
},
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 101,
|
|
"asset": "asset-2",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
},
|
|
]
|
|
|
|
good_assets = ["asset-1", "asset-2"]
|
|
good_versions = ["100", "101"]
|
|
good_name = "abc"
|
|
|
|
# Test bad asset
|
|
filter_func_bad_asset = make_filter_from_options(
|
|
model_name=good_name,
|
|
model_versions=good_versions,
|
|
model_assets=["asset-1", "asset-2", "asset-3"],
|
|
model_backends=None,
|
|
)
|
|
filtered_configs_asset = list(filter(filter_func_bad_asset, config))
|
|
with self.assertRaises(RuntimeError):
|
|
verify_all_specified_present(
|
|
model_assets=["asset-1", "asset-2", "asset-3"],
|
|
model_versions=good_versions,
|
|
selected_models_yaml=filtered_configs_asset,
|
|
rule_name="test",
|
|
model_name=good_name,
|
|
new_style_rule=True,
|
|
)
|
|
|
|
# Test bad version
|
|
filter_func_bad_version = make_filter_from_options(
|
|
model_name=good_name,
|
|
model_versions=["100", "101", "102"],
|
|
model_assets=good_assets,
|
|
model_backends=None,
|
|
)
|
|
filtered_configs_version = list(filter(filter_func_bad_version, config))
|
|
with self.assertRaises(RuntimeError):
|
|
verify_all_specified_present(
|
|
model_assets=good_assets,
|
|
model_versions=["100", "101", "102"],
|
|
selected_models_yaml=filtered_configs_version,
|
|
rule_name="test",
|
|
model_name=good_name,
|
|
new_style_rule=True,
|
|
)
|
|
|
|
# Test bad name
|
|
filter_func_bad_name = make_filter_from_options(
|
|
model_name="abcd",
|
|
model_versions=good_versions,
|
|
model_assets=good_assets,
|
|
model_backends=None,
|
|
)
|
|
filtered_configs_name = list(filter(filter_func_bad_name, config))
|
|
with self.assertRaises(RuntimeError):
|
|
verify_all_specified_present(
|
|
model_assets=good_assets,
|
|
model_versions=good_versions,
|
|
selected_models_yaml=filtered_configs_name,
|
|
rule_name="test",
|
|
model_name="abcd",
|
|
new_style_rule=True,
|
|
)
|
|
|
|
@patch("gen_operators_yaml.parse_options", return_value=_mock_options())
|
|
@patch(
|
|
"gen_operators_yaml.load_op_dep_graph", return_value=_mock_load_op_dep_graph()
|
|
)
|
|
def test_fill_output_with_arguments_not_include_all_overloads(
|
|
self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
|
|
):
|
|
parser = argparse.ArgumentParser(description="Generate used operators YAML")
|
|
options = get_parser_options(parser)
|
|
|
|
model_dict = {
|
|
"model_name": options.model_name,
|
|
"asset_info": {},
|
|
"is_new_style_rule": False,
|
|
}
|
|
output = {"debug_info": [json.dumps(model_dict)]}
|
|
|
|
fill_output(output, options)
|
|
|
|
for op_val in output["operators"].values():
|
|
self.assertFalse(op_val["include_all_overloads"])
|