Revert "Delete TorchScript based Android demo app and point to ExecuTorch (#153633)"

This reverts commit b22f01fcb9d69bb7d77e08d69004c7265ef7fa4a.

Reverted https://github.com/pytorch/pytorch/pull/153633 on behalf of https://github.com/malfet due to But libtorch build regressions are real, fbjni is still used for C++ builds ([comment](https://github.com/pytorch/pytorch/pull/153633#issuecomment-2884951805))
This commit is contained in:
PyTorch MergeBot
2025-05-15 20:16:05 +00:00
parent b03e4f53d2
commit ae0e8f0c73
118 changed files with 7669 additions and 3 deletions

View File

@ -1,5 +1,240 @@
# Android Demo App
# Android
Please refer to [pytorch-labs/executorch-examples](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch).
## Demo applications and tutorials
Please join our [Discord](https://discord.com/channels/1334270993966825602/1349854760299270284) for any questions.
Demo applications with code walk-through can be find in [this github repo](https://github.com/pytorch/android-demo-app).
## Publishing
##### Release
Release artifacts are published to jcenter:
```groovy
repositories {
jcenter()
}
# lite interpreter build
dependencies {
implementation 'org.pytorch:pytorch_android_lite:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.10.0'
}
# full jit build
dependencies {
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
}
```
##### Nightly
Nightly(snapshots) builds are published every night from `master` branch to [nexus sonatype snapshots repository](https://oss.sonatype.org/#nexus-search;quick~pytorch_android)
To use them repository must be specified explicitly:
```groovy
repositories {
maven {
url "https://oss.sonatype.org/content/repositories/snapshots"
}
}
# lite interpreter build
dependencies {
...
implementation 'org.pytorch:pytorch_android_lite:1.12.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.12.0-SNAPSHOT'
...
}
# full jit build
dependencies {
...
implementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_torchvision:1.12.0-SNAPSHOT'
...
}
```
The current nightly(snapshots) version is the value of `VERSION_NAME` in `gradle.properties` in current folder, at this moment it is `1.8.0-SNAPSHOT`.
## Building PyTorch Android from Source
In some cases you might want to use a local build of pytorch android, for example you may build custom libtorch binary with another set of operators or to make local changes.
For this you can use `./scripts/build_pytorch_android.sh` script.
```bash
git clone https://github.com/pytorch/pytorch.git
cd pytorch
git submodule update --init --recursive
bash ./scripts/build_pytorch_android.sh
```
The workflow contains several steps:
1\. Build libtorch for android for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64)
2\. Create symbolic links to the results of those builds:
`android/pytorch_android/src/main/jniLibs/${abi}` to the directory with output libraries
`android/pytorch_android/src/main/cpp/libtorch_include/${abi}` to the directory with headers. These directories are used to build `libpytorch.so` library that will be loaded on android device.
3\. And finally run `gradle` in `android/pytorch_android` directory with task `assembleRelease`
Script requires that Android SDK, Android NDK and gradle are installed.
They are specified as environment variables:
`ANDROID_HOME` - path to [Android SDK](https://developer.android.com/studio/command-line/sdkmanager.html)
`ANDROID_NDK` - path to [Android NDK](https://developer.android.com/studio/projects/install-ndk). It's recommended to use NDK 21.x.
`GRADLE_HOME` - path to [gradle](https://gradle.org/releases/)
After successful build you should see the result as aar file:
```bash
$ find pytorch_android/build/ -type f -name *aar
pytorch_android/build/outputs/aar/pytorch_android.aar
pytorch_android_torchvision/build/outputs/aar/pytorch_android.aar
```
It can be used directly in android projects, as a gradle dependency:
```groovy
allprojects {
repositories {
flatDir {
dirs 'libs'
}
}
}
dependencies {
implementation(name:'pytorch_android', ext:'aar')
implementation(name:'pytorch_android_torchvision', ext:'aar')
...
implementation 'com.facebook.soloader:nativeloader:0.10.5'
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
}
```
We also have to add all transitive dependencies of our aars.
As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.5'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them.
(In case of using maven dependencies they are added automatically from `pom.xml`).
You can check out [test app example](https://github.com/pytorch/pytorch/blob/master/android/test_app/app/build.gradle) that uses aars directly.
## Linking to prebuilt libtorch library from gradle dependency
In some cases, you may want to use libtorch from your android native build.
You can do it without building libtorch android, using native libraries from PyTorch android gradle dependency.
For that, you will need to add the next lines to your gradle build.
```groovy
android {
...
configurations {
extractForNativeBuild
}
...
compileOptions {
externalNativeBuild {
cmake {
arguments "-DANDROID_STL=c++_shared"
}
}
}
...
externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}
}
dependencies {
extractForNativeBuild('org.pytorch:pytorch_android:1.10.0')
}
task extractAARForNativeBuild {
doLast {
configurations.extractForNativeBuild.files.each {
def file = it.absoluteFile
copy {
from zipTree(file)
into "$buildDir/$file.name"
include "headers/**"
include "jni/**"
}
}
}
}
tasks.whenTaskAdded { task ->
if (task.name.contains('externalNativeBuild')) {
task.dependsOn(extractAARForNativeBuild)
}
}
```
pytorch_android aar contains headers to link in `headers` folder and native libraries in `jni/$ANDROID_ABI/`.
As PyTorch native libraries use `ANDROID_STL` - we should use `ANDROID_STL=c++_shared` to have only one loaded binary of STL.
The added task will unpack them to gradle build directory.
In your native build you can link to them adding these lines to your CMakeLists.txt:
```cmake
# Relative path of gradle build directory to CMakeLists.txt
set(build_DIR ${CMAKE_SOURCE_DIR}/build)
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
set(BUILD_SUBDIR ${ANDROID_ABI})
target_include_directories(${PROJECT_NAME} PRIVATE
${PYTORCH_INCLUDE_DIRS}
)
find_library(PYTORCH_LIBRARY pytorch_jni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)
find_library(FBJNI_LIBRARY fbjni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)
target_link_libraries(${PROJECT_NAME}
${PYTORCH_LIBRARY})
${FBJNI_LIBRARY})
```
If your CMakeLists.txt file is located in the same directory as your build.gradle, `set(build_DIR ${CMAKE_SOURCE_DIR}/build)` should work for you. But if you have another location of it, you may need to change it.
After that, you can use libtorch C++ API from your native code.
```cpp
#include <string>
#include <ATen/NativeFunctions.h>
#include <torch/script.h>
namespace pytorch_testapp_jni {
namespace {
struct JITCallGuard {
c10::InferenceMode guard;
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
};
}
void loadAndForwardModel(const std::string& modelPath) {
JITCallGuard guard;
torch::jit::Module module = torch::jit::load(modelPath);
module.eval();
torch::Tensor t = torch::randn({1, 3, 224, 224});
c10::IValue t_out = module.forward({t});
}
}
```
To load torchscript model for mobile we need some special setup which is placed in `struct JITCallGuard` in this example. It may change in future, you can track the latest changes keeping an eye in our [pytorch android jni code]([https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp#L28)
[Example of linking to libtorch from aar](https://github.com/pytorch/pytorch/tree/master/android/test_app)
## PyTorch Android API Javadoc
You can find more details about the PyTorch Android API in the [Javadoc](https://pytorch.org/javadoc/).

40
android/build.gradle Normal file
View File

@ -0,0 +1,40 @@
allprojects {
buildscript {
ext {
minSdkVersion = 21
targetSdkVersion = 28
compileSdkVersion = 28
buildToolsVersion = '28.0.3'
coreVersion = "1.2.0"
extJUnitVersion = "1.1.1"
runnerVersion = "1.2.0"
rulesVersion = "1.2.0"
junitVersion = "4.12"
fbjniJavaOnlyVersion = "0.2.2"
soLoaderNativeLoaderVersion = "0.10.5"
}
repositories {
google()
mavenLocal()
mavenCentral()
jcenter()
}
dependencies {
classpath 'com.android.tools.build:gradle:4.1.2'
classpath 'com.vanniktech:gradle-maven-publish-plugin:0.14.2'
}
}
repositories {
google()
jcenter()
}
}
ext.deps = [
jsr305: 'com.google.code.findbugs:jsr305:3.0.1',
]

30
android/build_test_app.sh Executable file
View File

@ -0,0 +1,30 @@
#!/bin/bash
set -eux
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android
echo "PYTORCH_DIR:$PYTORCH_DIR"
source "$PYTORCH_ANDROID_DIR/common.sh"
check_android_sdk
check_gradle
parse_abis_list "$@"
build_android
# To set proxy for gradle add following lines to ./gradle/gradle.properties:
# systemProp.http.proxyHost=...
# systemProp.http.proxyPort=8080
# systemProp.https.proxyHost=...
# systemProp.https.proxyPort=8080
if [ "$CUSTOM_ABIS_LIST" = true ]; then
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
else
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
fi
find $PYTORCH_ANDROID_DIR -type f -name *apk
find $PYTORCH_ANDROID_DIR -type f -name *apk | xargs echo "To install apk run: $ANDROID_HOME/platform-tools/adb install -r "

View File

@ -0,0 +1,32 @@
#!/bin/bash
###############################################################################
# This script tests the custom selective build flow for PyTorch Android, which
# optimizes library size by only including ops used by a specific model.
###############################################################################
set -eux
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
PYTORCH_ANDROID_DIR="${PYTORCH_DIR}/android"
BUILD_ROOT="${PYTORCH_DIR}/build_pytorch_android_custom"
source "${PYTORCH_ANDROID_DIR}/common.sh"
prepare_model_and_dump_root_ops() {
cd "${BUILD_ROOT}"
MODEL="${BUILD_ROOT}/MobileNetV2.pt"
ROOT_OPS="${BUILD_ROOT}/MobileNetV2.yaml"
python "${PYTORCH_ANDROID_DIR}/test_app/make_assets_custom.py"
cp "${MODEL}" "${PYTORCH_ANDROID_DIR}/test_app/app/src/main/assets/mobilenet2.pt"
}
# Start building
mkdir -p "${BUILD_ROOT}"
check_android_sdk
check_gradle
parse_abis_list "$@"
prepare_model_and_dump_root_ops
SELECTED_OP_LIST="${ROOT_OPS}" build_android
# TODO: change this to build test_app instead
$GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean assembleRelease

74
android/common.sh Normal file
View File

@ -0,0 +1,74 @@
#!/bin/bash
set -eux
##############################################################################
# Common util functions for Android build scripts.
##############################################################################
if [ -z "$PYTORCH_DIR" ]; then
echo "PYTORCH_DIR not set!"
exit 1
fi
retry () {
"$@" || (sleep 10 && "$@") || (sleep 20 && "$@") || (sleep 40 && "$@")
}
check_android_sdk() {
if [ -z "$ANDROID_HOME" ]; then
echo "ANDROID_HOME not set; please set it to Android sdk directory"
exit 1
fi
if [ ! -d "$ANDROID_HOME" ]; then
echo "ANDROID_HOME not a directory; did you install it under $ANDROID_HOME?"
exit 1
fi
echo "ANDROID_HOME:$ANDROID_HOME"
}
check_gradle() {
GRADLE_PATH=$PYTORCH_DIR/android/gradlew
echo "GRADLE_PATH:$GRADLE_PATH"
}
parse_abis_list() {
# sync with https://github.com/pytorch/pytorch/blob/0ca0e02685a9d033ac4f04e2fa5c8ba6dbc5ae50/android/gradle.properties#L1
ABIS_LIST="armeabi-v7a,arm64-v8a,x86,x86_64"
CUSTOM_ABIS_LIST=false
if [ $# -gt 0 ]; then
ABIS_LIST=$1
CUSTOM_ABIS_LIST=true
fi
echo "ABIS_LIST:$ABIS_LIST"
echo "CUSTOM_ABIS_LIST:$CUSTOM_ABIS_LIST"
}
build_android() {
PYTORCH_ANDROID_DIR="$PYTORCH_DIR/android"
BUILD_ROOT="${BUILD_ROOT:-$PYTORCH_DIR}"
echo "BUILD_ROOT:$BUILD_ROOT"
LIB_DIR="$PYTORCH_ANDROID_DIR/pytorch_android/src/main/jniLibs"
INCLUDE_DIR="$PYTORCH_ANDROID_DIR/pytorch_android/src/main/cpp/libtorch_include"
# These directories only contain symbolic links.
rm -rf "$LIB_DIR" && mkdir -p "$LIB_DIR"
rm -rf "$INCLUDE_DIR" && mkdir -p "$INCLUDE_DIR"
for abi in $(echo "$ABIS_LIST" | tr ',' '\n')
do
echo "abi:$abi"
ANDROID_BUILD_ROOT="$BUILD_ROOT/build_android_$abi"
ANDROID_ABI="$abi" \
BUILD_ROOT="$ANDROID_BUILD_ROOT" \
"$PYTORCH_DIR/scripts/build_android.sh" \
-DANDROID_CCACHE="$(which ccache)" \
-DUSE_LITE_INTERPRETER_PROFILER="OFF"
echo "$abi build output lib,include at $ANDROID_BUILD_ROOT/install"
ln -s "$ANDROID_BUILD_ROOT/install/lib" "$LIB_DIR/$abi"
ln -s "$ANDROID_BUILD_ROOT/install/include" "$INCLUDE_DIR/$abi"
done
}

26
android/gradle.properties Normal file
View File

@ -0,0 +1,26 @@
ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64
VERSION_NAME=2.2.0-SNAPSHOT
GROUP=org.pytorch
MAVEN_GROUP=org.pytorch
SONATYPE_STAGING_PROFILE=orgpytorch
POM_URL=https://github.com/pytorch/pytorch/tree/master/android
POM_SCM_URL=https://github.com/pytorch/pytorch.git
POM_SCM_CONNECTION=scm:git:https://github.com/pytorch/pytorch
POM_SCM_DEV_CONNECTION=scm:git:git@github.com:pytorch/pytorch.git
POM_LICENSE_NAME=BSD 3-Clause
POM_LICENSE_URL=https://github.com/pytorch/pytorch/blob/master/LICENSE
POM_ISSUES_URL=https://github.com/pytorch/pytorch/issues
POM_LICENSE_DIST=repo
POM_DEVELOPER_ID=pytorch
POM_DEVELOPER_NAME=pytorch
# Gradle internals
org.gradle.internal.repository.max.retries=1
org.gradle.jvmargs=-XX:MaxMetaspaceSize=1024m
android.useAndroidX=true
android.enableJetifier=true
nativeLibsDoNotStrip=false
testAppAllVariantsEnabled=false

View File

@ -0,0 +1,11 @@
afterEvaluate { project ->
if (POM_PACKAGING == 'aar') {
task headersJar(type: Jar) {
archiveClassifier.set('headers')
from("$rootDir/cxx/") {
include '**/*.h'
}
}
artifacts.add('archives', headersJar)
}
}

View File

@ -0,0 +1,3 @@
apply from: rootProject.file('gradle/android_tasks.gradle')
apply plugin: 'com.vanniktech.maven.publish'

Binary file not shown.

View File

@ -0,0 +1,5 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

172
android/gradlew vendored Executable file
View File

@ -0,0 +1,172 @@
#!/usr/bin/env sh
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS=""
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn () {
echo "$*"
}
die () {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin, switch paths to Windows format before running java
if $cygwin ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=$((i+1))
done
case $i in
(0) set -- ;;
(1) set -- "$args0" ;;
(2) set -- "$args0" "$args1" ;;
(3) set -- "$args0" "$args1" "$args2" ;;
(4) set -- "$args0" "$args1" "$args2" "$args3" ;;
(5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
(6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
(7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
(8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
(9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Escape application args
save () {
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
echo " "
}
APP_ARGS=$(save "$@")
# Collect all arguments for the java command, following the shell quoting and substitution rules
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
cd "$(dirname "$0")"
fi
exec "$JAVACMD" "$@"

84
android/gradlew.bat vendored Normal file
View File

@ -0,0 +1,84 @@
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS=
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto init
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto init
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:init
@rem Get command-line arguments, handling Windows variants
if not "%OS%" == "Windows_NT" goto win9xME_args
:win9xME_args
@rem Slurp the command line arguments.
set CMD_LINE_ARGS=
set _SKIP=2
:win9xME_args_slurp
if "x%~1" == "x" goto execute
set CMD_LINE_ARGS=%*
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega

1
android/libs/fbjni Submodule

Submodule android/libs/fbjni added at 7e1e1fe385

View File

@ -0,0 +1,184 @@
cmake_minimum_required(VERSION 3.5)
option(BUILD_LITE_INTERPRETER "Master flag to build pytorch_jni_lite" ON)
message(
STATUS
"BUILD_LITE_INTERPRETER (pytorch_jni_lite): ${BUILD_LITE_INTERPRETER}")
if(BUILD_LITE_INTERPRETER)
project(pytorch_jni_lite CXX)
set(PYTORCH_JNI_TARGET pytorch_jni_lite)
else()
project(pytorch_jni CXX)
set(PYTORCH_JNI_TARGET pytorch_jni)
endif()
include(GNUInstallDirs)
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.")
set(CMAKE_VERBOSE_MAKEFILE ON)
message(STATUS "ANDROID_STL:${ANDROID_STL}")
set(TRACE_ENABLED OFF)
if(DEFINED ENV{TRACE_ENABLED})
if($ENV{TRACE_ENABLED} STREQUAL "1")
message(STATUS "TRACE_ENABLED ON")
set(TRACE_ENABLED ON)
endif()
endif()
if(NOT TRACE_ENABLED)
message(STATUS "TRACE_ENABLED OFF")
endif()
set(USE_VULKAN OFF)
set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
if(ANDROID_ABI)
set(USE_VULKAN ON)
set(libtorch_include_DIR ${pytorch_android_DIR}/libtorch_include/${ANDROID_ABI})
set(BUILD_SUBDIR ${ANDROID_ABI})
elseif(BUILD_LIBTORCH_WITH_JNI)
# Don't need LIBTORCH_HOME if we're building from within PyTorch.
else()
# Building against a pre-built libtorch.
if(NOT LIBTORCH_HOME)
message(FATAL_ERROR
"pytorch_android requires LIBTORCH_HOME to be defined for non-Android builds.")
endif()
set(libtorch_include_DIR ${LIBTORCH_HOME}/include)
link_directories(${LIBTORCH_HOME}/lib)
set(BUILD_SUBDIR host)
endif()
message(STATUS "libtorch dir:${libtorch_DIR}")
configure_file(
${pytorch_android_DIR}/cmake_macros.h.in
${pytorch_android_DIR}/cmake_macros.h)
if(BUILD_LITE_INTERPRETER)
file(GLOB pytorch_android_SOURCES
${pytorch_android_DIR}/pytorch_jni_lite.cpp
${pytorch_android_DIR}/pytorch_jni_common.cpp
${pytorch_android_DIR}/pytorch_jni_common.h
)
else()
file(GLOB pytorch_android_SOURCES
${pytorch_android_DIR}/pytorch_jni_jit.cpp
${pytorch_android_DIR}/pytorch_jni_common.cpp
${pytorch_android_DIR}/pytorch_jni_common.h
)
endif()
add_library(${PYTORCH_JNI_TARGET} SHARED ${pytorch_android_SOURCES})
if(APPLE)
# Need to add rpath so dlopen can find dependencies.
add_custom_command(TARGET pytorch_jni
POST_BUILD COMMAND
${CMAKE_INSTALL_NAME_TOOL} -add_rpath "@loader_path"
$<TARGET_FILE:pytorch_jni>)
endif()
target_compile_options(${PYTORCH_JNI_TARGET} PRIVATE
-fexceptions
)
target_include_directories(${PYTORCH_JNI_TARGET} BEFORE
PUBLIC $<BUILD_INTERFACE:${libtorch_include_DIR}>)
set(fbjni_DIR ${CMAKE_CURRENT_LIST_DIR}/../libs/fbjni/)
set(fbjni_BUILD_DIR ${CMAKE_BINARY_DIR}/fbjni/${BUILD_SUBDIR})
add_subdirectory(${fbjni_DIR} ${fbjni_BUILD_DIR})
# ---[ Vulkan deps
if(USE_VULKAN)
set(Vulkan_LIBS)
set(Vulkan_INCLUDES)
include(${CMAKE_CURRENT_LIST_DIR}/../../cmake/VulkanDependencies.cmake)
endif()
if(ANDROID_ABI)
function(import_static_lib name)
add_library(${name} STATIC IMPORTED)
set_property(
TARGET ${name}
PROPERTY IMPORTED_LOCATION
${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/${name}.a)
endfunction(import_static_lib)
import_static_lib(libtorch)
import_static_lib(libtorch_cpu)
import_static_lib(libc10)
import_static_lib(libnnpack)
import_static_lib(libXNNPACK)
import_static_lib(libmicrokernels-prod)
import_static_lib(libpytorch_qnnpack)
import_static_lib(libpthreadpool)
import_static_lib(libeigen_blas)
import_static_lib(libcpuinfo)
import_static_lib(libclog)
# Link most things statically on Android.
set(pytorch_jni_LIBS
fbjni
-Wl,--gc-sections
-Wl,--whole-archive
libtorch
libtorch_cpu
-Wl,--no-whole-archive
libc10
libnnpack
libXNNPACK
libmicrokernels-prod
libpytorch_qnnpack
libpthreadpool
libeigen_blas
libcpuinfo
libclog
)
else()
# Prefer dynamic linking on the host
set(pytorch_jni_LIBS
fbjni
torch
torch_cpu
c10
cpuinfo
)
if(USE_NNPACK)
list(APPEND pytorch_jni_LIBS nnpack)
endif()
if(USE_XNNPACK)
list(APPEND pytorch_jni_LIBS XNNPACK)
list(APPEND pytorch_jni_LIBS microkernels-prod)
endif()
if(USE_SYSTEM_PTHREADPOOL)
list(APPEND pytorch_jni_LIBS pthreadpool)
endif()
if(USE_PYTORCH_QNNPACK)
list(APPEND pytorch_jni_LIBS pytorch_qnnpack)
list(APPEND pytorch_jni_LIBS clog)
endif()
endif()
if(USE_VULKAN)
list(APPEND pytorch_jni_LIBS ${Vulkan_LIBS})
endif()
target_link_libraries(${PYTORCH_JNI_TARGET} ${pytorch_jni_LIBS})
install(TARGETS ${PYTORCH_JNI_TARGET}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) #For windows
if(MSVC)
install(FILES $<TARGET_PDB_FILE:pytorch_jni> DESTINATION ${CMAKE_INSTALL_LIBDIR} OPTIONAL)
install(TARGETS ${PYTORCH_JNI_TARGET} DESTINATION ${CMAKE_INSTALL_LIBDIR})
endif()

View File

@ -0,0 +1,163 @@
apply plugin: 'com.android.library'
apply plugin: 'maven'
android {
compileSdkVersion rootProject.compileSdkVersion
buildToolsVersion rootProject.buildToolsVersion
defaultConfig {
minSdkVersion rootProject.minSdkVersion
targetSdkVersion rootProject.targetSdkVersion
versionCode 0
versionName "0.1"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
ndk {
abiFilters ABI_FILTERS.split(",")
}
externalNativeBuild {
cmake {
if(System.env.BUILD_LITE_INTERPRETER == '0') {
arguments "-DANDROID_STL=c++_shared", "-DBUILD_LITE_INTERPRETER=OFF", "-DUSE_LITE_INTERPRETER_PROFILER=OFF"
} else {
arguments "-DANDROID_STL=c++_shared", "-DUSE_LITE_INTERPRETER_PROFILER=OFF"
}
}
}
}
buildTypes {
debug {
minifyEnabled false
debuggable true
}
release {
minifyEnabled false
}
}
sourceSets {
main {
java {
if(System.env.BUILD_LITE_INTERPRETER == '0') {
println 'Build pytorch_jni'
exclude 'org/pytorch/LiteModuleLoader.java'
exclude 'org/pytorch/LiteNativePeer.java'
} else {
println 'Build pytorch_jni_lite'
}
}
jniLibs.srcDirs = ['src/main/jniLibs']
manifest.srcFile 'src/main/AndroidManifest.xml'
}
androidTest {
java {
if(System.env.BUILD_LITE_INTERPRETER == '0') {
println 'Build test for full jit (pytorch_jni)'
exclude 'org/pytorch/PytorchHostTests.java'
exclude 'org/pytorch/PytorchLiteInstrumentedTests.java'
exclude 'org/pytorch/suite/PytorchLiteInstrumentedTestSuite.java'
} else {
println 'Build test for lite interpreter (pytorch_jni_lite)'
exclude 'org/pytorch/PytorchHostTests.java'
exclude 'org/pytorch/PytorchInstrumentedTests.java'
exclude 'org/pytorch/suite/PytorchInstrumentedTestSuite.java'
}
}
}
}
externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}
packagingOptions {
if (nativeLibsDoNotStrip.toBoolean()) {
doNotStrip "**/*.so"
logger.warn('WARNING: nativeLibsDoNotStrip==true; debug symbols included')
}
}
useLibrary 'android.test.runner'
useLibrary 'android.test.base'
useLibrary 'android.test.mock'
}
dependencies {
implementation 'com.facebook.fbjni:fbjni-java-only:' + rootProject.fbjniJavaOnlyVersion
implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion
testImplementation 'junit:junit:' + rootProject.junitVersion
testImplementation 'androidx.test:core:' + rootProject.coreVersion
androidTestImplementation 'junit:junit:' + rootProject.junitVersion
androidTestImplementation 'androidx.test:core:' + rootProject.coreVersion
androidTestImplementation 'androidx.test.ext:junit:' + rootProject.extJUnitVersion
androidTestImplementation 'androidx.test:rules:' + rootProject.rulesVersion
androidTestImplementation 'androidx.test:runner:' + rootProject.runnerVersion
}
apply from: rootProject.file('gradle/release.gradle')
task sourcesJar(type: Jar) {
from android.sourceSets.main.java.srcDirs
classifier = 'sources'
}
def getLibtorchHeadersDir() {
def abi = ABI_FILTERS.split(",")[0]
return "$rootDir/pytorch_android/src/main/cpp/libtorch_include/$abi"
}
afterEvaluate {
if (POM_PACKAGING == 'aar') {
android.libraryVariants.all { variant ->
variant.outputs.each { output ->
File f = output.outputFile
if (f.name.endsWith(".aar")) {
output.assemble.finalizedBy addFolderToAarTask(
"addHeadersToAar" + variant.name,
f.path,
getLibtorchHeadersDir(),
"headers")
}
}
}
}
}
tasks.whenTaskAdded { task ->
if (task.name.startsWith("bundle") && task.name.endsWith("Aar")) {
doLast {
addFolderToAar("addHeadersTo" + task.name, task.archivePath, getLibtorchHeadersDir(), 'headers')
}
}
}
def addFolderToAarTask(taskName, aarPath, folderPath, folderPathInAar) {
return tasks.register(taskName) {
doLast {
addFolderToAar(taskName, aarPath, folderPath, folderPathInAar)
}
}
}
def addFolderToAar(taskName, aarPath, folderPath, folderPathInAar) {
def tmpDir = file("${buildDir}/${taskName}")
tmpDir.mkdir()
def tmpDirFolder = file("${tmpDir.path}/${folderPathInAar}")
tmpDirFolder.mkdir()
copy {
from zipTree(aarPath)
into tmpDir
}
copy {
from fileTree(folderPath)
into tmpDirFolder
}
ant.zip(destfile: aarPath) {
fileset(dir: tmpDir.path)
}
delete tmpDir
}
artifacts.add('archives', sourcesJar)

View File

@ -0,0 +1,20 @@
#include <torch/csrc/jit/api/module.h>
#include <torch/jit.h>
#include <torch/script.h>
#include <fstream>
#include <iostream>
#include <string>
int main(int argc, char* argv[]) {
std::string input_file_path{argv[1]};
std::string output_file_path{argv[2]};
std::ifstream ifs(input_file_path);
std::stringstream buffer;
buffer << ifs.rdbuf();
torch::jit::Module m("TestModule");
m.define(buffer.str());
m.save(output_file_path);
}

View File

@ -0,0 +1,151 @@
from typing import Optional
import torch
from torch import Tensor
OUTPUT_DIR = "src/androidTest/assets/"
def scriptAndSave(module, fileName):
print("-" * 80)
script_module = torch.jit.script(module)
print(script_module.graph)
outputFileName = OUTPUT_DIR + fileName
# note that the lite interpreter model can also be used in full JIT
script_module._save_for_lite_interpreter(outputFileName)
print("Saved to " + outputFileName)
print("=" * 80)
class Test(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return None
@torch.jit.script_method
def eqBool(self, input: bool) -> bool:
return input
@torch.jit.script_method
def eqInt(self, input: int) -> int:
return input
@torch.jit.script_method
def eqFloat(self, input: float) -> float:
return input
@torch.jit.script_method
def eqStr(self, input: str) -> str:
return input
@torch.jit.script_method
def eqTensor(self, input: Tensor) -> Tensor:
return input
@torch.jit.script_method
def eqDictStrKeyIntValue(self, input: dict[str, int]) -> dict[str, int]:
return input
@torch.jit.script_method
def eqDictIntKeyIntValue(self, input: dict[int, int]) -> dict[int, int]:
return input
@torch.jit.script_method
def eqDictFloatKeyIntValue(self, input: dict[float, int]) -> dict[float, int]:
return input
@torch.jit.script_method
def listIntSumReturnTuple(self, input: list[int]) -> tuple[list[int], int]:
sum = 0
for x in input:
sum += x
return (input, sum)
@torch.jit.script_method
def listBoolConjunction(self, input: list[bool]) -> bool:
res = True
for x in input:
res = res and x
return res
@torch.jit.script_method
def listBoolDisjunction(self, input: list[bool]) -> bool:
res = False
for x in input:
res = res or x
return res
@torch.jit.script_method
def tupleIntSumReturnTuple(
self, input: tuple[int, int, int]
) -> tuple[tuple[int, int, int], int]:
sum = 0
for x in input:
sum += x
return (input, sum)
@torch.jit.script_method
def optionalIntIsNone(self, input: Optional[int]) -> bool:
return input is None
@torch.jit.script_method
def intEq0None(self, input: int) -> Optional[int]:
if input == 0:
return None
return input
@torch.jit.script_method
def str3Concat(self, input: str) -> str:
return input + input + input
@torch.jit.script_method
def newEmptyShapeWithItem(self, input):
return torch.tensor([int(input.item())])[0]
@torch.jit.script_method
def testAliasWithOffset(self) -> list[Tensor]:
x = torch.tensor([100, 200])
a = [x[0], x[1]]
return a
@torch.jit.script_method
def testNonContiguous(self):
x = torch.tensor([100, 200, 300])[::2]
assert not x.is_contiguous()
assert x[0] == 100
assert x[1] == 300
return x
@torch.jit.script_method
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.nn.functional.conv2d(x, w)
if toChannelsLast:
r = r.contiguous(memory_format=torch.channels_last)
else:
r = r.contiguous()
return r
@torch.jit.script_method
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.nn.functional.conv3d(x, w)
if toChannelsLast:
r = r.contiguous(memory_format=torch.channels_last_3d)
else:
r = r.contiguous()
return r
@torch.jit.script_method
def contiguous(self, x: Tensor) -> Tensor:
return x.contiguous()
@torch.jit.script_method
def contiguousChannelsLast(self, x: Tensor) -> Tensor:
return x.contiguous(memory_format=torch.channels_last)
@torch.jit.script_method
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
return x.contiguous(memory_format=torch.channels_last_3d)
scriptAndSave(Test(), "test.pt")

View File

@ -0,0 +1,4 @@
POM_NAME=pytorch_android_lite pytorch android api
POM_DESCRIPTION=pytorch_android_lite pytorch android api
POM_ARTIFACT_ID=pytorch_android_lite
POM_PACKAGING=aar

View File

@ -0,0 +1,42 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the Apache-2 license found in the
// LICENSE file in the root directory of this source tree.
plugins {
id 'java-library'
}
repositories {
mavenLocal()
jcenter()
}
sourceSets {
main {
java {
srcDir '../src/main/java'
exclude 'org/pytorch/PyTorchAndroid.java'
exclude 'org/pytorch/LitePyTorchAndroid.java'
exclude 'org/pytorch/LiteModuleLoader.java'
exclude 'org/pytorch/LiteNativePeer.java'
}
}
test {
java {
srcDir '../src/androidTest/java'
exclude '**/PytorchInstrumented*'
exclude '**/PytorchLiteInstrumented*'
}
resources.srcDirs = ["../src/androidTest/assets"]
}
}
dependencies {
compileOnly 'com.google.code.findbugs:jsr305:3.0.1'
implementation 'com.facebook.soloader:nativeloader:0.10.1'
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
testImplementation 'junit:junit:4.12'
}
apply from: rootProject.file('gradle/release.gradle')

View File

@ -0,0 +1,4 @@
POM_NAME=pytorch_java_only pytorch java api
POM_DESCRIPTION=pytorch_java_only pytorch java api
POM_ARTIFACT_ID=pytorch_java_only
POM_PACKAGING=jar

Binary file not shown.

View File

@ -0,0 +1,21 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <gtest/gtest.h>
#include <ATen/core/type_factory.h>
#include "caffe2/android/pytorch_android/src/main/cpp/pytorch_jni_common.h"
using namespace ::testing;
TEST(pytorch_jni_common_test, newJIValueFromAtIValue) {
auto dict = c10::impl::GenericDict(
c10::dynT<c10::IntType>(), c10::dynT<c10::StringType>());
auto dictCallback = [](auto&&) {
return facebook::jni::local_ref<pytorch_jni::JIValue>{};
};
EXPECT_NO_THROW(pytorch_jni::JIValue::newJIValueFromAtIValue(
dict, dictCallback, dictCallback));
}

View File

@ -0,0 +1,25 @@
package org.pytorch;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.Objects;
public class PytorchHostTests extends PytorchTestBase {
@Override
protected Module loadModel(String path) throws IOException {
return Module.load(assetFilePath(path));
}
private String assetFilePath(String assetName) throws IOException {
Path tempFile = Files.createTempFile("test", ".pt");
try (InputStream resource =
Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {
Files.copy(resource, tempFile, StandardCopyOption.REPLACE_EXISTING);
}
return tempFile.toAbsolutePath().toString();
}
}

View File

@ -0,0 +1,42 @@
package org.pytorch;
import android.content.Context;
import androidx.test.InstrumentationRegistry;
import androidx.test.runner.AndroidJUnit4;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import org.junit.runner.RunWith;
@RunWith(AndroidJUnit4.class)
public class PytorchInstrumentedTests extends PytorchTestBase {
@Override
protected Module loadModel(String path) throws IOException {
return Module.load(assetFilePath(path));
}
private String assetFilePath(String assetName) throws IOException {
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
File file = new File(appContext.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = appContext.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
throw e;
}
}
}

View File

@ -0,0 +1,42 @@
package org.pytorch;
import android.content.Context;
import androidx.test.InstrumentationRegistry;
import androidx.test.runner.AndroidJUnit4;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import org.junit.runner.RunWith;
@RunWith(AndroidJUnit4.class)
public class PytorchLiteInstrumentedTests extends PytorchTestBase {
@Override
protected Module loadModel(String path) throws IOException {
return LiteModuleLoader.load(assetFilePath(path));
}
private String assetFilePath(String assetName) throws IOException {
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
File file = new File(appContext.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = appContext.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
throw e;
}
}
}

View File

@ -0,0 +1,694 @@
package org.pytorch;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
import org.junit.Ignore;
public abstract class PytorchTestBase {
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
@Test
public void testForwardNull() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
assertTrue(input.isTensor());
final IValue output = module.forward(input);
assertTrue(output.isNull());
}
@Test
public void testEqBool() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (boolean value : new boolean[] {false, true}) {
final IValue input = IValue.from(value);
assertTrue(input.isBool());
assertTrue(value == input.toBool());
final IValue output = module.runMethod("eqBool", input);
assertTrue(output.isBool());
assertTrue(value == output.toBool());
}
}
@Test
public void testEqInt() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
final IValue input = IValue.from(value);
assertTrue(input.isLong());
assertTrue(value == input.toLong());
final IValue output = module.runMethod("eqInt", input);
assertTrue(output.isLong());
assertTrue(value == output.toLong());
}
}
@Test
public void testEqFloat() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
double[] values =
new double[] {
-Double.MAX_VALUE,
Double.MAX_VALUE,
-Double.MIN_VALUE,
Double.MIN_VALUE,
-Math.exp(1.d),
-Math.sqrt(2.d),
-3.1415f,
3.1415f,
-1,
0,
1,
};
for (double value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isDouble());
assertTrue(value == input.toDouble());
final IValue output = module.runMethod("eqFloat", input);
assertTrue(output.isDouble());
assertTrue(value == output.toDouble());
}
}
@Test
public void testEqTensor() throws IOException {
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
final long numElements = Tensor.numel(inputTensorShape);
final float[] inputTensorData = new float[(int) numElements];
for (int i = 0; i < numElements; ++i) {
inputTensorData[i] = i;
}
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(inputTensor);
assertTrue(input.isTensor());
assertTrue(inputTensor == input.toTensor());
final IValue output = module.runMethod("eqTensor", input);
assertTrue(output.isTensor());
final Tensor outputTensor = output.toTensor();
assertNotNull(outputTensor);
assertArrayEquals(inputTensorShape, outputTensor.shape());
float[] outputData = outputTensor.getDataAsFloatArray();
for (int i = 0; i < numElements; i++) {
assertTrue(inputTensorData[i] == outputData[i]);
}
}
@Test
public void testEqDictIntKeyIntValue() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<Long, IValue> inputMap = new HashMap<>();
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
inputMap.put(0l, IValue.from(0l));
inputMap.put(1l, IValue.from(-1l));
inputMap.put(-1l, IValue.from(1l));
final IValue input = IValue.dictLongKeyFrom(inputMap);
assertTrue(input.isDictLongKey());
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
assertTrue(output.isDictLongKey());
final Map<Long, IValue> outputMap = output.toDictLongKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@Test
public void testEqDictStrKeyIntValue() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<String, IValue> inputMap = new HashMap<>();
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
inputMap.put("long_0", IValue.from(0l));
inputMap.put("long_1", IValue.from(1l));
inputMap.put("long_-1", IValue.from(-1l));
final IValue input = IValue.dictStringKeyFrom(inputMap);
assertTrue(input.isDictStringKey());
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
assertTrue(output.isDictStringKey());
final Map<String, IValue> outputMap = output.toDictStringKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@Test
public void testListIntSumReturnTuple() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (int n : new int[] {0, 1, 128}) {
long[] a = new long[n];
long sum = 0;
for (int i = 0; i < n; i++) {
a[i] = i;
sum += a[i];
}
final IValue input = IValue.listFrom(a);
assertTrue(input.isLongList());
final IValue output = module.runMethod("listIntSumReturnTuple", input);
assertTrue(output.isTuple());
assertTrue(2 == output.toTuple().length);
IValue output0 = output.toTuple()[0];
IValue output1 = output.toTuple()[1];
assertArrayEquals(a, output0.toLongList());
assertTrue(sum == output1.toLong());
}
}
@Test
public void testOptionalIntIsNone() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
}
@Test
public void testIntEq0None() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
}
@Test(expected = IllegalArgumentException.class)
public void testRunUndefinedMethod() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
module.runMethod("test_undefined_method_throws_exception");
}
@Test
public void testTensorMethods() {
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
int[] ints = new int[numel];
float[] floats = new float[numel];
byte[] bytes = new byte[numel];
for (int i = 0; i < numel; i++) {
bytes[i] = (byte) ((i % 255) - 128);
ints[i] = i;
floats[i] = i / 1000.f;
}
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
assertTrue(tensorBytes.dtype() == DType.INT8);
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
Tensor tensorInts = Tensor.fromBlob(ints, shape);
assertTrue(tensorInts.dtype() == DType.INT32);
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
float[] floatsOut = tensorFloats.getDataAsFloatArray();
assertTrue(floatsOut.length == numel);
for (int i = 0; i < numel; i++) {
assertTrue(floats[i] == floatsOut[i]);
}
}
@Test(expected = IllegalStateException.class)
public void testTensorIllegalStateOnWrongType() {
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
float[] floats = new float[numel];
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
tensorFloats.getDataAsByteArray();
}
@Test
public void testEqString() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[] {
"smoketest",
"проверка не латинских символов", // not latin symbols check
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("eqStr", input);
assertTrue(output.isString());
assertTrue(value.equals(output.toStr()));
}
}
@Test
public void testStr3Concat() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[] {
"smoketest",
"проверка не латинских символов", // not latin symbols check
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("str3Concat", input);
assertTrue(output.isString());
String expectedOutput =
new StringBuilder().append(value).append(value).append(value).toString();
assertTrue(expectedOutput.equals(output.toStr()));
}
}
@Test
public void testEmptyShape() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final long someNumber = 43;
final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {}));
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
assertTrue(output.isTensor());
Tensor value = output.toTensor();
assertArrayEquals(new long[] {}, value.shape());
assertArrayEquals(new long[] {someNumber}, value.getDataAsLongArray());
}
@Test
public void testAliasWithOffset() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testAliasWithOffset");
assertTrue(output.isTensorList());
Tensor[] tensors = output.toTensorList();
assertEquals(100, tensors[0].getDataAsLongArray()[0]);
assertEquals(200, tensors[1].getDataAsLongArray()[0]);
}
@Test
public void testNonContiguous() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testNonContiguous");
assertTrue(output.isTensor());
Tensor value = output.toTensor();
assertArrayEquals(new long[] {2}, value.shape());
assertArrayEquals(new long[] {100, 300}, value.getDataAsLongArray());
}
@Test
public void testChannelsLast() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2};
long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
assertIValueTensor(
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2},
new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104});
final IValue outputNHWC = module.runMethod("contiguousChannelsLast", IValue.from(inputNHWC));
assertIValueTensor(outputNHWC, MemoryFormat.CHANNELS_LAST, inputShape, data);
}
@Test
public void testChannelsLast3d() throws IOException {
long[] shape = new long[] {1, 2, 2, 2, 2};
long[] dataNCHWD = new long[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
Tensor inputNCHWD = Tensor.fromBlob(dataNCHWD, shape, MemoryFormat.CONTIGUOUS);
final IValue outputNHWDC =
module.runMethod("contiguousChannelsLast3d", IValue.from(inputNCHWD));
assertIValueTensor(outputNHWDC, MemoryFormat.CHANNELS_LAST_3D, shape, dataNHWDC);
}
@Test
public void testChannelsLastConv2d() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2};
long[] dataNCHW = new long[] {
111, 112,
121, 122,
211, 212,
221, 222,
311, 312,
321, 322};
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNHWC = new long[] {
111, 211, 311, 112, 212, 312,
121, 221, 321, 122, 222, 322};
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
long[] weightShape = new long[] {3, 3, 1, 1};
long[] dataWeightOIHW = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1};
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
long[] dataWeightOHWI = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1};
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW =
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
assertIValueTensor(
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2},
new long[] {
2*111, 2*112,
2*121, 2*122,
211, 212,
221, 222,
-311, -312,
-321, -322});
final IValue outputNHWC =
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
assertIValueTensor(
outputNHWC,
MemoryFormat.CHANNELS_LAST,
new long[] {1, 3, 2, 2},
new long[] {
2*111, 211, -311, 2*112, 212, -312,
2*121, 221, -321, 2*122, 222, -322});
}
@Test
public void testChannelsLastConv3d() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2, 2};
long[] dataNCDHW = new long[] {
1111, 1112,
1121, 1122,
1211, 1212,
1221, 1222,
2111, 2112,
2121, 2122,
2211, 2212,
2221, 2222,
3111, 3112,
3121, 3122,
3211, 3212,
3221, 3222};
Tensor inputNCDHW = Tensor.fromBlob(dataNCDHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNDHWC = new long[] {
1111, 2111, 3111,
1112, 2112, 3112,
1121, 2121, 3121,
1122, 2122, 3122,
1211, 2211, 3211,
1212, 2212, 3212,
1221, 2221, 3221,
1222, 2222, 3222};
Tensor inputNDHWC = Tensor.fromBlob(dataNDHWC, inputShape, MemoryFormat.CHANNELS_LAST_3D);
long[] weightShape = new long[] {3, 3, 1, 1, 1};
long[] dataWeightOIDHW = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1,
};
Tensor wNCDHW = Tensor.fromBlob(dataWeightOIDHW, weightShape, MemoryFormat.CONTIGUOUS);
long[] dataWeightODHWI = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1,
};
Tensor wNDHWC = Tensor.fromBlob(dataWeightODHWI, weightShape, MemoryFormat.CHANNELS_LAST_3D);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCDHW =
module.runMethod("conv3d", IValue.from(inputNCDHW), IValue.from(wNCDHW), IValue.from(false));
assertIValueTensor(
outputNCDHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2, 2},
new long[] {
2*1111, 2*1112, 2*1121, 2*1122,
2*1211, 2*1212, 2*1221, 2*1222,
2111, 2112, 2121, 2122,
2211, 2212, 2221, 2222,
-3111, -3112, -3121, -3122,
-3211, -3212, -3221, -3222});
final IValue outputNDHWC =
module.runMethod("conv3d", IValue.from(inputNDHWC), IValue.from(wNDHWC), IValue.from(true));
assertIValueTensor(
outputNDHWC,
MemoryFormat.CHANNELS_LAST_3D,
new long[] {1, 3, 2, 2, 2},
new long[] {
2*1111, 2111, -3111, 2*1112, 2112, -3112,
2*1121, 2121, -3121, 2*1122, 2122, -3122,
2*1211, 2211, -3211, 2*1212, 2212, -3212,
2*1221, 2221, -3221, 2*1222, 2222, -3222});
}
@Test
public void testMobileNetV2() throws IOException {
try {
final Module module = loadModel("mobilenet_v2.ptl");
final IValue inputs = module.runMethod("get_all_bundled_inputs");
assertTrue(inputs.isList());
final IValue input = inputs.toList()[0];
assertTrue(input.isTuple());
module.forward(input.toTuple()[0]);
assertTrue(true);
} catch (Exception ex) {
assertTrue("failed to run MobileNetV2 " + ex.getMessage(), false);
}
}
@Test
public void testPointwiseOps() throws IOException {
runModel("pointwise_ops");
}
@Test
public void testReductionOps() throws IOException {
runModel("reduction_ops");
}
@Test
public void testComparisonOps() throws IOException {
runModel("comparison_ops");
}
@Test
public void testOtherMathOps() throws IOException {
runModel("other_math_ops");
}
@Test
@Ignore
public void testSpectralOps() throws IOException {
// NB: This model fails without lite interpreter. The error is as follows:
// RuntimeError: stft requires the return_complex parameter be given for real inputs
runModel("spectral_ops");
}
@Test
public void testBlasLapackOps() throws IOException {
runModel("blas_lapack_ops");
}
@Test
public void testSamplingOps() throws IOException {
runModel("sampling_ops");
}
@Test
public void testTensorOps() throws IOException {
runModel("tensor_general_ops");
}
@Test
public void testTensorCreationOps() throws IOException {
runModel("tensor_creation_ops");
}
@Test
public void testTensorIndexingOps() throws IOException {
runModel("tensor_indexing_ops");
}
@Test
public void testTensorTypingOps() throws IOException {
runModel("tensor_typing_ops");
}
@Test
public void testTensorViewOps() throws IOException {
runModel("tensor_view_ops");
}
@Test
public void testConvolutionOps() throws IOException {
runModel("convolution_ops");
}
@Test
public void testPoolingOps() throws IOException {
runModel("pooling_ops");
}
@Test
public void testPaddingOps() throws IOException {
runModel("padding_ops");
}
@Test
public void testActivationOps() throws IOException {
runModel("activation_ops");
}
@Test
public void testNormalizationOps() throws IOException {
runModel("normalization_ops");
}
@Test
public void testRecurrentOps() throws IOException {
runModel("recurrent_ops");
}
@Test
public void testTransformerOps() throws IOException {
runModel("transformer_ops");
}
@Test
public void testLinearOps() throws IOException {
runModel("linear_ops");
}
@Test
public void testDropoutOps() throws IOException {
runModel("dropout_ops");
}
@Test
public void testSparseOps() throws IOException {
runModel("sparse_ops");
}
@Test
public void testDistanceFunctionOps() throws IOException {
runModel("distance_function_ops");
}
@Test
public void testLossFunctionOps() throws IOException {
runModel("loss_function_ops");
}
@Test
public void testVisionFunctionOps() throws IOException {
runModel("vision_function_ops");
}
@Test
public void testShuffleOps() throws IOException {
runModel("shuffle_ops");
}
@Test
public void testNNUtilsOps() throws IOException {
runModel("nn_utils_ops");
}
@Test
public void testQuantOps() throws IOException {
runModel("general_quant_ops");
}
@Test
public void testDynamicQuantOps() throws IOException {
runModel("dynamic_quant_ops");
}
@Test
public void testStaticQuantOps() throws IOException {
runModel("static_quant_ops");
}
@Test
public void testFusedQuantOps() throws IOException {
runModel("fused_quant_ops");
}
@Test
public void testTorchScriptBuiltinQuantOps() throws IOException {
runModel("torchscript_builtin_ops");
}
@Test
public void testTorchScriptCollectionQuantOps() throws IOException {
runModel("torchscript_collection_ops");
}
static void assertIValueTensor(
final IValue ivalue,
final MemoryFormat memoryFormat,
final long[] expectedShape,
final long[] expectedData) {
assertTrue(ivalue.isTensor());
Tensor t = ivalue.toTensor();
assertEquals(memoryFormat, t.memoryFormat());
assertArrayEquals(expectedShape, t.shape());
assertArrayEquals(expectedData, t.getDataAsLongArray());
}
void runModel(final String name) throws IOException {
final Module storage_module = loadModel(name + ".ptl");
storage_module.forward();
// TODO enable this once the on-the-fly script is ready
// final Module on_the_fly_module = loadModel(name + "_temp.ptl");
// on_the_fly_module.forward();
assertTrue(true);
}
protected abstract Module loadModel(String assetName) throws IOException;
}

