Revert "add model test for Android"

This reverts commit 91ef3c82615d6ede05d5b86f1bd5571ea95e4ef1.

Reverted https://github.com/pytorch/pytorch/pull/74793 on behalf of https://github.com/seemethere
This commit is contained in:
PyTorch MergeBot
2022-03-29 22:08:22 +00:00
parent fc47257b30
commit 33b9726e6b
48 changed files with 524 additions and 853 deletions

View File

@ -235,23 +235,6 @@ To load torchscript model for mobile we need some special setup which is placed
[Example of linking to libtorch from aar](https://github.com/pytorch/pytorch/tree/master/android/test_app)
## Running Instrumentation Test
Test using lite interpreter
```
./android/run_tests.sh
```
Test using full JIT
```
BUILD_LITE_INTERPRETER=0 ./android/run_tests.sh
```
Regenerate all test models
```
python test/mobile/model_test/gen_test_model.py android
```
## PyTorch Android API Javadoc
You can find more details about the PyTorch Android API in the [Javadoc](https://pytorch.org/javadoc/).

View File

@ -1,12 +1,22 @@
from typing import Dict, List, Tuple, Optional
import torch
from torch import Tensor
from typing import Dict, List, Tuple, Optional
OUTPUT_DIR = "src/androidTest/assets/"
class AndroidAPIModule(torch.jit.ScriptModule):
def scriptAndSave(module, fileName):
print('-' * 80)
script_module = torch.jit.script(module)
print(script_module.graph)
outputFileName = OUTPUT_DIR + fileName
# note that the lite interpreter model can also be used in full JIT
script_module._save_for_lite_interpreter(outputFileName)
print("Saved to " + outputFileName)
print('=' * 80)
class Test(torch.jit.ScriptModule):
def __init__(self):
super(AndroidAPIModule, self).__init__()
super(Test, self).__init__()
@torch.jit.script_method
def forward(self, input):
@ -66,9 +76,7 @@ class AndroidAPIModule(torch.jit.ScriptModule):
return res
@torch.jit.script_method
def tupleIntSumReturnTuple(
self, input: Tuple[int, int, int]
) -> Tuple[Tuple[int, int, int], int]:
def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]:
sum = 0
for x in input:
sum += x
@ -109,7 +117,7 @@ class AndroidAPIModule(torch.jit.ScriptModule):
@torch.jit.script_method
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.nn.functional.conv2d(x, w)
if toChannelsLast:
if (toChannelsLast):
r = r.contiguous(memory_format=torch.channels_last)
else:
r = r.contiguous()
@ -126,3 +134,5 @@ class AndroidAPIModule(torch.jit.ScriptModule):
@torch.jit.script_method
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
return x.contiguous(memory_format=torch.channels_last_3d)
scriptAndSave(Test(), "test.pt")

Binary file not shown.

View File

@ -9,577 +9,385 @@ import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
public abstract class PytorchTestBase {
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
private static final String TEST_MODULE_ASSET_NAME = "test.pt";
@Test
public void testForwardNull() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[]{1}));
assertTrue(input.isTensor());
final IValue output = module.forward(input);
assertTrue(output.isNull());
@Test
public void testForwardNull() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
assertTrue(input.isTensor());
final IValue output = module.forward(input);
assertTrue(output.isNull());
}
@Test
public void testEqBool() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (boolean value : new boolean[] {false, true}) {
final IValue input = IValue.from(value);
assertTrue(input.isBool());
assertTrue(value == input.toBool());
final IValue output = module.runMethod("eqBool", input);
assertTrue(output.isBool());
assertTrue(value == output.toBool());
}
}
@Test
public void testEqInt() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
final IValue input = IValue.from(value);
assertTrue(input.isLong());
assertTrue(value == input.toLong());
final IValue output = module.runMethod("eqInt", input);
assertTrue(output.isLong());
assertTrue(value == output.toLong());
}
}
@Test
public void testEqFloat() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
double[] values =
new double[] {
-Double.MAX_VALUE,
Double.MAX_VALUE,
-Double.MIN_VALUE,
Double.MIN_VALUE,
-Math.exp(1.d),
-Math.sqrt(2.d),
-3.1415f,
3.1415f,
-1,
0,
1,
};
for (double value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isDouble());
assertTrue(value == input.toDouble());
final IValue output = module.runMethod("eqFloat", input);
assertTrue(output.isDouble());
assertTrue(value == output.toDouble());
}
}
@Test
public void testEqTensor() throws IOException {
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
final long numElements = Tensor.numel(inputTensorShape);
final float[] inputTensorData = new float[(int) numElements];
for (int i = 0; i < numElements; ++i) {
inputTensorData[i] = i;
}
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(inputTensor);
assertTrue(input.isTensor());
assertTrue(inputTensor == input.toTensor());
final IValue output = module.runMethod("eqTensor", input);
assertTrue(output.isTensor());
final Tensor outputTensor = output.toTensor();
assertNotNull(outputTensor);
assertArrayEquals(inputTensorShape, outputTensor.shape());
float[] outputData = outputTensor.getDataAsFloatArray();
for (int i = 0; i < numElements; i++) {
assertTrue(inputTensorData[i] == outputData[i]);
}
}
@Test
public void testEqDictIntKeyIntValue() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<Long, IValue> inputMap = new HashMap<>();
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
inputMap.put(0l, IValue.from(0l));
inputMap.put(1l, IValue.from(-1l));
inputMap.put(-1l, IValue.from(1l));
final IValue input = IValue.dictLongKeyFrom(inputMap);
assertTrue(input.isDictLongKey());
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
assertTrue(output.isDictLongKey());
final Map<Long, IValue> outputMap = output.toDictLongKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@Test
public void testEqDictStrKeyIntValue() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<String, IValue> inputMap = new HashMap<>();
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
inputMap.put("long_0", IValue.from(0l));
inputMap.put("long_1", IValue.from(1l));
inputMap.put("long_-1", IValue.from(-1l));
final IValue input = IValue.dictStringKeyFrom(inputMap);
assertTrue(input.isDictStringKey());
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
assertTrue(output.isDictStringKey());
final Map<String, IValue> outputMap = output.toDictStringKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@Test
public void testListIntSumReturnTuple() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (int n : new int[] {0, 1, 128}) {
long[] a = new long[n];
long sum = 0;
for (int i = 0; i < n; i++) {
a[i] = i;
sum += a[i];
}
final IValue input = IValue.listFrom(a);
assertTrue(input.isLongList());
final IValue output = module.runMethod("listIntSumReturnTuple", input);
assertTrue(output.isTuple());
assertTrue(2 == output.toTuple().length);
IValue output0 = output.toTuple()[0];
IValue output1 = output.toTuple()[1];
assertArrayEquals(a, output0.toLongList());
assertTrue(sum == output1.toLong());
}
}
@Test
public void testOptionalIntIsNone() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
}
@Test
public void testIntEq0None() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
}
@Test(expected = IllegalArgumentException.class)
public void testRunUndefinedMethod() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
module.runMethod("test_undefined_method_throws_exception");
}
@Test
public void testTensorMethods() {
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
int[] ints = new int[numel];
float[] floats = new float[numel];
byte[] bytes = new byte[numel];
for (int i = 0; i < numel; i++) {
bytes[i] = (byte) ((i % 255) - 128);
ints[i] = i;
floats[i] = i / 1000.f;
}
@Test
public void testEqBool() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (boolean value : new boolean[]{false, true}) {
final IValue input = IValue.from(value);
assertTrue(input.isBool());
assertTrue(value == input.toBool());
final IValue output = module.runMethod("eqBool", input);
assertTrue(output.isBool());
assertTrue(value == output.toBool());
}
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
assertTrue(tensorBytes.dtype() == DType.INT8);
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
Tensor tensorInts = Tensor.fromBlob(ints, shape);
assertTrue(tensorInts.dtype() == DType.INT32);
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
float[] floatsOut = tensorFloats.getDataAsFloatArray();
assertTrue(floatsOut.length == numel);
for (int i = 0; i < numel; i++) {
assertTrue(floats[i] == floatsOut[i]);
}
}
@Test
public void testEqInt() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (long value : new long[]{Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
final IValue input = IValue.from(value);
assertTrue(input.isLong());
assertTrue(value == input.toLong());
final IValue output = module.runMethod("eqInt", input);
assertTrue(output.isLong());
assertTrue(value == output.toLong());
}
@Test(expected = IllegalStateException.class)
public void testTensorIllegalStateOnWrongType() {
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
float[] floats = new float[numel];
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
tensorFloats.getDataAsByteArray();
}
@Test
public void testEqString() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[] {
"smoketest",
"проверка не латинских символов", // not latin symbols check
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("eqStr", input);
assertTrue(output.isString());
assertTrue(value.equals(output.toStr()));
}
}
@Test
public void testEqFloat() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
double[] values =
new double[]{
-Double.MAX_VALUE,
Double.MAX_VALUE,
-Double.MIN_VALUE,
Double.MIN_VALUE,
-Math.exp(1.d),
-Math.sqrt(2.d),
-3.1415f,
3.1415f,
-1,
0,
1,
};
for (double value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isDouble());
assertTrue(value == input.toDouble());
final IValue output = module.runMethod("eqFloat", input);
assertTrue(output.isDouble());
assertTrue(value == output.toDouble());
}
@Test
public void testStr3Concat() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[] {
"smoketest",
"проверка не латинских символов", // not latin symbols check
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("str3Concat", input);
assertTrue(output.isString());
String expectedOutput =
new StringBuilder().append(value).append(value).append(value).toString();
assertTrue(expectedOutput.equals(output.toStr()));
}
@Test
public void testEqTensor() throws IOException {
final long[] inputTensorShape = new long[]{1, 3, 224, 224};
final long numElements = Tensor.numel(inputTensorShape);
final float[] inputTensorData = new float[(int) numElements];
for (int i = 0; i < numElements; ++i) {
inputTensorData[i] = i;
}
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(inputTensor);
assertTrue(input.isTensor());
assertTrue(inputTensor == input.toTensor());
final IValue output = module.runMethod("eqTensor", input);
assertTrue(output.isTensor());
final Tensor outputTensor = output.toTensor();
assertNotNull(outputTensor);
assertArrayEquals(inputTensorShape, outputTensor.shape());
float[] outputData = outputTensor.getDataAsFloatArray();
for (int i = 0; i < numElements; i++) {
assertTrue(inputTensorData[i] == outputData[i]);
}
}
@Test
public void testEqDictIntKeyIntValue() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<Long, IValue> inputMap = new HashMap<>();
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
inputMap.put(0l, IValue.from(0l));
inputMap.put(1l, IValue.from(-1l));
inputMap.put(-1l, IValue.from(1l));
final IValue input = IValue.dictLongKeyFrom(inputMap);
assertTrue(input.isDictLongKey());
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
assertTrue(output.isDictLongKey());
final Map<Long, IValue> outputMap = output.toDictLongKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@Test
public void testEqDictStrKeyIntValue() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<String, IValue> inputMap = new HashMap<>();
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
inputMap.put("long_0", IValue.from(0l));
inputMap.put("long_1", IValue.from(1l));
inputMap.put("long_-1", IValue.from(-1l));
final IValue input = IValue.dictStringKeyFrom(inputMap);
assertTrue(input.isDictStringKey());
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
assertTrue(output.isDictStringKey());
final Map<String, IValue> outputMap = output.toDictStringKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@Test
public void testListIntSumReturnTuple() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (int n : new int[]{0, 1, 128}) {
long[] a = new long[n];
long sum = 0;
for (int i = 0; i < n; i++) {
a[i] = i;
sum += a[i];
}
final IValue input = IValue.listFrom(a);
assertTrue(input.isLongList());
final IValue output = module.runMethod("listIntSumReturnTuple", input);
assertTrue(output.isTuple());
assertTrue(2 == output.toTuple().length);
IValue output0 = output.toTuple()[0];
IValue output1 = output.toTuple()[1];
assertArrayEquals(a, output0.toLongList());
assertTrue(sum == output1.toLong());
}
}
@Test
public void testOptionalIntIsNone() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
}
@Test
public void testIntEq0None() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
}
@Test(expected = IllegalArgumentException.class)
public void testRunUndefinedMethod() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
module.runMethod("test_undefined_method_throws_exception");
}
@Test
public void testTensorMethods() {
long[] shape = new long[]{1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
int[] ints = new int[numel];
float[] floats = new float[numel];
byte[] bytes = new byte[numel];
for (int i = 0; i < numel; i++) {
bytes[i] = (byte) ((i % 255) - 128);
ints[i] = i;
floats[i] = i / 1000.f;
}
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
assertTrue(tensorBytes.dtype() == DType.INT8);
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
Tensor tensorInts = Tensor.fromBlob(ints, shape);
assertTrue(tensorInts.dtype() == DType.INT32);
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
float[] floatsOut = tensorFloats.getDataAsFloatArray();
assertTrue(floatsOut.length == numel);
for (int i = 0; i < numel; i++) {
assertTrue(floats[i] == floatsOut[i]);
}
}
@Test(expected = IllegalStateException.class)
public void testTensorIllegalStateOnWrongType() {
long[] shape = new long[]{1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
float[] floats = new float[numel];
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
tensorFloats.getDataAsByteArray();
}
@Test
public void testEqString() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[]{
"smoketest",
"проверка не латинских символов", // not latin symbols check
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("eqStr", input);
assertTrue(output.isString());
assertTrue(value.equals(output.toStr()));
}
}
@Test
public void testStr3Concat() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[]{
"smoketest",
"проверка не латинских символов", // not latin symbols check
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("str3Concat", input);
assertTrue(output.isString());
String expectedOutput =
new StringBuilder().append(value).append(value).append(value).toString();
assertTrue(expectedOutput.equals(output.toStr()));
}
}
@Test
public void testEmptyShape() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final long someNumber = 43;
final IValue input = IValue.from(Tensor.fromBlob(new long[]{someNumber}, new long[]{}));
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
assertTrue(output.isTensor());
Tensor value = output.toTensor();
assertArrayEquals(new long[]{}, value.shape());
assertArrayEquals(new long[]{someNumber}, value.getDataAsLongArray());
}
@Test
public void testAliasWithOffset() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testAliasWithOffset");
assertTrue(output.isTensorList());
Tensor[] tensors = output.toTensorList();
assertEquals(100, tensors[0].getDataAsLongArray()[0]);
assertEquals(200, tensors[1].getDataAsLongArray()[0]);
}
@Test
public void testNonContiguous() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testNonContiguous");
assertTrue(output.isTensor());
Tensor value = output.toTensor();
assertArrayEquals(new long[]{2}, value.shape());
assertArrayEquals(new long[]{100, 300}, value.getDataAsLongArray());
}
@Test
public void testChannelsLast() throws IOException {
long[] inputShape = new long[]{1, 3, 2, 2};
long[] data = new long[]{1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
assertIValueTensor(
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[]{1, 3, 2, 2},
new long[]{1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104});
final IValue outputNHWC = module.runMethod("contiguousChannelsLast", IValue.from(inputNHWC));
assertIValueTensor(outputNHWC, MemoryFormat.CHANNELS_LAST, inputShape, data);
}
@Test
public void testChannelsLast3d() throws IOException {
long[] shape = new long[]{1, 2, 2, 2, 2};
long[] dataNCHWD = new long[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
long[] dataNHWDC = new long[]{1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
Tensor inputNCHWD = Tensor.fromBlob(dataNCHWD, shape, MemoryFormat.CONTIGUOUS);
final IValue outputNHWDC =
module.runMethod("contiguousChannelsLast3d", IValue.from(inputNCHWD));
assertIValueTensor(outputNHWDC, MemoryFormat.CHANNELS_LAST_3D, shape, dataNHWDC);
}
@Test
public void testChannelsLastConv2d() throws IOException {
long[] inputShape = new long[]{1, 3, 2, 2};
long[] dataNCHW = new long[]{1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104};
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNHWC = new long[]{1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
long[] weightShape = new long[]{3, 3, 1, 1};
long[] dataWeightOIHW = new long[]{2, 0, 0, 0, 1, 0, 0, 0, -1};
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
long[] dataWeightOHWI = new long[]{2, 0, 0, 0, 1, 0, 0, 0, -1};
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW =
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
assertIValueTensor(
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[]{1, 3, 2, 2},
new long[]{2, 4, 6, 8, 11, 12, 13, 14, -101, -102, -103, -104});
final IValue outputNHWC =
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
assertIValueTensor(
outputNHWC,
MemoryFormat.CHANNELS_LAST,
new long[]{1, 3, 2, 2},
new long[]{2, 11, -101, 4, 12, -102, 6, 13, -103, 8, 14, -104});
}
@Test
public void testMobileNetV2() throws IOException {
try {
final Module module = loadModel("mobilenet_v2.ptl");
final IValue inputs = module.runMethod("get_all_bundled_inputs");
assertTrue(inputs.isList());
final IValue input = inputs.toList()[0];
assertTrue(input.isTuple());
module.forward(input.toTuple()[0]);
assertTrue(true);
} catch (Exception ex) {
assertTrue("failed to run MobileNetV2 " + ex.getMessage(), false);
}
}
@Test
public void testPointwiseOps() throws IOException {
runModel("pointwise_ops");
}
@Test
public void testReductionOps() throws IOException {
runModel("reduction_ops");
}
@Test
public void testComparisonOps() throws IOException {
runModel("comparison_ops");
}
@Test
public void testOtherMathOps() throws IOException {
runModel("other_math_ops");
}
@Test
public void testSpectralOps() throws IOException {
runModel("spectral_ops");
}
@Test
public void testBlasLapackOps() throws IOException {
runModel("blas_lapack_ops");
}
@Test
public void testSamplingOps() throws IOException {
runModel("sampling_ops");
}
@Test
public void testTensorOps() throws IOException {
runModel("tensor_general_ops");
}
@Test
public void testTensorCreationOps() throws IOException {
runModel("tensor_creation_ops");
}
@Test
public void testTensorIndexingOps() throws IOException {
runModel("tensor_indexing_ops");
}
@Test
public void testTensorTypingOps() throws IOException {
runModel("tensor_typing_ops");
}
@Test
public void testTensorViewOps() throws IOException {
runModel("tensor_view_ops");
}
@Test
public void testConvolutionOps() throws IOException {
runModel("convolution_ops");
}
@Test
public void testPoolingOps() throws IOException {
runModel("pooling_ops");
}
@Test
public void testPaddingOps() throws IOException {
runModel("padding_ops");
}
@Test
public void testActivationOps() throws IOException {
runModel("activation_ops");
}
@Test
public void testNormalizationOps() throws IOException {
runModel("normalization_ops");
}
@Test
public void testRecurrentOps() throws IOException {
runModel("recurrent_ops");
}
@Test
public void testTransformerOps() throws IOException {
runModel("transformer_ops");
}
@Test
public void testLinearOps() throws IOException {
runModel("linear_ops");
}
@Test
public void testDropoutOps() throws IOException {
runModel("dropout_ops");
}
@Test
public void testSparseOps() throws IOException {
runModel("sparse_ops");
}
@Test
public void testDistanceFunctionOps() throws IOException {
runModel("distance_function_ops");
}
@Test
public void testLossFunctionOps() throws IOException {
runModel("loss_function_ops");
}
@Test
public void testVisionFunctionOps() throws IOException {
runModel("vision_function_ops");
}
@Test
public void testShuffleOps() throws IOException {
runModel("shuffle_ops");
}
@Test
public void testNNUtilsOps() throws IOException {
runModel("nn_utils_ops");
}
@Test
public void testQuantOps() throws IOException {
runModel("general_quant_ops");
}
@Test
public void testDynamicQuantOps() throws IOException {
runModel("dynamic_quant_ops");
}
@Test
public void testStaticQuantOps() throws IOException {
runModel("static_quant_ops");
}
@Test
public void testFusedQuantOps() throws IOException {
runModel("fused_quant_ops");
}
@Test
public void testTorchScriptBuiltinQuantOps() throws IOException {
runModel("torchscript_builtin_ops");
}
@Test
public void testTorchScriptCollectionQuantOps() throws IOException {
runModel("torchscript_collection_ops");
}
static void assertIValueTensor(
final IValue ivalue,
final MemoryFormat memoryFormat,
final long[] expectedShape,
final long[] expectedData) {
assertTrue(ivalue.isTensor());
Tensor t = ivalue.toTensor();
assertEquals(memoryFormat, t.memoryFormat());
assertArrayEquals(expectedShape, t.shape());
assertArrayEquals(expectedData, t.getDataAsLongArray());
}
void runModel(final String name) throws IOException {
final Module storage_module = loadModel(name + ".ptl");
storage_module.forward();
// TODO enable this once the on-the-fly script is ready
// final Module on_the_fly_module = loadModel(name + "_temp.ptl");
// on_the_fly_module.forward();
assertTrue(true);
}
protected abstract Module loadModel(String assetName) throws IOException;
}
@Test
public void testEmptyShape() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final long someNumber = 43;
final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {}));
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
assertTrue(output.isTensor());
Tensor value = output.toTensor();
assertArrayEquals(new long[] {}, value.shape());
assertArrayEquals(new long[] {someNumber}, value.getDataAsLongArray());
}
@Test
public void testAliasWithOffset() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testAliasWithOffset");
assertTrue(output.isTensorList());
Tensor[] tensors = output.toTensorList();
assertEquals(100, tensors[0].getDataAsLongArray()[0]);
assertEquals(200, tensors[1].getDataAsLongArray()[0]);
}
@Test
public void testNonContiguous() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testNonContiguous");
assertTrue(output.isTensor());
Tensor value = output.toTensor();
assertArrayEquals(new long[] {2}, value.shape());
assertArrayEquals(new long[] {100, 300}, value.getDataAsLongArray());
}
@Test
public void testChannelsLast() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2};
long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
assertIValueTensor(
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2},
new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104});
final IValue outputNHWC = module.runMethod("contiguousChannelsLast", IValue.from(inputNHWC));
assertIValueTensor(outputNHWC, MemoryFormat.CHANNELS_LAST, inputShape, data);
}
@Test
public void testChannelsLast3d() throws IOException {
long[] shape = new long[] {1, 2, 2, 2, 2};
long[] dataNCHWD = new long[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
Tensor inputNCHWD = Tensor.fromBlob(dataNCHWD, shape, MemoryFormat.CONTIGUOUS);
final IValue outputNHWDC =
module.runMethod("contiguousChannelsLast3d", IValue.from(inputNCHWD));
assertIValueTensor(outputNHWDC, MemoryFormat.CHANNELS_LAST_3D, shape, dataNHWDC);
}
@Test
public void testChannelsLastConv2d() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2};
long[] dataNCHW = new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104};
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNHWC = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
long[] weightShape = new long[] {3, 3, 1, 1};
long[] dataWeightOIHW = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
long[] dataWeightOHWI = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW =
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
assertIValueTensor(
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2},
new long[] {2, 4, 6, 8, 11, 12, 13, 14, -101, -102, -103, -104});
final IValue outputNHWC =
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
assertIValueTensor(
outputNHWC,
MemoryFormat.CHANNELS_LAST,
new long[] {1, 3, 2, 2},
new long[] {2, 11, -101, 4, 12, -102, 6, 13, -103, 8, 14, -104});
}
static void assertIValueTensor(
final IValue ivalue,
final MemoryFormat memoryFormat,
final long[] expectedShape,
final long[] expectedData) {
assertTrue(ivalue.isTensor());
Tensor t = ivalue.toTensor();
assertEquals(memoryFormat, t.memoryFormat());
assertArrayEquals(expectedShape, t.shape());
assertArrayEquals(expectedData, t.getDataAsLongArray());
}
protected abstract Module loadModel(String assetName) throws IOException;
}

