mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add models test for android and iOS
See https://github.com/pytorch/pytorch/pull/74793 for more details Redo everything in one pr. This diff added test model generation script and updated android and ios tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/74947 Approved by: https://github.com/kit1980
This commit is contained in:
committed by
PyTorch MergeBot
parent
f82b2d4a82
commit
28a4b4759a
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/dropout_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/dropout_ops.ptl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/linear_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/linear_ops.ptl
Normal file
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/mobilenet_v2.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/mobilenet_v2.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/nn_utils_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/nn_utils_ops.ptl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/padding_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/padding_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/pointwise_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/pointwise_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/pooling_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/pooling_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/recurrent_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/recurrent_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/reduction_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/reduction_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/sampling_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/sampling_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/shuffle_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/shuffle_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/sparse_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/sparse_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/spectral_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/spectral_ops.ptl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -9,10 +9,11 @@ import static org.junit.Assert.assertTrue;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public abstract class PytorchTestBase {
|
||||
private static final String TEST_MODULE_ASSET_NAME = "test.pt";
|
||||
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
|
||||
|
||||
@Test
|
||||
public void testForwardNull() throws IOException {
|
||||
@ -377,6 +378,186 @@ public abstract class PytorchTestBase {
|
||||
new long[] {2, 11, -101, 4, 12, -102, 6, 13, -103, 8, 14, -104});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMobileNetV2() throws IOException {
|
||||
try {
|
||||
final Module module = loadModel("mobilenet_v2.ptl");
|
||||
final IValue inputs = module.runMethod("get_all_bundled_inputs");
|
||||
assertTrue(inputs.isList());
|
||||
final IValue input = inputs.toList()[0];
|
||||
assertTrue(input.isTuple());
|
||||
module.forward(input.toTuple()[0]);
|
||||
assertTrue(true);
|
||||
} catch (Exception ex) {
|
||||
assertTrue("failed to run MobileNetV2 " + ex.getMessage(), false);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPointwiseOps() throws IOException {
|
||||
runModel("pointwise_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReductionOps() throws IOException {
|
||||
runModel("reduction_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComparisonOps() throws IOException {
|
||||
runModel("comparison_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOtherMathOps() throws IOException {
|
||||
runModel("other_math_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSpectralOps() throws IOException {
|
||||
runModel("spectral_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBlasLapackOps() throws IOException {
|
||||
runModel("blas_lapack_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSamplingOps() throws IOException {
|
||||
runModel("sampling_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorOps() throws IOException {
|
||||
runModel("tensor_general_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorCreationOps() throws IOException {
|
||||
runModel("tensor_creation_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorIndexingOps() throws IOException {
|
||||
runModel("tensor_indexing_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorTypingOps() throws IOException {
|
||||
runModel("tensor_typing_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorViewOps() throws IOException {
|
||||
runModel("tensor_view_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConvolutionOps() throws IOException {
|
||||
runModel("convolution_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPoolingOps() throws IOException {
|
||||
runModel("pooling_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPaddingOps() throws IOException {
|
||||
runModel("padding_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActivationOps() throws IOException {
|
||||
runModel("activation_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNormalizationOps() throws IOException {
|
||||
runModel("normalization_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRecurrentOps() throws IOException {
|
||||
runModel("recurrent_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransformerOps() throws IOException {
|
||||
runModel("transformer_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLinearOps() throws IOException {
|
||||
runModel("linear_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDropoutOps() throws IOException {
|
||||
runModel("dropout_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSparseOps() throws IOException {
|
||||
runModel("sparse_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDistanceFunctionOps() throws IOException {
|
||||
runModel("distance_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLossFunctionOps() throws IOException {
|
||||
runModel("loss_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVisionFunctionOps() throws IOException {
|
||||
runModel("vision_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShuffleOps() throws IOException {
|
||||
runModel("shuffle_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNNUtilsOps() throws IOException {
|
||||
runModel("nn_utils_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testQuantOps() throws IOException {
|
||||
runModel("general_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDynamicQuantOps() throws IOException {
|
||||
runModel("dynamic_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStaticQuantOps() throws IOException {
|
||||
runModel("static_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFusedQuantOps() throws IOException {
|
||||
runModel("fused_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTorchScriptBuiltinQuantOps() throws IOException {
|
||||
runModel("torchscript_builtin_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTorchScriptCollectionQuantOps() throws IOException {
|
||||
runModel("torchscript_collection_ops");
|
||||
}
|
||||
|
||||
static void assertIValueTensor(
|
||||
final IValue ivalue,
|
||||
final MemoryFormat memoryFormat,
|
||||
@ -389,5 +570,15 @@ public abstract class PytorchTestBase {
|
||||
assertArrayEquals(expectedData, t.getDataAsLongArray());
|
||||
}
|
||||
|
||||
void runModel(final String name) throws IOException {
|
||||
final Module storage_module = loadModel(name + ".ptl");
|
||||
storage_module.forward();
|
||||
|
||||
// TODO enable this once the on-the-fly script is ready
|
||||
// final Module on_the_fly_module = loadModel(name + "_temp.ptl");
|
||||
// on_the_fly_module.forward();
|
||||
assertTrue(true);
|
||||
}
|
||||
|
||||
protected abstract Module loadModel(String assetName) throws IOException;
|
||||
}
|
||||
|
@ -11,25 +11,6 @@
|
||||
@implementation TestAppTests {
|
||||
}
|
||||
|
||||
- (void)runModel:(NSString*)filename {
|
||||
NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:filename
|
||||
ofType:@"ptl"];
|
||||
XCTAssertNotNil(modelPath);
|
||||
c10::InferenceMode mode;
|
||||
auto module = torch::jit::_load_for_mobile(modelPath.UTF8String);
|
||||
XCTAssertNoThrow(module.forward({}));
|
||||
}
|
||||
|
||||
- (void)testLiteInterpreter {
|
||||
NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model_lite"
|
||||
ofType:@"ptl"];
|
||||
auto module = torch::jit::_load_for_mobile(modelPath.UTF8String);
|
||||
c10::InferenceMode mode;
|
||||
auto input = torch::ones({1, 3, 224, 224}, at::kFloat);
|
||||
auto outputTensor = module.forward({input}).toTensor();
|
||||
XCTAssertTrue(outputTensor.numel() == 1000);
|
||||
}
|
||||
|
||||
- (void)testCoreML {
|
||||
NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model_coreml"
|
||||
ofType:@"ptl"];
|
||||
@ -40,135 +21,173 @@
|
||||
XCTAssertTrue(outputTensor.numel() == 1000);
|
||||
}
|
||||
|
||||
- (void)testModel:(NSString*)filename {
|
||||
// model generated using the current pytorch revision
|
||||
// [self runModel:[NSString stringWithFormat:@"%@_temp", filename]];
|
||||
// model generated using older pyotrch revision
|
||||
[self runModel:filename];
|
||||
}
|
||||
|
||||
- (void)runModel:(NSString*)filename {
|
||||
NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:filename
|
||||
ofType:@"ptl"];
|
||||
XCTAssertNotNil(modelPath);
|
||||
c10::InferenceMode mode;
|
||||
auto module = torch::jit::_load_for_mobile(modelPath.UTF8String);
|
||||
auto has_bundled_input = module.find_method("get_all_bundled_inputs");
|
||||
if (has_bundled_input) {
|
||||
c10::IValue bundled_inputs = module.run_method("get_all_bundled_inputs");
|
||||
c10::List<at::IValue> all_inputs = bundled_inputs.toList();
|
||||
std::vector<std::vector<at::IValue>> inputs;
|
||||
for (at::IValue input : all_inputs) {
|
||||
inputs.push_back(input.toTupleRef().elements());
|
||||
}
|
||||
// run with the first bundled input
|
||||
XCTAssertNoThrow(module.forward(inputs[0]));
|
||||
} else {
|
||||
XCTAssertNoThrow(module.forward({}));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO remove this once updated test script
|
||||
- (void)testLiteInterpreter {
|
||||
XCTAssertTrue(true);
|
||||
}
|
||||
|
||||
- (void)testMobileNetV2 {
|
||||
[self testModel:@"mobilenet_v2"];
|
||||
}
|
||||
|
||||
- (void)testPointwiseOps {
|
||||
[self runModel:@"pointwise_ops"];
|
||||
[self testModel:@"pointwise_ops"];
|
||||
}
|
||||
|
||||
- (void)testReductionOps {
|
||||
[self runModel:@"reduction_ops"];
|
||||
[self testModel:@"reduction_ops"];
|
||||
}
|
||||
|
||||
- (void)testComparisonOps {
|
||||
[self runModel:@"comparison_ops"];
|
||||
[self testModel:@"comparison_ops"];
|
||||
}
|
||||
|
||||
- (void)testOtherMathOps {
|
||||
[self runModel:@"other_math_ops"];
|
||||
[self testModel:@"other_math_ops"];
|
||||
}
|
||||
|
||||
- (void)testSpectralOps {
|
||||
[self runModel:@"spectral_ops"];
|
||||
[self testModel:@"spectral_ops"];
|
||||
}
|
||||
|
||||
- (void)testBlasLapackOps {
|
||||
[self runModel:@"blas_lapack_ops"];
|
||||
[self testModel:@"blas_lapack_ops"];
|
||||
}
|
||||
|
||||
- (void)testSamplingOps {
|
||||
[self runModel:@"sampling_ops"];
|
||||
[self testModel:@"sampling_ops"];
|
||||
}
|
||||
|
||||
- (void)testTensorOps {
|
||||
[self runModel:@"tensor_general_ops"];
|
||||
[self testModel:@"tensor_general_ops"];
|
||||
}
|
||||
|
||||
- (void)testTensorCreationOps {
|
||||
[self runModel:@"tensor_creation_ops"];
|
||||
[self testModel:@"tensor_creation_ops"];
|
||||
}
|
||||
|
||||
- (void)testTensorIndexingOps {
|
||||
[self runModel:@"tensor_indexing_ops"];
|
||||
[self testModel:@"tensor_indexing_ops"];
|
||||
}
|
||||
|
||||
- (void)testTensorTypingOps {
|
||||
[self runModel:@"tensor_typing_ops"];
|
||||
[self testModel:@"tensor_typing_ops"];
|
||||
}
|
||||
|
||||
- (void)testTensorViewOps {
|
||||
[self runModel:@"tensor_view_ops"];
|
||||
[self testModel:@"tensor_view_ops"];
|
||||
}
|
||||
|
||||
- (void)testConvolutionOps {
|
||||
[self runModel:@"convolution_ops"];
|
||||
[self testModel:@"convolution_ops"];
|
||||
}
|
||||
|
||||
- (void)testPoolingOps {
|
||||
[self runModel:@"pooling_ops"];
|
||||
[self testModel:@"pooling_ops"];
|
||||
}
|
||||
|
||||
- (void)testPaddingOps {
|
||||
[self runModel:@"padding_ops"];
|
||||
[self testModel:@"padding_ops"];
|
||||
}
|
||||
|
||||
- (void)testActivationOps {
|
||||
[self runModel:@"activation_ops"];
|
||||
[self testModel:@"activation_ops"];
|
||||
}
|
||||
|
||||
- (void)testNormalizationOps {
|
||||
[self runModel:@"normalization_ops"];
|
||||
[self testModel:@"normalization_ops"];
|
||||
}
|
||||
|
||||
- (void)testRecurrentOps {
|
||||
[self runModel:@"recurrent_ops"];
|
||||
[self testModel:@"recurrent_ops"];
|
||||
}
|
||||
|
||||
- (void)testTransformerOps {
|
||||
[self runModel:@"transformer_ops"];
|
||||
[self testModel:@"transformer_ops"];
|
||||
}
|
||||
|
||||
- (void)testLinearOps {
|
||||
[self runModel:@"linear_ops"];
|
||||
[self testModel:@"linear_ops"];
|
||||
}
|
||||
|
||||
- (void)testDropoutOps {
|
||||
[self runModel:@"dropout_ops"];
|
||||
[self testModel:@"dropout_ops"];
|
||||
}
|
||||
|
||||
- (void)testSparseOps {
|
||||
[self runModel:@"sparse_ops"];
|
||||
[self testModel:@"sparse_ops"];
|
||||
}
|
||||
|
||||
- (void)testDistanceFunctionOps {
|
||||
[self runModel:@"distance_function_ops"];
|
||||
[self testModel:@"distance_function_ops"];
|
||||
}
|
||||
|
||||
- (void)testLossFunctionOps {
|
||||
[self runModel:@"loss_function_ops"];
|
||||
[self testModel:@"loss_function_ops"];
|
||||
}
|
||||
|
||||
- (void)testVisionFunctionOps {
|
||||
[self runModel:@"vision_function_ops"];
|
||||
[self testModel:@"vision_function_ops"];
|
||||
}
|
||||
|
||||
- (void)testShuffleOps {
|
||||
[self runModel:@"shuffle_ops"];
|
||||
[self testModel:@"shuffle_ops"];
|
||||
}
|
||||
|
||||
- (void)testNNUtilsOps {
|
||||
[self runModel:@"nn_utils_ops"];
|
||||
[self testModel:@"nn_utils_ops"];
|
||||
}
|
||||
|
||||
- (void)testQuantOps {
|
||||
[self runModel:@"general_quant_ops"];
|
||||
[self testModel:@"general_quant_ops"];
|
||||
}
|
||||
|
||||
- (void)testDynamicQuantOps {
|
||||
[self runModel:@"dynamic_quant_ops"];
|
||||
[self testModel:@"dynamic_quant_ops"];
|
||||
}
|
||||
|
||||
- (void)testStaticQuantOps {
|
||||
[self runModel:@"static_quant_ops"];
|
||||
[self testModel:@"static_quant_ops"];
|
||||
}
|
||||
|
||||
- (void)testFusedQuantOps {
|
||||
[self runModel:@"fused_quant_ops"];
|
||||
[self testModel:@"fused_quant_ops"];
|
||||
}
|
||||
|
||||
- (void)testTorchScriptBuiltinQuantOps {
|
||||
[self runModel:@"torchscript_builtin_ops"];
|
||||
[self testModel:@"torchscript_builtin_ops"];
|
||||
}
|
||||
|
||||
- (void)testTorchScriptCollectionQuantOps {
|
||||
[self runModel:@"torchscript_collection_ops"];
|
||||
[self testModel:@"torchscript_collection_ops"];
|
||||
}
|
||||
|
||||
@end
|
||||
|
Binary file not shown.
BIN
ios/TestApp/models/android_api_module.ptl
Normal file
BIN
ios/TestApp/models/android_api_module.ptl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
ios/TestApp/models/mobilenet_v2.ptl
Normal file
BIN
ios/TestApp/models/mobilenet_v2.ptl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
128
test/mobile/model_test/android_api_module.py
Normal file
128
test/mobile/model_test/android_api_module.py
Normal file
@ -0,0 +1,128 @@
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class AndroidAPIModule(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(AndroidAPIModule, self).__init__()
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input):
|
||||
return None
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqBool(self, input: bool) -> bool:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqInt(self, input: int) -> int:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqFloat(self, input: float) -> float:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqStr(self, input: str) -> str:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqTensor(self, input: Tensor) -> Tensor:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
return (input, sum)
|
||||
|
||||
@torch.jit.script_method
|
||||
def listBoolConjunction(self, input: List[bool]) -> bool:
|
||||
res = True
|
||||
for x in input:
|
||||
res = res and x
|
||||
return res
|
||||
|
||||
@torch.jit.script_method
|
||||
def listBoolDisjunction(self, input: List[bool]) -> bool:
|
||||
res = False
|
||||
for x in input:
|
||||
res = res or x
|
||||
return res
|
||||
|
||||
@torch.jit.script_method
|
||||
def tupleIntSumReturnTuple(
|
||||
self, input: Tuple[int, int, int]
|
||||
) -> Tuple[Tuple[int, int, int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
return (input, sum)
|
||||
|
||||
@torch.jit.script_method
|
||||
def optionalIntIsNone(self, input: Optional[int]) -> bool:
|
||||
return input is None
|
||||
|
||||
@torch.jit.script_method
|
||||
def intEq0None(self, input: int) -> Optional[int]:
|
||||
if input == 0:
|
||||
return None
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def str3Concat(self, input: str) -> str:
|
||||
return input + input + input
|
||||
|
||||
@torch.jit.script_method
|
||||
def newEmptyShapeWithItem(self, input):
|
||||
return torch.tensor([int(input.item())])[0]
|
||||
|
||||
@torch.jit.script_method
|
||||
def testAliasWithOffset(self) -> List[Tensor]:
|
||||
x = torch.tensor([100, 200])
|
||||
a = [x[0], x[1]]
|
||||
return a
|
||||
|
||||
@torch.jit.script_method
|
||||
def testNonContiguous(self):
|
||||
x = torch.tensor([100, 200, 300])[::2]
|
||||
assert not x.is_contiguous()
|
||||
assert x[0] == 100
|
||||
assert x[1] == 300
|
||||
return x
|
||||
|
||||
@torch.jit.script_method
|
||||
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = torch.nn.functional.conv2d(x, w)
|
||||
if toChannelsLast:
|
||||
r = r.contiguous(memory_format=torch.channels_last)
|
||||
else:
|
||||
r = r.contiguous()
|
||||
return r
|
||||
|
||||
@torch.jit.script_method
|
||||
def contiguous(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous()
|
||||
|
||||
@torch.jit.script_method
|
||||
def contiguousChannelsLast(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
@torch.jit.script_method
|
||||
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous(memory_format=torch.channels_last_3d)
|
@ -16,7 +16,7 @@ class TSBuiltinOpsModule(torch.nn.Module):
|
||||
l = ["1", "2", "test"]
|
||||
d = {"key": 1}
|
||||
d2 = {0: 100}
|
||||
return (
|
||||
return len(
|
||||
# type
|
||||
bool(x),
|
||||
bool(x.item()),
|
||||
@ -111,7 +111,7 @@ class TSCollectionOpsModule(torch.nn.Module):
|
||||
d2.clear()
|
||||
d2[0] = 100
|
||||
|
||||
return (
|
||||
return len(
|
||||
s[torch.tensor(1)],
|
||||
d["key"],
|
||||
d2[0],
|
||||
|
@ -1,3 +1,8 @@
|
||||
_coverage: 91.09
|
||||
_covered_ops: 358
|
||||
_generated_ops: 705
|
||||
_production_ops: 393
|
||||
_uncovered_ops: 35
|
||||
all_generated_ops:
|
||||
- aten::Bool.Tensor
|
||||
- aten::Bool.int
|
||||
@ -642,6 +647,8 @@ all_generated_ops:
|
||||
- aten::zeros
|
||||
- aten::zeros.out
|
||||
- aten::zeros_like
|
||||
- prepacked::conv2d_clamp_run
|
||||
- prepacked::linear_clamp_run
|
||||
- prim::RaiseException
|
||||
- prim::TupleIndex
|
||||
- prim::TupleUnpack
|
||||
@ -1023,6 +1030,8 @@ covered_ops:
|
||||
- aten::zeros
|
||||
- aten::zeros.out
|
||||
- aten::zeros_like
|
||||
- prepacked::conv2d_clamp_run
|
||||
- prepacked::linear_clamp_run
|
||||
- prim::RaiseException
|
||||
- prim::TupleIndex
|
||||
- prim::TupleUnpack
|
||||
@ -1086,10 +1095,8 @@ uncovered_ops:
|
||||
- aten::true_divide.Tensor
|
||||
- aten::upsample_nearest2d
|
||||
- prepacked::conv2d_clamp_prepack
|
||||
- prepacked::conv2d_clamp_run
|
||||
- prepacked::conv2d_transpose_clamp_prepack
|
||||
- prepacked::conv2d_transpose_clamp_run
|
||||
- prepacked::linear_clamp_run
|
||||
- prim::ModuleContainerIndex.list
|
||||
- prim::NumToTensor.Scalar
|
||||
- prim::Print
|
||||
|
@ -1,5 +1,8 @@
|
||||
import io
|
||||
import sys
|
||||
import torch
|
||||
import yaml
|
||||
from android_api_module import AndroidAPIModule
|
||||
from builtin_ops import (
|
||||
TSBuiltinOpsModule,
|
||||
TSCollectionOpsModule,
|
||||
@ -43,92 +46,188 @@ from tensor_ops import (
|
||||
TensorTypingOpsModule,
|
||||
TensorViewOpsModule,
|
||||
)
|
||||
from torch.jit.mobile import _load_for_lite_interpreter
|
||||
from torchvision_models import MobileNetV2Module
|
||||
|
||||
test_path_ios = "ios/TestApp/models/"
|
||||
test_path_android = "android/pytorch_android/src/androidTest/assets/"
|
||||
|
||||
output_path = "ios/TestApp/models/"
|
||||
production_ops_path = "test/mobile/model_test/model_ops.yaml"
|
||||
coverage_out_path = "test/mobile/model_test/coverage.yaml"
|
||||
|
||||
|
||||
def scriptAndSave(module, name):
|
||||
module = torch.jit.script(module)
|
||||
return save(module, name)
|
||||
|
||||
|
||||
def traceAndSave(module, name):
|
||||
module = torch.jit.trace(module, [])
|
||||
return save(module, name)
|
||||
|
||||
|
||||
def save(module, name):
|
||||
module._save_for_lite_interpreter(output_path + name + ".ptl")
|
||||
print("model saved to " + output_path + name + ".ptl")
|
||||
ops = torch.jit.export_opnames(module)
|
||||
print(ops)
|
||||
module()
|
||||
return ops
|
||||
|
||||
|
||||
ops = [
|
||||
all_modules = {
|
||||
# math ops
|
||||
scriptAndSave(PointwiseOpsModule(), "pointwise_ops"),
|
||||
scriptAndSave(ReductionOpsModule(), "reduction_ops"),
|
||||
scriptAndSave(ComparisonOpsModule(), "comparison_ops"),
|
||||
scriptAndSave(OtherMathOpsModule(), "other_math_ops"),
|
||||
scriptAndSave(SpectralOpsModule(), "spectral_ops"),
|
||||
scriptAndSave(BlasLapackOpsModule(), "blas_lapack_ops"),
|
||||
"pointwise_ops": PointwiseOpsModule(),
|
||||
"reduction_ops": ReductionOpsModule(),
|
||||
"comparison_ops": ComparisonOpsModule(),
|
||||
"spectral_ops": SpectralOpsModule(),
|
||||
"other_math_ops": OtherMathOpsModule(),
|
||||
"blas_lapack_ops": BlasLapackOpsModule(),
|
||||
# sampling
|
||||
scriptAndSave(SamplingOpsModule(), "sampling_ops"),
|
||||
"sampling_ops": SamplingOpsModule(),
|
||||
# tensor ops
|
||||
scriptAndSave(TensorOpsModule(), "tensor_general_ops"),
|
||||
scriptAndSave(TensorCreationOpsModule(), "tensor_creation_ops"),
|
||||
scriptAndSave(TensorIndexingOpsModule(), "tensor_indexing_ops"),
|
||||
scriptAndSave(TensorTypingOpsModule(), "tensor_typing_ops"),
|
||||
scriptAndSave(TensorViewOpsModule(), "tensor_view_ops"),
|
||||
"tensor_general_ops": TensorOpsModule(),
|
||||
"tensor_creation_ops": TensorCreationOpsModule(),
|
||||
"tensor_indexing_ops": TensorIndexingOpsModule(),
|
||||
"tensor_typing_ops": TensorTypingOpsModule(),
|
||||
"tensor_view_ops": TensorViewOpsModule(),
|
||||
# nn ops
|
||||
scriptAndSave(NNConvolutionModule(), "convolution_ops"),
|
||||
scriptAndSave(NNPoolingModule(), "pooling_ops"),
|
||||
scriptAndSave(NNPaddingModule(), "padding_ops"),
|
||||
scriptAndSave(NNActivationModule(), "activation_ops"),
|
||||
scriptAndSave(NNNormalizationModule(), "normalization_ops"),
|
||||
scriptAndSave(NNRecurrentModule(), "recurrent_ops"),
|
||||
scriptAndSave(NNTransformerModule(), "transformer_ops"),
|
||||
scriptAndSave(NNLinearModule(), "linear_ops"),
|
||||
scriptAndSave(NNDropoutModule(), "dropout_ops"),
|
||||
scriptAndSave(NNSparseModule(), "sparse_ops"),
|
||||
scriptAndSave(NNDistanceModule(), "distance_function_ops"),
|
||||
scriptAndSave(NNLossFunctionModule(), "loss_function_ops"),
|
||||
scriptAndSave(NNVisionModule(), "vision_function_ops"),
|
||||
scriptAndSave(NNShuffleModule(), "shuffle_ops"),
|
||||
scriptAndSave(NNUtilsModule(), "nn_utils_ops"),
|
||||
"convolution_ops": NNConvolutionModule(),
|
||||
"pooling_ops": NNPoolingModule(),
|
||||
"padding_ops": NNPaddingModule(),
|
||||
"activation_ops": NNActivationModule(),
|
||||
"normalization_ops": NNNormalizationModule(),
|
||||
"recurrent_ops": NNRecurrentModule(),
|
||||
"transformer_ops": NNTransformerModule(),
|
||||
"linear_ops": NNLinearModule(),
|
||||
"dropout_ops": NNDropoutModule(),
|
||||
"sparse_ops": NNSparseModule(),
|
||||
"distance_function_ops": NNDistanceModule(),
|
||||
"loss_function_ops": NNLossFunctionModule(),
|
||||
"vision_function_ops": NNVisionModule(),
|
||||
"shuffle_ops": NNShuffleModule(),
|
||||
"nn_utils_ops": NNUtilsModule(),
|
||||
# quantization ops
|
||||
scriptAndSave(GeneralQuantModule(), "general_quant_ops"),
|
||||
scriptAndSave(DynamicQuantModule().getModule(), "dynamic_quant_ops"),
|
||||
traceAndSave(StaticQuantModule().getModule(), "static_quant_ops"),
|
||||
scriptAndSave(FusedQuantModule().getModule(), "fused_quant_ops"),
|
||||
"general_quant_ops": GeneralQuantModule(),
|
||||
"dynamic_quant_ops": DynamicQuantModule(),
|
||||
"static_quant_ops": StaticQuantModule(),
|
||||
"fused_quant_ops": FusedQuantModule(),
|
||||
# TorchScript buildin ops
|
||||
scriptAndSave(TSBuiltinOpsModule(), "torchscript_builtin_ops"),
|
||||
scriptAndSave(TSCollectionOpsModule(), "torchscript_collection_ops"),
|
||||
"torchscript_builtin_ops": TSBuiltinOpsModule(),
|
||||
"torchscript_collection_ops": TSCollectionOpsModule(),
|
||||
# vision
|
||||
"mobilenet_v2": MobileNetV2Module(),
|
||||
# android api module
|
||||
"android_api_module": AndroidAPIModule(),
|
||||
}
|
||||
|
||||
models_need_trace = [
|
||||
"static_quant_ops",
|
||||
]
|
||||
|
||||
|
||||
with open(production_ops_path) as input_yaml_file:
|
||||
production_ops = yaml.safe_load(input_yaml_file)
|
||||
def calcOpsCoverage(ops):
|
||||
with open(production_ops_path) as input_yaml_file:
|
||||
production_ops = yaml.safe_load(input_yaml_file)
|
||||
|
||||
production_ops = set(production_ops["root_operators"])
|
||||
all_generated_ops = set().union(*ops)
|
||||
covered_ops = production_ops.intersection(all_generated_ops)
|
||||
uncovered_ops = production_ops - covered_ops
|
||||
coverage = 100 * len(covered_ops) / len(production_ops)
|
||||
print(
|
||||
f"\nGenerated {len(all_generated_ops)} ops and covered "
|
||||
+ f"{len(covered_ops)}/{len(production_ops)} ({round(coverage, 2)}%) production ops. \n"
|
||||
)
|
||||
with open(coverage_out_path, "w") as f:
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"uncovered_ops": sorted(list(uncovered_ops)),
|
||||
"covered_ops": sorted(list(covered_ops)),
|
||||
"all_generated_ops": sorted(list(all_generated_ops)),
|
||||
},
|
||||
f,
|
||||
production_ops = set(production_ops["root_operators"])
|
||||
all_generated_ops = set(ops)
|
||||
covered_ops = production_ops.intersection(all_generated_ops)
|
||||
uncovered_ops = production_ops - covered_ops
|
||||
coverage = 100 * len(covered_ops) / len(production_ops)
|
||||
print(
|
||||
f"\nGenerated {len(all_generated_ops)} ops and covered "
|
||||
+ f"{len(covered_ops)}/{len(production_ops)} ({round(coverage, 2)}%) production ops. \n"
|
||||
)
|
||||
with open(coverage_out_path, "w") as f:
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"_covered_ops": len(covered_ops),
|
||||
"_production_ops": len(production_ops),
|
||||
"_generated_ops": len(all_generated_ops),
|
||||
"_uncovered_ops": len(uncovered_ops),
|
||||
"_coverage": round(coverage, 2),
|
||||
"uncovered_ops": sorted(list(uncovered_ops)),
|
||||
"covered_ops": sorted(list(covered_ops)),
|
||||
"all_generated_ops": sorted(list(all_generated_ops)),
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
|
||||
def getModuleFromName(model_name):
|
||||
if model_name not in all_modules:
|
||||
print("Cannot find test model for " + model_name)
|
||||
return None, []
|
||||
|
||||
module = all_modules[model_name]
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
module = module.getModule()
|
||||
|
||||
has_bundled_inputs = False # module.find_method("get_all_bundled_inputs")
|
||||
|
||||
if model_name in models_need_trace:
|
||||
module = torch.jit.trace(module, [])
|
||||
else:
|
||||
module = torch.jit.script(module)
|
||||
|
||||
ops = torch.jit.export_opnames(module)
|
||||
print(ops)
|
||||
|
||||
# try to run the model
|
||||
runModule(module)
|
||||
|
||||
return module, ops
|
||||
|
||||
|
||||
def runModule(module):
|
||||
buffer = io.BytesIO(module._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
lite_module = _load_for_lite_interpreter(buffer)
|
||||
if lite_module.find_method("get_all_bundled_inputs"):
|
||||
# run with the first bundled input
|
||||
input = lite_module.run_method("get_all_bundled_inputs")[0]
|
||||
lite_module.forward(*input)
|
||||
else:
|
||||
# assuming model has no input
|
||||
lite_module()
|
||||
|
||||
|
||||
# generate all models in the given folder.
|
||||
# If it's "on the fly" mode, add "_temp" suffix to the model file.
|
||||
def generateAllModels(folder, on_the_fly=False):
|
||||
all_ops = []
|
||||
for name in all_modules:
|
||||
module, ops = getModuleFromName(name)
|
||||
all_ops = all_ops + ops
|
||||
path = folder + name + ("_temp.ptl" if on_the_fly else ".ptl")
|
||||
module._save_for_lite_interpreter(path)
|
||||
print("model saved to " + path)
|
||||
calcOpsCoverage(all_ops)
|
||||
|
||||
|
||||
# generate/update a given model for storage
|
||||
def generateModel(name):
|
||||
module, ops = getModuleFromName(name)
|
||||
if module is None:
|
||||
return
|
||||
path_ios = test_path_ios + name + ".ptl"
|
||||
path_android = test_path_android + name + ".ptl"
|
||||
module._save_for_lite_interpreter(path_ios)
|
||||
module._save_for_lite_interpreter(path_android)
|
||||
print("model saved to " + path_ios + " and " + path_android)
|
||||
|
||||
|
||||
def main(argv):
|
||||
if argv is None or len(argv) != 1:
|
||||
print(
|
||||
"""
|
||||
This script generate models for mobile test. For each model we have a "storage" version
|
||||
and an "on-the-fly" version. The "on-the-fly" version will be generated during test,and
|
||||
should not be committed to the repo.
|
||||
The "storage" version is for back compatibility # test (a model generated today should
|
||||
run on master branch in the next 6 months). We can use this script to update a model that
|
||||
is no longer supported.
|
||||
- use 'python gen_test_model.py android-test' to generate on-the-fly models for android
|
||||
- use 'python gen_test_model.py ios-test' to generate on-the-fly models for ios
|
||||
- use 'python gen_test_model.py android' to generate checked-in models for android
|
||||
- use 'python gen_test_model.py ios' to generate on-the-fly models for ios
|
||||
- use 'python gen_test_model.py <model_name_no_suffix>' to update the given storage model
|
||||
"""
|
||||
)
|
||||
return
|
||||
|
||||
if argv[0] == "android":
|
||||
generateAllModels(test_path_android, on_the_fly=False)
|
||||
elif argv[0] == "ios":
|
||||
generateAllModels(test_path_ios, on_the_fly=False)
|
||||
elif argv[0] == "android-test":
|
||||
generateAllModels(test_path_android, on_the_fly=True)
|
||||
elif argv[0] == "ios-test":
|
||||
generateAllModels(test_path_ios, on_the_fly=True)
|
||||
else:
|
||||
generateModel(argv[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
||||
|
@ -22,7 +22,7 @@ class PointwiseOpsModule(torch.nn.Module):
|
||||
f = torch.zeros(3)
|
||||
g = torch.tensor([-1, 0, 1])
|
||||
w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
|
||||
return (
|
||||
return len(
|
||||
torch.abs(torch.tensor([-1, -2, 3])),
|
||||
torch.absolute(torch.tensor([-1, -2, 3])),
|
||||
torch.acos(a),
|
||||
@ -222,7 +222,7 @@ class ReductionOpsModule(torch.nn.Module):
|
||||
a = torch.randn(4)
|
||||
b = torch.randn(4)
|
||||
c = torch.tensor(0.5)
|
||||
return (
|
||||
return len(
|
||||
torch.argmax(a),
|
||||
torch.argmin(a),
|
||||
torch.amax(a),
|
||||
@ -271,7 +271,7 @@ class ComparisonOpsModule(torch.nn.Module):
|
||||
def forward(self):
|
||||
a = torch.tensor(0)
|
||||
b = torch.tensor(1)
|
||||
return (
|
||||
return len(
|
||||
torch.allclose(a, b),
|
||||
torch.argsort(a),
|
||||
torch.eq(a, b),
|
||||
@ -327,7 +327,7 @@ class OtherMathOpsModule(torch.nn.Module):
|
||||
f = torch.randn(4, 4, 4)
|
||||
size = [0, 1]
|
||||
dims = [0, 1]
|
||||
return (
|
||||
return len(
|
||||
torch.atleast_1d(a),
|
||||
torch.atleast_2d(a),
|
||||
torch.atleast_3d(a),
|
||||
@ -380,7 +380,7 @@ class OtherMathOpsModule(torch.nn.Module):
|
||||
torch.triu_indices(3, 3),
|
||||
torch.vander(a),
|
||||
torch.view_as_real(torch.randn(4, dtype=torch.cfloat)),
|
||||
torch.view_as_complex(torch.randn(4, 2)),
|
||||
torch.view_as_complex(torch.randn(4, 2)).real,
|
||||
torch.resolve_conj(a),
|
||||
torch.resolve_neg(a),
|
||||
)
|
||||
@ -396,7 +396,7 @@ class SpectralOpsModule(torch.nn.Module):
|
||||
def spectral_ops(self):
|
||||
a = torch.randn(10)
|
||||
b = torch.randn(10, 8, 4, 2)
|
||||
return (
|
||||
return len(
|
||||
torch.stft(a, 8),
|
||||
torch.stft(a, torch.tensor(8)),
|
||||
torch.istft(b, 8),
|
||||
@ -420,7 +420,7 @@ class BlasLapackOpsModule(torch.nn.Module):
|
||||
a = torch.randn(10, 3, 4)
|
||||
b = torch.randn(10, 4, 3)
|
||||
v = torch.randn(3)
|
||||
return (
|
||||
return len(
|
||||
torch.addbmm(m, a, b),
|
||||
torch.addmm(torch.randn(2, 3), torch.randn(2, 3), torch.randn(3, 3)),
|
||||
torch.addmv(torch.randn(2), torch.randn(2, 3), torch.randn(3)),
|
||||
|
@ -31,11 +31,11 @@ class NNConvolutionModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self):
|
||||
return (
|
||||
return len((
|
||||
[module(self.input1d) for i, module in enumerate(self.module1d)],
|
||||
[module(self.input2d) for i, module in enumerate(self.module2d)],
|
||||
[module(self.input3d) for i, module in enumerate(self.module3d)],
|
||||
)
|
||||
))
|
||||
|
||||
|
||||
class NNPoolingModule(torch.nn.Module):
|
||||
@ -77,11 +77,11 @@ class NNPoolingModule(torch.nn.Module):
|
||||
# TODO max_unpool
|
||||
|
||||
def forward(self):
|
||||
return (
|
||||
return len((
|
||||
[module(self.input1d) for i, module in enumerate(self.module1d)],
|
||||
[module(self.input2d) for i, module in enumerate(self.module2d)],
|
||||
[module(self.input3d) for i, module in enumerate(self.module3d)],
|
||||
)
|
||||
))
|
||||
|
||||
|
||||
class NNPaddingModule(torch.nn.Module):
|
||||
@ -116,11 +116,11 @@ class NNPaddingModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self):
|
||||
return (
|
||||
return len((
|
||||
[module(self.input1d) for i, module in enumerate(self.module1d)],
|
||||
[module(self.input2d) for i, module in enumerate(self.module2d)],
|
||||
[module(self.input3d) for i, module in enumerate(self.module3d)],
|
||||
)
|
||||
))
|
||||
|
||||
|
||||
class NNNormalizationModule(torch.nn.Module):
|
||||
@ -155,11 +155,11 @@ class NNNormalizationModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self):
|
||||
return (
|
||||
return len((
|
||||
[module(self.input1d) for i, module in enumerate(self.module1d)],
|
||||
[module(self.input2d) for i, module in enumerate(self.module2d)],
|
||||
[module(self.input3d) for i, module in enumerate(self.module3d)],
|
||||
)
|
||||
))
|
||||
|
||||
|
||||
class NNActivationModule(torch.nn.Module):
|
||||
@ -202,9 +202,9 @@ class NNActivationModule(torch.nn.Module):
|
||||
|
||||
def forward(self):
|
||||
input = torch.randn(2, 3, 4)
|
||||
for i, module in enumerate(self.activations):
|
||||
x = module(input)
|
||||
return x
|
||||
return len((
|
||||
[module(input) for i, module in enumerate(self.activations)],
|
||||
))
|
||||
|
||||
|
||||
class NNRecurrentModule(torch.nn.Module):
|
||||
@ -228,14 +228,13 @@ class NNRecurrentModule(torch.nn.Module):
|
||||
input = torch.randn(5, 3, 4)
|
||||
h = torch.randn(2, 3, 8)
|
||||
c = torch.randn(2, 3, 8)
|
||||
return (
|
||||
self.rnn[0](input, h),
|
||||
self.rnn[1](input[0], h[0]),
|
||||
self.gru[0](input, h),
|
||||
self.gru[1](input[0], h[0]),
|
||||
self.lstm[0](input, (h, c)),
|
||||
self.lstm[1](input[0], (h[0], c[0])),
|
||||
)
|
||||
r = self.rnn[0](input, h)
|
||||
r = self.rnn[1](input[0], h[0])
|
||||
r = self.gru[0](input, h)
|
||||
r = self.gru[1](input[0], h[0])
|
||||
r = self.lstm[0](input, (h, c))
|
||||
r = self.lstm[1](input[0], (h[0], c[0]))
|
||||
return len(r)
|
||||
|
||||
|
||||
class NNTransformerModule(torch.nn.Module):
|
||||
@ -258,11 +257,10 @@ class NNTransformerModule(torch.nn.Module):
|
||||
def forward(self):
|
||||
input = torch.rand(1, 16, 2)
|
||||
tgt = torch.rand((1, 16, 2))
|
||||
return (
|
||||
self.transformers[0](input, tgt),
|
||||
self.transformers[1](input),
|
||||
self.transformers[2](input, tgt),
|
||||
)
|
||||
r = self.transformers[0](input, tgt)
|
||||
r = self.transformers[1](input)
|
||||
r = self.transformers[2](input, tgt)
|
||||
return len(r)
|
||||
|
||||
|
||||
class NNLinearModule(torch.nn.Module):
|
||||
@ -279,11 +277,10 @@ class NNLinearModule(torch.nn.Module):
|
||||
|
||||
def forward(self):
|
||||
input = torch.randn(32, 20)
|
||||
return (
|
||||
self.linears[0](input),
|
||||
self.linears[1](input),
|
||||
self.linears[2](input, input),
|
||||
)
|
||||
r = self.linears[0](input)
|
||||
r = self.linears[1](input)
|
||||
r = self.linears[2](input, input)
|
||||
return len(r)
|
||||
|
||||
|
||||
class NNDropoutModule(torch.nn.Module):
|
||||
@ -294,7 +291,7 @@ class NNDropoutModule(torch.nn.Module):
|
||||
a = torch.randn(8, 4)
|
||||
b = torch.randn(8, 4, 4, 4)
|
||||
c = torch.randn(8, 4, 4, 4, 4)
|
||||
return (
|
||||
return len(
|
||||
F.dropout(a),
|
||||
F.dropout2d(b),
|
||||
F.dropout3d(c),
|
||||
@ -312,7 +309,7 @@ class NNSparseModule(torch.nn.Module):
|
||||
input2 = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
|
||||
embedding_matrix = torch.rand(10, 3)
|
||||
offsets = torch.tensor([0, 4])
|
||||
return (
|
||||
return len(
|
||||
F.embedding(input, embedding_matrix),
|
||||
F.embedding_bag(input2, embedding_matrix, offsets),
|
||||
F.one_hot(torch.arange(0, 5) % 3, num_classes=5),
|
||||
@ -326,7 +323,7 @@ class NNDistanceModule(torch.nn.Module):
|
||||
def forward(self):
|
||||
a = torch.randn(8, 4)
|
||||
b = torch.randn(8, 4)
|
||||
return (
|
||||
return len(
|
||||
F.pairwise_distance(a, b),
|
||||
F.cosine_similarity(a, b),
|
||||
F.pdist(a),
|
||||
@ -347,7 +344,7 @@ class NNLossFunctionModule(torch.nn.Module):
|
||||
targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
|
||||
input_lengths = torch.full((16,), 50, dtype=torch.long)
|
||||
target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
|
||||
return (
|
||||
return len(
|
||||
F.binary_cross_entropy(torch.sigmoid(a), b),
|
||||
F.binary_cross_entropy_with_logits(torch.sigmoid(a), b),
|
||||
F.poisson_nll_loss(a, b),
|
||||
@ -392,8 +389,10 @@ class NNVisionModule(torch.nn.Module):
|
||||
|
||||
def forward(self):
|
||||
input = torch.randn(1, 3, 16, 16)
|
||||
return (
|
||||
[module(self.input) for i, module in enumerate(self.vision_modules)],
|
||||
for i, module in enumerate(self.vision_modules):
|
||||
r = module(self.input)
|
||||
return len(
|
||||
r,
|
||||
self.linear_sample(torch.randn(4, 9, 9)),
|
||||
self.trilinear_sample(torch.randn(1, 3, 4, 9, 9)),
|
||||
F.grid_sample(input, torch.ones(1, 4, 4, 2)),
|
||||
@ -406,7 +405,7 @@ class NNShuffleModule(torch.nn.Module):
|
||||
self.shuffle = nn.ChannelShuffle(2)
|
||||
|
||||
def forward(self):
|
||||
return (self.shuffle(torch.randn(1, 4, 2, 2)),)
|
||||
return len(self.shuffle(torch.randn(1, 4, 2, 2)),)
|
||||
|
||||
|
||||
class NNUtilsModule(torch.nn.Module):
|
||||
@ -419,6 +418,6 @@ class NNUtilsModule(torch.nn.Module):
|
||||
|
||||
def forward(self):
|
||||
input = torch.randn(2, 50)
|
||||
return (
|
||||
return len(
|
||||
self.flatten(input),
|
||||
)
|
||||
|
@ -23,7 +23,7 @@ class GeneralQuantModule(torch.nn.Module):
|
||||
input1 = torch.randn(1, 16, 4)
|
||||
input2 = torch.randn(1, 16, 4, 4)
|
||||
input3 = torch.randn(1, 16, 4, 4, 4)
|
||||
return (
|
||||
return len(
|
||||
self.func.add(a, b),
|
||||
self.func.cat((a, a), 0),
|
||||
self.func.mul(a, b),
|
||||
@ -92,7 +92,7 @@ class DynamicQuantModule:
|
||||
trans_input = torch.randn(1, 16, 2)
|
||||
tgt = torch.rand(1, 16, 2)
|
||||
|
||||
return (
|
||||
return len((
|
||||
self.rnn(input, h),
|
||||
self.rnncell(input[0], h[0]),
|
||||
self.gru(input, h),
|
||||
@ -105,7 +105,7 @@ class DynamicQuantModule:
|
||||
self.linears[0](linear_input),
|
||||
self.linears[1](linear_input),
|
||||
self.linears[2](linear_input, linear_input),
|
||||
)
|
||||
))
|
||||
|
||||
|
||||
class StaticQuantModule:
|
||||
|
@ -11,7 +11,7 @@ class SamplingOpsModule(torch.nn.Module):
|
||||
a = torch.empty(3, 3).uniform_(0.0, 1.0)
|
||||
size = (1, 4)
|
||||
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
|
||||
return (
|
||||
return len(
|
||||
# torch.seed(),
|
||||
# torch.manual_seed(0),
|
||||
torch.bernoulli(a),
|
||||
|
@ -15,7 +15,7 @@ class TensorOpsModule(torch.nn.Module):
|
||||
c = torch.randn(4, dtype=torch.cfloat)
|
||||
w = torch.rand(4, 4, 4, 4)
|
||||
v = torch.rand(4, 4, 4, 4)
|
||||
return (
|
||||
return len(
|
||||
# torch.is_tensor(a),
|
||||
# torch.is_storage(a),
|
||||
torch.is_complex(a),
|
||||
@ -120,7 +120,7 @@ class TensorCreationOpsModule(torch.nn.Module):
|
||||
0,
|
||||
torch.quint8,
|
||||
)
|
||||
return (
|
||||
return len(
|
||||
torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]),
|
||||
# torch.sparse_coo_tensor(i, v, [2, 3]), # not work for iOS
|
||||
torch.as_tensor([1, 2, 3]),
|
||||
@ -171,7 +171,7 @@ class TensorIndexingOpsModule(torch.nn.Module):
|
||||
t = torch.tensor([[0, 0], [1, 0]])
|
||||
mask = x.ge(0.5)
|
||||
i = [0, 1]
|
||||
return (
|
||||
return len(
|
||||
torch.cat((x, x, x), 0),
|
||||
torch.concat((x, x, x), 0),
|
||||
torch.conj(x),
|
||||
@ -233,7 +233,7 @@ class TensorTypingOpsModule(torch.nn.Module):
|
||||
|
||||
def tensor_typing_ops(self):
|
||||
x = torch.randn(1, 3, 4, 4)
|
||||
return (
|
||||
return len(
|
||||
x.to(torch.float),
|
||||
x.to(torch.double),
|
||||
x.to(torch.cfloat),
|
||||
@ -262,7 +262,7 @@ class TensorViewOpsModule(torch.nn.Module):
|
||||
def tensor_view_ops(self):
|
||||
x = torch.randn(4, 4, 1)
|
||||
y = torch.randn(4, 4, 2)
|
||||
return (
|
||||
return len(
|
||||
x[0, 2:],
|
||||
x.detach(),
|
||||
x.detach_(),
|
||||
|
24
test/mobile/model_test/torchvision_models.py
Normal file
24
test/mobile/model_test/torchvision_models.py
Normal file
@ -0,0 +1,24 @@
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
|
||||
|
||||
class MobileNetV2Module:
|
||||
def __init__(self):
|
||||
super(MobileNetV2Module, self).__init__()
|
||||
|
||||
def getModule(self):
|
||||
model = torchvision.models.mobilenet_v2(pretrained=True)
|
||||
model.eval()
|
||||
example = torch.zeros(1, 3, 224, 224)
|
||||
traced_script_module = torch.jit.trace(model, example)
|
||||
optimized_module = optimize_for_mobile(traced_script_module)
|
||||
augment_model_with_bundled_inputs(
|
||||
optimized_module,
|
||||
[
|
||||
(example, ),
|
||||
],
|
||||
)
|
||||
optimized_module(example)
|
||||
return optimized_module
|
Reference in New Issue
Block a user