mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "add model test for Android"
This reverts commit 91ef3c82615d6ede05d5b86f1bd5571ea95e4ef1. Reverted https://github.com/pytorch/pytorch/pull/74793 on behalf of https://github.com/seemethere
This commit is contained in:
@ -235,23 +235,6 @@ To load torchscript model for mobile we need some special setup which is placed
|
||||
|
||||
[Example of linking to libtorch from aar](https://github.com/pytorch/pytorch/tree/master/android/test_app)
|
||||
|
||||
|
||||
## Running Instrumentation Test
|
||||
Test using lite interpreter
|
||||
```
|
||||
./android/run_tests.sh
|
||||
```
|
||||
|
||||
Test using full JIT
|
||||
```
|
||||
BUILD_LITE_INTERPRETER=0 ./android/run_tests.sh
|
||||
```
|
||||
|
||||
Regenerate all test models
|
||||
```
|
||||
python test/mobile/model_test/gen_test_model.py android
|
||||
```
|
||||
|
||||
## PyTorch Android API Javadoc
|
||||
|
||||
You can find more details about the PyTorch Android API in the [Javadoc](https://pytorch.org/javadoc/).
|
||||
|
@ -1,12 +1,22 @@
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
OUTPUT_DIR = "src/androidTest/assets/"
|
||||
|
||||
class AndroidAPIModule(torch.jit.ScriptModule):
|
||||
def scriptAndSave(module, fileName):
|
||||
print('-' * 80)
|
||||
script_module = torch.jit.script(module)
|
||||
print(script_module.graph)
|
||||
outputFileName = OUTPUT_DIR + fileName
|
||||
# note that the lite interpreter model can also be used in full JIT
|
||||
script_module._save_for_lite_interpreter(outputFileName)
|
||||
print("Saved to " + outputFileName)
|
||||
print('=' * 80)
|
||||
|
||||
class Test(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(AndroidAPIModule, self).__init__()
|
||||
super(Test, self).__init__()
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input):
|
||||
@ -66,9 +76,7 @@ class AndroidAPIModule(torch.jit.ScriptModule):
|
||||
return res
|
||||
|
||||
@torch.jit.script_method
|
||||
def tupleIntSumReturnTuple(
|
||||
self, input: Tuple[int, int, int]
|
||||
) -> Tuple[Tuple[int, int, int], int]:
|
||||
def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
@ -109,7 +117,7 @@ class AndroidAPIModule(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = torch.nn.functional.conv2d(x, w)
|
||||
if toChannelsLast:
|
||||
if (toChannelsLast):
|
||||
r = r.contiguous(memory_format=torch.channels_last)
|
||||
else:
|
||||
r = r.contiguous()
|
||||
@ -126,3 +134,5 @@ class AndroidAPIModule(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous(memory_format=torch.channels_last_3d)
|
||||
|
||||
scriptAndSave(Test(), "test.pt")
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/test.pt
Normal file
BIN
android/pytorch_android/src/androidTest/assets/test.pt
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -9,577 +9,385 @@ import static org.junit.Assert.assertTrue;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public abstract class PytorchTestBase {
|
||||
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
|
||||
private static final String TEST_MODULE_ASSET_NAME = "test.pt";
|
||||
|
||||
@Test
|
||||
public void testForwardNull() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[]{1}));
|
||||
assertTrue(input.isTensor());
|
||||
final IValue output = module.forward(input);
|
||||
assertTrue(output.isNull());
|
||||
@Test
|
||||
public void testForwardNull() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
|
||||
assertTrue(input.isTensor());
|
||||
final IValue output = module.forward(input);
|
||||
assertTrue(output.isNull());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqBool() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
for (boolean value : new boolean[] {false, true}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isBool());
|
||||
assertTrue(value == input.toBool());
|
||||
final IValue output = module.runMethod("eqBool", input);
|
||||
assertTrue(output.isBool());
|
||||
assertTrue(value == output.toBool());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqInt() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isLong());
|
||||
assertTrue(value == input.toLong());
|
||||
final IValue output = module.runMethod("eqInt", input);
|
||||
assertTrue(output.isLong());
|
||||
assertTrue(value == output.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqFloat() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
double[] values =
|
||||
new double[] {
|
||||
-Double.MAX_VALUE,
|
||||
Double.MAX_VALUE,
|
||||
-Double.MIN_VALUE,
|
||||
Double.MIN_VALUE,
|
||||
-Math.exp(1.d),
|
||||
-Math.sqrt(2.d),
|
||||
-3.1415f,
|
||||
3.1415f,
|
||||
-1,
|
||||
0,
|
||||
1,
|
||||
};
|
||||
for (double value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isDouble());
|
||||
assertTrue(value == input.toDouble());
|
||||
final IValue output = module.runMethod("eqFloat", input);
|
||||
assertTrue(output.isDouble());
|
||||
assertTrue(value == output.toDouble());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqTensor() throws IOException {
|
||||
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
|
||||
final long numElements = Tensor.numel(inputTensorShape);
|
||||
final float[] inputTensorData = new float[(int) numElements];
|
||||
for (int i = 0; i < numElements; ++i) {
|
||||
inputTensorData[i] = i;
|
||||
}
|
||||
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
|
||||
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue input = IValue.from(inputTensor);
|
||||
assertTrue(input.isTensor());
|
||||
assertTrue(inputTensor == input.toTensor());
|
||||
final IValue output = module.runMethod("eqTensor", input);
|
||||
assertTrue(output.isTensor());
|
||||
final Tensor outputTensor = output.toTensor();
|
||||
assertNotNull(outputTensor);
|
||||
assertArrayEquals(inputTensorShape, outputTensor.shape());
|
||||
float[] outputData = outputTensor.getDataAsFloatArray();
|
||||
for (int i = 0; i < numElements; i++) {
|
||||
assertTrue(inputTensorData[i] == outputData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictIntKeyIntValue() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final Map<Long, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
|
||||
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
|
||||
inputMap.put(0l, IValue.from(0l));
|
||||
inputMap.put(1l, IValue.from(-1l));
|
||||
inputMap.put(-1l, IValue.from(1l));
|
||||
|
||||
final IValue input = IValue.dictLongKeyFrom(inputMap);
|
||||
assertTrue(input.isDictLongKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
|
||||
assertTrue(output.isDictLongKey());
|
||||
|
||||
final Map<Long, IValue> outputMap = output.toDictLongKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictStrKeyIntValue() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final Map<String, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
|
||||
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
|
||||
inputMap.put("long_0", IValue.from(0l));
|
||||
inputMap.put("long_1", IValue.from(1l));
|
||||
inputMap.put("long_-1", IValue.from(-1l));
|
||||
|
||||
final IValue input = IValue.dictStringKeyFrom(inputMap);
|
||||
assertTrue(input.isDictStringKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
|
||||
assertTrue(output.isDictStringKey());
|
||||
|
||||
final Map<String, IValue> outputMap = output.toDictStringKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testListIntSumReturnTuple() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
for (int n : new int[] {0, 1, 128}) {
|
||||
long[] a = new long[n];
|
||||
long sum = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
a[i] = i;
|
||||
sum += a[i];
|
||||
}
|
||||
final IValue input = IValue.listFrom(a);
|
||||
assertTrue(input.isLongList());
|
||||
|
||||
final IValue output = module.runMethod("listIntSumReturnTuple", input);
|
||||
|
||||
assertTrue(output.isTuple());
|
||||
assertTrue(2 == output.toTuple().length);
|
||||
|
||||
IValue output0 = output.toTuple()[0];
|
||||
IValue output1 = output.toTuple()[1];
|
||||
|
||||
assertArrayEquals(a, output0.toLongList());
|
||||
assertTrue(sum == output1.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOptionalIntIsNone() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
|
||||
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntEq0None() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testRunUndefinedMethod() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
module.runMethod("test_undefined_method_throws_exception");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorMethods() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
int[] ints = new int[numel];
|
||||
float[] floats = new float[numel];
|
||||
|
||||
byte[] bytes = new byte[numel];
|
||||
for (int i = 0; i < numel; i++) {
|
||||
bytes[i] = (byte) ((i % 255) - 128);
|
||||
ints[i] = i;
|
||||
floats[i] = i / 1000.f;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqBool() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
for (boolean value : new boolean[]{false, true}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isBool());
|
||||
assertTrue(value == input.toBool());
|
||||
final IValue output = module.runMethod("eqBool", input);
|
||||
assertTrue(output.isBool());
|
||||
assertTrue(value == output.toBool());
|
||||
}
|
||||
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
|
||||
assertTrue(tensorBytes.dtype() == DType.INT8);
|
||||
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
|
||||
|
||||
Tensor tensorInts = Tensor.fromBlob(ints, shape);
|
||||
assertTrue(tensorInts.dtype() == DType.INT32);
|
||||
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
|
||||
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
float[] floatsOut = tensorFloats.getDataAsFloatArray();
|
||||
assertTrue(floatsOut.length == numel);
|
||||
for (int i = 0; i < numel; i++) {
|
||||
assertTrue(floats[i] == floatsOut[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqInt() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
for (long value : new long[]{Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isLong());
|
||||
assertTrue(value == input.toLong());
|
||||
final IValue output = module.runMethod("eqInt", input);
|
||||
assertTrue(output.isLong());
|
||||
assertTrue(value == output.toLong());
|
||||
}
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void testTensorIllegalStateOnWrongType() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
float[] floats = new float[numel];
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
tensorFloats.getDataAsByteArray();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqString() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("eqStr", input);
|
||||
assertTrue(output.isString());
|
||||
assertTrue(value.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqFloat() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
double[] values =
|
||||
new double[]{
|
||||
-Double.MAX_VALUE,
|
||||
Double.MAX_VALUE,
|
||||
-Double.MIN_VALUE,
|
||||
Double.MIN_VALUE,
|
||||
-Math.exp(1.d),
|
||||
-Math.sqrt(2.d),
|
||||
-3.1415f,
|
||||
3.1415f,
|
||||
-1,
|
||||
0,
|
||||
1,
|
||||
};
|
||||
for (double value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isDouble());
|
||||
assertTrue(value == input.toDouble());
|
||||
final IValue output = module.runMethod("eqFloat", input);
|
||||
assertTrue(output.isDouble());
|
||||
assertTrue(value == output.toDouble());
|
||||
}
|
||||
@Test
|
||||
public void testStr3Concat() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("str3Concat", input);
|
||||
assertTrue(output.isString());
|
||||
String expectedOutput =
|
||||
new StringBuilder().append(value).append(value).append(value).toString();
|
||||
assertTrue(expectedOutput.equals(output.toStr()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqTensor() throws IOException {
|
||||
final long[] inputTensorShape = new long[]{1, 3, 224, 224};
|
||||
final long numElements = Tensor.numel(inputTensorShape);
|
||||
final float[] inputTensorData = new float[(int) numElements];
|
||||
for (int i = 0; i < numElements; ++i) {
|
||||
inputTensorData[i] = i;
|
||||
}
|
||||
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
|
||||
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue input = IValue.from(inputTensor);
|
||||
assertTrue(input.isTensor());
|
||||
assertTrue(inputTensor == input.toTensor());
|
||||
final IValue output = module.runMethod("eqTensor", input);
|
||||
assertTrue(output.isTensor());
|
||||
final Tensor outputTensor = output.toTensor();
|
||||
assertNotNull(outputTensor);
|
||||
assertArrayEquals(inputTensorShape, outputTensor.shape());
|
||||
float[] outputData = outputTensor.getDataAsFloatArray();
|
||||
for (int i = 0; i < numElements; i++) {
|
||||
assertTrue(inputTensorData[i] == outputData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictIntKeyIntValue() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final Map<Long, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
|
||||
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
|
||||
inputMap.put(0l, IValue.from(0l));
|
||||
inputMap.put(1l, IValue.from(-1l));
|
||||
inputMap.put(-1l, IValue.from(1l));
|
||||
|
||||
final IValue input = IValue.dictLongKeyFrom(inputMap);
|
||||
assertTrue(input.isDictLongKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
|
||||
assertTrue(output.isDictLongKey());
|
||||
|
||||
final Map<Long, IValue> outputMap = output.toDictLongKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictStrKeyIntValue() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final Map<String, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
|
||||
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
|
||||
inputMap.put("long_0", IValue.from(0l));
|
||||
inputMap.put("long_1", IValue.from(1l));
|
||||
inputMap.put("long_-1", IValue.from(-1l));
|
||||
|
||||
final IValue input = IValue.dictStringKeyFrom(inputMap);
|
||||
assertTrue(input.isDictStringKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
|
||||
assertTrue(output.isDictStringKey());
|
||||
|
||||
final Map<String, IValue> outputMap = output.toDictStringKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testListIntSumReturnTuple() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
for (int n : new int[]{0, 1, 128}) {
|
||||
long[] a = new long[n];
|
||||
long sum = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
a[i] = i;
|
||||
sum += a[i];
|
||||
}
|
||||
final IValue input = IValue.listFrom(a);
|
||||
assertTrue(input.isLongList());
|
||||
|
||||
final IValue output = module.runMethod("listIntSumReturnTuple", input);
|
||||
|
||||
assertTrue(output.isTuple());
|
||||
assertTrue(2 == output.toTuple().length);
|
||||
|
||||
IValue output0 = output.toTuple()[0];
|
||||
IValue output1 = output.toTuple()[1];
|
||||
|
||||
assertArrayEquals(a, output0.toLongList());
|
||||
assertTrue(sum == output1.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOptionalIntIsNone() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
|
||||
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntEq0None() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testRunUndefinedMethod() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
module.runMethod("test_undefined_method_throws_exception");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorMethods() {
|
||||
long[] shape = new long[]{1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
int[] ints = new int[numel];
|
||||
float[] floats = new float[numel];
|
||||
|
||||
byte[] bytes = new byte[numel];
|
||||
for (int i = 0; i < numel; i++) {
|
||||
bytes[i] = (byte) ((i % 255) - 128);
|
||||
ints[i] = i;
|
||||
floats[i] = i / 1000.f;
|
||||
}
|
||||
|
||||
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
|
||||
assertTrue(tensorBytes.dtype() == DType.INT8);
|
||||
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
|
||||
|
||||
Tensor tensorInts = Tensor.fromBlob(ints, shape);
|
||||
assertTrue(tensorInts.dtype() == DType.INT32);
|
||||
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
|
||||
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
float[] floatsOut = tensorFloats.getDataAsFloatArray();
|
||||
assertTrue(floatsOut.length == numel);
|
||||
for (int i = 0; i < numel; i++) {
|
||||
assertTrue(floats[i] == floatsOut[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void testTensorIllegalStateOnWrongType() {
|
||||
long[] shape = new long[]{1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
float[] floats = new float[numel];
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
tensorFloats.getDataAsByteArray();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqString() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
String[] values =
|
||||
new String[]{
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("eqStr", input);
|
||||
assertTrue(output.isString());
|
||||
assertTrue(value.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStr3Concat() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
String[] values =
|
||||
new String[]{
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("str3Concat", input);
|
||||
assertTrue(output.isString());
|
||||
String expectedOutput =
|
||||
new StringBuilder().append(value).append(value).append(value).toString();
|
||||
assertTrue(expectedOutput.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyShape() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final long someNumber = 43;
|
||||
final IValue input = IValue.from(Tensor.fromBlob(new long[]{someNumber}, new long[]{}));
|
||||
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
|
||||
assertTrue(output.isTensor());
|
||||
Tensor value = output.toTensor();
|
||||
assertArrayEquals(new long[]{}, value.shape());
|
||||
assertArrayEquals(new long[]{someNumber}, value.getDataAsLongArray());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAliasWithOffset() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue output = module.runMethod("testAliasWithOffset");
|
||||
assertTrue(output.isTensorList());
|
||||
Tensor[] tensors = output.toTensorList();
|
||||
assertEquals(100, tensors[0].getDataAsLongArray()[0]);
|
||||
assertEquals(200, tensors[1].getDataAsLongArray()[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNonContiguous() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue output = module.runMethod("testNonContiguous");
|
||||
assertTrue(output.isTensor());
|
||||
Tensor value = output.toTensor();
|
||||
assertArrayEquals(new long[]{2}, value.shape());
|
||||
assertArrayEquals(new long[]{100, 300}, value.getDataAsLongArray());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLast() throws IOException {
|
||||
long[] inputShape = new long[]{1, 3, 2, 2};
|
||||
long[] data = new long[]{1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
|
||||
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
|
||||
assertIValueTensor(
|
||||
outputNCHW,
|
||||
MemoryFormat.CONTIGUOUS,
|
||||
new long[]{1, 3, 2, 2},
|
||||
new long[]{1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104});
|
||||
final IValue outputNHWC = module.runMethod("contiguousChannelsLast", IValue.from(inputNHWC));
|
||||
assertIValueTensor(outputNHWC, MemoryFormat.CHANNELS_LAST, inputShape, data);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLast3d() throws IOException {
|
||||
long[] shape = new long[]{1, 2, 2, 2, 2};
|
||||
long[] dataNCHWD = new long[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
|
||||
long[] dataNHWDC = new long[]{1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
|
||||
|
||||
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
|
||||
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
|
||||
|
||||
Tensor inputNCHWD = Tensor.fromBlob(dataNCHWD, shape, MemoryFormat.CONTIGUOUS);
|
||||
final IValue outputNHWDC =
|
||||
module.runMethod("contiguousChannelsLast3d", IValue.from(inputNCHWD));
|
||||
assertIValueTensor(outputNHWDC, MemoryFormat.CHANNELS_LAST_3D, shape, dataNHWDC);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLastConv2d() throws IOException {
|
||||
long[] inputShape = new long[]{1, 3, 2, 2};
|
||||
long[] dataNCHW = new long[]{1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104};
|
||||
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataNHWC = new long[]{1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
|
||||
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||
|
||||
long[] weightShape = new long[]{3, 3, 1, 1};
|
||||
long[] dataWeightOIHW = new long[]{2, 0, 0, 0, 1, 0, 0, 0, -1};
|
||||
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataWeightOHWI = new long[]{2, 0, 0, 0, 1, 0, 0, 0, -1};
|
||||
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
|
||||
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
final IValue outputNCHW =
|
||||
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
|
||||
assertIValueTensor(
|
||||
outputNCHW,
|
||||
MemoryFormat.CONTIGUOUS,
|
||||
new long[]{1, 3, 2, 2},
|
||||
new long[]{2, 4, 6, 8, 11, 12, 13, 14, -101, -102, -103, -104});
|
||||
|
||||
final IValue outputNHWC =
|
||||
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
|
||||
assertIValueTensor(
|
||||
outputNHWC,
|
||||
MemoryFormat.CHANNELS_LAST,
|
||||
new long[]{1, 3, 2, 2},
|
||||
new long[]{2, 11, -101, 4, 12, -102, 6, 13, -103, 8, 14, -104});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMobileNetV2() throws IOException {
|
||||
try {
|
||||
final Module module = loadModel("mobilenet_v2.ptl");
|
||||
final IValue inputs = module.runMethod("get_all_bundled_inputs");
|
||||
assertTrue(inputs.isList());
|
||||
final IValue input = inputs.toList()[0];
|
||||
assertTrue(input.isTuple());
|
||||
module.forward(input.toTuple()[0]);
|
||||
assertTrue(true);
|
||||
} catch (Exception ex) {
|
||||
assertTrue("failed to run MobileNetV2 " + ex.getMessage(), false);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPointwiseOps() throws IOException {
|
||||
runModel("pointwise_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReductionOps() throws IOException {
|
||||
runModel("reduction_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComparisonOps() throws IOException {
|
||||
runModel("comparison_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOtherMathOps() throws IOException {
|
||||
runModel("other_math_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSpectralOps() throws IOException {
|
||||
runModel("spectral_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBlasLapackOps() throws IOException {
|
||||
runModel("blas_lapack_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSamplingOps() throws IOException {
|
||||
runModel("sampling_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorOps() throws IOException {
|
||||
runModel("tensor_general_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorCreationOps() throws IOException {
|
||||
runModel("tensor_creation_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorIndexingOps() throws IOException {
|
||||
runModel("tensor_indexing_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorTypingOps() throws IOException {
|
||||
runModel("tensor_typing_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorViewOps() throws IOException {
|
||||
runModel("tensor_view_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConvolutionOps() throws IOException {
|
||||
runModel("convolution_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPoolingOps() throws IOException {
|
||||
runModel("pooling_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPaddingOps() throws IOException {
|
||||
runModel("padding_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActivationOps() throws IOException {
|
||||
runModel("activation_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNormalizationOps() throws IOException {
|
||||
runModel("normalization_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRecurrentOps() throws IOException {
|
||||
runModel("recurrent_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransformerOps() throws IOException {
|
||||
runModel("transformer_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLinearOps() throws IOException {
|
||||
runModel("linear_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDropoutOps() throws IOException {
|
||||
runModel("dropout_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSparseOps() throws IOException {
|
||||
runModel("sparse_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDistanceFunctionOps() throws IOException {
|
||||
runModel("distance_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLossFunctionOps() throws IOException {
|
||||
runModel("loss_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVisionFunctionOps() throws IOException {
|
||||
runModel("vision_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShuffleOps() throws IOException {
|
||||
runModel("shuffle_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNNUtilsOps() throws IOException {
|
||||
runModel("nn_utils_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testQuantOps() throws IOException {
|
||||
runModel("general_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDynamicQuantOps() throws IOException {
|
||||
runModel("dynamic_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStaticQuantOps() throws IOException {
|
||||
runModel("static_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFusedQuantOps() throws IOException {
|
||||
runModel("fused_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTorchScriptBuiltinQuantOps() throws IOException {
|
||||
runModel("torchscript_builtin_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTorchScriptCollectionQuantOps() throws IOException {
|
||||
runModel("torchscript_collection_ops");
|
||||
}
|
||||
|
||||
|
||||
static void assertIValueTensor(
|
||||
final IValue ivalue,
|
||||
final MemoryFormat memoryFormat,
|
||||
final long[] expectedShape,
|
||||
final long[] expectedData) {
|
||||
assertTrue(ivalue.isTensor());
|
||||
Tensor t = ivalue.toTensor();
|
||||
assertEquals(memoryFormat, t.memoryFormat());
|
||||
assertArrayEquals(expectedShape, t.shape());
|
||||
assertArrayEquals(expectedData, t.getDataAsLongArray());
|
||||
}
|
||||
|
||||
void runModel(final String name) throws IOException {
|
||||
final Module storage_module = loadModel(name + ".ptl");
|
||||
storage_module.forward();
|
||||
|
||||
// TODO enable this once the on-the-fly script is ready
|
||||
// final Module on_the_fly_module = loadModel(name + "_temp.ptl");
|
||||
// on_the_fly_module.forward();
|
||||
assertTrue(true);
|
||||
}
|
||||
|
||||
protected abstract Module loadModel(String assetName) throws IOException;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyShape() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final long someNumber = 43;
|
||||
final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {}));
|
||||
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
|
||||
assertTrue(output.isTensor());
|
||||
Tensor value = output.toTensor();
|
||||
assertArrayEquals(new long[] {}, value.shape());
|
||||
assertArrayEquals(new long[] {someNumber}, value.getDataAsLongArray());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAliasWithOffset() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue output = module.runMethod("testAliasWithOffset");
|
||||
assertTrue(output.isTensorList());
|
||||
Tensor[] tensors = output.toTensorList();
|
||||
assertEquals(100, tensors[0].getDataAsLongArray()[0]);
|
||||
assertEquals(200, tensors[1].getDataAsLongArray()[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNonContiguous() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue output = module.runMethod("testNonContiguous");
|
||||
assertTrue(output.isTensor());
|
||||
Tensor value = output.toTensor();
|
||||
assertArrayEquals(new long[] {2}, value.shape());
|
||||
assertArrayEquals(new long[] {100, 300}, value.getDataAsLongArray());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLast() throws IOException {
|
||||
long[] inputShape = new long[] {1, 3, 2, 2};
|
||||
long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
|
||||
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
|
||||
assertIValueTensor(
|
||||
outputNCHW,
|
||||
MemoryFormat.CONTIGUOUS,
|
||||
new long[] {1, 3, 2, 2},
|
||||
new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104});
|
||||
final IValue outputNHWC = module.runMethod("contiguousChannelsLast", IValue.from(inputNHWC));
|
||||
assertIValueTensor(outputNHWC, MemoryFormat.CHANNELS_LAST, inputShape, data);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLast3d() throws IOException {
|
||||
long[] shape = new long[] {1, 2, 2, 2, 2};
|
||||
long[] dataNCHWD = new long[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
|
||||
long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
|
||||
|
||||
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
|
||||
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
|
||||
|
||||
Tensor inputNCHWD = Tensor.fromBlob(dataNCHWD, shape, MemoryFormat.CONTIGUOUS);
|
||||
final IValue outputNHWDC =
|
||||
module.runMethod("contiguousChannelsLast3d", IValue.from(inputNCHWD));
|
||||
assertIValueTensor(outputNHWDC, MemoryFormat.CHANNELS_LAST_3D, shape, dataNHWDC);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLastConv2d() throws IOException {
|
||||
long[] inputShape = new long[] {1, 3, 2, 2};
|
||||
long[] dataNCHW = new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104};
|
||||
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataNHWC = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
|
||||
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||
|
||||
long[] weightShape = new long[] {3, 3, 1, 1};
|
||||
long[] dataWeightOIHW = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
|
||||
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataWeightOHWI = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
|
||||
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
|
||||
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
final IValue outputNCHW =
|
||||
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
|
||||
assertIValueTensor(
|
||||
outputNCHW,
|
||||
MemoryFormat.CONTIGUOUS,
|
||||
new long[] {1, 3, 2, 2},
|
||||
new long[] {2, 4, 6, 8, 11, 12, 13, 14, -101, -102, -103, -104});
|
||||
|
||||
final IValue outputNHWC =
|
||||
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
|
||||
assertIValueTensor(
|
||||
outputNHWC,
|
||||
MemoryFormat.CHANNELS_LAST,
|
||||
new long[] {1, 3, 2, 2},
|
||||
new long[] {2, 11, -101, 4, 12, -102, 6, 13, -103, 8, 14, -104});
|
||||
}
|
||||
|
||||
static void assertIValueTensor(
|
||||
final IValue ivalue,
|
||||
final MemoryFormat memoryFormat,
|
||||
final long[] expectedShape,
|
||||
final long[] expectedData) {
|
||||
assertTrue(ivalue.isTensor());
|
||||
Tensor t = ivalue.toTensor();
|
||||
assertEquals(memoryFormat, t.memoryFormat());
|
||||
assertArrayEquals(expectedShape, t.shape());
|
||||
assertArrayEquals(expectedData, t.getDataAsLongArray());
|
||||
}
|
||||
|
||||
protected abstract Module loadModel(String assetName) throws IOException;
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ class TSBuiltinOpsModule(torch.nn.Module):
|
||||
l = ["1", "2", "test"]
|
||||
d = {"key": 1}
|
||||
d2 = {0: 100}
|
||||
return len(
|
||||
return (
|
||||
# type
|
||||
bool(x),
|
||||
bool(x.item()),
|
||||
@ -111,7 +111,7 @@ class TSCollectionOpsModule(torch.nn.Module):
|
||||
d2.clear()
|
||||
d2[0] = 100
|
||||
|
||||
return len(
|
||||
return (
|
||||
s[torch.tensor(1)],
|
||||
d["key"],
|
||||
d2[0],
|
||||
|
@ -1,8 +1,3 @@
|
||||
_coverage: 91.09
|
||||
_covered_ops: 358
|
||||
_generated_ops: 705
|
||||
_production_ops: 393
|
||||
_uncovered_ops: 35
|
||||
all_generated_ops:
|
||||
- aten::Bool.Tensor
|
||||
- aten::Bool.int
|
||||
@ -647,8 +642,6 @@ all_generated_ops:
|
||||
- aten::zeros
|
||||
- aten::zeros.out
|
||||
- aten::zeros_like
|
||||
- prepacked::conv2d_clamp_run
|
||||
- prepacked::linear_clamp_run
|
||||
- prim::RaiseException
|
||||
- prim::TupleIndex
|
||||
- prim::TupleUnpack
|
||||
@ -1030,8 +1023,6 @@ covered_ops:
|
||||
- aten::zeros
|
||||
- aten::zeros.out
|
||||
- aten::zeros_like
|
||||
- prepacked::conv2d_clamp_run
|
||||
- prepacked::linear_clamp_run
|
||||
- prim::RaiseException
|
||||
- prim::TupleIndex
|
||||
- prim::TupleUnpack
|
||||
@ -1095,8 +1086,10 @@ uncovered_ops:
|
||||
- aten::true_divide.Tensor
|
||||
- aten::upsample_nearest2d
|
||||
- prepacked::conv2d_clamp_prepack
|
||||
- prepacked::conv2d_clamp_run
|
||||
- prepacked::conv2d_transpose_clamp_prepack
|
||||
- prepacked::conv2d_transpose_clamp_run
|
||||
- prepacked::linear_clamp_run
|
||||
- prim::ModuleContainerIndex.list
|
||||
- prim::NumToTensor.Scalar
|
||||
- prim::Print
|
||||
|
@ -1,9 +1,5 @@
|
||||
import io
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from android_api_module import AndroidAPIModule
|
||||
from builtin_ops import (
|
||||
TSBuiltinOpsModule,
|
||||
TSCollectionOpsModule,
|
||||
@ -47,188 +43,92 @@ from tensor_ops import (
|
||||
TensorTypingOpsModule,
|
||||
TensorViewOpsModule,
|
||||
)
|
||||
from torch.jit.mobile import _load_for_lite_interpreter
|
||||
from torchvision_models import MobileNetV2Module
|
||||
|
||||
test_path_ios = "ios/TestApp/models/"
|
||||
test_path_android = "android/pytorch_android/src/androidTest/assets/"
|
||||
|
||||
output_path = "ios/TestApp/models/"
|
||||
production_ops_path = "test/mobile/model_test/model_ops.yaml"
|
||||
coverage_out_path = "test/mobile/model_test/coverage.yaml"
|
||||
|
||||
all_modules = {
|
||||
# math ops
|
||||
"pointwise_ops": PointwiseOpsModule(),
|
||||
"reduction_ops": ReductionOpsModule(),
|
||||
"comparison_ops": ComparisonOpsModule(),
|
||||
"spectral_ops": SpectralOpsModule(),
|
||||
"other_math_ops": OtherMathOpsModule(),
|
||||
"blas_lapack_ops": BlasLapackOpsModule(),
|
||||
# sampling
|
||||
"sampling_ops": SamplingOpsModule(),
|
||||
# tensor ops
|
||||
"tensor_general_ops": TensorOpsModule(),
|
||||
"tensor_creation_ops": TensorCreationOpsModule(),
|
||||
"tensor_indexing_ops": TensorIndexingOpsModule(),
|
||||
"tensor_typing_ops": TensorTypingOpsModule(),
|
||||
"tensor_view_ops": TensorViewOpsModule(),
|
||||
# nn ops
|
||||
"convolution_ops": NNConvolutionModule(),
|
||||
"pooling_ops": NNPoolingModule(),
|
||||
"padding_ops": NNPaddingModule(),
|
||||
"activation_ops": NNActivationModule(),
|
||||
"normalization_ops": NNNormalizationModule(),
|
||||
"recurrent_ops": NNRecurrentModule(),
|
||||
"transformer_ops": NNTransformerModule(),
|
||||
"linear_ops": NNLinearModule(),
|
||||
"dropout_ops": NNDropoutModule(),
|
||||
"sparse_ops": NNSparseModule(),
|
||||
"distance_function_ops": NNDistanceModule(),
|
||||
"loss_function_ops": NNLossFunctionModule(),
|
||||
"vision_function_ops": NNVisionModule(),
|
||||
"shuffle_ops": NNShuffleModule(),
|
||||
"nn_utils_ops": NNUtilsModule(),
|
||||
# quantization ops
|
||||
"general_quant_ops": GeneralQuantModule(),
|
||||
"dynamic_quant_ops": DynamicQuantModule(),
|
||||
"static_quant_ops": StaticQuantModule(),
|
||||
"fused_quant_ops": FusedQuantModule(),
|
||||
# TorchScript buildin ops
|
||||
"torchscript_builtin_ops": TSBuiltinOpsModule(),
|
||||
"torchscript_collection_ops": TSCollectionOpsModule(),
|
||||
# vision
|
||||
"mobilenet_v2": MobileNetV2Module(),
|
||||
# android api module
|
||||
"android_api_module": AndroidAPIModule(),
|
||||
}
|
||||
|
||||
models_need_trace = [
|
||||
"static_quant_ops",
|
||||
def scriptAndSave(module, name):
|
||||
module = torch.jit.script(module)
|
||||
return save(module, name)
|
||||
|
||||
|
||||
def traceAndSave(module, name):
|
||||
module = torch.jit.trace(module, [])
|
||||
return save(module, name)
|
||||
|
||||
|
||||
def save(module, name):
|
||||
module._save_for_lite_interpreter(output_path + name + ".ptl")
|
||||
print("model saved to " + output_path + name + ".ptl")
|
||||
ops = torch.jit.export_opnames(module)
|
||||
print(ops)
|
||||
module()
|
||||
return ops
|
||||
|
||||
|
||||
ops = [
|
||||
# math ops
|
||||
scriptAndSave(PointwiseOpsModule(), "pointwise_ops"),
|
||||
scriptAndSave(ReductionOpsModule(), "reduction_ops"),
|
||||
scriptAndSave(ComparisonOpsModule(), "comparison_ops"),
|
||||
scriptAndSave(OtherMathOpsModule(), "other_math_ops"),
|
||||
scriptAndSave(SpectralOpsModule(), "spectral_ops"),
|
||||
scriptAndSave(BlasLapackOpsModule(), "blas_lapack_ops"),
|
||||
# sampling
|
||||
scriptAndSave(SamplingOpsModule(), "sampling_ops"),
|
||||
# tensor ops
|
||||
scriptAndSave(TensorOpsModule(), "tensor_general_ops"),
|
||||
scriptAndSave(TensorCreationOpsModule(), "tensor_creation_ops"),
|
||||
scriptAndSave(TensorIndexingOpsModule(), "tensor_indexing_ops"),
|
||||
scriptAndSave(TensorTypingOpsModule(), "tensor_typing_ops"),
|
||||
scriptAndSave(TensorViewOpsModule(), "tensor_view_ops"),
|
||||
# nn ops
|
||||
scriptAndSave(NNConvolutionModule(), "convolution_ops"),
|
||||
scriptAndSave(NNPoolingModule(), "pooling_ops"),
|
||||
scriptAndSave(NNPaddingModule(), "padding_ops"),
|
||||
scriptAndSave(NNActivationModule(), "activation_ops"),
|
||||
scriptAndSave(NNNormalizationModule(), "normalization_ops"),
|
||||
scriptAndSave(NNRecurrentModule(), "recurrent_ops"),
|
||||
scriptAndSave(NNTransformerModule(), "transformer_ops"),
|
||||
scriptAndSave(NNLinearModule(), "linear_ops"),
|
||||
scriptAndSave(NNDropoutModule(), "dropout_ops"),
|
||||
scriptAndSave(NNSparseModule(), "sparse_ops"),
|
||||
scriptAndSave(NNDistanceModule(), "distance_function_ops"),
|
||||
scriptAndSave(NNLossFunctionModule(), "loss_function_ops"),
|
||||
scriptAndSave(NNVisionModule(), "vision_function_ops"),
|
||||
scriptAndSave(NNShuffleModule(), "shuffle_ops"),
|
||||
scriptAndSave(NNUtilsModule(), "nn_utils_ops"),
|
||||
# quantization ops
|
||||
scriptAndSave(GeneralQuantModule(), "general_quant_ops"),
|
||||
scriptAndSave(DynamicQuantModule().getModule(), "dynamic_quant_ops"),
|
||||
traceAndSave(StaticQuantModule().getModule(), "static_quant_ops"),
|
||||
scriptAndSave(FusedQuantModule().getModule(), "fused_quant_ops"),
|
||||
# TorchScript buildin ops
|
||||
scriptAndSave(TSBuiltinOpsModule(), "torchscript_builtin_ops"),
|
||||
scriptAndSave(TSCollectionOpsModule(), "torchscript_collection_ops"),
|
||||
]
|
||||
|
||||
|
||||
def calcOpsCoverage(ops):
|
||||
with open(production_ops_path) as input_yaml_file:
|
||||
production_ops = yaml.safe_load(input_yaml_file)
|
||||
with open(production_ops_path) as input_yaml_file:
|
||||
production_ops = yaml.safe_load(input_yaml_file)
|
||||
|
||||
production_ops = set(production_ops["root_operators"])
|
||||
all_generated_ops = set(ops)
|
||||
covered_ops = production_ops.intersection(all_generated_ops)
|
||||
uncovered_ops = production_ops - covered_ops
|
||||
coverage = 100 * len(covered_ops) / len(production_ops)
|
||||
print(
|
||||
f"\nGenerated {len(all_generated_ops)} ops and covered "
|
||||
+ f"{len(covered_ops)}/{len(production_ops)} ({round(coverage, 2)}%) production ops. \n"
|
||||
production_ops = set(production_ops["root_operators"])
|
||||
all_generated_ops = set().union(*ops)
|
||||
covered_ops = production_ops.intersection(all_generated_ops)
|
||||
uncovered_ops = production_ops - covered_ops
|
||||
coverage = 100 * len(covered_ops) / len(production_ops)
|
||||
print(
|
||||
f"\nGenerated {len(all_generated_ops)} ops and covered "
|
||||
+ f"{len(covered_ops)}/{len(production_ops)} ({round(coverage, 2)}%) production ops. \n"
|
||||
)
|
||||
with open(coverage_out_path, "w") as f:
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"uncovered_ops": sorted(list(uncovered_ops)),
|
||||
"covered_ops": sorted(list(covered_ops)),
|
||||
"all_generated_ops": sorted(list(all_generated_ops)),
|
||||
},
|
||||
f,
|
||||
)
|
||||
with open(coverage_out_path, "w") as f:
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"_covered_ops": len(covered_ops),
|
||||
"_production_ops": len(production_ops),
|
||||
"_generated_ops": len(all_generated_ops),
|
||||
"_uncovered_ops": len(uncovered_ops),
|
||||
"_coverage": round(coverage, 2),
|
||||
"uncovered_ops": sorted(list(uncovered_ops)),
|
||||
"covered_ops": sorted(list(covered_ops)),
|
||||
"all_generated_ops": sorted(list(all_generated_ops)),
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
|
||||
def getModuleFromName(model_name):
|
||||
if model_name not in all_modules:
|
||||
print("Cannot find test model for " + model_name)
|
||||
return None, []
|
||||
|
||||
module = all_modules[model_name]
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
module = module.getModule()
|
||||
|
||||
has_bundled_inputs = False # module.find_method("get_all_bundled_inputs")
|
||||
|
||||
if model_name in models_need_trace:
|
||||
module = torch.jit.trace(module, [])
|
||||
else:
|
||||
module = torch.jit.script(module)
|
||||
|
||||
ops = torch.jit.export_opnames(module)
|
||||
print(ops)
|
||||
|
||||
# try to run the model
|
||||
runModule(module)
|
||||
|
||||
return module, ops
|
||||
|
||||
|
||||
def runModule(module):
|
||||
buffer = io.BytesIO(module._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
lite_module = _load_for_lite_interpreter(buffer)
|
||||
if lite_module.find_method("get_all_bundled_inputs"):
|
||||
# run with the first bundled input
|
||||
input = lite_module.run_method("get_all_bundled_inputs")[0]
|
||||
lite_module.forward(*input)
|
||||
else:
|
||||
# assuming model has no input
|
||||
lite_module()
|
||||
|
||||
|
||||
# generate all models in the given folder.
|
||||
# If it's "on the fly" mode, add "_temp" suffix to the model file.
|
||||
def generateAllModels(folder, on_the_fly=False):
|
||||
all_ops = []
|
||||
for name in all_modules:
|
||||
module, ops = getModuleFromName(name)
|
||||
all_ops = all_ops + ops
|
||||
path = folder + name + ("_temp.ptl" if on_the_fly else ".ptl")
|
||||
module._save_for_lite_interpreter(path)
|
||||
print("model saved to " + path)
|
||||
calcOpsCoverage(all_ops)
|
||||
|
||||
|
||||
# generate/update a given model for storage
|
||||
def generateModel(name):
|
||||
module, ops = getModuleFromName(name)
|
||||
if module is None:
|
||||
return
|
||||
path_ios = test_path_ios + name + ".ptl"
|
||||
path_android = test_path_android + name + ".ptl"
|
||||
module._save_for_lite_interpreter(path_ios)
|
||||
module._save_for_lite_interpreter(path_android)
|
||||
print("model saved to " + path_ios + " and " + path_android)
|
||||
|
||||
|
||||
def main(argv):
|
||||
if argv is None or len(argv) != 1:
|
||||
print(
|
||||
"""
|
||||
This script generate models for mobile test. For each model we have a "storage" version
|
||||
and an "on-the-fly" version. The "on-the-fly" version will be generated during test,and
|
||||
should not be committed to the repo.
|
||||
The "storage" version is for back compatibility # test (a model generated today should
|
||||
run on master branch in the next 6 months). We can use this script to update a model that
|
||||
is no longer supported.
|
||||
- use 'python gen_test_model.py android-test' to generate on-the-fly models for android
|
||||
- use 'python gen_test_model.py ios-test' to generate on-the-fly models for ios
|
||||
- use 'python gen_test_model.py android' to generate checked-in models for android
|
||||
- use 'python gen_test_model.py ios' to generate on-the-fly models for ios
|
||||
- use 'python gen_test_model.py <model_name_no_suffix>' to update the given storage model
|
||||
"""
|
||||
)
|
||||
return
|
||||
|
||||
if argv[0] == "android":
|
||||
generateAllModels(test_path_android, on_the_fly=False)
|
||||
elif argv[0] == "ios":
|
||||
generateAllModels(test_path_ios, on_the_fly=False)
|
||||
elif argv[0] == "android-test":
|
||||
generateAllModels(test_path_android, on_the_fly=True)
|
||||
elif argv[0] == "ios-test":
|
||||
generateAllModels(test_path_ios, on_the_fly=True)
|
||||
else:
|
||||
generateModel(argv[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
||||
|
@ -22,7 +22,7 @@ class PointwiseOpsModule(torch.nn.Module):
|
||||
f = torch.zeros(3)
|
||||
g = torch.tensor([-1, 0, 1])
|
||||
w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
|
||||
return len(
|
||||
return (
|
||||
torch.abs(torch.tensor([-1, -2, 3])),
|
||||
torch.absolute(torch.tensor([-1, -2, 3])),
|
||||
torch.acos(a),
|
||||
@ -222,7 +222,7 @@ class ReductionOpsModule(torch.nn.Module):
|
||||
a = torch.randn(4)
|
||||
b = torch.randn(4)
|
||||
c = torch.tensor(0.5)
|
||||
return len(
|
||||
return (
|
||||
torch.argmax(a),
|
||||
torch.argmin(a),
|
||||
torch.amax(a),
|
||||
@ -271,7 +271,7 @@ class ComparisonOpsModule(torch.nn.Module):
|
||||
def forward(self):
|
||||
a = torch.tensor(0)
|
||||
b = torch.tensor(1)
|
||||
return len(
|
||||
return (
|
||||
torch.allclose(a, b),
|
||||
torch.argsort(a),
|
||||
torch.eq(a, b),
|
||||
@ -327,7 +327,7 @@ class OtherMathOpsModule(torch.nn.Module):
|
||||
f = torch.randn(4, 4, 4)
|
||||
size = [0, 1]
|
||||
dims = [0, 1]
|
||||
return len(
|
||||
return (
|
||||
torch.atleast_1d(a),
|
||||
torch.atleast_2d(a),
|
||||
torch.atleast_3d(a),
|
||||
@ -380,7 +380,7 @@ class OtherMathOpsModule(torch.nn.Module):
|
||||
torch.triu_indices(3, 3),
|
||||
torch.vander(a),
|
||||
torch.view_as_real(torch.randn(4, dtype=torch.cfloat)),
|
||||
torch.view_as_complex(torch.randn(4, 2)).real,
|
||||
torch.view_as_complex(torch.randn(4, 2)),
|
||||
torch.resolve_conj(a),
|
||||
torch.resolve_neg(a),
|
||||
)
|
||||
@ -396,7 +396,7 @@ class SpectralOpsModule(torch.nn.Module):
|
||||
def spectral_ops(self):
|
||||
a = torch.randn(10)
|
||||
b = torch.randn(10, 8, 4, 2)
|
||||
return len(
|
||||
return (
|
||||
torch.stft(a, 8),
|
||||
torch.stft(a, torch.tensor(8)),
|
||||
torch.istft(b, 8),
|
||||
@ -420,7 +420,7 @@ class BlasLapackOpsModule(torch.nn.Module):
|
||||
a = torch.randn(10, 3, 4)
|
||||
b = torch.randn(10, 4, 3)
|
||||
v = torch.randn(3)
|
||||
return len(
|
||||
return (
|
||||
torch.addbmm(m, a, b),
|
||||
torch.addmm(torch.randn(2, 3), torch.randn(2, 3), torch.randn(3, 3)),
|
||||
torch.addmv(torch.randn(2), torch.randn(2, 3), torch.randn(3)),
|
||||
|
@ -31,11 +31,11 @@ class NNConvolutionModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self):
|
||||
return len((
|
||||
return (
|
||||
[module(self.input1d) for i, module in enumerate(self.module1d)],
|
||||
[module(self.input2d) for i, module in enumerate(self.module2d)],
|
||||
[module(self.input3d) for i, module in enumerate(self.module3d)],
|
||||
))
|
||||
)
|
||||
|
||||
|
||||
class NNPoolingModule(torch.nn.Module):
|
||||
@ -77,11 +77,11 @@ class NNPoolingModule(torch.nn.Module):
|
||||
# TODO max_unpool
|
||||
|
||||
def forward(self):
|
||||
return len((
|
||||
return (
|
||||
[module(self.input1d) for i, module in enumerate(self.module1d)],
|
||||
[module(self.input2d) for i, module in enumerate(self.module2d)],
|
||||
[module(self.input3d) for i, module in enumerate(self.module3d)],
|
||||
))
|
||||
)
|
||||
|
||||
|
||||
class NNPaddingModule(torch.nn.Module):
|
||||
@ -116,11 +116,11 @@ class NNPaddingModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self):
|
||||
return len((
|
||||
return (
|
||||
[module(self.input1d) for i, module in enumerate(self.module1d)],
|
||||
[module(self.input2d) for i, module in enumerate(self.module2d)],
|
||||
[module(self.input3d) for i, module in enumerate(self.module3d)],
|
||||
))
|
||||
)
|
||||
|
||||
|
||||
class NNNormalizationModule(torch.nn.Module):
|
||||
@ -155,11 +155,11 @@ class NNNormalizationModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self):
|
||||
return len((
|
||||
return (
|
||||
[module(self.input1d) for i, module in enumerate(self.module1d)],
|
||||
[module(self.input2d) for i, module in enumerate(self.module2d)],
|
||||
[module(self.input3d) for i, module in enumerate(self.module3d)],
|
||||
))
|
||||
)
|
||||
|
||||
|
||||
class NNActivationModule(torch.nn.Module):
|
||||
@ -202,9 +202,9 @@ class NNActivationModule(torch.nn.Module):
|
||||
|
||||
def forward(self):
|
||||
input = torch.randn(2, 3, 4)
|
||||
return len((
|
||||
[module(input) for i, module in enumerate(self.activations)],
|
||||
))
|
||||
for i, module in enumerate(self.activations):
|
||||
x = module(input)
|
||||
return x
|
||||
|
||||
|
||||
class NNRecurrentModule(torch.nn.Module):
|
||||
@ -228,13 +228,14 @@ class NNRecurrentModule(torch.nn.Module):
|
||||
input = torch.randn(5, 3, 4)
|
||||
h = torch.randn(2, 3, 8)
|
||||
c = torch.randn(2, 3, 8)
|
||||
r = self.rnn[0](input, h)
|
||||
r = self.rnn[1](input[0], h[0])
|
||||
r = self.gru[0](input, h)
|
||||
r = self.gru[1](input[0], h[0])
|
||||
r = self.lstm[0](input, (h, c))
|
||||
r = self.lstm[1](input[0], (h[0], c[0]))
|
||||
return len(r)
|
||||
return (
|
||||
self.rnn[0](input, h),
|
||||
self.rnn[1](input[0], h[0]),
|
||||
self.gru[0](input, h),
|
||||
self.gru[1](input[0], h[0]),
|
||||
self.lstm[0](input, (h, c)),
|
||||
self.lstm[1](input[0], (h[0], c[0])),
|
||||
)
|
||||
|
||||
|
||||
class NNTransformerModule(torch.nn.Module):
|
||||
@ -257,10 +258,11 @@ class NNTransformerModule(torch.nn.Module):
|
||||
def forward(self):
|
||||
input = torch.rand(1, 16, 2)
|
||||
tgt = torch.rand((1, 16, 2))
|
||||
r = self.transformers[0](input, tgt)
|
||||
r = self.transformers[1](input)
|
||||
r = self.transformers[2](input, tgt)
|
||||
return len(r)
|
||||
return (
|
||||
self.transformers[0](input, tgt),
|
||||
self.transformers[1](input),
|
||||
self.transformers[2](input, tgt),
|
||||
)
|
||||
|
||||
|
||||
class NNLinearModule(torch.nn.Module):
|
||||
@ -277,10 +279,11 @@ class NNLinearModule(torch.nn.Module):
|
||||
|
||||
def forward(self):
|
||||
input = torch.randn(32, 20)
|
||||
r = self.linears[0](input)
|
||||
r = self.linears[1](input)
|
||||
r = self.linears[2](input, input)
|
||||
return len(r)
|
||||
return (
|
||||
self.linears[0](input),
|
||||
self.linears[1](input),
|
||||
self.linears[2](input, input),
|
||||
)
|
||||
|
||||
|
||||
class NNDropoutModule(torch.nn.Module):
|
||||
@ -291,7 +294,7 @@ class NNDropoutModule(torch.nn.Module):
|
||||
a = torch.randn(8, 4)
|
||||
b = torch.randn(8, 4, 4, 4)
|
||||
c = torch.randn(8, 4, 4, 4, 4)
|
||||
return len(
|
||||
return (
|
||||
F.dropout(a),
|
||||
F.dropout2d(b),
|
||||
F.dropout3d(c),
|
||||
@ -309,7 +312,7 @@ class NNSparseModule(torch.nn.Module):
|
||||
input2 = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
|
||||
embedding_matrix = torch.rand(10, 3)
|
||||
offsets = torch.tensor([0, 4])
|
||||
return len(
|
||||
return (
|
||||
F.embedding(input, embedding_matrix),
|
||||
F.embedding_bag(input2, embedding_matrix, offsets),
|
||||
F.one_hot(torch.arange(0, 5) % 3, num_classes=5),
|
||||
@ -323,7 +326,7 @@ class NNDistanceModule(torch.nn.Module):
|
||||
def forward(self):
|
||||
a = torch.randn(8, 4)
|
||||
b = torch.randn(8, 4)
|
||||
return len(
|
||||
return (
|
||||
F.pairwise_distance(a, b),
|
||||
F.cosine_similarity(a, b),
|
||||
F.pdist(a),
|
||||
@ -344,7 +347,7 @@ class NNLossFunctionModule(torch.nn.Module):
|
||||
targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
|
||||
input_lengths = torch.full((16,), 50, dtype=torch.long)
|
||||
target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
|
||||
return len(
|
||||
return (
|
||||
F.binary_cross_entropy(torch.sigmoid(a), b),
|
||||
F.binary_cross_entropy_with_logits(torch.sigmoid(a), b),
|
||||
F.poisson_nll_loss(a, b),
|
||||
@ -389,10 +392,8 @@ class NNVisionModule(torch.nn.Module):
|
||||
|
||||
def forward(self):
|
||||
input = torch.randn(1, 3, 16, 16)
|
||||
for i, module in enumerate(self.vision_modules):
|
||||
r = module(self.input)
|
||||
return len(
|
||||
r,
|
||||
return (
|
||||
[module(self.input) for i, module in enumerate(self.vision_modules)],
|
||||
self.linear_sample(torch.randn(4, 9, 9)),
|
||||
self.trilinear_sample(torch.randn(1, 3, 4, 9, 9)),
|
||||
F.grid_sample(input, torch.ones(1, 4, 4, 2)),
|
||||
@ -405,7 +406,7 @@ class NNShuffleModule(torch.nn.Module):
|
||||
self.shuffle = nn.ChannelShuffle(2)
|
||||
|
||||
def forward(self):
|
||||
return len(self.shuffle(torch.randn(1, 4, 2, 2)),)
|
||||
return (self.shuffle(torch.randn(1, 4, 2, 2)),)
|
||||
|
||||
|
||||
class NNUtilsModule(torch.nn.Module):
|
||||
@ -418,6 +419,6 @@ class NNUtilsModule(torch.nn.Module):
|
||||
|
||||
def forward(self):
|
||||
input = torch.randn(2, 50)
|
||||
return len(
|
||||
return (
|
||||
self.flatten(input),
|
||||
)
|
||||
|
@ -23,7 +23,7 @@ class GeneralQuantModule(torch.nn.Module):
|
||||
input1 = torch.randn(1, 16, 4)
|
||||
input2 = torch.randn(1, 16, 4, 4)
|
||||
input3 = torch.randn(1, 16, 4, 4, 4)
|
||||
return len(
|
||||
return (
|
||||
self.func.add(a, b),
|
||||
self.func.cat((a, a), 0),
|
||||
self.func.mul(a, b),
|
||||
@ -92,7 +92,7 @@ class DynamicQuantModule:
|
||||
trans_input = torch.randn(1, 16, 2)
|
||||
tgt = torch.rand(1, 16, 2)
|
||||
|
||||
return len((
|
||||
return (
|
||||
self.rnn(input, h),
|
||||
self.rnncell(input[0], h[0]),
|
||||
self.gru(input, h),
|
||||
@ -105,7 +105,7 @@ class DynamicQuantModule:
|
||||
self.linears[0](linear_input),
|
||||
self.linears[1](linear_input),
|
||||
self.linears[2](linear_input, linear_input),
|
||||
))
|
||||
)
|
||||
|
||||
|
||||
class StaticQuantModule:
|
||||
|
@ -11,7 +11,7 @@ class SamplingOpsModule(torch.nn.Module):
|
||||
a = torch.empty(3, 3).uniform_(0.0, 1.0)
|
||||
size = (1, 4)
|
||||
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
|
||||
return len(
|
||||
return (
|
||||
# torch.seed(),
|
||||
# torch.manual_seed(0),
|
||||
torch.bernoulli(a),
|
||||
|
@ -15,7 +15,7 @@ class TensorOpsModule(torch.nn.Module):
|
||||
c = torch.randn(4, dtype=torch.cfloat)
|
||||
w = torch.rand(4, 4, 4, 4)
|
||||
v = torch.rand(4, 4, 4, 4)
|
||||
return len(
|
||||
return (
|
||||
# torch.is_tensor(a),
|
||||
# torch.is_storage(a),
|
||||
torch.is_complex(a),
|
||||
@ -120,7 +120,7 @@ class TensorCreationOpsModule(torch.nn.Module):
|
||||
0,
|
||||
torch.quint8,
|
||||
)
|
||||
return len(
|
||||
return (
|
||||
torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]),
|
||||
# torch.sparse_coo_tensor(i, v, [2, 3]), # not work for iOS
|
||||
torch.as_tensor([1, 2, 3]),
|
||||
@ -171,7 +171,7 @@ class TensorIndexingOpsModule(torch.nn.Module):
|
||||
t = torch.tensor([[0, 0], [1, 0]])
|
||||
mask = x.ge(0.5)
|
||||
i = [0, 1]
|
||||
return len(
|
||||
return (
|
||||
torch.cat((x, x, x), 0),
|
||||
torch.concat((x, x, x), 0),
|
||||
torch.conj(x),
|
||||
@ -233,7 +233,7 @@ class TensorTypingOpsModule(torch.nn.Module):
|
||||
|
||||
def tensor_typing_ops(self):
|
||||
x = torch.randn(1, 3, 4, 4)
|
||||
return len(
|
||||
return (
|
||||
x.to(torch.float),
|
||||
x.to(torch.double),
|
||||
x.to(torch.cfloat),
|
||||
@ -262,7 +262,7 @@ class TensorViewOpsModule(torch.nn.Module):
|
||||
def tensor_view_ops(self):
|
||||
x = torch.randn(4, 4, 1)
|
||||
y = torch.randn(4, 4, 2)
|
||||
return len(
|
||||
return (
|
||||
x[0, 2:],
|
||||
x.detach(),
|
||||
x.detach_(),
|
||||
|
@ -1,24 +0,0 @@
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
|
||||
|
||||
class MobileNetV2Module:
|
||||
def __init__(self):
|
||||
super(MobileNetV2Module, self).__init__()
|
||||
|
||||
def getModule(self):
|
||||
model = torchvision.models.mobilenet_v2(pretrained=True)
|
||||
model.eval()
|
||||
example = torch.zeros(1, 3, 224, 224)
|
||||
traced_script_module = torch.jit.trace(model, example)
|
||||
optimized_module = optimize_for_mobile(traced_script_module)
|
||||
augment_model_with_bundled_inputs(
|
||||
optimized_module,
|
||||
[
|
||||
(example, ),
|
||||
],
|
||||
)
|
||||
optimized_module(example)
|
||||
return optimized_module
|
Reference in New Issue
Block a user