Add a workflow to release Android binaries (#110976)

This adds 2 jobs to build PyTorch Android with and without lite interpreter:

* Keep the list of currently supported ABI armeabi-v7a, arm64-v8a, x86, x86_64
* Pass all the test on emulator
* Run an the test app on emulator and my Android phone `arm64-v8a` without any issue
![Screenshot_20231010-114453](https://github.com/pytorch/pytorch/assets/475357/57e12188-1675-44d2-a259-9f9577578590)
* Run on AWS https://us-west-2.console.aws.amazon.com/devicefarm/home#/mobile/projects/b531574a-fb82-40ae-b687-8f0b81341ae0/runs/5fce6818-628a-4099-9aab-23e91a212076
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110976
Approved by: https://github.com/atalman
This commit is contained in:
Huy Do
2023-10-11 00:19:33 +00:00
committed by PyTorch MergeBot
parent 5aa96fd336
commit 2edc75a669
16 changed files with 283 additions and 55 deletions

View File

@ -41,9 +41,17 @@ jobs:
strategy: strategy:
matrix: ${{ fromJSON(needs.filter.outputs.test-matrix) }} matrix: ${{ fromJSON(needs.filter.outputs.test-matrix) }}
fail-fast: false fail-fast: false
# NB: This job can only run on GitHub Linux runner atm. This is an ok thing though
# because that runner is ephemeral and could access upload secrets
runs-on: ${{ matrix.runner }} runs-on: ${{ matrix.runner }}
env:
# GitHub runner installs Android SDK on this path
ANDROID_ROOT: /usr/local/lib/android
ANDROID_NDK_VERSION: '21.4.7075529'
BUILD_LITE_INTERPRETER: ${{ matrix.use_lite_interpreter }}
# 4 of them are supported atm: armeabi-v7a, arm64-v8a, x86, x86_64
SUPPORT_ABI: '${{ matrix.support_abi }}'
steps: steps:
# [see note: pytorch repo ref]
- name: Checkout PyTorch - name: Checkout PyTorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
@ -51,7 +59,7 @@ jobs:
uses: pytorch/test-infra/.github/actions/setup-miniconda@main uses: pytorch/test-infra/.github/actions/setup-miniconda@main
with: with:
python-version: 3.8 python-version: 3.8
environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }}.txt
- name: Install NDK - name: Install NDK
uses: nick-fields/retry@v2.8.2 uses: nick-fields/retry@v2.8.2
@ -60,12 +68,12 @@ jobs:
max_attempts: 3 max_attempts: 3
retry_wait_seconds: 90 retry_wait_seconds: 90
command: | command: |
set -eux
# Install NDK 21 after GitHub update # Install NDK 21 after GitHub update
# https://github.com/actions/virtual-environments/issues/5595 # https://github.com/actions/virtual-environments/issues/5595
ANDROID_ROOT="/usr/local/lib/android"
ANDROID_SDK_ROOT="${ANDROID_ROOT}/sdk" ANDROID_SDK_ROOT="${ANDROID_ROOT}/sdk"
ANDROID_NDK="${ANDROID_SDK_ROOT}/ndk-bundle" ANDROID_NDK="${ANDROID_SDK_ROOT}/ndk-bundle"
ANDROID_NDK_VERSION="21.4.7075529"
SDKMANAGER="${ANDROID_SDK_ROOT}/cmdline-tools/latest/bin/sdkmanager" SDKMANAGER="${ANDROID_SDK_ROOT}/cmdline-tools/latest/bin/sdkmanager"
# NB: This step downloads and installs NDK, thus it could be flaky. # NB: This step downloads and installs NDK, thus it could be flaky.
@ -86,8 +94,10 @@ jobs:
- name: Build PyTorch Android - name: Build PyTorch Android
run: | run: |
set -eux
echo "CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname "$(which conda)")/../"}" >> "${GITHUB_ENV}" echo "CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname "$(which conda)")/../"}" >> "${GITHUB_ENV}"
${CONDA_RUN} ./scripts/build_pytorch_android.sh x86 ${CONDA_RUN} ./scripts/build_pytorch_android.sh "${SUPPORT_ABI}"
- name: Run tests - name: Run tests
uses: reactivecircus/android-emulator-runner@v2 uses: reactivecircus/android-emulator-runner@v2