View File

@ -0,0 +1,9 @@
package org.pytorch.suite;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.pytorch.PytorchInstrumentedTests;
@RunWith(Suite.class)
@Suite.SuiteClasses({PytorchInstrumentedTests.class})
public class PytorchInstrumentedTestSuite {}

View File

@ -0,0 +1,9 @@
package org.pytorch.suite;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.pytorch.PytorchLiteInstrumentedTests;
@RunWith(Suite.class)
@Suite.SuiteClasses({PytorchLiteInstrumentedTests.class})
public class PytorchLiteInstrumentedTestSuite {}

View File

@ -0,0 +1 @@
<manifest package="org.pytorch" />

View File

@ -0,0 +1,3 @@
#pragma once
/* #undef TRACE_ENABLED */

View File

@ -0,0 +1,3 @@
#pragma once
#cmakedefine TRACE_ENABLED

View File

@ -0,0 +1,697 @@
#include <cassert>
#include <iostream>
#include <memory>
#include <string>
#include <c10/core/MemoryFormat.h>
#include <c10/util/irange.h>
#include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h>
#include "pytorch_jni_common.h"
#if defined(__ANDROID__)
#ifndef USE_PTHREADPOOL
#define USE_PTHREADPOOL
#endif /* USE_PTHREADPOOL */
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#endif
namespace pytorch_jni {
c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode) {
if (deviceJniCode == kDeviceCPU) {
return at::kCPU;
} else if (deviceJniCode == kDeviceVulkan) {
return at::kVulkan;
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException, "Unknown device");
}
bool Trace::is_initialized_ = false;
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
Trace::fp_ATrace_beginSection Trace::ATrace_beginSection;
Trace::fp_ATrace_endSection Trace::ATrace_endSection;
#endif
void Trace::init() {
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
void* lib = dlopen("libandroid.so", RTLD_NOW || RTLD_LOCAL);
if (lib != NULL) {
Trace::ATrace_beginSection = reinterpret_cast<fp_ATrace_beginSection>(
dlsym(lib, "ATrace_beginSection"));
Trace::ATrace_endSection =
reinterpret_cast<fp_ATrace_endSection>(dlsym(lib, "ATrace_endSection"));
}
#endif
}
// NOTE: Codes must be kept in sync with DType.java.
// NOTE: Never serialize these, because they can change between releases.
constexpr static int kTensorDTypeUInt8 = 1;
constexpr static int kTensorDTypeInt8 = 2;
constexpr static int kTensorDTypeInt32 = 3;
constexpr static int kTensorDTypeFloat32 = 4;
constexpr static int kTensorDTypeInt64 = 5;
constexpr static int kTensorDTypeFloat64 = 6;
constexpr static int kTensorMemoryFormatContiguous = 1;
constexpr static int kTensorMemoryFormatChannelsLast = 2;
constexpr static int kTensorMemoryFormatChannelsLast3d = 3;
template <typename K = jobject, typename V = jobject>
struct JHashMap
: facebook::jni::JavaClass<JHashMap<K, V>, facebook::jni::JMap<K, V>> {
constexpr static auto kJavaDescriptor = "Ljava/util/HashMap;";
using Super =
facebook::jni::JavaClass<JHashMap<K, V>, facebook::jni::JMap<K, V>>;
static facebook::jni::local_ref<JHashMap<K, V>> create() {
return Super::newInstance();
}
void put(
facebook::jni::alias_ref<facebook::jni::JObject::javaobject> key,
facebook::jni::alias_ref<facebook::jni::JObject::javaobject> value) {
static auto putMethod =
Super::javaClassStatic()
->template getMethod<facebook::jni::alias_ref<
facebook::jni::JObject::javaobject>(
facebook::jni::alias_ref<facebook::jni::JObject::javaobject>,
facebook::jni::alias_ref<facebook::jni::JObject::javaobject>)>(
"put");
putMethod(Super::self(), key, value);
}
};
static at::Tensor newAtTensor(
facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,
facebook::jni::alias_ref<jlongArray> jshape,
jint jdtype,
jint jmemoryFormat) {
const auto rank = jshape->size();
const auto shapeArr = jshape->getRegion(0, rank);
std::vector<int64_t> shapeVec{};
shapeVec.reserve(rank);
auto numel = 1;
for (const auto i : c10::irange(rank)) {
shapeVec.push_back(shapeArr[i]);
numel *= shapeArr[i];
}
JNIEnv* jni = facebook::jni::Environment::current();
caffe2::TypeMeta typeMeta{};
int dataElementSizeBytes = 0;
if (kTensorDTypeFloat32 == jdtype) {
dataElementSizeBytes = 4;
typeMeta = caffe2::TypeMeta::Make<float>();
} else if (kTensorDTypeInt32 == jdtype) {
dataElementSizeBytes = 4;
typeMeta = caffe2::TypeMeta::Make<int32_t>();
} else if (kTensorDTypeInt8 == jdtype) {
dataElementSizeBytes = 1;
typeMeta = caffe2::TypeMeta::Make<int8_t>();
} else if (kTensorDTypeUInt8 == jdtype) {
dataElementSizeBytes = 1;
typeMeta = caffe2::TypeMeta::Make<uint8_t>();
} else if (kTensorDTypeFloat64 == jdtype) {
dataElementSizeBytes = 8;
typeMeta = caffe2::TypeMeta::Make<double>();
} else if (kTensorDTypeInt64 == jdtype) {
dataElementSizeBytes = 8;
typeMeta = caffe2::TypeMeta::Make<int64_t>();
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown Tensor jdtype %d",
jdtype);
}
const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
if (dataCapacity != numel) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Tensor dimensions(elements number:%d, element byte size:%d, total "
"bytes:%d) inconsistent with buffer capacity(%d)",
numel,
dataElementSizeBytes,
numel * dataElementSizeBytes,
dataCapacity);
}
if (jmemoryFormat == kTensorMemoryFormatChannelsLast) {
auto sizes = torch::IntArrayRef(shapeVec);
return torch::from_blob(
jni->GetDirectBufferAddress(jbuffer.get()),
sizes,
torch::IntArrayRef(c10::get_channels_last_strides_2d(sizes)),
at::TensorOptions(typeMeta).memory_format(
at::MemoryFormat::ChannelsLast));
} else if (jmemoryFormat == kTensorMemoryFormatChannelsLast3d) {
auto sizes = torch::IntArrayRef(shapeVec);
return torch::from_blob(
jni->GetDirectBufferAddress(jbuffer.get()),
sizes,
torch::IntArrayRef(c10::get_channels_last_strides_3d(sizes)),
at::TensorOptions(typeMeta).memory_format(
at::MemoryFormat::ChannelsLast3d));
}
return torch::from_blob(
jni->GetDirectBufferAddress(jbuffer.get()),
torch::IntArrayRef(shapeVec),
at::TensorOptions(typeMeta));
}
class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
public:
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/Tensor;";
explicit TensorHybrid(at::Tensor tensor) : tensor_(tensor) {}
static facebook::jni::local_ref<TensorHybrid::jhybriddata> initHybrid(
facebook::jni::alias_ref<TensorHybrid::javaobject> jTensorThis) {
static auto cls = TensorHybrid::javaClassStatic();
static const auto jMethodDTypeCode = cls->getMethod<jint()>("dtypeJniCode");
static const auto jMethodMemoryFormatCode =
cls->getMethod<jint()>("memoryFormatJniCode");
static const auto jFieldShape = cls->getField<jlongArray>("shape");
static const auto jMethodGetDataBuffer = cls->getMethod<
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
"getRawDataBuffer");
at::Tensor tensor = newAtTensor(
jMethodGetDataBuffer(jTensorThis),
jTensorThis->getFieldValue(jFieldShape),
jMethodDTypeCode(jTensorThis),
jMethodMemoryFormatCode(jTensorThis));
return makeCxxInstance(std::move(tensor));
}
static facebook::jni::local_ref<TensorHybrid::javaobject>
newJTensorFromAtTensor(const at::Tensor& input_tensor) {
// Java wrapper currently only supports contiguous tensors.
int jmemoryFormat = 0;
at::Tensor tensor{};
if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast)) {
tensor = input_tensor;
jmemoryFormat = kTensorMemoryFormatChannelsLast;
} else if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
tensor = input_tensor;
jmemoryFormat = kTensorMemoryFormatChannelsLast3d;
} else {
tensor = input_tensor.contiguous();
jmemoryFormat = kTensorMemoryFormatContiguous;
}
const auto scalarType = tensor.scalar_type();
int jdtype = 0;
if (at::kFloat == scalarType) {
jdtype = kTensorDTypeFloat32;
} else if (at::kInt == scalarType) {
jdtype = kTensorDTypeInt32;
} else if (at::kByte == scalarType) {
jdtype = kTensorDTypeUInt8;
} else if (at::kChar == scalarType) {
jdtype = kTensorDTypeInt8;
} else if (at::kLong == scalarType) {
jdtype = kTensorDTypeInt64;
} else if (at::kDouble == scalarType) {
jdtype = kTensorDTypeFloat64;
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"at::Tensor scalar type %s is not supported on java side",
c10::toString(scalarType));
}
const auto& tensorShape = tensor.sizes();
std::vector<jlong> tensorShapeVec;
for (const auto& s : tensorShape) {
tensorShapeVec.push_back(s);
}
facebook::jni::local_ref<jlongArray> jTensorShape =
facebook::jni::make_long_array(tensorShapeVec.size());
jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data());
static auto cls = TensorHybrid::javaClassStatic();
facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
facebook::jni::JByteBuffer::wrapBytes(
(uint8_t*)tensor.data_ptr(), tensor.nbytes());
jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
static const auto jMethodNewTensor =
cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
facebook::jni::alias_ref<jlongArray>,
jint,
jint,
facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
return jMethodNewTensor(
cls,
jTensorBuffer,
jTensorShape,
jdtype,
jmemoryFormat,
makeCxxInstance(tensor));
}
static at::Tensor newAtTensorFromJTensor(
facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor) {
static auto cls = TensorHybrid::javaClassStatic();
static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
jint jdtype = dtypeMethod(jtensor);
static const auto memoryFormatMethod =
cls->getMethod<jint()>("memoryFormatJniCode");
jint jmemoryFormat = memoryFormatMethod(jtensor);
static const auto shapeField = cls->getField<jlongArray>("shape");
auto jshape = jtensor->getFieldValue(shapeField);
static auto dataBufferMethod = cls->getMethod<
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
"getRawDataBuffer");
facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
dataBufferMethod(jtensor);
return newAtTensor(jbuffer, jshape, jdtype, jmemoryFormat);
}
at::Tensor tensor() const {
return tensor_;
}
private:
friend HybridBase;
at::Tensor tensor_;
};
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromStringDict(
c10::Dict<c10::IValue, c10::IValue> dict) {
static auto jMethodDictStringKey =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JMap<
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
"dictStringKeyFrom");
auto jmap = JHashMap<
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>::create();
for (auto& pair : dict) {
jmap->put(
facebook::jni::make_jstring(pair.key().toStringRef()),
JIValue::newJIValueFromAtIValue(pair.value()));
}
return jMethodDictStringKey(JIValue::javaClassStatic(), jmap);
}
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromIntDict(
c10::Dict<c10::IValue, c10::IValue> dict) {
static auto jMethodDictLongKey =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JMap<
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
"dictLongKeyFrom");
auto jmap = JHashMap<
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>::create();
for (auto& pair : dict) {
jmap->put(
facebook::jni::JLong::valueOf(pair.key().toInt()),
JIValue::newJIValueFromAtIValue(pair.value()));
}
return jMethodDictLongKey(JIValue::javaClassStatic(), jmap);
}
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
const at::IValue& ivalue,
DictCallback stringDictCallback,
DictCallback intDictCallback) {
Trace _s{"jni::JIValue::newJIValueFromAtIValue"};
if (ivalue.isNone()) {
static auto jMethodOptionalNull =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>()>(
"optionalNull");
return jMethodOptionalNull(JIValue::javaClassStatic());
} else if (ivalue.isTensor()) {
static auto jMethodTensor =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::local_ref<TensorHybrid::javaobject>)>("from");
const auto& tensor = ivalue.toTensor();
return jMethodTensor(
JIValue::javaClassStatic(),
TensorHybrid::newJTensorFromAtTensor(tensor));
} else if (ivalue.isBool()) {
static auto jMethodBool =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(jboolean)>(
"from");
return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool());
} else if (ivalue.isInt()) {
static auto jMethodInt =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(jlong)>("from");
return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt());
} else if (ivalue.isDouble()) {
static auto jMethodDouble =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(jdouble)>(
"from");
return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble());
} else if (ivalue.isString()) {
static auto jMethodString =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JString::javaobject>)>(
"from");
return jMethodString(
JIValue::javaClassStatic(),
facebook::jni::make_jstring(ivalue.toStringRef()));
} else if (ivalue.isTuple()) {
auto elementsVec = ivalue.toTupleRef().elements();
static auto jMethodTupleArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JArrayClass<
JIValue::javaobject>::javaobject>)>("tupleFrom");
auto jElementsArray =
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
elementsVec.size());
auto index = 0;
for (const auto& e : elementsVec) {
(*jElementsArray)[index++] = JIValue::newJIValueFromAtIValue(e);
}
return jMethodTupleArr(JIValue::javaClassStatic(), jElementsArray);
} else if (ivalue.isBoolList()) {
auto list = ivalue.toBoolList();
static auto jMethodBoolListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<jbooleanArray>)>("listFrom");
size_t n = list.size();
auto jArray = facebook::jni::make_boolean_array(n);
auto jArrayPinned = jArray->pin();
auto index = 0;
for (const auto& e : list) {
jArrayPinned[index++] = e;
}
return jMethodBoolListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isIntList()) {
auto list = ivalue.toIntList();
static auto jMethodLongListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<jlongArray>)>("listFrom");
size_t n = list.size();
auto jArray = facebook::jni::make_long_array(n);
auto jArrayPinned = jArray->pin();
auto index = 0;
for (const auto& e : list) {
jArrayPinned[index++] = e;
}
return jMethodLongListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isDoubleList()) {
auto list = ivalue.toDoubleList();
static auto jMethoDoubleListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<jdoubleArray>)>("listFrom");
size_t n = list.size();
auto jArray = facebook::jni::make_double_array(n);
auto jArrayPinned = jArray->pin();
auto index = 0;
for (const auto& e : list) {
jArrayPinned[index++] = e;
}
return jMethoDoubleListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isTensorList()) {
auto list = ivalue.toTensorList();
static auto jMethodTensorListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JArrayClass<
TensorHybrid::javaobject>::javaobject>)>("listFrom");
auto jArray =
facebook::jni::JArrayClass<TensorHybrid::javaobject>::newArray(
list.size());
auto index = 0;
for (const auto& e : list) {
(*jArray)[index++] = TensorHybrid::newJTensorFromAtTensor(e);
}
return jMethodTensorListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isList()) {
auto list = ivalue.toList();
static auto jMethodListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JArrayClass<
JIValue::javaobject>::javaobject>)>("listFrom");
auto jArray =
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(list.size());
auto index = 0;
for (const auto& e : list) {
(*jArray)[index++] = JIValue::newJIValueFromAtIValue(e);
}
return jMethodListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isGenericDict()) {
auto dict = ivalue.toGenericDict();
const auto keyType = dict.keyType();
if (!keyType) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown IValue-Dict key type");
}
if (*keyType == *c10::StringType::get()) {
return stringDictCallback(std::move(dict));
} else if (*keyType == *c10::IntType::get()) {
return intDictCallback(std::move(dict));
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unsupported IValue-Dict key type: %s",
keyType->str().c_str());
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unsupported IValue type %s",
ivalue.tagKind().c_str());
}
at::IValue JIValue::JIValueToAtIValue(
facebook::jni::alias_ref<JIValue> jivalue) {
Trace _s{"jni::JIValue::JIValueToAtIValue"};
static const auto typeCodeField =
JIValue::javaClassStatic()->getField<jint>("mTypeCode");
const auto typeCode = jivalue->getFieldValue(typeCodeField);
if (JIValue::kTypeCodeNull == typeCode) {
return at::IValue{};
} else if (JIValue::kTypeCodeTensor == typeCode) {
static const auto jMethodGetTensor =
JIValue::javaClassStatic()
->getMethod<facebook::jni::alias_ref<TensorHybrid::javaobject>()>(
"toTensor");
return TensorHybrid::newAtTensorFromJTensor(jMethodGetTensor(jivalue));
} else if (JIValue::kTypeCodeBool == typeCode) {
static const auto jMethodGetBool =
JIValue::javaClassStatic()->getMethod<jboolean()>("toBool");
// explicit cast to bool as jboolean is defined as uint8_t, IValue ctor
// for int will be called for jboolean
bool b = jMethodGetBool(jivalue);
return at::IValue{b};
} else if (JIValue::kTypeCodeLong == typeCode) {
static const auto jMethodGetLong =
JIValue::javaClassStatic()->getMethod<jlong()>("toLong");
return at::IValue{(int64_t)jMethodGetLong(jivalue)};
} else if (JIValue::kTypeCodeDouble == typeCode) {
static const auto jMethodGetDouble =
JIValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
return at::IValue{jMethodGetDouble(jivalue)};
} else if (JIValue::kTypeCodeString == typeCode) {
static const auto jMethodGetString =
JIValue::javaClassStatic()->getMethod<jstring()>("toStr");
return at::IValue{jMethodGetString(jivalue)->toStdString()};
} else if (JIValue::kTypeCodeTuple == typeCode) {
static const auto jMethodGetTuple =
JIValue::javaClassStatic()
->getMethod<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject()>(
"toTuple");
auto jarray = jMethodGetTuple(jivalue);
size_t n = jarray->size();
std::vector<at::IValue> elements;
elements.reserve(n);
for (const auto i : c10::irange(n)) {
auto jivalue_element = jarray->getElement(i);
auto element = JIValue::JIValueToAtIValue(jivalue_element);
elements.push_back(std::move(element));
}
return c10::ivalue::Tuple::create(std::move(elements));
} else if (JIValue::kTypeCodeBoolList == typeCode) {
static const auto jMethodGetBoolList =
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("toBoolList");
auto jArray = jMethodGetBoolList(jivalue);
auto jArrayPinned = jArray->pin();
size_t n = jArrayPinned.size();
c10::List<bool> list{};
list.reserve(n);
for (const auto i : c10::irange(n)) {
list.push_back(jArrayPinned[i]);
}
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeLongList == typeCode) {
static const auto jMethodGetLongList =
JIValue::javaClassStatic()->getMethod<jlongArray()>("toLongList");
auto jArray = jMethodGetLongList(jivalue);
auto jArrayPinned = jArray->pin();
size_t n = jArrayPinned.size();
c10::List<int64_t> list{};
list.reserve(n);
for (const auto i : c10::irange(n)) {
list.push_back(jArrayPinned[i]);
}
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeDoubleList == typeCode) {
static const auto jMethodGetDoubleList =
JIValue::javaClassStatic()->getMethod<jdoubleArray()>("toDoubleList");
auto jArray = jMethodGetDoubleList(jivalue);
auto jArrayPinned = jArray->pin();
size_t n = jArrayPinned.size();
c10::List<double> list{};
list.reserve(n);
for (const auto i : c10::irange(n)) {
list.push_back(jArrayPinned[i]);
}
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeTensorList == typeCode) {
static const auto jMethodGetTensorList =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JArrayClass<
TensorHybrid::javaobject>::javaobject()>("toTensorList");
auto jArray = jMethodGetTensorList(jivalue);
size_t n = jArray->size();
c10::List<at::Tensor> list{};
list.reserve(n);
for (const auto i : c10::irange(n)) {
list.push_back(
TensorHybrid::newAtTensorFromJTensor(jArray->getElement(i)));
}
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeList == typeCode) {
static const auto jMethodGetList =
JIValue::javaClassStatic()
->getMethod<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject()>(
"toList");
auto jarray = jMethodGetList(jivalue);
size_t n = jarray->size();
if (n == 0) {
return at::IValue{c10::impl::GenericList(c10::TensorType::get())};
}
auto jivalue_first_element = jarray->getElement(0);
auto first_element = JIValue::JIValueToAtIValue(jivalue_first_element);
c10::impl::GenericList list{c10::unshapedType(first_element.type())};
list.reserve(n);
list.push_back(first_element);
for (const auto i : c10::irange(1, n)) {
auto jivalue_element = jarray->getElement(i);
auto element = JIValue::JIValueToAtIValue(jivalue_element);
list.push_back(element);
}
return at::IValue{list};
} else if (JIValue::kTypeCodeDictStringKey == typeCode) {
static const auto jMethodGetDictStringKey =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JMap<jstring, JIValue::javaobject>::
javaobject()>("toDictStringKey");
auto jmap = jMethodGetDictStringKey(jivalue);
auto it = jmap->begin();
if (it == jmap->end()) {
return at::IValue{c10::impl::GenericDict(
c10::StringType::get(), c10::TensorType::get())};
}
auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
c10::impl::GenericDict dict{
c10::StringType::get(), c10::unshapedType(firstEntryValue.type())};
dict.insert(it->first->toStdString(), firstEntryValue);
it++;
for (; it != jmap->end(); it++) {
dict.insert(
it->first->toStdString(), JIValue::JIValueToAtIValue(it->second));
}
return at::IValue{dict};
} else if (JIValue::kTypeCodeDictLongKey == typeCode) {
static const auto jMethodGetDictLongKey =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JMap<
facebook::jni::JLong::javaobject,
JIValue::javaobject>::javaobject()>("toDictLongKey");
auto jmap = jMethodGetDictLongKey(jivalue);
auto it = jmap->begin();
if (it == jmap->end()) {
return at::IValue{
c10::impl::GenericDict(c10::IntType::get(), c10::TensorType::get())};
}
auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
c10::impl::GenericDict dict{
c10::IntType::get(), c10::unshapedType(firstEntryValue.type())};
dict.insert((int64_t)it->first->longValue(), firstEntryValue);
it++;
for (; it != jmap->end(); it++) {
dict.insert(
(int64_t)it->first->longValue(),
JIValue::JIValueToAtIValue(it->second));
}
return at::IValue{dict};
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown IValue typeCode %d",
typeCode);
}
#if defined(__ANDROID__)
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;";
static void registerNatives() {
javaClassStatic()->registerNatives({
makeNativeMethod(
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
});
}
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
caffe2::pthreadpool()->set_thread_count(numThreads);
}
};
#endif
void common_registerNatives() {
static const int once = []() {
#if defined(__ANDROID__)
pytorch_jni::PyTorchAndroidJni::registerNatives();
#endif
return 0;
}();
((void)once);
}
} // namespace pytorch_jni

