Revert "Reapply "Delete TorchScript based Android demo app and point to ExecuTorch (#153633)" (#153656)"

This reverts commit 7ed377f5776578aec4a6a9bc4eeef221a6b80a77.

Reverted https://github.com/pytorch/pytorch/pull/153656 on behalf of https://github.com/larryliu0820 due to Still being used internally so can't remove ([comment](https://github.com/pytorch/pytorch/pull/153656#issuecomment-2887665403))
This commit is contained in:
PyTorch MergeBot
2025-05-16 21:00:11 +00:00
parent e4a636df80
commit 084c4aa614
119 changed files with 7700 additions and 3 deletions

View File

@ -0,0 +1,184 @@
cmake_minimum_required(VERSION 3.5)
option(BUILD_LITE_INTERPRETER "Master flag to build pytorch_jni_lite" ON)
message(
STATUS
"BUILD_LITE_INTERPRETER (pytorch_jni_lite): ${BUILD_LITE_INTERPRETER}")
if(BUILD_LITE_INTERPRETER)
project(pytorch_jni_lite CXX)
set(PYTORCH_JNI_TARGET pytorch_jni_lite)
else()
project(pytorch_jni CXX)
set(PYTORCH_JNI_TARGET pytorch_jni)
endif()
include(GNUInstallDirs)
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.")
set(CMAKE_VERBOSE_MAKEFILE ON)
message(STATUS "ANDROID_STL:${ANDROID_STL}")
set(TRACE_ENABLED OFF)
if(DEFINED ENV{TRACE_ENABLED})
if($ENV{TRACE_ENABLED} STREQUAL "1")
message(STATUS "TRACE_ENABLED ON")
set(TRACE_ENABLED ON)
endif()
endif()
if(NOT TRACE_ENABLED)
message(STATUS "TRACE_ENABLED OFF")
endif()
set(USE_VULKAN OFF)
set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
if(ANDROID_ABI)
set(USE_VULKAN ON)
set(libtorch_include_DIR ${pytorch_android_DIR}/libtorch_include/${ANDROID_ABI})
set(BUILD_SUBDIR ${ANDROID_ABI})
elseif(BUILD_LIBTORCH_WITH_JNI)
# Don't need LIBTORCH_HOME if we're building from within PyTorch.
else()
# Building against a pre-built libtorch.
if(NOT LIBTORCH_HOME)
message(FATAL_ERROR
"pytorch_android requires LIBTORCH_HOME to be defined for non-Android builds.")
endif()
set(libtorch_include_DIR ${LIBTORCH_HOME}/include)
link_directories(${LIBTORCH_HOME}/lib)
set(BUILD_SUBDIR host)
endif()
message(STATUS "libtorch dir:${libtorch_DIR}")
configure_file(
${pytorch_android_DIR}/cmake_macros.h.in
${pytorch_android_DIR}/cmake_macros.h)
if(BUILD_LITE_INTERPRETER)
file(GLOB pytorch_android_SOURCES
${pytorch_android_DIR}/pytorch_jni_lite.cpp
${pytorch_android_DIR}/pytorch_jni_common.cpp
${pytorch_android_DIR}/pytorch_jni_common.h
)
else()
file(GLOB pytorch_android_SOURCES
${pytorch_android_DIR}/pytorch_jni_jit.cpp
${pytorch_android_DIR}/pytorch_jni_common.cpp
${pytorch_android_DIR}/pytorch_jni_common.h
)
endif()
add_library(${PYTORCH_JNI_TARGET} SHARED ${pytorch_android_SOURCES})
if(APPLE)
# Need to add rpath so dlopen can find dependencies.
add_custom_command(TARGET pytorch_jni
POST_BUILD COMMAND
${CMAKE_INSTALL_NAME_TOOL} -add_rpath "@loader_path"
$<TARGET_FILE:pytorch_jni>)
endif()
target_compile_options(${PYTORCH_JNI_TARGET} PRIVATE
-fexceptions
)
target_include_directories(${PYTORCH_JNI_TARGET} BEFORE
PUBLIC $<BUILD_INTERFACE:${libtorch_include_DIR}>)
set(fbjni_DIR ${CMAKE_CURRENT_LIST_DIR}/../libs/fbjni/)
set(fbjni_BUILD_DIR ${CMAKE_BINARY_DIR}/fbjni/${BUILD_SUBDIR})
add_subdirectory(${fbjni_DIR} ${fbjni_BUILD_DIR})
# ---[ Vulkan deps
if(USE_VULKAN)
set(Vulkan_LIBS)
set(Vulkan_INCLUDES)
include(${CMAKE_CURRENT_LIST_DIR}/../../cmake/VulkanDependencies.cmake)
endif()
if(ANDROID_ABI)
function(import_static_lib name)
add_library(${name} STATIC IMPORTED)
set_property(
TARGET ${name}
PROPERTY IMPORTED_LOCATION
${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/${name}.a)
endfunction(import_static_lib)
import_static_lib(libtorch)
import_static_lib(libtorch_cpu)
import_static_lib(libc10)
import_static_lib(libnnpack)
import_static_lib(libXNNPACK)
import_static_lib(libmicrokernels-prod)
import_static_lib(libpytorch_qnnpack)
import_static_lib(libpthreadpool)
import_static_lib(libeigen_blas)
import_static_lib(libcpuinfo)
import_static_lib(libclog)
# Link most things statically on Android.
set(pytorch_jni_LIBS
fbjni
-Wl,--gc-sections
-Wl,--whole-archive
libtorch
libtorch_cpu
-Wl,--no-whole-archive
libc10
libnnpack
libXNNPACK
libmicrokernels-prod
libpytorch_qnnpack
libpthreadpool
libeigen_blas
libcpuinfo
libclog
)
else()
# Prefer dynamic linking on the host
set(pytorch_jni_LIBS
fbjni
torch
torch_cpu
c10
cpuinfo
)
if(USE_NNPACK)
list(APPEND pytorch_jni_LIBS nnpack)
endif()
if(USE_XNNPACK)
list(APPEND pytorch_jni_LIBS XNNPACK)
list(APPEND pytorch_jni_LIBS microkernels-prod)
endif()
if(USE_SYSTEM_PTHREADPOOL)
list(APPEND pytorch_jni_LIBS pthreadpool)
endif()
if(USE_PYTORCH_QNNPACK)
list(APPEND pytorch_jni_LIBS pytorch_qnnpack)
list(APPEND pytorch_jni_LIBS clog)
endif()
endif()
if(USE_VULKAN)
list(APPEND pytorch_jni_LIBS ${Vulkan_LIBS})
endif()
target_link_libraries(${PYTORCH_JNI_TARGET} ${pytorch_jni_LIBS})
install(TARGETS ${PYTORCH_JNI_TARGET}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) #For windows
if(MSVC)
install(FILES $<TARGET_PDB_FILE:pytorch_jni> DESTINATION ${CMAKE_INSTALL_LIBDIR} OPTIONAL)
install(TARGETS ${PYTORCH_JNI_TARGET} DESTINATION ${CMAKE_INSTALL_LIBDIR})
endif()

View File

@ -0,0 +1,163 @@
apply plugin: 'com.android.library'
apply plugin: 'maven'
android {
compileSdkVersion rootProject.compileSdkVersion
buildToolsVersion rootProject.buildToolsVersion
defaultConfig {
minSdkVersion rootProject.minSdkVersion
targetSdkVersion rootProject.targetSdkVersion
versionCode 0
versionName "0.1"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
ndk {
abiFilters ABI_FILTERS.split(",")
}
externalNativeBuild {
cmake {
if(System.env.BUILD_LITE_INTERPRETER == '0') {
arguments "-DANDROID_STL=c++_shared", "-DBUILD_LITE_INTERPRETER=OFF", "-DUSE_LITE_INTERPRETER_PROFILER=OFF"
} else {
arguments "-DANDROID_STL=c++_shared", "-DUSE_LITE_INTERPRETER_PROFILER=OFF"
}
}
}
}
buildTypes {
debug {
minifyEnabled false
debuggable true
}
release {
minifyEnabled false
}
}
sourceSets {
main {
java {
if(System.env.BUILD_LITE_INTERPRETER == '0') {
println 'Build pytorch_jni'
exclude 'org/pytorch/LiteModuleLoader.java'
exclude 'org/pytorch/LiteNativePeer.java'
} else {
println 'Build pytorch_jni_lite'
}
}
jniLibs.srcDirs = ['src/main/jniLibs']
manifest.srcFile 'src/main/AndroidManifest.xml'
}
androidTest {
java {
if(System.env.BUILD_LITE_INTERPRETER == '0') {
println 'Build test for full jit (pytorch_jni)'
exclude 'org/pytorch/PytorchHostTests.java'
exclude 'org/pytorch/PytorchLiteInstrumentedTests.java'
exclude 'org/pytorch/suite/PytorchLiteInstrumentedTestSuite.java'
} else {
println 'Build test for lite interpreter (pytorch_jni_lite)'
exclude 'org/pytorch/PytorchHostTests.java'
exclude 'org/pytorch/PytorchInstrumentedTests.java'
exclude 'org/pytorch/suite/PytorchInstrumentedTestSuite.java'
}
}
}
}
externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}
packagingOptions {
if (nativeLibsDoNotStrip.toBoolean()) {
doNotStrip "**/*.so"
logger.warn('WARNING: nativeLibsDoNotStrip==true; debug symbols included')
}
}
useLibrary 'android.test.runner'
useLibrary 'android.test.base'
useLibrary 'android.test.mock'
}
dependencies {
implementation 'com.facebook.fbjni:fbjni-java-only:' + rootProject.fbjniJavaOnlyVersion
implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion
testImplementation 'junit:junit:' + rootProject.junitVersion
testImplementation 'androidx.test:core:' + rootProject.coreVersion
androidTestImplementation 'junit:junit:' + rootProject.junitVersion
androidTestImplementation 'androidx.test:core:' + rootProject.coreVersion
androidTestImplementation 'androidx.test.ext:junit:' + rootProject.extJUnitVersion
androidTestImplementation 'androidx.test:rules:' + rootProject.rulesVersion
androidTestImplementation 'androidx.test:runner:' + rootProject.runnerVersion
}
apply from: rootProject.file('gradle/release.gradle')
task sourcesJar(type: Jar) {
from android.sourceSets.main.java.srcDirs
classifier = 'sources'
}
def getLibtorchHeadersDir() {
def abi = ABI_FILTERS.split(",")[0]
return "$rootDir/pytorch_android/src/main/cpp/libtorch_include/$abi"
}
afterEvaluate {
if (POM_PACKAGING == 'aar') {
android.libraryVariants.all { variant ->
variant.outputs.each { output ->
File f = output.outputFile
if (f.name.endsWith(".aar")) {
output.assemble.finalizedBy addFolderToAarTask(
"addHeadersToAar" + variant.name,
f.path,
getLibtorchHeadersDir(),
"headers")
}
}
}
}
}
tasks.whenTaskAdded { task ->
if (task.name.startsWith("bundle") && task.name.endsWith("Aar")) {
doLast {
addFolderToAar("addHeadersTo" + task.name, task.archivePath, getLibtorchHeadersDir(), 'headers')
}
}
}
def addFolderToAarTask(taskName, aarPath, folderPath, folderPathInAar) {
return tasks.register(taskName) {
doLast {
addFolderToAar(taskName, aarPath, folderPath, folderPathInAar)
}
}
}
def addFolderToAar(taskName, aarPath, folderPath, folderPathInAar) {
def tmpDir = file("${buildDir}/${taskName}")
tmpDir.mkdir()
def tmpDirFolder = file("${tmpDir.path}/${folderPathInAar}")
tmpDirFolder.mkdir()
copy {
from zipTree(aarPath)
into tmpDir
}
copy {
from fileTree(folderPath)
into tmpDirFolder
}
ant.zip(destfile: aarPath) {
fileset(dir: tmpDir.path)
}
delete tmpDir
}
artifacts.add('archives', sourcesJar)

View File

@ -0,0 +1,20 @@
#include <torch/csrc/jit/api/module.h>
#include <torch/jit.h>
#include <torch/script.h>
#include <fstream>
#include <iostream>
#include <string>
int main(int argc, char* argv[]) {
std::string input_file_path{argv[1]};
std::string output_file_path{argv[2]};
std::ifstream ifs(input_file_path);
std::stringstream buffer;
buffer << ifs.rdbuf();
torch::jit::Module m("TestModule");
m.define(buffer.str());
m.save(output_file_path);
}

View File

@ -0,0 +1,151 @@
from typing import Optional
import torch
from torch import Tensor
OUTPUT_DIR = "src/androidTest/assets/"
def scriptAndSave(module, fileName):
print("-" * 80)
script_module = torch.jit.script(module)
print(script_module.graph)
outputFileName = OUTPUT_DIR + fileName
# note that the lite interpreter model can also be used in full JIT
script_module._save_for_lite_interpreter(outputFileName)
print("Saved to " + outputFileName)
print("=" * 80)
class Test(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return None
@torch.jit.script_method
def eqBool(self, input: bool) -> bool:
return input
@torch.jit.script_method
def eqInt(self, input: int) -> int:
return input
@torch.jit.script_method
def eqFloat(self, input: float) -> float:
return input
@torch.jit.script_method
def eqStr(self, input: str) -> str:
return input
@torch.jit.script_method
def eqTensor(self, input: Tensor) -> Tensor:
return input
@torch.jit.script_method
def eqDictStrKeyIntValue(self, input: dict[str, int]) -> dict[str, int]:
return input
@torch.jit.script_method
def eqDictIntKeyIntValue(self, input: dict[int, int]) -> dict[int, int]:
return input
@torch.jit.script_method
def eqDictFloatKeyIntValue(self, input: dict[float, int]) -> dict[float, int]:
return input
@torch.jit.script_method
def listIntSumReturnTuple(self, input: list[int]) -> tuple[list[int], int]:
sum = 0
for x in input:
sum += x
return (input, sum)
@torch.jit.script_method
def listBoolConjunction(self, input: list[bool]) -> bool:
res = True
for x in input:
res = res and x
return res
@torch.jit.script_method
def listBoolDisjunction(self, input: list[bool]) -> bool:
res = False
for x in input:
res = res or x
return res
@torch.jit.script_method
def tupleIntSumReturnTuple(
self, input: tuple[int, int, int]
) -> tuple[tuple[int, int, int], int]:
sum = 0
for x in input:
sum += x
return (input, sum)
@torch.jit.script_method
def optionalIntIsNone(self, input: Optional[int]) -> bool:
return input is None
@torch.jit.script_method
def intEq0None(self, input: int) -> Optional[int]:
if input == 0:
return None
return input
@torch.jit.script_method
def str3Concat(self, input: str) -> str:
return input + input + input
@torch.jit.script_method
def newEmptyShapeWithItem(self, input):
return torch.tensor([int(input.item())])[0]
@torch.jit.script_method
def testAliasWithOffset(self) -> list[Tensor]:
x = torch.tensor([100, 200])
a = [x[0], x[1]]
return a
@torch.jit.script_method
def testNonContiguous(self):
x = torch.tensor([100, 200, 300])[::2]
assert not x.is_contiguous()
assert x[0] == 100
assert x[1] == 300
return x
@torch.jit.script_method
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.nn.functional.conv2d(x, w)
if toChannelsLast:
r = r.contiguous(memory_format=torch.channels_last)
else:
r = r.contiguous()
return r
@torch.jit.script_method
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.nn.functional.conv3d(x, w)
if toChannelsLast:
r = r.contiguous(memory_format=torch.channels_last_3d)
else:
r = r.contiguous()
return r
@torch.jit.script_method
def contiguous(self, x: Tensor) -> Tensor:
return x.contiguous()
@torch.jit.script_method
def contiguousChannelsLast(self, x: Tensor) -> Tensor:
return x.contiguous(memory_format=torch.channels_last)
@torch.jit.script_method
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
return x.contiguous(memory_format=torch.channels_last_3d)
scriptAndSave(Test(), "test.pt")

View File

@ -0,0 +1,4 @@
POM_NAME=pytorch_android_lite pytorch android api
POM_DESCRIPTION=pytorch_android_lite pytorch android api
POM_ARTIFACT_ID=pytorch_android_lite
POM_PACKAGING=aar

View File

@ -0,0 +1,42 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the Apache-2 license found in the
// LICENSE file in the root directory of this source tree.
plugins {
id 'java-library'
}
repositories {
mavenLocal()
jcenter()
}
sourceSets {
main {
java {
srcDir '../src/main/java'
exclude 'org/pytorch/PyTorchAndroid.java'
exclude 'org/pytorch/LitePyTorchAndroid.java'
exclude 'org/pytorch/LiteModuleLoader.java'
exclude 'org/pytorch/LiteNativePeer.java'
}
}
test {
java {
srcDir '../src/androidTest/java'
exclude '**/PytorchInstrumented*'
exclude '**/PytorchLiteInstrumented*'
}
resources.srcDirs = ["../src/androidTest/assets"]
}
}
dependencies {
compileOnly 'com.google.code.findbugs:jsr305:3.0.1'
implementation 'com.facebook.soloader:nativeloader:0.10.1'
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
testImplementation 'junit:junit:4.12'
}
apply from: rootProject.file('gradle/release.gradle')

View File

@ -0,0 +1,4 @@
POM_NAME=pytorch_java_only pytorch java api
POM_DESCRIPTION=pytorch_java_only pytorch java api
POM_ARTIFACT_ID=pytorch_java_only
POM_PACKAGING=jar

Binary file not shown.

View File

@ -0,0 +1,21 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <gtest/gtest.h>
#include <ATen/core/type_factory.h>
#include "caffe2/android/pytorch_android/src/main/cpp/pytorch_jni_common.h"
using namespace ::testing;
TEST(pytorch_jni_common_test, newJIValueFromAtIValue) {
auto dict = c10::impl::GenericDict(
c10::dynT<c10::IntType>(), c10::dynT<c10::StringType>());
auto dictCallback = [](auto&&) {
return facebook::jni::local_ref<pytorch_jni::JIValue>{};
};
EXPECT_NO_THROW(pytorch_jni::JIValue::newJIValueFromAtIValue(
dict, dictCallback, dictCallback));
}

View File

@ -0,0 +1,25 @@
package org.pytorch;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.Objects;
public class PytorchHostTests extends PytorchTestBase {
@Override
protected Module loadModel(String path) throws IOException {
return Module.load(assetFilePath(path));
}
private String assetFilePath(String assetName) throws IOException {
Path tempFile = Files.createTempFile("test", ".pt");
try (InputStream resource =
Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {
Files.copy(resource, tempFile, StandardCopyOption.REPLACE_EXISTING);
}
return tempFile.toAbsolutePath().toString();
}
}

View File

@ -0,0 +1,42 @@
package org.pytorch;
import android.content.Context;
import androidx.test.InstrumentationRegistry;
import androidx.test.runner.AndroidJUnit4;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import org.junit.runner.RunWith;
@RunWith(AndroidJUnit4.class)
public class PytorchInstrumentedTests extends PytorchTestBase {
@Override
protected Module loadModel(String path) throws IOException {
return Module.load(assetFilePath(path));
}
private String assetFilePath(String assetName) throws IOException {
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
File file = new File(appContext.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = appContext.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
throw e;
}
}
}

View File

@ -0,0 +1,42 @@
package org.pytorch;
import android.content.Context;
import androidx.test.InstrumentationRegistry;
import androidx.test.runner.AndroidJUnit4;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import org.junit.runner.RunWith;
@RunWith(AndroidJUnit4.class)
public class PytorchLiteInstrumentedTests extends PytorchTestBase {
@Override
protected Module loadModel(String path) throws IOException {
return LiteModuleLoader.load(assetFilePath(path));
}
private String assetFilePath(String assetName) throws IOException {
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
File file = new File(appContext.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = appContext.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
throw e;
}
}
}

View File

@ -0,0 +1,694 @@
package org.pytorch;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
import org.junit.Ignore;
public abstract class PytorchTestBase {
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
@Test
public void testForwardNull() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
assertTrue(input.isTensor());
final IValue output = module.forward(input);
assertTrue(output.isNull());
}
@Test
public void testEqBool() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (boolean value : new boolean[] {false, true}) {
final IValue input = IValue.from(value);
assertTrue(input.isBool());
assertTrue(value == input.toBool());
final IValue output = module.runMethod("eqBool", input);
assertTrue(output.isBool());
assertTrue(value == output.toBool());
}
}
@Test
public void testEqInt() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
final IValue input = IValue.from(value);
assertTrue(input.isLong());
assertTrue(value == input.toLong());
final IValue output = module.runMethod("eqInt", input);
assertTrue(output.isLong());
assertTrue(value == output.toLong());
}
}
@Test
public void testEqFloat() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
double[] values =
new double[] {
-Double.MAX_VALUE,
Double.MAX_VALUE,
-Double.MIN_VALUE,
Double.MIN_VALUE,
-Math.exp(1.d),
-Math.sqrt(2.d),
-3.1415f,
3.1415f,
-1,
0,
1,
};
for (double value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isDouble());
assertTrue(value == input.toDouble());
final IValue output = module.runMethod("eqFloat", input);
assertTrue(output.isDouble());
assertTrue(value == output.toDouble());
}
}
@Test
public void testEqTensor() throws IOException {
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
final long numElements = Tensor.numel(inputTensorShape);
final float[] inputTensorData = new float[(int) numElements];
for (int i = 0; i < numElements; ++i) {
inputTensorData[i] = i;
}
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(inputTensor);
assertTrue(input.isTensor());
assertTrue(inputTensor == input.toTensor());
final IValue output = module.runMethod("eqTensor", input);
assertTrue(output.isTensor());
final Tensor outputTensor = output.toTensor();
assertNotNull(outputTensor);
assertArrayEquals(inputTensorShape, outputTensor.shape());
float[] outputData = outputTensor.getDataAsFloatArray();
for (int i = 0; i < numElements; i++) {
assertTrue(inputTensorData[i] == outputData[i]);
}
}
@Test
public void testEqDictIntKeyIntValue() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<Long, IValue> inputMap = new HashMap<>();
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
inputMap.put(0l, IValue.from(0l));
inputMap.put(1l, IValue.from(-1l));
inputMap.put(-1l, IValue.from(1l));
final IValue input = IValue.dictLongKeyFrom(inputMap);
assertTrue(input.isDictLongKey());
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
assertTrue(output.isDictLongKey());
final Map<Long, IValue> outputMap = output.toDictLongKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@Test
public void testEqDictStrKeyIntValue() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<String, IValue> inputMap = new HashMap<>();
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
inputMap.put("long_0", IValue.from(0l));
inputMap.put("long_1", IValue.from(1l));
inputMap.put("long_-1", IValue.from(-1l));
final IValue input = IValue.dictStringKeyFrom(inputMap);
assertTrue(input.isDictStringKey());
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
assertTrue(output.isDictStringKey());
final Map<String, IValue> outputMap = output.toDictStringKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@Test
public void testListIntSumReturnTuple() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (int n : new int[] {0, 1, 128}) {
long[] a = new long[n];
long sum = 0;
for (int i = 0; i < n; i++) {
a[i] = i;
sum += a[i];
}
final IValue input = IValue.listFrom(a);
assertTrue(input.isLongList());
final IValue output = module.runMethod("listIntSumReturnTuple", input);
assertTrue(output.isTuple());
assertTrue(2 == output.toTuple().length);
IValue output0 = output.toTuple()[0];
IValue output1 = output.toTuple()[1];
assertArrayEquals(a, output0.toLongList());
assertTrue(sum == output1.toLong());
}
}
@Test
public void testOptionalIntIsNone() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
}
@Test
public void testIntEq0None() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
}
@Test(expected = IllegalArgumentException.class)
public void testRunUndefinedMethod() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
module.runMethod("test_undefined_method_throws_exception");
}
@Test
public void testTensorMethods() {
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
int[] ints = new int[numel];
float[] floats = new float[numel];
byte[] bytes = new byte[numel];
for (int i = 0; i < numel; i++) {
bytes[i] = (byte) ((i % 255) - 128);
ints[i] = i;
floats[i] = i / 1000.f;
}
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
assertTrue(tensorBytes.dtype() == DType.INT8);
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
Tensor tensorInts = Tensor.fromBlob(ints, shape);
assertTrue(tensorInts.dtype() == DType.INT32);
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
float[] floatsOut = tensorFloats.getDataAsFloatArray();
assertTrue(floatsOut.length == numel);
for (int i = 0; i < numel; i++) {
assertTrue(floats[i] == floatsOut[i]);
}
}
@Test(expected = IllegalStateException.class)
public void testTensorIllegalStateOnWrongType() {
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
float[] floats = new float[numel];
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
tensorFloats.getDataAsByteArray();
}
@Test
public void testEqString() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[] {
"smoketest",
"проверка не латинских символов", // not latin symbols check
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("eqStr", input);
assertTrue(output.isString());
assertTrue(value.equals(output.toStr()));
}
}
@Test
public void testStr3Concat() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[] {
"smoketest",
"проверка не латинских символов", // not latin symbols check
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("str3Concat", input);
assertTrue(output.isString());
String expectedOutput =
new StringBuilder().append(value).append(value).append(value).toString();
assertTrue(expectedOutput.equals(output.toStr()));
}
}
@Test
public void testEmptyShape() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final long someNumber = 43;
final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {}));
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
assertTrue(output.isTensor());
Tensor value = output.toTensor();
assertArrayEquals(new long[] {}, value.shape());
assertArrayEquals(new long[] {someNumber}, value.getDataAsLongArray());
}
@Test
public void testAliasWithOffset() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testAliasWithOffset");
assertTrue(output.isTensorList());
Tensor[] tensors = output.toTensorList();
assertEquals(100, tensors[0].getDataAsLongArray()[0]);
assertEquals(200, tensors[1].getDataAsLongArray()[0]);
}
@Test
public void testNonContiguous() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testNonContiguous");
assertTrue(output.isTensor());
Tensor value = output.toTensor();
assertArrayEquals(new long[] {2}, value.shape());
assertArrayEquals(new long[] {100, 300}, value.getDataAsLongArray());
}
@Test
public void testChannelsLast() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2};
long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
assertIValueTensor(
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2},
new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104});
final IValue outputNHWC = module.runMethod("contiguousChannelsLast", IValue.from(inputNHWC));
assertIValueTensor(outputNHWC, MemoryFormat.CHANNELS_LAST, inputShape, data);
}
@Test
public void testChannelsLast3d() throws IOException {
long[] shape = new long[] {1, 2, 2, 2, 2};
long[] dataNCHWD = new long[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
Tensor inputNCHWD = Tensor.fromBlob(dataNCHWD, shape, MemoryFormat.CONTIGUOUS);
final IValue outputNHWDC =
module.runMethod("contiguousChannelsLast3d", IValue.from(inputNCHWD));
assertIValueTensor(outputNHWDC, MemoryFormat.CHANNELS_LAST_3D, shape, dataNHWDC);
}
@Test
public void testChannelsLastConv2d() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2};
long[] dataNCHW = new long[] {
111, 112,
121, 122,
211, 212,
221, 222,
311, 312,
321, 322};
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNHWC = new long[] {
111, 211, 311, 112, 212, 312,
121, 221, 321, 122, 222, 322};
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
long[] weightShape = new long[] {3, 3, 1, 1};
long[] dataWeightOIHW = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1};
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
long[] dataWeightOHWI = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1};
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW =
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
assertIValueTensor(
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2},
new long[] {
2*111, 2*112,
2*121, 2*122,
211, 212,
221, 222,
-311, -312,
-321, -322});
final IValue outputNHWC =
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
assertIValueTensor(
outputNHWC,
MemoryFormat.CHANNELS_LAST,
new long[] {1, 3, 2, 2},
new long[] {
2*111, 211, -311, 2*112, 212, -312,
2*121, 221, -321, 2*122, 222, -322});
}
@Test
public void testChannelsLastConv3d() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2, 2};
long[] dataNCDHW = new long[] {
1111, 1112,
1121, 1122,
1211, 1212,
1221, 1222,
2111, 2112,
2121, 2122,
2211, 2212,
2221, 2222,
3111, 3112,
3121, 3122,
3211, 3212,
3221, 3222};
Tensor inputNCDHW = Tensor.fromBlob(dataNCDHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNDHWC = new long[] {
1111, 2111, 3111,
1112, 2112, 3112,
1121, 2121, 3121,
1122, 2122, 3122,
1211, 2211, 3211,
1212, 2212, 3212,
1221, 2221, 3221,
1222, 2222, 3222};
Tensor inputNDHWC = Tensor.fromBlob(dataNDHWC, inputShape, MemoryFormat.CHANNELS_LAST_3D);
long[] weightShape = new long[] {3, 3, 1, 1, 1};
long[] dataWeightOIDHW = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1,
};
Tensor wNCDHW = Tensor.fromBlob(dataWeightOIDHW, weightShape, MemoryFormat.CONTIGUOUS);
long[] dataWeightODHWI = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1,
};
Tensor wNDHWC = Tensor.fromBlob(dataWeightODHWI, weightShape, MemoryFormat.CHANNELS_LAST_3D);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCDHW =
module.runMethod("conv3d", IValue.from(inputNCDHW), IValue.from(wNCDHW), IValue.from(false));
assertIValueTensor(
outputNCDHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2, 2},
new long[] {
2*1111, 2*1112, 2*1121, 2*1122,
2*1211, 2*1212, 2*1221, 2*1222,
2111, 2112, 2121, 2122,
2211, 2212, 2221, 2222,
-3111, -3112, -3121, -3122,
-3211, -3212, -3221, -3222});
final IValue outputNDHWC =
module.runMethod("conv3d", IValue.from(inputNDHWC), IValue.from(wNDHWC), IValue.from(true));
assertIValueTensor(
outputNDHWC,
MemoryFormat.CHANNELS_LAST_3D,
new long[] {1, 3, 2, 2, 2},
new long[] {
2*1111, 2111, -3111, 2*1112, 2112, -3112,
2*1121, 2121, -3121, 2*1122, 2122, -3122,
2*1211, 2211, -3211, 2*1212, 2212, -3212,
2*1221, 2221, -3221, 2*1222, 2222, -3222});
}
@Test
public void testMobileNetV2() throws IOException {
try {
final Module module = loadModel("mobilenet_v2.ptl");
final IValue inputs = module.runMethod("get_all_bundled_inputs");
assertTrue(inputs.isList());
final IValue input = inputs.toList()[0];
assertTrue(input.isTuple());
module.forward(input.toTuple()[0]);
assertTrue(true);
} catch (Exception ex) {
assertTrue("failed to run MobileNetV2 " + ex.getMessage(), false);
}
}
@Test
public void testPointwiseOps() throws IOException {
runModel("pointwise_ops");
}
@Test
public void testReductionOps() throws IOException {
runModel("reduction_ops");
}
@Test
public void testComparisonOps() throws IOException {
runModel("comparison_ops");
}
@Test
public void testOtherMathOps() throws IOException {
runModel("other_math_ops");
}
@Test
@Ignore
public void testSpectralOps() throws IOException {
// NB: This model fails without lite interpreter. The error is as follows:
// RuntimeError: stft requires the return_complex parameter be given for real inputs
runModel("spectral_ops");
}
@Test
public void testBlasLapackOps() throws IOException {
runModel("blas_lapack_ops");
}
@Test
public void testSamplingOps() throws IOException {
runModel("sampling_ops");
}
@Test
public void testTensorOps() throws IOException {
runModel("tensor_general_ops");
}
@Test
public void testTensorCreationOps() throws IOException {
runModel("tensor_creation_ops");
}
@Test
public void testTensorIndexingOps() throws IOException {
runModel("tensor_indexing_ops");
}
@Test
public void testTensorTypingOps() throws IOException {
runModel("tensor_typing_ops");
}
@Test
public void testTensorViewOps() throws IOException {
runModel("tensor_view_ops");
}
@Test
public void testConvolutionOps() throws IOException {
runModel("convolution_ops");
}
@Test
public void testPoolingOps() throws IOException {
runModel("pooling_ops");
}
@Test
public void testPaddingOps() throws IOException {
runModel("padding_ops");
}
@Test
public void testActivationOps() throws IOException {
runModel("activation_ops");
}
@Test
public void testNormalizationOps() throws IOException {
runModel("normalization_ops");
}
@Test
public void testRecurrentOps() throws IOException {
runModel("recurrent_ops");
}
@Test
public void testTransformerOps() throws IOException {
runModel("transformer_ops");
}
@Test
public void testLinearOps() throws IOException {
runModel("linear_ops");
}
@Test
public void testDropoutOps() throws IOException {
runModel("dropout_ops");
}
@Test
public void testSparseOps() throws IOException {
runModel("sparse_ops");
}
@Test
public void testDistanceFunctionOps() throws IOException {
runModel("distance_function_ops");
}
@Test
public void testLossFunctionOps() throws IOException {
runModel("loss_function_ops");
}
@Test
public void testVisionFunctionOps() throws IOException {
runModel("vision_function_ops");
}
@Test
public void testShuffleOps() throws IOException {
runModel("shuffle_ops");
}
@Test
public void testNNUtilsOps() throws IOException {
runModel("nn_utils_ops");
}
@Test
public void testQuantOps() throws IOException {
runModel("general_quant_ops");
}
@Test
public void testDynamicQuantOps() throws IOException {
runModel("dynamic_quant_ops");
}
@Test
public void testStaticQuantOps() throws IOException {
runModel("static_quant_ops");
}
@Test
public void testFusedQuantOps() throws IOException {
runModel("fused_quant_ops");
}
@Test
public void testTorchScriptBuiltinQuantOps() throws IOException {
runModel("torchscript_builtin_ops");
}
@Test
public void testTorchScriptCollectionQuantOps() throws IOException {
runModel("torchscript_collection_ops");
}
static void assertIValueTensor(
final IValue ivalue,
final MemoryFormat memoryFormat,
final long[] expectedShape,
final long[] expectedData) {
assertTrue(ivalue.isTensor());
Tensor t = ivalue.toTensor();
assertEquals(memoryFormat, t.memoryFormat());
assertArrayEquals(expectedShape, t.shape());
assertArrayEquals(expectedData, t.getDataAsLongArray());
}
void runModel(final String name) throws IOException {
final Module storage_module = loadModel(name + ".ptl");
storage_module.forward();
// TODO enable this once the on-the-fly script is ready
// final Module on_the_fly_module = loadModel(name + "_temp.ptl");
// on_the_fly_module.forward();
assertTrue(true);
}
protected abstract Module loadModel(String assetName) throws IOException;
}

