[mobile] Fix lightweight dispatch OOM error by introducing selective build

This PR introduces selective build to lightweight dispatch CI job. By doing so we can't run the `test_lite_intepreter_runtime` test suite anymore because it requires some other operators.

From now on, if we are adding a new unit test in `test_codegen_unboxing`, we will have to export the operators for the unit test model and add them into `lightweight_dispatch_ops.yaml`. This can be automated by introducing tracing based selective build, but that's for next PR to do.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78983

Approved by: https://github.com/kit1980
This commit is contained in:
PyTorch MergeBot
2022-06-07 22:19:30 +00:00
parent 99ffeff949
commit 272bdb1442
7 changed files with 89 additions and 25 deletions

View File

@ -2,6 +2,8 @@
import argparse
import os
import pathlib
import yaml
from dataclasses import dataclass
from torchgen.api import cpp
from torchgen.api import unboxing
@ -155,6 +157,9 @@ def gen_unboxing(
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
return fn.root_name
selected_op_num: int = len(selector.operators)
# a best practice threshold of operators to enable sharding
sharding_threshold: int = 100
cpu_fm.write_sharded(
"UnboxingFunctions.cpp",
native_functions,
@ -162,7 +167,7 @@ def gen_unboxing(
env_callable=lambda fn: {
"definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)]
},
num_shards=5,
num_shards=1 if selected_op_num < sharding_threshold else 5,
sharded_keys={"definitions"},
)
cpu_fm.write(
@ -183,7 +188,7 @@ def gen_unboxing(
env_callable=lambda fn: {
"unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]
},
num_shards=10,
num_shards=1 if selected_op_num < sharding_threshold else 10,
sharded_keys={"unboxed_ops"},
)
@ -218,17 +223,19 @@ def main() -> None:
"The operator names also contain the namespace prefix (e.g. aten::)",
)
parser.add_argument(
"--op_registration_allowlist",
nargs="*",
help="filter op registrations by the allowlist (if set); "
"--TEST_ONLY_op_registration_allowlist_yaml_path",
help="Provide a path to the operator selection (for custom build) YAML "
"which contains a list of operators. It is to serve testing purpose and "
"each item is `namespace`::`operator name` without overload name; "
"e.g.: aten::empty aten::conv2d ...",
)
options = parser.parse_args()
with open(options.TEST_ONLY_op_registration_allowlist_yaml_path, "r") as f:
op_registration_allowlist = yaml.safe_load(f)
selector = get_custom_build_selector(
options.op_registration_allowlist,
op_registration_allowlist,
options.op_selection_yaml_path,
)