View File

@ -0,0 +1,137 @@
#pragma once
#include <c10/util/FunctionRef.h>
#include <fbjni/fbjni.h>
#include <torch/csrc/api/include/torch/types.h>
#include "caffe2/serialize/read_adapter_interface.h"
#include "cmake_macros.h"
#ifdef __ANDROID__
#include <android/log.h>
#define ALOGI(...) \
__android_log_print(ANDROID_LOG_INFO, "pytorch-jni", __VA_ARGS__)
#define ALOGE(...) \
__android_log_print(ANDROID_LOG_ERROR, "pytorch-jni", __VA_ARGS__)
#endif
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
#include <android/trace.h>
#include <dlfcn.h>
#endif
namespace pytorch_jni {
constexpr static int kDeviceCPU = 1;
constexpr static int kDeviceVulkan = 2;
c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode);
class Trace {
public:
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
typedef void* (*fp_ATrace_beginSection)(const char* sectionName);
typedef void* (*fp_ATrace_endSection)(void);
static fp_ATrace_beginSection ATrace_beginSection;
static fp_ATrace_endSection ATrace_endSection;
#endif
static void ensureInit() {
if (!Trace::is_initialized_) {
init();
Trace::is_initialized_ = true;
}
}
static void beginSection(const char* name) {
Trace::ensureInit();
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
ATrace_beginSection(name);
#endif
}
static void endSection() {
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
ATrace_endSection();
#endif
}
Trace(const char* name) {
ensureInit();
beginSection(name);
}
~Trace() {
endSection();
}
private:
static void init();
static bool is_initialized_;
};
class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
public:
explicit MemoryReadAdapter(const void* data, off_t size)
: data_(data), size_(size){};
size_t size() const override {
return size_;
}
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override {
memcpy(buf, (int8_t*)(data_) + pos, n);
return n;
}
~MemoryReadAdapter() {}
private:
const void* data_;
off_t size_;
};
class JIValue : public facebook::jni::JavaClass<JIValue> {
using DictCallback = c10::function_ref<facebook::jni::local_ref<JIValue>(
c10::Dict<c10::IValue, c10::IValue>)>;
public:
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
constexpr static int kTypeCodeNull = 1;
constexpr static int kTypeCodeTensor = 2;
constexpr static int kTypeCodeBool = 3;
constexpr static int kTypeCodeLong = 4;
constexpr static int kTypeCodeDouble = 5;
constexpr static int kTypeCodeString = 6;
constexpr static int kTypeCodeTuple = 7;
constexpr static int kTypeCodeBoolList = 8;
constexpr static int kTypeCodeLongList = 9;
constexpr static int kTypeCodeDoubleList = 10;
constexpr static int kTypeCodeTensorList = 11;
constexpr static int kTypeCodeList = 12;
constexpr static int kTypeCodeDictStringKey = 13;
constexpr static int kTypeCodeDictLongKey = 14;
static facebook::jni::local_ref<JIValue> newJIValueFromAtIValue(
const at::IValue& ivalue,
DictCallback stringDictCallback = newJIValueFromStringDict,
DictCallback intDictCallback = newJIValueFromIntDict);
static at::IValue JIValueToAtIValue(
facebook::jni::alias_ref<JIValue> jivalue);
private:
static facebook::jni::local_ref<JIValue> newJIValueFromStringDict(
c10::Dict<c10::IValue, c10::IValue>);
static facebook::jni::local_ref<JIValue> newJIValueFromIntDict(
c10::Dict<c10::IValue, c10::IValue>);
};
void common_registerNatives();
} // namespace pytorch_jni