View File

@ -0,0 +1,9 @@
package org.pytorch.suite;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.pytorch.PytorchInstrumentedTests;
@RunWith(Suite.class)
@Suite.SuiteClasses({PytorchInstrumentedTests.class})
public class PytorchInstrumentedTestSuite {}

View File

@ -0,0 +1,9 @@
package org.pytorch.suite;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.pytorch.PytorchLiteInstrumentedTests;
@RunWith(Suite.class)
@Suite.SuiteClasses({PytorchLiteInstrumentedTests.class})
public class PytorchLiteInstrumentedTestSuite {}

View File

@ -0,0 +1 @@
<manifest package="org.pytorch" />

View File

@ -0,0 +1,3 @@
#pragma once
/* #undef TRACE_ENABLED */

View File

@ -0,0 +1,3 @@
#pragma once
#cmakedefine TRACE_ENABLED

View File

@ -0,0 +1,697 @@
#include <cassert>
#include <iostream>
#include <memory>
#include <string>
#include <c10/core/MemoryFormat.h>
#include <c10/util/irange.h>
#include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h>
#include "pytorch_jni_common.h"
#if defined(__ANDROID__)
#ifndef USE_PTHREADPOOL
#define USE_PTHREADPOOL
#endif /* USE_PTHREADPOOL */
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#endif
namespace pytorch_jni {
c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode) {
if (deviceJniCode == kDeviceCPU) {
return at::kCPU;
} else if (deviceJniCode == kDeviceVulkan) {
return at::kVulkan;
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException, "Unknown device");
}
bool Trace::is_initialized_ = false;
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
Trace::fp_ATrace_beginSection Trace::ATrace_beginSection;
Trace::fp_ATrace_endSection Trace::ATrace_endSection;
#endif
void Trace::init() {
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
void* lib = dlopen("libandroid.so", RTLD_NOW || RTLD_LOCAL);
if (lib != NULL) {
Trace::ATrace_beginSection = reinterpret_cast<fp_ATrace_beginSection>(
dlsym(lib, "ATrace_beginSection"));
Trace::ATrace_endSection =
reinterpret_cast<fp_ATrace_endSection>(dlsym(lib, "ATrace_endSection"));
}
#endif
}
// NOTE: Codes must be kept in sync with DType.java.
// NOTE: Never serialize these, because they can change between releases.
constexpr static int kTensorDTypeUInt8 = 1;
constexpr static int kTensorDTypeInt8 = 2;
constexpr static int kTensorDTypeInt32 = 3;
constexpr static int kTensorDTypeFloat32 = 4;
constexpr static int kTensorDTypeInt64 = 5;
constexpr static int kTensorDTypeFloat64 = 6;
constexpr static int kTensorMemoryFormatContiguous = 1;
constexpr static int kTensorMemoryFormatChannelsLast = 2;
constexpr static int kTensorMemoryFormatChannelsLast3d = 3;
template <typename K = jobject, typename V = jobject>
struct JHashMap
: facebook::jni::JavaClass<JHashMap<K, V>, facebook::jni::JMap<K, V>> {
constexpr static auto kJavaDescriptor = "Ljava/util/HashMap;";
using Super =
facebook::jni::JavaClass<JHashMap<K, V>, facebook::jni::JMap<K, V>>;
static facebook::jni::local_ref<JHashMap<K, V>> create() {
return Super::newInstance();
}
void put(
facebook::jni::alias_ref<facebook::jni::JObject::javaobject> key,
facebook::jni::alias_ref<facebook::jni::JObject::javaobject> value) {
static auto putMethod =
Super::javaClassStatic()
->template getMethod<facebook::jni::alias_ref<
facebook::jni::JObject::javaobject>(
facebook::jni::alias_ref<facebook::jni::JObject::javaobject>,
facebook::jni::alias_ref<facebook::jni::JObject::javaobject>)>(
"put");
putMethod(Super::self(), key, value);
}
};
static at::Tensor newAtTensor(
facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,
facebook::jni::alias_ref<jlongArray> jshape,
jint jdtype,
jint jmemoryFormat) {
const auto rank = jshape->size();
const auto shapeArr = jshape->getRegion(0, rank);
std::vector<int64_t> shapeVec{};
shapeVec.reserve(rank);
auto numel = 1;
for (const auto i : c10::irange(rank)) {
shapeVec.push_back(shapeArr[i]);
numel *= shapeArr[i];
}
JNIEnv* jni = facebook::jni::Environment::current();
caffe2::TypeMeta typeMeta{};
int dataElementSizeBytes = 0;
if (kTensorDTypeFloat32 == jdtype) {
dataElementSizeBytes = 4;
typeMeta = caffe2::TypeMeta::Make<float>();
} else if (kTensorDTypeInt32 == jdtype) {
dataElementSizeBytes = 4;
typeMeta = caffe2::TypeMeta::Make<int32_t>();
} else if (kTensorDTypeInt8 == jdtype) {
dataElementSizeBytes = 1;
typeMeta = caffe2::TypeMeta::Make<int8_t>();
} else if (kTensorDTypeUInt8 == jdtype) {
dataElementSizeBytes = 1;
typeMeta = caffe2::TypeMeta::Make<uint8_t>();
} else if (kTensorDTypeFloat64 == jdtype) {
dataElementSizeBytes = 8;
typeMeta = caffe2::TypeMeta::Make<double>();
} else if (kTensorDTypeInt64 == jdtype) {
dataElementSizeBytes = 8;
typeMeta = caffe2::TypeMeta::Make<int64_t>();
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown Tensor jdtype %d",
jdtype);
}
const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
if (dataCapacity != numel) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Tensor dimensions(elements number:%d, element byte size:%d, total "
"bytes:%d) inconsistent with buffer capacity(%d)",
numel,
dataElementSizeBytes,
numel * dataElementSizeBytes,
dataCapacity);
}
if (jmemoryFormat == kTensorMemoryFormatChannelsLast) {
auto sizes = torch::IntArrayRef(shapeVec);
return torch::from_blob(
jni->GetDirectBufferAddress(jbuffer.get()),
sizes,
torch::IntArrayRef(c10::get_channels_last_strides_2d(sizes)),
at::TensorOptions(typeMeta).memory_format(
at::MemoryFormat::ChannelsLast));
} else if (jmemoryFormat == kTensorMemoryFormatChannelsLast3d) {
auto sizes = torch::IntArrayRef(shapeVec);
return torch::from_blob(
jni->GetDirectBufferAddress(jbuffer.get()),
sizes,
torch::IntArrayRef(c10::get_channels_last_strides_3d(sizes)),
at::TensorOptions(typeMeta).memory_format(
at::MemoryFormat::ChannelsLast3d));
}
return torch::from_blob(
jni->GetDirectBufferAddress(jbuffer.get()),
torch::IntArrayRef(shapeVec),
at::TensorOptions(typeMeta));
}
class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
public:
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/Tensor;";
explicit TensorHybrid(at::Tensor tensor) : tensor_(tensor) {}
static facebook::jni::local_ref<TensorHybrid::jhybriddata> initHybrid(
facebook::jni::alias_ref<TensorHybrid::javaobject> jTensorThis) {
static auto cls = TensorHybrid::javaClassStatic();
static const auto jMethodDTypeCode = cls->getMethod<jint()>("dtypeJniCode");
static const auto jMethodMemoryFormatCode =
cls->getMethod<jint()>("memoryFormatJniCode");
static const auto jFieldShape = cls->getField<jlongArray>("shape");
static const auto jMethodGetDataBuffer = cls->getMethod<
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
"getRawDataBuffer");
at::Tensor tensor = newAtTensor(
jMethodGetDataBuffer(jTensorThis),
jTensorThis->getFieldValue(jFieldShape),
jMethodDTypeCode(jTensorThis),
jMethodMemoryFormatCode(jTensorThis));
return makeCxxInstance(std::move(tensor));
}
static facebook::jni::local_ref<TensorHybrid::javaobject>
newJTensorFromAtTensor(const at::Tensor& input_tensor) {
// Java wrapper currently only supports contiguous tensors.
int jmemoryFormat = 0;
at::Tensor tensor{};
if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast)) {
tensor = input_tensor;
jmemoryFormat = kTensorMemoryFormatChannelsLast;
} else if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
tensor = input_tensor;
jmemoryFormat = kTensorMemoryFormatChannelsLast3d;
} else {
tensor = input_tensor.contiguous();
jmemoryFormat = kTensorMemoryFormatContiguous;
}
const auto scalarType = tensor.scalar_type();
int jdtype = 0;
if (at::kFloat == scalarType) {
jdtype = kTensorDTypeFloat32;
} else if (at::kInt == scalarType) {
jdtype = kTensorDTypeInt32;
} else if (at::kByte == scalarType) {
jdtype = kTensorDTypeUInt8;
} else if (at::kChar == scalarType) {
jdtype = kTensorDTypeInt8;
} else if (at::kLong == scalarType) {
jdtype = kTensorDTypeInt64;
} else if (at::kDouble == scalarType) {
jdtype = kTensorDTypeFloat64;
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"at::Tensor scalar type %s is not supported on java side",
c10::toString(scalarType));
}
const auto& tensorShape = tensor.sizes();
std::vector<jlong> tensorShapeVec;
for (const auto& s : tensorShape) {
tensorShapeVec.push_back(s);
}
facebook::jni::local_ref<jlongArray> jTensorShape =
facebook::jni::make_long_array(tensorShapeVec.size());
jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data());
static auto cls = TensorHybrid::javaClassStatic();
facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
facebook::jni::JByteBuffer::wrapBytes(
(uint8_t*)tensor.data_ptr(), tensor.nbytes());
jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
static const auto jMethodNewTensor =
cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
facebook::jni::alias_ref<jlongArray>,
jint,
jint,
facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
return jMethodNewTensor(
cls,
jTensorBuffer,
jTensorShape,
jdtype,
jmemoryFormat,
makeCxxInstance(tensor));
}
static at::Tensor newAtTensorFromJTensor(
facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor) {
static auto cls = TensorHybrid::javaClassStatic();
static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
jint jdtype = dtypeMethod(jtensor);
static const auto memoryFormatMethod =
cls->getMethod<jint()>("memoryFormatJniCode");
jint jmemoryFormat = memoryFormatMethod(jtensor);
static const auto shapeField = cls->getField<jlongArray>("shape");
auto jshape = jtensor->getFieldValue(shapeField);
static auto dataBufferMethod = cls->getMethod<
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
"getRawDataBuffer");
facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
dataBufferMethod(jtensor);
return newAtTensor(jbuffer, jshape, jdtype, jmemoryFormat);
}
at::Tensor tensor() const {
return tensor_;
}
private:
friend HybridBase;
at::Tensor tensor_;
};
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromStringDict(
c10::Dict<c10::IValue, c10::IValue> dict) {
static auto jMethodDictStringKey =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JMap<
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
"dictStringKeyFrom");
auto jmap = JHashMap<
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>::create();
for (auto& pair : dict) {
jmap->put(
facebook::jni::make_jstring(pair.key().toStringRef()),
JIValue::newJIValueFromAtIValue(pair.value()));
}
return jMethodDictStringKey(JIValue::javaClassStatic(), jmap);
}
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromIntDict(
c10::Dict<c10::IValue, c10::IValue> dict) {
static auto jMethodDictLongKey =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JMap<
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
"dictLongKeyFrom");
auto jmap = JHashMap<
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>::create();
for (auto& pair : dict) {
jmap->put(
facebook::jni::JLong::valueOf(pair.key().toInt()),
JIValue::newJIValueFromAtIValue(pair.value()));
}
return jMethodDictLongKey(JIValue::javaClassStatic(), jmap);
}
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
const at::IValue& ivalue,
DictCallback stringDictCallback,
DictCallback intDictCallback) {
Trace _s{"jni::JIValue::newJIValueFromAtIValue"};
if (ivalue.isNone()) {
static auto jMethodOptionalNull =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>()>(
"optionalNull");
return jMethodOptionalNull(JIValue::javaClassStatic());
} else if (ivalue.isTensor()) {
static auto jMethodTensor =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::local_ref<TensorHybrid::javaobject>)>("from");
const auto& tensor = ivalue.toTensor();
return jMethodTensor(
JIValue::javaClassStatic(),
TensorHybrid::newJTensorFromAtTensor(tensor));
} else if (ivalue.isBool()) {
static auto jMethodBool =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(jboolean)>(
"from");
return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool());
} else if (ivalue.isInt()) {
static auto jMethodInt =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(jlong)>("from");
return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt());
} else if (ivalue.isDouble()) {
static auto jMethodDouble =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(jdouble)>(
"from");
return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble());
} else if (ivalue.isString()) {
static auto jMethodString =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JString::javaobject>)>(
"from");
return jMethodString(
JIValue::javaClassStatic(),
facebook::jni::make_jstring(ivalue.toStringRef()));
} else if (ivalue.isTuple()) {
auto elementsVec = ivalue.toTupleRef().elements();
static auto jMethodTupleArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JArrayClass<
JIValue::javaobject>::javaobject>)>("tupleFrom");
auto jElementsArray =
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
elementsVec.size());
auto index = 0;
for (const auto& e : elementsVec) {
(*jElementsArray)[index++] = JIValue::newJIValueFromAtIValue(e);
}
return jMethodTupleArr(JIValue::javaClassStatic(), jElementsArray);
} else if (ivalue.isBoolList()) {
auto list = ivalue.toBoolList();
static auto jMethodBoolListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<jbooleanArray>)>("listFrom");
size_t n = list.size();
auto jArray = facebook::jni::make_boolean_array(n);
auto jArrayPinned = jArray->pin();
auto index = 0;
for (const auto& e : list) {
jArrayPinned[index++] = e;
}
return jMethodBoolListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isIntList()) {
auto list = ivalue.toIntList();
static auto jMethodLongListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<jlongArray>)>("listFrom");
size_t n = list.size();
auto jArray = facebook::jni::make_long_array(n);
auto jArrayPinned = jArray->pin();
auto index = 0;
for (const auto& e : list) {
jArrayPinned[index++] = e;
}
return jMethodLongListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isDoubleList()) {
auto list = ivalue.toDoubleList();
static auto jMethoDoubleListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<jdoubleArray>)>("listFrom");
size_t n = list.size();
auto jArray = facebook::jni::make_double_array(n);
auto jArrayPinned = jArray->pin();
auto index = 0;
for (const auto& e : list) {
jArrayPinned[index++] = e;
}
return jMethoDoubleListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isTensorList()) {
auto list = ivalue.toTensorList();
static auto jMethodTensorListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JArrayClass<
TensorHybrid::javaobject>::javaobject>)>("listFrom");
auto jArray =
facebook::jni::JArrayClass<TensorHybrid::javaobject>::newArray(
list.size());
auto index = 0;
for (const auto& e : list) {
(*jArray)[index++] = TensorHybrid::newJTensorFromAtTensor(e);
}
return jMethodTensorListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isList()) {
auto list = ivalue.toList();
static auto jMethodListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JArrayClass<
JIValue::javaobject>::javaobject>)>("listFrom");
auto jArray =
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(list.size());
auto index = 0;
for (const auto& e : list) {
(*jArray)[index++] = JIValue::newJIValueFromAtIValue(e);
}
return jMethodListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isGenericDict()) {
auto dict = ivalue.toGenericDict();
const auto keyType = dict.keyType();
if (!keyType) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown IValue-Dict key type");
}
if (*keyType == *c10::StringType::get()) {
return stringDictCallback(std::move(dict));
} else if (*keyType == *c10::IntType::get()) {
return intDictCallback(std::move(dict));
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unsupported IValue-Dict key type: %s",
keyType->str().c_str());
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unsupported IValue type %s",
ivalue.tagKind().c_str());
}
at::IValue JIValue::JIValueToAtIValue(
facebook::jni::alias_ref<JIValue> jivalue) {
Trace _s{"jni::JIValue::JIValueToAtIValue"};
static const auto typeCodeField =
JIValue::javaClassStatic()->getField<jint>("mTypeCode");
const auto typeCode = jivalue->getFieldValue(typeCodeField);
if (JIValue::kTypeCodeNull == typeCode) {
return at::IValue{};
} else if (JIValue::kTypeCodeTensor == typeCode) {
static const auto jMethodGetTensor =
JIValue::javaClassStatic()
->getMethod<facebook::jni::alias_ref<TensorHybrid::javaobject>()>(
"toTensor");
return TensorHybrid::newAtTensorFromJTensor(jMethodGetTensor(jivalue));
} else if (JIValue::kTypeCodeBool == typeCode) {
static const auto jMethodGetBool =
JIValue::javaClassStatic()->getMethod<jboolean()>("toBool");
// explicit cast to bool as jboolean is defined as uint8_t, IValue ctor
// for int will be called for jboolean
bool b = jMethodGetBool(jivalue);
return at::IValue{b};
} else if (JIValue::kTypeCodeLong == typeCode) {
static const auto jMethodGetLong =
JIValue::javaClassStatic()->getMethod<jlong()>("toLong");
return at::IValue{(int64_t)jMethodGetLong(jivalue)};
} else if (JIValue::kTypeCodeDouble == typeCode) {
static const auto jMethodGetDouble =
JIValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
return at::IValue{jMethodGetDouble(jivalue)};
} else if (JIValue::kTypeCodeString == typeCode) {
static const auto jMethodGetString =
JIValue::javaClassStatic()->getMethod<jstring()>("toStr");
return at::IValue{jMethodGetString(jivalue)->toStdString()};
} else if (JIValue::kTypeCodeTuple == typeCode) {
static const auto jMethodGetTuple =
JIValue::javaClassStatic()
->getMethod<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject()>(
"toTuple");
auto jarray = jMethodGetTuple(jivalue);
size_t n = jarray->size();
std::vector<at::IValue> elements;
elements.reserve(n);
for (const auto i : c10::irange(n)) {
auto jivalue_element = jarray->getElement(i);
auto element = JIValue::JIValueToAtIValue(jivalue_element);
elements.push_back(std::move(element));
}
return c10::ivalue::Tuple::create(std::move(elements));
} else if (JIValue::kTypeCodeBoolList == typeCode) {
static const auto jMethodGetBoolList =
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("toBoolList");
auto jArray = jMethodGetBoolList(jivalue);
auto jArrayPinned = jArray->pin();
size_t n = jArrayPinned.size();
c10::List<bool> list{};
list.reserve(n);
for (const auto i : c10::irange(n)) {
list.push_back(jArrayPinned[i]);
}
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeLongList == typeCode) {
static const auto jMethodGetLongList =
JIValue::javaClassStatic()->getMethod<jlongArray()>("toLongList");
auto jArray = jMethodGetLongList(jivalue);
auto jArrayPinned = jArray->pin();
size_t n = jArrayPinned.size();
c10::List<int64_t> list{};
list.reserve(n);
for (const auto i : c10::irange(n)) {
list.push_back(jArrayPinned[i]);
}
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeDoubleList == typeCode) {
static const auto jMethodGetDoubleList =
JIValue::javaClassStatic()->getMethod<jdoubleArray()>("toDoubleList");
auto jArray = jMethodGetDoubleList(jivalue);
auto jArrayPinned = jArray->pin();
size_t n = jArrayPinned.size();
c10::List<double> list{};
list.reserve(n);
for (const auto i : c10::irange(n)) {
list.push_back(jArrayPinned[i]);
}
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeTensorList == typeCode) {
static const auto jMethodGetTensorList =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JArrayClass<
TensorHybrid::javaobject>::javaobject()>("toTensorList");
auto jArray = jMethodGetTensorList(jivalue);
size_t n = jArray->size();
c10::List<at::Tensor> list{};
list.reserve(n);
for (const auto i : c10::irange(n)) {
list.push_back(
TensorHybrid::newAtTensorFromJTensor(jArray->getElement(i)));
}
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeList == typeCode) {
static const auto jMethodGetList =
JIValue::javaClassStatic()
->getMethod<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject()>(
"toList");
auto jarray = jMethodGetList(jivalue);
size_t n = jarray->size();
if (n == 0) {
return at::IValue{c10::impl::GenericList(c10::TensorType::get())};
}
auto jivalue_first_element = jarray->getElement(0);
auto first_element = JIValue::JIValueToAtIValue(jivalue_first_element);
c10::impl::GenericList list{c10::unshapedType(first_element.type())};
list.reserve(n);
list.push_back(first_element);
for (const auto i : c10::irange(1, n)) {
auto jivalue_element = jarray->getElement(i);
auto element = JIValue::JIValueToAtIValue(jivalue_element);
list.push_back(element);
}
return at::IValue{list};
} else if (JIValue::kTypeCodeDictStringKey == typeCode) {
static const auto jMethodGetDictStringKey =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JMap<jstring, JIValue::javaobject>::
javaobject()>("toDictStringKey");
auto jmap = jMethodGetDictStringKey(jivalue);
auto it = jmap->begin();
if (it == jmap->end()) {
return at::IValue{c10::impl::GenericDict(
c10::StringType::get(), c10::TensorType::get())};
}
auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
c10::impl::GenericDict dict{
c10::StringType::get(), c10::unshapedType(firstEntryValue.type())};
dict.insert(it->first->toStdString(), firstEntryValue);
it++;
for (; it != jmap->end(); it++) {
dict.insert(
it->first->toStdString(), JIValue::JIValueToAtIValue(it->second));
}
return at::IValue{dict};
} else if (JIValue::kTypeCodeDictLongKey == typeCode) {
static const auto jMethodGetDictLongKey =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JMap<
facebook::jni::JLong::javaobject,
JIValue::javaobject>::javaobject()>("toDictLongKey");
auto jmap = jMethodGetDictLongKey(jivalue);
auto it = jmap->begin();
if (it == jmap->end()) {
return at::IValue{
c10::impl::GenericDict(c10::IntType::get(), c10::TensorType::get())};
}
auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
c10::impl::GenericDict dict{
c10::IntType::get(), c10::unshapedType(firstEntryValue.type())};
dict.insert((int64_t)it->first->longValue(), firstEntryValue);
it++;
for (; it != jmap->end(); it++) {
dict.insert(
(int64_t)it->first->longValue(),
JIValue::JIValueToAtIValue(it->second));
}
return at::IValue{dict};
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown IValue typeCode %d",
typeCode);
}
#if defined(__ANDROID__)
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;";
static void registerNatives() {
javaClassStatic()->registerNatives({
makeNativeMethod(
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
});
}
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
caffe2::pthreadpool()->set_thread_count(numThreads);
}
};
#endif
void common_registerNatives() {
static const int once = []() {
#if defined(__ANDROID__)
pytorch_jni::PyTorchAndroidJni::registerNatives();
#endif
return 0;
}();
((void)once);
}
} // namespace pytorch_jni

