mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add a workflow to release Android binaries (#110976)
This adds 2 jobs to build PyTorch Android with and without lite interpreter: * Keep the list of currently supported ABI armeabi-v7a, arm64-v8a, x86, x86_64 * Pass all the test on emulator * Run an the test app on emulator and my Android phone `arm64-v8a` without any issue  * Run on AWS https://us-west-2.console.aws.amazon.com/devicefarm/home#/mobile/projects/b531574a-fb82-40ae-b687-8f0b81341ae0/runs/5fce6818-628a-4099-9aab-23e91a212076 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110976 Approved by: https://github.com/atalman
This commit is contained in:
20
.github/workflows/_run_android_tests.yml
vendored
20
.github/workflows/_run_android_tests.yml
vendored
@ -41,9 +41,17 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix: ${{ fromJSON(needs.filter.outputs.test-matrix) }}
|
matrix: ${{ fromJSON(needs.filter.outputs.test-matrix) }}
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
# NB: This job can only run on GitHub Linux runner atm. This is an ok thing though
|
||||||
|
# because that runner is ephemeral and could access upload secrets
|
||||||
runs-on: ${{ matrix.runner }}
|
runs-on: ${{ matrix.runner }}
|
||||||
|
env:
|
||||||
|
# GitHub runner installs Android SDK on this path
|
||||||
|
ANDROID_ROOT: /usr/local/lib/android
|
||||||
|
ANDROID_NDK_VERSION: '21.4.7075529'
|
||||||
|
BUILD_LITE_INTERPRETER: ${{ matrix.use_lite_interpreter }}
|
||||||
|
# 4 of them are supported atm: armeabi-v7a, arm64-v8a, x86, x86_64
|
||||||
|
SUPPORT_ABI: '${{ matrix.support_abi }}'
|
||||||
steps:
|
steps:
|
||||||
# [see note: pytorch repo ref]
|
|
||||||
- name: Checkout PyTorch
|
- name: Checkout PyTorch
|
||||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||||
|
|
||||||
@ -51,7 +59,7 @@ jobs:
|
|||||||
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
|
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
|
||||||
with:
|
with:
|
||||||
python-version: 3.8
|
python-version: 3.8
|
||||||
environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }}
|
environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }}.txt
|
||||||
|
|
||||||
- name: Install NDK
|
- name: Install NDK
|
||||||
uses: nick-fields/retry@v2.8.2
|
uses: nick-fields/retry@v2.8.2
|
||||||
@ -60,12 +68,12 @@ jobs:
|
|||||||
max_attempts: 3
|
max_attempts: 3
|
||||||
retry_wait_seconds: 90
|
retry_wait_seconds: 90
|
||||||
command: |
|
command: |
|
||||||
|
set -eux
|
||||||
|
|
||||||
# Install NDK 21 after GitHub update
|
# Install NDK 21 after GitHub update
|
||||||
# https://github.com/actions/virtual-environments/issues/5595
|
# https://github.com/actions/virtual-environments/issues/5595
|
||||||
ANDROID_ROOT="/usr/local/lib/android"
|
|
||||||
ANDROID_SDK_ROOT="${ANDROID_ROOT}/sdk"
|
ANDROID_SDK_ROOT="${ANDROID_ROOT}/sdk"
|
||||||
ANDROID_NDK="${ANDROID_SDK_ROOT}/ndk-bundle"
|
ANDROID_NDK="${ANDROID_SDK_ROOT}/ndk-bundle"
|
||||||
ANDROID_NDK_VERSION="21.4.7075529"
|
|
||||||
|
|
||||||
SDKMANAGER="${ANDROID_SDK_ROOT}/cmdline-tools/latest/bin/sdkmanager"
|
SDKMANAGER="${ANDROID_SDK_ROOT}/cmdline-tools/latest/bin/sdkmanager"
|
||||||
# NB: This step downloads and installs NDK, thus it could be flaky.
|
# NB: This step downloads and installs NDK, thus it could be flaky.
|
||||||
@ -86,8 +94,10 @@ jobs:
|
|||||||
|
|
||||||
- name: Build PyTorch Android
|
- name: Build PyTorch Android
|
||||||
run: |
|
run: |
|
||||||
|
set -eux
|
||||||
|
|
||||||
echo "CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname "$(which conda)")/../"}" >> "${GITHUB_ENV}"
|
echo "CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname "$(which conda)")/../"}" >> "${GITHUB_ENV}"
|
||||||
${CONDA_RUN} ./scripts/build_pytorch_android.sh x86
|
${CONDA_RUN} ./scripts/build_pytorch_android.sh "${SUPPORT_ABI}"
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
uses: reactivecircus/android-emulator-runner@v2
|
uses: reactivecircus/android-emulator-runner@v2
|
||||||
|
48
.github/workflows/build-android-binaries.yml
vendored
Normal file
48
.github/workflows/build-android-binaries.yml
vendored
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
name: Build Android binaries
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- nightly
|
||||||
|
tags:
|
||||||
|
# NOTE: Binary build pipelines should only get triggered on release candidate builds
|
||||||
|
# Release candidate tags look like: v1.11.0-rc1
|
||||||
|
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
|
||||||
|
paths:
|
||||||
|
- .github/workflows/build-android-binaries.yml
|
||||||
|
- .github/workflows/_run_android_tests.yml
|
||||||
|
- android/**
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- .github/workflows/build-android-binaries.yml
|
||||||
|
- .github/workflows/_run_android_tests.yml
|
||||||
|
- android/**
|
||||||
|
# NB: We can use this workflow dispatch to test and build the binaries manually
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
android-build-test:
|
||||||
|
name: android-build-test
|
||||||
|
uses: ./.github/workflows/_run_android_tests.yml
|
||||||
|
with:
|
||||||
|
test-matrix: |
|
||||||
|
{ include: [
|
||||||
|
{ config: 'default',
|
||||||
|
shard: 1,
|
||||||
|
num_shards: 1,
|
||||||
|
runner: 'ubuntu-20.04-16x',
|
||||||
|
use_lite_interpreter: 1,
|
||||||
|
support_abi: 'armeabi-v7a,arm64-v8a,x86,x86_64',
|
||||||
|
},
|
||||||
|
{ config: 'default',
|
||||||
|
shard: 1,
|
||||||
|
num_shards: 1,
|
||||||
|
runner: 'ubuntu-20.04-16x',
|
||||||
|
use_lite_interpreter: 0,
|
||||||
|
support_abi: 'armeabi-v7a,arm64-v8a,x86,x86_64',
|
||||||
|
},
|
||||||
|
]}
|
9
.github/workflows/periodic.yml
vendored
9
.github/workflows/periodic.yml
vendored
@ -189,7 +189,14 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "default", shard: 1, num_shards: 1, runner: "ubuntu-20.04-16x" },
|
{ config: "default",
|
||||||
|
shard: 1,
|
||||||
|
num_shards: 1,
|
||||||
|
runner: "ubuntu-20.04-16x"
|
||||||
|
use_lite_interpreter: 1,
|
||||||
|
# Just set x86 for testing here
|
||||||
|
support_abi: x86,
|
||||||
|
},
|
||||||
]}
|
]}
|
||||||
|
|
||||||
linux-vulkan-focal-py3_11-clang10-build:
|
linux-vulkan-focal-py3_11-clang10-build:
|
||||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -364,3 +364,7 @@ venv/
|
|||||||
# Log files
|
# Log files
|
||||||
*.log
|
*.log
|
||||||
sweep/
|
sweep/
|
||||||
|
|
||||||
|
# Android build artifacts
|
||||||
|
android/pytorch_android/.cxx
|
||||||
|
android/pytorch_android_torchvision/.cxx
|
||||||
|
@ -41,6 +41,7 @@ android {
|
|||||||
println 'Build pytorch_jni'
|
println 'Build pytorch_jni'
|
||||||
exclude 'org/pytorch/LiteModuleLoader.java'
|
exclude 'org/pytorch/LiteModuleLoader.java'
|
||||||
exclude 'org/pytorch/LiteNativePeer.java'
|
exclude 'org/pytorch/LiteNativePeer.java'
|
||||||
|
exclude 'org/pytorch/LitePyTorchAndroid.java'
|
||||||
} else {
|
} else {
|
||||||
println 'Build pytorch_jni_lite'
|
println 'Build pytorch_jni_lite'
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,7 @@ import java.io.IOException;
|
|||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.Ignore;
|
||||||
|
|
||||||
public abstract class PytorchTestBase {
|
public abstract class PytorchTestBase {
|
||||||
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
|
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
|
||||||
@ -413,7 +414,10 @@ public abstract class PytorchTestBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@Ignore
|
||||||
public void testSpectralOps() throws IOException {
|
public void testSpectralOps() throws IOException {
|
||||||
|
// NB: This model fails without lite interpreter. The error is as follows:
|
||||||
|
// RuntimeError: stft requires the return_complex parameter be given for real inputs
|
||||||
runModel("spectral_ops");
|
runModel("spectral_ops");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,12 +10,6 @@
|
|||||||
#include <fbjni/fbjni.h>
|
#include <fbjni/fbjni.h>
|
||||||
|
|
||||||
#include "pytorch_jni_common.h"
|
#include "pytorch_jni_common.h"
|
||||||
#if defined(__ANDROID__)
|
|
||||||
#ifndef USE_PTHREADPOOL
|
|
||||||
#define USE_PTHREADPOOL
|
|
||||||
#endif /* USE_PTHREADPOOL */
|
|
||||||
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace pytorch_jni {
|
namespace pytorch_jni {
|
||||||
|
|
||||||
@ -666,32 +660,4 @@ at::IValue JIValue::JIValueToAtIValue(
|
|||||||
typeCode);
|
typeCode);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__ANDROID__)
|
|
||||||
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
|
|
||||||
public:
|
|
||||||
constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;";
|
|
||||||
|
|
||||||
static void registerNatives() {
|
|
||||||
javaClassStatic()->registerNatives({
|
|
||||||
makeNativeMethod(
|
|
||||||
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
|
|
||||||
caffe2::pthreadpool()->set_thread_count(numThreads);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void common_registerNatives() {
|
|
||||||
static const int once = []() {
|
|
||||||
#if defined(__ANDROID__)
|
|
||||||
pytorch_jni::PyTorchAndroidJni::registerNatives();
|
|
||||||
#endif
|
|
||||||
return 0;
|
|
||||||
}();
|
|
||||||
((void)once);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace pytorch_jni
|
} // namespace pytorch_jni
|
||||||
|
@ -17,6 +17,11 @@
|
|||||||
#include <android/asset_manager.h>
|
#include <android/asset_manager.h>
|
||||||
#include <android/asset_manager_jni.h>
|
#include <android/asset_manager_jni.h>
|
||||||
#include <android/log.h>
|
#include <android/log.h>
|
||||||
|
|
||||||
|
#ifndef USE_PTHREADPOOL
|
||||||
|
#define USE_PTHREADPOOL
|
||||||
|
#endif /* USE_PTHREADPOOL */
|
||||||
|
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace pytorch_jni {
|
namespace pytorch_jni {
|
||||||
@ -235,6 +240,34 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if defined(__ANDROID__)
|
||||||
|
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
|
||||||
|
public:
|
||||||
|
constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;";
|
||||||
|
|
||||||
|
static void registerNatives() {
|
||||||
|
javaClassStatic()->registerNatives({
|
||||||
|
makeNativeMethod(
|
||||||
|
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
|
||||||
|
caffe2::pthreadpool()->set_thread_count(numThreads);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void common_registerNatives() {
|
||||||
|
static const int once = []() {
|
||||||
|
#if defined(__ANDROID__)
|
||||||
|
pytorch_jni::PyTorchAndroidJni::registerNatives();
|
||||||
|
#endif
|
||||||
|
return 0;
|
||||||
|
}();
|
||||||
|
((void)once);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace pytorch_jni
|
} // namespace pytorch_jni
|
||||||
|
|
||||||
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
|
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
|
||||||
|
@ -18,6 +18,11 @@
|
|||||||
#include <android/asset_manager.h>
|
#include <android/asset_manager.h>
|
||||||
#include <android/asset_manager_jni.h>
|
#include <android/asset_manager_jni.h>
|
||||||
#include <android/log.h>
|
#include <android/log.h>
|
||||||
|
|
||||||
|
#ifndef USE_PTHREADPOOL
|
||||||
|
#define USE_PTHREADPOOL
|
||||||
|
#endif /* USE_PTHREADPOOL */
|
||||||
|
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace pytorch_jni {
|
namespace pytorch_jni {
|
||||||
@ -199,6 +204,34 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if defined(__ANDROID__)
|
||||||
|
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
|
||||||
|
public:
|
||||||
|
constexpr static auto kJavaDescriptor = "Lorg/pytorch/LitePyTorchAndroid;";
|
||||||
|
|
||||||
|
static void registerNatives() {
|
||||||
|
javaClassStatic()->registerNatives({
|
||||||
|
makeNativeMethod(
|
||||||
|
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
|
||||||
|
caffe2::pthreadpool()->set_thread_count(numThreads);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void common_registerNatives() {
|
||||||
|
static const int once = []() {
|
||||||
|
#if defined(__ANDROID__)
|
||||||
|
pytorch_jni::PyTorchAndroidJni::registerNatives();
|
||||||
|
#endif
|
||||||
|
return 0;
|
||||||
|
}();
|
||||||
|
((void)once);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace pytorch_jni
|
} // namespace pytorch_jni
|
||||||
|
|
||||||
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
|
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
|
||||||
|
@ -0,0 +1,50 @@
|
|||||||
|
package org.pytorch;
|
||||||
|
|
||||||
|
import android.content.res.AssetManager;
|
||||||
|
import com.facebook.jni.annotations.DoNotStrip;
|
||||||
|
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||||
|
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||||
|
|
||||||
|
public final class LitePyTorchAndroid {
|
||||||
|
static {
|
||||||
|
if (!NativeLoader.isInitialized()) {
|
||||||
|
NativeLoader.init(new SystemDelegate());
|
||||||
|
}
|
||||||
|
NativeLoader.loadLibrary("pytorch_jni_lite");
|
||||||
|
PyTorchCodegenLoader.loadNativeLibs();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attention: This is not recommended way of loading production modules, as prepackaged assets
|
||||||
|
* increase apk size etc. For production usage consider using loading from file on the disk {@link
|
||||||
|
* org.pytorch.Module#load(String)}.
|
||||||
|
*
|
||||||
|
* <p>This method is meant to use in tests and demos.
|
||||||
|
*/
|
||||||
|
public static Module loadModuleFromAsset(
|
||||||
|
final AssetManager assetManager, final String assetName, final Device device) {
|
||||||
|
return new Module(new LiteNativePeer(assetName, assetManager, device));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Module loadModuleFromAsset(
|
||||||
|
final AssetManager assetManager, final String assetName) {
|
||||||
|
return new Module(new LiteNativePeer(assetName, assetManager, Device.CPU));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Globally sets the number of threads used on native side. Attention: Has global effect, all
|
||||||
|
* modules use one thread pool with specified number of threads.
|
||||||
|
*
|
||||||
|
* @param numThreads number of threads, must be positive number.
|
||||||
|
*/
|
||||||
|
public static void setNumThreads(int numThreads) {
|
||||||
|
if (numThreads < 1) {
|
||||||
|
throw new IllegalArgumentException("Number of threads cannot be less than 1");
|
||||||
|
}
|
||||||
|
|
||||||
|
nativeSetNumThreads(numThreads);
|
||||||
|
}
|
||||||
|
|
||||||
|
@DoNotStrip
|
||||||
|
private static native void nativeSetNumThreads(int numThreads);
|
||||||
|
}
|
@ -10,7 +10,7 @@ public final class PyTorchAndroid {
|
|||||||
if (!NativeLoader.isInitialized()) {
|
if (!NativeLoader.isInitialized()) {
|
||||||
NativeLoader.init(new SystemDelegate());
|
NativeLoader.init(new SystemDelegate());
|
||||||
}
|
}
|
||||||
NativeLoader.loadLibrary("pytorch_jni_lite");
|
NativeLoader.loadLibrary("pytorch_jni");
|
||||||
PyTorchCodegenLoader.loadNativeLibs();
|
PyTorchCodegenLoader.loadNativeLibs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,6 +41,11 @@ android {
|
|||||||
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
|
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
|
||||||
buildConfigField("boolean", "NATIVE_BUILD", 'false')
|
buildConfigField("boolean", "NATIVE_BUILD", 'false')
|
||||||
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
|
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
|
||||||
|
buildConfigField(
|
||||||
|
"int",
|
||||||
|
"BUILD_LITE_INTERPRETER",
|
||||||
|
System.env.BUILD_LITE_INTERPRETER != null ? System.env.BUILD_LITE_INTERPRETER : "1"
|
||||||
|
)
|
||||||
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
|
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
|
||||||
}
|
}
|
||||||
buildTypes {
|
buildTypes {
|
||||||
@ -63,14 +68,15 @@ android {
|
|||||||
mnet {
|
mnet {
|
||||||
dimension "model"
|
dimension "model"
|
||||||
applicationIdSuffix ".mnet"
|
applicationIdSuffix ".mnet"
|
||||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet.pt\"")
|
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2.ptl\"")
|
||||||
addManifestPlaceholders([APP_NAME: "MNET"])
|
addManifestPlaceholders([APP_NAME: "MNET"])
|
||||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"")
|
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"")
|
||||||
}
|
}
|
||||||
|
// NB: This is not working atm https://github.com/pytorch/pytorch/issues/102966
|
||||||
mnetVulkan {
|
mnetVulkan {
|
||||||
dimension "model"
|
dimension "model"
|
||||||
applicationIdSuffix ".mnet_vulkan"
|
applicationIdSuffix ".mnet_vulkan"
|
||||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet_vulkan.pt\"")
|
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2_vulkan.ptl\"")
|
||||||
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
|
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
|
||||||
addManifestPlaceholders([APP_NAME: "MNET_VULKAN"])
|
addManifestPlaceholders([APP_NAME: "MNET_VULKAN"])
|
||||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"")
|
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"")
|
||||||
@ -78,7 +84,7 @@ android {
|
|||||||
resnet18 {
|
resnet18 {
|
||||||
dimension "model"
|
dimension "model"
|
||||||
applicationIdSuffix ".resnet18"
|
applicationIdSuffix ".resnet18"
|
||||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.pt\"")
|
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.ptl\"")
|
||||||
addManifestPlaceholders([APP_NAME: "RN18"])
|
addManifestPlaceholders([APP_NAME: "RN18"])
|
||||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
|
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
|
||||||
}
|
}
|
||||||
@ -149,8 +155,8 @@ dependencies {
|
|||||||
//nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
|
//nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
|
||||||
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
|
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
|
||||||
|
|
||||||
nightlyImplementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT'
|
nightlyImplementation 'org.pytorch:pytorch_android:2.2.0-SNAPSHOT'
|
||||||
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:1.12.0-SNAPSHOT'
|
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:2.2.0-SNAPSHOT'
|
||||||
|
|
||||||
aarImplementation(name:'pytorch_android', ext:'aar')
|
aarImplementation(name:'pytorch_android', ext:'aar')
|
||||||
aarImplementation(name:'pytorch_android_torchvision', ext:'aar')
|
aarImplementation(name:'pytorch_android_torchvision', ext:'aar')
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package org.pytorch.testapp;
|
package org.pytorch.testapp;
|
||||||
|
|
||||||
import android.content.Context;
|
import android.content.Context;
|
||||||
|
import android.content.res.AssetManager;
|
||||||
import android.os.Bundle;
|
import android.os.Bundle;
|
||||||
import android.os.Handler;
|
import android.os.Handler;
|
||||||
import android.os.HandlerThread;
|
import android.os.HandlerThread;
|
||||||
@ -16,6 +17,8 @@ import java.io.FileOutputStream;
|
|||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
import java.lang.reflect.InvocationTargetException;
|
||||||
|
import java.lang.reflect.Method;
|
||||||
import java.nio.FloatBuffer;
|
import java.nio.FloatBuffer;
|
||||||
import org.pytorch.Device;
|
import org.pytorch.Device;
|
||||||
import org.pytorch.IValue;
|
import org.pytorch.IValue;
|
||||||
@ -42,7 +45,13 @@ public class MainActivity extends AppCompatActivity {
|
|||||||
new Runnable() {
|
new Runnable() {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
final Result result = doModuleForward();
|
final Result result;
|
||||||
|
try {
|
||||||
|
result = doModuleForward();
|
||||||
|
} catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException |
|
||||||
|
InvocationTargetException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
runOnUiThread(
|
runOnUiThread(
|
||||||
new Runnable() {
|
new Runnable() {
|
||||||
@Override
|
@Override
|
||||||
@ -118,7 +127,7 @@ public class MainActivity extends AppCompatActivity {
|
|||||||
|
|
||||||
@WorkerThread
|
@WorkerThread
|
||||||
@Nullable
|
@Nullable
|
||||||
protected Result doModuleForward() {
|
protected Result doModuleForward() throws ClassNotFoundException, IllegalAccessException, NoSuchMethodException, InvocationTargetException {
|
||||||
if (mModule == null) {
|
if (mModule == null) {
|
||||||
final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE;
|
final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE;
|
||||||
long numElements = 1;
|
long numElements = 1;
|
||||||
@ -129,12 +138,29 @@ public class MainActivity extends AppCompatActivity {
|
|||||||
mInputTensor =
|
mInputTensor =
|
||||||
Tensor.fromBlob(
|
Tensor.fromBlob(
|
||||||
mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE, MemoryFormat.CHANNELS_LAST);
|
mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE, MemoryFormat.CHANNELS_LAST);
|
||||||
PyTorchAndroid.setNumThreads(1);
|
|
||||||
mModule =
|
Class ptAndroid;
|
||||||
BuildConfig.USE_VULKAN_DEVICE
|
if (BuildConfig.BUILD_LITE_INTERPRETER == 1) {
|
||||||
? PyTorchAndroid.loadModuleFromAsset(
|
ptAndroid = Class.forName("org.pytorch.LitePyTorchAndroid");
|
||||||
getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
|
}
|
||||||
: PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
|
else {
|
||||||
|
ptAndroid = Class.forName("org.pytorch.PyTorchAndroid");
|
||||||
|
}
|
||||||
|
|
||||||
|
Method setNumThreads = ptAndroid.getMethod("setNumThreads", int.class);
|
||||||
|
setNumThreads.invoke(null,1);
|
||||||
|
|
||||||
|
Method loadModuleFromAsset = ptAndroid.getMethod(
|
||||||
|
"loadModuleFromAsset",
|
||||||
|
AssetManager.class,
|
||||||
|
String.class,
|
||||||
|
Device.class
|
||||||
|
);
|
||||||
|
mModule = (Module) (BuildConfig.USE_VULKAN_DEVICE
|
||||||
|
? loadModuleFromAsset.invoke(
|
||||||
|
null, getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
|
||||||
|
: loadModuleFromAsset.invoke(
|
||||||
|
null, getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.CPU));
|
||||||
}
|
}
|
||||||
|
|
||||||
final long startTime = SystemClock.elapsedRealtime();
|
final long startTime = SystemClock.elapsedRealtime();
|
||||||
|
@ -47,7 +47,11 @@ from tensor_ops import (
|
|||||||
TensorViewOpsModule,
|
TensorViewOpsModule,
|
||||||
)
|
)
|
||||||
from torch.jit.mobile import _load_for_lite_interpreter
|
from torch.jit.mobile import _load_for_lite_interpreter
|
||||||
from torchvision_models import MobileNetV2Module
|
from torchvision_models import (
|
||||||
|
MobileNetV2Module,
|
||||||
|
MobileNetV2VulkanModule,
|
||||||
|
Resnet18Module,
|
||||||
|
)
|
||||||
|
|
||||||
test_path_ios = "ios/TestApp/models/"
|
test_path_ios = "ios/TestApp/models/"
|
||||||
test_path_android = "android/pytorch_android/src/androidTest/assets/"
|
test_path_android = "android/pytorch_android/src/androidTest/assets/"
|
||||||
@ -98,6 +102,8 @@ all_modules = {
|
|||||||
"torchscript_collection_ops": TSCollectionOpsModule(),
|
"torchscript_collection_ops": TSCollectionOpsModule(),
|
||||||
# vision
|
# vision
|
||||||
"mobilenet_v2": MobileNetV2Module(),
|
"mobilenet_v2": MobileNetV2Module(),
|
||||||
|
"mobilenet_v2_vulkan": MobileNetV2VulkanModule(),
|
||||||
|
"resnet18": Resnet18Module(),
|
||||||
# android api module
|
# android api module
|
||||||
"android_api_module": AndroidAPIModule(),
|
"android_api_module": AndroidAPIModule(),
|
||||||
}
|
}
|
||||||
|
@ -19,3 +19,37 @@ class MobileNetV2Module:
|
|||||||
)
|
)
|
||||||
optimized_module(example)
|
optimized_module(example)
|
||||||
return optimized_module
|
return optimized_module
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV2VulkanModule:
|
||||||
|
def getModule(self):
|
||||||
|
model = torchvision.models.mobilenet_v2(pretrained=True)
|
||||||
|
model.eval()
|
||||||
|
example = torch.zeros(1, 3, 224, 224)
|
||||||
|
traced_script_module = torch.jit.trace(model, example)
|
||||||
|
optimized_module = optimize_for_mobile(traced_script_module, backend="vulkan")
|
||||||
|
augment_model_with_bundled_inputs(
|
||||||
|
optimized_module,
|
||||||
|
[
|
||||||
|
(example, ),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
optimized_module(example)
|
||||||
|
return optimized_module
|
||||||
|
|
||||||
|
|
||||||
|
class Resnet18Module:
|
||||||
|
def getModule(self):
|
||||||
|
model = torchvision.models.resnet18(pretrained=True)
|
||||||
|
model.eval()
|
||||||
|
example = torch.zeros(1, 3, 224, 224)
|
||||||
|
traced_script_module = torch.jit.trace(model, example)
|
||||||
|
optimized_module = optimize_for_mobile(traced_script_module)
|
||||||
|
augment_model_with_bundled_inputs(
|
||||||
|
optimized_module,
|
||||||
|
[
|
||||||
|
(example, ),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
optimized_module(example)
|
||||||
|
return optimized_module
|
||||||
|
Reference in New Issue
Block a user