mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[mobile] Fix lightweight dispatch OOM error by introducing selective build"
This reverts commit 272bdb1442ee3750861d9f2f10690cc3f1521b92. Reverted https://github.com/pytorch/pytorch/pull/78983 on behalf of https://github.com/osalpekar due to broke internal mobile tests
This commit is contained in:
@ -2,8 +2,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api import unboxing
|
||||
@ -157,9 +155,6 @@ def gen_unboxing(
|
||||
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
|
||||
return fn.root_name
|
||||
|
||||
selected_op_num: int = len(selector.operators)
|
||||
# a best practice threshold of operators to enable sharding
|
||||
sharding_threshold: int = 100
|
||||
cpu_fm.write_sharded(
|
||||
"UnboxingFunctions.cpp",
|
||||
native_functions,
|
||||
@ -167,7 +162,7 @@ def gen_unboxing(
|
||||
env_callable=lambda fn: {
|
||||
"definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)]
|
||||
},
|
||||
num_shards=1 if selected_op_num < sharding_threshold else 5,
|
||||
num_shards=5,
|
||||
sharded_keys={"definitions"},
|
||||
)
|
||||
cpu_fm.write(
|
||||
@ -188,7 +183,7 @@ def gen_unboxing(
|
||||
env_callable=lambda fn: {
|
||||
"unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]
|
||||
},
|
||||
num_shards=1 if selected_op_num < sharding_threshold else 10,
|
||||
num_shards=10,
|
||||
sharded_keys={"unboxed_ops"},
|
||||
)
|
||||
|
||||
@ -223,19 +218,17 @@ def main() -> None:
|
||||
"The operator names also contain the namespace prefix (e.g. aten::)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--TEST_ONLY_op_registration_allowlist_yaml_path",
|
||||
help="Provide a path to the operator selection (for custom build) YAML "
|
||||
"which contains a list of operators. It is to serve testing purpose and "
|
||||
"--op_registration_allowlist",
|
||||
nargs="*",
|
||||
help="filter op registrations by the allowlist (if set); "
|
||||
"each item is `namespace`::`operator name` without overload name; "
|
||||
"e.g.: aten::empty aten::conv2d ...",
|
||||
)
|
||||
|
||||
options = parser.parse_args()
|
||||
with open(options.TEST_ONLY_op_registration_allowlist_yaml_path, "r") as f:
|
||||
op_registration_allowlist = yaml.safe_load(f)
|
||||
|
||||
selector = get_custom_build_selector(
|
||||
op_registration_allowlist,
|
||||
options.op_registration_allowlist,
|
||||
options.op_selection_yaml_path,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user