View File

@ -0,0 +1,137 @@
#pragma once
#include <c10/util/FunctionRef.h>
#include <fbjni/fbjni.h>
#include <torch/csrc/api/include/torch/types.h>
#include "caffe2/serialize/read_adapter_interface.h"
#include "cmake_macros.h"
#ifdef __ANDROID__
#include <android/log.h>
#define ALOGI(...) \
__android_log_print(ANDROID_LOG_INFO, "pytorch-jni", __VA_ARGS__)
#define ALOGE(...) \
__android_log_print(ANDROID_LOG_ERROR, "pytorch-jni", __VA_ARGS__)
#endif
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
#include <android/trace.h>
#include <dlfcn.h>
#endif
namespace pytorch_jni {
constexpr static int kDeviceCPU = 1;
constexpr static int kDeviceVulkan = 2;
c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode);
class Trace {
public:
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
typedef void* (*fp_ATrace_beginSection)(const char* sectionName);
typedef void* (*fp_ATrace_endSection)(void);
static fp_ATrace_beginSection ATrace_beginSection;
static fp_ATrace_endSection ATrace_endSection;
#endif
static void ensureInit() {
if (!Trace::is_initialized_) {
init();
Trace::is_initialized_ = true;
}
}
static void beginSection(const char* name) {
Trace::ensureInit();
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
ATrace_beginSection(name);
#endif
}
static void endSection() {
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
ATrace_endSection();
#endif
}
Trace(const char* name) {
ensureInit();
beginSection(name);
}
~Trace() {
endSection();
}
private:
static void init();
static bool is_initialized_;
};
class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
public:
explicit MemoryReadAdapter(const void* data, off_t size)
: data_(data), size_(size){};
size_t size() const override {
return size_;
}
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override {
memcpy(buf, (int8_t*)(data_) + pos, n);
return n;
}
~MemoryReadAdapter() {}
private:
const void* data_;
off_t size_;
};
class JIValue : public facebook::jni::JavaClass<JIValue> {
using DictCallback = c10::function_ref<facebook::jni::local_ref<JIValue>(
c10::Dict<c10::IValue, c10::IValue>)>;
public:
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
constexpr static int kTypeCodeNull = 1;
constexpr static int kTypeCodeTensor = 2;
constexpr static int kTypeCodeBool = 3;
constexpr static int kTypeCodeLong = 4;
constexpr static int kTypeCodeDouble = 5;
constexpr static int kTypeCodeString = 6;
constexpr static int kTypeCodeTuple = 7;
constexpr static int kTypeCodeBoolList = 8;
constexpr static int kTypeCodeLongList = 9;
constexpr static int kTypeCodeDoubleList = 10;
constexpr static int kTypeCodeTensorList = 11;
constexpr static int kTypeCodeList = 12;
constexpr static int kTypeCodeDictStringKey = 13;
constexpr static int kTypeCodeDictLongKey = 14;
static facebook::jni::local_ref<JIValue> newJIValueFromAtIValue(
const at::IValue& ivalue,
DictCallback stringDictCallback = newJIValueFromStringDict,
DictCallback intDictCallback = newJIValueFromIntDict);
static at::IValue JIValueToAtIValue(
facebook::jni::alias_ref<JIValue> jivalue);
private:
static facebook::jni::local_ref<JIValue> newJIValueFromStringDict(
c10::Dict<c10::IValue, c10::IValue>);
static facebook::jni::local_ref<JIValue> newJIValueFromIntDict(
c10::Dict<c10::IValue, c10::IValue>);
};
void common_registerNatives();
} // namespace pytorch_jni

