[pytorch] Delete TorchScript based Android demo app and point user to ExecuTorch (#153767)

Summary: A retry of #153656. This time start from co-dev to make sure we capture internal signals.

Test Plan: Rely on CI jobs.

Differential Revision: D74911818

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153767
Approved by: https://github.com/kirklandsign, https://github.com/cyyever, https://github.com/Skylion007
This commit is contained in:
Mengwei Liu
2025-05-19 17:20:36 +00:00
committed by PyTorch MergeBot
parent 6487ea30b3
commit be36bacdaa
26 changed files with 4 additions and 1987 deletions

View File

@ -2,7 +2,9 @@
## Demo applications and tutorials
Demo applications with code walk-through can be find in [this github repo](https://github.com/pytorch/android-demo-app).
Please refer to [pytorch-labs/executorch-examples](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch).
Please join our [Discord](https://discord.com/channels/1334270993966825602/1349854760299270284) for any questions.
## Publishing
@ -119,8 +121,6 @@ We also have to add all transitive dependencies of our aars.
As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.5'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them.
(In case of using maven dependencies they are added automatically from `pom.xml`).
You can check out [test app example](https://github.com/pytorch/pytorch/blob/master/android/test_app/app/build.gradle) that uses aars directly.
## Linking to prebuilt libtorch library from gradle dependency
In some cases, you may want to use libtorch from your android native build.
@ -202,7 +202,7 @@ find_library(FBJNI_LIBRARY fbjni
NO_CMAKE_FIND_ROOT_PATH)
target_link_libraries(${PROJECT_NAME}
${PYTORCH_LIBRARY})
${PYTORCH_LIBRARY}
${FBJNI_LIBRARY})
```
@ -233,8 +233,6 @@ void loadAndForwardModel(const std::string& modelPath) {
To load torchscript model for mobile we need some special setup which is placed in `struct JITCallGuard` in this example. It may change in future, you can track the latest changes keeping an eye in our [pytorch android jni code]([https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp#L28)
[Example of linking to libtorch from aar](https://github.com/pytorch/pytorch/tree/master/android/test_app)
## PyTorch Android API Javadoc
You can find more details about the PyTorch Android API in the [Javadoc](https://pytorch.org/javadoc/).

View File

@ -1,30 +0,0 @@
#!/bin/bash
set -eux
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android
echo "PYTORCH_DIR:$PYTORCH_DIR"
source "$PYTORCH_ANDROID_DIR/common.sh"
check_android_sdk
check_gradle
parse_abis_list "$@"
build_android
# To set proxy for gradle add following lines to ./gradle/gradle.properties:
# systemProp.http.proxyHost=...
# systemProp.http.proxyPort=8080
# systemProp.https.proxyHost=...
# systemProp.https.proxyPort=8080
if [ "$CUSTOM_ABIS_LIST" = true ]; then
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
else
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
fi
find $PYTORCH_ANDROID_DIR -type f -name *apk
find $PYTORCH_ANDROID_DIR -type f -name *apk | xargs echo "To install apk run: $ANDROID_HOME/platform-tools/adb install -r "

View File

@ -1,32 +0,0 @@
#!/bin/bash
###############################################################################
# This script tests the custom selective build flow for PyTorch Android, which
# optimizes library size by only including ops used by a specific model.
###############################################################################
set -eux
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
PYTORCH_ANDROID_DIR="${PYTORCH_DIR}/android"
BUILD_ROOT="${PYTORCH_DIR}/build_pytorch_android_custom"
source "${PYTORCH_ANDROID_DIR}/common.sh"
prepare_model_and_dump_root_ops() {
cd "${BUILD_ROOT}"
MODEL="${BUILD_ROOT}/MobileNetV2.pt"
ROOT_OPS="${BUILD_ROOT}/MobileNetV2.yaml"
python "${PYTORCH_ANDROID_DIR}/test_app/make_assets_custom.py"
cp "${MODEL}" "${PYTORCH_ANDROID_DIR}/test_app/app/src/main/assets/mobilenet2.pt"
}
# Start building
mkdir -p "${BUILD_ROOT}"
check_android_sdk
check_gradle
parse_abis_list "$@"
prepare_model_and_dump_root_ops
SELECTED_OP_LIST="${ROOT_OPS}" build_android
# TODO: change this to build test_app instead
$GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean assembleRelease

View File

@ -3,4 +3,3 @@ include ':app', ':pytorch_android', ':pytorch_android_torchvision', ':pytorch_ho
project(':pytorch_android_torchvision').projectDir = file('pytorch_android_torchvision')
project(':pytorch_host').projectDir = file('pytorch_android/host')
project(':test_app').projectDir = file('test_app/app')

View File

@ -1,9 +0,0 @@
local.properties
**/*.iml
.gradle
gradlew*
gradle/wrapper
.idea/*
.DS_Store
build
.externalNativeBuild

View File

@ -1,38 +0,0 @@
cmake_minimum_required(VERSION 3.5)
set(PROJECT_NAME pytorch_testapp_jni)
project(${PROJECT_NAME} CXX)
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.")
set(CMAKE_VERBOSE_MAKEFILE ON)
set(build_DIR ${CMAKE_SOURCE_DIR}/build)
set(pytorch_testapp_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
message(STATUS "ANDROID_STL:${ANDROID_STL}")
file(GLOB pytorch_testapp_SOURCES
${pytorch_testapp_cpp_DIR}/pytorch_testapp_jni.cpp
)
add_library(${PROJECT_NAME} SHARED
${pytorch_testapp_SOURCES}
)
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
target_compile_options(${PROJECT_NAME} PRIVATE
-fexceptions
)
set(BUILD_SUBDIR ${ANDROID_ABI})
target_include_directories(${PROJECT_NAME} PRIVATE
${PYTORCH_INCLUDE_DIRS}
)
find_library(PYTORCH_LIBRARY pytorch_jni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)
target_link_libraries(${PROJECT_NAME}
${PYTORCH_LIBRARY}
log)

View File

@ -1,190 +0,0 @@
apply plugin: 'com.android.application'
repositories {
jcenter()
maven {
url "https://oss.sonatype.org/content/repositories/snapshots"
}
flatDir {
dirs 'aars'
}
}
android {
configurations {
extractForNativeBuild
}
compileOptions {
sourceCompatibility 1.8
targetCompatibility 1.8
}
compileSdkVersion rootProject.compileSdkVersion
buildToolsVersion rootProject.buildToolsVersion
defaultConfig {
applicationId "org.pytorch.testapp"
minSdkVersion rootProject.minSdkVersion
targetSdkVersion rootProject.targetSdkVersion
versionCode 1
versionName "1.0"
ndk {
abiFilters ABI_FILTERS.split(",")
}
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//externalNativeBuild {
// cmake {
// abiFilters ABI_FILTERS.split(",")
// arguments "-DANDROID_STL=c++_shared"
// }
//}
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2q.pt\"")
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
buildConfigField("boolean", "NATIVE_BUILD", 'false')
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
buildConfigField(
"int",
"BUILD_LITE_INTERPRETER",
System.env.BUILD_LITE_INTERPRETER != null ? System.env.BUILD_LITE_INTERPRETER : "1"
)
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
}
buildTypes {
debug {
minifyEnabled false
debuggable true
}
release {
minifyEnabled false
}
}
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//externalNativeBuild {
// cmake {
// path "CMakeLists.txt"
// }
//}
flavorDimensions "model", "build", "activity"
productFlavors {
mnet {
dimension "model"
applicationIdSuffix ".mnet"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2.ptl\"")
addManifestPlaceholders([APP_NAME: "MNET"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"")
}
// NB: This is not working atm https://github.com/pytorch/pytorch/issues/102966
mnetVulkan {
dimension "model"
applicationIdSuffix ".mnet_vulkan"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2_vulkan.ptl\"")
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
addManifestPlaceholders([APP_NAME: "MNET_VULKAN"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"")
}
resnet18 {
dimension "model"
applicationIdSuffix ".resnet18"
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.ptl\"")
addManifestPlaceholders([APP_NAME: "RN18"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
}
local {
dimension "build"
}
nightly {
dimension "build"
}
aar {
dimension "build"
}
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//nativeBuild {
// dimension "build"
// buildConfigField("boolean", "NATIVE_BUILD", "true")
//}
camera {
dimension "activity"
addManifestPlaceholders([MAIN_ACTIVITY: "org.pytorch.testapp.CameraActivity"])
}
base {
dimension "activity"
sourceSets {
main {
java {
exclude 'org/pytorch/testapp/CameraActivity.java'
}
}
}
}
}
packagingOptions {
doNotStrip '**.so'
}
// Filtering for CI
if (!testAppAllVariantsEnabled.toBoolean()) {
variantFilter { variant ->
def names = variant.flavors*.name
if (names.contains("nightly")
|| names.contains("camera")
|| names.contains("aar")
|| names.contains("nativeBuild")) {
setIgnore(true)
}
}
}
}
tasks.all { task ->
// Disable externalNativeBuild for all but nativeBuild variant
if (task.name.startsWith('externalNativeBuild')
&& !task.name.contains('NativeBuild')) {
task.enabled = false
}
}
dependencies {
implementation 'com.android.support:appcompat-v7:28.0.0'
implementation 'com.facebook.soloader:nativeloader:0.10.5'
localImplementation project(':pytorch_android')
localImplementation project(':pytorch_android_torchvision')
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//nativeBuildImplementation(name: 'pytorch_android-release', ext: 'aar')
//nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
nightlyImplementation 'org.pytorch:pytorch_android:2.2.0-SNAPSHOT'
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:2.2.0-SNAPSHOT'
aarImplementation(name:'pytorch_android', ext:'aar')
aarImplementation(name:'pytorch_android_torchvision', ext:'aar')
aarImplementation 'com.facebook.soloader:nativeloader:0.10.5'
aarImplementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
def camerax_version = "1.0.0-alpha05"
cameraImplementation "androidx.camera:camera-core:$camerax_version"
cameraImplementation "androidx.camera:camera-camera2:$camerax_version"
cameraImplementation 'com.google.android.material:material:1.0.0-beta01'
}
task extractAARForNativeBuild {
doLast {
configurations.extractForNativeBuild.files.each {
def file = it.absoluteFile
copy {
from zipTree(file)
into "$buildDir/$file.name"
include "headers/**"
include "jni/**"
}
}
}
}
tasks.whenTaskAdded { task ->
if (task.name.contains('externalNativeBuild')) {
task.dependsOn(extractAARForNativeBuild)
}
}

View File

@ -1,27 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.pytorch.testapp">
<application
android:allowBackup="true"
android:label="${APP_NAME}"
android:supportsRtl="true"
android:theme="@style/AppTheme">
<activity android:name="${MAIN_ACTIVITY}">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
<uses-permission android:name="android.permission.CAMERA" />
<!--
Permissions required by the Snapdragon Profiler to collect GPU metrics.
-->
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
</manifest>

View File

@ -1,3 +0,0 @@
*
*/
!.gitignore

View File

@ -1,77 +0,0 @@
#include <android/log.h>
#include <pthread.h>
#include <unistd.h>
#include <cassert>
#include <cmath>
#include <vector>
#define ALOGI(...) \
__android_log_print(ANDROID_LOG_INFO, "PyTorchTestAppJni", __VA_ARGS__)
#define ALOGE(...) \
__android_log_print(ANDROID_LOG_ERROR, "PyTorchTestAppJni", __VA_ARGS__)
#include "jni.h"
#include <torch/script.h>
namespace pytorch_testapp_jni {
namespace {
template <typename T>
void log(const char* m, T t) {
std::ostringstream os;
os << t << std::endl;
ALOGI("%s %s", m, os.str().c_str());
}
struct JITCallGuard {
c10::InferenceMode guard;
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
};
} // namespace
static void loadAndForwardModel(JNIEnv* env, jclass, jstring jModelPath) {
const char* modelPath = env->GetStringUTFChars(jModelPath, 0);
assert(modelPath);
// To load torchscript model for mobile we need set these guards,
// because mobile build doesn't support features like autograd for smaller
// build size which is placed in `struct JITCallGuard` in this example. It may
// change in future, you can track the latest changes keeping an eye in
// android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
JITCallGuard guard;
torch::jit::Module module = torch::jit::load(modelPath);
module.eval();
torch::Tensor t = torch::randn({1, 3, 224, 224});
log("input tensor:", t);
c10::IValue t_out = module.forward({t});
log("output tensor:", t_out);
env->ReleaseStringUTFChars(jModelPath, modelPath);
}
} // namespace pytorch_testapp_jni
JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
JNIEnv* env;
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
return JNI_ERR;
}
jclass c =
env->FindClass("org/pytorch/testapp/LibtorchNativeClient$NativePeer");
if (c == nullptr) {
return JNI_ERR;
}
static const JNINativeMethod methods[] = {
{"loadAndForwardModel",
"(Ljava/lang/String;)V",
(void*)pytorch_testapp_jni::loadAndForwardModel},
};
int rc = env->RegisterNatives(
c, methods, sizeof(methods) / sizeof(JNINativeMethod));
if (rc != JNI_OK) {
return rc;
}
return JNI_VERSION_1_6;
}

View File

@ -1,214 +0,0 @@
package org.pytorch.testapp;
import android.Manifest;
import android.content.pm.PackageManager;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.SystemClock;
import android.util.Log;
import android.util.Size;
import android.view.TextureView;
import android.view.ViewStub;
import android.widget.TextView;
import android.widget.Toast;
import androidx.annotation.Nullable;
import androidx.annotation.UiThread;
import androidx.annotation.WorkerThread;
import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.CameraX;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageAnalysisConfig;
import androidx.camera.core.ImageProxy;
import androidx.camera.core.Preview;
import androidx.camera.core.PreviewConfig;
import androidx.core.app.ActivityCompat;
import java.nio.FloatBuffer;
import org.pytorch.IValue;
import org.pytorch.MemoryFormat;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
public class CameraActivity extends AppCompatActivity {
private static final String TAG = BuildConfig.LOGCAT_TAG;
private static final int TEXT_TRIM_SIZE = 4096;
private static final int REQUEST_CODE_CAMERA_PERMISSION = 200;
private static final String[] PERMISSIONS = {Manifest.permission.CAMERA};
private long mLastAnalysisResultTime;
protected HandlerThread mBackgroundThread;
protected Handler mBackgroundHandler;
protected Handler mUIHandler;
private TextView mTextView;
private StringBuilder mTextViewStringBuilder = new StringBuilder();
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_camera);
mTextView = findViewById(R.id.text);
mUIHandler = new Handler(getMainLooper());
startBackgroundThread();
if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA)
!= PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, PERMISSIONS, REQUEST_CODE_CAMERA_PERMISSION);
} else {
setupCameraX();
}
}
@Override
protected void onPostCreate(@Nullable Bundle savedInstanceState) {
super.onPostCreate(savedInstanceState);
startBackgroundThread();
}
protected void startBackgroundThread() {
mBackgroundThread = new HandlerThread("ModuleActivity");
mBackgroundThread.start();
mBackgroundHandler = new Handler(mBackgroundThread.getLooper());
}
@Override
protected void onDestroy() {
stopBackgroundThread();
super.onDestroy();
}
protected void stopBackgroundThread() {
mBackgroundThread.quitSafely();
try {
mBackgroundThread.join();
mBackgroundThread = null;
mBackgroundHandler = null;
} catch (InterruptedException e) {
Log.e(TAG, "Error on stopping background thread", e);
}
}
@Override
public void onRequestPermissionsResult(
int requestCode, String[] permissions, int[] grantResults) {
if (requestCode == REQUEST_CODE_CAMERA_PERMISSION) {
if (grantResults[0] == PackageManager.PERMISSION_DENIED) {
Toast.makeText(
this,
"You can't use image classification example without granting CAMERA permission",
Toast.LENGTH_LONG)
.show();
finish();
} else {
setupCameraX();
}
}
}
private static final int TENSOR_WIDTH = 224;
private static final int TENSOR_HEIGHT = 224;
private void setupCameraX() {
final TextureView textureView =
((ViewStub) findViewById(R.id.camera_texture_view_stub))
.inflate()
.findViewById(R.id.texture_view);
final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
final Preview preview = new Preview(previewConfig);
preview.setOnPreviewOutputUpdateListener(
new Preview.OnPreviewOutputUpdateListener() {
@Override
public void onUpdated(Preview.PreviewOutput output) {
textureView.setSurfaceTexture(output.getSurfaceTexture());
}
});
final ImageAnalysisConfig imageAnalysisConfig =
new ImageAnalysisConfig.Builder()
.setTargetResolution(new Size(TENSOR_WIDTH, TENSOR_HEIGHT))
.setCallbackHandler(mBackgroundHandler)
.setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
.build();
final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);
imageAnalysis.setAnalyzer(
new ImageAnalysis.Analyzer() {
@Override
public void analyze(ImageProxy image, int rotationDegrees) {
if (SystemClock.elapsedRealtime() - mLastAnalysisResultTime < 500) {
return;
}
final Result result = CameraActivity.this.analyzeImage(image, rotationDegrees);
if (result != null) {
mLastAnalysisResultTime = SystemClock.elapsedRealtime();
CameraActivity.this.runOnUiThread(
new Runnable() {
@Override
public void run() {
CameraActivity.this.handleResult(result);
}
});
}
}
});
CameraX.bindToLifecycle(this, preview, imageAnalysis);
}
private Module mModule;
private FloatBuffer mInputTensorBuffer;
private Tensor mInputTensor;
@WorkerThread
@Nullable
protected Result analyzeImage(ImageProxy image, int rotationDegrees) {
Log.i(TAG, String.format("analyzeImage(%s, %d)", image, rotationDegrees));
if (mModule == null) {
Log.i(TAG, "Loading module from asset '" + BuildConfig.MODULE_ASSET_NAME + "'");
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
mInputTensorBuffer = Tensor.allocateFloatBuffer(3 * TENSOR_WIDTH * TENSOR_HEIGHT);
mInputTensor =
Tensor.fromBlob(mInputTensorBuffer, new long[] {1, 3, TENSOR_WIDTH, TENSOR_HEIGHT});
}
final long startTime = SystemClock.elapsedRealtime();
TensorImageUtils.imageYUV420CenterCropToFloatBuffer(
image.getImage(),
rotationDegrees,
TENSOR_WIDTH,
TENSOR_HEIGHT,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
mInputTensorBuffer,
0,
MemoryFormat.CHANNELS_LAST);
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
final float[] scores = outputTensor.getDataAsFloatArray();
final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
return new Result(scores, moduleForwardDuration, analysisDuration);
}
@UiThread
protected void handleResult(Result result) {
int ixs[] = Utils.topK(result.scores, 1);
String message =
String.format(
"forwardDuration:%d class:%s",
result.moduleForwardDuration, Constants.IMAGENET_CLASSES[ixs[0]]);
Log.i(TAG, message);
mTextViewStringBuilder.insert(0, '\n').insert(0, message);
if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length());
}
mTextView.setText(mTextViewStringBuilder.toString());
}
}

View File

@ -1,22 +0,0 @@
package org.pytorch.testapp;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
public final class LibtorchNativeClient {
public static void loadAndForwardModel(final String modelPath) {
NativePeer.loadAndForwardModel(modelPath);
}
private static class NativePeer {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_testapp_jni");
}
private static native void loadAndForwardModel(final String modelPath);
}
}

View File

@ -1,171 +0,0 @@
package org.pytorch.testapp;
import android.content.Context;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.SystemClock;
import android.util.Log;
import android.widget.TextView;
import androidx.annotation.Nullable;
import androidx.annotation.UiThread;
import androidx.annotation.WorkerThread;
import androidx.appcompat.app.AppCompatActivity;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.FloatBuffer;
import org.pytorch.Device;
import org.pytorch.IValue;
import org.pytorch.MemoryFormat;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;
public class MainActivity extends AppCompatActivity {
private static final String TAG = BuildConfig.LOGCAT_TAG;
private static final int TEXT_TRIM_SIZE = 4096;
private TextView mTextView;
protected HandlerThread mBackgroundThread;
protected Handler mBackgroundHandler;
private Module mModule;
private FloatBuffer mInputTensorBuffer;
private Tensor mInputTensor;
private StringBuilder mTextViewStringBuilder = new StringBuilder();
private final Runnable mModuleForwardRunnable =
new Runnable() {
@Override
public void run() {
final Result result = doModuleForward();
runOnUiThread(
new Runnable() {
@Override
public void run() {
handleResult(result);
if (mBackgroundHandler != null) {
mBackgroundHandler.post(mModuleForwardRunnable);
}
}
});
}
};
public static String assetFilePath(Context context, String assetName) {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
Log.e(TAG, "Error process asset " + assetName + " to file path");
}
return null;
}
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
if (BuildConfig.NATIVE_BUILD) {
final String modelFileAbsoluteFilePath =
new File(assetFilePath(this, BuildConfig.MODULE_ASSET_NAME)).getAbsolutePath();
LibtorchNativeClient.loadAndForwardModel(modelFileAbsoluteFilePath);
return;
}
setContentView(R.layout.activity_main);
mTextView = findViewById(R.id.text);
startBackgroundThread();
mBackgroundHandler.post(mModuleForwardRunnable);
}
protected void startBackgroundThread() {
mBackgroundThread = new HandlerThread(TAG + "_bg");
mBackgroundThread.start();
mBackgroundHandler = new Handler(mBackgroundThread.getLooper());
}
@Override
protected void onDestroy() {
stopBackgroundThread();
super.onDestroy();
}
protected void stopBackgroundThread() {
mBackgroundThread.quitSafely();
try {
mBackgroundThread.join();
mBackgroundThread = null;
mBackgroundHandler = null;
} catch (InterruptedException e) {
Log.e(TAG, "Error stopping background thread", e);
}
}
@WorkerThread
@Nullable
protected Result doModuleForward() {
if (mModule == null) {
final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE;
long numElements = 1;
for (int i = 0; i < shape.length; i++) {
numElements *= shape[i];
}
mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
mInputTensor =
Tensor.fromBlob(
mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE, MemoryFormat.CHANNELS_LAST);
PyTorchAndroid.setNumThreads(1);
mModule =
BuildConfig.USE_VULKAN_DEVICE
? PyTorchAndroid.loadModuleFromAsset(
getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
: PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
}
final long startTime = SystemClock.elapsedRealtime();
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
final float[] scores = outputTensor.getDataAsFloatArray();
final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
return new Result(scores, moduleForwardDuration, analysisDuration);
}
static class Result {
private final float[] scores;
private final long totalDuration;
private final long moduleForwardDuration;
public Result(float[] scores, long moduleForwardDuration, long totalDuration) {
this.scores = scores;
this.moduleForwardDuration = moduleForwardDuration;
this.totalDuration = totalDuration;
}
}
@UiThread
protected void handleResult(Result result) {
String message = String.format("forwardDuration:%d", result.moduleForwardDuration);
mTextViewStringBuilder.insert(0, '\n').insert(0, message);
if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length());
}
mTextView.setText(mTextViewStringBuilder.toString());
}
}

View File

@ -1,14 +0,0 @@
package org.pytorch.testapp;
class Result {
public final float[] scores;
public final long totalDuration;
public final long moduleForwardDuration;
public Result(float[] scores, long moduleForwardDuration, long totalDuration) {
this.scores = scores;
this.moduleForwardDuration = moduleForwardDuration;
this.totalDuration = totalDuration;
}
}

View File

@ -1,28 +0,0 @@
package org.pytorch.testapp;
import java.util.Arrays;
public class Utils {
public static int[] topK(float[] a, final int topk) {
float values[] = new float[topk];
Arrays.fill(values, -Float.MAX_VALUE);
int ixs[] = new int[topk];
Arrays.fill(ixs, -1);
for (int i = 0; i < a.length; i++) {
for (int j = 0; j < topk; j++) {
if (a[i] > values[j]) {
for (int k = topk - 1; k >= j + 1; k--) {
values[k] = values[k - 1];
ixs[k] = ixs[k - 1];
}
values[j] = a[i];
ixs[j] = i;
break;
}
}
}
return ixs;
}
}

View File

@ -1,23 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<FrameLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".CameraActivity">
<ViewStub
android:id="@+id/camera_texture_view_stub"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout="@layout/texture_view"/>
<TextView
android:id="@+id/text"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_gravity="top"
android:textSize="16sp"
android:textStyle="bold"
android:textColor="#ff0000"/>
</FrameLayout>

View File

@ -1,17 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<TextView
android:id="@+id/text"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_gravity="top"
android:textSize="14sp"
android:background="@android:color/black"
android:textColor="@android:color/white" />
</FrameLayout>

View File

@ -1,5 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<TextureView xmlns:android="http://schemas.android.com/apk/res/android"
android:id="@+id/texture_view"
android:layout_width="match_parent"
android:layout_height="0dp" />

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 KiB

View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="colorPrimary">#008577</color>
<color name="colorPrimaryDark">#00574B</color>
<color name="colorAccent">#D81B60</color>
</resources>

View File

@ -1,3 +0,0 @@
<resources>
<string name="app_name">PyTest</string>
</resources>

View File

@ -1,11 +0,0 @@
<resources>
<!-- Base application theme. -->
<style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
<!-- Customize your theme here. -->
<item name="colorPrimary">@color/colorPrimary</item>
<item name="colorPrimaryDark">@color/colorPrimaryDark</item>
<item name="colorAccent">@color/colorAccent</item>
</style>
</resources>

View File

@ -1,24 +0,0 @@
from torchvision import models
import torch
print(torch.version.__version__)
resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
resnet18.eval()
resnet18_traced = torch.jit.trace(resnet18, torch.rand(1, 3, 224, 224)).save(
"app/src/main/assets/resnet18.pt"
)
resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
resnet50.eval()
torch.jit.trace(resnet50, torch.rand(1, 3, 224, 224)).save(
"app/src/main/assets/resnet50.pt"
)
mobilenet2q = models.quantization.mobilenet_v2(pretrained=True, quantize=True)
mobilenet2q.eval()
torch.jit.trace(mobilenet2q, torch.rand(1, 3, 224, 224)).save(
"app/src/main/assets/mobilenet2q.pt"
)

View File

@ -1,27 +0,0 @@
"""
This is a script for PyTorch Android custom selective build test. It prepares
MobileNetV2 TorchScript model, and dumps root ops used by the model for custom
build script to create a tailored build which only contains these used ops.
"""
import yaml
from torchvision import models
import torch
# Download and trace the model.
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
model.eval()
example = torch.rand(1, 3, 224, 224)
# TODO: create script model with `torch.jit.script`
traced_script_module = torch.jit.trace(model, example)
# Save traced TorchScript model.
traced_script_module.save("MobileNetV2.pt")
# Dump root ops used by the model (for custom build optimization).
ops = torch.jit.export_opnames(traced_script_module)
with open("MobileNetV2.yaml", "w") as output:
yaml.dump(ops, output)