View File

@ -0,0 +1,48 @@
name: Build Android binaries
on:
push:
branches:
- nightly
tags:
# NOTE: Binary build pipelines should only get triggered on release candidate builds
# Release candidate tags look like: v1.11.0-rc1
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
paths:
- .github/workflows/build-android-binaries.yml
- .github/workflows/_run_android_tests.yml
- android/**
pull_request:
paths:
- .github/workflows/build-android-binaries.yml
- .github/workflows/_run_android_tests.yml
- android/**
# NB: We can use this workflow dispatch to test and build the binaries manually
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
jobs:
android-build-test:
name: android-build-test
uses: ./.github/workflows/_run_android_tests.yml
with:
test-matrix: |
{ include: [
{ config: 'default',
shard: 1,
num_shards: 1,
runner: 'ubuntu-20.04-16x',
use_lite_interpreter: 1,
support_abi: 'armeabi-v7a,arm64-v8a,x86,x86_64',
},
{ config: 'default',
shard: 1,
num_shards: 1,
runner: 'ubuntu-20.04-16x',
use_lite_interpreter: 0,
support_abi: 'armeabi-v7a,arm64-v8a,x86,x86_64',
},
]}

View File

@ -189,7 +189,14 @@ jobs:
with: with:
test-matrix: | test-matrix: |
{ include: [ { include: [
{ config: "default", shard: 1, num_shards: 1, runner: "ubuntu-20.04-16x" }, { config: "default",
shard: 1,
num_shards: 1,
runner: "ubuntu-20.04-16x"
use_lite_interpreter: 1,
# Just set x86 for testing here
support_abi: x86,
},
]} ]}
linux-vulkan-focal-py3_11-clang10-build: linux-vulkan-focal-py3_11-clang10-build:

4
.gitignore vendored
View File

@ -364,3 +364,7 @@ venv/
# Log files # Log files
*.log *.log
sweep/ sweep/
# Android build artifacts
android/pytorch_android/.cxx
android/pytorch_android_torchvision/.cxx

View File

@ -41,6 +41,7 @@ android {
println 'Build pytorch_jni' println 'Build pytorch_jni'
exclude 'org/pytorch/LiteModuleLoader.java' exclude 'org/pytorch/LiteModuleLoader.java'
exclude 'org/pytorch/LiteNativePeer.java' exclude 'org/pytorch/LiteNativePeer.java'
exclude 'org/pytorch/LitePyTorchAndroid.java'
} else { } else {
println 'Build pytorch_jni_lite' println 'Build pytorch_jni_lite'
} }

View File

@ -10,6 +10,7 @@ import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.junit.Test; import org.junit.Test;
import org.junit.Ignore;
public abstract class PytorchTestBase { public abstract class PytorchTestBase {
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl"; private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
@ -413,7 +414,10 @@ public abstract class PytorchTestBase {
} }
@Test @Test
@Ignore
public void testSpectralOps() throws IOException { public void testSpectralOps() throws IOException {
// NB: This model fails without lite interpreter. The error is as follows:
// RuntimeError: stft requires the return_complex parameter be given for real inputs
runModel("spectral_ops"); runModel("spectral_ops");
} }

View File

@ -10,12 +10,6 @@
#include <fbjni/fbjni.h> #include <fbjni/fbjni.h>
#include "pytorch_jni_common.h" #include "pytorch_jni_common.h"
#if defined(__ANDROID__)
#ifndef USE_PTHREADPOOL
#define USE_PTHREADPOOL
#endif /* USE_PTHREADPOOL */
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#endif
namespace pytorch_jni { namespace pytorch_jni {
@ -666,32 +660,4 @@ at::IValue JIValue::JIValueToAtIValue(
typeCode); typeCode);
} }
#if defined(__ANDROID__)
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;";
static void registerNatives() {
javaClassStatic()->registerNatives({
makeNativeMethod(
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
});
}
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
caffe2::pthreadpool()->set_thread_count(numThreads);
}
};
#endif
void common_registerNatives() {
static const int once = []() {
#if defined(__ANDROID__)
pytorch_jni::PyTorchAndroidJni::registerNatives();
#endif
return 0;
}();
((void)once);
}
} // namespace pytorch_jni } // namespace pytorch_jni