View File

@ -0,0 +1,245 @@
#include <cassert>
#include <iostream>
#include <memory>
#include <string>
#include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h>
#include <ATen/record_function.h>
#include <torch/csrc/jit/runtime/print_handler.h>
#include <torch/script.h>
#include "caffe2/serialize/read_adapter_interface.h"
#include "pytorch_jni_common.h"
#ifdef __ANDROID__
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include <android/log.h>
#endif
namespace pytorch_jni {
namespace {
struct JITCallGuard {
// Inference only workload.
c10::InferenceMode guard;
// Disable graph optimizer to ensure list of unused ops are not changed for
// custom mobile build.
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
};
} // namespace
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
private:
friend HybridBase;
torch::jit::Module module_;
c10::DeviceType deviceType_;
public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;";
static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> modelPath,
facebook::jni::alias_ref<
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
extraFiles,
jint device) {
return makeCxxInstance(modelPath, extraFiles, device);
}
#ifdef __ANDROID__
static facebook::jni::local_ref<jhybriddata> initHybridAndroidAsset(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
return makeCxxInstance(assetName, assetManager, device);
}
#endif
#ifdef TRACE_ENABLED
static std::unique_ptr<at::ObserverContext> onFunctionEnter(
const at::RecordFunction& fn) {
Trace::beginSection(fn.name().str());
return nullptr;
}
static void onFunctionExit(const at::RecordFunction&, at::ObserverContext*) {
Trace::endSection();
}
#endif
static void preModuleLoadSetupOnce() {
auto qengines = at::globalContext().supportedQEngines();
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
qengines.end()) {
at::globalContext().setQEngine(at::QEngine::QNNPACK);
}
#ifdef __ANDROID__
torch::jit::setPrintHandler([](const std::string& s) {
__android_log_print(ANDROID_LOG_DEBUG, "pytorch-print", "%s", s.c_str());
});
#endif
#ifdef TRACE_ENABLED
at::addGlobalCallback(
at::RecordFunctionCallback(&onFunctionEnter, &onFunctionExit)
.scopes({RecordScope::FUNCTION, RecordScope::USER_SCOPE}));
#endif
}
void preModuleLoadSetup() {
static const int once = []() {
preModuleLoadSetupOnce();
return 0;
}();
((void)once);
}
PytorchJni(
facebook::jni::alias_ref<jstring> modelPath,
facebook::jni::alias_ref<
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
extraFiles,
jint device) {
preModuleLoadSetup();
JITCallGuard guard;
std::unordered_map<std::string, std::string> extra_files;
const auto has_extra = extraFiles && extraFiles->size() > 0;
if (has_extra) {
for (const auto& e : *extraFiles) {
extra_files[e.first->toStdString()] = "";
}
}
deviceType_ = deviceJniCodeToDeviceType(device);
module_ = torch::jit::load(
std::move(modelPath->toStdString()), std::nullopt, extra_files);
if (has_extra) {
static auto putMethod =
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>::
javaClassStatic()
->template getMethod<facebook::jni::alias_ref<jobject>(
facebook::jni::alias_ref<jobject>,
facebook::jni::alias_ref<jobject>)>("put");
for (const auto& ef : extra_files) {
putMethod(
extraFiles,
facebook::jni::make_jstring(ef.first),
facebook::jni::make_jstring(ef.second));
}
}
module_.eval();
}
#ifdef __ANDROID__
PytorchJni(
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
preModuleLoadSetup();
JNIEnv* env = facebook::jni::Environment::current();
AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get());
if (!mgr) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unable to get asset manager");
}
AAsset* asset = AAssetManager_open(
mgr, assetName->toStdString().c_str(), AASSET_MODE_BUFFER);
if (!asset) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Failed to open asset '%s'",
assetName->toStdString().c_str());
}
auto assetBuffer = AAsset_getBuffer(asset);
if (!assetBuffer) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Could not get buffer for asset '%s'",
assetName->toStdString().c_str());
}
JITCallGuard guard;
module_ = torch::jit::load(std::make_unique<MemoryReadAdapter>(
assetBuffer, AAsset_getLength(asset)));
AAsset_close(asset);
module_.eval();
deviceType_ = deviceJniCodeToDeviceType(device);
}
#endif
static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", PytorchJni::initHybrid),
#ifdef __ANDROID__
makeNativeMethod(
"initHybridAndroidAsset", PytorchJni::initHybridAndroidAsset),
#endif
makeNativeMethod("forward", PytorchJni::forward),
makeNativeMethod("runMethod", PytorchJni::runMethod),
});
}
facebook::jni::local_ref<JIValue> forward(
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
jinputs) {
Trace _s{"jni::Module::forward"};
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
}
auto output = [&]() {
JITCallGuard guard;
return module_.forward(std::move(inputs));
}();
return JIValue::newJIValueFromAtIValue(output);
}
facebook::jni::local_ref<JIValue> runMethod(
facebook::jni::alias_ref<facebook::jni::JString::javaobject> jmethodName,
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
jinputs) {
std::string methodName = jmethodName->toStdString();
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
}
if (auto method = module_.find_method(methodName)) {
auto output = [&]() {
JITCallGuard guard;
return (*method)(std::move(inputs));
}();
return JIValue::newJIValueFromAtIValue(output);
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Undefined method %s",
methodName.c_str());
}
};
} // namespace pytorch_jni
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(vm, [] {
pytorch_jni::common_registerNatives();
pytorch_jni::PytorchJni::registerNatives();
});
}