View File

@ -0,0 +1,245 @@
#include <cassert>
#include <iostream>
#include <memory>
#include <string>
#include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h>
#include <ATen/record_function.h>
#include <torch/csrc/jit/runtime/print_handler.h>
#include <torch/script.h>
#include "caffe2/serialize/read_adapter_interface.h"
#include "pytorch_jni_common.h"
#ifdef __ANDROID__
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include <android/log.h>
#endif
namespace pytorch_jni {
namespace {
struct JITCallGuard {
// Inference only workload.
c10::InferenceMode guard;
// Disable graph optimizer to ensure list of unused ops are not changed for
// custom mobile build.
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
};
} // namespace
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
private:
friend HybridBase;
torch::jit::Module module_;
c10::DeviceType deviceType_;
public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;";
static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> modelPath,
facebook::jni::alias_ref<
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
extraFiles,
jint device) {
return makeCxxInstance(modelPath, extraFiles, device);
}
#ifdef __ANDROID__
static facebook::jni::local_ref<jhybriddata> initHybridAndroidAsset(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
return makeCxxInstance(assetName, assetManager, device);
}
#endif
#ifdef TRACE_ENABLED
static std::unique_ptr<at::ObserverContext> onFunctionEnter(
const at::RecordFunction& fn) {
Trace::beginSection(fn.name().str());
return nullptr;
}
static void onFunctionExit(const at::RecordFunction&, at::ObserverContext*) {
Trace::endSection();
}
#endif
static void preModuleLoadSetupOnce() {
auto qengines = at::globalContext().supportedQEngines();
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
qengines.end()) {
at::globalContext().setQEngine(at::QEngine::QNNPACK);
}
#ifdef __ANDROID__
torch::jit::setPrintHandler([](const std::string& s) {
__android_log_print(ANDROID_LOG_DEBUG, "pytorch-print", "%s", s.c_str());
});
#endif
#ifdef TRACE_ENABLED
at::addGlobalCallback(
at::RecordFunctionCallback(&onFunctionEnter, &onFunctionExit)
.scopes({RecordScope::FUNCTION, RecordScope::USER_SCOPE}));
#endif
}
void preModuleLoadSetup() {
static const int once = []() {
preModuleLoadSetupOnce();
return 0;
}();
((void)once);
}
PytorchJni(
facebook::jni::alias_ref<jstring> modelPath,
facebook::jni::alias_ref<
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
extraFiles,
jint device) {
preModuleLoadSetup();
JITCallGuard guard;
std::unordered_map<std::string, std::string> extra_files;
const auto has_extra = extraFiles && extraFiles->size() > 0;
if (has_extra) {
for (const auto& e : *extraFiles) {
extra_files[e.first->toStdString()] = "";
}
}
deviceType_ = deviceJniCodeToDeviceType(device);
module_ = torch::jit::load(
std::move(modelPath->toStdString()), std::nullopt, extra_files);
if (has_extra) {
static auto putMethod =
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>::
javaClassStatic()
->template getMethod<facebook::jni::alias_ref<jobject>(
facebook::jni::alias_ref<jobject>,
facebook::jni::alias_ref<jobject>)>("put");
for (const auto& ef : extra_files) {
putMethod(
extraFiles,
facebook::jni::make_jstring(ef.first),
facebook::jni::make_jstring(ef.second));
}
}
module_.eval();
}
#ifdef __ANDROID__
PytorchJni(
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
preModuleLoadSetup();
JNIEnv* env = facebook::jni::Environment::current();
AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get());
if (!mgr) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unable to get asset manager");
}
AAsset* asset = AAssetManager_open(
mgr, assetName->toStdString().c_str(), AASSET_MODE_BUFFER);
if (!asset) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Failed to open asset '%s'",
assetName->toStdString().c_str());
}
auto assetBuffer = AAsset_getBuffer(asset);
if (!assetBuffer) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Could not get buffer for asset '%s'",
assetName->toStdString().c_str());
}
JITCallGuard guard;
module_ = torch::jit::load(std::make_unique<MemoryReadAdapter>(
assetBuffer, AAsset_getLength(asset)));
AAsset_close(asset);
module_.eval();
deviceType_ = deviceJniCodeToDeviceType(device);
}
#endif
static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", PytorchJni::initHybrid),
#ifdef __ANDROID__
makeNativeMethod(
"initHybridAndroidAsset", PytorchJni::initHybridAndroidAsset),
#endif
makeNativeMethod("forward", PytorchJni::forward),
makeNativeMethod("runMethod", PytorchJni::runMethod),
});
}
facebook::jni::local_ref<JIValue> forward(
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
jinputs) {
Trace _s{"jni::Module::forward"};
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
}
auto output = [&]() {
JITCallGuard guard;
return module_.forward(std::move(inputs));
}();
return JIValue::newJIValueFromAtIValue(output);
}
facebook::jni::local_ref<JIValue> runMethod(
facebook::jni::alias_ref<facebook::jni::JString::javaobject> jmethodName,
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
jinputs) {
std::string methodName = jmethodName->toStdString();
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
}
if (auto method = module_.find_method(methodName)) {
auto output = [&]() {
JITCallGuard guard;
return (*method)(std::move(inputs));
}();
return JIValue::newJIValueFromAtIValue(output);
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Undefined method %s",
methodName.c_str());
}
};
} // namespace pytorch_jni
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(vm, [] {
pytorch_jni::common_registerNatives();
pytorch_jni::PytorchJni::registerNatives();
});
}