View File

@ -17,6 +17,11 @@
#include <android/asset_manager.h> #include <android/asset_manager.h>
#include <android/asset_manager_jni.h> #include <android/asset_manager_jni.h>
#include <android/log.h> #include <android/log.h>
#ifndef USE_PTHREADPOOL
#define USE_PTHREADPOOL
#endif /* USE_PTHREADPOOL */
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#endif #endif
namespace pytorch_jni { namespace pytorch_jni {
@ -235,6 +240,34 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
} }
}; };
#if defined(__ANDROID__)
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;";
static void registerNatives() {
javaClassStatic()->registerNatives({
makeNativeMethod(
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
});
}
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
caffe2::pthreadpool()->set_thread_count(numThreads);
}
};
#endif
void common_registerNatives() {
static const int once = []() {
#if defined(__ANDROID__)
pytorch_jni::PyTorchAndroidJni::registerNatives();
#endif
return 0;
}();
((void)once);
}
} // namespace pytorch_jni } // namespace pytorch_jni
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {

View File

@ -18,6 +18,11 @@
#include <android/asset_manager.h> #include <android/asset_manager.h>
#include <android/asset_manager_jni.h> #include <android/asset_manager_jni.h>
#include <android/log.h> #include <android/log.h>
#ifndef USE_PTHREADPOOL
#define USE_PTHREADPOOL
#endif /* USE_PTHREADPOOL */
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#endif #endif
namespace pytorch_jni { namespace pytorch_jni {
@ -199,6 +204,34 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
} }
}; };
#if defined(__ANDROID__)
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/LitePyTorchAndroid;";
static void registerNatives() {
javaClassStatic()->registerNatives({
makeNativeMethod(
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
});
}
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
caffe2::pthreadpool()->set_thread_count(numThreads);
}
};
#endif
void common_registerNatives() {
static const int once = []() {
#if defined(__ANDROID__)
pytorch_jni::PyTorchAndroidJni::registerNatives();
#endif
return 0;
}();
((void)once);
}
} // namespace pytorch_jni } // namespace pytorch_jni
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {

View File

@ -0,0 +1,50 @@
package org.pytorch;
import android.content.res.AssetManager;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
public final class LitePyTorchAndroid {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_jni_lite");
PyTorchCodegenLoader.loadNativeLibs();
}
/**
* Attention: This is not recommended way of loading production modules, as prepackaged assets
* increase apk size etc. For production usage consider using loading from file on the disk {@link
* org.pytorch.Module#load(String)}.
*
* <p>This method is meant to use in tests and demos.
*/
public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName, final Device device) {
return new Module(new LiteNativePeer(assetName, assetManager, device));
}
public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName) {
return new Module(new LiteNativePeer(assetName, assetManager, Device.CPU));
}
/**
* Globally sets the number of threads used on native side. Attention: Has global effect, all
* modules use one thread pool with specified number of threads.
*
* @param numThreads number of threads, must be positive number.
*/
public static void setNumThreads(int numThreads) {
if (numThreads < 1) {
throw new IllegalArgumentException("Number of threads cannot be less than 1");
}
nativeSetNumThreads(numThreads);
}
@DoNotStrip
private static native void nativeSetNumThreads(int numThreads);
}