View File

@ -0,0 +1,209 @@
#include <cassert>
#include <iostream>
#include <memory>
#include <string>
#include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/script.h>
#include "caffe2/serialize/read_adapter_interface.h"
#include "pytorch_jni_common.h"
#ifdef __ANDROID__
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include <android/log.h>
#endif
namespace pytorch_jni {
namespace {
struct LiteJITCallGuard {
// VariableType dispatch is not included in default mobile build. We need set
// this guard globally to avoid dispatch error (only for dynamic dispatch).
// Thanks to the unification of Variable class and Tensor class it's no longer
// required to toggle the NonVariableTypeMode per op - so it doesn't hurt to
// always set NonVariableTypeMode for inference only use case.
// TODO: Ideally AutoNonVariableTypeMode in this file should be changed to
// InferenceMode but it's blocked due to typeahead application on Oculus
// (D27943428). To unblock, we need to find out which op is making inplace
// update to an inference tensor outside InferenceMode and properly guard it.
torch::AutoNonVariableTypeMode non_var_guard;
};
} // namespace
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
private:
friend HybridBase;
torch::jit::mobile::Module module_;
c10::DeviceType deviceType_;
public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/LiteNativePeer;";
static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> modelPath,
facebook::jni::alias_ref<
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
extraFiles,
jint device) {
return makeCxxInstance(modelPath, extraFiles, device);
}
#ifdef __ANDROID__
static facebook::jni::local_ref<jhybriddata> initHybridAndroidAsset(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
return makeCxxInstance(assetName, assetManager, device);
}
#endif
PytorchJni(
facebook::jni::alias_ref<jstring> modelPath,
facebook::jni::alias_ref<
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
extraFiles,
jint device) {
LiteJITCallGuard guard;
std::unordered_map<std::string, std::string> extra_files;
const auto has_extra = extraFiles && extraFiles->size() > 0;
if (has_extra) {
for (const auto& e : *extraFiles) {
extra_files[e.first->toStdString()] = "";
}
}
deviceType_ = deviceJniCodeToDeviceType(device);
module_ = torch::jit::_load_for_mobile(
std::move(modelPath->toStdString()), std::nullopt, extra_files);
torch::jit::_load_extra_only_for_mobile(
std::move(modelPath->toStdString()), std::nullopt, extra_files);
if (has_extra) {
static auto putMethod =
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>::
javaClassStatic()
->template getMethod<facebook::jni::alias_ref<jobject>(
facebook::jni::alias_ref<jobject>,
facebook::jni::alias_ref<jobject>)>("put");
for (const auto& ef : extra_files) {
putMethod(
extraFiles,
facebook::jni::make_jstring(ef.first),
facebook::jni::make_jstring(ef.second));
}
}
}
#ifdef __ANDROID__
PytorchJni(
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
JNIEnv* env = facebook::jni::Environment::current();
AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get());
if (!mgr) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unable to get asset manager");
}
AAsset* asset = AAssetManager_open(
mgr, assetName->toStdString().c_str(), AASSET_MODE_BUFFER);
if (!asset) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Failed to open asset '%s'",
assetName->toStdString().c_str());
}
auto assetBuffer = AAsset_getBuffer(asset);
if (!assetBuffer) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Could not get buffer for asset '%s'",
assetName->toStdString().c_str());
}
LiteJITCallGuard guard;
module_ =
torch::jit::_load_for_mobile(std::make_unique<MemoryReadAdapter>(
assetBuffer, AAsset_getLength(asset)));
AAsset_close(asset);
deviceType_ = deviceJniCodeToDeviceType(device);
}
#endif
static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", PytorchJni::initHybrid),
#ifdef __ANDROID__
makeNativeMethod(
"initHybridAndroidAsset", PytorchJni::initHybridAndroidAsset),
#endif
makeNativeMethod("forward", PytorchJni::forward),
makeNativeMethod("runMethod", PytorchJni::runMethod),
});
}
facebook::jni::local_ref<JIValue> forward(
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
jinputs) {
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
}
auto output = [&]() {
LiteJITCallGuard guard;
return module_.forward(inputs);
}();
return JIValue::newJIValueFromAtIValue(output);
}
facebook::jni::local_ref<JIValue> runMethod(
facebook::jni::alias_ref<facebook::jni::JString::javaobject> jmethodName,
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
jinputs) {
std::string methodName = jmethodName->toStdString();
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
}
if (auto method = module_.find_method(methodName)) {
auto output = [&]() {
LiteJITCallGuard guard;
return module_.get_method(methodName)(inputs);
}();
return JIValue::newJIValueFromAtIValue(output);
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Undefined method %s",
methodName.c_str());
}
};
} // namespace pytorch_jni
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(vm, [] {
pytorch_jni::common_registerNatives();
pytorch_jni::PytorchJni::registerNatives();
});
}