View File

@ -0,0 +1,209 @@
#include <cassert>
#include <iostream>
#include <memory>
#include <string>
#include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/script.h>
#include "caffe2/serialize/read_adapter_interface.h"
#include "pytorch_jni_common.h"
#ifdef __ANDROID__
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include <android/log.h>
#endif
namespace pytorch_jni {
namespace {
struct LiteJITCallGuard {
// VariableType dispatch is not included in default mobile build. We need set
// this guard globally to avoid dispatch error (only for dynamic dispatch).
// Thanks to the unification of Variable class and Tensor class it's no longer
// required to toggle the NonVariableTypeMode per op - so it doesn't hurt to
// always set NonVariableTypeMode for inference only use case.
// TODO: Ideally AutoNonVariableTypeMode in this file should be changed to
// InferenceMode but it's blocked due to typeahead application on Oculus
// (D27943428). To unblock, we need to find out which op is making inplace
// update to an inference tensor outside InferenceMode and properly guard it.
torch::AutoNonVariableTypeMode non_var_guard;
};
} // namespace
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
private:
friend HybridBase;
torch::jit::mobile::Module module_;
c10::DeviceType deviceType_;
public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/LiteNativePeer;";
static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> modelPath,
facebook::jni::alias_ref<
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
extraFiles,
jint device) {
return makeCxxInstance(modelPath, extraFiles, device);
}
#ifdef __ANDROID__
static facebook::jni::local_ref<jhybriddata> initHybridAndroidAsset(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
return makeCxxInstance(assetName, assetManager, device);
}
#endif
PytorchJni(
facebook::jni::alias_ref<jstring> modelPath,
facebook::jni::alias_ref<
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
extraFiles,
jint device) {
LiteJITCallGuard guard;
std::unordered_map<std::string, std::string> extra_files;
const auto has_extra = extraFiles && extraFiles->size() > 0;
if (has_extra) {
for (const auto& e : *extraFiles) {
extra_files[e.first->toStdString()] = "";
}
}
deviceType_ = deviceJniCodeToDeviceType(device);
module_ = torch::jit::_load_for_mobile(
std::move(modelPath->toStdString()), std::nullopt, extra_files);
torch::jit::_load_extra_only_for_mobile(
std::move(modelPath->toStdString()), std::nullopt, extra_files);
if (has_extra) {
static auto putMethod =
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>::
javaClassStatic()
->template getMethod<facebook::jni::alias_ref<jobject>(
facebook::jni::alias_ref<jobject>,
facebook::jni::alias_ref<jobject>)>("put");
for (const auto& ef : extra_files) {
putMethod(
extraFiles,
facebook::jni::make_jstring(ef.first),
facebook::jni::make_jstring(ef.second));
}
}
}
#ifdef __ANDROID__
PytorchJni(
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
JNIEnv* env = facebook::jni::Environment::current();
AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get());
if (!mgr) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unable to get asset manager");
}
AAsset* asset = AAssetManager_open(
mgr, assetName->toStdString().c_str(), AASSET_MODE_BUFFER);
if (!asset) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Failed to open asset '%s'",
assetName->toStdString().c_str());
}
auto assetBuffer = AAsset_getBuffer(asset);
if (!assetBuffer) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Could not get buffer for asset '%s'",
assetName->toStdString().c_str());
}
LiteJITCallGuard guard;
module_ =
torch::jit::_load_for_mobile(std::make_unique<MemoryReadAdapter>(
assetBuffer, AAsset_getLength(asset)));
AAsset_close(asset);
deviceType_ = deviceJniCodeToDeviceType(device);
}
#endif
static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", PytorchJni::initHybrid),
#ifdef __ANDROID__
makeNativeMethod(
"initHybridAndroidAsset", PytorchJni::initHybridAndroidAsset),
#endif
makeNativeMethod("forward", PytorchJni::forward),
makeNativeMethod("runMethod", PytorchJni::runMethod),
});
}
facebook::jni::local_ref<JIValue> forward(
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
jinputs) {
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
}
auto output = [&]() {
LiteJITCallGuard guard;
return module_.forward(inputs);
}();
return JIValue::newJIValueFromAtIValue(output);
}
facebook::jni::local_ref<JIValue> runMethod(
facebook::jni::alias_ref<facebook::jni::JString::javaobject> jmethodName,
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
jinputs) {
std::string methodName = jmethodName->toStdString();
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
}
if (auto method = module_.find_method(methodName)) {
auto output = [&]() {
LiteJITCallGuard guard;
return module_.get_method(methodName)(inputs);
}();
return JIValue::newJIValueFromAtIValue(output);
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Undefined method %s",
methodName.c_str());
}
};
} // namespace pytorch_jni
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(vm, [] {
pytorch_jni::common_registerNatives();
pytorch_jni::PytorchJni::registerNatives();
});
}

View File

@ -0,0 +1,27 @@
package org.pytorch;
/** Codes representing tensor data types. */
public enum DType {
// NOTE: "jniCode" must be kept in sync with pytorch_jni_common.cpp.
// NOTE: Never serialize "jniCode", because it can change between releases.
/** Code for dtype torch.uint8. {@link Tensor#dtype()} */
UINT8(1),
/** Code for dtype torch.int8. {@link Tensor#dtype()} */
INT8(2),
/** Code for dtype torch.int32. {@link Tensor#dtype()} */
INT32(3),
/** Code for dtype torch.float32. {@link Tensor#dtype()} */
FLOAT32(4),
/** Code for dtype torch.int64. {@link Tensor#dtype()} */
INT64(5),
/** Code for dtype torch.float64. {@link Tensor#dtype()} */
FLOAT64(6),
;
final int jniCode;
DType(int jniCode) {
this.jniCode = jniCode;
}
}

View File

@ -0,0 +1,15 @@
package org.pytorch;
public enum Device {
// Must be in sync with kDeviceCPU, kDeviceVulkan in
// pytorch_android/src/main/cpp/pytorch_jni_lite.cpp
CPU(1),
VULKAN(2),
;
final int jniCode;
Device(int jniCode) {
this.jniCode = jniCode;
}
}

View File

@ -0,0 +1,9 @@
package org.pytorch;
interface INativePeer {
void resetNative();
IValue forward(IValue... inputs);
IValue runMethod(String methodName, IValue... inputs);
}

View File

@ -0,0 +1,343 @@
package org.pytorch;
import com.facebook.jni.annotations.DoNotStrip;
import java.util.Locale;
import java.util.Map;
/**
* Java representation of a TorchScript value, which is implemented as tagged union that can be one
* of the supported types: https://pytorch.org/docs/stable/jit.html#types .
*
* <p>Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}.
*
* <p>{@code IValue} objects are constructed with {@code IValue.from(value)}, {@code
* IValue.tupleFrom(value1, value2, ...)}, {@code IValue.listFrom(value1, value2, ...)}, or one of
* the {@code dict} methods, depending on the key type.
*
* <p>Data is retrieved from {@code IValue} objects with the {@code toX()} methods. Note that {@code
* str}-type IValues must be extracted with {@link #toStr()}, rather than {@link #toString()}.
*
* <p>{@code IValue} objects may retain references to objects passed into their constructors, and
* may return references to their internal state from {@code toX()}.
*/
@DoNotStrip
public class IValue {
private static final int TYPE_CODE_NULL = 1;
private static final int TYPE_CODE_TENSOR = 2;
private static final int TYPE_CODE_BOOL = 3;
private static final int TYPE_CODE_LONG = 4;
private static final int TYPE_CODE_DOUBLE = 5;
private static final int TYPE_CODE_STRING = 6;
private static final int TYPE_CODE_TUPLE = 7;
private static final int TYPE_CODE_BOOL_LIST = 8;
private static final int TYPE_CODE_LONG_LIST = 9;
private static final int TYPE_CODE_DOUBLE_LIST = 10;
private static final int TYPE_CODE_TENSOR_LIST = 11;
private static final int TYPE_CODE_LIST = 12;
private static final int TYPE_CODE_DICT_STRING_KEY = 13;
private static final int TYPE_CODE_DICT_LONG_KEY = 14;
private String[] TYPE_NAMES = {
"Unknown",
"Null",
"Tensor",
"Bool",
"Long",
"Double",
"String",
"Tuple",
"BoolList",
"LongList",
"DoubleList",
"TensorList",
"GenericList",
"DictStringKey",
"DictLongKey",
};
@DoNotStrip private final int mTypeCode;
@DoNotStrip private Object mData;
@DoNotStrip
private IValue(int typeCode) {
this.mTypeCode = typeCode;
}
@DoNotStrip
public boolean isNull() {
return TYPE_CODE_NULL == this.mTypeCode;
}
@DoNotStrip
public boolean isTensor() {
return TYPE_CODE_TENSOR == this.mTypeCode;
}
@DoNotStrip
public boolean isBool() {
return TYPE_CODE_BOOL == this.mTypeCode;
}
@DoNotStrip
public boolean isLong() {
return TYPE_CODE_LONG == this.mTypeCode;
}
@DoNotStrip
public boolean isDouble() {
return TYPE_CODE_DOUBLE == this.mTypeCode;
}
@DoNotStrip
public boolean isString() {
return TYPE_CODE_STRING == this.mTypeCode;
}
@DoNotStrip
public boolean isTuple() {
return TYPE_CODE_TUPLE == this.mTypeCode;
}
@DoNotStrip
public boolean isBoolList() {
return TYPE_CODE_BOOL_LIST == this.mTypeCode;
}
@DoNotStrip
public boolean isLongList() {
return TYPE_CODE_LONG_LIST == this.mTypeCode;
}
@DoNotStrip
public boolean isDoubleList() {
return TYPE_CODE_DOUBLE_LIST == this.mTypeCode;
}
@DoNotStrip
public boolean isTensorList() {
return TYPE_CODE_TENSOR_LIST == this.mTypeCode;
}
@DoNotStrip
public boolean isList() {
return TYPE_CODE_LIST == this.mTypeCode;
}
@DoNotStrip
public boolean isDictStringKey() {
return TYPE_CODE_DICT_STRING_KEY == this.mTypeCode;
}
@DoNotStrip
public boolean isDictLongKey() {
return TYPE_CODE_DICT_LONG_KEY == this.mTypeCode;
}
/** Creates a new {@code IValue} of type {@code Optional} that contains no value. */
@DoNotStrip
public static IValue optionalNull() {
return new IValue(TYPE_CODE_NULL);
}
/** Creates a new {@code IValue} of type {@code Tensor}. */
@DoNotStrip
public static IValue from(Tensor tensor) {
final IValue iv = new IValue(TYPE_CODE_TENSOR);
iv.mData = tensor;
return iv;
}
/** Creates a new {@code IValue} of type {@code bool}. */
@DoNotStrip
public static IValue from(boolean value) {
final IValue iv = new IValue(TYPE_CODE_BOOL);
iv.mData = value;
return iv;
}
/** Creates a new {@code IValue} of type {@code int}. */
@DoNotStrip
public static IValue from(long value) {
final IValue iv = new IValue(TYPE_CODE_LONG);
iv.mData = value;
return iv;
}
/** Creates a new {@code IValue} of type {@code float}. */
@DoNotStrip
public static IValue from(double value) {
final IValue iv = new IValue(TYPE_CODE_DOUBLE);
iv.mData = value;
return iv;
}
/** Creates a new {@code IValue} of type {@code str}. */
@DoNotStrip
public static IValue from(String value) {
final IValue iv = new IValue(TYPE_CODE_STRING);
iv.mData = value;
return iv;
}
/** Creates a new {@code IValue} of type {@code List[bool]}. */
@DoNotStrip
public static IValue listFrom(boolean... list) {
final IValue iv = new IValue(TYPE_CODE_BOOL_LIST);
iv.mData = list;
return iv;
}
/** Creates a new {@code IValue} of type {@code List[int]}. */
@DoNotStrip
public static IValue listFrom(long... list) {
final IValue iv = new IValue(TYPE_CODE_LONG_LIST);
iv.mData = list;
return iv;
}
/** Creates a new {@code IValue} of type {@code List[float]}. */
@DoNotStrip
public static IValue listFrom(double... list) {
final IValue iv = new IValue(TYPE_CODE_DOUBLE_LIST);
iv.mData = list;
return iv;
}
/** Creates a new {@code IValue} of type {@code List[Tensor]}. */
@DoNotStrip
public static IValue listFrom(Tensor... list) {
final IValue iv = new IValue(TYPE_CODE_TENSOR_LIST);
iv.mData = list;
return iv;
}
/** Creates a new {@code IValue} of type {@code List[T]}. All elements must have the same type. */
@DoNotStrip
public static IValue listFrom(IValue... array) {
final int size = array.length;
if (size > 0) {
final int typeCode0 = array[0].mTypeCode;
for (int i = 1; i < size; i++) {
if (typeCode0 != array[i].mTypeCode) {
throw new IllegalArgumentException("List must contain items of the same type");
}
}
}
final IValue iv = new IValue(TYPE_CODE_LIST);
iv.mData = array;
return iv;
}
/** Creates a new {@code IValue} of type {@code Tuple[T0, T1, ...]}. */
@DoNotStrip
public static IValue tupleFrom(IValue... array) {
final IValue iv = new IValue(TYPE_CODE_TUPLE);
iv.mData = array;
return iv;
}
/** Creates a new {@code IValue} of type {@code Dict[str, V]}. */
@DoNotStrip
public static IValue dictStringKeyFrom(Map<String, IValue> map) {
final IValue iv = new IValue(TYPE_CODE_DICT_STRING_KEY);
iv.mData = map;
return iv;
}
/** Creates a new {@code IValue} of type {@code Dict[int, V]}. */
@DoNotStrip
public static IValue dictLongKeyFrom(Map<Long, IValue> map) {
final IValue iv = new IValue(TYPE_CODE_DICT_LONG_KEY);
iv.mData = map;
return iv;
}
@DoNotStrip
public Tensor toTensor() {
preconditionType(TYPE_CODE_TENSOR, mTypeCode);
return (Tensor) mData;
}
@DoNotStrip
public boolean toBool() {
preconditionType(TYPE_CODE_BOOL, mTypeCode);
return (boolean) mData;
}
@DoNotStrip
public long toLong() {
preconditionType(TYPE_CODE_LONG, mTypeCode);
return (long) mData;
}
@DoNotStrip
public double toDouble() {
preconditionType(TYPE_CODE_DOUBLE, mTypeCode);
return (double) mData;
}
@DoNotStrip
public String toStr() {
preconditionType(TYPE_CODE_STRING, mTypeCode);
return (String) mData;
}
@DoNotStrip
public boolean[] toBoolList() {
preconditionType(TYPE_CODE_BOOL_LIST, mTypeCode);
return (boolean[]) mData;
}
@DoNotStrip
public long[] toLongList() {
preconditionType(TYPE_CODE_LONG_LIST, mTypeCode);
return (long[]) mData;
}
@DoNotStrip
public double[] toDoubleList() {
preconditionType(TYPE_CODE_DOUBLE_LIST, mTypeCode);
return (double[]) mData;
}
@DoNotStrip
public Tensor[] toTensorList() {
preconditionType(TYPE_CODE_TENSOR_LIST, mTypeCode);
return (Tensor[]) mData;
}
@DoNotStrip
public IValue[] toList() {
preconditionType(TYPE_CODE_LIST, mTypeCode);
return (IValue[]) mData;
}
@DoNotStrip
public IValue[] toTuple() {
preconditionType(TYPE_CODE_TUPLE, mTypeCode);
return (IValue[]) mData;
}
@DoNotStrip
public Map<String, IValue> toDictStringKey() {
preconditionType(TYPE_CODE_DICT_STRING_KEY, mTypeCode);
return (Map<String, IValue>) mData;
}
@DoNotStrip
public Map<Long, IValue> toDictLongKey() {
preconditionType(TYPE_CODE_DICT_LONG_KEY, mTypeCode);
return (Map<Long, IValue>) mData;
}
private void preconditionType(int typeCodeExpected, int typeCode) {
if (typeCode != typeCodeExpected) {
throw new IllegalStateException(
String.format(
Locale.US,
"Expected IValue type %s, actual type %s",
getTypeName(typeCodeExpected),
getTypeName(typeCode)));
}
}
private String getTypeName(int typeCode) {
return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown";
}
}

