mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[mobile] Fix lightweight dispatch OOM error by introducing selective build
This PR introduces selective build to lightweight dispatch CI job. By doing so we can't run the `test_lite_intepreter_runtime` test suite anymore because it requires some other operators. From now on, if we are adding a new unit test in `test_codegen_unboxing`, we will have to export the operators for the unit test model and add them into `lightweight_dispatch_ops.yaml`. This can be automated by introducing tracing based selective build, but that's for next PR to do. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78983 Approved by: https://github.com/kit1980
This commit is contained in:
@ -2,6 +2,8 @@
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api import unboxing
|
||||
@ -155,6 +157,9 @@ def gen_unboxing(
|
||||
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
|
||||
return fn.root_name
|
||||
|
||||
selected_op_num: int = len(selector.operators)
|
||||
# a best practice threshold of operators to enable sharding
|
||||
sharding_threshold: int = 100
|
||||
cpu_fm.write_sharded(
|
||||
"UnboxingFunctions.cpp",
|
||||
native_functions,
|
||||
@ -162,7 +167,7 @@ def gen_unboxing(
|
||||
env_callable=lambda fn: {
|
||||
"definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)]
|
||||
},
|
||||
num_shards=5,
|
||||
num_shards=1 if selected_op_num < sharding_threshold else 5,
|
||||
sharded_keys={"definitions"},
|
||||
)
|
||||
cpu_fm.write(
|
||||
@ -183,7 +188,7 @@ def gen_unboxing(
|
||||
env_callable=lambda fn: {
|
||||
"unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]
|
||||
},
|
||||
num_shards=10,
|
||||
num_shards=1 if selected_op_num < sharding_threshold else 10,
|
||||
sharded_keys={"unboxed_ops"},
|
||||
)
|
||||
|
||||
@ -218,17 +223,19 @@ def main() -> None:
|
||||
"The operator names also contain the namespace prefix (e.g. aten::)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--op_registration_allowlist",
|
||||
nargs="*",
|
||||
help="filter op registrations by the allowlist (if set); "
|
||||
"--TEST_ONLY_op_registration_allowlist_yaml_path",
|
||||
help="Provide a path to the operator selection (for custom build) YAML "
|
||||
"which contains a list of operators. It is to serve testing purpose and "
|
||||
"each item is `namespace`::`operator name` without overload name; "
|
||||
"e.g.: aten::empty aten::conv2d ...",
|
||||
)
|
||||
|
||||
options = parser.parse_args()
|
||||
with open(options.TEST_ONLY_op_registration_allowlist_yaml_path, "r") as f:
|
||||
op_registration_allowlist = yaml.safe_load(f)
|
||||
|
||||
selector = get_custom_build_selector(
|
||||
options.op_registration_allowlist,
|
||||
op_registration_allowlist,
|
||||
options.op_selection_yaml_path,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user