mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[gen_operators_yaml] add arguments to control include_all_overloads (#108396)
Summary: In SelectiveBuildOperator, we can specify argument `include_all_overloads`. If True, all overloaded operators (for example, `aten::to.dtype_layout`, `aten::to.prim_Device"` are considered as overloaded operators of `aten::to`), will be built and linked to the final binary. This can significantly increases the final binary size, which could be a deal breaker for on-device deployment. In this diff, we make back-compatible changes to add new arguments `--not-include-all-overloads-static-root-ops` and `--not-include-all-overloads-closure-ops`. When they are set, we set `include_all_overloads` flag to False for static root ops and closure ops, and rely on code analyzer to decide the actual used overloaded operator. Test Plan: - unit test ``` buck test //xplat/caffe2/tools:gen_operators_yaml_test ``` - See test plan in D48771544 where we reduce the shared lib file `libmrengine.lib` from 16653072 bytes to 13686032 bytes. - See detailed document: https://fburl.com/gdoc/mc93h6kb Reviewed By: larryliu0820 Differential Revision: D48772302 Pull Request resolved: https://github.com/pytorch/pytorch/pull/108396 Approved by: https://github.com/larryliu0820
This commit is contained in:
committed by
PyTorch MergeBot
parent
b9dfdc091b
commit
01b662bafe
15
pt_ops.bzl
15
pt_ops.bzl
@ -1,5 +1,5 @@
|
||||
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
|
||||
load("//tools/build_defs:expect.bzl", "expect")
|
||||
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
|
||||
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||
load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
|
||||
|
||||
@ -72,6 +72,13 @@ def pt_operator_library(
|
||||
# was hand-crafted which is not a support workflow for traced ops.
|
||||
yaml_option = "--models_yaml_path $(location fbsource//xplat/pytorch_models/build/{}/v{}:{})/{}.yaml".format(model_name, model_version, yaml_dep, model_asset)
|
||||
|
||||
not_include_all_overloads_static_root_ops = kwargs.pop(
|
||||
"not_include_all_overloads_static_root_ops",
|
||||
False,
|
||||
)
|
||||
|
||||
not_include_all_overloads_closure_ops = kwargs.pop("not_include_all_overloads_closure_ops", False)
|
||||
|
||||
fb_xplat_genrule(
|
||||
name = name,
|
||||
out = "model_operators.yaml",
|
||||
@ -87,7 +94,9 @@ def pt_operator_library(
|
||||
"{optionally_model_versions} " +
|
||||
"{optionally_model_assets} " +
|
||||
"{optionally_model_traced_backends} " +
|
||||
"{optionally_include_all_operators}"
|
||||
"{optionally_include_all_operators}" +
|
||||
"{not_include_all_overloads_static_root_ops}" +
|
||||
"{not_include_all_overloads_closure_ops}"
|
||||
).format(
|
||||
exe = "//tools:gen_operators_yaml" if IS_OSS else "fbsource//xplat/caffe2/tools:gen_operators_yaml",
|
||||
rule_name = name,
|
||||
@ -100,6 +109,8 @@ def pt_operator_library(
|
||||
optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
|
||||
optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
|
||||
optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
|
||||
not_include_all_overloads_static_root_ops = "--not_include_all_overloads_static_root_ops " if not_include_all_overloads_static_root_ops else "",
|
||||
not_include_all_overloads_closure_ops = "--not_include_all_overloads_closure_ops " if not_include_all_overloads_closure_ops else "",
|
||||
),
|
||||
labels = labels + [
|
||||
"pt_operator_library",
|
||||
|
@ -348,7 +348,7 @@ def fill_output(output: Dict[str, object], options: object):
|
||||
{
|
||||
"is_root_operator": True,
|
||||
"is_used_for_training": False,
|
||||
"include_all_overloads": True,
|
||||
"include_all_overloads": not options.not_include_all_overloads_static_root_ops,
|
||||
"debug_info": [options.model_name],
|
||||
},
|
||||
)
|
||||
@ -362,7 +362,7 @@ def fill_output(output: Dict[str, object], options: object):
|
||||
{
|
||||
"is_root_operator": False,
|
||||
"is_used_for_training": False,
|
||||
"include_all_overloads": True,
|
||||
"include_all_overloads": not options.not_include_all_overloads_closure_ops,
|
||||
"debug_info": [options.model_name],
|
||||
},
|
||||
)
|
||||
@ -489,7 +489,7 @@ def fill_output(output: Dict[str, object], options: object):
|
||||
output["kernel_metadata"] = kernel_metadata
|
||||
|
||||
|
||||
def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
||||
def add_arguments_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--root-ops",
|
||||
"--root_ops",
|
||||
@ -567,8 +567,32 @@ def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
||||
help="The name of pt_operator_library rule resulting in this generation",
|
||||
required=True,
|
||||
)
|
||||
options = parser.parse_args()
|
||||
return options
|
||||
parser.add_argument(
|
||||
"--not-include-all-overloads-static-root-ops",
|
||||
"--not_include_all_overloads_static_root_ops",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Set this flag to not include all overloaded operators for static root ops bucket in fill_output() subroutine",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--not-include-all-overloads-closure-ops",
|
||||
"--not_include_all_overloads_closure_ops",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Set this flag to not include all overloaded operators for closure ops bucket in fill_output() subroutine",
|
||||
required=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def parse_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
||||
parser = add_arguments_parser(parser)
|
||||
return parse_options(parser)
|
||||
|
||||
|
||||
def main(argv) -> None:
|
||||
|
@ -1,9 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2004-present Facebook. All Rights Reserved.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
|
||||
from gen_operators_yaml import make_filter_from_options, verify_all_specified_present
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from gen_operators_yaml import (
|
||||
fill_output,
|
||||
get_parser_options,
|
||||
make_filter_from_options,
|
||||
verify_all_specified_present,
|
||||
)
|
||||
|
||||
|
||||
def _mock_options():
|
||||
options = argparse.Namespace()
|
||||
options.root_ops = "aten::add,aten::cat"
|
||||
options.training_root_ops = []
|
||||
options.output_path = "/tmp"
|
||||
options.dep_graph_yaml_path = "dummy_pytorch_op_deps.yaml"
|
||||
options.model_name = "test_model"
|
||||
options.model_versions = None
|
||||
options.model_assets = None
|
||||
options.model_backends = None
|
||||
options.models_yaml_path = None
|
||||
options.include_all_operators = False
|
||||
options.rule_name = "test_rule"
|
||||
options.not_include_all_overloads_static_root_ops = True
|
||||
options.not_include_all_overloads_closure_ops = True
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _mock_load_op_dep_graph():
|
||||
result = defaultdict(set)
|
||||
result["aten::add"] = {"aten::add", "aten::as_strided_"}
|
||||
result["aten::cat"] = {"aten::cat", "aten::as_strided_"}
|
||||
return dict(result)
|
||||
|
||||
|
||||
class GenOperatorsYAMLTest(unittest.TestCase):
|
||||
@ -186,3 +222,25 @@ class GenOperatorsYAMLTest(unittest.TestCase):
|
||||
model_name="abcd",
|
||||
new_style_rule=True,
|
||||
)
|
||||
|
||||
@patch("gen_operators_yaml.parse_options", return_value=_mock_options())
|
||||
@patch(
|
||||
"gen_operators_yaml.load_op_dep_graph", return_value=_mock_load_op_dep_graph()
|
||||
)
|
||||
def test_fill_output_with_arguments_not_include_all_overloads(
|
||||
self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
|
||||
):
|
||||
parser = argparse.ArgumentParser(description="Generate used operators YAML")
|
||||
options = get_parser_options(parser)
|
||||
|
||||
model_dict = {
|
||||
"model_name": options.model_name,
|
||||
"asset_info": {},
|
||||
"is_new_style_rule": False,
|
||||
}
|
||||
output = {"debug_info": [json.dumps(model_dict)]}
|
||||
|
||||
fill_output(output, options)
|
||||
|
||||
for op_val in output["operators"].values():
|
||||
self.assertFalse(op_val["include_all_overloads"])
|
||||
|
Reference in New Issue
Block a user