Revert "[mobile] Fix lightweight dispatch OOM error by introducing selective build"

This reverts commit 272bdb1442ee3750861d9f2f10690cc3f1521b92.

Reverted https://github.com/pytorch/pytorch/pull/78983 on behalf of https://github.com/osalpekar due to broke internal mobile tests
This commit is contained in:
PyTorch MergeBot
2022-06-09 05:16:42 +00:00
parent cb04cc0aa7
commit c3e089a047
7 changed files with 25 additions and 89 deletions

View File

@ -2,8 +2,6 @@
import argparse
import os
import pathlib
import yaml
from dataclasses import dataclass
from torchgen.api import cpp
from torchgen.api import unboxing
@ -157,9 +155,6 @@ def gen_unboxing(
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
return fn.root_name
selected_op_num: int = len(selector.operators)
# a best practice threshold of operators to enable sharding
sharding_threshold: int = 100
cpu_fm.write_sharded(
"UnboxingFunctions.cpp",
native_functions,
@ -167,7 +162,7 @@ def gen_unboxing(
env_callable=lambda fn: {
"definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)]
},
num_shards=1 if selected_op_num < sharding_threshold else 5,
num_shards=5,
sharded_keys={"definitions"},
)
cpu_fm.write(
@ -188,7 +183,7 @@ def gen_unboxing(
env_callable=lambda fn: {
"unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]
},
num_shards=1 if selected_op_num < sharding_threshold else 10,
num_shards=10,
sharded_keys={"unboxed_ops"},
)
@ -223,19 +218,17 @@ def main() -> None:
"The operator names also contain the namespace prefix (e.g. aten::)",
)
parser.add_argument(
"--TEST_ONLY_op_registration_allowlist_yaml_path",
help="Provide a path to the operator selection (for custom build) YAML "
"which contains a list of operators. It is to serve testing purpose and "
"--op_registration_allowlist",
nargs="*",
help="filter op registrations by the allowlist (if set); "
"each item is `namespace`::`operator name` without overload name; "
"e.g.: aten::empty aten::conv2d ...",
)
options = parser.parse_args()
with open(options.TEST_ONLY_op_registration_allowlist_yaml_path, "r") as f:
op_registration_allowlist = yaml.safe_load(f)
selector = get_custom_build_selector(
op_registration_allowlist,
options.op_registration_allowlist,
options.op_selection_yaml_path,
)