diff --git a/.gitmodules b/.gitmodules index 6aa37c656728..bf1fc38c6797 100644 --- a/.gitmodules +++ b/.gitmodules @@ -70,6 +70,10 @@ ignore = dirty path = third_party/fbgemm url = https://github.com/pytorch/fbgemm +[submodule "android/libs/fbjni"] + ignore = dirty + path = android/libs/fbjni + url = https://github.com/facebookincubator/fbjni.git [submodule "third_party/XNNPACK"] ignore = dirty path = third_party/XNNPACK diff --git a/android/README.md b/android/README.md index af89ce8b1ada..d6a1ba1d4479 100644 --- a/android/README.md +++ b/android/README.md @@ -1,5 +1,240 @@ -# Android Demo App +# Android -Please refer to [pytorch-labs/executorch-examples](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch). +## Demo applications and tutorials -Please join our [Discord](https://discord.com/channels/1334270993966825602/1349854760299270284) for any questions. +Demo applications with code walk-through can be find in [this github repo](https://github.com/pytorch/android-demo-app). + +## Publishing + +##### Release +Release artifacts are published to jcenter: + +```groovy +repositories { + jcenter() +} + +# lite interpreter build +dependencies { + implementation 'org.pytorch:pytorch_android_lite:1.10.0' + implementation 'org.pytorch:pytorch_android_torchvision_lite:1.10.0' +} + +# full jit build +dependencies { + implementation 'org.pytorch:pytorch_android:1.10.0' + implementation 'org.pytorch:pytorch_android_torchvision:1.10.0' +} +``` + +##### Nightly + +Nightly(snapshots) builds are published every night from `master` branch to [nexus sonatype snapshots repository](https://oss.sonatype.org/#nexus-search;quick~pytorch_android) + +To use them repository must be specified explicitly: +```groovy +repositories { + maven { + url "https://oss.sonatype.org/content/repositories/snapshots" + } +} + +# lite interpreter build +dependencies { + ... + implementation 'org.pytorch:pytorch_android_lite:1.12.0-SNAPSHOT' + implementation 'org.pytorch:pytorch_android_torchvision_lite:1.12.0-SNAPSHOT' + ... +} + +# full jit build +dependencies { + ... + implementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT' + implementation 'org.pytorch:pytorch_android_torchvision:1.12.0-SNAPSHOT' + ... +} +``` +The current nightly(snapshots) version is the value of `VERSION_NAME` in `gradle.properties` in current folder, at this moment it is `1.8.0-SNAPSHOT`. + +## Building PyTorch Android from Source + +In some cases you might want to use a local build of pytorch android, for example you may build custom libtorch binary with another set of operators or to make local changes. + +For this you can use `./scripts/build_pytorch_android.sh` script. +```bash +git clone https://github.com/pytorch/pytorch.git +cd pytorch +git submodule update --init --recursive +bash ./scripts/build_pytorch_android.sh +``` + +The workflow contains several steps: + +1\. Build libtorch for android for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64) + +2\. Create symbolic links to the results of those builds: +`android/pytorch_android/src/main/jniLibs/${abi}` to the directory with output libraries +`android/pytorch_android/src/main/cpp/libtorch_include/${abi}` to the directory with headers. These directories are used to build `libpytorch.so` library that will be loaded on android device. + +3\. And finally run `gradle` in `android/pytorch_android` directory with task `assembleRelease` + +Script requires that Android SDK, Android NDK and gradle are installed. +They are specified as environment variables: + +`ANDROID_HOME` - path to [Android SDK](https://developer.android.com/studio/command-line/sdkmanager.html) + +`ANDROID_NDK` - path to [Android NDK](https://developer.android.com/studio/projects/install-ndk). It's recommended to use NDK 21.x. + +`GRADLE_HOME` - path to [gradle](https://gradle.org/releases/) + + +After successful build you should see the result as aar file: + +```bash +$ find pytorch_android/build/ -type f -name *aar +pytorch_android/build/outputs/aar/pytorch_android.aar +pytorch_android_torchvision/build/outputs/aar/pytorch_android.aar +``` + +It can be used directly in android projects, as a gradle dependency: +```groovy +allprojects { + repositories { + flatDir { + dirs 'libs' + } + } +} + +dependencies { + implementation(name:'pytorch_android', ext:'aar') + implementation(name:'pytorch_android_torchvision', ext:'aar') + ... + implementation 'com.facebook.soloader:nativeloader:0.10.5' + implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2' +} +``` +We also have to add all transitive dependencies of our aars. +As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.5'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them. +(In case of using maven dependencies they are added automatically from `pom.xml`). + +You can check out [test app example](https://github.com/pytorch/pytorch/blob/master/android/test_app/app/build.gradle) that uses aars directly. + +## Linking to prebuilt libtorch library from gradle dependency + +In some cases, you may want to use libtorch from your android native build. +You can do it without building libtorch android, using native libraries from PyTorch android gradle dependency. +For that, you will need to add the next lines to your gradle build. +```groovy +android { +... + configurations { + extractForNativeBuild + } +... + compileOptions { + externalNativeBuild { + cmake { + arguments "-DANDROID_STL=c++_shared" + } + } + } +... + externalNativeBuild { + cmake { + path "CMakeLists.txt" + } + } +} + +dependencies { + extractForNativeBuild('org.pytorch:pytorch_android:1.10.0') +} + +task extractAARForNativeBuild { + doLast { + configurations.extractForNativeBuild.files.each { + def file = it.absoluteFile + copy { + from zipTree(file) + into "$buildDir/$file.name" + include "headers/**" + include "jni/**" + } + } + } +} + +tasks.whenTaskAdded { task -> + if (task.name.contains('externalNativeBuild')) { + task.dependsOn(extractAARForNativeBuild) + } +} +``` + +pytorch_android aar contains headers to link in `headers` folder and native libraries in `jni/$ANDROID_ABI/`. +As PyTorch native libraries use `ANDROID_STL` - we should use `ANDROID_STL=c++_shared` to have only one loaded binary of STL. + +The added task will unpack them to gradle build directory. + +In your native build you can link to them adding these lines to your CMakeLists.txt: + + +```cmake +# Relative path of gradle build directory to CMakeLists.txt +set(build_DIR ${CMAKE_SOURCE_DIR}/build) + +file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers") +file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}") + +set(BUILD_SUBDIR ${ANDROID_ABI}) +target_include_directories(${PROJECT_NAME} PRIVATE + ${PYTORCH_INCLUDE_DIRS} +) + +find_library(PYTORCH_LIBRARY pytorch_jni + PATHS ${PYTORCH_LINK_DIRS} + NO_CMAKE_FIND_ROOT_PATH) + +find_library(FBJNI_LIBRARY fbjni + PATHS ${PYTORCH_LINK_DIRS} + NO_CMAKE_FIND_ROOT_PATH) + +target_link_libraries(${PROJECT_NAME} + ${PYTORCH_LIBRARY}) + ${FBJNI_LIBRARY}) + +``` +If your CMakeLists.txt file is located in the same directory as your build.gradle, `set(build_DIR ${CMAKE_SOURCE_DIR}/build)` should work for you. But if you have another location of it, you may need to change it. + +After that, you can use libtorch C++ API from your native code. +```cpp +#include +#include +#include +namespace pytorch_testapp_jni { +namespace { + struct JITCallGuard { + c10::InferenceMode guard; + torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false}; + }; +} + +void loadAndForwardModel(const std::string& modelPath) { + JITCallGuard guard; + torch::jit::Module module = torch::jit::load(modelPath); + module.eval(); + torch::Tensor t = torch::randn({1, 3, 224, 224}); + c10::IValue t_out = module.forward({t}); +} +} +``` + +To load torchscript model for mobile we need some special setup which is placed in `struct JITCallGuard` in this example. It may change in future, you can track the latest changes keeping an eye in our [pytorch android jni code]([https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp#L28) + +[Example of linking to libtorch from aar](https://github.com/pytorch/pytorch/tree/master/android/test_app) + +## PyTorch Android API Javadoc + +You can find more details about the PyTorch Android API in the [Javadoc](https://pytorch.org/javadoc/). diff --git a/android/build.gradle b/android/build.gradle new file mode 100644 index 000000000000..d58faaff95fd --- /dev/null +++ b/android/build.gradle @@ -0,0 +1,40 @@ +allprojects { + buildscript { + ext { + minSdkVersion = 21 + targetSdkVersion = 28 + compileSdkVersion = 28 + buildToolsVersion = '28.0.3' + + coreVersion = "1.2.0" + extJUnitVersion = "1.1.1" + runnerVersion = "1.2.0" + rulesVersion = "1.2.0" + junitVersion = "4.12" + + fbjniJavaOnlyVersion = "0.2.2" + soLoaderNativeLoaderVersion = "0.10.5" + } + + repositories { + google() + mavenLocal() + mavenCentral() + jcenter() + } + + dependencies { + classpath 'com.android.tools.build:gradle:4.1.2' + classpath 'com.vanniktech:gradle-maven-publish-plugin:0.14.2' + } + } + + repositories { + google() + jcenter() + } +} + +ext.deps = [ + jsr305: 'com.google.code.findbugs:jsr305:3.0.1', +] diff --git a/android/build_test_app.sh b/android/build_test_app.sh new file mode 100755 index 000000000000..1f552fa44869 --- /dev/null +++ b/android/build_test_app.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -eux + +PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)" +PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android + +echo "PYTORCH_DIR:$PYTORCH_DIR" + +source "$PYTORCH_ANDROID_DIR/common.sh" + +check_android_sdk +check_gradle +parse_abis_list "$@" +build_android + +# To set proxy for gradle add following lines to ./gradle/gradle.properties: +# systemProp.http.proxyHost=... +# systemProp.http.proxyPort=8080 +# systemProp.https.proxyHost=... +# systemProp.https.proxyPort=8080 + +if [ "$CUSTOM_ABIS_LIST" = true ]; then + NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug +else + NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug +fi + +find $PYTORCH_ANDROID_DIR -type f -name *apk + +find $PYTORCH_ANDROID_DIR -type f -name *apk | xargs echo "To install apk run: $ANDROID_HOME/platform-tools/adb install -r " diff --git a/android/build_test_app_custom.sh b/android/build_test_app_custom.sh new file mode 100755 index 000000000000..8577a344fa4c --- /dev/null +++ b/android/build_test_app_custom.sh @@ -0,0 +1,32 @@ +#!/bin/bash +############################################################################### +# This script tests the custom selective build flow for PyTorch Android, which +# optimizes library size by only including ops used by a specific model. +############################################################################### + +set -eux + +PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)" +PYTORCH_ANDROID_DIR="${PYTORCH_DIR}/android" +BUILD_ROOT="${PYTORCH_DIR}/build_pytorch_android_custom" + +source "${PYTORCH_ANDROID_DIR}/common.sh" + +prepare_model_and_dump_root_ops() { + cd "${BUILD_ROOT}" + MODEL="${BUILD_ROOT}/MobileNetV2.pt" + ROOT_OPS="${BUILD_ROOT}/MobileNetV2.yaml" + python "${PYTORCH_ANDROID_DIR}/test_app/make_assets_custom.py" + cp "${MODEL}" "${PYTORCH_ANDROID_DIR}/test_app/app/src/main/assets/mobilenet2.pt" +} + +# Start building +mkdir -p "${BUILD_ROOT}" +check_android_sdk +check_gradle +parse_abis_list "$@" +prepare_model_and_dump_root_ops +SELECTED_OP_LIST="${ROOT_OPS}" build_android + +# TODO: change this to build test_app instead +$GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean assembleRelease diff --git a/android/common.sh b/android/common.sh new file mode 100644 index 000000000000..c7c508e2c90a --- /dev/null +++ b/android/common.sh @@ -0,0 +1,74 @@ +#!/bin/bash +set -eux + +############################################################################## +# Common util functions for Android build scripts. +############################################################################## + +if [ -z "$PYTORCH_DIR" ]; then + echo "PYTORCH_DIR not set!" + exit 1 +fi + +retry () { + "$@" || (sleep 10 && "$@") || (sleep 20 && "$@") || (sleep 40 && "$@") +} + +check_android_sdk() { + if [ -z "$ANDROID_HOME" ]; then + echo "ANDROID_HOME not set; please set it to Android sdk directory" + exit 1 + fi + + if [ ! -d "$ANDROID_HOME" ]; then + echo "ANDROID_HOME not a directory; did you install it under $ANDROID_HOME?" + exit 1 + fi + echo "ANDROID_HOME:$ANDROID_HOME" +} + +check_gradle() { + GRADLE_PATH=$PYTORCH_DIR/android/gradlew + echo "GRADLE_PATH:$GRADLE_PATH" +} + +parse_abis_list() { + # sync with https://github.com/pytorch/pytorch/blob/0ca0e02685a9d033ac4f04e2fa5c8ba6dbc5ae50/android/gradle.properties#L1 + ABIS_LIST="armeabi-v7a,arm64-v8a,x86,x86_64" + CUSTOM_ABIS_LIST=false + if [ $# -gt 0 ]; then + ABIS_LIST=$1 + CUSTOM_ABIS_LIST=true + fi + + echo "ABIS_LIST:$ABIS_LIST" + echo "CUSTOM_ABIS_LIST:$CUSTOM_ABIS_LIST" +} + +build_android() { + PYTORCH_ANDROID_DIR="$PYTORCH_DIR/android" + BUILD_ROOT="${BUILD_ROOT:-$PYTORCH_DIR}" + echo "BUILD_ROOT:$BUILD_ROOT" + + LIB_DIR="$PYTORCH_ANDROID_DIR/pytorch_android/src/main/jniLibs" + INCLUDE_DIR="$PYTORCH_ANDROID_DIR/pytorch_android/src/main/cpp/libtorch_include" + + # These directories only contain symbolic links. + rm -rf "$LIB_DIR" && mkdir -p "$LIB_DIR" + rm -rf "$INCLUDE_DIR" && mkdir -p "$INCLUDE_DIR" + + for abi in $(echo "$ABIS_LIST" | tr ',' '\n') + do + echo "abi:$abi" + ANDROID_BUILD_ROOT="$BUILD_ROOT/build_android_$abi" + ANDROID_ABI="$abi" \ + BUILD_ROOT="$ANDROID_BUILD_ROOT" \ + "$PYTORCH_DIR/scripts/build_android.sh" \ + -DANDROID_CCACHE="$(which ccache)" \ + -DUSE_LITE_INTERPRETER_PROFILER="OFF" + + echo "$abi build output lib,include at $ANDROID_BUILD_ROOT/install" + ln -s "$ANDROID_BUILD_ROOT/install/lib" "$LIB_DIR/$abi" + ln -s "$ANDROID_BUILD_ROOT/install/include" "$INCLUDE_DIR/$abi" + done +} diff --git a/android/gradle.properties b/android/gradle.properties new file mode 100644 index 000000000000..f7f9bdba010c --- /dev/null +++ b/android/gradle.properties @@ -0,0 +1,26 @@ +ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64 + +VERSION_NAME=2.2.0-SNAPSHOT +GROUP=org.pytorch +MAVEN_GROUP=org.pytorch +SONATYPE_STAGING_PROFILE=orgpytorch +POM_URL=https://github.com/pytorch/pytorch/tree/master/android +POM_SCM_URL=https://github.com/pytorch/pytorch.git +POM_SCM_CONNECTION=scm:git:https://github.com/pytorch/pytorch +POM_SCM_DEV_CONNECTION=scm:git:git@github.com:pytorch/pytorch.git +POM_LICENSE_NAME=BSD 3-Clause +POM_LICENSE_URL=https://github.com/pytorch/pytorch/blob/master/LICENSE +POM_ISSUES_URL=https://github.com/pytorch/pytorch/issues +POM_LICENSE_DIST=repo +POM_DEVELOPER_ID=pytorch +POM_DEVELOPER_NAME=pytorch + +# Gradle internals +org.gradle.internal.repository.max.retries=1 +org.gradle.jvmargs=-XX:MaxMetaspaceSize=1024m + +android.useAndroidX=true +android.enableJetifier=true + +nativeLibsDoNotStrip=false +testAppAllVariantsEnabled=false diff --git a/android/gradle/android_tasks.gradle b/android/gradle/android_tasks.gradle new file mode 100644 index 000000000000..6bba126b2f66 --- /dev/null +++ b/android/gradle/android_tasks.gradle @@ -0,0 +1,11 @@ +afterEvaluate { project -> + if (POM_PACKAGING == 'aar') { + task headersJar(type: Jar) { + archiveClassifier.set('headers') + from("$rootDir/cxx/") { + include '**/*.h' + } + } + artifacts.add('archives', headersJar) + } +} diff --git a/android/gradle/release.gradle b/android/gradle/release.gradle new file mode 100644 index 000000000000..7a6948349364 --- /dev/null +++ b/android/gradle/release.gradle @@ -0,0 +1,3 @@ +apply from: rootProject.file('gradle/android_tasks.gradle') + +apply plugin: 'com.vanniktech.maven.publish' diff --git a/android/gradle/wrapper/gradle-wrapper.jar b/android/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 000000000000..94336fcae912 Binary files /dev/null and b/android/gradle/wrapper/gradle-wrapper.jar differ diff --git a/android/gradle/wrapper/gradle-wrapper.properties b/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 000000000000..442d9132ea32 --- /dev/null +++ b/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,5 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/android/gradlew b/android/gradlew new file mode 100755 index 000000000000..cccdd3d517fc --- /dev/null +++ b/android/gradlew @@ -0,0 +1,172 @@ +#!/usr/bin/env sh + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="" + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=$(save "$@") + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong +if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then + cd "$(dirname "$0")" +fi + +exec "$JAVACMD" "$@" diff --git a/android/gradlew.bat b/android/gradlew.bat new file mode 100644 index 000000000000..f9553162f122 --- /dev/null +++ b/android/gradlew.bat @@ -0,0 +1,84 @@ +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS= + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windows variants + +if not "%OS%" == "Windows_NT" goto win9xME_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/android/libs/fbjni b/android/libs/fbjni new file mode 160000 index 000000000000..7e1e1fe3858c --- /dev/null +++ b/android/libs/fbjni @@ -0,0 +1 @@ +Subproject commit 7e1e1fe3858c63c251c637ae41a20de425dde96f diff --git a/android/pytorch_android/CMakeLists.txt b/android/pytorch_android/CMakeLists.txt new file mode 100644 index 000000000000..653256c256c2 --- /dev/null +++ b/android/pytorch_android/CMakeLists.txt @@ -0,0 +1,184 @@ +cmake_minimum_required(VERSION 3.5) +option(BUILD_LITE_INTERPRETER "Master flag to build pytorch_jni_lite" ON) +message( + STATUS + "BUILD_LITE_INTERPRETER (pytorch_jni_lite): ${BUILD_LITE_INTERPRETER}") + +if(BUILD_LITE_INTERPRETER) + project(pytorch_jni_lite CXX) + set(PYTORCH_JNI_TARGET pytorch_jni_lite) +else() + project(pytorch_jni CXX) + set(PYTORCH_JNI_TARGET pytorch_jni) +endif() + +include(GNUInstallDirs) + +set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_VERBOSE_MAKEFILE ON) +message(STATUS "ANDROID_STL:${ANDROID_STL}") + +set(TRACE_ENABLED OFF) +if(DEFINED ENV{TRACE_ENABLED}) + if($ENV{TRACE_ENABLED} STREQUAL "1") + message(STATUS "TRACE_ENABLED ON") + set(TRACE_ENABLED ON) + endif() +endif() +if(NOT TRACE_ENABLED) + message(STATUS "TRACE_ENABLED OFF") +endif() + +set(USE_VULKAN OFF) + +set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp) + +if(ANDROID_ABI) + set(USE_VULKAN ON) + set(libtorch_include_DIR ${pytorch_android_DIR}/libtorch_include/${ANDROID_ABI}) + set(BUILD_SUBDIR ${ANDROID_ABI}) +elseif(BUILD_LIBTORCH_WITH_JNI) + # Don't need LIBTORCH_HOME if we're building from within PyTorch. +else() + # Building against a pre-built libtorch. + if(NOT LIBTORCH_HOME) + message(FATAL_ERROR + "pytorch_android requires LIBTORCH_HOME to be defined for non-Android builds.") + endif() + set(libtorch_include_DIR ${LIBTORCH_HOME}/include) + link_directories(${LIBTORCH_HOME}/lib) + set(BUILD_SUBDIR host) +endif() + +message(STATUS "libtorch dir:${libtorch_DIR}") + +configure_file( + ${pytorch_android_DIR}/cmake_macros.h.in + ${pytorch_android_DIR}/cmake_macros.h) + + +if(BUILD_LITE_INTERPRETER) + file(GLOB pytorch_android_SOURCES + ${pytorch_android_DIR}/pytorch_jni_lite.cpp + ${pytorch_android_DIR}/pytorch_jni_common.cpp + ${pytorch_android_DIR}/pytorch_jni_common.h + ) +else() + file(GLOB pytorch_android_SOURCES + ${pytorch_android_DIR}/pytorch_jni_jit.cpp + ${pytorch_android_DIR}/pytorch_jni_common.cpp + ${pytorch_android_DIR}/pytorch_jni_common.h + ) +endif() +add_library(${PYTORCH_JNI_TARGET} SHARED ${pytorch_android_SOURCES}) + +if(APPLE) + # Need to add rpath so dlopen can find dependencies. + add_custom_command(TARGET pytorch_jni + POST_BUILD COMMAND + ${CMAKE_INSTALL_NAME_TOOL} -add_rpath "@loader_path" + $) +endif() + +target_compile_options(${PYTORCH_JNI_TARGET} PRIVATE + -fexceptions +) +target_include_directories(${PYTORCH_JNI_TARGET} BEFORE +PUBLIC $) + +set(fbjni_DIR ${CMAKE_CURRENT_LIST_DIR}/../libs/fbjni/) +set(fbjni_BUILD_DIR ${CMAKE_BINARY_DIR}/fbjni/${BUILD_SUBDIR}) + +add_subdirectory(${fbjni_DIR} ${fbjni_BUILD_DIR}) + +# ---[ Vulkan deps +if(USE_VULKAN) + set(Vulkan_LIBS) + set(Vulkan_INCLUDES) + include(${CMAKE_CURRENT_LIST_DIR}/../../cmake/VulkanDependencies.cmake) +endif() + +if(ANDROID_ABI) + function(import_static_lib name) + add_library(${name} STATIC IMPORTED) + set_property( + TARGET ${name} + PROPERTY IMPORTED_LOCATION + ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/${name}.a) + endfunction(import_static_lib) + + import_static_lib(libtorch) + import_static_lib(libtorch_cpu) + import_static_lib(libc10) + import_static_lib(libnnpack) + import_static_lib(libXNNPACK) + import_static_lib(libmicrokernels-prod) + import_static_lib(libpytorch_qnnpack) + import_static_lib(libpthreadpool) + import_static_lib(libeigen_blas) + import_static_lib(libcpuinfo) + import_static_lib(libclog) + + # Link most things statically on Android. + set(pytorch_jni_LIBS + fbjni + -Wl,--gc-sections + -Wl,--whole-archive + libtorch + libtorch_cpu + -Wl,--no-whole-archive + libc10 + libnnpack + libXNNPACK + libmicrokernels-prod + libpytorch_qnnpack + libpthreadpool + libeigen_blas + libcpuinfo + libclog + ) +else() + # Prefer dynamic linking on the host + set(pytorch_jni_LIBS + fbjni + torch + torch_cpu + c10 + cpuinfo + ) + + if(USE_NNPACK) + list(APPEND pytorch_jni_LIBS nnpack) + endif() + + if(USE_XNNPACK) + list(APPEND pytorch_jni_LIBS XNNPACK) + list(APPEND pytorch_jni_LIBS microkernels-prod) + endif() + + if(USE_SYSTEM_PTHREADPOOL) + list(APPEND pytorch_jni_LIBS pthreadpool) + endif() + + if(USE_PYTORCH_QNNPACK) + list(APPEND pytorch_jni_LIBS pytorch_qnnpack) + list(APPEND pytorch_jni_LIBS clog) + endif() + +endif() + +if(USE_VULKAN) + list(APPEND pytorch_jni_LIBS ${Vulkan_LIBS}) +endif() + +target_link_libraries(${PYTORCH_JNI_TARGET} ${pytorch_jni_LIBS}) + +install(TARGETS ${PYTORCH_JNI_TARGET} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) #For windows + +if(MSVC) + install(FILES $ DESTINATION ${CMAKE_INSTALL_LIBDIR} OPTIONAL) + install(TARGETS ${PYTORCH_JNI_TARGET} DESTINATION ${CMAKE_INSTALL_LIBDIR}) +endif() diff --git a/android/pytorch_android/build.gradle b/android/pytorch_android/build.gradle new file mode 100644 index 000000000000..d10f6a305085 --- /dev/null +++ b/android/pytorch_android/build.gradle @@ -0,0 +1,163 @@ +apply plugin: 'com.android.library' +apply plugin: 'maven' + +android { + compileSdkVersion rootProject.compileSdkVersion + buildToolsVersion rootProject.buildToolsVersion + + defaultConfig { + minSdkVersion rootProject.minSdkVersion + targetSdkVersion rootProject.targetSdkVersion + versionCode 0 + versionName "0.1" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + ndk { + abiFilters ABI_FILTERS.split(",") + } + externalNativeBuild { + cmake { + if(System.env.BUILD_LITE_INTERPRETER == '0') { + arguments "-DANDROID_STL=c++_shared", "-DBUILD_LITE_INTERPRETER=OFF", "-DUSE_LITE_INTERPRETER_PROFILER=OFF" + } else { + arguments "-DANDROID_STL=c++_shared", "-DUSE_LITE_INTERPRETER_PROFILER=OFF" + } + } + } + } + buildTypes { + debug { + minifyEnabled false + debuggable true + } + release { + minifyEnabled false + } + } + sourceSets { + main { + java { + if(System.env.BUILD_LITE_INTERPRETER == '0') { + println 'Build pytorch_jni' + exclude 'org/pytorch/LiteModuleLoader.java' + exclude 'org/pytorch/LiteNativePeer.java' + } else { + println 'Build pytorch_jni_lite' + } + } + jniLibs.srcDirs = ['src/main/jniLibs'] + manifest.srcFile 'src/main/AndroidManifest.xml' + } + androidTest { + java { + if(System.env.BUILD_LITE_INTERPRETER == '0') { + println 'Build test for full jit (pytorch_jni)' + exclude 'org/pytorch/PytorchHostTests.java' + exclude 'org/pytorch/PytorchLiteInstrumentedTests.java' + exclude 'org/pytorch/suite/PytorchLiteInstrumentedTestSuite.java' + } else { + println 'Build test for lite interpreter (pytorch_jni_lite)' + exclude 'org/pytorch/PytorchHostTests.java' + exclude 'org/pytorch/PytorchInstrumentedTests.java' + exclude 'org/pytorch/suite/PytorchInstrumentedTestSuite.java' + } + } + } + } + externalNativeBuild { + cmake { + path "CMakeLists.txt" + } + } + + packagingOptions { + if (nativeLibsDoNotStrip.toBoolean()) { + doNotStrip "**/*.so" + logger.warn('WARNING: nativeLibsDoNotStrip==true; debug symbols included') + } + } + + useLibrary 'android.test.runner' + useLibrary 'android.test.base' + useLibrary 'android.test.mock' +} + +dependencies { + implementation 'com.facebook.fbjni:fbjni-java-only:' + rootProject.fbjniJavaOnlyVersion + implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion + + testImplementation 'junit:junit:' + rootProject.junitVersion + testImplementation 'androidx.test:core:' + rootProject.coreVersion + + androidTestImplementation 'junit:junit:' + rootProject.junitVersion + androidTestImplementation 'androidx.test:core:' + rootProject.coreVersion + androidTestImplementation 'androidx.test.ext:junit:' + rootProject.extJUnitVersion + androidTestImplementation 'androidx.test:rules:' + rootProject.rulesVersion + androidTestImplementation 'androidx.test:runner:' + rootProject.runnerVersion +} + +apply from: rootProject.file('gradle/release.gradle') + +task sourcesJar(type: Jar) { + from android.sourceSets.main.java.srcDirs + classifier = 'sources' +} + +def getLibtorchHeadersDir() { + def abi = ABI_FILTERS.split(",")[0] + return "$rootDir/pytorch_android/src/main/cpp/libtorch_include/$abi" +} + +afterEvaluate { + if (POM_PACKAGING == 'aar') { + android.libraryVariants.all { variant -> + variant.outputs.each { output -> + File f = output.outputFile + if (f.name.endsWith(".aar")) { + output.assemble.finalizedBy addFolderToAarTask( + "addHeadersToAar" + variant.name, + f.path, + getLibtorchHeadersDir(), + "headers") + } + } + } + } +} + +tasks.whenTaskAdded { task -> + if (task.name.startsWith("bundle") && task.name.endsWith("Aar")) { + doLast { + addFolderToAar("addHeadersTo" + task.name, task.archivePath, getLibtorchHeadersDir(), 'headers') + } + } +} + +def addFolderToAarTask(taskName, aarPath, folderPath, folderPathInAar) { + return tasks.register(taskName) { + doLast { + addFolderToAar(taskName, aarPath, folderPath, folderPathInAar) + } + } +} + +def addFolderToAar(taskName, aarPath, folderPath, folderPathInAar) { + def tmpDir = file("${buildDir}/${taskName}") + tmpDir.mkdir() + def tmpDirFolder = file("${tmpDir.path}/${folderPathInAar}") + tmpDirFolder.mkdir() + copy { + from zipTree(aarPath) + into tmpDir + } + copy { + from fileTree(folderPath) + into tmpDirFolder + } + ant.zip(destfile: aarPath) { + fileset(dir: tmpDir.path) + } + delete tmpDir +} + +artifacts.add('archives', sourcesJar) diff --git a/android/pytorch_android/generate_test_asset.cpp b/android/pytorch_android/generate_test_asset.cpp new file mode 100644 index 000000000000..6105a0904196 --- /dev/null +++ b/android/pytorch_android/generate_test_asset.cpp @@ -0,0 +1,20 @@ +#include +#include +#include + +#include +#include +#include + +int main(int argc, char* argv[]) { + std::string input_file_path{argv[1]}; + std::string output_file_path{argv[2]}; + + std::ifstream ifs(input_file_path); + std::stringstream buffer; + buffer << ifs.rdbuf(); + torch::jit::Module m("TestModule"); + + m.define(buffer.str()); + m.save(output_file_path); +} diff --git a/android/pytorch_android/generate_test_torchscripts.py b/android/pytorch_android/generate_test_torchscripts.py new file mode 100644 index 000000000000..55622da89268 --- /dev/null +++ b/android/pytorch_android/generate_test_torchscripts.py @@ -0,0 +1,151 @@ +from typing import Optional + +import torch +from torch import Tensor + + +OUTPUT_DIR = "src/androidTest/assets/" + + +def scriptAndSave(module, fileName): + print("-" * 80) + script_module = torch.jit.script(module) + print(script_module.graph) + outputFileName = OUTPUT_DIR + fileName + # note that the lite interpreter model can also be used in full JIT + script_module._save_for_lite_interpreter(outputFileName) + print("Saved to " + outputFileName) + print("=" * 80) + + +class Test(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, input): + return None + + @torch.jit.script_method + def eqBool(self, input: bool) -> bool: + return input + + @torch.jit.script_method + def eqInt(self, input: int) -> int: + return input + + @torch.jit.script_method + def eqFloat(self, input: float) -> float: + return input + + @torch.jit.script_method + def eqStr(self, input: str) -> str: + return input + + @torch.jit.script_method + def eqTensor(self, input: Tensor) -> Tensor: + return input + + @torch.jit.script_method + def eqDictStrKeyIntValue(self, input: dict[str, int]) -> dict[str, int]: + return input + + @torch.jit.script_method + def eqDictIntKeyIntValue(self, input: dict[int, int]) -> dict[int, int]: + return input + + @torch.jit.script_method + def eqDictFloatKeyIntValue(self, input: dict[float, int]) -> dict[float, int]: + return input + + @torch.jit.script_method + def listIntSumReturnTuple(self, input: list[int]) -> tuple[list[int], int]: + sum = 0 + for x in input: + sum += x + return (input, sum) + + @torch.jit.script_method + def listBoolConjunction(self, input: list[bool]) -> bool: + res = True + for x in input: + res = res and x + return res + + @torch.jit.script_method + def listBoolDisjunction(self, input: list[bool]) -> bool: + res = False + for x in input: + res = res or x + return res + + @torch.jit.script_method + def tupleIntSumReturnTuple( + self, input: tuple[int, int, int] + ) -> tuple[tuple[int, int, int], int]: + sum = 0 + for x in input: + sum += x + return (input, sum) + + @torch.jit.script_method + def optionalIntIsNone(self, input: Optional[int]) -> bool: + return input is None + + @torch.jit.script_method + def intEq0None(self, input: int) -> Optional[int]: + if input == 0: + return None + return input + + @torch.jit.script_method + def str3Concat(self, input: str) -> str: + return input + input + input + + @torch.jit.script_method + def newEmptyShapeWithItem(self, input): + return torch.tensor([int(input.item())])[0] + + @torch.jit.script_method + def testAliasWithOffset(self) -> list[Tensor]: + x = torch.tensor([100, 200]) + a = [x[0], x[1]] + return a + + @torch.jit.script_method + def testNonContiguous(self): + x = torch.tensor([100, 200, 300])[::2] + assert not x.is_contiguous() + assert x[0] == 100 + assert x[1] == 300 + return x + + @torch.jit.script_method + def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: + r = torch.nn.functional.conv2d(x, w) + if toChannelsLast: + r = r.contiguous(memory_format=torch.channels_last) + else: + r = r.contiguous() + return r + + @torch.jit.script_method + def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: + r = torch.nn.functional.conv3d(x, w) + if toChannelsLast: + r = r.contiguous(memory_format=torch.channels_last_3d) + else: + r = r.contiguous() + return r + + @torch.jit.script_method + def contiguous(self, x: Tensor) -> Tensor: + return x.contiguous() + + @torch.jit.script_method + def contiguousChannelsLast(self, x: Tensor) -> Tensor: + return x.contiguous(memory_format=torch.channels_last) + + @torch.jit.script_method + def contiguousChannelsLast3d(self, x: Tensor) -> Tensor: + return x.contiguous(memory_format=torch.channels_last_3d) + + +scriptAndSave(Test(), "test.pt") diff --git a/android/pytorch_android/gradle.properties b/android/pytorch_android/gradle.properties new file mode 100644 index 000000000000..bfc0db61bf60 --- /dev/null +++ b/android/pytorch_android/gradle.properties @@ -0,0 +1,4 @@ +POM_NAME=pytorch_android_lite pytorch android api +POM_DESCRIPTION=pytorch_android_lite pytorch android api +POM_ARTIFACT_ID=pytorch_android_lite +POM_PACKAGING=aar diff --git a/android/pytorch_android/host/build.gradle b/android/pytorch_android/host/build.gradle new file mode 100644 index 000000000000..0cd076bb94fa --- /dev/null +++ b/android/pytorch_android/host/build.gradle @@ -0,0 +1,42 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the Apache-2 license found in the +// LICENSE file in the root directory of this source tree. + +plugins { + id 'java-library' +} + +repositories { + mavenLocal() + jcenter() +} + +sourceSets { + main { + java { + srcDir '../src/main/java' + exclude 'org/pytorch/PyTorchAndroid.java' + exclude 'org/pytorch/LitePyTorchAndroid.java' + exclude 'org/pytorch/LiteModuleLoader.java' + exclude 'org/pytorch/LiteNativePeer.java' + } + } + test { + java { + srcDir '../src/androidTest/java' + exclude '**/PytorchInstrumented*' + exclude '**/PytorchLiteInstrumented*' + } + resources.srcDirs = ["../src/androidTest/assets"] + } +} + +dependencies { + compileOnly 'com.google.code.findbugs:jsr305:3.0.1' + implementation 'com.facebook.soloader:nativeloader:0.10.1' + implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2' + testImplementation 'junit:junit:4.12' +} + +apply from: rootProject.file('gradle/release.gradle') diff --git a/android/pytorch_android/host/gradle.properties b/android/pytorch_android/host/gradle.properties new file mode 100644 index 000000000000..19fcffcdace5 --- /dev/null +++ b/android/pytorch_android/host/gradle.properties @@ -0,0 +1,4 @@ +POM_NAME=pytorch_java_only pytorch java api +POM_DESCRIPTION=pytorch_java_only pytorch java api +POM_ARTIFACT_ID=pytorch_java_only +POM_PACKAGING=jar diff --git a/android/pytorch_android/src/androidTest/assets/activation_ops.ptl b/android/pytorch_android/src/androidTest/assets/activation_ops.ptl new file mode 100644 index 000000000000..179f426ae7cd Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/activation_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/android_api_module.ptl b/android/pytorch_android/src/androidTest/assets/android_api_module.ptl new file mode 100644 index 000000000000..9adfb84bf855 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/android_api_module.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/blas_lapack_ops.ptl b/android/pytorch_android/src/androidTest/assets/blas_lapack_ops.ptl new file mode 100644 index 000000000000..fea933ee644f Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/blas_lapack_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/comparison_ops.ptl b/android/pytorch_android/src/androidTest/assets/comparison_ops.ptl new file mode 100644 index 000000000000..01b1c153e751 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/comparison_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/convolution_ops.ptl b/android/pytorch_android/src/androidTest/assets/convolution_ops.ptl new file mode 100644 index 000000000000..db253a207a33 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/convolution_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/distance_function_ops.ptl b/android/pytorch_android/src/androidTest/assets/distance_function_ops.ptl new file mode 100644 index 000000000000..cc4d994f440a Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/distance_function_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/dropout_ops.ptl b/android/pytorch_android/src/androidTest/assets/dropout_ops.ptl new file mode 100644 index 000000000000..422c2f60e6be Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/dropout_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/dynamic_quant_ops.ptl b/android/pytorch_android/src/androidTest/assets/dynamic_quant_ops.ptl new file mode 100644 index 000000000000..0bbbce9671c3 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/dynamic_quant_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/fused_quant_ops.ptl b/android/pytorch_android/src/androidTest/assets/fused_quant_ops.ptl new file mode 100644 index 000000000000..9d2b3f9dde1a Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/fused_quant_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/general_quant_ops.ptl b/android/pytorch_android/src/androidTest/assets/general_quant_ops.ptl new file mode 100644 index 000000000000..7d4888e0bc81 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/general_quant_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/linear_ops.ptl b/android/pytorch_android/src/androidTest/assets/linear_ops.ptl new file mode 100644 index 000000000000..ca9066c03dc4 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/linear_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/loss_function_ops.ptl b/android/pytorch_android/src/androidTest/assets/loss_function_ops.ptl new file mode 100644 index 000000000000..4c0592e5485a Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/loss_function_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/mobilenet_v2.ptl b/android/pytorch_android/src/androidTest/assets/mobilenet_v2.ptl new file mode 100644 index 000000000000..9b8297a250d3 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/mobilenet_v2.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/nn_utils_ops.ptl b/android/pytorch_android/src/androidTest/assets/nn_utils_ops.ptl new file mode 100644 index 000000000000..5d008eab03b9 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/nn_utils_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/normalization_ops.ptl b/android/pytorch_android/src/androidTest/assets/normalization_ops.ptl new file mode 100644 index 000000000000..d85bd06c763b Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/normalization_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/other_math_ops.ptl b/android/pytorch_android/src/androidTest/assets/other_math_ops.ptl new file mode 100644 index 000000000000..7209c3b3bd1f Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/other_math_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/padding_ops.ptl b/android/pytorch_android/src/androidTest/assets/padding_ops.ptl new file mode 100644 index 000000000000..02e57ba20712 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/padding_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/pointwise_ops.ptl b/android/pytorch_android/src/androidTest/assets/pointwise_ops.ptl new file mode 100644 index 000000000000..948ed4832660 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/pointwise_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/pooling_ops.ptl b/android/pytorch_android/src/androidTest/assets/pooling_ops.ptl new file mode 100644 index 000000000000..df051163413f Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/pooling_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/recurrent_ops.ptl b/android/pytorch_android/src/androidTest/assets/recurrent_ops.ptl new file mode 100644 index 000000000000..245ceb454d53 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/recurrent_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/reduction_ops.ptl b/android/pytorch_android/src/androidTest/assets/reduction_ops.ptl new file mode 100644 index 000000000000..13771302c668 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/reduction_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/sampling_ops.ptl b/android/pytorch_android/src/androidTest/assets/sampling_ops.ptl new file mode 100644 index 000000000000..416be7cb1279 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/sampling_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/shuffle_ops.ptl b/android/pytorch_android/src/androidTest/assets/shuffle_ops.ptl new file mode 100644 index 000000000000..5e5520118764 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/shuffle_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/sparse_ops.ptl b/android/pytorch_android/src/androidTest/assets/sparse_ops.ptl new file mode 100644 index 000000000000..a16f68f8f95f Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/sparse_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/spectral_ops.ptl b/android/pytorch_android/src/androidTest/assets/spectral_ops.ptl new file mode 100644 index 000000000000..9828dd2ba901 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/spectral_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/static_quant_ops.ptl b/android/pytorch_android/src/androidTest/assets/static_quant_ops.ptl new file mode 100644 index 000000000000..d0a0a254d1ef Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/static_quant_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/tensor_creation_ops.ptl b/android/pytorch_android/src/androidTest/assets/tensor_creation_ops.ptl new file mode 100644 index 000000000000..d897b43cd36c Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/tensor_creation_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/tensor_general_ops.ptl b/android/pytorch_android/src/androidTest/assets/tensor_general_ops.ptl new file mode 100644 index 000000000000..6f2855ea83ea Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/tensor_general_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/tensor_indexing_ops.ptl b/android/pytorch_android/src/androidTest/assets/tensor_indexing_ops.ptl new file mode 100644 index 000000000000..ac9cb8c4b94a Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/tensor_indexing_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/tensor_typing_ops.ptl b/android/pytorch_android/src/androidTest/assets/tensor_typing_ops.ptl new file mode 100644 index 000000000000..3e2f4d8cc689 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/tensor_typing_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/tensor_view_ops.ptl b/android/pytorch_android/src/androidTest/assets/tensor_view_ops.ptl new file mode 100644 index 000000000000..5e2dc8294842 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/tensor_view_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/test.pt b/android/pytorch_android/src/androidTest/assets/test.pt new file mode 100644 index 000000000000..016b6d666a2a Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/test.pt differ diff --git a/android/pytorch_android/src/androidTest/assets/torchscript_builtin_ops.ptl b/android/pytorch_android/src/androidTest/assets/torchscript_builtin_ops.ptl new file mode 100644 index 000000000000..2d2532df2fd2 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/torchscript_builtin_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/torchscript_collection_ops.ptl b/android/pytorch_android/src/androidTest/assets/torchscript_collection_ops.ptl new file mode 100644 index 000000000000..ce434b3b4210 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/torchscript_collection_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/transformer_ops.ptl b/android/pytorch_android/src/androidTest/assets/transformer_ops.ptl new file mode 100644 index 000000000000..ebb2bd693604 Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/transformer_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/assets/vision_function_ops.ptl b/android/pytorch_android/src/androidTest/assets/vision_function_ops.ptl new file mode 100644 index 000000000000..c9c45655e2bc Binary files /dev/null and b/android/pytorch_android/src/androidTest/assets/vision_function_ops.ptl differ diff --git a/android/pytorch_android/src/androidTest/cpp/pytorch_jni_common_test.cpp b/android/pytorch_android/src/androidTest/cpp/pytorch_jni_common_test.cpp new file mode 100644 index 000000000000..eab3d8d1c8eb --- /dev/null +++ b/android/pytorch_android/src/androidTest/cpp/pytorch_jni_common_test.cpp @@ -0,0 +1,21 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include +#include "caffe2/android/pytorch_android/src/main/cpp/pytorch_jni_common.h" + +using namespace ::testing; + +TEST(pytorch_jni_common_test, newJIValueFromAtIValue) { + auto dict = c10::impl::GenericDict( + c10::dynT(), c10::dynT()); + auto dictCallback = [](auto&&) { + return facebook::jni::local_ref{}; + }; + EXPECT_NO_THROW(pytorch_jni::JIValue::newJIValueFromAtIValue( + dict, dictCallback, dictCallback)); +} diff --git a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchHostTests.java b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchHostTests.java new file mode 100644 index 000000000000..afdde74c5bde --- /dev/null +++ b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchHostTests.java @@ -0,0 +1,25 @@ +package org.pytorch; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.Objects; + +public class PytorchHostTests extends PytorchTestBase { + + @Override + protected Module loadModel(String path) throws IOException { + return Module.load(assetFilePath(path)); + } + + private String assetFilePath(String assetName) throws IOException { + Path tempFile = Files.createTempFile("test", ".pt"); + try (InputStream resource = + Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) { + Files.copy(resource, tempFile, StandardCopyOption.REPLACE_EXISTING); + } + return tempFile.toAbsolutePath().toString(); + } +} diff --git a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java new file mode 100644 index 000000000000..20c30d1587c8 --- /dev/null +++ b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java @@ -0,0 +1,42 @@ +package org.pytorch; + +import android.content.Context; +import androidx.test.InstrumentationRegistry; +import androidx.test.runner.AndroidJUnit4; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import org.junit.runner.RunWith; + +@RunWith(AndroidJUnit4.class) +public class PytorchInstrumentedTests extends PytorchTestBase { + + @Override + protected Module loadModel(String path) throws IOException { + return Module.load(assetFilePath(path)); + } + + private String assetFilePath(String assetName) throws IOException { + final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); + File file = new File(appContext.getFilesDir(), assetName); + if (file.exists() && file.length() > 0) { + return file.getAbsolutePath(); + } + + try (InputStream is = appContext.getAssets().open(assetName)) { + try (OutputStream os = new FileOutputStream(file)) { + byte[] buffer = new byte[4 * 1024]; + int read; + while ((read = is.read(buffer)) != -1) { + os.write(buffer, 0, read); + } + os.flush(); + } + return file.getAbsolutePath(); + } catch (IOException e) { + throw e; + } + } +} diff --git a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchLiteInstrumentedTests.java b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchLiteInstrumentedTests.java new file mode 100644 index 000000000000..bc62270a6fa8 --- /dev/null +++ b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchLiteInstrumentedTests.java @@ -0,0 +1,42 @@ +package org.pytorch; + +import android.content.Context; +import androidx.test.InstrumentationRegistry; +import androidx.test.runner.AndroidJUnit4; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import org.junit.runner.RunWith; + +@RunWith(AndroidJUnit4.class) +public class PytorchLiteInstrumentedTests extends PytorchTestBase { + + @Override + protected Module loadModel(String path) throws IOException { + return LiteModuleLoader.load(assetFilePath(path)); + } + + private String assetFilePath(String assetName) throws IOException { + final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); + File file = new File(appContext.getFilesDir(), assetName); + if (file.exists() && file.length() > 0) { + return file.getAbsolutePath(); + } + + try (InputStream is = appContext.getAssets().open(assetName)) { + try (OutputStream os = new FileOutputStream(file)) { + byte[] buffer = new byte[4 * 1024]; + int read; + while ((read = is.read(buffer)) != -1) { + os.write(buffer, 0, read); + } + os.flush(); + } + return file.getAbsolutePath(); + } catch (IOException e) { + throw e; + } + } +} diff --git a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchTestBase.java b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchTestBase.java new file mode 100644 index 000000000000..7980a34c0434 --- /dev/null +++ b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchTestBase.java @@ -0,0 +1,694 @@ +package org.pytorch; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.Ignore; + +public abstract class PytorchTestBase { + private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl"; + + @Test + public void testForwardNull() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1})); + assertTrue(input.isTensor()); + final IValue output = module.forward(input); + assertTrue(output.isNull()); + } + + @Test + public void testEqBool() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + for (boolean value : new boolean[] {false, true}) { + final IValue input = IValue.from(value); + assertTrue(input.isBool()); + assertTrue(value == input.toBool()); + final IValue output = module.runMethod("eqBool", input); + assertTrue(output.isBool()); + assertTrue(value == output.toBool()); + } + } + + @Test + public void testEqInt() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) { + final IValue input = IValue.from(value); + assertTrue(input.isLong()); + assertTrue(value == input.toLong()); + final IValue output = module.runMethod("eqInt", input); + assertTrue(output.isLong()); + assertTrue(value == output.toLong()); + } + } + + @Test + public void testEqFloat() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + double[] values = + new double[] { + -Double.MAX_VALUE, + Double.MAX_VALUE, + -Double.MIN_VALUE, + Double.MIN_VALUE, + -Math.exp(1.d), + -Math.sqrt(2.d), + -3.1415f, + 3.1415f, + -1, + 0, + 1, + }; + for (double value : values) { + final IValue input = IValue.from(value); + assertTrue(input.isDouble()); + assertTrue(value == input.toDouble()); + final IValue output = module.runMethod("eqFloat", input); + assertTrue(output.isDouble()); + assertTrue(value == output.toDouble()); + } + } + + @Test + public void testEqTensor() throws IOException { + final long[] inputTensorShape = new long[] {1, 3, 224, 224}; + final long numElements = Tensor.numel(inputTensorShape); + final float[] inputTensorData = new float[(int) numElements]; + for (int i = 0; i < numElements; ++i) { + inputTensorData[i] = i; + } + final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape); + + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + final IValue input = IValue.from(inputTensor); + assertTrue(input.isTensor()); + assertTrue(inputTensor == input.toTensor()); + final IValue output = module.runMethod("eqTensor", input); + assertTrue(output.isTensor()); + final Tensor outputTensor = output.toTensor(); + assertNotNull(outputTensor); + assertArrayEquals(inputTensorShape, outputTensor.shape()); + float[] outputData = outputTensor.getDataAsFloatArray(); + for (int i = 0; i < numElements; i++) { + assertTrue(inputTensorData[i] == outputData[i]); + } + } + + @Test + public void testEqDictIntKeyIntValue() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + final Map inputMap = new HashMap<>(); + + inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE)); + inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE)); + inputMap.put(0l, IValue.from(0l)); + inputMap.put(1l, IValue.from(-1l)); + inputMap.put(-1l, IValue.from(1l)); + + final IValue input = IValue.dictLongKeyFrom(inputMap); + assertTrue(input.isDictLongKey()); + + final IValue output = module.runMethod("eqDictIntKeyIntValue", input); + assertTrue(output.isDictLongKey()); + + final Map outputMap = output.toDictLongKey(); + assertTrue(inputMap.size() == outputMap.size()); + for (Map.Entry entry : inputMap.entrySet()) { + assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong()); + } + } + + @Test + public void testEqDictStrKeyIntValue() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + final Map inputMap = new HashMap<>(); + + inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE)); + inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE)); + inputMap.put("long_0", IValue.from(0l)); + inputMap.put("long_1", IValue.from(1l)); + inputMap.put("long_-1", IValue.from(-1l)); + + final IValue input = IValue.dictStringKeyFrom(inputMap); + assertTrue(input.isDictStringKey()); + + final IValue output = module.runMethod("eqDictStrKeyIntValue", input); + assertTrue(output.isDictStringKey()); + + final Map outputMap = output.toDictStringKey(); + assertTrue(inputMap.size() == outputMap.size()); + for (Map.Entry entry : inputMap.entrySet()) { + assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong()); + } + } + + @Test + public void testListIntSumReturnTuple() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + + for (int n : new int[] {0, 1, 128}) { + long[] a = new long[n]; + long sum = 0; + for (int i = 0; i < n; i++) { + a[i] = i; + sum += a[i]; + } + final IValue input = IValue.listFrom(a); + assertTrue(input.isLongList()); + + final IValue output = module.runMethod("listIntSumReturnTuple", input); + + assertTrue(output.isTuple()); + assertTrue(2 == output.toTuple().length); + + IValue output0 = output.toTuple()[0]; + IValue output1 = output.toTuple()[1]; + + assertArrayEquals(a, output0.toLongList()); + assertTrue(sum == output1.toLong()); + } + } + + @Test + public void testOptionalIntIsNone() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + + assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool()); + assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool()); + } + + @Test + public void testIntEq0None() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + + assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull()); + assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l); + } + + @Test(expected = IllegalArgumentException.class) + public void testRunUndefinedMethod() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + module.runMethod("test_undefined_method_throws_exception"); + } + + @Test + public void testTensorMethods() { + long[] shape = new long[] {1, 3, 224, 224}; + final int numel = (int) Tensor.numel(shape); + int[] ints = new int[numel]; + float[] floats = new float[numel]; + + byte[] bytes = new byte[numel]; + for (int i = 0; i < numel; i++) { + bytes[i] = (byte) ((i % 255) - 128); + ints[i] = i; + floats[i] = i / 1000.f; + } + + Tensor tensorBytes = Tensor.fromBlob(bytes, shape); + assertTrue(tensorBytes.dtype() == DType.INT8); + assertArrayEquals(bytes, tensorBytes.getDataAsByteArray()); + + Tensor tensorInts = Tensor.fromBlob(ints, shape); + assertTrue(tensorInts.dtype() == DType.INT32); + assertArrayEquals(ints, tensorInts.getDataAsIntArray()); + + Tensor tensorFloats = Tensor.fromBlob(floats, shape); + assertTrue(tensorFloats.dtype() == DType.FLOAT32); + float[] floatsOut = tensorFloats.getDataAsFloatArray(); + assertTrue(floatsOut.length == numel); + for (int i = 0; i < numel; i++) { + assertTrue(floats[i] == floatsOut[i]); + } + } + + @Test(expected = IllegalStateException.class) + public void testTensorIllegalStateOnWrongType() { + long[] shape = new long[] {1, 3, 224, 224}; + final int numel = (int) Tensor.numel(shape); + float[] floats = new float[numel]; + Tensor tensorFloats = Tensor.fromBlob(floats, shape); + assertTrue(tensorFloats.dtype() == DType.FLOAT32); + tensorFloats.getDataAsByteArray(); + } + + @Test + public void testEqString() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + String[] values = + new String[] { + "smoketest", + "проверка не латинских символов", // not latin symbols check + "#@$!@#)($*!@#$)(!@*#$" + }; + for (String value : values) { + final IValue input = IValue.from(value); + assertTrue(input.isString()); + assertTrue(value.equals(input.toStr())); + final IValue output = module.runMethod("eqStr", input); + assertTrue(output.isString()); + assertTrue(value.equals(output.toStr())); + } + } + + @Test + public void testStr3Concat() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + String[] values = + new String[] { + "smoketest", + "проверка не латинских символов", // not latin symbols check + "#@$!@#)($*!@#$)(!@*#$" + }; + for (String value : values) { + final IValue input = IValue.from(value); + assertTrue(input.isString()); + assertTrue(value.equals(input.toStr())); + final IValue output = module.runMethod("str3Concat", input); + assertTrue(output.isString()); + String expectedOutput = + new StringBuilder().append(value).append(value).append(value).toString(); + assertTrue(expectedOutput.equals(output.toStr())); + } + } + + @Test + public void testEmptyShape() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + final long someNumber = 43; + final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {})); + final IValue output = module.runMethod("newEmptyShapeWithItem", input); + assertTrue(output.isTensor()); + Tensor value = output.toTensor(); + assertArrayEquals(new long[] {}, value.shape()); + assertArrayEquals(new long[] {someNumber}, value.getDataAsLongArray()); + } + + @Test + public void testAliasWithOffset() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + final IValue output = module.runMethod("testAliasWithOffset"); + assertTrue(output.isTensorList()); + Tensor[] tensors = output.toTensorList(); + assertEquals(100, tensors[0].getDataAsLongArray()[0]); + assertEquals(200, tensors[1].getDataAsLongArray()[0]); + } + + @Test + public void testNonContiguous() throws IOException { + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + final IValue output = module.runMethod("testNonContiguous"); + assertTrue(output.isTensor()); + Tensor value = output.toTensor(); + assertArrayEquals(new long[] {2}, value.shape()); + assertArrayEquals(new long[] {100, 300}, value.getDataAsLongArray()); + } + + @Test + public void testChannelsLast() throws IOException { + long[] inputShape = new long[] {1, 3, 2, 2}; + long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104}; + Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST); + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC)); + assertIValueTensor( + outputNCHW, + MemoryFormat.CONTIGUOUS, + new long[] {1, 3, 2, 2}, + new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104}); + final IValue outputNHWC = module.runMethod("contiguousChannelsLast", IValue.from(inputNHWC)); + assertIValueTensor(outputNHWC, MemoryFormat.CHANNELS_LAST, inputShape, data); + } + + @Test + public void testChannelsLast3d() throws IOException { + long[] shape = new long[] {1, 2, 2, 2, 2}; + long[] dataNCHWD = new long[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16}; + + Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D); + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC)); + assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD); + + Tensor inputNCHWD = Tensor.fromBlob(dataNCHWD, shape, MemoryFormat.CONTIGUOUS); + final IValue outputNHWDC = + module.runMethod("contiguousChannelsLast3d", IValue.from(inputNCHWD)); + assertIValueTensor(outputNHWDC, MemoryFormat.CHANNELS_LAST_3D, shape, dataNHWDC); + } + + @Test + public void testChannelsLastConv2d() throws IOException { + long[] inputShape = new long[] {1, 3, 2, 2}; + long[] dataNCHW = new long[] { + 111, 112, + 121, 122, + + 211, 212, + 221, 222, + + 311, 312, + 321, 322}; + Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS); + long[] dataNHWC = new long[] { + 111, 211, 311, 112, 212, 312, + + 121, 221, 321, 122, 222, 322}; + Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST); + long[] weightShape = new long[] {3, 3, 1, 1}; + long[] dataWeightOIHW = new long[] { + 2, 0, 0, + 0, 1, 0, + 0, 0, -1}; + Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS); + long[] dataWeightOHWI = new long[] { + 2, 0, 0, + 0, 1, 0, + 0, 0, -1}; + + Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST); + + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + + final IValue outputNCHW = + module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false)); + assertIValueTensor( + outputNCHW, + MemoryFormat.CONTIGUOUS, + new long[] {1, 3, 2, 2}, + new long[] { + 2*111, 2*112, + 2*121, 2*122, + + 211, 212, + 221, 222, + + -311, -312, + -321, -322}); + + final IValue outputNHWC = + module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true)); + assertIValueTensor( + outputNHWC, + MemoryFormat.CHANNELS_LAST, + new long[] {1, 3, 2, 2}, + new long[] { + 2*111, 211, -311, 2*112, 212, -312, + 2*121, 221, -321, 2*122, 222, -322}); + } + + @Test + public void testChannelsLastConv3d() throws IOException { + long[] inputShape = new long[] {1, 3, 2, 2, 2}; + long[] dataNCDHW = new long[] { + 1111, 1112, + 1121, 1122, + 1211, 1212, + 1221, 1222, + + 2111, 2112, + 2121, 2122, + 2211, 2212, + 2221, 2222, + + 3111, 3112, + 3121, 3122, + 3211, 3212, + 3221, 3222}; + Tensor inputNCDHW = Tensor.fromBlob(dataNCDHW, inputShape, MemoryFormat.CONTIGUOUS); + long[] dataNDHWC = new long[] { + 1111, 2111, 3111, + 1112, 2112, 3112, + + 1121, 2121, 3121, + 1122, 2122, 3122, + + 1211, 2211, 3211, + 1212, 2212, 3212, + + 1221, 2221, 3221, + 1222, 2222, 3222}; + + Tensor inputNDHWC = Tensor.fromBlob(dataNDHWC, inputShape, MemoryFormat.CHANNELS_LAST_3D); + + long[] weightShape = new long[] {3, 3, 1, 1, 1}; + long[] dataWeightOIDHW = new long[] { + 2, 0, 0, + 0, 1, 0, + 0, 0, -1, + }; + Tensor wNCDHW = Tensor.fromBlob(dataWeightOIDHW, weightShape, MemoryFormat.CONTIGUOUS); + long[] dataWeightODHWI = new long[] { + 2, 0, 0, + 0, 1, 0, + 0, 0, -1, + }; + Tensor wNDHWC = Tensor.fromBlob(dataWeightODHWI, weightShape, MemoryFormat.CHANNELS_LAST_3D); + + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + + final IValue outputNCDHW = + module.runMethod("conv3d", IValue.from(inputNCDHW), IValue.from(wNCDHW), IValue.from(false)); + assertIValueTensor( + outputNCDHW, + MemoryFormat.CONTIGUOUS, + new long[] {1, 3, 2, 2, 2}, + new long[] { + 2*1111, 2*1112, 2*1121, 2*1122, + 2*1211, 2*1212, 2*1221, 2*1222, + + 2111, 2112, 2121, 2122, + 2211, 2212, 2221, 2222, + + -3111, -3112, -3121, -3122, + -3211, -3212, -3221, -3222}); + + final IValue outputNDHWC = + module.runMethod("conv3d", IValue.from(inputNDHWC), IValue.from(wNDHWC), IValue.from(true)); + assertIValueTensor( + outputNDHWC, + MemoryFormat.CHANNELS_LAST_3D, + new long[] {1, 3, 2, 2, 2}, + new long[] { + 2*1111, 2111, -3111, 2*1112, 2112, -3112, + 2*1121, 2121, -3121, 2*1122, 2122, -3122, + + 2*1211, 2211, -3211, 2*1212, 2212, -3212, + 2*1221, 2221, -3221, 2*1222, 2222, -3222}); + } + + @Test + public void testMobileNetV2() throws IOException { + try { + final Module module = loadModel("mobilenet_v2.ptl"); + final IValue inputs = module.runMethod("get_all_bundled_inputs"); + assertTrue(inputs.isList()); + final IValue input = inputs.toList()[0]; + assertTrue(input.isTuple()); + module.forward(input.toTuple()[0]); + assertTrue(true); + } catch (Exception ex) { + assertTrue("failed to run MobileNetV2 " + ex.getMessage(), false); + } + } + + @Test + public void testPointwiseOps() throws IOException { + runModel("pointwise_ops"); + } + + @Test + public void testReductionOps() throws IOException { + runModel("reduction_ops"); + } + + @Test + public void testComparisonOps() throws IOException { + runModel("comparison_ops"); + } + + @Test + public void testOtherMathOps() throws IOException { + runModel("other_math_ops"); + } + + @Test + @Ignore + public void testSpectralOps() throws IOException { + // NB: This model fails without lite interpreter. The error is as follows: + // RuntimeError: stft requires the return_complex parameter be given for real inputs + runModel("spectral_ops"); + } + + @Test + public void testBlasLapackOps() throws IOException { + runModel("blas_lapack_ops"); + } + + @Test + public void testSamplingOps() throws IOException { + runModel("sampling_ops"); + } + + @Test + public void testTensorOps() throws IOException { + runModel("tensor_general_ops"); + } + + @Test + public void testTensorCreationOps() throws IOException { + runModel("tensor_creation_ops"); + } + + @Test + public void testTensorIndexingOps() throws IOException { + runModel("tensor_indexing_ops"); + } + + @Test + public void testTensorTypingOps() throws IOException { + runModel("tensor_typing_ops"); + } + + @Test + public void testTensorViewOps() throws IOException { + runModel("tensor_view_ops"); + } + + @Test + public void testConvolutionOps() throws IOException { + runModel("convolution_ops"); + } + + @Test + public void testPoolingOps() throws IOException { + runModel("pooling_ops"); + } + + @Test + public void testPaddingOps() throws IOException { + runModel("padding_ops"); + } + + @Test + public void testActivationOps() throws IOException { + runModel("activation_ops"); + } + + @Test + public void testNormalizationOps() throws IOException { + runModel("normalization_ops"); + } + + @Test + public void testRecurrentOps() throws IOException { + runModel("recurrent_ops"); + } + + @Test + public void testTransformerOps() throws IOException { + runModel("transformer_ops"); + } + + @Test + public void testLinearOps() throws IOException { + runModel("linear_ops"); + } + + @Test + public void testDropoutOps() throws IOException { + runModel("dropout_ops"); + } + + @Test + public void testSparseOps() throws IOException { + runModel("sparse_ops"); + } + + @Test + public void testDistanceFunctionOps() throws IOException { + runModel("distance_function_ops"); + } + + @Test + public void testLossFunctionOps() throws IOException { + runModel("loss_function_ops"); + } + + @Test + public void testVisionFunctionOps() throws IOException { + runModel("vision_function_ops"); + } + + @Test + public void testShuffleOps() throws IOException { + runModel("shuffle_ops"); + } + + @Test + public void testNNUtilsOps() throws IOException { + runModel("nn_utils_ops"); + } + + @Test + public void testQuantOps() throws IOException { + runModel("general_quant_ops"); + } + + @Test + public void testDynamicQuantOps() throws IOException { + runModel("dynamic_quant_ops"); + } + + @Test + public void testStaticQuantOps() throws IOException { + runModel("static_quant_ops"); + } + + @Test + public void testFusedQuantOps() throws IOException { + runModel("fused_quant_ops"); + } + + @Test + public void testTorchScriptBuiltinQuantOps() throws IOException { + runModel("torchscript_builtin_ops"); + } + + @Test + public void testTorchScriptCollectionQuantOps() throws IOException { + runModel("torchscript_collection_ops"); + } + + static void assertIValueTensor( + final IValue ivalue, + final MemoryFormat memoryFormat, + final long[] expectedShape, + final long[] expectedData) { + assertTrue(ivalue.isTensor()); + Tensor t = ivalue.toTensor(); + assertEquals(memoryFormat, t.memoryFormat()); + assertArrayEquals(expectedShape, t.shape()); + assertArrayEquals(expectedData, t.getDataAsLongArray()); + } + + void runModel(final String name) throws IOException { + final Module storage_module = loadModel(name + ".ptl"); + storage_module.forward(); + + // TODO enable this once the on-the-fly script is ready + // final Module on_the_fly_module = loadModel(name + "_temp.ptl"); + // on_the_fly_module.forward(); + assertTrue(true); + } + + protected abstract Module loadModel(String assetName) throws IOException; +} diff --git a/android/pytorch_android/src/androidTest/java/org/pytorch/suite/PytorchInstrumentedTestSuite.java b/android/pytorch_android/src/androidTest/java/org/pytorch/suite/PytorchInstrumentedTestSuite.java new file mode 100644 index 000000000000..8bbfa0d26937 --- /dev/null +++ b/android/pytorch_android/src/androidTest/java/org/pytorch/suite/PytorchInstrumentedTestSuite.java @@ -0,0 +1,9 @@ +package org.pytorch.suite; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.pytorch.PytorchInstrumentedTests; + +@RunWith(Suite.class) +@Suite.SuiteClasses({PytorchInstrumentedTests.class}) +public class PytorchInstrumentedTestSuite {} diff --git a/android/pytorch_android/src/androidTest/java/org/pytorch/suite/PytorchLiteInstrumentedTestSuite.java b/android/pytorch_android/src/androidTest/java/org/pytorch/suite/PytorchLiteInstrumentedTestSuite.java new file mode 100644 index 000000000000..a494ffc663ff --- /dev/null +++ b/android/pytorch_android/src/androidTest/java/org/pytorch/suite/PytorchLiteInstrumentedTestSuite.java @@ -0,0 +1,9 @@ +package org.pytorch.suite; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.pytorch.PytorchLiteInstrumentedTests; + +@RunWith(Suite.class) +@Suite.SuiteClasses({PytorchLiteInstrumentedTests.class}) +public class PytorchLiteInstrumentedTestSuite {} diff --git a/android/pytorch_android/src/main/AndroidManifest.xml b/android/pytorch_android/src/main/AndroidManifest.xml new file mode 100644 index 000000000000..763c601fb700 --- /dev/null +++ b/android/pytorch_android/src/main/AndroidManifest.xml @@ -0,0 +1 @@ + diff --git a/android/pytorch_android/src/main/cpp/cmake_macros.h b/android/pytorch_android/src/main/cpp/cmake_macros.h new file mode 100644 index 000000000000..06d25a052b15 --- /dev/null +++ b/android/pytorch_android/src/main/cpp/cmake_macros.h @@ -0,0 +1,3 @@ +#pragma once + +/* #undef TRACE_ENABLED */ diff --git a/android/pytorch_android/src/main/cpp/cmake_macros.h.in b/android/pytorch_android/src/main/cpp/cmake_macros.h.in new file mode 100644 index 000000000000..70986864c1fb --- /dev/null +++ b/android/pytorch_android/src/main/cpp/cmake_macros.h.in @@ -0,0 +1,3 @@ +#pragma once + +#cmakedefine TRACE_ENABLED diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp new file mode 100644 index 000000000000..704d8b38805c --- /dev/null +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp @@ -0,0 +1,697 @@ +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include "pytorch_jni_common.h" +#if defined(__ANDROID__) +#ifndef USE_PTHREADPOOL +#define USE_PTHREADPOOL +#endif /* USE_PTHREADPOOL */ +#include +#endif + +namespace pytorch_jni { + +c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode) { + if (deviceJniCode == kDeviceCPU) { + return at::kCPU; + } else if (deviceJniCode == kDeviceVulkan) { + return at::kVulkan; + } + + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, "Unknown device"); +} + +bool Trace::is_initialized_ = false; + +#if defined(TRACE_ENABLED) && defined(__ANDROID__) +Trace::fp_ATrace_beginSection Trace::ATrace_beginSection; +Trace::fp_ATrace_endSection Trace::ATrace_endSection; +#endif + +void Trace::init() { +#if defined(TRACE_ENABLED) && defined(__ANDROID__) + void* lib = dlopen("libandroid.so", RTLD_NOW || RTLD_LOCAL); + if (lib != NULL) { + Trace::ATrace_beginSection = reinterpret_cast( + dlsym(lib, "ATrace_beginSection")); + Trace::ATrace_endSection = + reinterpret_cast(dlsym(lib, "ATrace_endSection")); + } +#endif +} + +// NOTE: Codes must be kept in sync with DType.java. +// NOTE: Never serialize these, because they can change between releases. +constexpr static int kTensorDTypeUInt8 = 1; +constexpr static int kTensorDTypeInt8 = 2; +constexpr static int kTensorDTypeInt32 = 3; +constexpr static int kTensorDTypeFloat32 = 4; +constexpr static int kTensorDTypeInt64 = 5; +constexpr static int kTensorDTypeFloat64 = 6; + +constexpr static int kTensorMemoryFormatContiguous = 1; +constexpr static int kTensorMemoryFormatChannelsLast = 2; +constexpr static int kTensorMemoryFormatChannelsLast3d = 3; + +template +struct JHashMap + : facebook::jni::JavaClass, facebook::jni::JMap> { + constexpr static auto kJavaDescriptor = "Ljava/util/HashMap;"; + + using Super = + facebook::jni::JavaClass, facebook::jni::JMap>; + + static facebook::jni::local_ref> create() { + return Super::newInstance(); + } + + void put( + facebook::jni::alias_ref key, + facebook::jni::alias_ref value) { + static auto putMethod = + Super::javaClassStatic() + ->template getMethod( + facebook::jni::alias_ref, + facebook::jni::alias_ref)>( + "put"); + putMethod(Super::self(), key, value); + } +}; + +static at::Tensor newAtTensor( + facebook::jni::alias_ref jbuffer, + facebook::jni::alias_ref jshape, + jint jdtype, + jint jmemoryFormat) { + const auto rank = jshape->size(); + const auto shapeArr = jshape->getRegion(0, rank); + std::vector shapeVec{}; + shapeVec.reserve(rank); + auto numel = 1; + for (const auto i : c10::irange(rank)) { + shapeVec.push_back(shapeArr[i]); + numel *= shapeArr[i]; + } + JNIEnv* jni = facebook::jni::Environment::current(); + caffe2::TypeMeta typeMeta{}; + int dataElementSizeBytes = 0; + if (kTensorDTypeFloat32 == jdtype) { + dataElementSizeBytes = 4; + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeInt32 == jdtype) { + dataElementSizeBytes = 4; + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeInt8 == jdtype) { + dataElementSizeBytes = 1; + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeUInt8 == jdtype) { + dataElementSizeBytes = 1; + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeFloat64 == jdtype) { + dataElementSizeBytes = 8; + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeInt64 == jdtype) { + dataElementSizeBytes = 8; + typeMeta = caffe2::TypeMeta::Make(); + } else { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Unknown Tensor jdtype %d", + jdtype); + } + const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get()); + if (dataCapacity != numel) { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Tensor dimensions(elements number:%d, element byte size:%d, total " + "bytes:%d) inconsistent with buffer capacity(%d)", + numel, + dataElementSizeBytes, + numel * dataElementSizeBytes, + dataCapacity); + } + + if (jmemoryFormat == kTensorMemoryFormatChannelsLast) { + auto sizes = torch::IntArrayRef(shapeVec); + return torch::from_blob( + jni->GetDirectBufferAddress(jbuffer.get()), + sizes, + torch::IntArrayRef(c10::get_channels_last_strides_2d(sizes)), + at::TensorOptions(typeMeta).memory_format( + at::MemoryFormat::ChannelsLast)); + } else if (jmemoryFormat == kTensorMemoryFormatChannelsLast3d) { + auto sizes = torch::IntArrayRef(shapeVec); + return torch::from_blob( + jni->GetDirectBufferAddress(jbuffer.get()), + sizes, + torch::IntArrayRef(c10::get_channels_last_strides_3d(sizes)), + at::TensorOptions(typeMeta).memory_format( + at::MemoryFormat::ChannelsLast3d)); + } + return torch::from_blob( + jni->GetDirectBufferAddress(jbuffer.get()), + torch::IntArrayRef(shapeVec), + at::TensorOptions(typeMeta)); +} + +class TensorHybrid : public facebook::jni::HybridClass { + public: + constexpr static const char* kJavaDescriptor = "Lorg/pytorch/Tensor;"; + + explicit TensorHybrid(at::Tensor tensor) : tensor_(tensor) {} + + static facebook::jni::local_ref initHybrid( + facebook::jni::alias_ref jTensorThis) { + static auto cls = TensorHybrid::javaClassStatic(); + static const auto jMethodDTypeCode = cls->getMethod("dtypeJniCode"); + static const auto jMethodMemoryFormatCode = + cls->getMethod("memoryFormatJniCode"); + static const auto jFieldShape = cls->getField("shape"); + static const auto jMethodGetDataBuffer = cls->getMethod< + facebook::jni::local_ref()>( + "getRawDataBuffer"); + + at::Tensor tensor = newAtTensor( + jMethodGetDataBuffer(jTensorThis), + jTensorThis->getFieldValue(jFieldShape), + jMethodDTypeCode(jTensorThis), + jMethodMemoryFormatCode(jTensorThis)); + return makeCxxInstance(std::move(tensor)); + } + + static facebook::jni::local_ref + newJTensorFromAtTensor(const at::Tensor& input_tensor) { + // Java wrapper currently only supports contiguous tensors. + + int jmemoryFormat = 0; + at::Tensor tensor{}; + if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast)) { + tensor = input_tensor; + jmemoryFormat = kTensorMemoryFormatChannelsLast; + } else if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast3d)) { + tensor = input_tensor; + jmemoryFormat = kTensorMemoryFormatChannelsLast3d; + } else { + tensor = input_tensor.contiguous(); + jmemoryFormat = kTensorMemoryFormatContiguous; + } + + const auto scalarType = tensor.scalar_type(); + int jdtype = 0; + if (at::kFloat == scalarType) { + jdtype = kTensorDTypeFloat32; + } else if (at::kInt == scalarType) { + jdtype = kTensorDTypeInt32; + } else if (at::kByte == scalarType) { + jdtype = kTensorDTypeUInt8; + } else if (at::kChar == scalarType) { + jdtype = kTensorDTypeInt8; + } else if (at::kLong == scalarType) { + jdtype = kTensorDTypeInt64; + } else if (at::kDouble == scalarType) { + jdtype = kTensorDTypeFloat64; + } else { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "at::Tensor scalar type %s is not supported on java side", + c10::toString(scalarType)); + } + + const auto& tensorShape = tensor.sizes(); + std::vector tensorShapeVec; + for (const auto& s : tensorShape) { + tensorShapeVec.push_back(s); + } + facebook::jni::local_ref jTensorShape = + facebook::jni::make_long_array(tensorShapeVec.size()); + jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data()); + + static auto cls = TensorHybrid::javaClassStatic(); + facebook::jni::local_ref jTensorBuffer = + facebook::jni::JByteBuffer::wrapBytes( + (uint8_t*)tensor.data_ptr(), tensor.nbytes()); + jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder()); + + static const auto jMethodNewTensor = + cls->getStaticMethod( + facebook::jni::alias_ref, + facebook::jni::alias_ref, + jint, + jint, + facebook::jni::alias_ref)>("nativeNewTensor"); + return jMethodNewTensor( + cls, + jTensorBuffer, + jTensorShape, + jdtype, + jmemoryFormat, + makeCxxInstance(tensor)); + } + + static at::Tensor newAtTensorFromJTensor( + facebook::jni::alias_ref jtensor) { + static auto cls = TensorHybrid::javaClassStatic(); + static const auto dtypeMethod = cls->getMethod("dtypeJniCode"); + jint jdtype = dtypeMethod(jtensor); + + static const auto memoryFormatMethod = + cls->getMethod("memoryFormatJniCode"); + jint jmemoryFormat = memoryFormatMethod(jtensor); + + static const auto shapeField = cls->getField("shape"); + auto jshape = jtensor->getFieldValue(shapeField); + + static auto dataBufferMethod = cls->getMethod< + facebook::jni::local_ref()>( + "getRawDataBuffer"); + facebook::jni::local_ref jbuffer = + dataBufferMethod(jtensor); + return newAtTensor(jbuffer, jshape, jdtype, jmemoryFormat); + } + + at::Tensor tensor() const { + return tensor_; + } + + private: + friend HybridBase; + at::Tensor tensor_; +}; + +facebook::jni::local_ref JIValue::newJIValueFromStringDict( + c10::Dict dict) { + static auto jMethodDictStringKey = + JIValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::alias_ref, + facebook::jni::alias_ref>>)>( + "dictStringKeyFrom"); + + auto jmap = JHashMap< + facebook::jni::alias_ref, + facebook::jni::alias_ref>::create(); + for (auto& pair : dict) { + jmap->put( + facebook::jni::make_jstring(pair.key().toStringRef()), + JIValue::newJIValueFromAtIValue(pair.value())); + } + return jMethodDictStringKey(JIValue::javaClassStatic(), jmap); +} + +facebook::jni::local_ref JIValue::newJIValueFromIntDict( + c10::Dict dict) { + static auto jMethodDictLongKey = + JIValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::alias_ref, + facebook::jni::alias_ref>>)>( + "dictLongKeyFrom"); + auto jmap = JHashMap< + facebook::jni::alias_ref, + facebook::jni::alias_ref>::create(); + for (auto& pair : dict) { + jmap->put( + facebook::jni::JLong::valueOf(pair.key().toInt()), + JIValue::newJIValueFromAtIValue(pair.value())); + } + return jMethodDictLongKey(JIValue::javaClassStatic(), jmap); +} + +facebook::jni::local_ref JIValue::newJIValueFromAtIValue( + const at::IValue& ivalue, + DictCallback stringDictCallback, + DictCallback intDictCallback) { + Trace _s{"jni::JIValue::newJIValueFromAtIValue"}; + if (ivalue.isNone()) { + static auto jMethodOptionalNull = + JIValue::javaClassStatic() + ->getStaticMethod()>( + "optionalNull"); + return jMethodOptionalNull(JIValue::javaClassStatic()); + } else if (ivalue.isTensor()) { + static auto jMethodTensor = + JIValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::local_ref)>("from"); + const auto& tensor = ivalue.toTensor(); + return jMethodTensor( + JIValue::javaClassStatic(), + TensorHybrid::newJTensorFromAtTensor(tensor)); + } else if (ivalue.isBool()) { + static auto jMethodBool = + JIValue::javaClassStatic() + ->getStaticMethod(jboolean)>( + "from"); + return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool()); + } else if (ivalue.isInt()) { + static auto jMethodInt = + JIValue::javaClassStatic() + ->getStaticMethod(jlong)>("from"); + return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt()); + } else if (ivalue.isDouble()) { + static auto jMethodDouble = + JIValue::javaClassStatic() + ->getStaticMethod(jdouble)>( + "from"); + return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble()); + } else if (ivalue.isString()) { + static auto jMethodString = + JIValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::alias_ref)>( + "from"); + return jMethodString( + JIValue::javaClassStatic(), + facebook::jni::make_jstring(ivalue.toStringRef())); + } else if (ivalue.isTuple()) { + auto elementsVec = ivalue.toTupleRef().elements(); + static auto jMethodTupleArr = + JIValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::alias_ref::javaobject>)>("tupleFrom"); + auto jElementsArray = + facebook::jni::JArrayClass::newArray( + elementsVec.size()); + auto index = 0; + for (const auto& e : elementsVec) { + (*jElementsArray)[index++] = JIValue::newJIValueFromAtIValue(e); + } + return jMethodTupleArr(JIValue::javaClassStatic(), jElementsArray); + } else if (ivalue.isBoolList()) { + auto list = ivalue.toBoolList(); + static auto jMethodBoolListArr = + JIValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::alias_ref)>("listFrom"); + size_t n = list.size(); + auto jArray = facebook::jni::make_boolean_array(n); + auto jArrayPinned = jArray->pin(); + auto index = 0; + for (const auto& e : list) { + jArrayPinned[index++] = e; + } + return jMethodBoolListArr(JIValue::javaClassStatic(), jArray); + } else if (ivalue.isIntList()) { + auto list = ivalue.toIntList(); + static auto jMethodLongListArr = + JIValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::alias_ref)>("listFrom"); + size_t n = list.size(); + auto jArray = facebook::jni::make_long_array(n); + auto jArrayPinned = jArray->pin(); + auto index = 0; + for (const auto& e : list) { + jArrayPinned[index++] = e; + } + return jMethodLongListArr(JIValue::javaClassStatic(), jArray); + } else if (ivalue.isDoubleList()) { + auto list = ivalue.toDoubleList(); + static auto jMethoDoubleListArr = + JIValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::alias_ref)>("listFrom"); + size_t n = list.size(); + auto jArray = facebook::jni::make_double_array(n); + auto jArrayPinned = jArray->pin(); + auto index = 0; + for (const auto& e : list) { + jArrayPinned[index++] = e; + } + return jMethoDoubleListArr(JIValue::javaClassStatic(), jArray); + } else if (ivalue.isTensorList()) { + auto list = ivalue.toTensorList(); + static auto jMethodTensorListArr = + JIValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::alias_ref::javaobject>)>("listFrom"); + auto jArray = + facebook::jni::JArrayClass::newArray( + list.size()); + auto index = 0; + for (const auto& e : list) { + (*jArray)[index++] = TensorHybrid::newJTensorFromAtTensor(e); + } + return jMethodTensorListArr(JIValue::javaClassStatic(), jArray); + } else if (ivalue.isList()) { + auto list = ivalue.toList(); + static auto jMethodListArr = + JIValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::alias_ref::javaobject>)>("listFrom"); + auto jArray = + facebook::jni::JArrayClass::newArray(list.size()); + auto index = 0; + for (const auto& e : list) { + (*jArray)[index++] = JIValue::newJIValueFromAtIValue(e); + } + return jMethodListArr(JIValue::javaClassStatic(), jArray); + } else if (ivalue.isGenericDict()) { + auto dict = ivalue.toGenericDict(); + const auto keyType = dict.keyType(); + + if (!keyType) { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Unknown IValue-Dict key type"); + } + + if (*keyType == *c10::StringType::get()) { + return stringDictCallback(std::move(dict)); + } else if (*keyType == *c10::IntType::get()) { + return intDictCallback(std::move(dict)); + } + + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Unsupported IValue-Dict key type: %s", + keyType->str().c_str()); + } + + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Unsupported IValue type %s", + ivalue.tagKind().c_str()); +} + +at::IValue JIValue::JIValueToAtIValue( + facebook::jni::alias_ref jivalue) { + Trace _s{"jni::JIValue::JIValueToAtIValue"}; + static const auto typeCodeField = + JIValue::javaClassStatic()->getField("mTypeCode"); + const auto typeCode = jivalue->getFieldValue(typeCodeField); + if (JIValue::kTypeCodeNull == typeCode) { + return at::IValue{}; + } else if (JIValue::kTypeCodeTensor == typeCode) { + static const auto jMethodGetTensor = + JIValue::javaClassStatic() + ->getMethod()>( + "toTensor"); + return TensorHybrid::newAtTensorFromJTensor(jMethodGetTensor(jivalue)); + } else if (JIValue::kTypeCodeBool == typeCode) { + static const auto jMethodGetBool = + JIValue::javaClassStatic()->getMethod("toBool"); + // explicit cast to bool as jboolean is defined as uint8_t, IValue ctor + // for int will be called for jboolean + bool b = jMethodGetBool(jivalue); + return at::IValue{b}; + } else if (JIValue::kTypeCodeLong == typeCode) { + static const auto jMethodGetLong = + JIValue::javaClassStatic()->getMethod("toLong"); + return at::IValue{(int64_t)jMethodGetLong(jivalue)}; + } else if (JIValue::kTypeCodeDouble == typeCode) { + static const auto jMethodGetDouble = + JIValue::javaClassStatic()->getMethod("toDouble"); + return at::IValue{jMethodGetDouble(jivalue)}; + } else if (JIValue::kTypeCodeString == typeCode) { + static const auto jMethodGetString = + JIValue::javaClassStatic()->getMethod("toStr"); + return at::IValue{jMethodGetString(jivalue)->toStdString()}; + } else if (JIValue::kTypeCodeTuple == typeCode) { + static const auto jMethodGetTuple = + JIValue::javaClassStatic() + ->getMethod< + facebook::jni::JArrayClass::javaobject()>( + "toTuple"); + auto jarray = jMethodGetTuple(jivalue); + size_t n = jarray->size(); + + std::vector elements; + elements.reserve(n); + for (const auto i : c10::irange(n)) { + auto jivalue_element = jarray->getElement(i); + auto element = JIValue::JIValueToAtIValue(jivalue_element); + elements.push_back(std::move(element)); + } + return c10::ivalue::Tuple::create(std::move(elements)); + } else if (JIValue::kTypeCodeBoolList == typeCode) { + static const auto jMethodGetBoolList = + JIValue::javaClassStatic()->getMethod("toBoolList"); + auto jArray = jMethodGetBoolList(jivalue); + auto jArrayPinned = jArray->pin(); + size_t n = jArrayPinned.size(); + c10::List list{}; + list.reserve(n); + for (const auto i : c10::irange(n)) { + list.push_back(jArrayPinned[i]); + } + return at::IValue{std::move(list)}; + } else if (JIValue::kTypeCodeLongList == typeCode) { + static const auto jMethodGetLongList = + JIValue::javaClassStatic()->getMethod("toLongList"); + auto jArray = jMethodGetLongList(jivalue); + auto jArrayPinned = jArray->pin(); + size_t n = jArrayPinned.size(); + c10::List list{}; + list.reserve(n); + for (const auto i : c10::irange(n)) { + list.push_back(jArrayPinned[i]); + } + return at::IValue{std::move(list)}; + } else if (JIValue::kTypeCodeDoubleList == typeCode) { + static const auto jMethodGetDoubleList = + JIValue::javaClassStatic()->getMethod("toDoubleList"); + auto jArray = jMethodGetDoubleList(jivalue); + auto jArrayPinned = jArray->pin(); + size_t n = jArrayPinned.size(); + c10::List list{}; + list.reserve(n); + for (const auto i : c10::irange(n)) { + list.push_back(jArrayPinned[i]); + } + return at::IValue{std::move(list)}; + } else if (JIValue::kTypeCodeTensorList == typeCode) { + static const auto jMethodGetTensorList = + JIValue::javaClassStatic() + ->getMethod::javaobject()>("toTensorList"); + auto jArray = jMethodGetTensorList(jivalue); + size_t n = jArray->size(); + c10::List list{}; + list.reserve(n); + for (const auto i : c10::irange(n)) { + list.push_back( + TensorHybrid::newAtTensorFromJTensor(jArray->getElement(i))); + } + return at::IValue{std::move(list)}; + } else if (JIValue::kTypeCodeList == typeCode) { + static const auto jMethodGetList = + JIValue::javaClassStatic() + ->getMethod< + facebook::jni::JArrayClass::javaobject()>( + "toList"); + auto jarray = jMethodGetList(jivalue); + size_t n = jarray->size(); + if (n == 0) { + return at::IValue{c10::impl::GenericList(c10::TensorType::get())}; + } + + auto jivalue_first_element = jarray->getElement(0); + auto first_element = JIValue::JIValueToAtIValue(jivalue_first_element); + c10::impl::GenericList list{c10::unshapedType(first_element.type())}; + list.reserve(n); + list.push_back(first_element); + for (const auto i : c10::irange(1, n)) { + auto jivalue_element = jarray->getElement(i); + auto element = JIValue::JIValueToAtIValue(jivalue_element); + list.push_back(element); + } + return at::IValue{list}; + } else if (JIValue::kTypeCodeDictStringKey == typeCode) { + static const auto jMethodGetDictStringKey = + JIValue::javaClassStatic() + ->getMethod:: + javaobject()>("toDictStringKey"); + auto jmap = jMethodGetDictStringKey(jivalue); + auto it = jmap->begin(); + if (it == jmap->end()) { + return at::IValue{c10::impl::GenericDict( + c10::StringType::get(), c10::TensorType::get())}; + } + + auto firstEntryValue = JIValue::JIValueToAtIValue(it->second); + c10::impl::GenericDict dict{ + c10::StringType::get(), c10::unshapedType(firstEntryValue.type())}; + dict.insert(it->first->toStdString(), firstEntryValue); + it++; + for (; it != jmap->end(); it++) { + dict.insert( + it->first->toStdString(), JIValue::JIValueToAtIValue(it->second)); + } + return at::IValue{dict}; + } else if (JIValue::kTypeCodeDictLongKey == typeCode) { + static const auto jMethodGetDictLongKey = + JIValue::javaClassStatic() + ->getMethod::javaobject()>("toDictLongKey"); + auto jmap = jMethodGetDictLongKey(jivalue); + auto it = jmap->begin(); + if (it == jmap->end()) { + return at::IValue{ + c10::impl::GenericDict(c10::IntType::get(), c10::TensorType::get())}; + } + + auto firstEntryValue = JIValue::JIValueToAtIValue(it->second); + c10::impl::GenericDict dict{ + c10::IntType::get(), c10::unshapedType(firstEntryValue.type())}; + dict.insert((int64_t)it->first->longValue(), firstEntryValue); + it++; + for (; it != jmap->end(); it++) { + dict.insert( + (int64_t)it->first->longValue(), + JIValue::JIValueToAtIValue(it->second)); + } + return at::IValue{dict}; + } + + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Unknown IValue typeCode %d", + typeCode); +} + +#if defined(__ANDROID__) +class PyTorchAndroidJni : public facebook::jni::JavaClass { + public: + constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;"; + + static void registerNatives() { + javaClassStatic()->registerNatives({ + makeNativeMethod( + "nativeSetNumThreads", PyTorchAndroidJni::setNumThreads), + }); + } + + static void setNumThreads(facebook::jni::alias_ref, jint numThreads) { + caffe2::pthreadpool()->set_thread_count(numThreads); + } +}; +#endif + +void common_registerNatives() { + static const int once = []() { +#if defined(__ANDROID__) + pytorch_jni::PyTorchAndroidJni::registerNatives(); +#endif + return 0; + }(); + ((void)once); +} + +} // namespace pytorch_jni diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_common.h b/android/pytorch_android/src/main/cpp/pytorch_jni_common.h new file mode 100644 index 000000000000..5020a1cff960 --- /dev/null +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_common.h @@ -0,0 +1,137 @@ +#pragma once + +#include +#include +#include +#include "caffe2/serialize/read_adapter_interface.h" + +#include "cmake_macros.h" + +#ifdef __ANDROID__ +#include +#define ALOGI(...) \ + __android_log_print(ANDROID_LOG_INFO, "pytorch-jni", __VA_ARGS__) +#define ALOGE(...) \ + __android_log_print(ANDROID_LOG_ERROR, "pytorch-jni", __VA_ARGS__) +#endif + +#if defined(TRACE_ENABLED) && defined(__ANDROID__) +#include +#include +#endif + +namespace pytorch_jni { + +constexpr static int kDeviceCPU = 1; +constexpr static int kDeviceVulkan = 2; + +c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode); + +class Trace { + public: +#if defined(TRACE_ENABLED) && defined(__ANDROID__) + typedef void* (*fp_ATrace_beginSection)(const char* sectionName); + typedef void* (*fp_ATrace_endSection)(void); + + static fp_ATrace_beginSection ATrace_beginSection; + static fp_ATrace_endSection ATrace_endSection; +#endif + + static void ensureInit() { + if (!Trace::is_initialized_) { + init(); + Trace::is_initialized_ = true; + } + } + + static void beginSection(const char* name) { + Trace::ensureInit(); +#if defined(TRACE_ENABLED) && defined(__ANDROID__) + ATrace_beginSection(name); +#endif + } + + static void endSection() { +#if defined(TRACE_ENABLED) && defined(__ANDROID__) + ATrace_endSection(); +#endif + } + + Trace(const char* name) { + ensureInit(); + beginSection(name); + } + + ~Trace() { + endSection(); + } + + private: + static void init(); + static bool is_initialized_; +}; + +class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface { + public: + explicit MemoryReadAdapter(const void* data, off_t size) + : data_(data), size_(size){}; + + size_t size() const override { + return size_; + } + + size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const override { + memcpy(buf, (int8_t*)(data_) + pos, n); + return n; + } + + ~MemoryReadAdapter() {} + + private: + const void* data_; + off_t size_; +}; + +class JIValue : public facebook::jni::JavaClass { + using DictCallback = c10::function_ref( + c10::Dict)>; + + public: + constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;"; + + constexpr static int kTypeCodeNull = 1; + + constexpr static int kTypeCodeTensor = 2; + constexpr static int kTypeCodeBool = 3; + constexpr static int kTypeCodeLong = 4; + constexpr static int kTypeCodeDouble = 5; + constexpr static int kTypeCodeString = 6; + + constexpr static int kTypeCodeTuple = 7; + constexpr static int kTypeCodeBoolList = 8; + constexpr static int kTypeCodeLongList = 9; + constexpr static int kTypeCodeDoubleList = 10; + constexpr static int kTypeCodeTensorList = 11; + constexpr static int kTypeCodeList = 12; + + constexpr static int kTypeCodeDictStringKey = 13; + constexpr static int kTypeCodeDictLongKey = 14; + + static facebook::jni::local_ref newJIValueFromAtIValue( + const at::IValue& ivalue, + DictCallback stringDictCallback = newJIValueFromStringDict, + DictCallback intDictCallback = newJIValueFromIntDict); + + static at::IValue JIValueToAtIValue( + facebook::jni::alias_ref jivalue); + + private: + static facebook::jni::local_ref newJIValueFromStringDict( + c10::Dict); + static facebook::jni::local_ref newJIValueFromIntDict( + c10::Dict); +}; + +void common_registerNatives(); +} // namespace pytorch_jni diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp new file mode 100644 index 000000000000..c7a8a0dfd63e --- /dev/null +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp @@ -0,0 +1,245 @@ +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include "caffe2/serialize/read_adapter_interface.h" + +#include "pytorch_jni_common.h" + +#ifdef __ANDROID__ +#include +#include +#include +#endif + +namespace pytorch_jni { + +namespace { + +struct JITCallGuard { + // Inference only workload. + c10::InferenceMode guard; + // Disable graph optimizer to ensure list of unused ops are not changed for + // custom mobile build. + torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false}; +}; + +} // namespace + +class PytorchJni : public facebook::jni::HybridClass { + private: + friend HybridBase; + torch::jit::Module module_; + c10::DeviceType deviceType_; + + public: + constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;"; + + static facebook::jni::local_ref initHybrid( + facebook::jni::alias_ref, + facebook::jni::alias_ref modelPath, + facebook::jni::alias_ref< + facebook::jni::JMap> + extraFiles, + jint device) { + return makeCxxInstance(modelPath, extraFiles, device); + } + +#ifdef __ANDROID__ + static facebook::jni::local_ref initHybridAndroidAsset( + facebook::jni::alias_ref, + facebook::jni::alias_ref assetName, + facebook::jni::alias_ref assetManager, + jint device) { + return makeCxxInstance(assetName, assetManager, device); + } +#endif + +#ifdef TRACE_ENABLED + static std::unique_ptr onFunctionEnter( + const at::RecordFunction& fn) { + Trace::beginSection(fn.name().str()); + return nullptr; + } + + static void onFunctionExit(const at::RecordFunction&, at::ObserverContext*) { + Trace::endSection(); + } +#endif + + static void preModuleLoadSetupOnce() { + auto qengines = at::globalContext().supportedQEngines(); + if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != + qengines.end()) { + at::globalContext().setQEngine(at::QEngine::QNNPACK); + } + +#ifdef __ANDROID__ + torch::jit::setPrintHandler([](const std::string& s) { + __android_log_print(ANDROID_LOG_DEBUG, "pytorch-print", "%s", s.c_str()); + }); +#endif + +#ifdef TRACE_ENABLED + at::addGlobalCallback( + at::RecordFunctionCallback(&onFunctionEnter, &onFunctionExit) + .scopes({RecordScope::FUNCTION, RecordScope::USER_SCOPE})); +#endif + } + + void preModuleLoadSetup() { + static const int once = []() { + preModuleLoadSetupOnce(); + return 0; + }(); + ((void)once); + } + + PytorchJni( + facebook::jni::alias_ref modelPath, + facebook::jni::alias_ref< + facebook::jni::JMap> + extraFiles, + jint device) { + preModuleLoadSetup(); + JITCallGuard guard; + std::unordered_map extra_files; + const auto has_extra = extraFiles && extraFiles->size() > 0; + if (has_extra) { + for (const auto& e : *extraFiles) { + extra_files[e.first->toStdString()] = ""; + } + } + deviceType_ = deviceJniCodeToDeviceType(device); + module_ = torch::jit::load( + std::move(modelPath->toStdString()), std::nullopt, extra_files); + if (has_extra) { + static auto putMethod = + facebook::jni::JMap:: + javaClassStatic() + ->template getMethod( + facebook::jni::alias_ref, + facebook::jni::alias_ref)>("put"); + for (const auto& ef : extra_files) { + putMethod( + extraFiles, + facebook::jni::make_jstring(ef.first), + facebook::jni::make_jstring(ef.second)); + } + } + + module_.eval(); + } + +#ifdef __ANDROID__ + PytorchJni( + facebook::jni::alias_ref assetName, + facebook::jni::alias_ref assetManager, + jint device) { + preModuleLoadSetup(); + JNIEnv* env = facebook::jni::Environment::current(); + AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get()); + if (!mgr) { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Unable to get asset manager"); + } + AAsset* asset = AAssetManager_open( + mgr, assetName->toStdString().c_str(), AASSET_MODE_BUFFER); + if (!asset) { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Failed to open asset '%s'", + assetName->toStdString().c_str()); + } + auto assetBuffer = AAsset_getBuffer(asset); + if (!assetBuffer) { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Could not get buffer for asset '%s'", + assetName->toStdString().c_str()); + } + JITCallGuard guard; + module_ = torch::jit::load(std::make_unique( + assetBuffer, AAsset_getLength(asset))); + AAsset_close(asset); + module_.eval(); + deviceType_ = deviceJniCodeToDeviceType(device); + } +#endif + + static void registerNatives() { + registerHybrid({ + makeNativeMethod("initHybrid", PytorchJni::initHybrid), +#ifdef __ANDROID__ + makeNativeMethod( + "initHybridAndroidAsset", PytorchJni::initHybridAndroidAsset), +#endif + makeNativeMethod("forward", PytorchJni::forward), + makeNativeMethod("runMethod", PytorchJni::runMethod), + }); + } + + facebook::jni::local_ref forward( + facebook::jni::alias_ref< + facebook::jni::JArrayClass::javaobject> + jinputs) { + Trace _s{"jni::Module::forward"}; + std::vector inputs{}; + size_t n = jinputs->size(); + inputs.reserve(n); + for (const auto i : c10::irange(n)) { + at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); + inputs.push_back(std::move(atIValue)); + } + auto output = [&]() { + JITCallGuard guard; + return module_.forward(std::move(inputs)); + }(); + return JIValue::newJIValueFromAtIValue(output); + } + + facebook::jni::local_ref runMethod( + facebook::jni::alias_ref jmethodName, + facebook::jni::alias_ref< + facebook::jni::JArrayClass::javaobject> + jinputs) { + std::string methodName = jmethodName->toStdString(); + + std::vector inputs{}; + size_t n = jinputs->size(); + inputs.reserve(n); + for (const auto i : c10::irange(n)) { + at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); + inputs.push_back(std::move(atIValue)); + } + if (auto method = module_.find_method(methodName)) { + auto output = [&]() { + JITCallGuard guard; + return (*method)(std::move(inputs)); + }(); + return JIValue::newJIValueFromAtIValue(output); + } + + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Undefined method %s", + methodName.c_str()); + } +}; + +} // namespace pytorch_jni + +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { + return facebook::jni::initialize(vm, [] { + pytorch_jni::common_registerNatives(); + pytorch_jni::PytorchJni::registerNatives(); + }); +} diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp new file mode 100644 index 000000000000..7fa4de90037b --- /dev/null +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp @@ -0,0 +1,209 @@ +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include "caffe2/serialize/read_adapter_interface.h" + +#include "pytorch_jni_common.h" + +#ifdef __ANDROID__ +#include +#include +#include +#endif + +namespace pytorch_jni { + +namespace { + +struct LiteJITCallGuard { + // VariableType dispatch is not included in default mobile build. We need set + // this guard globally to avoid dispatch error (only for dynamic dispatch). + // Thanks to the unification of Variable class and Tensor class it's no longer + // required to toggle the NonVariableTypeMode per op - so it doesn't hurt to + // always set NonVariableTypeMode for inference only use case. + // TODO: Ideally AutoNonVariableTypeMode in this file should be changed to + // InferenceMode but it's blocked due to typeahead application on Oculus + // (D27943428). To unblock, we need to find out which op is making inplace + // update to an inference tensor outside InferenceMode and properly guard it. + torch::AutoNonVariableTypeMode non_var_guard; +}; + +} // namespace + +class PytorchJni : public facebook::jni::HybridClass { + private: + friend HybridBase; + torch::jit::mobile::Module module_; + c10::DeviceType deviceType_; + + public: + constexpr static auto kJavaDescriptor = "Lorg/pytorch/LiteNativePeer;"; + + static facebook::jni::local_ref initHybrid( + facebook::jni::alias_ref, + facebook::jni::alias_ref modelPath, + facebook::jni::alias_ref< + facebook::jni::JMap> + extraFiles, + jint device) { + return makeCxxInstance(modelPath, extraFiles, device); + } + +#ifdef __ANDROID__ + static facebook::jni::local_ref initHybridAndroidAsset( + facebook::jni::alias_ref, + facebook::jni::alias_ref assetName, + facebook::jni::alias_ref assetManager, + jint device) { + return makeCxxInstance(assetName, assetManager, device); + } +#endif + + PytorchJni( + facebook::jni::alias_ref modelPath, + facebook::jni::alias_ref< + facebook::jni::JMap> + extraFiles, + jint device) { + LiteJITCallGuard guard; + std::unordered_map extra_files; + const auto has_extra = extraFiles && extraFiles->size() > 0; + if (has_extra) { + for (const auto& e : *extraFiles) { + extra_files[e.first->toStdString()] = ""; + } + } + deviceType_ = deviceJniCodeToDeviceType(device); + module_ = torch::jit::_load_for_mobile( + std::move(modelPath->toStdString()), std::nullopt, extra_files); + torch::jit::_load_extra_only_for_mobile( + std::move(modelPath->toStdString()), std::nullopt, extra_files); + if (has_extra) { + static auto putMethod = + facebook::jni::JMap:: + javaClassStatic() + ->template getMethod( + facebook::jni::alias_ref, + facebook::jni::alias_ref)>("put"); + for (const auto& ef : extra_files) { + putMethod( + extraFiles, + facebook::jni::make_jstring(ef.first), + facebook::jni::make_jstring(ef.second)); + } + } + } + +#ifdef __ANDROID__ + PytorchJni( + facebook::jni::alias_ref assetName, + facebook::jni::alias_ref assetManager, + jint device) { + JNIEnv* env = facebook::jni::Environment::current(); + AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get()); + if (!mgr) { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Unable to get asset manager"); + } + AAsset* asset = AAssetManager_open( + mgr, assetName->toStdString().c_str(), AASSET_MODE_BUFFER); + if (!asset) { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Failed to open asset '%s'", + assetName->toStdString().c_str()); + } + auto assetBuffer = AAsset_getBuffer(asset); + if (!assetBuffer) { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Could not get buffer for asset '%s'", + assetName->toStdString().c_str()); + } + LiteJITCallGuard guard; + module_ = + torch::jit::_load_for_mobile(std::make_unique( + assetBuffer, AAsset_getLength(asset))); + AAsset_close(asset); + deviceType_ = deviceJniCodeToDeviceType(device); + } +#endif + + static void registerNatives() { + registerHybrid({ + makeNativeMethod("initHybrid", PytorchJni::initHybrid), +#ifdef __ANDROID__ + makeNativeMethod( + "initHybridAndroidAsset", PytorchJni::initHybridAndroidAsset), +#endif + makeNativeMethod("forward", PytorchJni::forward), + makeNativeMethod("runMethod", PytorchJni::runMethod), + }); + } + + facebook::jni::local_ref forward( + facebook::jni::alias_ref< + facebook::jni::JArrayClass::javaobject> + jinputs) { + std::vector inputs{}; + size_t n = jinputs->size(); + inputs.reserve(n); + for (const auto i : c10::irange(n)) { + at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); + inputs.push_back(std::move(atIValue)); + } + + auto output = [&]() { + LiteJITCallGuard guard; + return module_.forward(inputs); + }(); + return JIValue::newJIValueFromAtIValue(output); + } + + facebook::jni::local_ref runMethod( + facebook::jni::alias_ref jmethodName, + facebook::jni::alias_ref< + facebook::jni::JArrayClass::javaobject> + jinputs) { + std::string methodName = jmethodName->toStdString(); + + std::vector inputs{}; + size_t n = jinputs->size(); + inputs.reserve(n); + for (const auto i : c10::irange(n)) { + at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); + inputs.push_back(std::move(atIValue)); + } + if (auto method = module_.find_method(methodName)) { + auto output = [&]() { + LiteJITCallGuard guard; + return module_.get_method(methodName)(inputs); + }(); + return JIValue::newJIValueFromAtIValue(output); + } + + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Undefined method %s", + methodName.c_str()); + } +}; + +} // namespace pytorch_jni + +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { + return facebook::jni::initialize(vm, [] { + pytorch_jni::common_registerNatives(); + pytorch_jni::PytorchJni::registerNatives(); + }); +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/DType.java b/android/pytorch_android/src/main/java/org/pytorch/DType.java new file mode 100644 index 000000000000..da319bbe72e4 --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/DType.java @@ -0,0 +1,27 @@ +package org.pytorch; + +/** Codes representing tensor data types. */ +public enum DType { + // NOTE: "jniCode" must be kept in sync with pytorch_jni_common.cpp. + // NOTE: Never serialize "jniCode", because it can change between releases. + + /** Code for dtype torch.uint8. {@link Tensor#dtype()} */ + UINT8(1), + /** Code for dtype torch.int8. {@link Tensor#dtype()} */ + INT8(2), + /** Code for dtype torch.int32. {@link Tensor#dtype()} */ + INT32(3), + /** Code for dtype torch.float32. {@link Tensor#dtype()} */ + FLOAT32(4), + /** Code for dtype torch.int64. {@link Tensor#dtype()} */ + INT64(5), + /** Code for dtype torch.float64. {@link Tensor#dtype()} */ + FLOAT64(6), + ; + + final int jniCode; + + DType(int jniCode) { + this.jniCode = jniCode; + } +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/Device.java b/android/pytorch_android/src/main/java/org/pytorch/Device.java new file mode 100644 index 000000000000..ec3b8a303805 --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/Device.java @@ -0,0 +1,15 @@ +package org.pytorch; + +public enum Device { + // Must be in sync with kDeviceCPU, kDeviceVulkan in + // pytorch_android/src/main/cpp/pytorch_jni_lite.cpp + CPU(1), + VULKAN(2), + ; + + final int jniCode; + + Device(int jniCode) { + this.jniCode = jniCode; + } +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/INativePeer.java b/android/pytorch_android/src/main/java/org/pytorch/INativePeer.java new file mode 100644 index 000000000000..8f70619c3538 --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/INativePeer.java @@ -0,0 +1,9 @@ +package org.pytorch; + +interface INativePeer { + void resetNative(); + + IValue forward(IValue... inputs); + + IValue runMethod(String methodName, IValue... inputs); +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/IValue.java b/android/pytorch_android/src/main/java/org/pytorch/IValue.java new file mode 100644 index 000000000000..28ebf7a9011d --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/IValue.java @@ -0,0 +1,343 @@ +package org.pytorch; + +import com.facebook.jni.annotations.DoNotStrip; +import java.util.Locale; +import java.util.Map; + +/** + * Java representation of a TorchScript value, which is implemented as tagged union that can be one + * of the supported types: https://pytorch.org/docs/stable/jit.html#types . + * + *

Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}. + * + *