View File

@ -10,7 +10,7 @@ public final class PyTorchAndroid {
if (!NativeLoader.isInitialized()) { if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate()); NativeLoader.init(new SystemDelegate());
} }
NativeLoader.loadLibrary("pytorch_jni_lite"); NativeLoader.loadLibrary("pytorch_jni");
PyTorchCodegenLoader.loadNativeLibs(); PyTorchCodegenLoader.loadNativeLibs();
} }

View File

@ -41,6 +41,11 @@ android {
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}") buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
buildConfigField("boolean", "NATIVE_BUILD", 'false') buildConfigField("boolean", "NATIVE_BUILD", 'false')
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false') buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
buildConfigField(
"int",
"BUILD_LITE_INTERPRETER",
System.env.BUILD_LITE_INTERPRETER != null ? System.env.BUILD_LITE_INTERPRETER : "1"
)
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"]) addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
} }
buildTypes { buildTypes {
@ -63,14 +68,15 @@ android {
mnet { mnet {
dimension "model" dimension "model"
applicationIdSuffix ".mnet" applicationIdSuffix ".mnet"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet.pt\"") buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2.ptl\"")
addManifestPlaceholders([APP_NAME: "MNET"]) addManifestPlaceholders([APP_NAME: "MNET"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"") buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"")
} }
// NB: This is not working atm https://github.com/pytorch/pytorch/issues/102966
mnetVulkan { mnetVulkan {
dimension "model" dimension "model"
applicationIdSuffix ".mnet_vulkan" applicationIdSuffix ".mnet_vulkan"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet_vulkan.pt\"") buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2_vulkan.ptl\"")
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true') buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
addManifestPlaceholders([APP_NAME: "MNET_VULKAN"]) addManifestPlaceholders([APP_NAME: "MNET_VULKAN"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"") buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"")
@ -78,7 +84,7 @@ android {
resnet18 { resnet18 {
dimension "model" dimension "model"
applicationIdSuffix ".resnet18" applicationIdSuffix ".resnet18"
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.pt\"") buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.ptl\"")
addManifestPlaceholders([APP_NAME: "RN18"]) addManifestPlaceholders([APP_NAME: "RN18"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"") buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
} }
@ -149,8 +155,8 @@ dependencies {
//nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar') //nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar') //extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
nightlyImplementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT' nightlyImplementation 'org.pytorch:pytorch_android:2.2.0-SNAPSHOT'
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:1.12.0-SNAPSHOT' nightlyImplementation 'org.pytorch:pytorch_android_torchvision:2.2.0-SNAPSHOT'
aarImplementation(name:'pytorch_android', ext:'aar') aarImplementation(name:'pytorch_android', ext:'aar')
aarImplementation(name:'pytorch_android_torchvision', ext:'aar') aarImplementation(name:'pytorch_android_torchvision', ext:'aar')

View File

@ -1,6 +1,7 @@
package org.pytorch.testapp; package org.pytorch.testapp;
import android.content.Context; import android.content.Context;
import android.content.res.AssetManager;
import android.os.Bundle; import android.os.Bundle;
import android.os.Handler; import android.os.Handler;
import android.os.HandlerThread; import android.os.HandlerThread;
@ -16,6 +17,8 @@ import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.FloatBuffer; import java.nio.FloatBuffer;
import org.pytorch.Device; import org.pytorch.Device;
import org.pytorch.IValue; import org.pytorch.IValue;
@ -42,7 +45,13 @@ public class MainActivity extends AppCompatActivity {
new Runnable() { new Runnable() {
@Override @Override
public void run() { public void run() {
final Result result = doModuleForward(); final Result result;
try {
result = doModuleForward();
} catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException |
InvocationTargetException e) {
throw new RuntimeException(e);
}
runOnUiThread( runOnUiThread(
new Runnable() { new Runnable() {
@Override @Override
@ -118,7 +127,7 @@ public class MainActivity extends AppCompatActivity {
@WorkerThread @WorkerThread
@Nullable @Nullable
protected Result doModuleForward() { protected Result doModuleForward() throws ClassNotFoundException, IllegalAccessException, NoSuchMethodException, InvocationTargetException {
if (mModule == null) { if (mModule == null) {
final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE; final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE;
long numElements = 1; long numElements = 1;
@ -129,12 +138,29 @@ public class MainActivity extends AppCompatActivity {
mInputTensor = mInputTensor =
Tensor.fromBlob( Tensor.fromBlob(
mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE, MemoryFormat.CHANNELS_LAST); mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE, MemoryFormat.CHANNELS_LAST);
PyTorchAndroid.setNumThreads(1);
mModule = Class ptAndroid;
BuildConfig.USE_VULKAN_DEVICE if (BuildConfig.BUILD_LITE_INTERPRETER == 1) {
? PyTorchAndroid.loadModuleFromAsset( ptAndroid = Class.forName("org.pytorch.LitePyTorchAndroid");
getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN) }
: PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME); else {
ptAndroid = Class.forName("org.pytorch.PyTorchAndroid");
}
Method setNumThreads = ptAndroid.getMethod("setNumThreads", int.class);
setNumThreads.invoke(null,1);
Method loadModuleFromAsset = ptAndroid.getMethod(
"loadModuleFromAsset",
AssetManager.class,
String.class,
Device.class
);
mModule = (Module) (BuildConfig.USE_VULKAN_DEVICE
? loadModuleFromAsset.invoke(
null, getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
: loadModuleFromAsset.invoke(
null, getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.CPU));
} }
final long startTime = SystemClock.elapsedRealtime(); final long startTime = SystemClock.elapsedRealtime();

View File

@ -47,7 +47,11 @@ from tensor_ops import (
TensorViewOpsModule, TensorViewOpsModule,
) )
from torch.jit.mobile import _load_for_lite_interpreter from torch.jit.mobile import _load_for_lite_interpreter
from torchvision_models import MobileNetV2Module from torchvision_models import (
MobileNetV2Module,
MobileNetV2VulkanModule,
Resnet18Module,
)
test_path_ios = "ios/TestApp/models/" test_path_ios = "ios/TestApp/models/"
test_path_android = "android/pytorch_android/src/androidTest/assets/" test_path_android = "android/pytorch_android/src/androidTest/assets/"
@ -98,6 +102,8 @@ all_modules = {
"torchscript_collection_ops": TSCollectionOpsModule(), "torchscript_collection_ops": TSCollectionOpsModule(),
# vision # vision
"mobilenet_v2": MobileNetV2Module(), "mobilenet_v2": MobileNetV2Module(),
"mobilenet_v2_vulkan": MobileNetV2VulkanModule(),
"resnet18": Resnet18Module(),
# android api module # android api module
"android_api_module": AndroidAPIModule(), "android_api_module": AndroidAPIModule(),
} }

View File

@ -19,3 +19,37 @@ class MobileNetV2Module:
) )
optimized_module(example) optimized_module(example)
return optimized_module return optimized_module
class MobileNetV2VulkanModule:
def getModule(self):
model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.zeros(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_module = optimize_for_mobile(traced_script_module, backend="vulkan")
augment_model_with_bundled_inputs(
optimized_module,
[
(example, ),
],
)
optimized_module(example)
return optimized_module
class Resnet18Module:
def getModule(self):
model = torchvision.models.resnet18(pretrained=True)
model.eval()
example = torch.zeros(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_module = optimize_for_mobile(traced_script_module)
augment_model_with_bundled_inputs(
optimized_module,
[
(example, ),
],
)
optimized_module(example)
return optimized_module