mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add models test for android and iOS
See https://github.com/pytorch/pytorch/pull/74793 for more details Redo everything in one pr. This diff added test model generation script and updated android and ios tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/74947 Approved by: https://github.com/kit1980
This commit is contained in:
committed by
PyTorch MergeBot
parent
f82b2d4a82
commit
28a4b4759a
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/dropout_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/dropout_ops.ptl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/linear_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/linear_ops.ptl
Normal file
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/mobilenet_v2.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/mobilenet_v2.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/nn_utils_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/nn_utils_ops.ptl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/padding_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/padding_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/pointwise_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/pointwise_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/pooling_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/pooling_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/recurrent_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/recurrent_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/reduction_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/reduction_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/sampling_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/sampling_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/shuffle_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/shuffle_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/sparse_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/sparse_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/spectral_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/spectral_ops.ptl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -9,10 +9,11 @@ import static org.junit.Assert.assertTrue;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public abstract class PytorchTestBase {
|
||||
private static final String TEST_MODULE_ASSET_NAME = "test.pt";
|
||||
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
|
||||
|
||||
@Test
|
||||
public void testForwardNull() throws IOException {
|
||||
@ -377,6 +378,186 @@ public abstract class PytorchTestBase {
|
||||
new long[] {2, 11, -101, 4, 12, -102, 6, 13, -103, 8, 14, -104});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMobileNetV2() throws IOException {
|
||||
try {
|
||||
final Module module = loadModel("mobilenet_v2.ptl");
|
||||
final IValue inputs = module.runMethod("get_all_bundled_inputs");
|
||||
assertTrue(inputs.isList());
|
||||
final IValue input = inputs.toList()[0];
|
||||
assertTrue(input.isTuple());
|
||||
module.forward(input.toTuple()[0]);
|
||||
assertTrue(true);
|
||||
} catch (Exception ex) {
|
||||
assertTrue("failed to run MobileNetV2 " + ex.getMessage(), false);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPointwiseOps() throws IOException {
|
||||
runModel("pointwise_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReductionOps() throws IOException {
|
||||
runModel("reduction_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComparisonOps() throws IOException {
|
||||
runModel("comparison_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOtherMathOps() throws IOException {
|
||||
runModel("other_math_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSpectralOps() throws IOException {
|
||||
runModel("spectral_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBlasLapackOps() throws IOException {
|
||||
runModel("blas_lapack_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSamplingOps() throws IOException {
|
||||
runModel("sampling_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorOps() throws IOException {
|
||||
runModel("tensor_general_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorCreationOps() throws IOException {
|
||||
runModel("tensor_creation_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorIndexingOps() throws IOException {
|
||||
runModel("tensor_indexing_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorTypingOps() throws IOException {
|
||||
runModel("tensor_typing_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorViewOps() throws IOException {
|
||||
runModel("tensor_view_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConvolutionOps() throws IOException {
|
||||
runModel("convolution_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPoolingOps() throws IOException {
|
||||
runModel("pooling_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPaddingOps() throws IOException {
|
||||
runModel("padding_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActivationOps() throws IOException {
|
||||
runModel("activation_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNormalizationOps() throws IOException {
|
||||
runModel("normalization_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRecurrentOps() throws IOException {
|
||||
runModel("recurrent_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransformerOps() throws IOException {
|
||||
runModel("transformer_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLinearOps() throws IOException {
|
||||
runModel("linear_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDropoutOps() throws IOException {
|
||||
runModel("dropout_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSparseOps() throws IOException {
|
||||
runModel("sparse_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDistanceFunctionOps() throws IOException {
|
||||
runModel("distance_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLossFunctionOps() throws IOException {
|
||||
runModel("loss_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVisionFunctionOps() throws IOException {
|
||||
runModel("vision_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShuffleOps() throws IOException {
|
||||
runModel("shuffle_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNNUtilsOps() throws IOException {
|
||||
runModel("nn_utils_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testQuantOps() throws IOException {
|
||||
runModel("general_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDynamicQuantOps() throws IOException {
|
||||
runModel("dynamic_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStaticQuantOps() throws IOException {
|
||||
runModel("static_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFusedQuantOps() throws IOException {
|
||||
runModel("fused_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTorchScriptBuiltinQuantOps() throws IOException {
|
||||
runModel("torchscript_builtin_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTorchScriptCollectionQuantOps() throws IOException {
|
||||
runModel("torchscript_collection_ops");
|
||||
}
|
||||
|
||||
static void assertIValueTensor(
|
||||
final IValue ivalue,
|
||||
final MemoryFormat memoryFormat,
|
||||
@ -389,5 +570,15 @@ public abstract class PytorchTestBase {
|
||||
assertArrayEquals(expectedData, t.getDataAsLongArray());
|
||||
}
|
||||
|
||||
void runModel(final String name) throws IOException {
|
||||
final Module storage_module = loadModel(name + ".ptl");
|
||||
storage_module.forward();
|
||||
|
||||
// TODO enable this once the on-the-fly script is ready
|
||||
// final Module on_the_fly_module = loadModel(name + "_temp.ptl");
|
||||
// on_the_fly_module.forward();
|
||||
assertTrue(true);
|
||||
}
|
||||
|
||||
protected abstract Module loadModel(String assetName) throws IOException;
|
||||
}
|
||||
|
Reference in New Issue
Block a user