View File

@ -0,0 +1,27 @@
package org.pytorch;
/** Codes representing tensor data types. */
public enum DType {
// NOTE: "jniCode" must be kept in sync with pytorch_jni_common.cpp.
// NOTE: Never serialize "jniCode", because it can change between releases.
/** Code for dtype torch.uint8. {@link Tensor#dtype()} */
UINT8(1),
/** Code for dtype torch.int8. {@link Tensor#dtype()} */
INT8(2),
/** Code for dtype torch.int32. {@link Tensor#dtype()} */
INT32(3),
/** Code for dtype torch.float32. {@link Tensor#dtype()} */
FLOAT32(4),
/** Code for dtype torch.int64. {@link Tensor#dtype()} */
INT64(5),
/** Code for dtype torch.float64. {@link Tensor#dtype()} */
FLOAT64(6),
;
final int jniCode;
DType(int jniCode) {
this.jniCode = jniCode;
}
}

View File

@ -0,0 +1,15 @@
package org.pytorch;
public enum Device {
// Must be in sync with kDeviceCPU, kDeviceVulkan in
// pytorch_android/src/main/cpp/pytorch_jni_lite.cpp
CPU(1),
VULKAN(2),
;
final int jniCode;
Device(int jniCode) {
this.jniCode = jniCode;
}
}

View File

@ -0,0 +1,9 @@
package org.pytorch;
interface INativePeer {
void resetNative();
IValue forward(IValue... inputs);
IValue runMethod(String methodName, IValue... inputs);
}

View File

@ -0,0 +1,343 @@
package org.pytorch;
import com.facebook.jni.annotations.DoNotStrip;
import java.util.Locale;
import java.util.Map;
/**
* Java representation of a TorchScript value, which is implemented as tagged union that can be one
* of the supported types: https://pytorch.org/docs/stable/jit.html#types .
*
* <p>Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}.
*
* <p>{@code IValue} objects are constructed with {@code IValue.from(value)}, {@code
* IValue.tupleFrom(value1, value2, ...)}, {@code IValue.listFrom(value1, value2, ...)}, or one of
* the {@code dict} methods, depending on the key type.
*
* <p>Data is retrieved from {@code IValue} objects with the {@code toX()} methods. Note that {@code
* str}-type IValues must be extracted with {@link #toStr()}, rather than {@link #toString()}.
*
* <p>{@code IValue} objects may retain references to objects passed into their constructors, and
* may return references to their internal state from {@code toX()}.
*/
@DoNotStrip
public class IValue {
private static final int TYPE_CODE_NULL = 1;
private static final int TYPE_CODE_TENSOR = 2;
private static final int TYPE_CODE_BOOL = 3;
private static final int TYPE_CODE_LONG = 4;
private static final int TYPE_CODE_DOUBLE = 5;
private static final int TYPE_CODE_STRING = 6;
private static final int TYPE_CODE_TUPLE = 7;
private static final int TYPE_CODE_BOOL_LIST = 8;
private static final int TYPE_CODE_LONG_LIST = 9;
private static final int TYPE_CODE_DOUBLE_LIST = 10;
private static final int TYPE_CODE_TENSOR_LIST = 11;
private static final int TYPE_CODE_LIST = 12;
private static final int TYPE_CODE_DICT_STRING_KEY = 13;
private static final int TYPE_CODE_DICT_LONG_KEY = 14;
private String[] TYPE_NAMES = {
"Unknown",
"Null",
"Tensor",
"Bool",
"Long",
"Double",
"String",
"Tuple",
"BoolList",
"LongList",
"DoubleList",
"TensorList",
"GenericList",
"DictStringKey",
"DictLongKey",
};
@DoNotStrip private final int mTypeCode;
@DoNotStrip private Object mData;
@DoNotStrip
private IValue(int typeCode) {
this.mTypeCode = typeCode;
}
@DoNotStrip
public boolean isNull() {
return TYPE_CODE_NULL == this.mTypeCode;
}
@DoNotStrip
public boolean isTensor() {
return TYPE_CODE_TENSOR == this.mTypeCode;
}
@DoNotStrip
public boolean isBool() {
return TYPE_CODE_BOOL == this.mTypeCode;
}
@DoNotStrip
public boolean isLong() {
return TYPE_CODE_LONG == this.mTypeCode;
}
@DoNotStrip
public boolean isDouble() {
return TYPE_CODE_DOUBLE == this.mTypeCode;
}
@DoNotStrip
public boolean isString() {
return TYPE_CODE_STRING == this.mTypeCode;
}
@DoNotStrip
public boolean isTuple() {
return TYPE_CODE_TUPLE == this.mTypeCode;
}
@DoNotStrip
public boolean isBoolList() {
return TYPE_CODE_BOOL_LIST == this.mTypeCode;
}
@DoNotStrip
public boolean isLongList() {
return TYPE_CODE_LONG_LIST == this.mTypeCode;
}
@DoNotStrip
public boolean isDoubleList() {
return TYPE_CODE_DOUBLE_LIST == this.mTypeCode;
}
@DoNotStrip
public boolean isTensorList() {
return TYPE_CODE_TENSOR_LIST == this.mTypeCode;
}
@DoNotStrip
public boolean isList() {
return TYPE_CODE_LIST == this.mTypeCode;
}
@DoNotStrip
public boolean isDictStringKey() {
return TYPE_CODE_DICT_STRING_KEY == this.mTypeCode;
}
@DoNotStrip
public boolean isDictLongKey() {
return TYPE_CODE_DICT_LONG_KEY == this.mTypeCode;
}
/** Creates a new {@code IValue} of type {@code Optional} that contains no value. */
@DoNotStrip
public static IValue optionalNull() {
return new IValue(TYPE_CODE_NULL);
}
/** Creates a new {@code IValue} of type {@code Tensor}. */
@DoNotStrip
public static IValue from(Tensor tensor) {
final IValue iv = new IValue(TYPE_CODE_TENSOR);
iv.mData = tensor;
return iv;
}
/** Creates a new {@code IValue} of type {@code bool}. */
@DoNotStrip
public static IValue from(boolean value) {
final IValue iv = new IValue(TYPE_CODE_BOOL);
iv.mData = value;
return iv;
}
/** Creates a new {@code IValue} of type {@code int}. */
@DoNotStrip
public static IValue from(long value) {
final IValue iv = new IValue(TYPE_CODE_LONG);
iv.mData = value;
return iv;
}
/** Creates a new {@code IValue} of type {@code float}. */
@DoNotStrip
public static IValue from(double value) {
final IValue iv = new IValue(TYPE_CODE_DOUBLE);
iv.mData = value;
return iv;
}
/** Creates a new {@code IValue} of type {@code str}. */
@DoNotStrip
public static IValue from(String value) {
final IValue iv = new IValue(TYPE_CODE_STRING);
iv.mData = value;
return iv;
}
/** Creates a new {@code IValue} of type {@code List[bool]}. */
@DoNotStrip
public static IValue listFrom(boolean... list) {
final IValue iv = new IValue(TYPE_CODE_BOOL_LIST);
iv.mData = list;
return iv;
}
/** Creates a new {@code IValue} of type {@code List[int]}. */
@DoNotStrip
public static IValue listFrom(long... list) {
final IValue iv = new IValue(TYPE_CODE_LONG_LIST);
iv.mData = list;
return iv;
}
/** Creates a new {@code IValue} of type {@code List[float]}. */
@DoNotStrip
public static IValue listFrom(double... list) {
final IValue iv = new IValue(TYPE_CODE_DOUBLE_LIST);
iv.mData = list;
return iv;
}
/** Creates a new {@code IValue} of type {@code List[Tensor]}. */
@DoNotStrip
public static IValue listFrom(Tensor... list) {
final IValue iv = new IValue(TYPE_CODE_TENSOR_LIST);
iv.mData = list;
return iv;
}
/** Creates a new {@code IValue} of type {@code List[T]}. All elements must have the same type. */
@DoNotStrip
public static IValue listFrom(IValue... array) {
final int size = array.length;
if (size > 0) {
final int typeCode0 = array[0].mTypeCode;
for (int i = 1; i < size; i++) {
if (typeCode0 != array[i].mTypeCode) {
throw new IllegalArgumentException("List must contain items of the same type");
}
}
}
final IValue iv = new IValue(TYPE_CODE_LIST);
iv.mData = array;
return iv;
}
/** Creates a new {@code IValue} of type {@code Tuple[T0, T1, ...]}. */
@DoNotStrip
public static IValue tupleFrom(IValue... array) {
final IValue iv = new IValue(TYPE_CODE_TUPLE);
iv.mData = array;
return iv;
}
/** Creates a new {@code IValue} of type {@code Dict[str, V]}. */
@DoNotStrip
public static IValue dictStringKeyFrom(Map<String, IValue> map) {
final IValue iv = new IValue(TYPE_CODE_DICT_STRING_KEY);
iv.mData = map;
return iv;
}
/** Creates a new {@code IValue} of type {@code Dict[int, V]}. */
@DoNotStrip
public static IValue dictLongKeyFrom(Map<Long, IValue> map) {
final IValue iv = new IValue(TYPE_CODE_DICT_LONG_KEY);
iv.mData = map;
return iv;
}
@DoNotStrip
public Tensor toTensor() {
preconditionType(TYPE_CODE_TENSOR, mTypeCode);
return (Tensor) mData;
}
@DoNotStrip
public boolean toBool() {
preconditionType(TYPE_CODE_BOOL, mTypeCode);
return (boolean) mData;
}
@DoNotStrip
public long toLong() {
preconditionType(TYPE_CODE_LONG, mTypeCode);
return (long) mData;
}
@DoNotStrip
public double toDouble() {
preconditionType(TYPE_CODE_DOUBLE, mTypeCode);
return (double) mData;
}
@DoNotStrip
public String toStr() {
preconditionType(TYPE_CODE_STRING, mTypeCode);
return (String) mData;
}
@DoNotStrip
public boolean[] toBoolList() {
preconditionType(TYPE_CODE_BOOL_LIST, mTypeCode);
return (boolean[]) mData;
}
@DoNotStrip
public long[] toLongList() {
preconditionType(TYPE_CODE_LONG_LIST, mTypeCode);
return (long[]) mData;
}
@DoNotStrip
public double[] toDoubleList() {
preconditionType(TYPE_CODE_DOUBLE_LIST, mTypeCode);
return (double[]) mData;
}
@DoNotStrip
public Tensor[] toTensorList() {
preconditionType(TYPE_CODE_TENSOR_LIST, mTypeCode);
return (Tensor[]) mData;
}
@DoNotStrip
public IValue[] toList() {
preconditionType(TYPE_CODE_LIST, mTypeCode);
return (IValue[]) mData;
}
@DoNotStrip
public IValue[] toTuple() {
preconditionType(TYPE_CODE_TUPLE, mTypeCode);
return (IValue[]) mData;
}
@DoNotStrip
public Map<String, IValue> toDictStringKey() {
preconditionType(TYPE_CODE_DICT_STRING_KEY, mTypeCode);
return (Map<String, IValue>) mData;
}
@DoNotStrip
public Map<Long, IValue> toDictLongKey() {
preconditionType(TYPE_CODE_DICT_LONG_KEY, mTypeCode);
return (Map<Long, IValue>) mData;
}
private void preconditionType(int typeCodeExpected, int typeCode) {
if (typeCode != typeCodeExpected) {
throw new IllegalStateException(
String.format(
Locale.US,
"Expected IValue type %s, actual type %s",
getTypeName(typeCodeExpected),
getTypeName(typeCode)));
}
}
private String getTypeName(int typeCode) {
return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown";
}
}