View File

@ -16,7 +16,7 @@ class TSBuiltinOpsModule(torch.nn.Module):
l = ["1", "2", "test"]
d = {"key": 1}
d2 = {0: 100}
return len(
return (
# type
bool(x),
bool(x.item()),
@ -111,7 +111,7 @@ class TSCollectionOpsModule(torch.nn.Module):
d2.clear()
d2[0] = 100
return len(
return (
s[torch.tensor(1)],
d["key"],
d2[0],

View File

@ -1,8 +1,3 @@
_coverage: 91.09
_covered_ops: 358
_generated_ops: 705
_production_ops: 393
_uncovered_ops: 35
all_generated_ops:
- aten::Bool.Tensor
- aten::Bool.int
@ -647,8 +642,6 @@ all_generated_ops:
- aten::zeros
- aten::zeros.out
- aten::zeros_like
- prepacked::conv2d_clamp_run
- prepacked::linear_clamp_run
- prim::RaiseException
- prim::TupleIndex
- prim::TupleUnpack
@ -1030,8 +1023,6 @@ covered_ops:
- aten::zeros
- aten::zeros.out
- aten::zeros_like
- prepacked::conv2d_clamp_run
- prepacked::linear_clamp_run
- prim::RaiseException
- prim::TupleIndex
- prim::TupleUnpack
@ -1095,8 +1086,10 @@ uncovered_ops:
- aten::true_divide.Tensor
- aten::upsample_nearest2d
- prepacked::conv2d_clamp_prepack
- prepacked::conv2d_clamp_run
- prepacked::conv2d_transpose_clamp_prepack
- prepacked::conv2d_transpose_clamp_run
- prepacked::linear_clamp_run
- prim::ModuleContainerIndex.list
- prim::NumToTensor.Scalar
- prim::Print

View File

@ -1,9 +1,5 @@
import io
import sys
import torch
import yaml
from android_api_module import AndroidAPIModule
from builtin_ops import (
TSBuiltinOpsModule,
TSCollectionOpsModule,
@ -47,188 +43,92 @@ from tensor_ops import (
TensorTypingOpsModule,
TensorViewOpsModule,
)
from torch.jit.mobile import _load_for_lite_interpreter
from torchvision_models import MobileNetV2Module
test_path_ios = "ios/TestApp/models/"
test_path_android = "android/pytorch_android/src/androidTest/assets/"
output_path = "ios/TestApp/models/"
production_ops_path = "test/mobile/model_test/model_ops.yaml"
coverage_out_path = "test/mobile/model_test/coverage.yaml"
all_modules = {
# math ops
"pointwise_ops": PointwiseOpsModule(),
"reduction_ops": ReductionOpsModule(),
"comparison_ops": ComparisonOpsModule(),
"spectral_ops": SpectralOpsModule(),
"other_math_ops": OtherMathOpsModule(),
"blas_lapack_ops": BlasLapackOpsModule(),
# sampling
"sampling_ops": SamplingOpsModule(),
# tensor ops
"tensor_general_ops": TensorOpsModule(),
"tensor_creation_ops": TensorCreationOpsModule(),
"tensor_indexing_ops": TensorIndexingOpsModule(),
"tensor_typing_ops": TensorTypingOpsModule(),
"tensor_view_ops": TensorViewOpsModule(),
# nn ops
"convolution_ops": NNConvolutionModule(),
"pooling_ops": NNPoolingModule(),
"padding_ops": NNPaddingModule(),
"activation_ops": NNActivationModule(),
"normalization_ops": NNNormalizationModule(),
"recurrent_ops": NNRecurrentModule(),
"transformer_ops": NNTransformerModule(),
"linear_ops": NNLinearModule(),
"dropout_ops": NNDropoutModule(),
"sparse_ops": NNSparseModule(),
"distance_function_ops": NNDistanceModule(),
"loss_function_ops": NNLossFunctionModule(),
"vision_function_ops": NNVisionModule(),
"shuffle_ops": NNShuffleModule(),
"nn_utils_ops": NNUtilsModule(),
# quantization ops
"general_quant_ops": GeneralQuantModule(),
"dynamic_quant_ops": DynamicQuantModule(),
"static_quant_ops": StaticQuantModule(),
"fused_quant_ops": FusedQuantModule(),
# TorchScript buildin ops
"torchscript_builtin_ops": TSBuiltinOpsModule(),
"torchscript_collection_ops": TSCollectionOpsModule(),
# vision
"mobilenet_v2": MobileNetV2Module(),
# android api module
"android_api_module": AndroidAPIModule(),
}
models_need_trace = [
"static_quant_ops",
def scriptAndSave(module, name):
module = torch.jit.script(module)
return save(module, name)
def traceAndSave(module, name):
module = torch.jit.trace(module, [])
return save(module, name)
def save(module, name):
module._save_for_lite_interpreter(output_path + name + ".ptl")
print("model saved to " + output_path + name + ".ptl")
ops = torch.jit.export_opnames(module)
print(ops)
module()
return ops
ops = [
# math ops
scriptAndSave(PointwiseOpsModule(), "pointwise_ops"),
scriptAndSave(ReductionOpsModule(), "reduction_ops"),
scriptAndSave(ComparisonOpsModule(), "comparison_ops"),
scriptAndSave(OtherMathOpsModule(), "other_math_ops"),
scriptAndSave(SpectralOpsModule(), "spectral_ops"),
scriptAndSave(BlasLapackOpsModule(), "blas_lapack_ops"),
# sampling
scriptAndSave(SamplingOpsModule(), "sampling_ops"),
# tensor ops
scriptAndSave(TensorOpsModule(), "tensor_general_ops"),
scriptAndSave(TensorCreationOpsModule(), "tensor_creation_ops"),
scriptAndSave(TensorIndexingOpsModule(), "tensor_indexing_ops"),
scriptAndSave(TensorTypingOpsModule(), "tensor_typing_ops"),
scriptAndSave(TensorViewOpsModule(), "tensor_view_ops"),
# nn ops
scriptAndSave(NNConvolutionModule(), "convolution_ops"),
scriptAndSave(NNPoolingModule(), "pooling_ops"),
scriptAndSave(NNPaddingModule(), "padding_ops"),
scriptAndSave(NNActivationModule(), "activation_ops"),
scriptAndSave(NNNormalizationModule(), "normalization_ops"),
scriptAndSave(NNRecurrentModule(), "recurrent_ops"),
scriptAndSave(NNTransformerModule(), "transformer_ops"),
scriptAndSave(NNLinearModule(), "linear_ops"),
scriptAndSave(NNDropoutModule(), "dropout_ops"),
scriptAndSave(NNSparseModule(), "sparse_ops"),
scriptAndSave(NNDistanceModule(), "distance_function_ops"),
scriptAndSave(NNLossFunctionModule(), "loss_function_ops"),
scriptAndSave(NNVisionModule(), "vision_function_ops"),
scriptAndSave(NNShuffleModule(), "shuffle_ops"),
scriptAndSave(NNUtilsModule(), "nn_utils_ops"),
# quantization ops
scriptAndSave(GeneralQuantModule(), "general_quant_ops"),
scriptAndSave(DynamicQuantModule().getModule(), "dynamic_quant_ops"),
traceAndSave(StaticQuantModule().getModule(), "static_quant_ops"),
scriptAndSave(FusedQuantModule().getModule(), "fused_quant_ops"),
# TorchScript buildin ops
scriptAndSave(TSBuiltinOpsModule(), "torchscript_builtin_ops"),
scriptAndSave(TSCollectionOpsModule(), "torchscript_collection_ops"),
]
def calcOpsCoverage(ops):
with open(production_ops_path) as input_yaml_file:
production_ops = yaml.safe_load(input_yaml_file)
with open(production_ops_path) as input_yaml_file:
production_ops = yaml.safe_load(input_yaml_file)
production_ops = set(production_ops["root_operators"])
all_generated_ops = set(ops)
covered_ops = production_ops.intersection(all_generated_ops)
uncovered_ops = production_ops - covered_ops
coverage = 100 * len(covered_ops) / len(production_ops)
print(
f"\nGenerated {len(all_generated_ops)} ops and covered "
+ f"{len(covered_ops)}/{len(production_ops)} ({round(coverage, 2)}%) production ops. \n"
production_ops = set(production_ops["root_operators"])
all_generated_ops = set().union(*ops)
covered_ops = production_ops.intersection(all_generated_ops)
uncovered_ops = production_ops - covered_ops
coverage = 100 * len(covered_ops) / len(production_ops)
print(
f"\nGenerated {len(all_generated_ops)} ops and covered "
+ f"{len(covered_ops)}/{len(production_ops)} ({round(coverage, 2)}%) production ops. \n"
)
with open(coverage_out_path, "w") as f:
yaml.safe_dump(
{
"uncovered_ops": sorted(list(uncovered_ops)),
"covered_ops": sorted(list(covered_ops)),
"all_generated_ops": sorted(list(all_generated_ops)),
},
f,
)
with open(coverage_out_path, "w") as f:
yaml.safe_dump(
{
"_covered_ops": len(covered_ops),
"_production_ops": len(production_ops),
"_generated_ops": len(all_generated_ops),
"_uncovered_ops": len(uncovered_ops),
"_coverage": round(coverage, 2),
"uncovered_ops": sorted(list(uncovered_ops)),
"covered_ops": sorted(list(covered_ops)),
"all_generated_ops": sorted(list(all_generated_ops)),
},
f,
)
def getModuleFromName(model_name):
if model_name not in all_modules:
print("Cannot find test model for " + model_name)
return None, []
module = all_modules[model_name]
if not isinstance(module, torch.nn.Module):
module = module.getModule()
has_bundled_inputs = False # module.find_method("get_all_bundled_inputs")
if model_name in models_need_trace:
module = torch.jit.trace(module, [])
else:
module = torch.jit.script(module)
ops = torch.jit.export_opnames(module)
print(ops)
# try to run the model
runModule(module)
return module, ops
def runModule(module):
buffer = io.BytesIO(module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
lite_module = _load_for_lite_interpreter(buffer)
if lite_module.find_method("get_all_bundled_inputs"):
# run with the first bundled input
input = lite_module.run_method("get_all_bundled_inputs")[0]
lite_module.forward(*input)
else:
# assuming model has no input
lite_module()
# generate all models in the given folder.
# If it's "on the fly" mode, add "_temp" suffix to the model file.
def generateAllModels(folder, on_the_fly=False):
all_ops = []
for name in all_modules:
module, ops = getModuleFromName(name)
all_ops = all_ops + ops
path = folder + name + ("_temp.ptl" if on_the_fly else ".ptl")
module._save_for_lite_interpreter(path)
print("model saved to " + path)
calcOpsCoverage(all_ops)
# generate/update a given model for storage
def generateModel(name):
module, ops = getModuleFromName(name)
if module is None:
return
path_ios = test_path_ios + name + ".ptl"
path_android = test_path_android + name + ".ptl"
module._save_for_lite_interpreter(path_ios)
module._save_for_lite_interpreter(path_android)
print("model saved to " + path_ios + " and " + path_android)
def main(argv):
if argv is None or len(argv) != 1:
print(
"""
This script generate models for mobile test. For each model we have a "storage" version
and an "on-the-fly" version. The "on-the-fly" version will be generated during test,and
should not be committed to the repo.
The "storage" version is for back compatibility # test (a model generated today should
run on master branch in the next 6 months). We can use this script to update a model that
is no longer supported.
- use 'python gen_test_model.py android-test' to generate on-the-fly models for android
- use 'python gen_test_model.py ios-test' to generate on-the-fly models for ios
- use 'python gen_test_model.py android' to generate checked-in models for android
- use 'python gen_test_model.py ios' to generate on-the-fly models for ios
- use 'python gen_test_model.py <model_name_no_suffix>' to update the given storage model
"""
)
return
if argv[0] == "android":
generateAllModels(test_path_android, on_the_fly=False)
elif argv[0] == "ios":
generateAllModels(test_path_ios, on_the_fly=False)
elif argv[0] == "android-test":
generateAllModels(test_path_android, on_the_fly=True)
elif argv[0] == "ios-test":
generateAllModels(test_path_ios, on_the_fly=True)
else:
generateModel(argv[0])
if __name__ == "__main__":
main(sys.argv[1:])

View File

@ -22,7 +22,7 @@ class PointwiseOpsModule(torch.nn.Module):
f = torch.zeros(3)
g = torch.tensor([-1, 0, 1])
w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
return len(
return (
torch.abs(torch.tensor([-1, -2, 3])),
torch.absolute(torch.tensor([-1, -2, 3])),
torch.acos(a),
@ -222,7 +222,7 @@ class ReductionOpsModule(torch.nn.Module):
a = torch.randn(4)
b = torch.randn(4)
c = torch.tensor(0.5)
return len(
return (
torch.argmax(a),
torch.argmin(a),
torch.amax(a),
@ -271,7 +271,7 @@ class ComparisonOpsModule(torch.nn.Module):
def forward(self):
a = torch.tensor(0)
b = torch.tensor(1)
return len(
return (
torch.allclose(a, b),
torch.argsort(a),
torch.eq(a, b),
@ -327,7 +327,7 @@ class OtherMathOpsModule(torch.nn.Module):
f = torch.randn(4, 4, 4)
size = [0, 1]
dims = [0, 1]
return len(
return (
torch.atleast_1d(a),
torch.atleast_2d(a),
torch.atleast_3d(a),
@ -380,7 +380,7 @@ class OtherMathOpsModule(torch.nn.Module):
torch.triu_indices(3, 3),
torch.vander(a),
torch.view_as_real(torch.randn(4, dtype=torch.cfloat)),
torch.view_as_complex(torch.randn(4, 2)).real,
torch.view_as_complex(torch.randn(4, 2)),
torch.resolve_conj(a),
torch.resolve_neg(a),
)
@ -396,7 +396,7 @@ class SpectralOpsModule(torch.nn.Module):
def spectral_ops(self):
a = torch.randn(10)
b = torch.randn(10, 8, 4, 2)
return len(
return (
torch.stft(a, 8),
torch.stft(a, torch.tensor(8)),
torch.istft(b, 8),
@ -420,7 +420,7 @@ class BlasLapackOpsModule(torch.nn.Module):
a = torch.randn(10, 3, 4)
b = torch.randn(10, 4, 3)
v = torch.randn(3)
return len(
return (
torch.addbmm(m, a, b),
torch.addmm(torch.randn(2, 3), torch.randn(2, 3), torch.randn(3, 3)),
torch.addmv(torch.randn(2), torch.randn(2, 3), torch.randn(3)),

View File

@ -31,11 +31,11 @@ class NNConvolutionModule(torch.nn.Module):
)
def forward(self):
return len((
return (
[module(self.input1d) for i, module in enumerate(self.module1d)],
[module(self.input2d) for i, module in enumerate(self.module2d)],
[module(self.input3d) for i, module in enumerate(self.module3d)],
))
)
class NNPoolingModule(torch.nn.Module):
@ -77,11 +77,11 @@ class NNPoolingModule(torch.nn.Module):
# TODO max_unpool
def forward(self):
return len((
return (
[module(self.input1d) for i, module in enumerate(self.module1d)],
[module(self.input2d) for i, module in enumerate(self.module2d)],
[module(self.input3d) for i, module in enumerate(self.module3d)],
))
)
class NNPaddingModule(torch.nn.Module):
@ -116,11 +116,11 @@ class NNPaddingModule(torch.nn.Module):
)
def forward(self):
return len((
return (
[module(self.input1d) for i, module in enumerate(self.module1d)],
[module(self.input2d) for i, module in enumerate(self.module2d)],
[module(self.input3d) for i, module in enumerate(self.module3d)],
))
)
class NNNormalizationModule(torch.nn.Module):
@ -155,11 +155,11 @@ class NNNormalizationModule(torch.nn.Module):
)
def forward(self):
return len((
return (
[module(self.input1d) for i, module in enumerate(self.module1d)],
[module(self.input2d) for i, module in enumerate(self.module2d)],
[module(self.input3d) for i, module in enumerate(self.module3d)],
))
)
class NNActivationModule(torch.nn.Module):
@ -202,9 +202,9 @@ class NNActivationModule(torch.nn.Module):
def forward(self):
input = torch.randn(2, 3, 4)
return len((
[module(input) for i, module in enumerate(self.activations)],
))
for i, module in enumerate(self.activations):
x = module(input)
return x
class NNRecurrentModule(torch.nn.Module):
@ -228,13 +228,14 @@ class NNRecurrentModule(torch.nn.Module):
input = torch.randn(5, 3, 4)
h = torch.randn(2, 3, 8)
c = torch.randn(2, 3, 8)
r = self.rnn[0](input, h)
r = self.rnn[1](input[0], h[0])
r = self.gru[0](input, h)
r = self.gru[1](input[0], h[0])
r = self.lstm[0](input, (h, c))
r = self.lstm[1](input[0], (h[0], c[0]))
return len(r)
return (
self.rnn[0](input, h),
self.rnn[1](input[0], h[0]),
self.gru[0](input, h),
self.gru[1](input[0], h[0]),
self.lstm[0](input, (h, c)),
self.lstm[1](input[0], (h[0], c[0])),
)
class NNTransformerModule(torch.nn.Module):
@ -257,10 +258,11 @@ class NNTransformerModule(torch.nn.Module):
def forward(self):
input = torch.rand(1, 16, 2)
tgt = torch.rand((1, 16, 2))
r = self.transformers[0](input, tgt)
r = self.transformers[1](input)
r = self.transformers[2](input, tgt)
return len(r)
return (
self.transformers[0](input, tgt),
self.transformers[1](input),
self.transformers[2](input, tgt),
)
class NNLinearModule(torch.nn.Module):
@ -277,10 +279,11 @@ class NNLinearModule(torch.nn.Module):
def forward(self):
input = torch.randn(32, 20)
r = self.linears[0](input)
r = self.linears[1](input)
r = self.linears[2](input, input)
return len(r)
return (
self.linears[0](input),
self.linears[1](input),
self.linears[2](input, input),
)
class NNDropoutModule(torch.nn.Module):
@ -291,7 +294,7 @@ class NNDropoutModule(torch.nn.Module):
a = torch.randn(8, 4)
b = torch.randn(8, 4, 4, 4)
c = torch.randn(8, 4, 4, 4, 4)
return len(
return (
F.dropout(a),
F.dropout2d(b),
F.dropout3d(c),
@ -309,7 +312,7 @@ class NNSparseModule(torch.nn.Module):
input2 = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
embedding_matrix = torch.rand(10, 3)
offsets = torch.tensor([0, 4])
return len(
return (
F.embedding(input, embedding_matrix),
F.embedding_bag(input2, embedding_matrix, offsets),
F.one_hot(torch.arange(0, 5) % 3, num_classes=5),
@ -323,7 +326,7 @@ class NNDistanceModule(torch.nn.Module):
def forward(self):
a = torch.randn(8, 4)
b = torch.randn(8, 4)
return len(
return (
F.pairwise_distance(a, b),
F.cosine_similarity(a, b),
F.pdist(a),
@ -344,7 +347,7 @@ class NNLossFunctionModule(torch.nn.Module):
targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
input_lengths = torch.full((16,), 50, dtype=torch.long)
target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
return len(
return (
F.binary_cross_entropy(torch.sigmoid(a), b),
F.binary_cross_entropy_with_logits(torch.sigmoid(a), b),
F.poisson_nll_loss(a, b),
@ -389,10 +392,8 @@ class NNVisionModule(torch.nn.Module):
def forward(self):
input = torch.randn(1, 3, 16, 16)
for i, module in enumerate(self.vision_modules):
r = module(self.input)
return len(
r,
return (
[module(self.input) for i, module in enumerate(self.vision_modules)],
self.linear_sample(torch.randn(4, 9, 9)),
self.trilinear_sample(torch.randn(1, 3, 4, 9, 9)),
F.grid_sample(input, torch.ones(1, 4, 4, 2)),
@ -405,7 +406,7 @@ class NNShuffleModule(torch.nn.Module):
self.shuffle = nn.ChannelShuffle(2)
def forward(self):
return len(self.shuffle(torch.randn(1, 4, 2, 2)),)
return (self.shuffle(torch.randn(1, 4, 2, 2)),)
class NNUtilsModule(torch.nn.Module):
@ -418,6 +419,6 @@ class NNUtilsModule(torch.nn.Module):
def forward(self):
input = torch.randn(2, 50)
return len(
return (
self.flatten(input),
)

View File

@ -23,7 +23,7 @@ class GeneralQuantModule(torch.nn.Module):
input1 = torch.randn(1, 16, 4)
input2 = torch.randn(1, 16, 4, 4)
input3 = torch.randn(1, 16, 4, 4, 4)
return len(
return (
self.func.add(a, b),
self.func.cat((a, a), 0),
self.func.mul(a, b),
@ -92,7 +92,7 @@ class DynamicQuantModule:
trans_input = torch.randn(1, 16, 2)
tgt = torch.rand(1, 16, 2)
return len((
return (
self.rnn(input, h),
self.rnncell(input[0], h[0]),
self.gru(input, h),
@ -105,7 +105,7 @@ class DynamicQuantModule:
self.linears[0](linear_input),
self.linears[1](linear_input),
self.linears[2](linear_input, linear_input),
))
)
class StaticQuantModule:

View File

@ -11,7 +11,7 @@ class SamplingOpsModule(torch.nn.Module):
a = torch.empty(3, 3).uniform_(0.0, 1.0)
size = (1, 4)
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
return len(
return (
# torch.seed(),
# torch.manual_seed(0),
torch.bernoulli(a),

View File

@ -15,7 +15,7 @@ class TensorOpsModule(torch.nn.Module):
c = torch.randn(4, dtype=torch.cfloat)
w = torch.rand(4, 4, 4, 4)
v = torch.rand(4, 4, 4, 4)
return len(
return (
# torch.is_tensor(a),
# torch.is_storage(a),
torch.is_complex(a),
@ -120,7 +120,7 @@ class TensorCreationOpsModule(torch.nn.Module):
0,
torch.quint8,
)
return len(
return (
torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]),
# torch.sparse_coo_tensor(i, v, [2, 3]), # not work for iOS
torch.as_tensor([1, 2, 3]),
@ -171,7 +171,7 @@ class TensorIndexingOpsModule(torch.nn.Module):
t = torch.tensor([[0, 0], [1, 0]])
mask = x.ge(0.5)
i = [0, 1]
return len(
return (
torch.cat((x, x, x), 0),
torch.concat((x, x, x), 0),
torch.conj(x),
@ -233,7 +233,7 @@ class TensorTypingOpsModule(torch.nn.Module):
def tensor_typing_ops(self):
x = torch.randn(1, 3, 4, 4)
return len(
return (
x.to(torch.float),
x.to(torch.double),
x.to(torch.cfloat),
@ -262,7 +262,7 @@ class TensorViewOpsModule(torch.nn.Module):
def tensor_view_ops(self):
x = torch.randn(4, 4, 1)
y = torch.randn(4, 4, 2)
return len(
return (
x[0, 2:],
x.detach(),
x.detach_(),

View File

@ -1,24 +0,0 @@
import torch
import torchvision
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
from torch.utils.mobile_optimizer import optimize_for_mobile
class MobileNetV2Module:
def __init__(self):
super(MobileNetV2Module, self).__init__()
def getModule(self):
model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.zeros(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_module = optimize_for_mobile(traced_script_module)
augment_model_with_bundled_inputs(
optimized_module,
[
(example, ),
],
)
optimized_module(example)
return optimized_module