mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[retake2][mobile] Fix lightweight dispatch OOM error by introducing selective build (#80791)
To fix #78540 I committed #78983 which is reverted due to internal CI failure. Then I comitted #79215 which was only fixing the failure but didn't have the full feature of #78983. This PR is another try. This PR adds script to dump all operators from test models and automatically write into `lightweight_dispatch_ops.yaml`. This way we don't have to manually update the yaml file. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80791 Approved by: https://github.com/raziel
This commit is contained in:
committed by
PyTorch MergeBot
parent
5139053e02
commit
e345138591
16
.github/workflows/pull.yml
vendored
16
.github/workflows/pull.yml
vendored
@ -280,15 +280,13 @@ jobs:
|
||||
build-environment: linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit
|
||||
docker-image-name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/78540
|
||||
# linux-focal-py3_7-gcc7-mobile-lightweight-dispatch-build:
|
||||
# if: false()
|
||||
# name: linux-focal-py3.7-gcc7-mobile-lightweight-dispatch-build
|
||||
# uses: ./.github/workflows/_linux-build.yml
|
||||
# with:
|
||||
# build-environment: linux-focal-py3.7-gcc7-mobile-lightweight-dispatch-build
|
||||
# docker-image-name: pytorch-linux-focal-py3.7-gcc7
|
||||
# build-generates-artifacts: false
|
||||
linux-focal-py3_7-gcc7-mobile-lightweight-dispatch-build:
|
||||
name: linux-focal-py3.7-gcc7-mobile-lightweight-dispatch-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-focal-py3.7-gcc7-mobile-lightweight-dispatch-build
|
||||
docker-image-name: pytorch-linux-focal-py3.7-gcc7
|
||||
build-generates-artifacts: false
|
||||
|
||||
linux-xenial-cuda11_3-py3_7-gcc7-deploy-build:
|
||||
name: linux-xenial-cuda11_3-py3_7-gcc7-deploy
|
||||
|
@ -118,6 +118,10 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
--source-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen
|
||||
--install_dir ${CMAKE_BINARY_DIR}/aten/src/ATen
|
||||
)
|
||||
if(SELECTED_OP_LIST)
|
||||
list(APPEND GEN_UNBOXING_COMMAND
|
||||
--TEST_ONLY_op_registration_allowlist_yaml_path "${SELECTED_OP_LIST}")
|
||||
endif()
|
||||
set("GEN_UNBOXING_COMMAND_sources"
|
||||
${GEN_UNBOXING_COMMAND}
|
||||
--output-dependencies ${CMAKE_BINARY_DIR}/aten/src/ATen/generated_unboxing_sources.cmake
|
||||
|
@ -19,13 +19,16 @@ TEST_SRC_ROOT="$PWD/test/mobile/lightweight_dispatch"
|
||||
pushd "$CUSTOM_TEST_ARTIFACT_BUILD_DIR"
|
||||
|
||||
# prepare test
|
||||
python "$TEST_SRC_ROOT/tests_setup.py" setup
|
||||
OP_LIST="lightweight_dispatch_ops.yaml"
|
||||
export SELECTED_OP_LIST=$TEST_SRC_ROOT/$OP_LIST
|
||||
python "$TEST_SRC_ROOT/tests_setup.py" setup "$SELECTED_OP_LIST"
|
||||
|
||||
export USE_DISTRIBUTED=0
|
||||
export USE_LIGHTWEIGHT_DISPATCH=1
|
||||
export STATIC_DISPATCH_BACKEND="CPU"
|
||||
export BUILD_LITE_INTERPRETER=1
|
||||
|
||||
export USE_FBGEMM=0
|
||||
python "${BUILD_LIBTORCH_PY}"
|
||||
ret=$?
|
||||
|
||||
@ -42,13 +45,7 @@ if ! build/bin/test_codegen_unboxing; then
|
||||
fi
|
||||
|
||||
# shutdown test
|
||||
python "$TEST_SRC_ROOT/tests_setup.py" shutdown
|
||||
|
||||
# run lite interpreter tests
|
||||
if ! build/bin/test_lite_interpreter_runtime; then
|
||||
echo "test_lite_interpreter_runtime has failure!"
|
||||
exit 1
|
||||
fi
|
||||
python "$TEST_SRC_ROOT/tests_setup.py" shutdown "$SELECTED_OP_LIST"
|
||||
|
||||
popd
|
||||
|
||||
|
@ -0,0 +1,6 @@
|
||||
# base ops for preparing inputs
|
||||
- aten::copy_
|
||||
- aten::detach
|
||||
- aten::fill_.Tensor
|
||||
- aten::to.device
|
||||
# model introduced ops begin from here
|
@ -60,10 +60,10 @@ namespace jit {
|
||||
namespace mobile {
|
||||
// covers int[], ScalarType?, Layout?, Device?, bool?
|
||||
TEST(LiteInterpreterTest, Ones) {
|
||||
// Load check in model: ones.ptl
|
||||
auto testModelFile = "ones.ptl";
|
||||
// Load check in model: ModelWithDTypeDeviceLayoutPinMemory.ptl
|
||||
auto testModelFile = "ModelWithDTypeDeviceLayoutPinMemory.ptl";
|
||||
|
||||
// class Model(torch.nn.Module):
|
||||
// class ModelWithDTypeDeviceLayoutPinMemory(torch.nn.Module):
|
||||
// def forward(self, x: int):
|
||||
// a = torch.ones([3, x], dtype=torch.int64, layout=torch.strided, device="cpu")
|
||||
// return a
|
||||
@ -75,10 +75,10 @@ TEST(LiteInterpreterTest, Ones) {
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, Index) {
|
||||
// Load check in model: index.ptl
|
||||
auto testModelFile = "index.ptl";
|
||||
// Load check in model: ModelWithTensorOptional.ptl
|
||||
auto testModelFile = "ModelWithTensorOptional.ptl";
|
||||
|
||||
// class Model(torch.nn.Module):
|
||||
// class ModelWithTensorOptional(torch.nn.Module):
|
||||
// def forward(self, index):
|
||||
// a = torch.zeros(2, 2)
|
||||
// a[0][1] = 1
|
||||
@ -98,10 +98,10 @@ TEST(LiteInterpreterTest, Index) {
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, Gradient) {
|
||||
// Load check in model: gradient.ptl
|
||||
auto testModelFile = "gradient.ptl";
|
||||
// Load check in model: ModelWithScalarList.ptl
|
||||
auto testModelFile = "ModelWithScalarList.ptl";
|
||||
|
||||
// class Model(torch.nn.Module):
|
||||
// class ModelWithScalarList(torch.nn.Module):
|
||||
// def forward(self, a: int):
|
||||
// values = torch.tensor([4., 1., 1., 16.], )
|
||||
// if a == 0:
|
||||
@ -120,8 +120,8 @@ TEST(LiteInterpreterTest, Gradient) {
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, Upsample) {
|
||||
// Load check in model: upsample.ptl
|
||||
auto testModelFile = "upsample.ptl";
|
||||
// Load check in model: ModelWithFloatList.ptl
|
||||
auto testModelFile = "ModelWithFloatList.ptl";
|
||||
|
||||
// model = torch.nn.Upsample(scale_factor=(2.0,), mode="linear")
|
||||
Module bc = _load_for_mobile(testModelFile);
|
||||
@ -132,10 +132,10 @@ TEST(LiteInterpreterTest, Upsample) {
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, IndexTensor) {
|
||||
// Load check in model: Index_Tensor.ptl
|
||||
auto testModelFile = "index_Tensor.ptl";
|
||||
// Load check in model: ModelWithListOfOptionalTensors.ptl
|
||||
auto testModelFile = "ModelWithListOfOptionalTensors.ptl";
|
||||
|
||||
// class Model(torch.nn.Module):
|
||||
// class ModelWithListOfOptionalTensors(torch.nn.Module):
|
||||
// def forward(self, index):
|
||||
// values = torch.tensor([4., 1., 1., 16.], )
|
||||
// return values[[index, torch.tensor(0)]]
|
||||
@ -147,8 +147,8 @@ TEST(LiteInterpreterTest, IndexTensor) {
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, Conv2d) {
|
||||
// Load check in model: conv2d.ptl
|
||||
auto testModelFile = "conv2d.ptl";
|
||||
// Load check in model: ModelWithArrayOfInt.ptl
|
||||
auto testModelFile = "ModelWithArrayOfInt.ptl";
|
||||
|
||||
// model = torch.nn.Conv2d(1, 2, (2, 2), stride=(1, 1), padding=(1, 1))
|
||||
Module bc = _load_for_mobile(testModelFile);
|
||||
@ -158,10 +158,10 @@ TEST(LiteInterpreterTest, Conv2d) {
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, AddTensor) {
|
||||
// Load check in model: add_Tensor.ptl
|
||||
auto testModelFile = "add_Tensor.ptl";
|
||||
// Load check in model: ModelWithTensors.ptl
|
||||
auto testModelFile = "ModelWithTensors.ptl";
|
||||
|
||||
// class Model(torch.nn.Module):
|
||||
// class ModelWithTensors(torch.nn.Module):
|
||||
// def forward(self, a):
|
||||
// values = torch.ones(size=[2, 3], names=['N', 'C'])
|
||||
// values[0][0] = a[0]
|
||||
@ -174,10 +174,10 @@ TEST(LiteInterpreterTest, AddTensor) {
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, DivideTensor) {
|
||||
// Load check in model: add_Tensor.ptl
|
||||
auto testModelFile = "divide_Tensor.ptl";
|
||||
// Load check in model: ModelWithStringOptional.ptl
|
||||
auto testModelFile = "ModelWithStringOptional.ptl";
|
||||
|
||||
// class Model(torch.nn.Module):
|
||||
// class ModelWithStringOptional(torch.nn.Module):
|
||||
// def forward(self, b):
|
||||
// a = torch.tensor(3, dtype=torch.int64)
|
||||
// out = torch.empty(size=[1], dtype=torch.float)
|
||||
@ -193,10 +193,10 @@ TEST(LiteInterpreterTest, DivideTensor) {
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, MultipleOps) {
|
||||
// Load check in model: multiple_ops.ptl
|
||||
auto testModelFile = "multiple_ops.ptl";
|
||||
// Load check in model: ModelWithMultipleOps.ptl
|
||||
auto testModelFile = "ModelWithMultipleOps.ptl";
|
||||
|
||||
// class Model(torch.nn.Module):
|
||||
// class ModelWithMultipleOps(torch.nn.Module):
|
||||
// def __init__(self):
|
||||
// super(Model, self).__init__()
|
||||
// self.ops = torch.nn.Sequential(
|
||||
|
@ -1,203 +1,143 @@
|
||||
import functools
|
||||
import os
|
||||
from io import BytesIO
|
||||
import shutil
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
|
||||
|
||||
_OPERATORS = set()
|
||||
_FILENAMES = []
|
||||
_MODELS = []
|
||||
|
||||
|
||||
class Setup(object):
|
||||
def setup(self):
|
||||
raise NotImplementedError()
|
||||
def save_model(cls):
|
||||
"""Save a model and dump all the ops"""
|
||||
|
||||
def shutdown(self):
|
||||
raise NotImplementedError()
|
||||
@functools.wraps(cls)
|
||||
def wrapper_save():
|
||||
_MODELS.append(cls)
|
||||
model = cls()
|
||||
scripted = torch.jit.script(model)
|
||||
buffer = BytesIO(scripted._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
mobile_module = _load_for_lite_interpreter(buffer)
|
||||
ops = _export_operator_list(mobile_module)
|
||||
_OPERATORS.update(ops)
|
||||
path = f"./{cls.__name__}.ptl"
|
||||
_FILENAMES.append(path)
|
||||
scripted._save_for_lite_interpreter(path)
|
||||
|
||||
return wrapper_save
|
||||
|
||||
|
||||
class FileSetup(object):
|
||||
path = None
|
||||
|
||||
def shutdown(self):
|
||||
if os.path.exists(self.path):
|
||||
os.remove(self.path)
|
||||
pass
|
||||
@save_model
|
||||
class ModelWithDTypeDeviceLayoutPinMemory(torch.nn.Module):
|
||||
def forward(self, x: int):
|
||||
a = torch.ones(size=[3, x], dtype=torch.int64, layout=torch.strided, device="cpu", pin_memory=False)
|
||||
return a
|
||||
|
||||
|
||||
class ModelWithDTypeDeviceLayoutPinMemory(FileSetup):
|
||||
path = 'ones.ptl'
|
||||
|
||||
def setup(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x: int):
|
||||
a = torch.ones(size=[3, x], dtype=torch.int64, layout=torch.strided, device="cpu", pin_memory=False)
|
||||
return a
|
||||
|
||||
model = Model()
|
||||
|
||||
# Script the model and save
|
||||
script_model = torch.jit.script(model)
|
||||
script_model._save_for_lite_interpreter(self.path)
|
||||
|
||||
|
||||
class ModelWithTensorOptional(FileSetup):
|
||||
path = 'index.ptl'
|
||||
|
||||
def setup(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, index):
|
||||
a = torch.zeros(2, 2)
|
||||
a[0][1] = 1
|
||||
a[1][0] = 2
|
||||
a[1][1] = 3
|
||||
return a[index]
|
||||
|
||||
model = Model()
|
||||
|
||||
# Script the model and save
|
||||
script_model = torch.jit.script(model)
|
||||
script_model._save_for_lite_interpreter(self.path)
|
||||
@save_model
|
||||
class ModelWithTensorOptional(torch.nn.Module):
|
||||
def forward(self, index):
|
||||
a = torch.zeros(2, 2)
|
||||
a[0][1] = 1
|
||||
a[1][0] = 2
|
||||
a[1][1] = 3
|
||||
return a[index]
|
||||
|
||||
|
||||
# gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[]
|
||||
class ModelWithScalarList(FileSetup):
|
||||
path = 'gradient.ptl'
|
||||
|
||||
def setup(self):
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, a: int):
|
||||
values = torch.tensor([4., 1., 1., 16.], )
|
||||
if a == 0:
|
||||
return torch.gradient(values, spacing=torch.scalar_tensor(2., dtype=torch.float64))
|
||||
elif a == 1:
|
||||
return torch.gradient(values, spacing=[torch.tensor(1.).item()])
|
||||
|
||||
model = Model()
|
||||
|
||||
# Script the model and save
|
||||
script_model = torch.jit.script(model)
|
||||
script_model._save_for_lite_interpreter(self.path)
|
||||
@save_model
|
||||
class ModelWithScalarList(torch.nn.Module):
|
||||
def forward(self, a: int):
|
||||
values = torch.tensor([4., 1., 1., 16.], )
|
||||
if a == 0:
|
||||
return torch.gradient(values, spacing=torch.scalar_tensor(2., dtype=torch.float64))
|
||||
elif a == 1:
|
||||
return torch.gradient(values, spacing=[torch.tensor(1.).item()])
|
||||
|
||||
|
||||
# upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
|
||||
class ModelWithFloatList(FileSetup):
|
||||
path = 'upsample.ptl'
|
||||
|
||||
def setup(self):
|
||||
model = torch.nn.Upsample(scale_factor=(2.0,), mode="linear", align_corners=False, recompute_scale_factor=True)
|
||||
|
||||
# Script the model and save
|
||||
script_model = torch.jit.script(model)
|
||||
script_model._save_for_lite_interpreter(self.path)
|
||||
@save_model
|
||||
class ModelWithFloatList(torch.nn.Upsample):
|
||||
def __init__(self):
|
||||
super().__init__(scale_factor=(2.0,), mode="linear", align_corners=False, recompute_scale_factor=True)
|
||||
|
||||
|
||||
# index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
|
||||
class ModelWithListOfOptionalTensors(FileSetup):
|
||||
path = 'index_Tensor.ptl'
|
||||
|
||||
def setup(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, index):
|
||||
values = torch.tensor([[4., 1., 1., 16.]])
|
||||
return values[torch.tensor(0), index]
|
||||
|
||||
model = Model()
|
||||
# Script the model and save
|
||||
script_model = torch.jit.script(model)
|
||||
script_model._save_for_lite_interpreter(self.path)
|
||||
@save_model
|
||||
class ModelWithListOfOptionalTensors(torch.nn.Module):
|
||||
def forward(self, index):
|
||||
values = torch.tensor([[4., 1., 1., 16.]])
|
||||
return values[torch.tensor(0), index]
|
||||
|
||||
|
||||
# conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1,
|
||||
# int groups=1) -> Tensor
|
||||
class ModelWithArrayOfInt(FileSetup):
|
||||
path = 'conv2d.ptl'
|
||||
|
||||
def setup(self):
|
||||
model = torch.nn.Conv2d(1, 2, (2, 2), stride=(1, 1), padding=(1, 1))
|
||||
# Script the model and save
|
||||
script_model = torch.jit.script(model)
|
||||
script_model._save_for_lite_interpreter(self.path)
|
||||
@save_model
|
||||
class ModelWithArrayOfInt(torch.nn.Conv2d):
|
||||
def __init__(self):
|
||||
super().__init__(1, 2, (2, 2), stride=(1, 1), padding=(1, 1))
|
||||
|
||||
|
||||
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
|
||||
# ones_like(Tensor self, *, ScalarType?, dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None,
|
||||
# MemoryFormat? memory_format=None) -> Tensor
|
||||
class ModelWithTensors(FileSetup):
|
||||
path = 'add_Tensor.ptl'
|
||||
@save_model
|
||||
class ModelWithTensors(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
b = torch.ones_like(a)
|
||||
return a + b
|
||||
|
||||
def setup(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
b = torch.ones_like(a)
|
||||
return a + b
|
||||
model = Model()
|
||||
# Script the model and save
|
||||
script_model = torch.jit.script(model)
|
||||
script_model._save_for_lite_interpreter(self.path)
|
||||
@save_model
|
||||
class ModelWithStringOptional(torch.nn.Module):
|
||||
def forward(self, b):
|
||||
a = torch.tensor(3, dtype=torch.int64)
|
||||
out = torch.empty(size=[1], dtype=torch.float)
|
||||
torch.div(b, a, out=out)
|
||||
return [torch.div(b, a, rounding_mode='trunc'), out]
|
||||
|
||||
|
||||
class ModelWithStringOptional(FileSetup):
|
||||
path = 'divide_Tensor.ptl'
|
||||
@save_model
|
||||
class ModelWithMultipleOps(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ops = torch.nn.Sequential(
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Flatten(),
|
||||
)
|
||||
|
||||
def setup(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, b):
|
||||
a = torch.tensor(3, dtype=torch.int64)
|
||||
out = torch.empty(size=[1], dtype=torch.float)
|
||||
torch.div(b, a, out=out)
|
||||
return [torch.div(b, a, rounding_mode='trunc'), out]
|
||||
model = Model()
|
||||
# Script the model and save
|
||||
script_model = torch.jit.script(model)
|
||||
script_model._save_for_lite_interpreter(self.path)
|
||||
|
||||
|
||||
class ModelWithMultipleOps(FileSetup):
|
||||
path = 'multiple_ops.ptl'
|
||||
|
||||
def setup(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.ops = torch.nn.Sequential(
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Flatten(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x[1] = -2
|
||||
return self.ops(x)
|
||||
|
||||
model = Model()
|
||||
# Script the model and save
|
||||
script_model = torch.jit.script(model)
|
||||
script_model._save_for_lite_interpreter(self.path)
|
||||
|
||||
|
||||
tests = [
|
||||
ModelWithDTypeDeviceLayoutPinMemory(),
|
||||
ModelWithTensorOptional(),
|
||||
ModelWithScalarList(),
|
||||
ModelWithFloatList(),
|
||||
ModelWithListOfOptionalTensors(),
|
||||
ModelWithArrayOfInt(),
|
||||
ModelWithTensors(),
|
||||
ModelWithStringOptional(),
|
||||
ModelWithMultipleOps(),
|
||||
]
|
||||
|
||||
|
||||
def setup():
|
||||
for test in tests:
|
||||
test.setup()
|
||||
|
||||
|
||||
def shutdown():
|
||||
for test in tests:
|
||||
test.shutdown()
|
||||
def forward(self, x):
|
||||
x[1] = -2
|
||||
return self.ops(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
command = sys.argv[1]
|
||||
ops_yaml = sys.argv[2]
|
||||
backup = ops_yaml + ".bak"
|
||||
if command == "setup":
|
||||
setup()
|
||||
tests = [
|
||||
ModelWithDTypeDeviceLayoutPinMemory(),
|
||||
ModelWithTensorOptional(),
|
||||
ModelWithScalarList(),
|
||||
ModelWithFloatList(),
|
||||
ModelWithListOfOptionalTensors(),
|
||||
ModelWithArrayOfInt(),
|
||||
ModelWithTensors(),
|
||||
ModelWithStringOptional(),
|
||||
ModelWithMultipleOps(),
|
||||
]
|
||||
shutil.copyfile(ops_yaml, backup)
|
||||
with open(ops_yaml, 'a') as f:
|
||||
for op in _OPERATORS:
|
||||
f.write(f"- {op}\n")
|
||||
elif command == "shutdown":
|
||||
shutdown()
|
||||
for file in _MODELS:
|
||||
if os.path.isfile(file):
|
||||
os.remove(file)
|
||||
shutil.move(backup, ops_yaml)
|
||||
|
@ -158,6 +158,9 @@ def gen_unboxing(
|
||||
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
|
||||
return fn.root_name
|
||||
|
||||
selected_op_num: int = len(selector.operators)
|
||||
# a best practice threshold of operators to enable sharding
|
||||
sharding_threshold: int = 100
|
||||
cpu_fm.write_sharded(
|
||||
"UnboxingFunctions.cpp",
|
||||
native_functions,
|
||||
@ -165,7 +168,7 @@ def gen_unboxing(
|
||||
env_callable=lambda fn: {
|
||||
"definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)]
|
||||
},
|
||||
num_shards=5,
|
||||
num_shards=1 if selected_op_num < sharding_threshold else 5,
|
||||
sharded_keys={"definitions"},
|
||||
)
|
||||
cpu_fm.write(
|
||||
@ -186,7 +189,7 @@ def gen_unboxing(
|
||||
env_callable=lambda fn: {
|
||||
"unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]
|
||||
},
|
||||
num_shards=10,
|
||||
num_shards=1 if selected_op_num < sharding_threshold else 10,
|
||||
sharded_keys={"unboxed_ops"},
|
||||
)
|
||||
|
||||
@ -245,7 +248,7 @@ def main(args: List[str]) -> None:
|
||||
op_registration_allowlist = None
|
||||
|
||||
selector = get_custom_build_selector(
|
||||
options.op_registration_allowlist,
|
||||
op_registration_allowlist,
|
||||
options.op_selection_yaml_path,
|
||||
)
|
||||
|
||||
|
@ -229,6 +229,7 @@ class CMake:
|
||||
"WERROR",
|
||||
"OPENSSL_ROOT_DIR",
|
||||
"STATIC_DISPATCH_BACKEND",
|
||||
"SELECTED_OP_LIST",
|
||||
)
|
||||
}
|
||||
)
|
||||
|
Reference in New Issue
Block a user