{@code IValue} objects are constructed with {@code IValue.from(value)}, {@code + * IValue.tupleFrom(value1, value2, ...)}, {@code IValue.listFrom(value1, value2, ...)}, or one of + * the {@code dict} methods, depending on the key type. + * + *

Data is retrieved from {@code IValue} objects with the {@code toX()} methods. Note that {@code + * str}-type IValues must be extracted with {@link #toStr()}, rather than {@link #toString()}. + * + *

{@code IValue} objects may retain references to objects passed into their constructors, and + * may return references to their internal state from {@code toX()}. + */ +@DoNotStrip +public class IValue { + private static final int TYPE_CODE_NULL = 1; + + private static final int TYPE_CODE_TENSOR = 2; + private static final int TYPE_CODE_BOOL = 3; + private static final int TYPE_CODE_LONG = 4; + private static final int TYPE_CODE_DOUBLE = 5; + private static final int TYPE_CODE_STRING = 6; + + private static final int TYPE_CODE_TUPLE = 7; + private static final int TYPE_CODE_BOOL_LIST = 8; + private static final int TYPE_CODE_LONG_LIST = 9; + private static final int TYPE_CODE_DOUBLE_LIST = 10; + private static final int TYPE_CODE_TENSOR_LIST = 11; + private static final int TYPE_CODE_LIST = 12; + + private static final int TYPE_CODE_DICT_STRING_KEY = 13; + private static final int TYPE_CODE_DICT_LONG_KEY = 14; + + private String[] TYPE_NAMES = { + "Unknown", + "Null", + "Tensor", + "Bool", + "Long", + "Double", + "String", + "Tuple", + "BoolList", + "LongList", + "DoubleList", + "TensorList", + "GenericList", + "DictStringKey", + "DictLongKey", + }; + + @DoNotStrip private final int mTypeCode; + @DoNotStrip private Object mData; + + @DoNotStrip + private IValue(int typeCode) { + this.mTypeCode = typeCode; + } + + @DoNotStrip + public boolean isNull() { + return TYPE_CODE_NULL == this.mTypeCode; + } + + @DoNotStrip + public boolean isTensor() { + return TYPE_CODE_TENSOR == this.mTypeCode; + } + + @DoNotStrip + public boolean isBool() { + return TYPE_CODE_BOOL == this.mTypeCode; + } + + @DoNotStrip + public boolean isLong() { + return TYPE_CODE_LONG == this.mTypeCode; + } + + @DoNotStrip + public boolean isDouble() { + return TYPE_CODE_DOUBLE == this.mTypeCode; + } + + @DoNotStrip + public boolean isString() { + return TYPE_CODE_STRING == this.mTypeCode; + } + + @DoNotStrip + public boolean isTuple() { + return TYPE_CODE_TUPLE == this.mTypeCode; + } + + @DoNotStrip + public boolean isBoolList() { + return TYPE_CODE_BOOL_LIST == this.mTypeCode; + } + + @DoNotStrip + public boolean isLongList() { + return TYPE_CODE_LONG_LIST == this.mTypeCode; + } + + @DoNotStrip + public boolean isDoubleList() { + return TYPE_CODE_DOUBLE_LIST == this.mTypeCode; + } + + @DoNotStrip + public boolean isTensorList() { + return TYPE_CODE_TENSOR_LIST == this.mTypeCode; + } + + @DoNotStrip + public boolean isList() { + return TYPE_CODE_LIST == this.mTypeCode; + } + + @DoNotStrip + public boolean isDictStringKey() { + return TYPE_CODE_DICT_STRING_KEY == this.mTypeCode; + } + + @DoNotStrip + public boolean isDictLongKey() { + return TYPE_CODE_DICT_LONG_KEY == this.mTypeCode; + } + + /** Creates a new {@code IValue} of type {@code Optional} that contains no value. */ + @DoNotStrip + public static IValue optionalNull() { + return new IValue(TYPE_CODE_NULL); + } + /** Creates a new {@code IValue} of type {@code Tensor}. */ + @DoNotStrip + public static IValue from(Tensor tensor) { + final IValue iv = new IValue(TYPE_CODE_TENSOR); + iv.mData = tensor; + return iv; + } + /** Creates a new {@code IValue} of type {@code bool}. */ + @DoNotStrip + public static IValue from(boolean value) { + final IValue iv = new IValue(TYPE_CODE_BOOL); + iv.mData = value; + return iv; + } + + /** Creates a new {@code IValue} of type {@code int}. */ + @DoNotStrip + public static IValue from(long value) { + final IValue iv = new IValue(TYPE_CODE_LONG); + iv.mData = value; + return iv; + } + /** Creates a new {@code IValue} of type {@code float}. */ + @DoNotStrip + public static IValue from(double value) { + final IValue iv = new IValue(TYPE_CODE_DOUBLE); + iv.mData = value; + return iv; + } + /** Creates a new {@code IValue} of type {@code str}. */ + @DoNotStrip + public static IValue from(String value) { + final IValue iv = new IValue(TYPE_CODE_STRING); + iv.mData = value; + return iv; + } + + /** Creates a new {@code IValue} of type {@code List[bool]}. */ + @DoNotStrip + public static IValue listFrom(boolean... list) { + final IValue iv = new IValue(TYPE_CODE_BOOL_LIST); + iv.mData = list; + return iv; + } + /** Creates a new {@code IValue} of type {@code List[int]}. */ + @DoNotStrip + public static IValue listFrom(long... list) { + final IValue iv = new IValue(TYPE_CODE_LONG_LIST); + iv.mData = list; + return iv; + } + /** Creates a new {@code IValue} of type {@code List[float]}. */ + @DoNotStrip + public static IValue listFrom(double... list) { + final IValue iv = new IValue(TYPE_CODE_DOUBLE_LIST); + iv.mData = list; + return iv; + } + + /** Creates a new {@code IValue} of type {@code List[Tensor]}. */ + @DoNotStrip + public static IValue listFrom(Tensor... list) { + final IValue iv = new IValue(TYPE_CODE_TENSOR_LIST); + iv.mData = list; + return iv; + } + + /** Creates a new {@code IValue} of type {@code List[T]}. All elements must have the same type. */ + @DoNotStrip + public static IValue listFrom(IValue... array) { + final int size = array.length; + if (size > 0) { + final int typeCode0 = array[0].mTypeCode; + for (int i = 1; i < size; i++) { + if (typeCode0 != array[i].mTypeCode) { + throw new IllegalArgumentException("List must contain items of the same type"); + } + } + } + + final IValue iv = new IValue(TYPE_CODE_LIST); + iv.mData = array; + return iv; + } + /** Creates a new {@code IValue} of type {@code Tuple[T0, T1, ...]}. */ + @DoNotStrip + public static IValue tupleFrom(IValue... array) { + final IValue iv = new IValue(TYPE_CODE_TUPLE); + iv.mData = array; + return iv; + } + + /** Creates a new {@code IValue} of type {@code Dict[str, V]}. */ + @DoNotStrip + public static IValue dictStringKeyFrom(Map map) { + final IValue iv = new IValue(TYPE_CODE_DICT_STRING_KEY); + iv.mData = map; + return iv; + } + /** Creates a new {@code IValue} of type {@code Dict[int, V]}. */ + @DoNotStrip + public static IValue dictLongKeyFrom(Map map) { + final IValue iv = new IValue(TYPE_CODE_DICT_LONG_KEY); + iv.mData = map; + return iv; + } + + @DoNotStrip + public Tensor toTensor() { + preconditionType(TYPE_CODE_TENSOR, mTypeCode); + return (Tensor) mData; + } + + @DoNotStrip + public boolean toBool() { + preconditionType(TYPE_CODE_BOOL, mTypeCode); + return (boolean) mData; + } + + @DoNotStrip + public long toLong() { + preconditionType(TYPE_CODE_LONG, mTypeCode); + return (long) mData; + } + + @DoNotStrip + public double toDouble() { + preconditionType(TYPE_CODE_DOUBLE, mTypeCode); + return (double) mData; + } + + @DoNotStrip + public String toStr() { + preconditionType(TYPE_CODE_STRING, mTypeCode); + return (String) mData; + } + + @DoNotStrip + public boolean[] toBoolList() { + preconditionType(TYPE_CODE_BOOL_LIST, mTypeCode); + return (boolean[]) mData; + } + + @DoNotStrip + public long[] toLongList() { + preconditionType(TYPE_CODE_LONG_LIST, mTypeCode); + return (long[]) mData; + } + + @DoNotStrip + public double[] toDoubleList() { + preconditionType(TYPE_CODE_DOUBLE_LIST, mTypeCode); + return (double[]) mData; + } + + @DoNotStrip + public Tensor[] toTensorList() { + preconditionType(TYPE_CODE_TENSOR_LIST, mTypeCode); + return (Tensor[]) mData; + } + + @DoNotStrip + public IValue[] toList() { + preconditionType(TYPE_CODE_LIST, mTypeCode); + return (IValue[]) mData; + } + + @DoNotStrip + public IValue[] toTuple() { + preconditionType(TYPE_CODE_TUPLE, mTypeCode); + return (IValue[]) mData; + } + + @DoNotStrip + public Map toDictStringKey() { + preconditionType(TYPE_CODE_DICT_STRING_KEY, mTypeCode); + return (Map) mData; + } + + @DoNotStrip + public Map toDictLongKey() { + preconditionType(TYPE_CODE_DICT_LONG_KEY, mTypeCode); + return (Map) mData; + } + + private void preconditionType(int typeCodeExpected, int typeCode) { + if (typeCode != typeCodeExpected) { + throw new IllegalStateException( + String.format( + Locale.US, + "Expected IValue type %s, actual type %s", + getTypeName(typeCodeExpected), + getTypeName(typeCode))); + } + } + + private String getTypeName(int typeCode) { + return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown"; + } +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/LiteModuleLoader.java b/android/pytorch_android/src/main/java/org/pytorch/LiteModuleLoader.java new file mode 100644 index 000000000000..2640e2b21342 --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/LiteModuleLoader.java @@ -0,0 +1,49 @@ +package org.pytorch; + +import android.content.res.AssetManager; +import java.util.Map; + +public class LiteModuleLoader { + + /** + * Loads a serialized TorchScript module from the specified path on the disk to run on specified + * device. The model should be generated from this api _save_for_lite_interpreter(). + * + * @param modelPath path to file that contains the serialized TorchScript module. + * @param extraFiles map with extra files names as keys, content of them will be loaded to values. + * @param device {@link org.pytorch.Device} to use for running specified module. + * @return new {@link org.pytorch.Module} object which owns torch::jit::mobile::Module. + */ + public static Module load( + final String modelPath, final Map extraFiles, final Device device) { + return new Module(new LiteNativePeer(modelPath, extraFiles, device)); + } + + /** + * Loads a serialized TorchScript module from the specified path on the disk to run on CPU. The + * model should be generated from this api _save_for_lite_interpreter(). + * + * @param modelPath path to file that contains the serialized TorchScript module. + * @return new {@link org.pytorch.Module} object which owns torch::jit::mobile::Module. + */ + public static Module load(final String modelPath) { + return new Module(new LiteNativePeer(modelPath, null, Device.CPU)); + } + + /** + * Attention: This is not recommended way of loading production modules, as prepackaged assets + * increase apk size etc. For production usage consider using loading from file on the disk {@link + * org.pytorch.Module#load(String)}. + * + *

This method is meant to use in tests and demos. + */ + public static Module loadModuleFromAsset( + final AssetManager assetManager, final String assetName, final Device device) { + return new Module(new LiteNativePeer(assetName, assetManager, device)); + } + + public static Module loadModuleFromAsset( + final AssetManager assetManager, final String assetName) { + return new Module(new LiteNativePeer(assetName, assetManager, Device.CPU)); + } +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/LiteNativePeer.java b/android/pytorch_android/src/main/java/org/pytorch/LiteNativePeer.java new file mode 100644 index 000000000000..c47c6a8d062b --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/LiteNativePeer.java @@ -0,0 +1,64 @@ +package org.pytorch; + +import com.facebook.jni.HybridData; +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; +import java.util.Map; + +class LiteNativePeer implements INativePeer { + static { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + NativeLoader.loadLibrary("pytorch_jni_lite"); + PyTorchCodegenLoader.loadNativeLibs(); + } + + private final HybridData mHybridData; + + private static native HybridData initHybrid( + String moduleAbsolutePath, Map extraFiles, int deviceJniCode); + + private static native HybridData initHybridAndroidAsset( + String assetName, /* android.content.res.AssetManager */ + Object androidAssetManager, + int deviceJniCode); + + LiteNativePeer(String moduleAbsolutePath, Map extraFiles, Device device) { + mHybridData = initHybrid(moduleAbsolutePath, extraFiles, device.jniCode); + } + + LiteNativePeer( + String assetName, /* android.content.res.AssetManager */ + Object androidAssetManager, + Device device) { + mHybridData = initHybridAndroidAsset(assetName, androidAssetManager, device.jniCode); + } + + /** + * Explicitly destroys the native torch::jit::mobile::Module. Calling this method is not required, + * as the native object will be destroyed when this object is garbage-collected. However, the + * timing of garbage collection is not guaranteed, so proactively calling {@code resetNative} can + * free memory more quickly. See {@link com.facebook.jni.HybridData#resetNative}. + */ + public void resetNative() { + mHybridData.resetNative(); + } + + /** + * Runs the 'forward' method of this module with the specified arguments. + * + * @param inputs arguments for the TorchScript module's 'forward' method. + * @return return value from the 'forward' method. + */ + public native IValue forward(IValue... inputs); + + /** + * Runs the specified method of this module with the specified arguments. + * + * @param methodName name of the TorchScript method to run. + * @param inputs arguments that will be passed to TorchScript method. + * @return return value from the method. + */ + public native IValue runMethod(String methodName, IValue... inputs); +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/MemoryFormat.java b/android/pytorch_android/src/main/java/org/pytorch/MemoryFormat.java new file mode 100644 index 000000000000..248b175d5a2d --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/MemoryFormat.java @@ -0,0 +1,14 @@ +package org.pytorch; + +public enum MemoryFormat { + CONTIGUOUS(1), + CHANNELS_LAST(2), + CHANNELS_LAST_3D(3), + ; + + final int jniCode; + + MemoryFormat(int jniCode) { + this.jniCode = jniCode; + } +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/Module.java b/android/pytorch_android/src/main/java/org/pytorch/Module.java new file mode 100644 index 000000000000..1d330c1fcf7a --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/Module.java @@ -0,0 +1,75 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +package org.pytorch; + +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; +import java.util.Map; + +/** Java wrapper for torch::jit::Module. */ +public class Module { + + private INativePeer mNativePeer; + + /** + * Loads a serialized TorchScript module from the specified path on the disk to run on specified + * device. + * + * @param modelPath path to file that contains the serialized TorchScript module. + * @param extraFiles map with extra files names as keys, content of them will be loaded to values. + * @param device {@link org.pytorch.Device} to use for running specified module. + * @return new {@link org.pytorch.Module} object which owns torch::jit::Module. + */ + public static Module load( + final String modelPath, final Map extraFiles, final Device device) { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + return new Module(new NativePeer(modelPath, extraFiles, device)); + } + + /** + * Loads a serialized TorchScript module from the specified path on the disk to run on CPU. + * + * @param modelPath path to file that contains the serialized TorchScript module. + * @return new {@link org.pytorch.Module} object which owns torch::jit::Module. + */ + public static Module load(final String modelPath) { + return load(modelPath, null, Device.CPU); + } + + Module(INativePeer nativePeer) { + this.mNativePeer = nativePeer; + } + + /** + * Runs the 'forward' method of this module with the specified arguments. + * + * @param inputs arguments for the TorchScript module's 'forward' method. + * @return return value from the 'forward' method. + */ + public IValue forward(IValue... inputs) { + return mNativePeer.forward(inputs); + } + + /** + * Runs the specified method of this module with the specified arguments. + * + * @param methodName name of the TorchScript method to run. + * @param inputs arguments that will be passed to TorchScript method. + * @return return value from the method. + */ + public IValue runMethod(String methodName, IValue... inputs) { + return mNativePeer.runMethod(methodName, inputs); + } + + /** + * Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the + * native object will be destroyed when this object is garbage-collected. However, the timing of + * garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory + * more quickly. See {@link com.facebook.jni.HybridData#resetNative}. + */ + public void destroy() { + mNativePeer.resetNative(); + } +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/NativePeer.java b/android/pytorch_android/src/main/java/org/pytorch/NativePeer.java new file mode 100644 index 000000000000..0c73c40a3ab5 --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/NativePeer.java @@ -0,0 +1,46 @@ +package org.pytorch; + +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import com.facebook.soloader.nativeloader.NativeLoader; +import java.util.Map; + +class NativePeer implements INativePeer { + static { + NativeLoader.loadLibrary("pytorch_jni"); + PyTorchCodegenLoader.loadNativeLibs(); + } + + private final HybridData mHybridData; + + @DoNotStrip + private static native HybridData initHybrid( + String moduleAbsolutePath, Map extraFiles, int deviceJniCode); + + @DoNotStrip + private static native HybridData initHybridAndroidAsset( + String assetName, /* android.content.res.AssetManager */ + Object androidAssetManager, + int deviceJniCode); + + NativePeer(String moduleAbsolutePath, Map extraFiles, Device device) { + mHybridData = initHybrid(moduleAbsolutePath, extraFiles, device.jniCode); + } + + NativePeer( + String assetName, /* android.content.res.AssetManager */ + Object androidAssetManager, + Device device) { + mHybridData = initHybridAndroidAsset(assetName, androidAssetManager, device.jniCode); + } + + public void resetNative() { + mHybridData.resetNative(); + } + + @DoNotStrip + public native IValue forward(IValue... inputs); + + @DoNotStrip + public native IValue runMethod(String methodName, IValue... inputs); +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/PyTorchAndroid.java b/android/pytorch_android/src/main/java/org/pytorch/PyTorchAndroid.java new file mode 100644 index 000000000000..becfbfbd0ba6 --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/PyTorchAndroid.java @@ -0,0 +1,50 @@ +package org.pytorch; + +import android.content.res.AssetManager; +import com.facebook.jni.annotations.DoNotStrip; +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; + +public final class PyTorchAndroid { + static { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + NativeLoader.loadLibrary("pytorch_jni_lite"); + PyTorchCodegenLoader.loadNativeLibs(); + } + + /** + * Attention: This is not recommended way of loading production modules, as prepackaged assets + * increase apk size etc. For production usage consider using loading from file on the disk {@link + * org.pytorch.Module#load(String)}. + * + *

This method is meant to use in tests and demos. + */ + public static Module loadModuleFromAsset( + final AssetManager assetManager, final String assetName, final Device device) { + return new Module(new NativePeer(assetName, assetManager, device)); + } + + public static Module loadModuleFromAsset( + final AssetManager assetManager, final String assetName) { + return new Module(new NativePeer(assetName, assetManager, Device.CPU)); + } + + /** + * Globally sets the number of threads used on native side. Attention: Has global effect, all + * modules use one thread pool with specified number of threads. + * + * @param numThreads number of threads, must be positive number. + */ + public static void setNumThreads(int numThreads) { + if (numThreads < 1) { + throw new IllegalArgumentException("Number of threads cannot be less than 1"); + } + + nativeSetNumThreads(numThreads); + } + + @DoNotStrip + private static native void nativeSetNumThreads(int numThreads); +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/PyTorchCodegenLoader.java b/android/pytorch_android/src/main/java/org/pytorch/PyTorchCodegenLoader.java new file mode 100644 index 000000000000..2485caac3e4f --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/PyTorchCodegenLoader.java @@ -0,0 +1,16 @@ +package org.pytorch; + +import com.facebook.soloader.nativeloader.NativeLoader; + +public class PyTorchCodegenLoader { + + public static void loadNativeLibs() { + try { + NativeLoader.loadLibrary("torch-code-gen"); + } catch (Throwable t) { + // Loading the codegen lib is best-effort since it's only there for query based builds. + } + } + + private PyTorchCodegenLoader() {} +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/Tensor.java b/android/pytorch_android/src/main/java/org/pytorch/Tensor.java new file mode 100644 index 000000000000..83a7c021bf6a --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/Tensor.java @@ -0,0 +1,735 @@ +package org.pytorch; + +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.util.Arrays; +import java.util.Locale; + +/** + * Representation of a Tensor. Behavior is similar to PyTorch's tensor objects. + * + *

Most tensors will be constructed as {@code Tensor.fromBlob(data, shape)}, where {@code data} + * can be an array or a direct {@link Buffer} (of the proper subclass). Helper methods are provided + * to allocate buffers properly. + * + *

To access Tensor data, see {@link #dtype()}, {@link #shape()}, and various {@code getDataAs*} + * methods. + * + *

When constructing {@code Tensor} objects with {@code data} as an array, it is not specified + * whether this data is copied or retained as a reference so it is recommended not to modify it + * after constructing. {@code data} passed as a {@link Buffer} is not copied, so it can be modified + * between {@link Module} calls to avoid reallocation. Data retrieved from {@code Tensor} objects + * may be copied or may be a reference to the {@code Tensor}'s internal data buffer. {@code shape} + * is always copied. + */ +public abstract class Tensor { + private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null"; + private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null"; + private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null"; + private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Shape elements must be non negative"; + private static final String ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER = + "Data buffer must have native byte order (java.nio.ByteOrder#nativeOrder)"; + private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT = + "Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)"; + + @DoNotStrip final long[] shape; + final MemoryFormat memoryFormat; + + private static final int INT_SIZE_BYTES = 4; + private static final int FLOAT_SIZE_BYTES = 4; + private static final int LONG_SIZE_BYTES = 8; + private static final int DOUBLE_SIZE_BYTES = 8; + + /** + * Allocates a new direct {@link java.nio.ByteBuffer} with native byte order with specified + * capacity that can be used in {@link Tensor#fromBlob(ByteBuffer, long[])}, {@link + * Tensor#fromBlobUnsigned(ByteBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static ByteBuffer allocateByteBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder()); + } + + /** + * Allocates a new direct {@link java.nio.IntBuffer} with native byte order with specified + * capacity that can be used in {@link Tensor#fromBlob(IntBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static IntBuffer allocateIntBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asIntBuffer(); + } + + /** + * Allocates a new direct {@link java.nio.FloatBuffer} with native byte order with specified + * capacity that can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static FloatBuffer allocateFloatBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * FLOAT_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer(); + } + + /** + * Allocates a new direct {@link java.nio.LongBuffer} with native byte order with specified + * capacity that can be used in {@link Tensor#fromBlob(LongBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static LongBuffer allocateLongBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * LONG_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asLongBuffer(); + } + + /** + * Allocates a new direct {@link java.nio.DoubleBuffer} with native byte order with specified + * capacity that can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static DoubleBuffer allocateDoubleBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * DOUBLE_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asDoubleBuffer(); + } + + /** + * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of + * bytes. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlobUnsigned(byte[] data, long[] shape, MemoryFormat memoryFormat) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape)); + byteBuffer.put(data); + return new Tensor_uint8(byteBuffer, shape, memoryFormat); + } + + public static Tensor fromBlobUnsigned(byte[] data, long[] shape) { + return fromBlobUnsigned(data, shape, MemoryFormat.CONTIGUOUS); + } + + /** + * Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of + * bytes. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlob(byte[] data, long[] shape, MemoryFormat memoryFormat) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape)); + byteBuffer.put(data); + return new Tensor_int8(byteBuffer, shape, memoryFormat); + } + + public static Tensor fromBlob(byte[] data, long[] shape) { + return fromBlob(data, shape, MemoryFormat.CONTIGUOUS); + } + + /** + * Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of + * ints. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlob(int[] data, long[] shape, MemoryFormat memoryFormat) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final IntBuffer intBuffer = allocateIntBuffer((int) numel(shape)); + intBuffer.put(data); + return new Tensor_int32(intBuffer, shape, memoryFormat); + } + + public static Tensor fromBlob(int[] data, long[] shape) { + return fromBlob(data, shape, MemoryFormat.CONTIGUOUS); + } + + /** + * Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array + * of floats. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlob(float[] data, long[] shape, MemoryFormat memoryFormat) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final FloatBuffer floatBuffer = allocateFloatBuffer((int) numel(shape)); + floatBuffer.put(data); + return new Tensor_float32(floatBuffer, shape, memoryFormat); + } + + public static Tensor fromBlob(float[] data, long[] shape) { + return fromBlob(data, shape, MemoryFormat.CONTIGUOUS); + } + + /** + * Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of + * longs. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlob(long[] data, long[] shape, MemoryFormat memoryFormat) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final LongBuffer longBuffer = allocateLongBuffer((int) numel(shape)); + longBuffer.put(data); + return new Tensor_int64(longBuffer, shape, memoryFormat); + } + + public static Tensor fromBlob(long[] data, long[] shape) { + return fromBlob(data, shape, MemoryFormat.CONTIGUOUS); + } + + /** + * Creates a new Tensor instance with dtype torch.float64 with specified shape and data as array + * of doubles. + * + * @param shape Tensor shape + * @param data Tensor elements + */ + public static Tensor fromBlob(double[] data, long[] shape, MemoryFormat memoryFormat) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final DoubleBuffer doubleBuffer = allocateDoubleBuffer((int) numel(shape)); + doubleBuffer.put(data); + return new Tensor_float64(doubleBuffer, shape, memoryFormat); + } + + public static Tensor fromBlob(double[] data, long[] shape) { + return fromBlob(data, shape, MemoryFormat.CONTIGUOUS); + } + + /** + * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_uint8(data, shape, memoryFormat); + } + + public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) { + return fromBlobUnsigned(data, shape, MemoryFormat.CONTIGUOUS); + } + + /** + * Creates a new Tensor instance with dtype torch.int8 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_int8(data, shape, memoryFormat); + } + + public static Tensor fromBlob(ByteBuffer data, long[] shape) { + return fromBlob(data, shape, MemoryFormat.CONTIGUOUS); + } + + /** + * Creates a new Tensor instance with dtype torch.int32 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(IntBuffer data, long[] shape, MemoryFormat memoryFormat) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_int32(data, shape, memoryFormat); + } + + public static Tensor fromBlob(IntBuffer data, long[] shape) { + return fromBlob(data, shape, MemoryFormat.CONTIGUOUS); + } + + /** + * Creates a new Tensor instance with dtype torch.float32 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(FloatBuffer data, long[] shape, MemoryFormat memoryFormat) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_float32(data, shape, memoryFormat); + } + + public static Tensor fromBlob(FloatBuffer data, long[] shape) { + return fromBlob(data, shape, MemoryFormat.CONTIGUOUS); + } + + /** + * Creates a new Tensor instance with dtype torch.int64 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(LongBuffer data, long[] shape, MemoryFormat memoryFormat) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_int64(data, shape, memoryFormat); + } + + public static Tensor fromBlob(LongBuffer data, long[] shape) { + return fromBlob(data, shape, MemoryFormat.CONTIGUOUS); + } + + /** + * Creates a new Tensor instance with dtype torch.float64 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(DoubleBuffer data, long[] shape, MemoryFormat memoryFormat) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_float64(data, shape, memoryFormat); + } + + public static Tensor fromBlob(DoubleBuffer data, long[] shape) { + return fromBlob(data, shape, MemoryFormat.CONTIGUOUS); + } + + @DoNotStrip private HybridData mHybridData; + + private Tensor(long[] shape, MemoryFormat memoryFormat) { + checkShape(shape); + this.shape = Arrays.copyOf(shape, shape.length); + this.memoryFormat = memoryFormat; + } + + /** Returns the number of elements in this tensor. */ + public long numel() { + return numel(this.shape); + } + + /** Calculates the number of elements in a tensor with the specified shape. */ + public static long numel(long[] shape) { + checkShape(shape); + int result = 1; + for (long s : shape) { + result *= s; + } + return result; + } + + /** Returns the shape of this tensor. (The array is a fresh copy.) */ + public long[] shape() { + return Arrays.copyOf(shape, shape.length); + } + + /** Returns the memory format of this tensor. */ + public MemoryFormat memoryFormat() { + return memoryFormat; + } + + /** @return data type of this tensor. */ + public abstract DType dtype(); + + // Called from native + @DoNotStrip + int dtypeJniCode() { + return dtype().jniCode; + } + + // Called from native + @DoNotStrip + int memoryFormatJniCode() { + return memoryFormat.jniCode; + } + + /** + * @return a Java byte array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-int8 tensor. + */ + public byte[] getDataAsByteArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); + } + + /** + * @return a Java byte array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-uint8 tensor. + */ + public byte[] getDataAsUnsignedByteArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); + } + + /** + * @return a Java int array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-int32 tensor. + */ + public int[] getDataAsIntArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as int array."); + } + + /** + * @return a Java float array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-float32 tensor. + */ + public float[] getDataAsFloatArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as float array."); + } + + /** + * @return a Java long array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-int64 tensor. + */ + public long[] getDataAsLongArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as long array."); + } + + /** + * @return a Java double array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-float64 tensor. + */ + public double[] getDataAsDoubleArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as double array."); + } + + @DoNotStrip + Buffer getRawDataBuffer() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer."); + } + + static class Tensor_uint8 extends Tensor { + private final ByteBuffer data; + + private Tensor_uint8(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) { + super(shape, memoryFormat); + this.data = data; + } + + @Override + public DType dtype() { + return DType.UINT8; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public byte[] getDataAsUnsignedByteArray() { + data.rewind(); + byte[] arr = new byte[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape)); + } + } + + static class Tensor_int8 extends Tensor { + private final ByteBuffer data; + + private Tensor_int8(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) { + super(shape, memoryFormat); + this.data = data; + } + + @Override + public DType dtype() { + return DType.INT8; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public byte[] getDataAsByteArray() { + data.rewind(); + byte[] arr = new byte[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape)); + } + } + + static class Tensor_int32 extends Tensor { + private final IntBuffer data; + + private Tensor_int32(IntBuffer data, long[] shape, MemoryFormat memoryFormat) { + super(shape, memoryFormat); + this.data = data; + } + + @Override + public DType dtype() { + return DType.INT32; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public int[] getDataAsIntArray() { + data.rewind(); + int[] arr = new int[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape)); + } + } + + static class Tensor_float32 extends Tensor { + private final FloatBuffer data; + + Tensor_float32(FloatBuffer data, long[] shape, MemoryFormat memoryFormat) { + super(shape, memoryFormat); + this.data = data; + } + + @Override + public float[] getDataAsFloatArray() { + data.rewind(); + float[] arr = new float[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public DType dtype() { + return DType.FLOAT32; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.float32)", Arrays.toString(shape)); + } + } + + static class Tensor_int64 extends Tensor { + private final LongBuffer data; + + private Tensor_int64(LongBuffer data, long[] shape, MemoryFormat memoryFormat) { + super(shape, memoryFormat); + this.data = data; + } + + @Override + public DType dtype() { + return DType.INT64; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public long[] getDataAsLongArray() { + data.rewind(); + long[] arr = new long[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape)); + } + } + + static class Tensor_float64 extends Tensor { + private final DoubleBuffer data; + + private Tensor_float64(DoubleBuffer data, long[] shape, MemoryFormat memoryFormat) { + super(shape, memoryFormat); + this.data = data; + } + + @Override + public DType dtype() { + return DType.FLOAT64; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public double[] getDataAsDoubleArray() { + data.rewind(); + double[] arr = new double[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape)); + } + } + + // region checks + private static void checkArgument(boolean expression, String errorMessage, Object... args) { + if (!expression) { + throw new IllegalArgumentException(String.format(Locale.US, errorMessage, args)); + } + } + + private static void checkShape(long[] shape) { + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + for (int i = 0; i < shape.length; i++) { + checkArgument(shape[i] >= 0, ERROR_MSG_SHAPE_NON_NEGATIVE); + } + } + + private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape) { + final long numel = numel(shape); + checkArgument( + numel == dataCapacity, + "Inconsistent data capacity:%d and shape number elements:%d shape:%s", + dataCapacity, + numel, + Arrays.toString(shape)); + } + // endregion checks + + // Called from native + @DoNotStrip + private static Tensor nativeNewTensor( + ByteBuffer data, long[] shape, int dtype, int memoryFormatCode, HybridData hybridData) { + Tensor tensor = null; + + MemoryFormat memoryFormat = MemoryFormat.CONTIGUOUS; + if (MemoryFormat.CHANNELS_LAST.jniCode == memoryFormatCode) { + memoryFormat = MemoryFormat.CHANNELS_LAST; + } else if (MemoryFormat.CHANNELS_LAST_3D.jniCode == memoryFormatCode) { + memoryFormat = MemoryFormat.CHANNELS_LAST_3D; + } + + if (DType.FLOAT32.jniCode == dtype) { + tensor = new Tensor_float32(data.asFloatBuffer(), shape, memoryFormat); + } else if (DType.INT32.jniCode == dtype) { + tensor = new Tensor_int32(data.asIntBuffer(), shape, memoryFormat); + } else if (DType.INT64.jniCode == dtype) { + tensor = new Tensor_int64(data.asLongBuffer(), shape, memoryFormat); + } else if (DType.FLOAT64.jniCode == dtype) { + tensor = new Tensor_float64(data.asDoubleBuffer(), shape, memoryFormat); + } else if (DType.UINT8.jniCode == dtype) { + tensor = new Tensor_uint8(data, shape, memoryFormat); + } else if (DType.INT8.jniCode == dtype) { + tensor = new Tensor_int8(data, shape, memoryFormat); + } else { + new IllegalArgumentException("Unknown Tensor dtype"); + } + tensor.mHybridData = hybridData; + return tensor; + } +} diff --git a/android/pytorch_android/src/main/res/values/strings.xml b/android/pytorch_android/src/main/res/values/strings.xml new file mode 100644 index 000000000000..67753dcbddf7 --- /dev/null +++ b/android/pytorch_android/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + pytorch_android + diff --git a/android/pytorch_android/test_asset.jit b/android/pytorch_android/test_asset.jit new file mode 100644 index 000000000000..8605ab13d555 --- /dev/null +++ b/android/pytorch_android/test_asset.jit @@ -0,0 +1,105 @@ +def forward(self, input): + return None + +def eqBool(self, input: bool) -> bool: + return input + +def eqInt(self, input: int) -> int: + return input + +def eqFloat(self, input: float) -> float: + return input + +def eqStr(self, input: str) -> str: + return input + +def eqTensor(self, input: Tensor) -> Tensor: + return input + +def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]: + return input + +def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]: + return input + +def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]: + return input + +def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]: + sum = 0 + for x in input: + sum += x + return (input, sum) + +def listBoolConjunction(self, input: List[bool]) -> bool: + res = True + for x in input: + res = res and x + return res + +def listBoolDisjunction(self, input: List[bool]) -> bool: + res = False + for x in input: + res = res or x + return res + +def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]: + sum = 0 + for x in input: + sum += x + return (input, sum) + +def optionalIntIsNone(self, input: Optional[int]) -> bool: + return input is None + +def intEq0None(self, input: int) -> Optional[int]: + if input == 0: + return None + return input + +def str3Concat(self, input: str) -> str: + return input + input + input + +def newEmptyShapeWithItem(self, input): + return torch.tensor([int(input.item())])[0] + +def testAliasWithOffset(self) -> List[Tensor]: + x = torch.tensor([100, 200]) + a = [x[0], x[1]] + return a + +def testNonContiguous(self): + x = torch.tensor([100, 200, 300])[::2] + assert not x.is_contiguous() + assert x[0] == 100 + assert x[1] == 300 + return x + +def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: + r = torch.conv2d(x, w) + if (toChannelsLast): + # memory_format=torch.channels_last + r = r.contiguous(memory_format=2) + else: + r = r.contiguous() + return r + +def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: + r = torch.conv3d(x, w) + if (toChannelsLast): + # memory_format=torch.channels_last_3d + r = r.contiguous(memory_format=2) + else: + r = r.contiguous() + return r + +def contiguous(self, x: Tensor) -> Tensor: + return x.contiguous() + +def contiguousChannelsLast(self, x: Tensor) -> Tensor: + # memory_format=torch.channels_last + return x.contiguous(memory_format=2) + +def contiguousChannelsLast3d(self, x: Tensor) -> Tensor: + # memory_format=torch.channels_last_3d + return x.contiguous(memory_format=3) diff --git a/android/pytorch_android_torchvision/CMakeLists.txt b/android/pytorch_android_torchvision/CMakeLists.txt new file mode 100644 index 000000000000..3e23fdc1a1e5 --- /dev/null +++ b/android/pytorch_android_torchvision/CMakeLists.txt @@ -0,0 +1,22 @@ +cmake_minimum_required(VERSION 3.5) +project(pytorch_vision_jni CXX) +set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_VERBOSE_MAKEFILE ON) + +set(pytorch_vision_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp) + +file(GLOB pytorch_vision_SOURCES + ${pytorch_vision_cpp_DIR}/pytorch_vision_jni.cpp +) + +add_library(pytorch_vision_jni SHARED + ${pytorch_vision_SOURCES} +) + +target_compile_options(pytorch_vision_jni PRIVATE + -fexceptions +) + +set(BUILD_SUBDIR ${ANDROID_ABI}) + +target_link_libraries(pytorch_vision_jni) diff --git a/android/pytorch_android_torchvision/build.gradle b/android/pytorch_android_torchvision/build.gradle new file mode 100644 index 000000000000..06d8d4db264f --- /dev/null +++ b/android/pytorch_android_torchvision/build.gradle @@ -0,0 +1,64 @@ +apply plugin: 'com.android.library' +apply plugin: 'maven' + +android { + compileSdkVersion rootProject.compileSdkVersion + buildToolsVersion rootProject.buildToolsVersion + + + defaultConfig { + minSdkVersion rootProject.minSdkVersion + targetSdkVersion rootProject.targetSdkVersion + versionCode 0 + versionName "0.1" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + ndk { + abiFilters ABI_FILTERS.split(",") + } + } + + buildTypes { + debug { + minifyEnabled false + debuggable true + } + release { + minifyEnabled false + } + } + + externalNativeBuild { + cmake { + path "CMakeLists.txt" + } + } + + useLibrary 'android.test.runner' + useLibrary 'android.test.base' + useLibrary 'android.test.mock' +} + +dependencies { + implementation project(':pytorch_android') + + implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion + + testImplementation 'junit:junit:' + rootProject.junitVersion + testImplementation 'androidx.test:core:' + rootProject.coreVersion + + androidTestImplementation 'junit:junit:' + rootProject.junitVersion + androidTestImplementation 'androidx.test:core:' + rootProject.coreVersion + androidTestImplementation 'androidx.test.ext:junit:' + rootProject.extJUnitVersion + androidTestImplementation 'androidx.test:rules:' + rootProject.rulesVersion + androidTestImplementation 'androidx.test:runner:' + rootProject.runnerVersion +} + +apply from: rootProject.file('gradle/release.gradle') + +task sourcesJar(type: Jar) { + from android.sourceSets.main.java.srcDirs + classifier = 'sources' +} + +artifacts.add('archives', sourcesJar) diff --git a/android/pytorch_android_torchvision/gradle.properties b/android/pytorch_android_torchvision/gradle.properties new file mode 100644 index 000000000000..6ff3a6ce3071 --- /dev/null +++ b/android/pytorch_android_torchvision/gradle.properties @@ -0,0 +1,4 @@ +POM_NAME=pytorch_android_torchvision_lite +POM_DESCRIPTION=pytorch_android_torchvision_lite +POM_ARTIFACT_ID=pytorch_android_torchvision_lite +POM_PACKAGING=aar diff --git a/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/TorchVisionInstrumentedTests.java b/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/TorchVisionInstrumentedTests.java new file mode 100644 index 000000000000..f3a43980a349 --- /dev/null +++ b/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/TorchVisionInstrumentedTests.java @@ -0,0 +1,24 @@ +package org.pytorch.torchvision; + +import static org.junit.Assert.assertArrayEquals; + +import android.graphics.Bitmap; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.pytorch.Tensor; + +@RunWith(AndroidJUnit4.class) +public class TorchVisionInstrumentedTests { + + @Test + public void smokeTest() { + Bitmap bitmap = Bitmap.createBitmap(320, 240, Bitmap.Config.ARGB_8888); + Tensor tensor = + TensorImageUtils.bitmapToFloat32Tensor( + bitmap, + TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, + TensorImageUtils.TORCHVISION_NORM_STD_RGB); + assertArrayEquals(new long[] {1l, 3l, 240l, 320l}, tensor.shape()); + } +} diff --git a/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/suite/TorchVisionInstrumentedTestSuite.java b/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/suite/TorchVisionInstrumentedTestSuite.java new file mode 100644 index 000000000000..831127396a67 --- /dev/null +++ b/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/suite/TorchVisionInstrumentedTestSuite.java @@ -0,0 +1,9 @@ +package org.pytorch.torchvision.suite; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.pytorch.torchvision.TorchVisionInstrumentedTests; + +@RunWith(Suite.class) +@Suite.SuiteClasses({TorchVisionInstrumentedTests.class}) +public class TorchVisionInstrumentedTestSuite {} diff --git a/android/pytorch_android_torchvision/src/main/AndroidManifest.xml b/android/pytorch_android_torchvision/src/main/AndroidManifest.xml new file mode 100644 index 000000000000..9cd2f3cf31d7 --- /dev/null +++ b/android/pytorch_android_torchvision/src/main/AndroidManifest.xml @@ -0,0 +1 @@ + diff --git a/android/pytorch_android_torchvision/src/main/cpp/pytorch_vision_jni.cpp b/android/pytorch_android_torchvision/src/main/cpp/pytorch_vision_jni.cpp new file mode 100644 index 000000000000..41fc7d5ed337 --- /dev/null +++ b/android/pytorch_android_torchvision/src/main/cpp/pytorch_vision_jni.cpp @@ -0,0 +1,187 @@ +#include +#include +#include + +#include "jni.h" + +namespace pytorch_vision_jni { + +static void imageYUV420CenterCropToFloatBuffer( + JNIEnv* jniEnv, + jclass, + jobject yBuffer, + jint yRowStride, + jint yPixelStride, + jobject uBuffer, + jobject vBuffer, + jint uRowStride, + jint uvPixelStride, + jint imageWidth, + jint imageHeight, + jint rotateCWDegrees, + jint tensorWidth, + jint tensorHeight, + jfloatArray jnormMeanRGB, + jfloatArray jnormStdRGB, + jobject outBuffer, + jint outOffset, + jint memoryFormatCode) { + constexpr static int32_t kMemoryFormatContiguous = 1; + constexpr static int32_t kMemoryFormatChannelsLast = 2; + + float* outData = (float*)jniEnv->GetDirectBufferAddress(outBuffer); + + jfloat normMeanRGB[3]; + jfloat normStdRGB[3]; + jniEnv->GetFloatArrayRegion(jnormMeanRGB, 0, 3, normMeanRGB); + jniEnv->GetFloatArrayRegion(jnormStdRGB, 0, 3, normStdRGB); + int widthAfterRtn = imageWidth; + int heightAfterRtn = imageHeight; + bool oddRotation = rotateCWDegrees == 90 || rotateCWDegrees == 270; + if (oddRotation) { + widthAfterRtn = imageHeight; + heightAfterRtn = imageWidth; + } + + int cropWidthAfterRtn = widthAfterRtn; + int cropHeightAfterRtn = heightAfterRtn; + + if (tensorWidth * heightAfterRtn <= tensorHeight * widthAfterRtn) { + cropWidthAfterRtn = tensorWidth * heightAfterRtn / tensorHeight; + } else { + cropHeightAfterRtn = tensorHeight * widthAfterRtn / tensorWidth; + } + + int cropWidthBeforeRtn = cropWidthAfterRtn; + int cropHeightBeforeRtn = cropHeightAfterRtn; + if (oddRotation) { + cropWidthBeforeRtn = cropHeightAfterRtn; + cropHeightBeforeRtn = cropWidthAfterRtn; + } + + const int offsetX = (imageWidth - cropWidthBeforeRtn) / 2.f; + const int offsetY = (imageHeight - cropHeightBeforeRtn) / 2.f; + + const uint8_t* yData = (uint8_t*)jniEnv->GetDirectBufferAddress(yBuffer); + const uint8_t* uData = (uint8_t*)jniEnv->GetDirectBufferAddress(uBuffer); + const uint8_t* vData = (uint8_t*)jniEnv->GetDirectBufferAddress(vBuffer); + + float scale = cropWidthAfterRtn / tensorWidth; + int uvRowStride = uRowStride; + int cropXMult = 1; + int cropYMult = 1; + int cropXAdd = offsetX; + int cropYAdd = offsetY; + if (rotateCWDegrees == 90) { + cropYMult = -1; + cropYAdd = offsetY + (cropHeightBeforeRtn - 1); + } else if (rotateCWDegrees == 180) { + cropXMult = -1; + cropXAdd = offsetX + (cropWidthBeforeRtn - 1); + cropYMult = -1; + cropYAdd = offsetY + (cropHeightBeforeRtn - 1); + } else if (rotateCWDegrees == 270) { + cropXMult = -1; + cropXAdd = offsetX + (cropWidthBeforeRtn - 1); + } + + float normMeanRm255 = 255 * normMeanRGB[0]; + float normMeanGm255 = 255 * normMeanRGB[1]; + float normMeanBm255 = 255 * normMeanRGB[2]; + float normStdRm255 = 255 * normStdRGB[0]; + float normStdGm255 = 255 * normStdRGB[1]; + float normStdBm255 = 255 * normStdRGB[2]; + + int xBeforeRtn, yBeforeRtn; + int yi, yIdx, uvIdx, ui, vi, a0, ri, gi, bi; + int channelSize = tensorWidth * tensorHeight; + // A bit of code duplication to avoid branching in the cycles + if (memoryFormatCode == kMemoryFormatContiguous) { + int wr = outOffset; + int wg = wr + channelSize; + int wb = wg + channelSize; + for (int y = 0; y < tensorHeight; y++) { + for (int x = 0; x < tensorWidth; x++) { + xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale); + yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale); + yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride; + uvIdx = + (yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride; + ui = uData[uvIdx]; + vi = vData[uvIdx]; + yi = yData[yIdx]; + yi = (yi - 16) < 0 ? 0 : (yi - 16); + ui -= 128; + vi -= 128; + a0 = 1192 * yi; + ri = (a0 + 1634 * vi) >> 10; + gi = (a0 - 833 * vi - 400 * ui) >> 10; + bi = (a0 + 2066 * ui) >> 10; + ri = ri > 255 ? 255 : ri < 0 ? 0 : ri; + gi = gi > 255 ? 255 : gi < 0 ? 0 : gi; + bi = bi > 255 ? 255 : bi < 0 ? 0 : bi; + outData[wr++] = (ri - normMeanRm255) / normStdRm255; + outData[wg++] = (gi - normMeanGm255) / normStdGm255; + outData[wb++] = (bi - normMeanBm255) / normStdBm255; + } + } + } else if (memoryFormatCode == kMemoryFormatChannelsLast) { + int wc = outOffset; + for (int y = 0; y < tensorHeight; y++) { + for (int x = 0; x < tensorWidth; x++) { + xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale); + yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale); + yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride; + uvIdx = + (yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride; + ui = uData[uvIdx]; + vi = vData[uvIdx]; + yi = yData[yIdx]; + yi = (yi - 16) < 0 ? 0 : (yi - 16); + ui -= 128; + vi -= 128; + a0 = 1192 * yi; + ri = (a0 + 1634 * vi) >> 10; + gi = (a0 - 833 * vi - 400 * ui) >> 10; + bi = (a0 + 2066 * ui) >> 10; + ri = ri > 255 ? 255 : ri < 0 ? 0 : ri; + gi = gi > 255 ? 255 : gi < 0 ? 0 : gi; + bi = bi > 255 ? 255 : bi < 0 ? 0 : bi; + outData[wc++] = (ri - normMeanRm255) / normStdRm255; + outData[wc++] = (gi - normMeanGm255) / normStdGm255; + outData[wc++] = (bi - normMeanBm255) / normStdBm255; + } + } + } else { + jclass Exception = jniEnv->FindClass("java/lang/IllegalArgumentException"); + jniEnv->ThrowNew(Exception, "Illegal memory format code"); + } +} +} // namespace pytorch_vision_jni + +JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { + return JNI_ERR; + } + + jclass c = + env->FindClass("org/pytorch/torchvision/TensorImageUtils$NativePeer"); + if (c == nullptr) { + return JNI_ERR; + } + + static const JNINativeMethod methods[] = { + {"imageYUV420CenterCropToFloatBuffer", + "(Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;Ljava/nio/ByteBuffer;IIIIIII[F[FLjava/nio/Buffer;II)V", + (void*)pytorch_vision_jni::imageYUV420CenterCropToFloatBuffer}, + }; + int rc = env->RegisterNatives( + c, methods, sizeof(methods) / sizeof(JNINativeMethod)); + + if (rc != JNI_OK) { + return rc; + } + + return JNI_VERSION_1_6; +} diff --git a/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java b/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java new file mode 100644 index 000000000000..405a474e03a5 --- /dev/null +++ b/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java @@ -0,0 +1,398 @@ +package org.pytorch.torchvision; + +import android.graphics.Bitmap; +import android.graphics.ImageFormat; +import android.media.Image; +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.FloatBuffer; +import java.util.Locale; +import org.pytorch.MemoryFormat; +import org.pytorch.Tensor; + +/** + * Contains utility functions for {@link org.pytorch.Tensor} creation from {@link + * android.graphics.Bitmap} or {@link android.media.Image} source. + */ +public final class TensorImageUtils { + + public static float[] TORCHVISION_NORM_MEAN_RGB = new float[] {0.485f, 0.456f, 0.406f}; + public static float[] TORCHVISION_NORM_STD_RGB = new float[] {0.229f, 0.224f, 0.225f}; + + /** + * Creates new {@link org.pytorch.Tensor} from full {@link android.graphics.Bitmap}, normalized + * with specified in parameters mean and std. + * + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + public static Tensor bitmapToFloat32Tensor( + final Bitmap bitmap, + final float[] normMeanRGB, + final float normStdRGB[], + final MemoryFormat memoryFormat) { + checkNormMeanArg(normMeanRGB); + checkNormStdArg(normStdRGB); + + return bitmapToFloat32Tensor( + bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB, memoryFormat); + } + + public static Tensor bitmapToFloat32Tensor( + final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[]) { + return bitmapToFloat32Tensor( + bitmap, + 0, + 0, + bitmap.getWidth(), + bitmap.getHeight(), + normMeanRGB, + normStdRGB, + MemoryFormat.CONTIGUOUS); + } + + /** + * Writes tensor content from specified {@link android.graphics.Bitmap}, normalized with specified + * in parameters mean and std to specified {@link java.nio.FloatBuffer} with specified offset. + * + * @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data + * @param x - x coordinate of top left corner of bitmap's area + * @param y - y coordinate of top left corner of bitmap's area + * @param width - width of bitmap's area + * @param height - height of bitmap's area + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + public static void bitmapToFloatBuffer( + final Bitmap bitmap, + final int x, + final int y, + final int width, + final int height, + final float[] normMeanRGB, + final float[] normStdRGB, + final FloatBuffer outBuffer, + final int outBufferOffset, + final MemoryFormat memoryFormat) { + checkOutBufferCapacity(outBuffer, outBufferOffset, width, height); + checkNormMeanArg(normMeanRGB); + checkNormStdArg(normStdRGB); + if (memoryFormat != MemoryFormat.CONTIGUOUS && memoryFormat != MemoryFormat.CHANNELS_LAST) { + throw new IllegalArgumentException("Unsupported memory format " + memoryFormat); + } + + final int pixelsCount = height * width; + final int[] pixels = new int[pixelsCount]; + bitmap.getPixels(pixels, 0, width, x, y, width, height); + if (MemoryFormat.CONTIGUOUS == memoryFormat) { + final int offset_g = pixelsCount; + final int offset_b = 2 * pixelsCount; + for (int i = 0; i < pixelsCount; i++) { + final int c = pixels[i]; + float r = ((c >> 16) & 0xff) / 255.0f; + float g = ((c >> 8) & 0xff) / 255.0f; + float b = ((c) & 0xff) / 255.0f; + outBuffer.put(outBufferOffset + i, (r - normMeanRGB[0]) / normStdRGB[0]); + outBuffer.put(outBufferOffset + offset_g + i, (g - normMeanRGB[1]) / normStdRGB[1]); + outBuffer.put(outBufferOffset + offset_b + i, (b - normMeanRGB[2]) / normStdRGB[2]); + } + } else { + for (int i = 0; i < pixelsCount; i++) { + final int c = pixels[i]; + float r = ((c >> 16) & 0xff) / 255.0f; + float g = ((c >> 8) & 0xff) / 255.0f; + float b = ((c) & 0xff) / 255.0f; + outBuffer.put(outBufferOffset + 3 * i + 0, (r - normMeanRGB[0]) / normStdRGB[0]); + outBuffer.put(outBufferOffset + 3 * i + 1, (g - normMeanRGB[1]) / normStdRGB[1]); + outBuffer.put(outBufferOffset + 3 * i + 2, (b - normMeanRGB[2]) / normStdRGB[2]); + } + } + } + + public static void bitmapToFloatBuffer( + final Bitmap bitmap, + final int x, + final int y, + final int width, + final int height, + final float[] normMeanRGB, + final float[] normStdRGB, + final FloatBuffer outBuffer, + final int outBufferOffset) { + bitmapToFloatBuffer( + bitmap, + x, + y, + width, + height, + normMeanRGB, + normStdRGB, + outBuffer, + outBufferOffset, + MemoryFormat.CONTIGUOUS); + } + + /** + * Creates new {@link org.pytorch.Tensor} from specified area of {@link android.graphics.Bitmap}, + * normalized with specified in parameters mean and std. + * + * @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data + * @param x - x coordinate of top left corner of bitmap's area + * @param y - y coordinate of top left corner of bitmap's area + * @param width - width of bitmap's area + * @param height - height of bitmap's area + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + public static Tensor bitmapToFloat32Tensor( + final Bitmap bitmap, + int x, + int y, + int width, + int height, + float[] normMeanRGB, + float[] normStdRGB, + MemoryFormat memoryFormat) { + checkNormMeanArg(normMeanRGB); + checkNormStdArg(normStdRGB); + + final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * width * height); + bitmapToFloatBuffer( + bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0, memoryFormat); + return Tensor.fromBlob(floatBuffer, new long[] {1, 3, height, width}, memoryFormat); + } + + public static Tensor bitmapToFloat32Tensor( + final Bitmap bitmap, + int x, + int y, + int width, + int height, + float[] normMeanRGB, + float[] normStdRGB) { + return bitmapToFloat32Tensor( + bitmap, x, y, width, height, normMeanRGB, normStdRGB, MemoryFormat.CONTIGUOUS); + } + + /** + * Creates new {@link org.pytorch.Tensor} from specified area of {@link android.media.Image}, + * doing optional rotation, scaling (nearest) and center cropping. + * + * @param image {@link android.media.Image} as a source for Tensor data + * @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be + * upright. Range of valid values: 0, 90, 180, 270 + * @param tensorWidth return tensor width, must be positive + * @param tensorHeight return tensor height, must be positive + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + public static Tensor imageYUV420CenterCropToFloat32Tensor( + final Image image, + int rotateCWDegrees, + final int tensorWidth, + final int tensorHeight, + float[] normMeanRGB, + float[] normStdRGB, + MemoryFormat memoryFormat) { + if (image.getFormat() != ImageFormat.YUV_420_888) { + throw new IllegalArgumentException( + String.format( + Locale.US, "Image format %d != ImageFormat.YUV_420_888", image.getFormat())); + } + + checkNormMeanArg(normMeanRGB); + checkNormStdArg(normStdRGB); + checkRotateCWDegrees(rotateCWDegrees); + checkTensorSize(tensorWidth, tensorHeight); + + final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * tensorWidth * tensorHeight); + imageYUV420CenterCropToFloatBuffer( + image, + rotateCWDegrees, + tensorWidth, + tensorHeight, + normMeanRGB, + normStdRGB, + floatBuffer, + 0, + memoryFormat); + return Tensor.fromBlob(floatBuffer, new long[] {1, 3, tensorHeight, tensorWidth}, memoryFormat); + } + + public static Tensor imageYUV420CenterCropToFloat32Tensor( + final Image image, + int rotateCWDegrees, + final int tensorWidth, + final int tensorHeight, + float[] normMeanRGB, + float[] normStdRGB) { + return imageYUV420CenterCropToFloat32Tensor( + image, + rotateCWDegrees, + tensorWidth, + tensorHeight, + normMeanRGB, + normStdRGB, + MemoryFormat.CONTIGUOUS); + } + + /** + * Writes tensor content from specified {@link android.media.Image}, doing optional rotation, + * scaling (nearest) and center cropping to specified {@link java.nio.FloatBuffer} with specified + * offset. + * + * @param image {@link android.media.Image} as a source for Tensor data + * @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be + * upright. Range of valid values: 0, 90, 180, 270 + * @param tensorWidth return tensor width, must be positive + * @param tensorHeight return tensor height, must be positive + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + * @param outBuffer Output buffer, where tensor content will be written + * @param outBufferOffset Output buffer offset with which tensor content will be written + */ + public static void imageYUV420CenterCropToFloatBuffer( + final Image image, + int rotateCWDegrees, + final int tensorWidth, + final int tensorHeight, + float[] normMeanRGB, + float[] normStdRGB, + final FloatBuffer outBuffer, + final int outBufferOffset, + final MemoryFormat memoryFormat) { + checkOutBufferCapacity(outBuffer, outBufferOffset, tensorWidth, tensorHeight); + + if (image.getFormat() != ImageFormat.YUV_420_888) { + throw new IllegalArgumentException( + String.format( + Locale.US, "Image format %d != ImageFormat.YUV_420_888", image.getFormat())); + } + + checkNormMeanArg(normMeanRGB); + checkNormStdArg(normStdRGB); + checkRotateCWDegrees(rotateCWDegrees); + checkTensorSize(tensorWidth, tensorHeight); + + Image.Plane[] planes = image.getPlanes(); + Image.Plane Y = planes[0]; + Image.Plane U = planes[1]; + Image.Plane V = planes[2]; + + int memoryFormatJniCode = 0; + if (MemoryFormat.CONTIGUOUS == memoryFormat) { + memoryFormatJniCode = 1; + } else if (MemoryFormat.CHANNELS_LAST == memoryFormat) { + memoryFormatJniCode = 2; + } + + NativePeer.imageYUV420CenterCropToFloatBuffer( + Y.getBuffer(), + Y.getRowStride(), + Y.getPixelStride(), + U.getBuffer(), + V.getBuffer(), + U.getRowStride(), + U.getPixelStride(), + image.getWidth(), + image.getHeight(), + rotateCWDegrees, + tensorWidth, + tensorHeight, + normMeanRGB, + normStdRGB, + outBuffer, + outBufferOffset, + memoryFormatJniCode); + } + + public static void imageYUV420CenterCropToFloatBuffer( + final Image image, + int rotateCWDegrees, + final int tensorWidth, + final int tensorHeight, + float[] normMeanRGB, + float[] normStdRGB, + final FloatBuffer outBuffer, + final int outBufferOffset) { + imageYUV420CenterCropToFloatBuffer( + image, + rotateCWDegrees, + tensorWidth, + tensorHeight, + normMeanRGB, + normStdRGB, + outBuffer, + outBufferOffset, + MemoryFormat.CONTIGUOUS); + } + + private static class NativePeer { + static { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + NativeLoader.loadLibrary("pytorch_vision_jni"); + } + + private static native void imageYUV420CenterCropToFloatBuffer( + ByteBuffer yBuffer, + int yRowStride, + int yPixelStride, + ByteBuffer uBuffer, + ByteBuffer vBuffer, + int uvRowStride, + int uvPixelStride, + int imageWidth, + int imageHeight, + int rotateCWDegrees, + int tensorWidth, + int tensorHeight, + float[] normMeanRgb, + float[] normStdRgb, + Buffer outBuffer, + int outBufferOffset, + int memoryFormatJniCode); + } + + private static void checkOutBufferCapacity( + FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) { + if (outBufferOffset + 3 * tensorWidth * tensorHeight > outBuffer.capacity()) { + throw new IllegalStateException("Buffer underflow"); + } + } + + private static void checkTensorSize(int tensorWidth, int tensorHeight) { + if (tensorHeight <= 0 || tensorWidth <= 0) { + throw new IllegalArgumentException("tensorHeight and tensorWidth must be positive"); + } + } + + private static void checkRotateCWDegrees(int rotateCWDegrees) { + if (rotateCWDegrees != 0 + && rotateCWDegrees != 90 + && rotateCWDegrees != 180 + && rotateCWDegrees != 270) { + throw new IllegalArgumentException("rotateCWDegrees must be one of 0, 90, 180, 270"); + } + } + + private static void checkNormStdArg(float[] normStdRGB) { + if (normStdRGB.length != 3) { + throw new IllegalArgumentException("normStdRGB length must be 3"); + } + } + + private static void checkNormMeanArg(float[] normMeanRGB) { + if (normMeanRGB.length != 3) { + throw new IllegalArgumentException("normMeanRGB length must be 3"); + } + } +} diff --git a/android/pytorch_android_torchvision/src/main/res/values/strings.xml b/android/pytorch_android_torchvision/src/main/res/values/strings.xml new file mode 100644 index 000000000000..8a9d24ca3536 --- /dev/null +++ b/android/pytorch_android_torchvision/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + pytorch_android_torchvision + diff --git a/android/run_tests.sh b/android/run_tests.sh new file mode 100755 index 000000000000..c8c897761493 --- /dev/null +++ b/android/run_tests.sh @@ -0,0 +1,57 @@ +#!/bin/bash +set -eux + +PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)" +PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android + +source "$PYTORCH_ANDROID_DIR/common.sh" + +check_android_sdk +check_gradle + +# Run android instrumented tests on x86 emulator + +ADB_PATH=$ANDROID_HOME/platform-tools/adb + +echo "Expecting running emulator" +$ADB_PATH devices + +DEVICES_COUNT=$($ADB_PATH devices | awk 'NF' | wc -l) +echo "DEVICES_COUNT:$DEVICES_COUNT" + +if [ "$DEVICES_COUNT" -eq 1 ]; then + echo "Unable to found connected android emulators" +cat <<- EOF + To start android emulator: + 1. Install android sdkmanager packages + $ANDROID_HOME/tools/bin/sdkmanager "system-images;android-25;google_apis;x86" + + to specify proxy add params: --proxy=http --proxy_host=fwdproxy --proxy_port=8080 + + 2. Create android virtual device + $ANDROID_HOME/tools/bin/avdmanager create avd --name "x86_android25" --package "system-images;android-25;google_apis;x86" + + 3. Start emulator in headless mode without audio + $ANDROID_HOME/tools/emulator -avd x86_android25 -no-audio -no-window + + 4. Check that emulator is running + $ANDROID_HOME/platform-tools/adb devices + + If everything is ok the output will be: + + List of devices attached + emulator-5554 device +EOF + exit 1 +fi + +echo "Waiting for emulator boot completed" +$ADB_PATH wait-for-device shell 'while [[ -z $(getprop sys.boot_completed) ]]; do sleep 1; done;' + +{ + # The test currently takes about 10 minutes + retry $GRADLE_PATH -PABI_FILTERS=x86 -p $PYTORCH_ANDROID_DIR connectedAndroidTest +} || { + echo "::error::Check https://github.com/pytorch/pytorch/tree/master/test/mobile/model_test to see how to fix the failed mobile test" + exit 1 +} diff --git a/android/settings.gradle b/android/settings.gradle new file mode 100644 index 000000000000..743f388b6507 --- /dev/null +++ b/android/settings.gradle @@ -0,0 +1,6 @@ +include ':app', ':pytorch_android', ':pytorch_android_torchvision', ':pytorch_host', ':test_app' + +project(':pytorch_android_torchvision').projectDir = file('pytorch_android_torchvision') + +project(':pytorch_host').projectDir = file('pytorch_android/host') +project(':test_app').projectDir = file('test_app/app') diff --git a/android/test_app/.gitignore b/android/test_app/.gitignore new file mode 100644 index 000000000000..b90a745d805c --- /dev/null +++ b/android/test_app/.gitignore @@ -0,0 +1,9 @@ +local.properties +**/*.iml +.gradle +gradlew* +gradle/wrapper +.idea/* +.DS_Store +build +.externalNativeBuild diff --git a/android/test_app/app/CMakeLists.txt b/android/test_app/app/CMakeLists.txt new file mode 100644 index 000000000000..c54f49f53a1d --- /dev/null +++ b/android/test_app/app/CMakeLists.txt @@ -0,0 +1,38 @@ +cmake_minimum_required(VERSION 3.5) +set(PROJECT_NAME pytorch_testapp_jni) +project(${PROJECT_NAME} CXX) +set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_VERBOSE_MAKEFILE ON) + +set(build_DIR ${CMAKE_SOURCE_DIR}/build) + +set(pytorch_testapp_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp) +message(STATUS "ANDROID_STL:${ANDROID_STL}") +file(GLOB pytorch_testapp_SOURCES + ${pytorch_testapp_cpp_DIR}/pytorch_testapp_jni.cpp +) + +add_library(${PROJECT_NAME} SHARED + ${pytorch_testapp_SOURCES} +) + +file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers") +file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}") + +target_compile_options(${PROJECT_NAME} PRIVATE + -fexceptions +) + +set(BUILD_SUBDIR ${ANDROID_ABI}) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${PYTORCH_INCLUDE_DIRS} +) + +find_library(PYTORCH_LIBRARY pytorch_jni + PATHS ${PYTORCH_LINK_DIRS} + NO_CMAKE_FIND_ROOT_PATH) + +target_link_libraries(${PROJECT_NAME} + ${PYTORCH_LIBRARY} + log) diff --git a/android/test_app/app/build.gradle b/android/test_app/app/build.gradle new file mode 100644 index 000000000000..963ff1a7a426 --- /dev/null +++ b/android/test_app/app/build.gradle @@ -0,0 +1,190 @@ +apply plugin: 'com.android.application' + +repositories { + jcenter() + maven { + url "https://oss.sonatype.org/content/repositories/snapshots" + } + flatDir { + dirs 'aars' + } +} + +android { + configurations { + extractForNativeBuild + } + compileOptions { + sourceCompatibility 1.8 + targetCompatibility 1.8 + } + compileSdkVersion rootProject.compileSdkVersion + buildToolsVersion rootProject.buildToolsVersion + defaultConfig { + applicationId "org.pytorch.testapp" + minSdkVersion rootProject.minSdkVersion + targetSdkVersion rootProject.targetSdkVersion + versionCode 1 + versionName "1.0" + ndk { + abiFilters ABI_FILTERS.split(",") + } + // Commented due to dependency on local copy of pytorch_android aar to aars folder + //externalNativeBuild { + // cmake { + // abiFilters ABI_FILTERS.split(",") + // arguments "-DANDROID_STL=c++_shared" + // } + //} + buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2q.pt\"") + buildConfigField("String", "LOGCAT_TAG", "@string/app_name") + buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}") + buildConfigField("boolean", "NATIVE_BUILD", 'false') + buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false') + buildConfigField( + "int", + "BUILD_LITE_INTERPRETER", + System.env.BUILD_LITE_INTERPRETER != null ? System.env.BUILD_LITE_INTERPRETER : "1" + ) + addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"]) + } + buildTypes { + debug { + minifyEnabled false + debuggable true + } + release { + minifyEnabled false + } + } + // Commented due to dependency on local copy of pytorch_android aar to aars folder + //externalNativeBuild { + // cmake { + // path "CMakeLists.txt" + // } + //} + flavorDimensions "model", "build", "activity" + productFlavors { + mnet { + dimension "model" + applicationIdSuffix ".mnet" + buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2.ptl\"") + addManifestPlaceholders([APP_NAME: "MNET"]) + buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"") + } + // NB: This is not working atm https://github.com/pytorch/pytorch/issues/102966 + mnetVulkan { + dimension "model" + applicationIdSuffix ".mnet_vulkan" + buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2_vulkan.ptl\"") + buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true') + addManifestPlaceholders([APP_NAME: "MNET_VULKAN"]) + buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"") + } + resnet18 { + dimension "model" + applicationIdSuffix ".resnet18" + buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.ptl\"") + addManifestPlaceholders([APP_NAME: "RN18"]) + buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"") + } + local { + dimension "build" + } + nightly { + dimension "build" + } + aar { + dimension "build" + } + // Commented due to dependency on local copy of pytorch_android aar to aars folder + //nativeBuild { + // dimension "build" + // buildConfigField("boolean", "NATIVE_BUILD", "true") + //} + camera { + dimension "activity" + addManifestPlaceholders([MAIN_ACTIVITY: "org.pytorch.testapp.CameraActivity"]) + } + base { + dimension "activity" + sourceSets { + main { + java { + exclude 'org/pytorch/testapp/CameraActivity.java' + } + } + } + } + } + packagingOptions { + doNotStrip '**.so' + } + + // Filtering for CI + if (!testAppAllVariantsEnabled.toBoolean()) { + variantFilter { variant -> + def names = variant.flavors*.name + if (names.contains("nightly") + || names.contains("camera") + || names.contains("aar") + || names.contains("nativeBuild")) { + setIgnore(true) + } + } + } +} + +tasks.all { task -> + // Disable externalNativeBuild for all but nativeBuild variant + if (task.name.startsWith('externalNativeBuild') + && !task.name.contains('NativeBuild')) { + task.enabled = false + } +} + +dependencies { + implementation 'com.android.support:appcompat-v7:28.0.0' + implementation 'com.facebook.soloader:nativeloader:0.10.5' + + localImplementation project(':pytorch_android') + localImplementation project(':pytorch_android_torchvision') + + // Commented due to dependency on local copy of pytorch_android aar to aars folder + //nativeBuildImplementation(name: 'pytorch_android-release', ext: 'aar') + //nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar') + //extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar') + + nightlyImplementation 'org.pytorch:pytorch_android:2.2.0-SNAPSHOT' + nightlyImplementation 'org.pytorch:pytorch_android_torchvision:2.2.0-SNAPSHOT' + + aarImplementation(name:'pytorch_android', ext:'aar') + aarImplementation(name:'pytorch_android_torchvision', ext:'aar') + aarImplementation 'com.facebook.soloader:nativeloader:0.10.5' + aarImplementation 'com.facebook.fbjni:fbjni-java-only:0.2.2' + + def camerax_version = "1.0.0-alpha05" + cameraImplementation "androidx.camera:camera-core:$camerax_version" + cameraImplementation "androidx.camera:camera-camera2:$camerax_version" + cameraImplementation 'com.google.android.material:material:1.0.0-beta01' +} + +task extractAARForNativeBuild { + doLast { + configurations.extractForNativeBuild.files.each { + def file = it.absoluteFile + copy { + from zipTree(file) + into "$buildDir/$file.name" + include "headers/**" + include "jni/**" + } + } + } +} + +tasks.whenTaskAdded { task -> + if (task.name.contains('externalNativeBuild')) { + task.dependsOn(extractAARForNativeBuild) + } +} diff --git a/android/test_app/app/src/main/AndroidManifest.xml b/android/test_app/app/src/main/AndroidManifest.xml new file mode 100644 index 000000000000..abdd9a8d986a --- /dev/null +++ b/android/test_app/app/src/main/AndroidManifest.xml @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/android/test_app/app/src/main/assets/.gitignore b/android/test_app/app/src/main/assets/.gitignore new file mode 100644 index 000000000000..94548af5beba --- /dev/null +++ b/android/test_app/app/src/main/assets/.gitignore @@ -0,0 +1,3 @@ +* +*/ +!.gitignore diff --git a/android/test_app/app/src/main/cpp/pytorch_testapp_jni.cpp b/android/test_app/app/src/main/cpp/pytorch_testapp_jni.cpp new file mode 100644 index 000000000000..d282099a283a --- /dev/null +++ b/android/test_app/app/src/main/cpp/pytorch_testapp_jni.cpp @@ -0,0 +1,77 @@ +#include +#include +#include +#include +#include +#include +#define ALOGI(...) \ + __android_log_print(ANDROID_LOG_INFO, "PyTorchTestAppJni", __VA_ARGS__) +#define ALOGE(...) \ + __android_log_print(ANDROID_LOG_ERROR, "PyTorchTestAppJni", __VA_ARGS__) + +#include "jni.h" + +#include + +namespace pytorch_testapp_jni { +namespace { + +template +void log(const char* m, T t) { + std::ostringstream os; + os << t << std::endl; + ALOGI("%s %s", m, os.str().c_str()); +} + +struct JITCallGuard { + c10::InferenceMode guard; + torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false}; +}; +} // namespace + +static void loadAndForwardModel(JNIEnv* env, jclass, jstring jModelPath) { + const char* modelPath = env->GetStringUTFChars(jModelPath, 0); + assert(modelPath); + + // To load torchscript model for mobile we need set these guards, + // because mobile build doesn't support features like autograd for smaller + // build size which is placed in `struct JITCallGuard` in this example. It may + // change in future, you can track the latest changes keeping an eye in + // android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp + JITCallGuard guard; + torch::jit::Module module = torch::jit::load(modelPath); + module.eval(); + torch::Tensor t = torch::randn({1, 3, 224, 224}); + log("input tensor:", t); + c10::IValue t_out = module.forward({t}); + log("output tensor:", t_out); + env->ReleaseStringUTFChars(jModelPath, modelPath); +} +} // namespace pytorch_testapp_jni + +JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { + return JNI_ERR; + } + + jclass c = + env->FindClass("org/pytorch/testapp/LibtorchNativeClient$NativePeer"); + if (c == nullptr) { + return JNI_ERR; + } + + static const JNINativeMethod methods[] = { + {"loadAndForwardModel", + "(Ljava/lang/String;)V", + (void*)pytorch_testapp_jni::loadAndForwardModel}, + }; + int rc = env->RegisterNatives( + c, methods, sizeof(methods) / sizeof(JNINativeMethod)); + + if (rc != JNI_OK) { + return rc; + } + + return JNI_VERSION_1_6; +} diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/CameraActivity.java b/android/test_app/app/src/main/java/org/pytorch/testapp/CameraActivity.java new file mode 100644 index 000000000000..da3b05b7feb1 --- /dev/null +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/CameraActivity.java @@ -0,0 +1,214 @@ +package org.pytorch.testapp; + +import android.Manifest; +import android.content.pm.PackageManager; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.SystemClock; +import android.util.Log; +import android.util.Size; +import android.view.TextureView; +import android.view.ViewStub; +import android.widget.TextView; +import android.widget.Toast; +import androidx.annotation.Nullable; +import androidx.annotation.UiThread; +import androidx.annotation.WorkerThread; +import androidx.appcompat.app.AppCompatActivity; +import androidx.camera.core.CameraX; +import androidx.camera.core.ImageAnalysis; +import androidx.camera.core.ImageAnalysisConfig; +import androidx.camera.core.ImageProxy; +import androidx.camera.core.Preview; +import androidx.camera.core.PreviewConfig; +import androidx.core.app.ActivityCompat; +import java.nio.FloatBuffer; +import org.pytorch.IValue; +import org.pytorch.MemoryFormat; +import org.pytorch.Module; +import org.pytorch.PyTorchAndroid; +import org.pytorch.Tensor; +import org.pytorch.torchvision.TensorImageUtils; + +public class CameraActivity extends AppCompatActivity { + private static final String TAG = BuildConfig.LOGCAT_TAG; + private static final int TEXT_TRIM_SIZE = 4096; + + private static final int REQUEST_CODE_CAMERA_PERMISSION = 200; + private static final String[] PERMISSIONS = {Manifest.permission.CAMERA}; + + private long mLastAnalysisResultTime; + + protected HandlerThread mBackgroundThread; + protected Handler mBackgroundHandler; + protected Handler mUIHandler; + + private TextView mTextView; + private StringBuilder mTextViewStringBuilder = new StringBuilder(); + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_camera); + mTextView = findViewById(R.id.text); + mUIHandler = new Handler(getMainLooper()); + startBackgroundThread(); + + if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA) + != PackageManager.PERMISSION_GRANTED) { + ActivityCompat.requestPermissions(this, PERMISSIONS, REQUEST_CODE_CAMERA_PERMISSION); + } else { + setupCameraX(); + } + } + + @Override + protected void onPostCreate(@Nullable Bundle savedInstanceState) { + super.onPostCreate(savedInstanceState); + startBackgroundThread(); + } + + protected void startBackgroundThread() { + mBackgroundThread = new HandlerThread("ModuleActivity"); + mBackgroundThread.start(); + mBackgroundHandler = new Handler(mBackgroundThread.getLooper()); + } + + @Override + protected void onDestroy() { + stopBackgroundThread(); + super.onDestroy(); + } + + protected void stopBackgroundThread() { + mBackgroundThread.quitSafely(); + try { + mBackgroundThread.join(); + mBackgroundThread = null; + mBackgroundHandler = null; + } catch (InterruptedException e) { + Log.e(TAG, "Error on stopping background thread", e); + } + } + + @Override + public void onRequestPermissionsResult( + int requestCode, String[] permissions, int[] grantResults) { + if (requestCode == REQUEST_CODE_CAMERA_PERMISSION) { + if (grantResults[0] == PackageManager.PERMISSION_DENIED) { + Toast.makeText( + this, + "You can't use image classification example without granting CAMERA permission", + Toast.LENGTH_LONG) + .show(); + finish(); + } else { + setupCameraX(); + } + } + } + + private static final int TENSOR_WIDTH = 224; + private static final int TENSOR_HEIGHT = 224; + + private void setupCameraX() { + final TextureView textureView = + ((ViewStub) findViewById(R.id.camera_texture_view_stub)) + .inflate() + .findViewById(R.id.texture_view); + final PreviewConfig previewConfig = new PreviewConfig.Builder().build(); + final Preview preview = new Preview(previewConfig); + preview.setOnPreviewOutputUpdateListener( + new Preview.OnPreviewOutputUpdateListener() { + @Override + public void onUpdated(Preview.PreviewOutput output) { + textureView.setSurfaceTexture(output.getSurfaceTexture()); + } + }); + + final ImageAnalysisConfig imageAnalysisConfig = + new ImageAnalysisConfig.Builder() + .setTargetResolution(new Size(TENSOR_WIDTH, TENSOR_HEIGHT)) + .setCallbackHandler(mBackgroundHandler) + .setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE) + .build(); + final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig); + imageAnalysis.setAnalyzer( + new ImageAnalysis.Analyzer() { + @Override + public void analyze(ImageProxy image, int rotationDegrees) { + if (SystemClock.elapsedRealtime() - mLastAnalysisResultTime < 500) { + return; + } + + final Result result = CameraActivity.this.analyzeImage(image, rotationDegrees); + + if (result != null) { + mLastAnalysisResultTime = SystemClock.elapsedRealtime(); + CameraActivity.this.runOnUiThread( + new Runnable() { + @Override + public void run() { + CameraActivity.this.handleResult(result); + } + }); + } + } + }); + + CameraX.bindToLifecycle(this, preview, imageAnalysis); + } + + private Module mModule; + private FloatBuffer mInputTensorBuffer; + private Tensor mInputTensor; + + @WorkerThread + @Nullable + protected Result analyzeImage(ImageProxy image, int rotationDegrees) { + Log.i(TAG, String.format("analyzeImage(%s, %d)", image, rotationDegrees)); + if (mModule == null) { + Log.i(TAG, "Loading module from asset '" + BuildConfig.MODULE_ASSET_NAME + "'"); + mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME); + mInputTensorBuffer = Tensor.allocateFloatBuffer(3 * TENSOR_WIDTH * TENSOR_HEIGHT); + mInputTensor = + Tensor.fromBlob(mInputTensorBuffer, new long[] {1, 3, TENSOR_WIDTH, TENSOR_HEIGHT}); + } + + final long startTime = SystemClock.elapsedRealtime(); + TensorImageUtils.imageYUV420CenterCropToFloatBuffer( + image.getImage(), + rotationDegrees, + TENSOR_WIDTH, + TENSOR_HEIGHT, + TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, + TensorImageUtils.TORCHVISION_NORM_STD_RGB, + mInputTensorBuffer, + 0, + MemoryFormat.CHANNELS_LAST); + final long moduleForwardStartTime = SystemClock.elapsedRealtime(); + final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor(); + final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime; + + final float[] scores = outputTensor.getDataAsFloatArray(); + final long analysisDuration = SystemClock.elapsedRealtime() - startTime; + + return new Result(scores, moduleForwardDuration, analysisDuration); + } + + @UiThread + protected void handleResult(Result result) { + int ixs[] = Utils.topK(result.scores, 1); + String message = + String.format( + "forwardDuration:%d class:%s", + result.moduleForwardDuration, Constants.IMAGENET_CLASSES[ixs[0]]); + Log.i(TAG, message); + mTextViewStringBuilder.insert(0, '\n').insert(0, message); + if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) { + mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length()); + } + mTextView.setText(mTextViewStringBuilder.toString()); + } +} diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/Constants.java b/android/test_app/app/src/main/java/org/pytorch/testapp/Constants.java new file mode 100644 index 000000000000..db60f0b02845 --- /dev/null +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/Constants.java @@ -0,0 +1,1009 @@ +package org.pytorch.testapp; + +public class Constants { + public static final String TAG = "PyTorchDemo"; + + public static String[] IMAGENET_CLASSES = + new String[] { + "tench, Tinca tinca", + "goldfish, Carassius auratus", + "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", + "tiger shark, Galeocerdo cuvieri", + "hammerhead, hammerhead shark", + "electric ray, crampfish, numbfish, torpedo", + "stingray", + "cock", + "hen", + "ostrich, Struthio camelus", + "brambling, Fringilla montifringilla", + "goldfinch, Carduelis carduelis", + "house finch, linnet, Carpodacus mexicanus", + "junco, snowbird", + "indigo bunting, indigo finch, indigo bird, Passerina cyanea", + "robin, American robin, Turdus migratorius", + "bulbul", + "jay", + "magpie", + "chickadee", + "water ouzel, dipper", + "kite", + "bald eagle, American eagle, Haliaeetus leucocephalus", + "vulture", + "great grey owl, great gray owl, Strix nebulosa", + "European fire salamander, Salamandra salamandra", + "common newt, Triturus vulgaris", + "eft", + "spotted salamander, Ambystoma maculatum", + "axolotl, mud puppy, Ambystoma mexicanum", + "bullfrog, Rana catesbeiana", + "tree frog, tree-frog", + "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", + "loggerhead, loggerhead turtle, Caretta caretta", + "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", + "mud turtle", + "terrapin", + "box turtle, box tortoise", + "banded gecko", + "common iguana, iguana, Iguana iguana", + "American chameleon, anole, Anolis carolinensis", + "whiptail, whiptail lizard", + "agama", + "frilled lizard, Chlamydosaurus kingi", + "alligator lizard", + "Gila monster, Heloderma suspectum", + "green lizard, Lacerta viridis", + "African chameleon, Chamaeleo chamaeleon", + "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", + "African crocodile, Nile crocodile, Crocodylus niloticus", + "American alligator, Alligator mississipiensis", + "triceratops", + "thunder snake, worm snake, Carphophis amoenus", + "ringneck snake, ring-necked snake, ring snake", + "hognose snake, puff adder, sand viper", + "green snake, grass snake", + "king snake, kingsnake", + "garter snake, grass snake", + "water snake", + "vine snake", + "night snake, Hypsiglena torquata", + "boa constrictor, Constrictor constrictor", + "rock python, rock snake, Python sebae", + "Indian cobra, Naja naja", + "green mamba", + "sea snake", + "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", + "diamondback, diamondback rattlesnake, Crotalus adamanteus", + "sidewinder, horned rattlesnake, Crotalus cerastes", + "trilobite", + "harvestman, daddy longlegs, Phalangium opilio", + "scorpion", + "black and gold garden spider, Argiope aurantia", + "barn spider, Araneus cavaticus", + "garden spider, Aranea diademata", + "black widow, Latrodectus mactans", + "tarantula", + "wolf spider, hunting spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse, partridge, Bonasa umbellus", + "prairie chicken, prairie grouse, prairie fowl", + "peacock", + "quail", + "partridge", + "African grey, African gray, Psittacus erithacus", + "macaw", + "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "drake", + "red-breasted merganser, Mergus serrator", + "goose", + "black swan, Cygnus atratus", + "tusker", + "echidna, spiny anteater, anteater", + "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", + "wallaby, brush kangaroo", + "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", + "wombat", + "jellyfish", + "sea anemone, anemone", + "brain coral", + "flatworm, platyhelminth", + "nematode, nematode worm, roundworm", + "conch", + "snail", + "slug", + "sea slug, nudibranch", + "chiton, coat-of-mail shell, sea cradle, polyplacophore", + "chambered nautilus, pearly nautilus, nautilus", + "Dungeness crab, Cancer magister", + "rock crab, Cancer irroratus", + "fiddler crab", + "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", + "American lobster, Northern lobster, Maine lobster, Homarus americanus", + "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", + "crayfish, crawfish, crawdad, crawdaddy", + "hermit crab", + "isopod", + "white stork, Ciconia ciconia", + "black stork, Ciconia nigra", + "spoonbill", + "flamingo", + "little blue heron, Egretta caerulea", + "American egret, great white heron, Egretta albus", + "bittern", + "crane", + "limpkin, Aramus pictus", + "European gallinule, Porphyrio porphyrio", + "American coot, marsh hen, mud hen, water hen, Fulica americana", + "bustard", + "ruddy turnstone, Arenaria interpres", + "red-backed sandpiper, dunlin, Erolia alpina", + "redshank, Tringa totanus", + "dowitcher", + "oystercatcher, oyster catcher", + "pelican", + "king penguin, Aptenodytes patagonica", + "albatross, mollymawk", + "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", + "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", + "dugong, Dugong dugon", + "sea lion", + "Chihuahua", + "Japanese spaniel", + "Maltese dog, Maltese terrier, Maltese", + "Pekinese, Pekingese, Peke", + "Shih-Tzu", + "Blenheim spaniel", + "papillon", + "toy terrier", + "Rhodesian ridgeback", + "Afghan hound, Afghan", + "basset, basset hound", + "beagle", + "bloodhound, sleuthhound", + "bluetick", + "black-and-tan coonhound", + "Walker hound, Walker foxhound", + "English foxhound", + "redbone", + "borzoi, Russian wolfhound", + "Irish wolfhound", + "Italian greyhound", + "whippet", + "Ibizan hound, Ibizan Podenco", + "Norwegian elkhound, elkhound", + "otterhound, otter hound", + "Saluki, gazelle hound", + "Scottish deerhound, deerhound", + "Weimaraner", + "Staffordshire bullterrier, Staffordshire bull terrier", + "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", + "Bedlington terrier", + "Border terrier", + "Kerry blue terrier", + "Irish terrier", + "Norfolk terrier", + "Norwich terrier", + "Yorkshire terrier", + "wire-haired fox terrier", + "Lakeland terrier", + "Sealyham terrier, Sealyham", + "Airedale, Airedale terrier", + "cairn, cairn terrier", + "Australian terrier", + "Dandie Dinmont, Dandie Dinmont terrier", + "Boston bull, Boston terrier", + "miniature schnauzer", + "giant schnauzer", + "standard schnauzer", + "Scotch terrier, Scottish terrier, Scottie", + "Tibetan terrier, chrysanthemum dog", + "silky terrier, Sydney silky", + "soft-coated wheaten terrier", + "West Highland white terrier", + "Lhasa, Lhasa apso", + "flat-coated retriever", + "curly-coated retriever", + "golden retriever", + "Labrador retriever", + "Chesapeake Bay retriever", + "German short-haired pointer", + "vizsla, Hungarian pointer", + "English setter", + "Irish setter, red setter", + "Gordon setter", + "Brittany spaniel", + "clumber, clumber spaniel", + "English springer, English springer spaniel", + "Welsh springer spaniel", + "cocker spaniel, English cocker spaniel, cocker", + "Sussex spaniel", + "Irish water spaniel", + "kuvasz", + "schipperke", + "groenendael", + "malinois", + "briard", + "kelpie", + "komondor", + "Old English sheepdog, bobtail", + "Shetland sheepdog, Shetland sheep dog, Shetland", + "collie", + "Border collie", + "Bouvier des Flandres, Bouviers des Flandres", + "Rottweiler", + "German shepherd, German shepherd dog, German police dog, alsatian", + "Doberman, Doberman pinscher", + "miniature pinscher", + "Greater Swiss Mountain dog", + "Bernese mountain dog", + "Appenzeller", + "EntleBucher", + "boxer", + "bull mastiff", + "Tibetan mastiff", + "French bulldog", + "Great Dane", + "Saint Bernard, St Bernard", + "Eskimo dog, husky", + "malamute, malemute, Alaskan malamute", + "Siberian husky", + "dalmatian, coach dog, carriage dog", + "affenpinscher, monkey pinscher, monkey dog", + "basenji", + "pug, pug-dog", + "Leonberg", + "Newfoundland, Newfoundland dog", + "Great Pyrenees", + "Samoyed, Samoyede", + "Pomeranian", + "chow, chow chow", + "keeshond", + "Brabancon griffon", + "Pembroke, Pembroke Welsh corgi", + "Cardigan, Cardigan Welsh corgi", + "toy poodle", + "miniature poodle", + "standard poodle", + "Mexican hairless", + "timber wolf, grey wolf, gray wolf, Canis lupus", + "white wolf, Arctic wolf, Canis lupus tundrarum", + "red wolf, maned wolf, Canis rufus, Canis niger", + "coyote, prairie wolf, brush wolf, Canis latrans", + "dingo, warrigal, warragal, Canis dingo", + "dhole, Cuon alpinus", + "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", + "hyena, hyaena", + "red fox, Vulpes vulpes", + "kit fox, Vulpes macrotis", + "Arctic fox, white fox, Alopex lagopus", + "grey fox, gray fox, Urocyon cinereoargenteus", + "tabby, tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat, Siamese", + "Egyptian cat", + "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", + "lynx, catamount", + "leopard, Panthera pardus", + "snow leopard, ounce, Panthera uncia", + "jaguar, panther, Panthera onca, Felis onca", + "lion, king of beasts, Panthera leo", + "tiger, Panthera tigris", + "cheetah, chetah, Acinonyx jubatus", + "brown bear, bruin, Ursus arctos", + "American black bear, black bear, Ursus americanus, Euarctos americanus", + "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", + "sloth bear, Melursus ursinus, Ursus ursinus", + "mongoose", + "meerkat, mierkat", + "tiger beetle", + "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", + "ground beetle, carabid beetle", + "long-horned beetle, longicorn, longicorn beetle", + "leaf beetle, chrysomelid", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant, emmet, pismire", + "grasshopper, hopper", + "cricket", + "walking stick, walkingstick, stick insect", + "cockroach, roach", + "mantis, mantid", + "cicada, cicala", + "leafhopper", + "lacewing, lacewing fly", + "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + "damselfly", + "admiral", + "ringlet, ringlet butterfly", + "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", + "cabbage butterfly", + "sulphur butterfly, sulfur butterfly", + "lycaenid, lycaenid butterfly", + "starfish, sea star", + "sea urchin", + "sea cucumber, holothurian", + "wood rabbit, cottontail, cottontail rabbit", + "hare", + "Angora, Angora rabbit", + "hamster", + "porcupine, hedgehog", + "fox squirrel, eastern fox squirrel, Sciurus niger", + "marmot", + "beaver", + "guinea pig, Cavia cobaya", + "sorrel", + "zebra", + "hog, pig, grunter, squealer, Sus scrofa", + "wild boar, boar, Sus scrofa", + "warthog", + "hippopotamus, hippo, river horse, Hippopotamus amphibius", + "ox", + "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", + "bison", + "ram, tup", + "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", + "ibex, Capra ibex", + "hartebeest", + "impala, Aepyceros melampus", + "gazelle", + "Arabian camel, dromedary, Camelus dromedarius", + "llama", + "weasel", + "mink", + "polecat, fitch, foulmart, foumart, Mustela putorius", + "black-footed ferret, ferret, Mustela nigripes", + "otter", + "skunk, polecat, wood pussy", + "badger", + "armadillo", + "three-toed sloth, ai, Bradypus tridactylus", + "orangutan, orang, orangutang, Pongo pygmaeus", + "gorilla, Gorilla gorilla", + "chimpanzee, chimp, Pan troglodytes", + "gibbon, Hylobates lar", + "siamang, Hylobates syndactylus, Symphalangus syndactylus", + "guenon, guenon monkey", + "patas, hussar monkey, Erythrocebus patas", + "baboon", + "macaque", + "langur", + "colobus, colobus monkey", + "proboscis monkey, Nasalis larvatus", + "marmoset", + "capuchin, ringtail, Cebus capucinus", + "howler monkey, howler", + "titi, titi monkey", + "spider monkey, Ateles geoffroyi", + "squirrel monkey, Saimiri sciureus", + "Madagascar cat, ring-tailed lemur, Lemur catta", + "indri, indris, Indri indri, Indri brevicaudatus", + "Indian elephant, Elephas maximus", + "African elephant, Loxodonta africana", + "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", + "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", + "barracouta, snoek", + "eel", + "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", + "rock beauty, Holocanthus tricolor", + "anemone fish", + "sturgeon", + "gar, garfish, garpike, billfish, Lepisosteus osseus", + "lionfish", + "puffer, pufferfish, blowfish, globefish", + "abacus", + "abaya", + "academic gown, academic robe, judge's robe", + "accordion, piano accordion, squeeze box", + "acoustic guitar", + "aircraft carrier, carrier, flattop, attack aircraft carrier", + "airliner", + "airship, dirigible", + "altar", + "ambulance", + "amphibian, amphibious vehicle", + "analog clock", + "apiary, bee house", + "apron", + "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + "assault rifle, assault gun", + "backpack, back pack, knapsack, packsack, rucksack, haversack", + "bakery, bakeshop, bakehouse", + "balance beam, beam", + "balloon", + "ballpoint, ballpoint pen, ballpen, Biro", + "Band Aid", + "banjo", + "bannister, banister, balustrade, balusters, handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel, cask", + "barrow, garden cart, lawn cart, wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "bathing cap, swimming cap", + "bath towel", + "bathtub, bathing tub, bath, tub", + "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", + "beacon, lighthouse, beacon light, pharos", + "beaker", + "bearskin, busby, shako", + "beer bottle", + "beer glass", + "bell cote, bell cot", + "bib", + "bicycle-built-for-two, tandem bicycle, tandem", + "bikini, two-piece", + "binder, ring-binder", + "binoculars, field glasses, opera glasses", + "birdhouse", + "boathouse", + "bobsled, bobsleigh, bob", + "bolo tie, bolo, bola tie, bola", + "bonnet, poke bonnet", + "bookcase", + "bookshop, bookstore, bookstall", + "bottlecap", + "bow", + "bow tie, bow-tie, bowtie", + "brass, memorial tablet, plaque", + "brassiere, bra, bandeau", + "breakwater, groin, groyne, mole, bulwark, seawall, jetty", + "breastplate, aegis, egis", + "broom", + "bucket, pail", + "buckle", + "bulletproof vest", + "bullet train, bullet", + "butcher shop, meat market", + "cab, hack, taxi, taxicab", + "caldron, cauldron", + "candle, taper, wax light", + "cannon", + "canoe", + "can opener, tin opener", + "cardigan", + "car mirror", + "carousel, carrousel, merry-go-round, roundabout, whirligig", + "carpenter's kit, tool kit", + "carton", + "car wheel", + "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello, violoncello", + "cellular telephone, cellular phone, cellphone, cell, mobile phone", + "chain", + "chainlink fence", + "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", + "chain saw, chainsaw", + "chest", + "chiffonier, commode", + "chime, bell, gong", + "china cabinet, china closet", + "Christmas stocking", + "church, church building", + "cinema, movie theater, movie theatre, movie house, picture palace", + "cleaver, meat cleaver, chopper", + "cliff dwelling", + "cloak", + "clog, geta, patten, sabot", + "cocktail shaker", + "coffee mug", + "coffeepot", + "coil, spiral, volute, whorl, helix", + "combination lock", + "computer keyboard, keypad", + "confectionery, confectionary, candy store", + "container ship, containership, container vessel", + "convertible", + "corkscrew, bottle screw", + "cornet, horn, trumpet, trump", + "cowboy boot", + "cowboy hat, ten-gallon hat", + "cradle", + "crane", + "crash helmet", + "crate", + "crib, cot", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam, dike, dyke", + "desk", + "desktop computer", + "dial telephone, dial phone", + "diaper, nappy, napkin", + "digital clock", + "digital watch", + "dining table, board", + "dishrag, dishcloth", + "dishwasher, dish washer, dishwashing machine", + "disk brake, disc brake", + "dock, dockage, docking facility", + "dogsled, dog sled, dog sleigh", + "dome", + "doormat, welcome mat", + "drilling platform, offshore rig", + "drum, membranophone, tympan", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan, blower", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso maker", + "face powder", + "feather boa, boa", + "file, file cabinet, filing cabinet", + "fireboat", + "fire engine, fire truck", + "fire screen, fireguard", + "flagpole, flagstaff", + "flute, transverse flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster", + "freight car", + "French horn, horn", + "frying pan, frypan, skillet", + "fur coat", + "garbage truck, dustcart", + "gasmask, respirator, gas helmet", + "gas pump, gasoline pump, petrol pump, island dispenser", + "goblet", + "go-kart", + "golf ball", + "golfcart, golf cart", + "gondola", + "gong, tam-tam", + "gown", + "grand piano, grand", + "greenhouse, nursery, glasshouse", + "grille, radiator grille", + "grocery store, grocery, food market, market", + "guillotine", + "hair slide", + "hair spray", + "half track", + "hammer", + "hamper", + "hand blower, blow dryer, blow drier, hair dryer, hair drier", + "hand-held computer, hand-held microcomputer", + "handkerchief, hankie, hanky, hankey", + "hard disc, hard disk, fixed disk", + "harmonica, mouth organ, harp, mouth harp", + "harp", + "harvester, reaper", + "hatchet", + "holster", + "home theater, home theatre", + "honeycomb", + "hook, claw", + "hoopskirt, crinoline", + "horizontal bar, high bar", + "horse cart, horse-cart", + "hourglass", + "iPod", + "iron, smoothing iron", + "jack-o'-lantern", + "jean, blue jean, denim", + "jeep, landrover", + "jersey, T-shirt, tee shirt", + "jigsaw puzzle", + "jinrikisha, ricksha, rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat, laboratory coat", + "ladle", + "lampshade, lamp shade", + "laptop, laptop computer", + "lawn mower, mower", + "lens cap, lens cover", + "letter opener, paper knife, paperknife", + "library", + "lifeboat", + "lighter, light, igniter, ignitor", + "limousine, limo", + "liner, ocean liner", + "lipstick, lip rouge", + "Loafer", + "lotion", + "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + "loupe, jeweler's loupe", + "lumbermill, sawmill", + "magnetic compass", + "mailbag, postbag", + "mailbox, letter box", + "maillot", + "maillot, tank suit", + "manhole cover", + "maraca", + "marimba, xylophone", + "mask", + "matchstick", + "maypole", + "maze, labyrinth", + "measuring cup", + "medicine chest, medicine cabinet", + "megalith, megalithic structure", + "microphone, mike", + "microwave, microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt, mini", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home, manufactured home", + "Model T", + "modem", + "monastery", + "monitor", + "moped", + "mortar", + "mortarboard", + "mosque", + "mosquito net", + "motor scooter, scooter", + "mountain bike, all-terrain bike, off-roader", + "mountain tent", + "mouse, computer mouse", + "mousetrap", + "moving van", + "muzzle", + "nail", + "neck brace", + "necklace", + "nipple", + "notebook, notebook computer", + "obelisk", + "oboe, hautboy, hautbois", + "ocarina, sweet potato", + "odometer, hodometer, mileometer, milometer", + "oil filter", + "organ, pipe organ", + "oscilloscope, scope, cathode-ray oscilloscope, CRO", + "overskirt", + "oxcart", + "oxygen mask", + "packet", + "paddle, boat paddle", + "paddlewheel, paddle wheel", + "padlock", + "paintbrush", + "pajama, pyjama, pj's, jammies", + "palace", + "panpipe, pandean pipe, syrinx", + "paper towel", + "parachute, chute", + "parallel bars, bars", + "park bench", + "parking meter", + "passenger car, coach, carriage", + "patio, terrace", + "pay-phone, pay-station", + "pedestal, plinth, footstall", + "pencil box, pencil case", + "pencil sharpener", + "perfume, essence", + "Petri dish", + "photocopier", + "pick, plectrum, plectron", + "pickelhaube", + "picket fence, paling", + "pickup, pickup truck", + "pier", + "piggy bank, penny bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate, pirate ship", + "pitcher, ewer", + "plane, carpenter's plane, woodworking plane", + "planetarium", + "plastic bag", + "plate rack", + "plow, plough", + "plunger, plumber's helper", + "Polaroid camera, Polaroid Land camera", + "pole", + "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", + "poncho", + "pool table, billiard table, snooker table", + "pop bottle, soda bottle", + "pot, flowerpot", + "potter's wheel", + "power drill", + "prayer rug, prayer mat", + "printer", + "prison, prison house", + "projectile, missile", + "projector", + "puck, hockey puck", + "punching bag, punch bag, punching ball, punchball", + "purse", + "quill, quill pen", + "quilt, comforter, comfort, puff", + "racer, race car, racing car", + "racket, racquet", + "radiator", + "radio, wireless", + "radio telescope, radio reflector", + "rain barrel", + "recreational vehicle, RV, R.V.", + "reel", + "reflex camera", + "refrigerator, icebox", + "remote control, remote", + "restaurant, eating house, eating place, eatery", + "revolver, six-gun, six-shooter", + "rifle", + "rocking chair, rocker", + "rotisserie", + "rubber eraser, rubber, pencil eraser", + "rugby ball", + "rule, ruler", + "running shoe", + "safe", + "safety pin", + "saltshaker, salt shaker", + "sandal", + "sarong", + "sax, saxophone", + "scabbard", + "scale, weighing machine", + "school bus", + "schooner", + "scoreboard", + "screen, CRT screen", + "screw", + "screwdriver", + "seat belt, seatbelt", + "sewing machine", + "shield, buckler", + "shoe shop, shoe-shop, shoe store", + "shoji", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "ski mask", + "sleeping bag", + "slide rule, slipstick", + "sliding door", + "slot, one-armed bandit", + "snorkel", + "snowmobile", + "snowplow, snowplough", + "soap dispenser", + "soccer ball", + "sock", + "solar dish, solar collector, solar furnace", + "sombrero", + "soup bowl", + "space bar", + "space heater", + "space shuttle", + "spatula", + "speedboat", + "spider web, spider's web", + "spindle", + "sports car, sport car", + "spotlight, spot", + "stage", + "steam locomotive", + "steel arch bridge", + "steel drum", + "stethoscope", + "stole", + "stone wall", + "stopwatch, stop watch", + "stove", + "strainer", + "streetcar, tram, tramcar, trolley, trolley car", + "stretcher", + "studio couch, day bed", + "stupa, tope", + "submarine, pigboat, sub, U-boat", + "suit, suit of clothes", + "sundial", + "sunglass", + "sunglasses, dark glasses, shades", + "sunscreen, sunblock, sun blocker", + "suspension bridge", + "swab, swob, mop", + "sweatshirt", + "swimming trunks, bathing trunks", + "swing", + "switch, electric switch, electrical switch", + "syringe", + "table lamp", + "tank, army tank, armored combat vehicle, armoured combat vehicle", + "tape player", + "teapot", + "teddy, teddy bear", + "television, television system", + "tennis ball", + "thatch, thatched roof", + "theater curtain, theatre curtain", + "thimble", + "thresher, thrasher, threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop, tobacconist shop, tobacconist", + "toilet seat", + "torch", + "totem pole", + "tow truck, tow car, wrecker", + "toyshop", + "tractor", + "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", + "tray", + "trench coat", + "tricycle, trike, velocipede", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus, trolley coach, trackless trolley", + "trombone", + "tub, vat", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle, monocycle", + "upright, upright piano", + "vacuum, vacuum cleaner", + "vase", + "vault", + "velvet", + "vending machine", + "vestment", + "viaduct", + "violin, fiddle", + "volleyball", + "waffle iron", + "wall clock", + "wallet, billfold, notecase, pocketbook", + "wardrobe, closet, press", + "warplane, military plane", + "washbasin, handbasin, washbowl, lavabo, wash-hand basin", + "washer, automatic washer, washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "wing", + "wok", + "wooden spoon", + "wool, woolen, woollen", + "worm fence, snake fence, snake-rail fence, Virginia fence", + "wreck", + "yawl", + "yurt", + "web site, website, internet site, site", + "comic book", + "crossword puzzle, crossword", + "street sign", + "traffic light, traffic signal, stoplight", + "book jacket, dust cover, dust jacket, dust wrapper", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot, hotpot", + "trifle", + "ice cream, icecream", + "ice lolly, lolly, lollipop, popsicle", + "French loaf", + "bagel, beigel", + "pretzel", + "cheeseburger", + "hotdog, hot dog, red hot", + "mashed potato", + "head cabbage", + "broccoli", + "cauliflower", + "zucchini, courgette", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber, cuke", + "artichoke, globe artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple, ananas", + "banana", + "jackfruit, jak, jack", + "custard apple", + "pomegranate", + "hay", + "carbonara", + "chocolate sauce, chocolate syrup", + "dough", + "meat loaf, meatloaf", + "pizza, pizza pie", + "potpie", + "burrito", + "red wine", + "espresso", + "cup", + "eggnog", + "alp", + "bubble", + "cliff, drop, drop-off", + "coral reef", + "geyser", + "lakeside, lakeshore", + "promontory, headland, head, foreland", + "sandbar, sand bar", + "seashore, coast, seacoast, sea-coast", + "valley, vale", + "volcano", + "ballplayer, baseball player", + "groom, bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + "corn", + "acorn", + "hip, rose hip, rosehip", + "buckeye, horse chestnut, conker", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn, carrion fungus", + "earthstar", + "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", + "bolete", + "ear, spike, capitulum", + "toilet tissue, toilet paper, bathroom tissue" + }; +} diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/LibtorchNativeClient.java b/android/test_app/app/src/main/java/org/pytorch/testapp/LibtorchNativeClient.java new file mode 100644 index 000000000000..6ff2d46166f7 --- /dev/null +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/LibtorchNativeClient.java @@ -0,0 +1,22 @@ +package org.pytorch.testapp; + +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; + +public final class LibtorchNativeClient { + + public static void loadAndForwardModel(final String modelPath) { + NativePeer.loadAndForwardModel(modelPath); + } + + private static class NativePeer { + static { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + NativeLoader.loadLibrary("pytorch_testapp_jni"); + } + + private static native void loadAndForwardModel(final String modelPath); + } +} diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java b/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java new file mode 100644 index 000000000000..3d30d9b78ecd --- /dev/null +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java @@ -0,0 +1,171 @@ +package org.pytorch.testapp; + +import android.content.Context; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.SystemClock; +import android.util.Log; +import android.widget.TextView; +import androidx.annotation.Nullable; +import androidx.annotation.UiThread; +import androidx.annotation.WorkerThread; +import androidx.appcompat.app.AppCompatActivity; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.FloatBuffer; +import org.pytorch.Device; +import org.pytorch.IValue; +import org.pytorch.MemoryFormat; +import org.pytorch.Module; +import org.pytorch.PyTorchAndroid; +import org.pytorch.Tensor; + +public class MainActivity extends AppCompatActivity { + + private static final String TAG = BuildConfig.LOGCAT_TAG; + private static final int TEXT_TRIM_SIZE = 4096; + + private TextView mTextView; + + protected HandlerThread mBackgroundThread; + protected Handler mBackgroundHandler; + private Module mModule; + private FloatBuffer mInputTensorBuffer; + private Tensor mInputTensor; + private StringBuilder mTextViewStringBuilder = new StringBuilder(); + + private final Runnable mModuleForwardRunnable = + new Runnable() { + @Override + public void run() { + final Result result = doModuleForward(); + runOnUiThread( + new Runnable() { + @Override + public void run() { + handleResult(result); + if (mBackgroundHandler != null) { + mBackgroundHandler.post(mModuleForwardRunnable); + } + } + }); + } + }; + + public static String assetFilePath(Context context, String assetName) { + File file = new File(context.getFilesDir(), assetName); + if (file.exists() && file.length() > 0) { + return file.getAbsolutePath(); + } + + try (InputStream is = context.getAssets().open(assetName)) { + try (OutputStream os = new FileOutputStream(file)) { + byte[] buffer = new byte[4 * 1024]; + int read; + while ((read = is.read(buffer)) != -1) { + os.write(buffer, 0, read); + } + os.flush(); + } + return file.getAbsolutePath(); + } catch (IOException e) { + Log.e(TAG, "Error process asset " + assetName + " to file path"); + } + return null; + } + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + if (BuildConfig.NATIVE_BUILD) { + final String modelFileAbsoluteFilePath = + new File(assetFilePath(this, BuildConfig.MODULE_ASSET_NAME)).getAbsolutePath(); + LibtorchNativeClient.loadAndForwardModel(modelFileAbsoluteFilePath); + return; + } + setContentView(R.layout.activity_main); + mTextView = findViewById(R.id.text); + startBackgroundThread(); + mBackgroundHandler.post(mModuleForwardRunnable); + } + + protected void startBackgroundThread() { + mBackgroundThread = new HandlerThread(TAG + "_bg"); + mBackgroundThread.start(); + mBackgroundHandler = new Handler(mBackgroundThread.getLooper()); + } + + @Override + protected void onDestroy() { + stopBackgroundThread(); + super.onDestroy(); + } + + protected void stopBackgroundThread() { + mBackgroundThread.quitSafely(); + try { + mBackgroundThread.join(); + mBackgroundThread = null; + mBackgroundHandler = null; + } catch (InterruptedException e) { + Log.e(TAG, "Error stopping background thread", e); + } + } + + @WorkerThread + @Nullable + protected Result doModuleForward() { + if (mModule == null) { + final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE; + long numElements = 1; + for (int i = 0; i < shape.length; i++) { + numElements *= shape[i]; + } + mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements); + mInputTensor = + Tensor.fromBlob( + mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE, MemoryFormat.CHANNELS_LAST); + PyTorchAndroid.setNumThreads(1); + mModule = + BuildConfig.USE_VULKAN_DEVICE + ? PyTorchAndroid.loadModuleFromAsset( + getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN) + : PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME); + } + + final long startTime = SystemClock.elapsedRealtime(); + final long moduleForwardStartTime = SystemClock.elapsedRealtime(); + final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor(); + final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime; + final float[] scores = outputTensor.getDataAsFloatArray(); + final long analysisDuration = SystemClock.elapsedRealtime() - startTime; + return new Result(scores, moduleForwardDuration, analysisDuration); + } + + static class Result { + + private final float[] scores; + private final long totalDuration; + private final long moduleForwardDuration; + + public Result(float[] scores, long moduleForwardDuration, long totalDuration) { + this.scores = scores; + this.moduleForwardDuration = moduleForwardDuration; + this.totalDuration = totalDuration; + } + } + + @UiThread + protected void handleResult(Result result) { + String message = String.format("forwardDuration:%d", result.moduleForwardDuration); + mTextViewStringBuilder.insert(0, '\n').insert(0, message); + if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) { + mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length()); + } + mTextView.setText(mTextViewStringBuilder.toString()); + } +} diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/Result.java b/android/test_app/app/src/main/java/org/pytorch/testapp/Result.java new file mode 100644 index 000000000000..7cc1148ede3d --- /dev/null +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/Result.java @@ -0,0 +1,14 @@ +package org.pytorch.testapp; + +class Result { + + public final float[] scores; + public final long totalDuration; + public final long moduleForwardDuration; + + public Result(float[] scores, long moduleForwardDuration, long totalDuration) { + this.scores = scores; + this.moduleForwardDuration = moduleForwardDuration; + this.totalDuration = totalDuration; + } +} diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/Utils.java b/android/test_app/app/src/main/java/org/pytorch/testapp/Utils.java new file mode 100644 index 000000000000..230d39e80acd --- /dev/null +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/Utils.java @@ -0,0 +1,28 @@ +package org.pytorch.testapp; + +import java.util.Arrays; + +public class Utils { + + public static int[] topK(float[] a, final int topk) { + float values[] = new float[topk]; + Arrays.fill(values, -Float.MAX_VALUE); + int ixs[] = new int[topk]; + Arrays.fill(ixs, -1); + + for (int i = 0; i < a.length; i++) { + for (int j = 0; j < topk; j++) { + if (a[i] > values[j]) { + for (int k = topk - 1; k >= j + 1; k--) { + values[k] = values[k - 1]; + ixs[k] = ixs[k - 1]; + } + values[j] = a[i]; + ixs[j] = i; + break; + } + } + } + return ixs; + } +} diff --git a/android/test_app/app/src/main/res/layout/activity_camera.xml b/android/test_app/app/src/main/res/layout/activity_camera.xml new file mode 100644 index 000000000000..9331a2dea598 --- /dev/null +++ b/android/test_app/app/src/main/res/layout/activity_camera.xml @@ -0,0 +1,23 @@ + + + + + + + diff --git a/android/test_app/app/src/main/res/layout/activity_main.xml b/android/test_app/app/src/main/res/layout/activity_main.xml new file mode 100644 index 000000000000..556839a994c5 --- /dev/null +++ b/android/test_app/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,17 @@ + + + + + + diff --git a/android/test_app/app/src/main/res/layout/texture_view.xml b/android/test_app/app/src/main/res/layout/texture_view.xml new file mode 100644 index 000000000000..6518c6c84c69 --- /dev/null +++ b/android/test_app/app/src/main/res/layout/texture_view.xml @@ -0,0 +1,5 @@ + + diff --git a/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher.png b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher.png new file mode 100644 index 000000000000..64ba76f75e9c Binary files /dev/null and b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher.png differ diff --git a/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher_round.png b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher_round.png new file mode 100644 index 000000000000..dae5e082342f Binary files /dev/null and b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher_round.png differ diff --git a/android/test_app/app/src/main/res/values/colors.xml b/android/test_app/app/src/main/res/values/colors.xml new file mode 100644 index 000000000000..69b22338c651 --- /dev/null +++ b/android/test_app/app/src/main/res/values/colors.xml @@ -0,0 +1,6 @@ + + + #008577 + #00574B + #D81B60 + diff --git a/android/test_app/app/src/main/res/values/strings.xml b/android/test_app/app/src/main/res/values/strings.xml new file mode 100644 index 000000000000..b8c9ca13d681 --- /dev/null +++ b/android/test_app/app/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + PyTest + diff --git a/android/test_app/app/src/main/res/values/styles.xml b/android/test_app/app/src/main/res/values/styles.xml new file mode 100644 index 000000000000..5885930df6d1 --- /dev/null +++ b/android/test_app/app/src/main/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/android/test_app/make_assets.py b/android/test_app/make_assets.py new file mode 100644 index 000000000000..81c1cf16bdb4 --- /dev/null +++ b/android/test_app/make_assets.py @@ -0,0 +1,24 @@ +from torchvision import models + +import torch + + +print(torch.version.__version__) + +resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) +resnet18.eval() +resnet18_traced = torch.jit.trace(resnet18, torch.rand(1, 3, 224, 224)).save( + "app/src/main/assets/resnet18.pt" +) + +resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) +resnet50.eval() +torch.jit.trace(resnet50, torch.rand(1, 3, 224, 224)).save( + "app/src/main/assets/resnet50.pt" +) + +mobilenet2q = models.quantization.mobilenet_v2(pretrained=True, quantize=True) +mobilenet2q.eval() +torch.jit.trace(mobilenet2q, torch.rand(1, 3, 224, 224)).save( + "app/src/main/assets/mobilenet2q.pt" +) diff --git a/android/test_app/make_assets_custom.py b/android/test_app/make_assets_custom.py new file mode 100644 index 000000000000..fa31dfb57ac8 --- /dev/null +++ b/android/test_app/make_assets_custom.py @@ -0,0 +1,27 @@ +""" +This is a script for PyTorch Android custom selective build test. It prepares +MobileNetV2 TorchScript model, and dumps root ops used by the model for custom +build script to create a tailored build which only contains these used ops. +""" + +import yaml +from torchvision import models + +import torch + + +# Download and trace the model. +model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1) +model.eval() +example = torch.rand(1, 3, 224, 224) +# TODO: create script model with `torch.jit.script` +traced_script_module = torch.jit.trace(model, example) + +# Save traced TorchScript model. +traced_script_module.save("MobileNetV2.pt") + +# Dump root ops used by the model (for custom build optimization). +ops = torch.jit.export_opnames(traced_script_module) + +with open("MobileNetV2.yaml", "w") as output: + yaml.dump(ops, output)