View File

@ -0,0 +1,49 @@
package org.pytorch;
import android.content.res.AssetManager;
import java.util.Map;
public class LiteModuleLoader {
/**
* Loads a serialized TorchScript module from the specified path on the disk to run on specified
* device. The model should be generated from this api _save_for_lite_interpreter().
*
* @param modelPath path to file that contains the serialized TorchScript module.
* @param extraFiles map with extra files names as keys, content of them will be loaded to values.
* @param device {@link org.pytorch.Device} to use for running specified module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::mobile::Module.
*/
public static Module load(
final String modelPath, final Map<String, String> extraFiles, final Device device) {
return new Module(new LiteNativePeer(modelPath, extraFiles, device));
}
/**
* Loads a serialized TorchScript module from the specified path on the disk to run on CPU. The
* model should be generated from this api _save_for_lite_interpreter().
*
* @param modelPath path to file that contains the serialized TorchScript module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::mobile::Module.
*/
public static Module load(final String modelPath) {
return new Module(new LiteNativePeer(modelPath, null, Device.CPU));
}
/**
* Attention: This is not recommended way of loading production modules, as prepackaged assets
* increase apk size etc. For production usage consider using loading from file on the disk {@link
* org.pytorch.Module#load(String)}.
*
* <p>This method is meant to use in tests and demos.
*/
public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName, final Device device) {
return new Module(new LiteNativePeer(assetName, assetManager, device));
}
public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName) {
return new Module(new LiteNativePeer(assetName, assetManager, Device.CPU));
}
}

View File

@ -0,0 +1,64 @@
package org.pytorch;
import com.facebook.jni.HybridData;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.util.Map;
class LiteNativePeer implements INativePeer {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_jni_lite");
PyTorchCodegenLoader.loadNativeLibs();
}
private final HybridData mHybridData;
private static native HybridData initHybrid(
String moduleAbsolutePath, Map<String, String> extraFiles, int deviceJniCode);
private static native HybridData initHybridAndroidAsset(
String assetName, /* android.content.res.AssetManager */
Object androidAssetManager,
int deviceJniCode);
LiteNativePeer(String moduleAbsolutePath, Map<String, String> extraFiles, Device device) {
mHybridData = initHybrid(moduleAbsolutePath, extraFiles, device.jniCode);
}
LiteNativePeer(
String assetName, /* android.content.res.AssetManager */
Object androidAssetManager,
Device device) {
mHybridData = initHybridAndroidAsset(assetName, androidAssetManager, device.jniCode);
}
/**
* Explicitly destroys the native torch::jit::mobile::Module. Calling this method is not required,
* as the native object will be destroyed when this object is garbage-collected. However, the
* timing of garbage collection is not guaranteed, so proactively calling {@code resetNative} can
* free memory more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
*/
public void resetNative() {
mHybridData.resetNative();
}
/**
* Runs the 'forward' method of this module with the specified arguments.
*
* @param inputs arguments for the TorchScript module's 'forward' method.
* @return return value from the 'forward' method.
*/
public native IValue forward(IValue... inputs);
/**
* Runs the specified method of this module with the specified arguments.
*
* @param methodName name of the TorchScript method to run.
* @param inputs arguments that will be passed to TorchScript method.
* @return return value from the method.
*/
public native IValue runMethod(String methodName, IValue... inputs);
}

View File

@ -0,0 +1,14 @@
package org.pytorch;
public enum MemoryFormat {
CONTIGUOUS(1),
CHANNELS_LAST(2),
CHANNELS_LAST_3D(3),
;
final int jniCode;
MemoryFormat(int jniCode) {
this.jniCode = jniCode;
}
}

View File

@ -0,0 +1,75 @@
// Copyright 2004-present Facebook. All Rights Reserved.
package org.pytorch;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.util.Map;
/** Java wrapper for torch::jit::Module. */
public class Module {
private INativePeer mNativePeer;
/**
* Loads a serialized TorchScript module from the specified path on the disk to run on specified
* device.
*
* @param modelPath path to file that contains the serialized TorchScript module.
* @param extraFiles map with extra files names as keys, content of them will be loaded to values.
* @param device {@link org.pytorch.Device} to use for running specified module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
*/
public static Module load(
final String modelPath, final Map<String, String> extraFiles, final Device device) {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
return new Module(new NativePeer(modelPath, extraFiles, device));
}
/**
* Loads a serialized TorchScript module from the specified path on the disk to run on CPU.
*
* @param modelPath path to file that contains the serialized TorchScript module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
*/
public static Module load(final String modelPath) {
return load(modelPath, null, Device.CPU);
}
Module(INativePeer nativePeer) {
this.mNativePeer = nativePeer;
}
/**
* Runs the 'forward' method of this module with the specified arguments.
*
* @param inputs arguments for the TorchScript module's 'forward' method.
* @return return value from the 'forward' method.
*/
public IValue forward(IValue... inputs) {
return mNativePeer.forward(inputs);
}
/**
* Runs the specified method of this module with the specified arguments.
*
* @param methodName name of the TorchScript method to run.
* @param inputs arguments that will be passed to TorchScript method.
* @return return value from the method.
*/
public IValue runMethod(String methodName, IValue... inputs) {
return mNativePeer.runMethod(methodName, inputs);
}
/**
* Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the
* native object will be destroyed when this object is garbage-collected. However, the timing of
* garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory
* more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
*/
public void destroy() {
mNativePeer.resetNative();
}
}

View File

@ -0,0 +1,46 @@
package org.pytorch;
import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import java.util.Map;
class NativePeer implements INativePeer {
static {
NativeLoader.loadLibrary("pytorch_jni");
PyTorchCodegenLoader.loadNativeLibs();
}
private final HybridData mHybridData;
@DoNotStrip
private static native HybridData initHybrid(
String moduleAbsolutePath, Map<String, String> extraFiles, int deviceJniCode);
@DoNotStrip
private static native HybridData initHybridAndroidAsset(
String assetName, /* android.content.res.AssetManager */
Object androidAssetManager,
int deviceJniCode);
NativePeer(String moduleAbsolutePath, Map<String, String> extraFiles, Device device) {
mHybridData = initHybrid(moduleAbsolutePath, extraFiles, device.jniCode);
}
NativePeer(
String assetName, /* android.content.res.AssetManager */
Object androidAssetManager,
Device device) {
mHybridData = initHybridAndroidAsset(assetName, androidAssetManager, device.jniCode);
}
public void resetNative() {
mHybridData.resetNative();
}
@DoNotStrip
public native IValue forward(IValue... inputs);
@DoNotStrip
public native IValue runMethod(String methodName, IValue... inputs);
}

View File

@ -0,0 +1,50 @@
package org.pytorch;
import android.content.res.AssetManager;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
public final class PyTorchAndroid {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_jni_lite");
PyTorchCodegenLoader.loadNativeLibs();
}
/**
* Attention: This is not recommended way of loading production modules, as prepackaged assets
* increase apk size etc. For production usage consider using loading from file on the disk {@link
* org.pytorch.Module#load(String)}.
*
* <p>This method is meant to use in tests and demos.
*/
public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName, final Device device) {
return new Module(new NativePeer(assetName, assetManager, device));
}
public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName) {
return new Module(new NativePeer(assetName, assetManager, Device.CPU));
}
/**
* Globally sets the number of threads used on native side. Attention: Has global effect, all
* modules use one thread pool with specified number of threads.
*
* @param numThreads number of threads, must be positive number.
*/
public static void setNumThreads(int numThreads) {
if (numThreads < 1) {
throw new IllegalArgumentException("Number of threads cannot be less than 1");
}
nativeSetNumThreads(numThreads);
}
@DoNotStrip
private static native void nativeSetNumThreads(int numThreads);
}

View File

@ -0,0 +1,16 @@
package org.pytorch;
import com.facebook.soloader.nativeloader.NativeLoader;
public class PyTorchCodegenLoader {
public static void loadNativeLibs() {
try {
NativeLoader.loadLibrary("torch-code-gen");
} catch (Throwable t) {
// Loading the codegen lib is best-effort since it's only there for query based builds.
}
}
private PyTorchCodegenLoader() {}
}

View File

