mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add a workflow to release Android binaries (#110976)
This adds 2 jobs to build PyTorch Android with and without lite interpreter: * Keep the list of currently supported ABI armeabi-v7a, arm64-v8a, x86, x86_64 * Pass all the test on emulator * Run an the test app on emulator and my Android phone `arm64-v8a` without any issue  * Run on AWS https://us-west-2.console.aws.amazon.com/devicefarm/home#/mobile/projects/b531574a-fb82-40ae-b687-8f0b81341ae0/runs/5fce6818-628a-4099-9aab-23e91a212076 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110976 Approved by: https://github.com/atalman
This commit is contained in:
@ -41,6 +41,7 @@ android {
|
||||
println 'Build pytorch_jni'
|
||||
exclude 'org/pytorch/LiteModuleLoader.java'
|
||||
exclude 'org/pytorch/LiteNativePeer.java'
|
||||
exclude 'org/pytorch/LitePyTorchAndroid.java'
|
||||
} else {
|
||||
println 'Build pytorch_jni_lite'
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.Test;
|
||||
import org.junit.Ignore;
|
||||
|
||||
public abstract class PytorchTestBase {
|
||||
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
|
||||
@ -413,7 +414,10 @@ public abstract class PytorchTestBase {
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testSpectralOps() throws IOException {
|
||||
// NB: This model fails without lite interpreter. The error is as follows:
|
||||
// RuntimeError: stft requires the return_complex parameter be given for real inputs
|
||||
runModel("spectral_ops");
|
||||
}
|
||||
|
||||
|
@ -10,12 +10,6 @@
|
||||
#include <fbjni/fbjni.h>
|
||||
|
||||
#include "pytorch_jni_common.h"
|
||||
#if defined(__ANDROID__)
|
||||
#ifndef USE_PTHREADPOOL
|
||||
#define USE_PTHREADPOOL
|
||||
#endif /* USE_PTHREADPOOL */
|
||||
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||
#endif
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
@ -666,32 +660,4 @@ at::IValue JIValue::JIValueToAtIValue(
|
||||
typeCode);
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
|
||||
public:
|
||||
constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;";
|
||||
|
||||
static void registerNatives() {
|
||||
javaClassStatic()->registerNatives({
|
||||
makeNativeMethod(
|
||||
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
|
||||
});
|
||||
}
|
||||
|
||||
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
|
||||
caffe2::pthreadpool()->set_thread_count(numThreads);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
void common_registerNatives() {
|
||||
static const int once = []() {
|
||||
#if defined(__ANDROID__)
|
||||
pytorch_jni::PyTorchAndroidJni::registerNatives();
|
||||
#endif
|
||||
return 0;
|
||||
}();
|
||||
((void)once);
|
||||
}
|
||||
|
||||
} // namespace pytorch_jni
|
||||
|
@ -17,6 +17,11 @@
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
#include <android/log.h>
|
||||
|
||||
#ifndef USE_PTHREADPOOL
|
||||
#define USE_PTHREADPOOL
|
||||
#endif /* USE_PTHREADPOOL */
|
||||
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||
#endif
|
||||
|
||||
namespace pytorch_jni {
|
||||
@ -235,6 +240,34 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
|
||||
public:
|
||||
constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;";
|
||||
|
||||
static void registerNatives() {
|
||||
javaClassStatic()->registerNatives({
|
||||
makeNativeMethod(
|
||||
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
|
||||
});
|
||||
}
|
||||
|
||||
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
|
||||
caffe2::pthreadpool()->set_thread_count(numThreads);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
void common_registerNatives() {
|
||||
static const int once = []() {
|
||||
#if defined(__ANDROID__)
|
||||
pytorch_jni::PyTorchAndroidJni::registerNatives();
|
||||
#endif
|
||||
return 0;
|
||||
}();
|
||||
((void)once);
|
||||
}
|
||||
|
||||
} // namespace pytorch_jni
|
||||
|
||||
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
|
||||
|
@ -18,6 +18,11 @@
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
#include <android/log.h>
|
||||
|
||||
#ifndef USE_PTHREADPOOL
|
||||
#define USE_PTHREADPOOL
|
||||
#endif /* USE_PTHREADPOOL */
|
||||
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||
#endif
|
||||
|
||||
namespace pytorch_jni {
|
||||
@ -199,6 +204,34 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
|
||||
public:
|
||||
constexpr static auto kJavaDescriptor = "Lorg/pytorch/LitePyTorchAndroid;";
|
||||
|
||||
static void registerNatives() {
|
||||
javaClassStatic()->registerNatives({
|
||||
makeNativeMethod(
|
||||
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
|
||||
});
|
||||
}
|
||||
|
||||
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
|
||||
caffe2::pthreadpool()->set_thread_count(numThreads);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
void common_registerNatives() {
|
||||
static const int once = []() {
|
||||
#if defined(__ANDROID__)
|
||||
pytorch_jni::PyTorchAndroidJni::registerNatives();
|
||||
#endif
|
||||
return 0;
|
||||
}();
|
||||
((void)once);
|
||||
}
|
||||
|
||||
} // namespace pytorch_jni
|
||||
|
||||
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
|
||||
|
@ -0,0 +1,50 @@
|
||||
package org.pytorch;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import com.facebook.jni.annotations.DoNotStrip;
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||
|
||||
public final class LitePyTorchAndroid {
|
||||
static {
|
||||
if (!NativeLoader.isInitialized()) {
|
||||
NativeLoader.init(new SystemDelegate());
|
||||
}
|
||||
NativeLoader.loadLibrary("pytorch_jni_lite");
|
||||
PyTorchCodegenLoader.loadNativeLibs();
|
||||
}
|
||||
|
||||
/**
|
||||
* Attention: This is not recommended way of loading production modules, as prepackaged assets
|
||||
* increase apk size etc. For production usage consider using loading from file on the disk {@link
|
||||
* org.pytorch.Module#load(String)}.
|
||||
*
|
||||
* <p>This method is meant to use in tests and demos.
|
||||
*/
|
||||
public static Module loadModuleFromAsset(
|
||||
final AssetManager assetManager, final String assetName, final Device device) {
|
||||
return new Module(new LiteNativePeer(assetName, assetManager, device));
|
||||
}
|
||||
|
||||
public static Module loadModuleFromAsset(
|
||||
final AssetManager assetManager, final String assetName) {
|
||||
return new Module(new LiteNativePeer(assetName, assetManager, Device.CPU));
|
||||
}
|
||||
|
||||
/**
|
||||
* Globally sets the number of threads used on native side. Attention: Has global effect, all
|
||||
* modules use one thread pool with specified number of threads.
|
||||
*
|
||||
* @param numThreads number of threads, must be positive number.
|
||||
*/
|
||||
public static void setNumThreads(int numThreads) {
|
||||
if (numThreads < 1) {
|
||||
throw new IllegalArgumentException("Number of threads cannot be less than 1");
|
||||
}
|
||||
|
||||
nativeSetNumThreads(numThreads);
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
private static native void nativeSetNumThreads(int numThreads);
|
||||
}
|
@ -10,7 +10,7 @@ public final class PyTorchAndroid {
|
||||
if (!NativeLoader.isInitialized()) {
|
||||
NativeLoader.init(new SystemDelegate());
|
||||
}
|
||||
NativeLoader.loadLibrary("pytorch_jni_lite");
|
||||
NativeLoader.loadLibrary("pytorch_jni");
|
||||
PyTorchCodegenLoader.loadNativeLibs();
|
||||
}
|
||||
|
||||
|
@ -41,6 +41,11 @@ android {
|
||||
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
|
||||
buildConfigField("boolean", "NATIVE_BUILD", 'false')
|
||||
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
|
||||
buildConfigField(
|
||||
"int",
|
||||
"BUILD_LITE_INTERPRETER",
|
||||
System.env.BUILD_LITE_INTERPRETER != null ? System.env.BUILD_LITE_INTERPRETER : "1"
|
||||
)
|
||||
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
|
||||
}
|
||||
buildTypes {
|
||||
@ -63,14 +68,15 @@ android {
|
||||
mnet {
|
||||
dimension "model"
|
||||
applicationIdSuffix ".mnet"
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet.pt\"")
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2.ptl\"")
|
||||
addManifestPlaceholders([APP_NAME: "MNET"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"")
|
||||
}
|
||||
// NB: This is not working atm https://github.com/pytorch/pytorch/issues/102966
|
||||
mnetVulkan {
|
||||
dimension "model"
|
||||
applicationIdSuffix ".mnet_vulkan"
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet_vulkan.pt\"")
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2_vulkan.ptl\"")
|
||||
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
|
||||
addManifestPlaceholders([APP_NAME: "MNET_VULKAN"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"")
|
||||
@ -78,7 +84,7 @@ android {
|
||||
resnet18 {
|
||||
dimension "model"
|
||||
applicationIdSuffix ".resnet18"
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.pt\"")
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.ptl\"")
|
||||
addManifestPlaceholders([APP_NAME: "RN18"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
|
||||
}
|
||||
@ -149,8 +155,8 @@ dependencies {
|
||||
//nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
|
||||
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
|
||||
|
||||
nightlyImplementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT'
|
||||
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:1.12.0-SNAPSHOT'
|
||||
nightlyImplementation 'org.pytorch:pytorch_android:2.2.0-SNAPSHOT'
|
||||
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:2.2.0-SNAPSHOT'
|
||||
|
||||
aarImplementation(name:'pytorch_android', ext:'aar')
|
||||
aarImplementation(name:'pytorch_android_torchvision', ext:'aar')
|
||||
|
@ -1,6 +1,7 @@
|
||||
package org.pytorch.testapp;
|
||||
|
||||
import android.content.Context;
|
||||
import android.content.res.AssetManager;
|
||||
import android.os.Bundle;
|
||||
import android.os.Handler;
|
||||
import android.os.HandlerThread;
|
||||
@ -16,6 +17,8 @@ import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.lang.reflect.Method;
|
||||
import java.nio.FloatBuffer;
|
||||
import org.pytorch.Device;
|
||||
import org.pytorch.IValue;
|
||||
@ -42,7 +45,13 @@ public class MainActivity extends AppCompatActivity {
|
||||
new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
final Result result = doModuleForward();
|
||||
final Result result;
|
||||
try {
|
||||
result = doModuleForward();
|
||||
} catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException |
|
||||
InvocationTargetException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
runOnUiThread(
|
||||
new Runnable() {
|
||||
@Override
|
||||
@ -118,7 +127,7 @@ public class MainActivity extends AppCompatActivity {
|
||||
|
||||
@WorkerThread
|
||||
@Nullable
|
||||
protected Result doModuleForward() {
|
||||
protected Result doModuleForward() throws ClassNotFoundException, IllegalAccessException, NoSuchMethodException, InvocationTargetException {
|
||||
if (mModule == null) {
|
||||
final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE;
|
||||
long numElements = 1;
|
||||
@ -129,12 +138,29 @@ public class MainActivity extends AppCompatActivity {
|
||||
mInputTensor =
|
||||
Tensor.fromBlob(
|
||||
mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE, MemoryFormat.CHANNELS_LAST);
|
||||
PyTorchAndroid.setNumThreads(1);
|
||||
mModule =
|
||||
BuildConfig.USE_VULKAN_DEVICE
|
||||
? PyTorchAndroid.loadModuleFromAsset(
|
||||
getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
|
||||
: PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
|
||||
|
||||
Class ptAndroid;
|
||||
if (BuildConfig.BUILD_LITE_INTERPRETER == 1) {
|
||||
ptAndroid = Class.forName("org.pytorch.LitePyTorchAndroid");
|
||||
}
|
||||
else {
|
||||
ptAndroid = Class.forName("org.pytorch.PyTorchAndroid");
|
||||
}
|
||||
|
||||
Method setNumThreads = ptAndroid.getMethod("setNumThreads", int.class);
|
||||
setNumThreads.invoke(null,1);
|
||||
|
||||
Method loadModuleFromAsset = ptAndroid.getMethod(
|
||||
"loadModuleFromAsset",
|
||||
AssetManager.class,
|
||||
String.class,
|
||||
Device.class
|
||||
);
|
||||
mModule = (Module) (BuildConfig.USE_VULKAN_DEVICE
|
||||
? loadModuleFromAsset.invoke(
|
||||
null, getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
|
||||
: loadModuleFromAsset.invoke(
|
||||
null, getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.CPU));
|
||||
}
|
||||
|
||||
final long startTime = SystemClock.elapsedRealtime();
|
||||
|
Reference in New Issue
Block a user