View File

@ -0,0 +1,49 @@
package org.pytorch;
import android.content.res.AssetManager;
import java.util.Map;
public class LiteModuleLoader {
/**
* Loads a serialized TorchScript module from the specified path on the disk to run on specified
* device. The model should be generated from this api _save_for_lite_interpreter().
*
* @param modelPath path to file that contains the serialized TorchScript module.
* @param extraFiles map with extra files names as keys, content of them will be loaded to values.
* @param device {@link org.pytorch.Device} to use for running specified module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::mobile::Module.
*/
public static Module load(
final String modelPath, final Map<String, String> extraFiles, final Device device) {
return new Module(new LiteNativePeer(modelPath, extraFiles, device));
}
/**
* Loads a serialized TorchScript module from the specified path on the disk to run on CPU. The
* model should be generated from this api _save_for_lite_interpreter().
*
* @param modelPath path to file that contains the serialized TorchScript module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::mobile::Module.
*/
public static Module load(final String modelPath) {
return new Module(new LiteNativePeer(modelPath, null, Device.CPU));
}
/**
* Attention: This is not recommended way of loading production modules, as prepackaged assets
* increase apk size etc. For production usage consider using loading from file on the disk {@link
* org.pytorch.Module#load(String)}.
*
* <p>This method is meant to use in tests and demos.
*/
public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName, final Device device) {
return new Module(new LiteNativePeer(assetName, assetManager, device));
}
public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName) {
return new Module(new LiteNativePeer(assetName, assetManager, Device.CPU));
}
}

View File

@ -0,0 +1,64 @@
package org.pytorch;
import com.facebook.jni.HybridData;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.util.Map;
class LiteNativePeer implements INativePeer {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_jni_lite");
PyTorchCodegenLoader.loadNativeLibs();
}
private final HybridData mHybridData;
private static native HybridData initHybrid(
String moduleAbsolutePath, Map<String, String> extraFiles, int deviceJniCode);
private static native HybridData initHybridAndroidAsset(
String assetName, /* android.content.res.AssetManager */
Object androidAssetManager,
int deviceJniCode);
LiteNativePeer(String moduleAbsolutePath, Map<String, String> extraFiles, Device device) {
mHybridData = initHybrid(moduleAbsolutePath, extraFiles, device.jniCode);
}
LiteNativePeer(
String assetName, /* android.content.res.AssetManager */
Object androidAssetManager,
Device device) {
mHybridData = initHybridAndroidAsset(assetName, androidAssetManager, device.jniCode);
}
/**
* Explicitly destroys the native torch::jit::mobile::Module. Calling this method is not required,
* as the native object will be destroyed when this object is garbage-collected. However, the
* timing of garbage collection is not guaranteed, so proactively calling {@code resetNative} can
* free memory more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
*/
public void resetNative() {
mHybridData.resetNative();
}
/**
* Runs the 'forward' method of this module with the specified arguments.
*
* @param inputs arguments for the TorchScript module's 'forward' method.
* @return return value from the 'forward' method.
*/
public native IValue forward(IValue... inputs);
/**
* Runs the specified method of this module with the specified arguments.
*
* @param methodName name of the TorchScript method to run.
* @param inputs arguments that will be passed to TorchScript method.
* @return return value from the method.
*/
public native IValue runMethod(String methodName, IValue... inputs);
}

View File

@ -0,0 +1,14 @@
package org.pytorch;
public enum MemoryFormat {
CONTIGUOUS(1),
CHANNELS_LAST(2),
CHANNELS_LAST_3D(3),
;
final int jniCode;
MemoryFormat(int jniCode) {
this.jniCode = jniCode;
}
}

View File

@ -0,0 +1,75 @@
// Copyright 2004-present Facebook. All Rights Reserved.
package org.pytorch;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.util.Map;
/** Java wrapper for torch::jit::Module. */
public class Module {
private INativePeer mNativePeer;
/**
* Loads a serialized TorchScript module from the specified path on the disk to run on specified
* device.
*
* @param modelPath path to file that contains the serialized TorchScript module.
* @param extraFiles map with extra files names as keys, content of them will be loaded to values.
* @param device {@link org.pytorch.Device} to use for running specified module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
*/
public static Module load(
final String modelPath, final Map<String, String> extraFiles, final Device device) {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
return new Module(new NativePeer(modelPath, extraFiles, device));
}
/**
* Loads a serialized TorchScript module from the specified path on the disk to run on CPU.
*
* @param modelPath path to file that contains the serialized TorchScript module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
*/
public static Module load(final String modelPath) {
return load(modelPath, null, Device.CPU);
}
Module(INativePeer nativePeer) {
this.mNativePeer = nativePeer;
}
/**
* Runs the 'forward' method of this module with the specified arguments.
*
* @param inputs arguments for the TorchScript module's 'forward' method.
* @return return value from the 'forward' method.
*/
public IValue forward(IValue... inputs) {
return mNativePeer.forward(inputs);
}
/**
* Runs the specified method of this module with the specified arguments.
*
* @param methodName name of the TorchScript method to run.
* @param inputs arguments that will be passed to TorchScript method.
* @return return value from the method.
*/
public IValue runMethod(String methodName, IValue... inputs) {
return mNativePeer.runMethod(methodName, inputs);
}
/**
* Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the
* native object will be destroyed when this object is garbage-collected. However, the timing of
* garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory
* more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
*/
public void destroy() {
mNativePeer.resetNative();
}
}

View File

@ -0,0 +1,46 @@
package org.pytorch;
import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import java.util.Map;
class NativePeer implements INativePeer {
static {
NativeLoader.loadLibrary("pytorch_jni");
PyTorchCodegenLoader.loadNativeLibs();
}
private final HybridData mHybridData;
@DoNotStrip
private static native HybridData initHybrid(
String moduleAbsolutePath, Map<String, String> extraFiles, int deviceJniCode);
@DoNotStrip
private static native HybridData initHybridAndroidAsset(
String assetName, /* android.content.res.AssetManager */
Object androidAssetManager,
int deviceJniCode);
NativePeer(String moduleAbsolutePath, Map<String, String> extraFiles, Device device) {
mHybridData = initHybrid(moduleAbsolutePath, extraFiles, device.jniCode);
}
NativePeer(
String assetName, /* android.content.res.AssetManager */
Object androidAssetManager,
Device device) {
mHybridData = initHybridAndroidAsset(assetName, androidAssetManager, device.jniCode);
}
public void resetNative() {
mHybridData.resetNative();
}
@DoNotStrip
public native IValue forward(IValue... inputs);
@DoNotStrip
public native IValue runMethod(String methodName, IValue... inputs);
}

View File

@ -0,0 +1,50 @@
package org.pytorch;
import android.content.res.AssetManager;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
public final class PyTorchAndroid {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_jni_lite");
PyTorchCodegenLoader.loadNativeLibs();
}
/**
* Attention: This is not recommended way of loading production modules, as prepackaged assets
* increase apk size etc. For production usage consider using loading from file on the disk {@link
* org.pytorch.Module#load(String)}.
*
* <p>This method is meant to use in tests and demos.
*/
public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName, final Device device) {
return new Module(new NativePeer(assetName, assetManager, device));
}
public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName) {
return new Module(new NativePeer(assetName, assetManager, Device.CPU));
}
/**
* Globally sets the number of threads used on native side. Attention: Has global effect, all
* modules use one thread pool with specified number of threads.
*
* @param numThreads number of threads, must be positive number.
*/
public static void setNumThreads(int numThreads) {
if (numThreads < 1) {
throw new IllegalArgumentException("Number of threads cannot be less than 1");
}
nativeSetNumThreads(numThreads);
}
@DoNotStrip
private static native void nativeSetNumThreads(int numThreads);
}

View File

@ -0,0 +1,16 @@
package org.pytorch;
import com.facebook.soloader.nativeloader.NativeLoader;
public class PyTorchCodegenLoader {
public static void loadNativeLibs() {
try {
NativeLoader.loadLibrary("torch-code-gen");
} catch (Throwable t) {
// Loading the codegen lib is best-effort since it's only there for query based builds.
}
}
private PyTorchCodegenLoader() {}
}

View File

@ -0,0 +1,735 @@
package org.pytorch;
import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.Locale;
/**
* Representation of a Tensor. Behavior is similar to PyTorch's tensor objects.
*
* <p>Most tensors will be constructed as {@code Tensor.fromBlob(data, shape)}, where {@code data}
* can be an array or a direct {@link Buffer} (of the proper subclass). Helper methods are provided
* to allocate buffers properly.
*
* <p>To access Tensor data, see {@link #dtype()}, {@link #shape()}, and various {@code getDataAs*}
* methods.
*
* <p>When constructing {@code Tensor} objects with {@code data} as an array, it is not specified
* whether this data is copied or retained as a reference so it is recommended not to modify it
* after constructing. {@code data} passed as a {@link Buffer} is not copied, so it can be modified
* between {@link Module} calls to avoid reallocation. Data retrieved from {@code Tensor} objects
* may be copied or may be a reference to the {@code Tensor}'s internal data buffer. {@code shape}
* is always copied.
*/
public abstract class Tensor {
private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null";
private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null";
private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null";
private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Shape elements must be non negative";
private static final String ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER =
"Data buffer must have native byte order (java.nio.ByteOrder#nativeOrder)";
private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT =
"Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)";
@DoNotStrip final long[] shape;
final MemoryFormat memoryFormat;
private static final int INT_SIZE_BYTES = 4;
private static final int FLOAT_SIZE_BYTES = 4;
private static final int LONG_SIZE_BYTES = 8;
private static final int DOUBLE_SIZE_BYTES = 8;
/**
* Allocates a new direct {@link java.nio.ByteBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#fromBlob(ByteBuffer, long[])}, {@link
* Tensor#fromBlobUnsigned(ByteBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
public static ByteBuffer allocateByteBuffer(int numElements) {
return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder());
}
/**
* Allocates a new direct {@link java.nio.IntBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#fromBlob(IntBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
public static IntBuffer allocateIntBuffer(int numElements) {
return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES)
.order(ByteOrder.nativeOrder())
.asIntBuffer();
}
/**
* Allocates a new direct {@link java.nio.FloatBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
public static FloatBuffer allocateFloatBuffer(int numElements) {
return ByteBuffer.allocateDirect(numElements * FLOAT_SIZE_BYTES)
.order(ByteOrder.nativeOrder())
.asFloatBuffer();
}
/**
* Allocates a new direct {@link java.nio.LongBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#fromBlob(LongBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
public static LongBuffer allocateLongBuffer(int numElements) {
return ByteBuffer.allocateDirect(numElements * LONG_SIZE_BYTES)
.order(ByteOrder.nativeOrder())
.asLongBuffer();
}
/**
* Allocates a new direct {@link java.nio.DoubleBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
public static DoubleBuffer allocateDoubleBuffer(int numElements) {
return ByteBuffer.allocateDirect(numElements * DOUBLE_SIZE_BYTES)
.order(ByteOrder.nativeOrder())
.asDoubleBuffer();
}
/**
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of
* bytes.
*
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor fromBlobUnsigned(byte[] data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape));
byteBuffer.put(data);
return new Tensor_uint8(byteBuffer, shape, memoryFormat);
}
public static Tensor fromBlobUnsigned(byte[] data, long[] shape) {
return fromBlobUnsigned(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of
* bytes.
*
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor fromBlob(byte[] data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape));
byteBuffer.put(data);
return new Tensor_int8(byteBuffer, shape, memoryFormat);
}
public static Tensor fromBlob(byte[] data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of
* ints.
*
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor fromBlob(int[] data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final IntBuffer intBuffer = allocateIntBuffer((int) numel(shape));
intBuffer.put(data);
return new Tensor_int32(intBuffer, shape, memoryFormat);
}
public static Tensor fromBlob(int[] data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array
* of floats.
*
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor fromBlob(float[] data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final FloatBuffer floatBuffer = allocateFloatBuffer((int) numel(shape));
floatBuffer.put(data);
return new Tensor_float32(floatBuffer, shape, memoryFormat);
}
public static Tensor fromBlob(float[] data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of
* longs.
*
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor fromBlob(long[] data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final LongBuffer longBuffer = allocateLongBuffer((int) numel(shape));
longBuffer.put(data);
return new Tensor_int64(longBuffer, shape, memoryFormat);
}
public static Tensor fromBlob(long[] data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.float64 with specified shape and data as array
* of doubles.
*
* @param shape Tensor shape
* @param data Tensor elements
*/
public static Tensor fromBlob(double[] data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final DoubleBuffer doubleBuffer = allocateDoubleBuffer((int) numel(shape));
doubleBuffer.put(data);
return new Tensor_float64(doubleBuffer, shape, memoryFormat);
}
public static Tensor fromBlob(double[] data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data.
*
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_uint8(data, shape, memoryFormat);
}
public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) {
return fromBlobUnsigned(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data.
*
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor fromBlob(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_int8(data, shape, memoryFormat);
}
public static Tensor fromBlob(ByteBuffer data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data.
*
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor fromBlob(IntBuffer data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_int32(data, shape, memoryFormat);
}
public static Tensor fromBlob(IntBuffer data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data.
*
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor fromBlob(FloatBuffer data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_float32(data, shape, memoryFormat);
}
public static Tensor fromBlob(FloatBuffer data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data.
*
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor fromBlob(LongBuffer data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_int64(data, shape, memoryFormat);
}
public static Tensor fromBlob(LongBuffer data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.float64 with specified shape and data.
*
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor fromBlob(DoubleBuffer data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_float64(data, shape, memoryFormat);
}
public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
@DoNotStrip private HybridData mHybridData;
private Tensor(long[] shape, MemoryFormat memoryFormat) {
checkShape(shape);
this.shape = Arrays.copyOf(shape, shape.length);
this.memoryFormat = memoryFormat;
}
/** Returns the number of elements in this tensor. */
public long numel() {
return numel(this.shape);
}
/** Calculates the number of elements in a tensor with the specified shape. */
public static long numel(long[] shape) {
checkShape(shape);
int result = 1;
for (long s : shape) {
result *= s;
}
return result;
}
/** Returns the shape of this tensor. (The array is a fresh copy.) */
public long[] shape() {
return Arrays.copyOf(shape, shape.length);
}
/** Returns the memory format of this tensor. */
public MemoryFormat memoryFormat() {
return memoryFormat;
}
/** @return data type of this tensor. */
public abstract DType dtype();
// Called from native
@DoNotStrip
int dtypeJniCode() {
return dtype().jniCode;
}
// Called from native
@DoNotStrip
int memoryFormatJniCode() {
return memoryFormat.jniCode;
}
/**
* @return a Java byte array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-int8 tensor.
*/
public byte[] getDataAsByteArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
}
/**
* @return a Java byte array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-uint8 tensor.
*/
public byte[] getDataAsUnsignedByteArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
}
/**
* @return a Java int array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-int32 tensor.
*/
public int[] getDataAsIntArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as int array.");
}
/**
* @return a Java float array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-float32 tensor.
*/
public float[] getDataAsFloatArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as float array.");
}
/**
* @return a Java long array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-int64 tensor.
*/
public long[] getDataAsLongArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as long array.");
}
/**
* @return a Java double array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-float64 tensor.
*/
public double[] getDataAsDoubleArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as double array.");
}
@DoNotStrip
Buffer getRawDataBuffer() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer.");
}
static class Tensor_uint8 extends Tensor {
private final ByteBuffer data;
private Tensor_uint8(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) {
super(shape, memoryFormat);
this.data = data;
}
@Override
public DType dtype() {
return DType.UINT8;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public byte[] getDataAsUnsignedByteArray() {
data.rewind();
byte[] arr = new byte[data.remaining()];
data.get(arr);
return arr;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape));
}
}
static class Tensor_int8 extends Tensor {
private final ByteBuffer data;
private Tensor_int8(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) {
super(shape, memoryFormat);
this.data = data;
}
@Override
public DType dtype() {
return DType.INT8;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public byte[] getDataAsByteArray() {
data.rewind();
byte[] arr = new byte[data.remaining()];
data.get(arr);
return arr;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape));
}
}
static class Tensor_int32 extends Tensor {
private final IntBuffer data;
private Tensor_int32(IntBuffer data, long[] shape, MemoryFormat memoryFormat) {
super(shape, memoryFormat);
this.data = data;
}
@Override
public DType dtype() {
return DType.INT32;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public int[] getDataAsIntArray() {
data.rewind();
int[] arr = new int[data.remaining()];
data.get(arr);
return arr;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape));
}
}
static class Tensor_float32 extends Tensor {
private final FloatBuffer data;
Tensor_float32(FloatBuffer data, long[] shape, MemoryFormat memoryFormat) {
super(shape, memoryFormat);
this.data = data;
}
@Override
public float[] getDataAsFloatArray() {
data.rewind();
float[] arr = new float[data.remaining()];
data.get(arr);
return arr;
}
@Override
public DType dtype() {
return DType.FLOAT32;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.float32)", Arrays.toString(shape));
}
}
static class Tensor_int64 extends Tensor {
private final LongBuffer data;
private Tensor_int64(LongBuffer data, long[] shape, MemoryFormat memoryFormat) {
super(shape, memoryFormat);
this.data = data;
}
@Override
public DType dtype() {
return DType.INT64;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public long[] getDataAsLongArray() {
data.rewind();
long[] arr = new long[data.remaining()];
data.get(arr);
return arr;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape));
}
}
static class Tensor_float64 extends Tensor {
private final DoubleBuffer data;
private Tensor_float64(DoubleBuffer data, long[] shape, MemoryFormat memoryFormat) {
super(shape, memoryFormat);
this.data = data;
}
@Override
public DType dtype() {
return DType.FLOAT64;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public double[] getDataAsDoubleArray() {
data.rewind();
double[] arr = new double[data.remaining()];
data.get(arr);
return arr;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape));
}
}
// region checks
private static void checkArgument(boolean expression, String errorMessage, Object... args) {
if (!expression) {
throw new IllegalArgumentException(String.format(Locale.US, errorMessage, args));
}
}
private static void checkShape(long[] shape) {
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
for (int i = 0; i < shape.length; i++) {
checkArgument(shape[i] >= 0, ERROR_MSG_SHAPE_NON_NEGATIVE);
}
}
private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape) {
final long numel = numel(shape);
checkArgument(
numel == dataCapacity,
"Inconsistent data capacity:%d and shape number elements:%d shape:%s",
dataCapacity,
numel,
Arrays.toString(shape));
}
// endregion checks
// Called from native
@DoNotStrip
private static Tensor nativeNewTensor(
ByteBuffer data, long[] shape, int dtype, int memoryFormatCode, HybridData hybridData) {
Tensor tensor = null;
MemoryFormat memoryFormat = MemoryFormat.CONTIGUOUS;
if (MemoryFormat.CHANNELS_LAST.jniCode == memoryFormatCode) {
memoryFormat = MemoryFormat.CHANNELS_LAST;
} else if (MemoryFormat.CHANNELS_LAST_3D.jniCode == memoryFormatCode) {
memoryFormat = MemoryFormat.CHANNELS_LAST_3D;
}
if (DType.FLOAT32.jniCode == dtype) {
tensor = new Tensor_float32(data.asFloatBuffer(), shape, memoryFormat);
} else if (DType.INT32.jniCode == dtype) {
tensor = new Tensor_int32(data.asIntBuffer(), shape, memoryFormat);
} else if (DType.INT64.jniCode == dtype) {
tensor = new Tensor_int64(data.asLongBuffer(), shape, memoryFormat);
} else if (DType.FLOAT64.jniCode == dtype) {
tensor = new Tensor_float64(data.asDoubleBuffer(), shape, memoryFormat);
} else if (DType.UINT8.jniCode == dtype) {
tensor = new Tensor_uint8(data, shape, memoryFormat);
} else if (DType.INT8.jniCode == dtype) {
tensor = new Tensor_int8(data, shape, memoryFormat);
} else {
new IllegalArgumentException("Unknown Tensor dtype");
}
tensor.mHybridData = hybridData;
return tensor;
}
}