@ -0,0 +1,735 @@
package org.pytorch;
import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.Locale;
/**
* Representation of a Tensor. Behavior is similar to PyTorch's tensor objects.
*
* <p>Most tensors will be constructed as {@code Tensor.fromBlob(data, shape)}, where {@code data}
* can be an array or a direct {@link Buffer} (of the proper subclass). Helper methods are provided
* to allocate buffers properly.
*
* <p>To access Tensor data, see {@link #dtype()}, {@link #shape()}, and various {@code getDataAs*}
* methods.
*
* <p>When constructing {@code Tensor} objects with {@code data} as an array, it is not specified
* whether this data is copied or retained as a reference so it is recommended not to modify it
* after constructing. {@code data} passed as a {@link Buffer} is not copied, so it can be modified
* between {@link Module} calls to avoid reallocation. Data retrieved from {@code Tensor} objects
* may be copied or may be a reference to the {@code Tensor}'s internal data buffer. {@code shape}
* is always copied.
*/
public abstract class Tensor {
private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null";
private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null";
private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null";
private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Shape elements must be non negative";
private static final String ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER =
"Data buffer must have native byte order (java.nio.ByteOrder#nativeOrder)";
private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT =
"Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)";
@DoNotStrip final long[] shape;
final MemoryFormat memoryFormat;
private static final int INT_SIZE_BYTES = 4;
private static final int FLOAT_SIZE_BYTES = 4;
private static final int LONG_SIZE_BYTES = 8;
private static final int DOUBLE_SIZE_BYTES = 8;
/**
* Allocates a new direct {@link java.nio.ByteBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#fromBlob(ByteBuffer, long[])}, {@link
* Tensor#fromBlobUnsigned(ByteBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
public static ByteBuffer allocateByteBuffer(int numElements) {
return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder());
}
/**
* Allocates a new direct {@link java.nio.IntBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#fromBlob(IntBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
public static IntBuffer allocateIntBuffer(int numElements) {
return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES)
.order(ByteOrder.nativeOrder())
.asIntBuffer();
}
/**
* Allocates a new direct {@link java.nio.FloatBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
public static FloatBuffer allocateFloatBuffer(int numElements) {
return ByteBuffer.allocateDirect(numElements * FLOAT_SIZE_BYTES)
.order(ByteOrder.nativeOrder())
.asFloatBuffer();
}
/**
* Allocates a new direct {@link java.nio.LongBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#fromBlob(LongBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
public static LongBuffer allocateLongBuffer(int numElements) {
return ByteBuffer.allocateDirect(numElements * LONG_SIZE_BYTES)
.order(ByteOrder.nativeOrder())
.asLongBuffer();
}
/**
* Allocates a new direct {@link java.nio.DoubleBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
public static DoubleBuffer allocateDoubleBuffer(int numElements) {
return ByteBuffer.allocateDirect(numElements * DOUBLE_SIZE_BYTES)
.order(ByteOrder.nativeOrder())
.asDoubleBuffer();
}
/**
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of
* bytes.
*
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor fromBlobUnsigned(byte[] data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape));
byteBuffer.put(data);
return new Tensor_uint8(byteBuffer, shape, memoryFormat);
}
public static Tensor fromBlobUnsigned(byte[] data, long[] shape) {
return fromBlobUnsigned(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of
* bytes.
*
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor fromBlob(byte[] data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape));
byteBuffer.put(data);
return new Tensor_int8(byteBuffer, shape, memoryFormat);
}
public static Tensor fromBlob(byte[] data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of
* ints.
*
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor fromBlob(int[] data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final IntBuffer intBuffer = allocateIntBuffer((int) numel(shape));
intBuffer.put(data);
return new Tensor_int32(intBuffer, shape, memoryFormat);
}
public static Tensor fromBlob(int[] data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array
* of floats.
*
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor fromBlob(float[] data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final FloatBuffer floatBuffer = allocateFloatBuffer((int) numel(shape));
floatBuffer.put(data);
return new Tensor_float32(floatBuffer, shape, memoryFormat);
}
public static Tensor fromBlob(float[] data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of
* longs.
*
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor fromBlob(long[] data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final LongBuffer longBuffer = allocateLongBuffer((int) numel(shape));
longBuffer.put(data);
return new Tensor_int64(longBuffer, shape, memoryFormat);
}
public static Tensor fromBlob(long[] data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.float64 with specified shape and data as array
* of doubles.
*
* @param shape Tensor shape
* @param data Tensor elements
*/
public static Tensor fromBlob(double[] data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final DoubleBuffer doubleBuffer = allocateDoubleBuffer((int) numel(shape));
doubleBuffer.put(data);
return new Tensor_float64(doubleBuffer, shape, memoryFormat);
}
public static Tensor fromBlob(double[] data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data.
*
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_uint8(data, shape, memoryFormat);
}
public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) {
return fromBlobUnsigned(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data.
*
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor fromBlob(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_int8(data, shape, memoryFormat);
}
public static Tensor fromBlob(ByteBuffer data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data.
*
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor fromBlob(IntBuffer data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_int32(data, shape, memoryFormat);
}
public static Tensor fromBlob(IntBuffer data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data.
*
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor fromBlob(FloatBuffer data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_float32(data, shape, memoryFormat);
}
public static Tensor fromBlob(FloatBuffer data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data.
*
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor fromBlob(LongBuffer data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_int64(data, shape, memoryFormat);
}
public static Tensor fromBlob(LongBuffer data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
/**
* Creates a new Tensor instance with dtype torch.float64 with specified shape and data.
*
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor fromBlob(DoubleBuffer data, long[] shape, MemoryFormat memoryFormat) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_float64(data, shape, memoryFormat);
}
public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
}
@DoNotStrip private HybridData mHybridData;
private Tensor(long[] shape, MemoryFormat memoryFormat) {
checkShape(shape);
this.shape = Arrays.copyOf(shape, shape.length);
this.memoryFormat = memoryFormat;
}
/** Returns the number of elements in this tensor. */
public long numel() {
return numel(this.shape);
}
/** Calculates the number of elements in a tensor with the specified shape. */
public static long numel(long[] shape) {
checkShape(shape);
int result = 1;
for (long s : shape) {
result *= s;
}
return result;
}
/** Returns the shape of this tensor. (The array is a fresh copy.) */
public long[] shape() {
return Arrays.copyOf(shape, shape.length);
}
/** Returns the memory format of this tensor. */
public MemoryFormat memoryFormat() {
return memoryFormat;
}
/** @return data type of this tensor. */
public abstract DType dtype();
// Called from native
@DoNotStrip
int dtypeJniCode() {
return dtype().jniCode;
}
// Called from native
@DoNotStrip
int memoryFormatJniCode() {
return memoryFormat.jniCode;
}
/**
* @return a Java byte array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-int8 tensor.
*/
public byte[] getDataAsByteArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
}
/**
* @return a Java byte array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-uint8 tensor.
*/
public byte[] getDataAsUnsignedByteArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
}
/**
* @return a Java int array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-int32 tensor.
*/
public int[] getDataAsIntArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as int array.");
}
/**
* @return a Java float array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-float32 tensor.
*/
public float[] getDataAsFloatArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as float array.");
}
/**
* @return a Java long array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-int64 tensor.
*/
public long[] getDataAsLongArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as long array.");
}
/**
* @return a Java double array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-float64 tensor.
*/
public double[] getDataAsDoubleArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as double array.");
}
@DoNotStrip
Buffer getRawDataBuffer() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer.");
}
static class Tensor_uint8 extends Tensor {
private final ByteBuffer data;
private Tensor_uint8(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) {
super(shape, memoryFormat);
this.data = data;
}
@Override
public DType dtype() {
return DType.UINT8;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public byte[] getDataAsUnsignedByteArray() {
data.rewind();
byte[] arr = new byte[data.remaining()];
data.get(arr);
return arr;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape));
}
}
static class Tensor_int8 extends Tensor {
private final ByteBuffer data;
private Tensor_int8(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) {
super(shape, memoryFormat);
this.data = data;
}
@Override
public DType dtype() {
return DType.INT8;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public byte[] getDataAsByteArray() {
data.rewind();
byte[] arr = new byte[data.remaining()];
data.get(arr);
return arr;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape));
}
}
static class Tensor_int32 extends Tensor {
private final IntBuffer data;
private Tensor_int32(IntBuffer data, long[] shape, MemoryFormat memoryFormat) {
super(shape, memoryFormat);
this.data = data;
}
@Override
public DType dtype() {
return DType.INT32;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public int[] getDataAsIntArray() {
data.rewind();
int[] arr = new int[data.remaining()];
data.get(arr);
return arr;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape));
}
}
static class Tensor_float32 extends Tensor {
private final FloatBuffer data;
Tensor_float32(FloatBuffer data, long[] shape, MemoryFormat memoryFormat) {
super(shape, memoryFormat);
this.data = data;
}
@Override
public float[] getDataAsFloatArray() {
data.rewind();
float[] arr = new float[data.remaining()];
data.get(arr);
return arr;
}
@Override
public DType dtype() {
return DType.FLOAT32;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.float32)", Arrays.toString(shape));
}
}
static class Tensor_int64 extends Tensor {
private final LongBuffer data;
private Tensor_int64(LongBuffer data, long[] shape, MemoryFormat memoryFormat) {
super(shape, memoryFormat);
this.data = data;
}
@Override
public DType dtype() {
return DType.INT64;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public long[] getDataAsLongArray() {
data.rewind();
long[] arr = new long[data.remaining()];
data.get(arr);
return arr;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape));
}
}
static class Tensor_float64 extends Tensor {
private final DoubleBuffer data;
private Tensor_float64(DoubleBuffer data, long[] shape, MemoryFormat memoryFormat) {
super(shape, memoryFormat);
this.data = data;
}
@Override
public DType dtype() {
return DType.FLOAT64;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public double[] getDataAsDoubleArray() {
data.rewind();
double[] arr = new double[data.remaining()];
data.get(arr);
return arr;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape));
}
}
// region checks
private static void checkArgument(boolean expression, String errorMessage, Object... args) {
if (!expression) {
throw new IllegalArgumentException(String.format(Locale.US, errorMessage, args));
}
}
private static void checkShape(long[] shape) {
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
for (int i = 0; i < shape.length; i++) {
checkArgument(shape[i] >= 0, ERROR_MSG_SHAPE_NON_NEGATIVE);
}
}
private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape) {
final long numel = numel(shape);
checkArgument(
numel == dataCapacity,
"Inconsistent data capacity:%d and shape number elements:%d shape:%s",
dataCapacity,
numel,
Arrays.toString(shape));
}
// endregion checks
// Called from native
@DoNotStrip
private static Tensor nativeNewTensor(
ByteBuffer data, long[] shape, int dtype, int memoryFormatCode, HybridData hybridData) {
Tensor tensor = null;
MemoryFormat memoryFormat = MemoryFormat.CONTIGUOUS;
if (MemoryFormat.CHANNELS_LAST.jniCode == memoryFormatCode) {
memoryFormat = MemoryFormat.CHANNELS_LAST;
} else if (MemoryFormat.CHANNELS_LAST_3D.jniCode == memoryFormatCode) {
memoryFormat = MemoryFormat.CHANNELS_LAST_3D;
}
if (DType.FLOAT32.jniCode == dtype) {
tensor = new Tensor_float32(data.asFloatBuffer(), shape, memoryFormat);
} else if (DType.INT32.jniCode == dtype) {
tensor = new Tensor_int32(data.asIntBuffer(), shape, memoryFormat);
} else if (DType.INT64.jniCode == dtype) {
tensor = new Tensor_int64(data.asLongBuffer(), shape, memoryFormat);
} else if (DType.FLOAT64.jniCode == dtype) {
tensor = new Tensor_float64(data.asDoubleBuffer(), shape, memoryFormat);
} else if (DType.UINT8.jniCode == dtype) {
tensor = new Tensor_uint8(data, shape, memoryFormat);
} else if (DType.INT8.jniCode == dtype) {
tensor = new Tensor_int8(data, shape, memoryFormat);
} else {
new IllegalArgumentException("Unknown Tensor dtype");
}
tensor.mHybridData = hybridData;
return tensor;
}
}

View File

@ -0,0 +1,3 @@
<resources>
<string name="app_name">pytorch_android</string>
</resources>

View File

@ -0,0 +1,105 @@
def forward(self, input):
return None
def eqBool(self, input: bool) -> bool:
return input
def eqInt(self, input: int) -> int:
return input
def eqFloat(self, input: float) -> float:
return input
def eqStr(self, input: str) -> str:
return input
def eqTensor(self, input: Tensor) -> Tensor:
return input
def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]:
return input
def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]:
return input
def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]:
return input
def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]:
sum = 0
for x in input:
sum += x
return (input, sum)
def listBoolConjunction(self, input: List[bool]) -> bool:
res = True
for x in input:
res = res and x
return res
def listBoolDisjunction(self, input: List[bool]) -> bool:
res = False
for x in input:
res = res or x
return res
def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]:
sum = 0
for x in input:
sum += x
return (input, sum)
def optionalIntIsNone(self, input: Optional[int]) -> bool:
return input is None
def intEq0None(self, input: int) -> Optional[int]:
if input == 0:
return None
return input
def str3Concat(self, input: str) -> str:
return input + input + input
def newEmptyShapeWithItem(self, input):
return torch.tensor([int(input.item())])[0]
def testAliasWithOffset(self) -> List[Tensor]:
x = torch.tensor([100, 200])
a = [x[0], x[1]]
return a
def testNonContiguous(self):
x = torch.tensor([100, 200, 300])[::2]
assert not x.is_contiguous()
assert x[0] == 100
assert x[1] == 300
return x
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.conv2d(x, w)
if (toChannelsLast):
# memory_format=torch.channels_last
r = r.contiguous(memory_format=2)
else:
r = r.contiguous()
return r
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.conv3d(x, w)
if (toChannelsLast):
# memory_format=torch.channels_last_3d
r = r.contiguous(memory_format=2)
else:
r = r.contiguous()
return r
def contiguous(self, x: Tensor) -> Tensor:
return x.contiguous()
def contiguousChannelsLast(self, x: Tensor) -> Tensor:
# memory_format=torch.channels_last
return x.contiguous(memory_format=2)
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
# memory_format=torch.channels_last_3d
return x.contiguous(memory_format=3)