View File

@ -0,0 +1,3 @@
<resources>
<string name="app_name">pytorch_android</string>
</resources>

View File

@ -0,0 +1,105 @@
def forward(self, input):
return None
def eqBool(self, input: bool) -> bool:
return input
def eqInt(self, input: int) -> int:
return input
def eqFloat(self, input: float) -> float:
return input
def eqStr(self, input: str) -> str:
return input
def eqTensor(self, input: Tensor) -> Tensor:
return input
def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]:
return input
def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]:
return input
def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]:
return input
def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]:
sum = 0
for x in input:
sum += x
return (input, sum)
def listBoolConjunction(self, input: List[bool]) -> bool:
res = True
for x in input:
res = res and x
return res
def listBoolDisjunction(self, input: List[bool]) -> bool:
res = False
for x in input:
res = res or x
return res
def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]:
sum = 0
for x in input:
sum += x
return (input, sum)
def optionalIntIsNone(self, input: Optional[int]) -> bool:
return input is None
def intEq0None(self, input: int) -> Optional[int]:
if input == 0:
return None
return input
def str3Concat(self, input: str) -> str:
return input + input + input
def newEmptyShapeWithItem(self, input):
return torch.tensor([int(input.item())])[0]
def testAliasWithOffset(self) -> List[Tensor]:
x = torch.tensor([100, 200])
a = [x[0], x[1]]
return a
def testNonContiguous(self):
x = torch.tensor([100, 200, 300])[::2]
assert not x.is_contiguous()
assert x[0] == 100
assert x[1] == 300
return x
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.conv2d(x, w)
if (toChannelsLast):
# memory_format=torch.channels_last
r = r.contiguous(memory_format=2)
else:
r = r.contiguous()
return r
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.conv3d(x, w)
if (toChannelsLast):
# memory_format=torch.channels_last_3d
r = r.contiguous(memory_format=2)
else:
r = r.contiguous()
return r
def contiguous(self, x: Tensor) -> Tensor:
return x.contiguous()
def contiguousChannelsLast(self, x: Tensor) -> Tensor:
# memory_format=torch.channels_last
return x.contiguous(memory_format=2)
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
# memory_format=torch.channels_last_3d
return x.contiguous(memory_format=3)

View File

@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.5)
project(pytorch_vision_jni CXX)
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.")
set(CMAKE_VERBOSE_MAKEFILE ON)
set(pytorch_vision_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
file(GLOB pytorch_vision_SOURCES
${pytorch_vision_cpp_DIR}/pytorch_vision_jni.cpp
)
add_library(pytorch_vision_jni SHARED
${pytorch_vision_SOURCES}
)
target_compile_options(pytorch_vision_jni PRIVATE
-fexceptions
)
set(BUILD_SUBDIR ${ANDROID_ABI})
target_link_libraries(pytorch_vision_jni)

View File

@ -0,0 +1,64 @@
apply plugin: 'com.android.library'
apply plugin: 'maven'
android {
compileSdkVersion rootProject.compileSdkVersion
buildToolsVersion rootProject.buildToolsVersion
defaultConfig {
minSdkVersion rootProject.minSdkVersion
targetSdkVersion rootProject.targetSdkVersion
versionCode 0
versionName "0.1"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
ndk {
abiFilters ABI_FILTERS.split(",")
}
}
buildTypes {
debug {
minifyEnabled false
debuggable true
}
release {
minifyEnabled false
}
}
externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}
useLibrary 'android.test.runner'
useLibrary 'android.test.base'
useLibrary 'android.test.mock'
}
dependencies {
implementation project(':pytorch_android')
implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion
testImplementation 'junit:junit:' + rootProject.junitVersion
testImplementation 'androidx.test:core:' + rootProject.coreVersion
androidTestImplementation 'junit:junit:' + rootProject.junitVersion
androidTestImplementation 'androidx.test:core:' + rootProject.coreVersion
androidTestImplementation 'androidx.test.ext:junit:' + rootProject.extJUnitVersion
androidTestImplementation 'androidx.test:rules:' + rootProject.rulesVersion
androidTestImplementation 'androidx.test:runner:' + rootProject.runnerVersion
}
apply from: rootProject.file('gradle/release.gradle')
task sourcesJar(type: Jar) {
from android.sourceSets.main.java.srcDirs
classifier = 'sources'
}
artifacts.add('archives', sourcesJar)

View File

@ -0,0 +1,4 @@
POM_NAME=pytorch_android_torchvision_lite
POM_DESCRIPTION=pytorch_android_torchvision_lite
POM_ARTIFACT_ID=pytorch_android_torchvision_lite
POM_PACKAGING=aar

View File

@ -0,0 +1,24 @@
package org.pytorch.torchvision;
import static org.junit.Assert.assertArrayEquals;
import android.graphics.Bitmap;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.pytorch.Tensor;
@RunWith(AndroidJUnit4.class)
public class TorchVisionInstrumentedTests {
@Test
public void smokeTest() {
Bitmap bitmap = Bitmap.createBitmap(320, 240, Bitmap.Config.ARGB_8888);
Tensor tensor =
TensorImageUtils.bitmapToFloat32Tensor(
bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB);
assertArrayEquals(new long[] {1l, 3l, 240l, 320l}, tensor.shape());
}
}

View File

@ -0,0 +1,9 @@
package org.pytorch.torchvision.suite;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.pytorch.torchvision.TorchVisionInstrumentedTests;
@RunWith(Suite.class)
@Suite.SuiteClasses({TorchVisionInstrumentedTests.class})
public class TorchVisionInstrumentedTestSuite {}

View File

@ -0,0 +1 @@
<manifest package="org.pytorch.torchvision" />

View File

@ -0,0 +1,187 @@
#include <cassert>
#include <cmath>
#include <vector>
#include "jni.h"
namespace pytorch_vision_jni {
static void imageYUV420CenterCropToFloatBuffer(
JNIEnv* jniEnv,
jclass,
jobject yBuffer,
jint yRowStride,
jint yPixelStride,
jobject uBuffer,
jobject vBuffer,
jint uRowStride,
jint uvPixelStride,
jint imageWidth,
jint imageHeight,
jint rotateCWDegrees,
jint tensorWidth,
jint tensorHeight,
jfloatArray jnormMeanRGB,
jfloatArray jnormStdRGB,
jobject outBuffer,
jint outOffset,
jint memoryFormatCode) {
constexpr static int32_t kMemoryFormatContiguous = 1;
constexpr static int32_t kMemoryFormatChannelsLast = 2;
float* outData = (float*)jniEnv->GetDirectBufferAddress(outBuffer);
jfloat normMeanRGB[3];
jfloat normStdRGB[3];
jniEnv->GetFloatArrayRegion(jnormMeanRGB, 0, 3, normMeanRGB);
jniEnv->GetFloatArrayRegion(jnormStdRGB, 0, 3, normStdRGB);
int widthAfterRtn = imageWidth;
int heightAfterRtn = imageHeight;
bool oddRotation = rotateCWDegrees == 90 || rotateCWDegrees == 270;
if (oddRotation) {
widthAfterRtn = imageHeight;
heightAfterRtn = imageWidth;
}
int cropWidthAfterRtn = widthAfterRtn;
int cropHeightAfterRtn = heightAfterRtn;
if (tensorWidth * heightAfterRtn <= tensorHeight * widthAfterRtn) {
cropWidthAfterRtn = tensorWidth * heightAfterRtn / tensorHeight;
} else {
cropHeightAfterRtn = tensorHeight * widthAfterRtn / tensorWidth;
}
int cropWidthBeforeRtn = cropWidthAfterRtn;
int cropHeightBeforeRtn = cropHeightAfterRtn;
if (oddRotation) {
cropWidthBeforeRtn = cropHeightAfterRtn;
cropHeightBeforeRtn = cropWidthAfterRtn;
}
const int offsetX = (imageWidth - cropWidthBeforeRtn) / 2.f;
const int offsetY = (imageHeight - cropHeightBeforeRtn) / 2.f;
const uint8_t* yData = (uint8_t*)jniEnv->GetDirectBufferAddress(yBuffer);
const uint8_t* uData = (uint8_t*)jniEnv->GetDirectBufferAddress(uBuffer);
const uint8_t* vData = (uint8_t*)jniEnv->GetDirectBufferAddress(vBuffer);
float scale = cropWidthAfterRtn / tensorWidth;
int uvRowStride = uRowStride;
int cropXMult = 1;
int cropYMult = 1;
int cropXAdd = offsetX;
int cropYAdd = offsetY;
if (rotateCWDegrees == 90) {
cropYMult = -1;
cropYAdd = offsetY + (cropHeightBeforeRtn - 1);
} else if (rotateCWDegrees == 180) {
cropXMult = -1;
cropXAdd = offsetX + (cropWidthBeforeRtn - 1);
cropYMult = -1;
cropYAdd = offsetY + (cropHeightBeforeRtn - 1);
} else if (rotateCWDegrees == 270) {
cropXMult = -1;
cropXAdd = offsetX + (cropWidthBeforeRtn - 1);
}
float normMeanRm255 = 255 * normMeanRGB[0];
float normMeanGm255 = 255 * normMeanRGB[1];
float normMeanBm255 = 255 * normMeanRGB[2];
float normStdRm255 = 255 * normStdRGB[0];
float normStdGm255 = 255 * normStdRGB[1];
float normStdBm255 = 255 * normStdRGB[2];
int xBeforeRtn, yBeforeRtn;
int yi, yIdx, uvIdx, ui, vi, a0, ri, gi, bi;
int channelSize = tensorWidth * tensorHeight;
// A bit of code duplication to avoid branching in the cycles
if (memoryFormatCode == kMemoryFormatContiguous) {
int wr = outOffset;
int wg = wr + channelSize;
int wb = wg + channelSize;
for (int y = 0; y < tensorHeight; y++) {
for (int x = 0; x < tensorWidth; x++) {
xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
uvIdx =
(yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
ui = uData[uvIdx];
vi = vData[uvIdx];
yi = yData[yIdx];
yi = (yi - 16) < 0 ? 0 : (yi - 16);
ui -= 128;
vi -= 128;
a0 = 1192 * yi;
ri = (a0 + 1634 * vi) >> 10;
gi = (a0 - 833 * vi - 400 * ui) >> 10;
bi = (a0 + 2066 * ui) >> 10;
ri = ri > 255 ? 255 : ri < 0 ? 0 : ri;
gi = gi > 255 ? 255 : gi < 0 ? 0 : gi;
bi = bi > 255 ? 255 : bi < 0 ? 0 : bi;
outData[wr++] = (ri - normMeanRm255) / normStdRm255;
outData[wg++] = (gi - normMeanGm255) / normStdGm255;
outData[wb++] = (bi - normMeanBm255) / normStdBm255;
}
}
} else if (memoryFormatCode == kMemoryFormatChannelsLast) {
int wc = outOffset;
for (int y = 0; y < tensorHeight; y++) {
for (int x = 0; x < tensorWidth; x++) {
xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
uvIdx =
(yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
ui = uData[uvIdx];
vi = vData[uvIdx];
yi = yData[yIdx];
yi = (yi - 16) < 0 ? 0 : (yi - 16);
ui -= 128;
vi -= 128;
a0 = 1192 * yi;
ri = (a0 + 1634 * vi) >> 10;
gi = (a0 - 833 * vi - 400 * ui) >> 10;
bi = (a0 + 2066 * ui) >> 10;
ri = ri > 255 ? 255 : ri < 0 ? 0 : ri;
gi = gi > 255 ? 255 : gi < 0 ? 0 : gi;
bi = bi > 255 ? 255 : bi < 0 ? 0 : bi;
outData[wc++] = (ri - normMeanRm255) / normStdRm255;
outData[wc++] = (gi - normMeanGm255) / normStdGm255;
outData[wc++] = (bi - normMeanBm255) / normStdBm255;
}
}
} else {
jclass Exception = jniEnv->FindClass("java/lang/IllegalArgumentException");
jniEnv->ThrowNew(Exception, "Illegal memory format code");
}
}
} // namespace pytorch_vision_jni
JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
JNIEnv* env;
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
return JNI_ERR;
}
jclass c =
env->FindClass("org/pytorch/torchvision/TensorImageUtils$NativePeer");
if (c == nullptr) {
return JNI_ERR;
}
static const JNINativeMethod methods[] = {
{"imageYUV420CenterCropToFloatBuffer",
"(Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;Ljava/nio/ByteBuffer;IIIIIII[F[FLjava/nio/Buffer;II)V",
(void*)pytorch_vision_jni::imageYUV420CenterCropToFloatBuffer},
};
int rc = env->RegisterNatives(
c, methods, sizeof(methods) / sizeof(JNINativeMethod));
if (rc != JNI_OK) {
return rc;
}
return JNI_VERSION_1_6;
}

View File

@ -0,0 +1,398 @@
package org.pytorch.torchvision;
import android.graphics.Bitmap;
import android.graphics.ImageFormat;
import android.media.Image;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.Locale;
import org.pytorch.MemoryFormat;
import org.pytorch.Tensor;
/**
* Contains utility functions for {@link org.pytorch.Tensor} creation from {@link
* android.graphics.Bitmap} or {@link android.media.Image} source.
*/
public final class TensorImageUtils {
public static float[] TORCHVISION_NORM_MEAN_RGB = new float[] {0.485f, 0.456f, 0.406f};
public static float[] TORCHVISION_NORM_STD_RGB = new float[] {0.229f, 0.224f, 0.225f};
/**
* Creates new {@link org.pytorch.Tensor} from full {@link android.graphics.Bitmap}, normalized
* with specified in parameters mean and std.
*
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
* order
*/
public static Tensor bitmapToFloat32Tensor(
final Bitmap bitmap,
final float[] normMeanRGB,
final float normStdRGB[],
final MemoryFormat memoryFormat) {
checkNormMeanArg(normMeanRGB);
checkNormStdArg(normStdRGB);
return bitmapToFloat32Tensor(
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB, memoryFormat);
}
public static Tensor bitmapToFloat32Tensor(
final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[]) {
return bitmapToFloat32Tensor(
bitmap,
0,
0,
bitmap.getWidth(),
bitmap.getHeight(),
normMeanRGB,
normStdRGB,
MemoryFormat.CONTIGUOUS);
}
/**
* Writes tensor content from specified {@link android.graphics.Bitmap}, normalized with specified
* in parameters mean and std to specified {@link java.nio.FloatBuffer} with specified offset.
*
* @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data
* @param x - x coordinate of top left corner of bitmap's area
* @param y - y coordinate of top left corner of bitmap's area
* @param width - width of bitmap's area
* @param height - height of bitmap's area
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
* order
*/
public static void bitmapToFloatBuffer(
final Bitmap bitmap,
final int x,
final int y,
final int width,
final int height,
final float[] normMeanRGB,
final float[] normStdRGB,
final FloatBuffer outBuffer,
final int outBufferOffset,
final MemoryFormat memoryFormat) {
checkOutBufferCapacity(outBuffer, outBufferOffset, width, height);
checkNormMeanArg(normMeanRGB);
checkNormStdArg(normStdRGB);
if (memoryFormat != MemoryFormat.CONTIGUOUS && memoryFormat != MemoryFormat.CHANNELS_LAST) {
throw new IllegalArgumentException("Unsupported memory format " + memoryFormat);
}
final int pixelsCount = height * width;
final int[] pixels = new int[pixelsCount];
bitmap.getPixels(pixels, 0, width, x, y, width, height);
if (MemoryFormat.CONTIGUOUS == memoryFormat) {
final int offset_g = pixelsCount;
final int offset_b = 2 * pixelsCount;
for (int i = 0; i < pixelsCount; i++) {
final int c = pixels[i];
float r = ((c >> 16) & 0xff) / 255.0f;
float g = ((c >> 8) & 0xff) / 255.0f;
float b = ((c) & 0xff) / 255.0f;
outBuffer.put(outBufferOffset + i, (r - normMeanRGB[0]) / normStdRGB[0]);
outBuffer.put(outBufferOffset + offset_g + i, (g - normMeanRGB[1]) / normStdRGB[1]);
outBuffer.put(outBufferOffset + offset_b + i, (b - normMeanRGB[2]) / normStdRGB[2]);
}
} else {
for (int i = 0; i < pixelsCount; i++) {
final int c = pixels[i];
float r = ((c >> 16) & 0xff) / 255.0f;
float g = ((c >> 8) & 0xff) / 255.0f;
float b = ((c) & 0xff) / 255.0f;
outBuffer.put(outBufferOffset + 3 * i + 0, (r - normMeanRGB[0]) / normStdRGB[0]);
outBuffer.put(outBufferOffset + 3 * i + 1, (g - normMeanRGB[1]) / normStdRGB[1]);
outBuffer.put(outBufferOffset + 3 * i + 2, (b - normMeanRGB[2]) / normStdRGB[2]);
}
}
}
public static void bitmapToFloatBuffer(
final Bitmap bitmap,
final int x,
final int y,
final int width,
final int height,
final float[] normMeanRGB,
final float[] normStdRGB,
final FloatBuffer outBuffer,
final int outBufferOffset) {
bitmapToFloatBuffer(
bitmap,
x,
y,
width,
height,
normMeanRGB,
normStdRGB,
outBuffer,
outBufferOffset,
MemoryFormat.CONTIGUOUS);
}
/**
* Creates new {@link org.pytorch.Tensor} from specified area of {@link android.graphics.Bitmap},
* normalized with specified in parameters mean and std.
*
* @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data
* @param x - x coordinate of top left corner of bitmap's area
* @param y - y coordinate of top left corner of bitmap's area
* @param width - width of bitmap's area
* @param height - height of bitmap's area
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
* order
*/
public static Tensor bitmapToFloat32Tensor(
final Bitmap bitmap,
int x,
int y,
int width,
int height,
float[] normMeanRGB,
float[] normStdRGB,
MemoryFormat memoryFormat) {
checkNormMeanArg(normMeanRGB);
checkNormStdArg(normStdRGB);
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * width * height);
bitmapToFloatBuffer(
bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0, memoryFormat);
return Tensor.fromBlob(floatBuffer, new long[] {1, 3, height, width}, memoryFormat);
}
public static Tensor bitmapToFloat32Tensor(
final Bitmap bitmap,
int x,
int y,
int width,
int height,
float[] normMeanRGB,
float[] normStdRGB) {
return bitmapToFloat32Tensor(
bitmap, x, y, width, height, normMeanRGB, normStdRGB, MemoryFormat.CONTIGUOUS);
}
/**
* Creates new {@link org.pytorch.Tensor} from specified area of {@link android.media.Image},
* doing optional rotation, scaling (nearest) and center cropping.
*
* @param image {@link android.media.Image} as a source for Tensor data
* @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be
* upright. Range of valid values: 0, 90, 180, 270
* @param tensorWidth return tensor width, must be positive
* @param tensorHeight return tensor height, must be positive
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
* order
*/
public static Tensor imageYUV420CenterCropToFloat32Tensor(
final Image image,
int rotateCWDegrees,
final int tensorWidth,
final int tensorHeight,
float[] normMeanRGB,
float[] normStdRGB,
MemoryFormat memoryFormat) {
if (image.getFormat() != ImageFormat.YUV_420_888) {
throw new IllegalArgumentException(
String.format(
Locale.US, "Image format %d != ImageFormat.YUV_420_888", image.getFormat()));
}
checkNormMeanArg(normMeanRGB);
checkNormStdArg(normStdRGB);
checkRotateCWDegrees(rotateCWDegrees);
checkTensorSize(tensorWidth, tensorHeight);
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * tensorWidth * tensorHeight);
imageYUV420CenterCropToFloatBuffer(
image,
rotateCWDegrees,
tensorWidth,
tensorHeight,
normMeanRGB,
normStdRGB,
floatBuffer,
0,
memoryFormat);
return Tensor.fromBlob(floatBuffer, new long[] {1, 3, tensorHeight, tensorWidth}, memoryFormat);
}
public static Tensor imageYUV420CenterCropToFloat32Tensor(
final Image image,
int rotateCWDegrees,
final int tensorWidth,
final int tensorHeight,
float[] normMeanRGB,
float[] normStdRGB) {
return imageYUV420CenterCropToFloat32Tensor(
image,
rotateCWDegrees,
tensorWidth,
tensorHeight,
normMeanRGB,
normStdRGB,
MemoryFormat.CONTIGUOUS);
}
/**
* Writes tensor content from specified {@link android.media.Image}, doing optional rotation,
* scaling (nearest) and center cropping to specified {@link java.nio.FloatBuffer} with specified
* offset.
*
* @param image {@link android.media.Image} as a source for Tensor data
* @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be
* upright. Range of valid values: 0, 90, 180, 270
* @param tensorWidth return tensor width, must be positive
* @param tensorHeight return tensor height, must be positive
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
* order
* @param outBuffer Output buffer, where tensor content will be written
* @param outBufferOffset Output buffer offset with which tensor content will be written
*/
public static void imageYUV420CenterCropToFloatBuffer(
final Image image,
int rotateCWDegrees,
final int tensorWidth,
final int tensorHeight,
float[] normMeanRGB,
float[] normStdRGB,
final FloatBuffer outBuffer,
final int outBufferOffset,
final MemoryFormat memoryFormat) {
checkOutBufferCapacity(outBuffer, outBufferOffset, tensorWidth, tensorHeight);
if (image.getFormat() != ImageFormat.YUV_420_888) {
throw new IllegalArgumentException(
String.format(
Locale.US, "Image format %d != ImageFormat.YUV_420_888", image.getFormat()));
}
checkNormMeanArg(normMeanRGB);
checkNormStdArg(normStdRGB);
checkRotateCWDegrees(rotateCWDegrees);
checkTensorSize(tensorWidth, tensorHeight);
Image.Plane[] planes = image.getPlanes();
Image.Plane Y = planes[0];
Image.Plane U = planes[1];
Image.Plane V = planes[2];
int memoryFormatJniCode = 0;
if (MemoryFormat.CONTIGUOUS == memoryFormat) {
memoryFormatJniCode = 1;
} else if (MemoryFormat.CHANNELS_LAST == memoryFormat) {
memoryFormatJniCode = 2;
}
NativePeer.imageYUV420CenterCropToFloatBuffer(
Y.getBuffer(),
Y.getRowStride(),
Y.getPixelStride(),
U.getBuffer(),
V.getBuffer(),
U.getRowStride(),
U.getPixelStride(),
image.getWidth(),
image.getHeight(),
rotateCWDegrees,
tensorWidth,
tensorHeight,
normMeanRGB,
normStdRGB,
outBuffer,
outBufferOffset,
memoryFormatJniCode);
}
public static void imageYUV420CenterCropToFloatBuffer(
final Image image,
int rotateCWDegrees,
final int tensorWidth,
final int tensorHeight,
float[] normMeanRGB,
float[] normStdRGB,
final FloatBuffer outBuffer,
final int outBufferOffset) {
imageYUV420CenterCropToFloatBuffer(
image,
rotateCWDegrees,
tensorWidth,
tensorHeight,
normMeanRGB,
normStdRGB,
outBuffer,
outBufferOffset,
MemoryFormat.CONTIGUOUS);
}
private static class NativePeer {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_vision_jni");
}
private static native void imageYUV420CenterCropToFloatBuffer(
ByteBuffer yBuffer,
int yRowStride,
int yPixelStride,
ByteBuffer uBuffer,
ByteBuffer vBuffer,
int uvRowStride,
int uvPixelStride,
int imageWidth,
int imageHeight,
int rotateCWDegrees,
int tensorWidth,
int tensorHeight,
float[] normMeanRgb,
float[] normStdRgb,
Buffer outBuffer,
int outBufferOffset,
int memoryFormatJniCode);
}
private static void checkOutBufferCapacity(
FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) {
if (outBufferOffset + 3 * tensorWidth * tensorHeight > outBuffer.capacity()) {
throw new IllegalStateException("Buffer underflow");
}
}
private static void checkTensorSize(int tensorWidth, int tensorHeight) {
if (tensorHeight <= 0 || tensorWidth <= 0) {
throw new IllegalArgumentException("tensorHeight and tensorWidth must be positive");
}
}
private static void checkRotateCWDegrees(int rotateCWDegrees) {
if (rotateCWDegrees != 0
&& rotateCWDegrees != 90
&& rotateCWDegrees != 180
&& rotateCWDegrees != 270) {
throw new IllegalArgumentException("rotateCWDegrees must be one of 0, 90, 180, 270");
}
}
private static void checkNormStdArg(float[] normStdRGB) {
if (normStdRGB.length != 3) {
throw new IllegalArgumentException("normStdRGB length must be 3");
}
}
private static void checkNormMeanArg(float[] normMeanRGB) {
if (normMeanRGB.length != 3) {
throw new IllegalArgumentException("normMeanRGB length must be 3");
}
}
}

View File

@ -0,0 +1,3 @@
<resources>
<string name="app_name">pytorch_android_torchvision</string>
</resources>

57
android/run_tests.sh Executable file
View File

@ -0,0 +1,57 @@
#!/bin/bash
set -eux
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android
source "$PYTORCH_ANDROID_DIR/common.sh"
check_android_sdk
check_gradle
# Run android instrumented tests on x86 emulator
ADB_PATH=$ANDROID_HOME/platform-tools/adb
echo "Expecting running emulator"
$ADB_PATH devices
DEVICES_COUNT=$($ADB_PATH devices | awk 'NF' | wc -l)
echo "DEVICES_COUNT:$DEVICES_COUNT"
if [ "$DEVICES_COUNT" -eq 1 ]; then
echo "Unable to found connected android emulators"
cat <<- EOF
To start android emulator:
1. Install android sdkmanager packages
$ANDROID_HOME/tools/bin/sdkmanager "system-images;android-25;google_apis;x86"
to specify proxy add params: --proxy=http --proxy_host=fwdproxy --proxy_port=8080
2. Create android virtual device
$ANDROID_HOME/tools/bin/avdmanager create avd --name "x86_android25" --package "system-images;android-25;google_apis;x86"
3. Start emulator in headless mode without audio
$ANDROID_HOME/tools/emulator -avd x86_android25 -no-audio -no-window
4. Check that emulator is running
$ANDROID_HOME/platform-tools/adb devices
If everything is ok the output will be:
List of devices attached
emulator-5554 device
EOF
exit 1
fi
echo "Waiting for emulator boot completed"
$ADB_PATH wait-for-device shell 'while [[ -z $(getprop sys.boot_completed) ]]; do sleep 1; done;'
{
# The test currently takes about 10 minutes
retry $GRADLE_PATH -PABI_FILTERS=x86 -p $PYTORCH_ANDROID_DIR connectedAndroidTest
} || {
echo "::error::Check https://github.com/pytorch/pytorch/tree/master/test/mobile/model_test to see how to fix the failed mobile test"
exit 1
}

6
android/settings.gradle Normal file
View File

@ -0,0 +1,6 @@
include ':app', ':pytorch_android', ':pytorch_android_torchvision', ':pytorch_host', ':test_app'
project(':pytorch_android_torchvision').projectDir = file('pytorch_android_torchvision')
project(':pytorch_host').projectDir = file('pytorch_android/host')
project(':test_app').projectDir = file('test_app/app')

9
android/test_app/.gitignore vendored Normal file
View File

@ -0,0 +1,9 @@
local.properties
**/*.iml
.gradle
gradlew*
gradle/wrapper
.idea/*
.DS_Store
build
.externalNativeBuild

View File

@ -0,0 +1,38 @@
cmake_minimum_required(VERSION 3.5)
set(PROJECT_NAME pytorch_testapp_jni)
project(${PROJECT_NAME} CXX)
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.")
set(CMAKE_VERBOSE_MAKEFILE ON)
set(build_DIR ${CMAKE_SOURCE_DIR}/build)
set(pytorch_testapp_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
message(STATUS "ANDROID_STL:${ANDROID_STL}")
file(GLOB pytorch_testapp_SOURCES
${pytorch_testapp_cpp_DIR}/pytorch_testapp_jni.cpp
)
add_library(${PROJECT_NAME} SHARED
${pytorch_testapp_SOURCES}
)
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
target_compile_options(${PROJECT_NAME} PRIVATE
-fexceptions
)
set(BUILD_SUBDIR ${ANDROID_ABI})
target_include_directories(${PROJECT_NAME} PRIVATE
${PYTORCH_INCLUDE_DIRS}
)
find_library(PYTORCH_LIBRARY pytorch_jni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)
target_link_libraries(${PROJECT_NAME}
${PYTORCH_LIBRARY}
log)

View File

@ -0,0 +1,190 @@
apply plugin: 'com.android.application'
repositories {
jcenter()
maven {
url "https://oss.sonatype.org/content/repositories/snapshots"
}
flatDir {
dirs 'aars'
}
}
android {
configurations {
extractForNativeBuild
}
compileOptions {
sourceCompatibility 1.8
targetCompatibility 1.8
}
compileSdkVersion rootProject.compileSdkVersion
buildToolsVersion rootProject.buildToolsVersion
defaultConfig {
applicationId "org.pytorch.testapp"
minSdkVersion rootProject.minSdkVersion
targetSdkVersion rootProject.targetSdkVersion
versionCode 1
versionName "1.0"
ndk {
abiFilters ABI_FILTERS.split(",")
}
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//externalNativeBuild {
// cmake {
// abiFilters ABI_FILTERS.split(",")
// arguments "-DANDROID_STL=c++_shared"
// }
//}
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2q.pt\"")
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
buildConfigField("boolean", "NATIVE_BUILD", 'false')
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
buildConfigField(
"int",
"BUILD_LITE_INTERPRETER",
System.env.BUILD_LITE_INTERPRETER != null ? System.env.BUILD_LITE_INTERPRETER : "1"
)
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
}
buildTypes {
debug {
minifyEnabled false
debuggable true
}
release {
minifyEnabled false
}
}
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//externalNativeBuild {
// cmake {
// path "CMakeLists.txt"
// }
//}
flavorDimensions "model", "build", "activity"
productFlavors {
mnet {
dimension "model"
applicationIdSuffix ".mnet"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2.ptl\"")
addManifestPlaceholders([APP_NAME: "MNET"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"")
}
// NB: This is not working atm https://github.com/pytorch/pytorch/issues/102966
mnetVulkan {
dimension "model"
applicationIdSuffix ".mnet_vulkan"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2_vulkan.ptl\"")
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
addManifestPlaceholders([APP_NAME: "MNET_VULKAN"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"")
}
resnet18 {
dimension "model"
applicationIdSuffix ".resnet18"
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.ptl\"")
addManifestPlaceholders([APP_NAME: "RN18"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
}
local {
dimension "build"
}
nightly {
dimension "build"
}
aar {
dimension "build"
}
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//nativeBuild {
// dimension "build"
// buildConfigField("boolean", "NATIVE_BUILD", "true")
//}
camera {
dimension "activity"
addManifestPlaceholders([MAIN_ACTIVITY: "org.pytorch.testapp.CameraActivity"])
}
base {
dimension "activity"
sourceSets {
main {
java {
exclude 'org/pytorch/testapp/CameraActivity.java'
}
}
}
}
}
packagingOptions {
doNotStrip '**.so'
}
// Filtering for CI
if (!testAppAllVariantsEnabled.toBoolean()) {
variantFilter { variant ->
def names = variant.flavors*.name
if (names.contains("nightly")
|| names.contains("camera")
|| names.contains("aar")
|| names.contains("nativeBuild")) {
setIgnore(true)
}
}
}
}
tasks.all { task ->
// Disable externalNativeBuild for all but nativeBuild variant
if (task.name.startsWith('externalNativeBuild')
&& !task.name.contains('NativeBuild')) {
task.enabled = false
}
}
dependencies {
implementation 'com.android.support:appcompat-v7:28.0.0'
implementation 'com.facebook.soloader:nativeloader:0.10.5'
localImplementation project(':pytorch_android')
localImplementation project(':pytorch_android_torchvision')
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//nativeBuildImplementation(name: 'pytorch_android-release', ext: 'aar')
//nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
nightlyImplementation 'org.pytorch:pytorch_android:2.2.0-SNAPSHOT'
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:2.2.0-SNAPSHOT'
aarImplementation(name:'pytorch_android', ext:'aar')
aarImplementation(name:'pytorch_android_torchvision', ext:'aar')
aarImplementation 'com.facebook.soloader:nativeloader:0.10.5'
aarImplementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
def camerax_version = "1.0.0-alpha05"
cameraImplementation "androidx.camera:camera-core:$camerax_version"
cameraImplementation "androidx.camera:camera-camera2:$camerax_version"
cameraImplementation 'com.google.android.material:material:1.0.0-beta01'
}
task extractAARForNativeBuild {
doLast {
configurations.extractForNativeBuild.files.each {
def file = it.absoluteFile
copy {
from zipTree(file)
into "$buildDir/$file.name"
include "headers/**"
include "jni/**"
}
}
}
}
tasks.whenTaskAdded { task ->
if (task.name.contains('externalNativeBuild')) {
task.dependsOn(extractAARForNativeBuild)
}
}

View File

@ -0,0 +1,27 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.pytorch.testapp">
<application
android:allowBackup="true"
android:label="${APP_NAME}"
android:supportsRtl="true"
android:theme="@style/AppTheme">
<activity android:name="${MAIN_ACTIVITY}">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
<uses-permission android:name="android.permission.CAMERA" />
<!--
Permissions required by the Snapdragon Profiler to collect GPU metrics.
-->
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
</manifest>

View File

@ -0,0 +1,3 @@
*
*/
!.gitignore

Some files were not shown because too many files have changed in this diff Show More