mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Reapply "Delete TorchScript based Android demo app and point to ExecuTorch (#153633)" (#153656)"
This reverts commit 7ed377f5776578aec4a6a9bc4eeef221a6b80a77. Reverted https://github.com/pytorch/pytorch/pull/153656 on behalf of https://github.com/larryliu0820 due to Still being used internally so can't remove ([comment](https://github.com/pytorch/pytorch/pull/153656#issuecomment-2887665403))
This commit is contained in:
184
android/pytorch_android/CMakeLists.txt
Normal file
184
android/pytorch_android/CMakeLists.txt
Normal file
@ -0,0 +1,184 @@
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
option(BUILD_LITE_INTERPRETER "Master flag to build pytorch_jni_lite" ON)
|
||||
message(
|
||||
STATUS
|
||||
"BUILD_LITE_INTERPRETER (pytorch_jni_lite): ${BUILD_LITE_INTERPRETER}")
|
||||
|
||||
if(BUILD_LITE_INTERPRETER)
|
||||
project(pytorch_jni_lite CXX)
|
||||
set(PYTORCH_JNI_TARGET pytorch_jni_lite)
|
||||
else()
|
||||
project(pytorch_jni CXX)
|
||||
set(PYTORCH_JNI_TARGET pytorch_jni)
|
||||
endif()
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.")
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
message(STATUS "ANDROID_STL:${ANDROID_STL}")
|
||||
|
||||
set(TRACE_ENABLED OFF)
|
||||
if(DEFINED ENV{TRACE_ENABLED})
|
||||
if($ENV{TRACE_ENABLED} STREQUAL "1")
|
||||
message(STATUS "TRACE_ENABLED ON")
|
||||
set(TRACE_ENABLED ON)
|
||||
endif()
|
||||
endif()
|
||||
if(NOT TRACE_ENABLED)
|
||||
message(STATUS "TRACE_ENABLED OFF")
|
||||
endif()
|
||||
|
||||
set(USE_VULKAN OFF)
|
||||
|
||||
set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
|
||||
|
||||
if(ANDROID_ABI)
|
||||
set(USE_VULKAN ON)
|
||||
set(libtorch_include_DIR ${pytorch_android_DIR}/libtorch_include/${ANDROID_ABI})
|
||||
set(BUILD_SUBDIR ${ANDROID_ABI})
|
||||
elseif(BUILD_LIBTORCH_WITH_JNI)
|
||||
# Don't need LIBTORCH_HOME if we're building from within PyTorch.
|
||||
else()
|
||||
# Building against a pre-built libtorch.
|
||||
if(NOT LIBTORCH_HOME)
|
||||
message(FATAL_ERROR
|
||||
"pytorch_android requires LIBTORCH_HOME to be defined for non-Android builds.")
|
||||
endif()
|
||||
set(libtorch_include_DIR ${LIBTORCH_HOME}/include)
|
||||
link_directories(${LIBTORCH_HOME}/lib)
|
||||
set(BUILD_SUBDIR host)
|
||||
endif()
|
||||
|
||||
message(STATUS "libtorch dir:${libtorch_DIR}")
|
||||
|
||||
configure_file(
|
||||
${pytorch_android_DIR}/cmake_macros.h.in
|
||||
${pytorch_android_DIR}/cmake_macros.h)
|
||||
|
||||
|
||||
if(BUILD_LITE_INTERPRETER)
|
||||
file(GLOB pytorch_android_SOURCES
|
||||
${pytorch_android_DIR}/pytorch_jni_lite.cpp
|
||||
${pytorch_android_DIR}/pytorch_jni_common.cpp
|
||||
${pytorch_android_DIR}/pytorch_jni_common.h
|
||||
)
|
||||
else()
|
||||
file(GLOB pytorch_android_SOURCES
|
||||
${pytorch_android_DIR}/pytorch_jni_jit.cpp
|
||||
${pytorch_android_DIR}/pytorch_jni_common.cpp
|
||||
${pytorch_android_DIR}/pytorch_jni_common.h
|
||||
)
|
||||
endif()
|
||||
add_library(${PYTORCH_JNI_TARGET} SHARED ${pytorch_android_SOURCES})
|
||||
|
||||
if(APPLE)
|
||||
# Need to add rpath so dlopen can find dependencies.
|
||||
add_custom_command(TARGET pytorch_jni
|
||||
POST_BUILD COMMAND
|
||||
${CMAKE_INSTALL_NAME_TOOL} -add_rpath "@loader_path"
|
||||
$<TARGET_FILE:pytorch_jni>)
|
||||
endif()
|
||||
|
||||
target_compile_options(${PYTORCH_JNI_TARGET} PRIVATE
|
||||
-fexceptions
|
||||
)
|
||||
target_include_directories(${PYTORCH_JNI_TARGET} BEFORE
|
||||
PUBLIC $<BUILD_INTERFACE:${libtorch_include_DIR}>)
|
||||
|
||||
set(fbjni_DIR ${CMAKE_CURRENT_LIST_DIR}/../libs/fbjni/)
|
||||
set(fbjni_BUILD_DIR ${CMAKE_BINARY_DIR}/fbjni/${BUILD_SUBDIR})
|
||||
|
||||
add_subdirectory(${fbjni_DIR} ${fbjni_BUILD_DIR})
|
||||
|
||||
# ---[ Vulkan deps
|
||||
if(USE_VULKAN)
|
||||
set(Vulkan_LIBS)
|
||||
set(Vulkan_INCLUDES)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/../../cmake/VulkanDependencies.cmake)
|
||||
endif()
|
||||
|
||||
if(ANDROID_ABI)
|
||||
function(import_static_lib name)
|
||||
add_library(${name} STATIC IMPORTED)
|
||||
set_property(
|
||||
TARGET ${name}
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/${name}.a)
|
||||
endfunction(import_static_lib)
|
||||
|
||||
import_static_lib(libtorch)
|
||||
import_static_lib(libtorch_cpu)
|
||||
import_static_lib(libc10)
|
||||
import_static_lib(libnnpack)
|
||||
import_static_lib(libXNNPACK)
|
||||
import_static_lib(libmicrokernels-prod)
|
||||
import_static_lib(libpytorch_qnnpack)
|
||||
import_static_lib(libpthreadpool)
|
||||
import_static_lib(libeigen_blas)
|
||||
import_static_lib(libcpuinfo)
|
||||
import_static_lib(libclog)
|
||||
|
||||
# Link most things statically on Android.
|
||||
set(pytorch_jni_LIBS
|
||||
fbjni
|
||||
-Wl,--gc-sections
|
||||
-Wl,--whole-archive
|
||||
libtorch
|
||||
libtorch_cpu
|
||||
-Wl,--no-whole-archive
|
||||
libc10
|
||||
libnnpack
|
||||
libXNNPACK
|
||||
libmicrokernels-prod
|
||||
libpytorch_qnnpack
|
||||
libpthreadpool
|
||||
libeigen_blas
|
||||
libcpuinfo
|
||||
libclog
|
||||
)
|
||||
else()
|
||||
# Prefer dynamic linking on the host
|
||||
set(pytorch_jni_LIBS
|
||||
fbjni
|
||||
torch
|
||||
torch_cpu
|
||||
c10
|
||||
cpuinfo
|
||||
)
|
||||
|
||||
if(USE_NNPACK)
|
||||
list(APPEND pytorch_jni_LIBS nnpack)
|
||||
endif()
|
||||
|
||||
if(USE_XNNPACK)
|
||||
list(APPEND pytorch_jni_LIBS XNNPACK)
|
||||
list(APPEND pytorch_jni_LIBS microkernels-prod)
|
||||
endif()
|
||||
|
||||
if(USE_SYSTEM_PTHREADPOOL)
|
||||
list(APPEND pytorch_jni_LIBS pthreadpool)
|
||||
endif()
|
||||
|
||||
if(USE_PYTORCH_QNNPACK)
|
||||
list(APPEND pytorch_jni_LIBS pytorch_qnnpack)
|
||||
list(APPEND pytorch_jni_LIBS clog)
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
if(USE_VULKAN)
|
||||
list(APPEND pytorch_jni_LIBS ${Vulkan_LIBS})
|
||||
endif()
|
||||
|
||||
target_link_libraries(${PYTORCH_JNI_TARGET} ${pytorch_jni_LIBS})
|
||||
|
||||
install(TARGETS ${PYTORCH_JNI_TARGET}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) #For windows
|
||||
|
||||
if(MSVC)
|
||||
install(FILES $<TARGET_PDB_FILE:pytorch_jni> DESTINATION ${CMAKE_INSTALL_LIBDIR} OPTIONAL)
|
||||
install(TARGETS ${PYTORCH_JNI_TARGET} DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
endif()
|
163
android/pytorch_android/build.gradle
Normal file
163
android/pytorch_android/build.gradle
Normal file
@ -0,0 +1,163 @@
|
||||
apply plugin: 'com.android.library'
|
||||
apply plugin: 'maven'
|
||||
|
||||
android {
|
||||
compileSdkVersion rootProject.compileSdkVersion
|
||||
buildToolsVersion rootProject.buildToolsVersion
|
||||
|
||||
defaultConfig {
|
||||
minSdkVersion rootProject.minSdkVersion
|
||||
targetSdkVersion rootProject.targetSdkVersion
|
||||
versionCode 0
|
||||
versionName "0.1"
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
ndk {
|
||||
abiFilters ABI_FILTERS.split(",")
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
if(System.env.BUILD_LITE_INTERPRETER == '0') {
|
||||
arguments "-DANDROID_STL=c++_shared", "-DBUILD_LITE_INTERPRETER=OFF", "-DUSE_LITE_INTERPRETER_PROFILER=OFF"
|
||||
} else {
|
||||
arguments "-DANDROID_STL=c++_shared", "-DUSE_LITE_INTERPRETER_PROFILER=OFF"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
buildTypes {
|
||||
debug {
|
||||
minifyEnabled false
|
||||
debuggable true
|
||||
}
|
||||
release {
|
||||
minifyEnabled false
|
||||
}
|
||||
}
|
||||
sourceSets {
|
||||
main {
|
||||
java {
|
||||
if(System.env.BUILD_LITE_INTERPRETER == '0') {
|
||||
println 'Build pytorch_jni'
|
||||
exclude 'org/pytorch/LiteModuleLoader.java'
|
||||
exclude 'org/pytorch/LiteNativePeer.java'
|
||||
} else {
|
||||
println 'Build pytorch_jni_lite'
|
||||
}
|
||||
}
|
||||
jniLibs.srcDirs = ['src/main/jniLibs']
|
||||
manifest.srcFile 'src/main/AndroidManifest.xml'
|
||||
}
|
||||
androidTest {
|
||||
java {
|
||||
if(System.env.BUILD_LITE_INTERPRETER == '0') {
|
||||
println 'Build test for full jit (pytorch_jni)'
|
||||
exclude 'org/pytorch/PytorchHostTests.java'
|
||||
exclude 'org/pytorch/PytorchLiteInstrumentedTests.java'
|
||||
exclude 'org/pytorch/suite/PytorchLiteInstrumentedTestSuite.java'
|
||||
} else {
|
||||
println 'Build test for lite interpreter (pytorch_jni_lite)'
|
||||
exclude 'org/pytorch/PytorchHostTests.java'
|
||||
exclude 'org/pytorch/PytorchInstrumentedTests.java'
|
||||
exclude 'org/pytorch/suite/PytorchInstrumentedTestSuite.java'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path "CMakeLists.txt"
|
||||
}
|
||||
}
|
||||
|
||||
packagingOptions {
|
||||
if (nativeLibsDoNotStrip.toBoolean()) {
|
||||
doNotStrip "**/*.so"
|
||||
logger.warn('WARNING: nativeLibsDoNotStrip==true; debug symbols included')
|
||||
}
|
||||
}
|
||||
|
||||
useLibrary 'android.test.runner'
|
||||
useLibrary 'android.test.base'
|
||||
useLibrary 'android.test.mock'
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'com.facebook.fbjni:fbjni-java-only:' + rootProject.fbjniJavaOnlyVersion
|
||||
implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion
|
||||
|
||||
testImplementation 'junit:junit:' + rootProject.junitVersion
|
||||
testImplementation 'androidx.test:core:' + rootProject.coreVersion
|
||||
|
||||
androidTestImplementation 'junit:junit:' + rootProject.junitVersion
|
||||
androidTestImplementation 'androidx.test:core:' + rootProject.coreVersion
|
||||
androidTestImplementation 'androidx.test.ext:junit:' + rootProject.extJUnitVersion
|
||||
androidTestImplementation 'androidx.test:rules:' + rootProject.rulesVersion
|
||||
androidTestImplementation 'androidx.test:runner:' + rootProject.runnerVersion
|
||||
}
|
||||
|
||||
apply from: rootProject.file('gradle/release.gradle')
|
||||
|
||||
task sourcesJar(type: Jar) {
|
||||
from android.sourceSets.main.java.srcDirs
|
||||
classifier = 'sources'
|
||||
}
|
||||
|
||||
def getLibtorchHeadersDir() {
|
||||
def abi = ABI_FILTERS.split(",")[0]
|
||||
return "$rootDir/pytorch_android/src/main/cpp/libtorch_include/$abi"
|
||||
}
|
||||
|
||||
afterEvaluate {
|
||||
if (POM_PACKAGING == 'aar') {
|
||||
android.libraryVariants.all { variant ->
|
||||
variant.outputs.each { output ->
|
||||
File f = output.outputFile
|
||||
if (f.name.endsWith(".aar")) {
|
||||
output.assemble.finalizedBy addFolderToAarTask(
|
||||
"addHeadersToAar" + variant.name,
|
||||
f.path,
|
||||
getLibtorchHeadersDir(),
|
||||
"headers")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tasks.whenTaskAdded { task ->
|
||||
if (task.name.startsWith("bundle") && task.name.endsWith("Aar")) {
|
||||
doLast {
|
||||
addFolderToAar("addHeadersTo" + task.name, task.archivePath, getLibtorchHeadersDir(), 'headers')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def addFolderToAarTask(taskName, aarPath, folderPath, folderPathInAar) {
|
||||
return tasks.register(taskName) {
|
||||
doLast {
|
||||
addFolderToAar(taskName, aarPath, folderPath, folderPathInAar)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def addFolderToAar(taskName, aarPath, folderPath, folderPathInAar) {
|
||||
def tmpDir = file("${buildDir}/${taskName}")
|
||||
tmpDir.mkdir()
|
||||
def tmpDirFolder = file("${tmpDir.path}/${folderPathInAar}")
|
||||
tmpDirFolder.mkdir()
|
||||
copy {
|
||||
from zipTree(aarPath)
|
||||
into tmpDir
|
||||
}
|
||||
copy {
|
||||
from fileTree(folderPath)
|
||||
into tmpDirFolder
|
||||
}
|
||||
ant.zip(destfile: aarPath) {
|
||||
fileset(dir: tmpDir.path)
|
||||
}
|
||||
delete tmpDir
|
||||
}
|
||||
|
||||
artifacts.add('archives', sourcesJar)
|
20
android/pytorch_android/generate_test_asset.cpp
Normal file
20
android/pytorch_android/generate_test_asset.cpp
Normal file
@ -0,0 +1,20 @@
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/jit.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
std::string input_file_path{argv[1]};
|
||||
std::string output_file_path{argv[2]};
|
||||
|
||||
std::ifstream ifs(input_file_path);
|
||||
std::stringstream buffer;
|
||||
buffer << ifs.rdbuf();
|
||||
torch::jit::Module m("TestModule");
|
||||
|
||||
m.define(buffer.str());
|
||||
m.save(output_file_path);
|
||||
}
|
151
android/pytorch_android/generate_test_torchscripts.py
Normal file
151
android/pytorch_android/generate_test_torchscripts.py
Normal file
@ -0,0 +1,151 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
OUTPUT_DIR = "src/androidTest/assets/"
|
||||
|
||||
|
||||
def scriptAndSave(module, fileName):
|
||||
print("-" * 80)
|
||||
script_module = torch.jit.script(module)
|
||||
print(script_module.graph)
|
||||
outputFileName = OUTPUT_DIR + fileName
|
||||
# note that the lite interpreter model can also be used in full JIT
|
||||
script_module._save_for_lite_interpreter(outputFileName)
|
||||
print("Saved to " + outputFileName)
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
class Test(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, input):
|
||||
return None
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqBool(self, input: bool) -> bool:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqInt(self, input: int) -> int:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqFloat(self, input: float) -> float:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqStr(self, input: str) -> str:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqTensor(self, input: Tensor) -> Tensor:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqDictStrKeyIntValue(self, input: dict[str, int]) -> dict[str, int]:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqDictIntKeyIntValue(self, input: dict[int, int]) -> dict[int, int]:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqDictFloatKeyIntValue(self, input: dict[float, int]) -> dict[float, int]:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def listIntSumReturnTuple(self, input: list[int]) -> tuple[list[int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
return (input, sum)
|
||||
|
||||
@torch.jit.script_method
|
||||
def listBoolConjunction(self, input: list[bool]) -> bool:
|
||||
res = True
|
||||
for x in input:
|
||||
res = res and x
|
||||
return res
|
||||
|
||||
@torch.jit.script_method
|
||||
def listBoolDisjunction(self, input: list[bool]) -> bool:
|
||||
res = False
|
||||
for x in input:
|
||||
res = res or x
|
||||
return res
|
||||
|
||||
@torch.jit.script_method
|
||||
def tupleIntSumReturnTuple(
|
||||
self, input: tuple[int, int, int]
|
||||
) -> tuple[tuple[int, int, int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
return (input, sum)
|
||||
|
||||
@torch.jit.script_method
|
||||
def optionalIntIsNone(self, input: Optional[int]) -> bool:
|
||||
return input is None
|
||||
|
||||
@torch.jit.script_method
|
||||
def intEq0None(self, input: int) -> Optional[int]:
|
||||
if input == 0:
|
||||
return None
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def str3Concat(self, input: str) -> str:
|
||||
return input + input + input
|
||||
|
||||
@torch.jit.script_method
|
||||
def newEmptyShapeWithItem(self, input):
|
||||
return torch.tensor([int(input.item())])[0]
|
||||
|
||||
@torch.jit.script_method
|
||||
def testAliasWithOffset(self) -> list[Tensor]:
|
||||
x = torch.tensor([100, 200])
|
||||
a = [x[0], x[1]]
|
||||
return a
|
||||
|
||||
@torch.jit.script_method
|
||||
def testNonContiguous(self):
|
||||
x = torch.tensor([100, 200, 300])[::2]
|
||||
assert not x.is_contiguous()
|
||||
assert x[0] == 100
|
||||
assert x[1] == 300
|
||||
return x
|
||||
|
||||
@torch.jit.script_method
|
||||
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = torch.nn.functional.conv2d(x, w)
|
||||
if toChannelsLast:
|
||||
r = r.contiguous(memory_format=torch.channels_last)
|
||||
else:
|
||||
r = r.contiguous()
|
||||
return r
|
||||
|
||||
@torch.jit.script_method
|
||||
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = torch.nn.functional.conv3d(x, w)
|
||||
if toChannelsLast:
|
||||
r = r.contiguous(memory_format=torch.channels_last_3d)
|
||||
else:
|
||||
r = r.contiguous()
|
||||
return r
|
||||
|
||||
@torch.jit.script_method
|
||||
def contiguous(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous()
|
||||
|
||||
@torch.jit.script_method
|
||||
def contiguousChannelsLast(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
@torch.jit.script_method
|
||||
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous(memory_format=torch.channels_last_3d)
|
||||
|
||||
|
||||
scriptAndSave(Test(), "test.pt")
|
4
android/pytorch_android/gradle.properties
Normal file
4
android/pytorch_android/gradle.properties
Normal file
@ -0,0 +1,4 @@
|
||||
POM_NAME=pytorch_android_lite pytorch android api
|
||||
POM_DESCRIPTION=pytorch_android_lite pytorch android api
|
||||
POM_ARTIFACT_ID=pytorch_android_lite
|
||||
POM_PACKAGING=aar
|
42
android/pytorch_android/host/build.gradle
Normal file
42
android/pytorch_android/host/build.gradle
Normal file
@ -0,0 +1,42 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
//
|
||||
// This source code is licensed under the Apache-2 license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
plugins {
|
||||
id 'java-library'
|
||||
}
|
||||
|
||||
repositories {
|
||||
mavenLocal()
|
||||
jcenter()
|
||||
}
|
||||
|
||||
sourceSets {
|
||||
main {
|
||||
java {
|
||||
srcDir '../src/main/java'
|
||||
exclude 'org/pytorch/PyTorchAndroid.java'
|
||||
exclude 'org/pytorch/LitePyTorchAndroid.java'
|
||||
exclude 'org/pytorch/LiteModuleLoader.java'
|
||||
exclude 'org/pytorch/LiteNativePeer.java'
|
||||
}
|
||||
}
|
||||
test {
|
||||
java {
|
||||
srcDir '../src/androidTest/java'
|
||||
exclude '**/PytorchInstrumented*'
|
||||
exclude '**/PytorchLiteInstrumented*'
|
||||
}
|
||||
resources.srcDirs = ["../src/androidTest/assets"]
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
compileOnly 'com.google.code.findbugs:jsr305:3.0.1'
|
||||
implementation 'com.facebook.soloader:nativeloader:0.10.1'
|
||||
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
|
||||
testImplementation 'junit:junit:4.12'
|
||||
}
|
||||
|
||||
apply from: rootProject.file('gradle/release.gradle')
|
4
android/pytorch_android/host/gradle.properties
Normal file
4
android/pytorch_android/host/gradle.properties
Normal file
@ -0,0 +1,4 @@
|
||||
POM_NAME=pytorch_java_only pytorch java api
|
||||
POM_DESCRIPTION=pytorch_java_only pytorch java api
|
||||
POM_ARTIFACT_ID=pytorch_java_only
|
||||
POM_PACKAGING=jar
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/dropout_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/dropout_ops.ptl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/linear_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/linear_ops.ptl
Normal file
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/mobilenet_v2.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/mobilenet_v2.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/nn_utils_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/nn_utils_ops.ptl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/padding_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/padding_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/pointwise_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/pointwise_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/pooling_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/pooling_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/recurrent_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/recurrent_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/reduction_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/reduction_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/sampling_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/sampling_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/shuffle_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/shuffle_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/sparse_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/sparse_ops.ptl
Normal file
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/spectral_ops.ptl
Normal file
BIN
android/pytorch_android/src/androidTest/assets/spectral_ops.ptl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
android/pytorch_android/src/androidTest/assets/test.pt
Normal file
BIN
android/pytorch_android/src/androidTest/assets/test.pt
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,21 @@
|
||||
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include "caffe2/android/pytorch_android/src/main/cpp/pytorch_jni_common.h"
|
||||
|
||||
using namespace ::testing;
|
||||
|
||||
TEST(pytorch_jni_common_test, newJIValueFromAtIValue) {
|
||||
auto dict = c10::impl::GenericDict(
|
||||
c10::dynT<c10::IntType>(), c10::dynT<c10::StringType>());
|
||||
auto dictCallback = [](auto&&) {
|
||||
return facebook::jni::local_ref<pytorch_jni::JIValue>{};
|
||||
};
|
||||
EXPECT_NO_THROW(pytorch_jni::JIValue::newJIValueFromAtIValue(
|
||||
dict, dictCallback, dictCallback));
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package org.pytorch;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.Objects;
|
||||
|
||||
public class PytorchHostTests extends PytorchTestBase {
|
||||
|
||||
@Override
|
||||
protected Module loadModel(String path) throws IOException {
|
||||
return Module.load(assetFilePath(path));
|
||||
}
|
||||
|
||||
private String assetFilePath(String assetName) throws IOException {
|
||||
Path tempFile = Files.createTempFile("test", ".pt");
|
||||
try (InputStream resource =
|
||||
Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {
|
||||
Files.copy(resource, tempFile, StandardCopyOption.REPLACE_EXISTING);
|
||||
}
|
||||
return tempFile.toAbsolutePath().toString();
|
||||
}
|
||||
}
|
@ -0,0 +1,42 @@
|
||||
package org.pytorch;
|
||||
|
||||
import android.content.Context;
|
||||
import androidx.test.InstrumentationRegistry;
|
||||
import androidx.test.runner.AndroidJUnit4;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class PytorchInstrumentedTests extends PytorchTestBase {
|
||||
|
||||
@Override
|
||||
protected Module loadModel(String path) throws IOException {
|
||||
return Module.load(assetFilePath(path));
|
||||
}
|
||||
|
||||
private String assetFilePath(String assetName) throws IOException {
|
||||
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||
File file = new File(appContext.getFilesDir(), assetName);
|
||||
if (file.exists() && file.length() > 0) {
|
||||
return file.getAbsolutePath();
|
||||
}
|
||||
|
||||
try (InputStream is = appContext.getAssets().open(assetName)) {
|
||||
try (OutputStream os = new FileOutputStream(file)) {
|
||||
byte[] buffer = new byte[4 * 1024];
|
||||
int read;
|
||||
while ((read = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, read);
|
||||
}
|
||||
os.flush();
|
||||
}
|
||||
return file.getAbsolutePath();
|
||||
} catch (IOException e) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,42 @@
|
||||
package org.pytorch;
|
||||
|
||||
import android.content.Context;
|
||||
import androidx.test.InstrumentationRegistry;
|
||||
import androidx.test.runner.AndroidJUnit4;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class PytorchLiteInstrumentedTests extends PytorchTestBase {
|
||||
|
||||
@Override
|
||||
protected Module loadModel(String path) throws IOException {
|
||||
return LiteModuleLoader.load(assetFilePath(path));
|
||||
}
|
||||
|
||||
private String assetFilePath(String assetName) throws IOException {
|
||||
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||
File file = new File(appContext.getFilesDir(), assetName);
|
||||
if (file.exists() && file.length() > 0) {
|
||||
return file.getAbsolutePath();
|
||||
}
|
||||
|
||||
try (InputStream is = appContext.getAssets().open(assetName)) {
|
||||
try (OutputStream os = new FileOutputStream(file)) {
|
||||
byte[] buffer = new byte[4 * 1024];
|
||||
int read;
|
||||
while ((read = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, read);
|
||||
}
|
||||
os.flush();
|
||||
}
|
||||
return file.getAbsolutePath();
|
||||
} catch (IOException e) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,694 @@
|
||||
package org.pytorch;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.Test;
|
||||
import org.junit.Ignore;
|
||||
|
||||
public abstract class PytorchTestBase {
|
||||
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
|
||||
|
||||
@Test
|
||||
public void testForwardNull() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
|
||||
assertTrue(input.isTensor());
|
||||
final IValue output = module.forward(input);
|
||||
assertTrue(output.isNull());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqBool() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
for (boolean value : new boolean[] {false, true}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isBool());
|
||||
assertTrue(value == input.toBool());
|
||||
final IValue output = module.runMethod("eqBool", input);
|
||||
assertTrue(output.isBool());
|
||||
assertTrue(value == output.toBool());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqInt() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isLong());
|
||||
assertTrue(value == input.toLong());
|
||||
final IValue output = module.runMethod("eqInt", input);
|
||||
assertTrue(output.isLong());
|
||||
assertTrue(value == output.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqFloat() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
double[] values =
|
||||
new double[] {
|
||||
-Double.MAX_VALUE,
|
||||
Double.MAX_VALUE,
|
||||
-Double.MIN_VALUE,
|
||||
Double.MIN_VALUE,
|
||||
-Math.exp(1.d),
|
||||
-Math.sqrt(2.d),
|
||||
-3.1415f,
|
||||
3.1415f,
|
||||
-1,
|
||||
0,
|
||||
1,
|
||||
};
|
||||
for (double value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isDouble());
|
||||
assertTrue(value == input.toDouble());
|
||||
final IValue output = module.runMethod("eqFloat", input);
|
||||
assertTrue(output.isDouble());
|
||||
assertTrue(value == output.toDouble());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqTensor() throws IOException {
|
||||
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
|
||||
final long numElements = Tensor.numel(inputTensorShape);
|
||||
final float[] inputTensorData = new float[(int) numElements];
|
||||
for (int i = 0; i < numElements; ++i) {
|
||||
inputTensorData[i] = i;
|
||||
}
|
||||
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
|
||||
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue input = IValue.from(inputTensor);
|
||||
assertTrue(input.isTensor());
|
||||
assertTrue(inputTensor == input.toTensor());
|
||||
final IValue output = module.runMethod("eqTensor", input);
|
||||
assertTrue(output.isTensor());
|
||||
final Tensor outputTensor = output.toTensor();
|
||||
assertNotNull(outputTensor);
|
||||
assertArrayEquals(inputTensorShape, outputTensor.shape());
|
||||
float[] outputData = outputTensor.getDataAsFloatArray();
|
||||
for (int i = 0; i < numElements; i++) {
|
||||
assertTrue(inputTensorData[i] == outputData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictIntKeyIntValue() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final Map<Long, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
|
||||
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
|
||||
inputMap.put(0l, IValue.from(0l));
|
||||
inputMap.put(1l, IValue.from(-1l));
|
||||
inputMap.put(-1l, IValue.from(1l));
|
||||
|
||||
final IValue input = IValue.dictLongKeyFrom(inputMap);
|
||||
assertTrue(input.isDictLongKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
|
||||
assertTrue(output.isDictLongKey());
|
||||
|
||||
final Map<Long, IValue> outputMap = output.toDictLongKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqDictStrKeyIntValue() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final Map<String, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
|
||||
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
|
||||
inputMap.put("long_0", IValue.from(0l));
|
||||
inputMap.put("long_1", IValue.from(1l));
|
||||
inputMap.put("long_-1", IValue.from(-1l));
|
||||
|
||||
final IValue input = IValue.dictStringKeyFrom(inputMap);
|
||||
assertTrue(input.isDictStringKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
|
||||
assertTrue(output.isDictStringKey());
|
||||
|
||||
final Map<String, IValue> outputMap = output.toDictStringKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testListIntSumReturnTuple() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
for (int n : new int[] {0, 1, 128}) {
|
||||
long[] a = new long[n];
|
||||
long sum = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
a[i] = i;
|
||||
sum += a[i];
|
||||
}
|
||||
final IValue input = IValue.listFrom(a);
|
||||
assertTrue(input.isLongList());
|
||||
|
||||
final IValue output = module.runMethod("listIntSumReturnTuple", input);
|
||||
|
||||
assertTrue(output.isTuple());
|
||||
assertTrue(2 == output.toTuple().length);
|
||||
|
||||
IValue output0 = output.toTuple()[0];
|
||||
IValue output1 = output.toTuple()[1];
|
||||
|
||||
assertArrayEquals(a, output0.toLongList());
|
||||
assertTrue(sum == output1.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOptionalIntIsNone() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
|
||||
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntEq0None() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testRunUndefinedMethod() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
module.runMethod("test_undefined_method_throws_exception");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorMethods() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
int[] ints = new int[numel];
|
||||
float[] floats = new float[numel];
|
||||
|
||||
byte[] bytes = new byte[numel];
|
||||
for (int i = 0; i < numel; i++) {
|
||||
bytes[i] = (byte) ((i % 255) - 128);
|
||||
ints[i] = i;
|
||||
floats[i] = i / 1000.f;
|
||||
}
|
||||
|
||||
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
|
||||
assertTrue(tensorBytes.dtype() == DType.INT8);
|
||||
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
|
||||
|
||||
Tensor tensorInts = Tensor.fromBlob(ints, shape);
|
||||
assertTrue(tensorInts.dtype() == DType.INT32);
|
||||
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
|
||||
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
float[] floatsOut = tensorFloats.getDataAsFloatArray();
|
||||
assertTrue(floatsOut.length == numel);
|
||||
for (int i = 0; i < numel; i++) {
|
||||
assertTrue(floats[i] == floatsOut[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void testTensorIllegalStateOnWrongType() {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
float[] floats = new float[numel];
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
tensorFloats.getDataAsByteArray();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqString() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("eqStr", input);
|
||||
assertTrue(output.isString());
|
||||
assertTrue(value.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStr3Concat() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
"проверка не латинских символов", // not latin symbols check
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("str3Concat", input);
|
||||
assertTrue(output.isString());
|
||||
String expectedOutput =
|
||||
new StringBuilder().append(value).append(value).append(value).toString();
|
||||
assertTrue(expectedOutput.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyShape() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final long someNumber = 43;
|
||||
final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {}));
|
||||
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
|
||||
assertTrue(output.isTensor());
|
||||
Tensor value = output.toTensor();
|
||||
assertArrayEquals(new long[] {}, value.shape());
|
||||
assertArrayEquals(new long[] {someNumber}, value.getDataAsLongArray());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAliasWithOffset() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue output = module.runMethod("testAliasWithOffset");
|
||||
assertTrue(output.isTensorList());
|
||||
Tensor[] tensors = output.toTensorList();
|
||||
assertEquals(100, tensors[0].getDataAsLongArray()[0]);
|
||||
assertEquals(200, tensors[1].getDataAsLongArray()[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNonContiguous() throws IOException {
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue output = module.runMethod("testNonContiguous");
|
||||
assertTrue(output.isTensor());
|
||||
Tensor value = output.toTensor();
|
||||
assertArrayEquals(new long[] {2}, value.shape());
|
||||
assertArrayEquals(new long[] {100, 300}, value.getDataAsLongArray());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLast() throws IOException {
|
||||
long[] inputShape = new long[] {1, 3, 2, 2};
|
||||
long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
|
||||
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
|
||||
assertIValueTensor(
|
||||
outputNCHW,
|
||||
MemoryFormat.CONTIGUOUS,
|
||||
new long[] {1, 3, 2, 2},
|
||||
new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104});
|
||||
final IValue outputNHWC = module.runMethod("contiguousChannelsLast", IValue.from(inputNHWC));
|
||||
assertIValueTensor(outputNHWC, MemoryFormat.CHANNELS_LAST, inputShape, data);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLast3d() throws IOException {
|
||||
long[] shape = new long[] {1, 2, 2, 2, 2};
|
||||
long[] dataNCHWD = new long[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
|
||||
long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
|
||||
|
||||
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
|
||||
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
|
||||
|
||||
Tensor inputNCHWD = Tensor.fromBlob(dataNCHWD, shape, MemoryFormat.CONTIGUOUS);
|
||||
final IValue outputNHWDC =
|
||||
module.runMethod("contiguousChannelsLast3d", IValue.from(inputNCHWD));
|
||||
assertIValueTensor(outputNHWDC, MemoryFormat.CHANNELS_LAST_3D, shape, dataNHWDC);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLastConv2d() throws IOException {
|
||||
long[] inputShape = new long[] {1, 3, 2, 2};
|
||||
long[] dataNCHW = new long[] {
|
||||
111, 112,
|
||||
121, 122,
|
||||
|
||||
211, 212,
|
||||
221, 222,
|
||||
|
||||
311, 312,
|
||||
321, 322};
|
||||
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataNHWC = new long[] {
|
||||
111, 211, 311, 112, 212, 312,
|
||||
|
||||
121, 221, 321, 122, 222, 322};
|
||||
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||
long[] weightShape = new long[] {3, 3, 1, 1};
|
||||
long[] dataWeightOIHW = new long[] {
|
||||
2, 0, 0,
|
||||
0, 1, 0,
|
||||
0, 0, -1};
|
||||
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataWeightOHWI = new long[] {
|
||||
2, 0, 0,
|
||||
0, 1, 0,
|
||||
0, 0, -1};
|
||||
|
||||
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
|
||||
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
final IValue outputNCHW =
|
||||
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
|
||||
assertIValueTensor(
|
||||
outputNCHW,
|
||||
MemoryFormat.CONTIGUOUS,
|
||||
new long[] {1, 3, 2, 2},
|
||||
new long[] {
|
||||
2*111, 2*112,
|
||||
2*121, 2*122,
|
||||
|
||||
211, 212,
|
||||
221, 222,
|
||||
|
||||
-311, -312,
|
||||
-321, -322});
|
||||
|
||||
final IValue outputNHWC =
|
||||
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
|
||||
assertIValueTensor(
|
||||
outputNHWC,
|
||||
MemoryFormat.CHANNELS_LAST,
|
||||
new long[] {1, 3, 2, 2},
|
||||
new long[] {
|
||||
2*111, 211, -311, 2*112, 212, -312,
|
||||
2*121, 221, -321, 2*122, 222, -322});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLastConv3d() throws IOException {
|
||||
long[] inputShape = new long[] {1, 3, 2, 2, 2};
|
||||
long[] dataNCDHW = new long[] {
|
||||
1111, 1112,
|
||||
1121, 1122,
|
||||
1211, 1212,
|
||||
1221, 1222,
|
||||
|
||||
2111, 2112,
|
||||
2121, 2122,
|
||||
2211, 2212,
|
||||
2221, 2222,
|
||||
|
||||
3111, 3112,
|
||||
3121, 3122,
|
||||
3211, 3212,
|
||||
3221, 3222};
|
||||
Tensor inputNCDHW = Tensor.fromBlob(dataNCDHW, inputShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataNDHWC = new long[] {
|
||||
1111, 2111, 3111,
|
||||
1112, 2112, 3112,
|
||||
|
||||
1121, 2121, 3121,
|
||||
1122, 2122, 3122,
|
||||
|
||||
1211, 2211, 3211,
|
||||
1212, 2212, 3212,
|
||||
|
||||
1221, 2221, 3221,
|
||||
1222, 2222, 3222};
|
||||
|
||||
Tensor inputNDHWC = Tensor.fromBlob(dataNDHWC, inputShape, MemoryFormat.CHANNELS_LAST_3D);
|
||||
|
||||
long[] weightShape = new long[] {3, 3, 1, 1, 1};
|
||||
long[] dataWeightOIDHW = new long[] {
|
||||
2, 0, 0,
|
||||
0, 1, 0,
|
||||
0, 0, -1,
|
||||
};
|
||||
Tensor wNCDHW = Tensor.fromBlob(dataWeightOIDHW, weightShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataWeightODHWI = new long[] {
|
||||
2, 0, 0,
|
||||
0, 1, 0,
|
||||
0, 0, -1,
|
||||
};
|
||||
Tensor wNDHWC = Tensor.fromBlob(dataWeightODHWI, weightShape, MemoryFormat.CHANNELS_LAST_3D);
|
||||
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
final IValue outputNCDHW =
|
||||
module.runMethod("conv3d", IValue.from(inputNCDHW), IValue.from(wNCDHW), IValue.from(false));
|
||||
assertIValueTensor(
|
||||
outputNCDHW,
|
||||
MemoryFormat.CONTIGUOUS,
|
||||
new long[] {1, 3, 2, 2, 2},
|
||||
new long[] {
|
||||
2*1111, 2*1112, 2*1121, 2*1122,
|
||||
2*1211, 2*1212, 2*1221, 2*1222,
|
||||
|
||||
2111, 2112, 2121, 2122,
|
||||
2211, 2212, 2221, 2222,
|
||||
|
||||
-3111, -3112, -3121, -3122,
|
||||
-3211, -3212, -3221, -3222});
|
||||
|
||||
final IValue outputNDHWC =
|
||||
module.runMethod("conv3d", IValue.from(inputNDHWC), IValue.from(wNDHWC), IValue.from(true));
|
||||
assertIValueTensor(
|
||||
outputNDHWC,
|
||||
MemoryFormat.CHANNELS_LAST_3D,
|
||||
new long[] {1, 3, 2, 2, 2},
|
||||
new long[] {
|
||||
2*1111, 2111, -3111, 2*1112, 2112, -3112,
|
||||
2*1121, 2121, -3121, 2*1122, 2122, -3122,
|
||||
|
||||
2*1211, 2211, -3211, 2*1212, 2212, -3212,
|
||||
2*1221, 2221, -3221, 2*1222, 2222, -3222});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMobileNetV2() throws IOException {
|
||||
try {
|
||||
final Module module = loadModel("mobilenet_v2.ptl");
|
||||
final IValue inputs = module.runMethod("get_all_bundled_inputs");
|
||||
assertTrue(inputs.isList());
|
||||
final IValue input = inputs.toList()[0];
|
||||
assertTrue(input.isTuple());
|
||||
module.forward(input.toTuple()[0]);
|
||||
assertTrue(true);
|
||||
} catch (Exception ex) {
|
||||
assertTrue("failed to run MobileNetV2 " + ex.getMessage(), false);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPointwiseOps() throws IOException {
|
||||
runModel("pointwise_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReductionOps() throws IOException {
|
||||
runModel("reduction_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComparisonOps() throws IOException {
|
||||
runModel("comparison_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOtherMathOps() throws IOException {
|
||||
runModel("other_math_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testSpectralOps() throws IOException {
|
||||
// NB: This model fails without lite interpreter. The error is as follows:
|
||||
// RuntimeError: stft requires the return_complex parameter be given for real inputs
|
||||
runModel("spectral_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBlasLapackOps() throws IOException {
|
||||
runModel("blas_lapack_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSamplingOps() throws IOException {
|
||||
runModel("sampling_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorOps() throws IOException {
|
||||
runModel("tensor_general_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorCreationOps() throws IOException {
|
||||
runModel("tensor_creation_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorIndexingOps() throws IOException {
|
||||
runModel("tensor_indexing_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorTypingOps() throws IOException {
|
||||
runModel("tensor_typing_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorViewOps() throws IOException {
|
||||
runModel("tensor_view_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConvolutionOps() throws IOException {
|
||||
runModel("convolution_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPoolingOps() throws IOException {
|
||||
runModel("pooling_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPaddingOps() throws IOException {
|
||||
runModel("padding_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActivationOps() throws IOException {
|
||||
runModel("activation_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNormalizationOps() throws IOException {
|
||||
runModel("normalization_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRecurrentOps() throws IOException {
|
||||
runModel("recurrent_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransformerOps() throws IOException {
|
||||
runModel("transformer_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLinearOps() throws IOException {
|
||||
runModel("linear_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDropoutOps() throws IOException {
|
||||
runModel("dropout_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSparseOps() throws IOException {
|
||||
runModel("sparse_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDistanceFunctionOps() throws IOException {
|
||||
runModel("distance_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLossFunctionOps() throws IOException {
|
||||
runModel("loss_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVisionFunctionOps() throws IOException {
|
||||
runModel("vision_function_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShuffleOps() throws IOException {
|
||||
runModel("shuffle_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNNUtilsOps() throws IOException {
|
||||
runModel("nn_utils_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testQuantOps() throws IOException {
|
||||
runModel("general_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDynamicQuantOps() throws IOException {
|
||||
runModel("dynamic_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStaticQuantOps() throws IOException {
|
||||
runModel("static_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFusedQuantOps() throws IOException {
|
||||
runModel("fused_quant_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTorchScriptBuiltinQuantOps() throws IOException {
|
||||
runModel("torchscript_builtin_ops");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTorchScriptCollectionQuantOps() throws IOException {
|
||||
runModel("torchscript_collection_ops");
|
||||
}
|
||||
|
||||
static void assertIValueTensor(
|
||||
final IValue ivalue,
|
||||
final MemoryFormat memoryFormat,
|
||||
final long[] expectedShape,
|
||||
final long[] expectedData) {
|
||||
assertTrue(ivalue.isTensor());
|
||||
Tensor t = ivalue.toTensor();
|
||||
assertEquals(memoryFormat, t.memoryFormat());
|
||||
assertArrayEquals(expectedShape, t.shape());
|
||||
assertArrayEquals(expectedData, t.getDataAsLongArray());
|
||||
}
|
||||
|
||||
void runModel(final String name) throws IOException {
|
||||
final Module storage_module = loadModel(name + ".ptl");
|
||||
storage_module.forward();
|
||||
|
||||
// TODO enable this once the on-the-fly script is ready
|
||||
// final Module on_the_fly_module = loadModel(name + "_temp.ptl");
|
||||
// on_the_fly_module.forward();
|
||||
assertTrue(true);
|
||||
}
|
||||
|
||||
protected abstract Module loadModel(String assetName) throws IOException;
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
package org.pytorch.suite;
|
||||
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.pytorch.PytorchInstrumentedTests;
|
||||
|
||||
@RunWith(Suite.class)
|
||||
@Suite.SuiteClasses({PytorchInstrumentedTests.class})
|
||||
public class PytorchInstrumentedTestSuite {}
|
@ -0,0 +1,9 @@
|
||||
package org.pytorch.suite;
|
||||
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.pytorch.PytorchLiteInstrumentedTests;
|
||||
|
||||
@RunWith(Suite.class)
|
||||
@Suite.SuiteClasses({PytorchLiteInstrumentedTests.class})
|
||||
public class PytorchLiteInstrumentedTestSuite {}
|
1
android/pytorch_android/src/main/AndroidManifest.xml
Normal file
1
android/pytorch_android/src/main/AndroidManifest.xml
Normal file
@ -0,0 +1 @@
|
||||
<manifest package="org.pytorch" />
|
3
android/pytorch_android/src/main/cpp/cmake_macros.h
Normal file
3
android/pytorch_android/src/main/cpp/cmake_macros.h
Normal file
@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
/* #undef TRACE_ENABLED */
|
3
android/pytorch_android/src/main/cpp/cmake_macros.h.in
Normal file
3
android/pytorch_android/src/main/cpp/cmake_macros.h.in
Normal file
@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
#cmakedefine TRACE_ENABLED
|
697
android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp
Normal file
697
android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp
Normal file
@ -0,0 +1,697 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <fbjni/ByteBuffer.h>
|
||||
#include <fbjni/fbjni.h>
|
||||
|
||||
#include "pytorch_jni_common.h"
|
||||
#if defined(__ANDROID__)
|
||||
#ifndef USE_PTHREADPOOL
|
||||
#define USE_PTHREADPOOL
|
||||
#endif /* USE_PTHREADPOOL */
|
||||
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
|
||||
#endif
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode) {
|
||||
if (deviceJniCode == kDeviceCPU) {
|
||||
return at::kCPU;
|
||||
} else if (deviceJniCode == kDeviceVulkan) {
|
||||
return at::kVulkan;
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException, "Unknown device");
|
||||
}
|
||||
|
||||
bool Trace::is_initialized_ = false;
|
||||
|
||||
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||
Trace::fp_ATrace_beginSection Trace::ATrace_beginSection;
|
||||
Trace::fp_ATrace_endSection Trace::ATrace_endSection;
|
||||
#endif
|
||||
|
||||
void Trace::init() {
|
||||
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||
void* lib = dlopen("libandroid.so", RTLD_NOW || RTLD_LOCAL);
|
||||
if (lib != NULL) {
|
||||
Trace::ATrace_beginSection = reinterpret_cast<fp_ATrace_beginSection>(
|
||||
dlsym(lib, "ATrace_beginSection"));
|
||||
Trace::ATrace_endSection =
|
||||
reinterpret_cast<fp_ATrace_endSection>(dlsym(lib, "ATrace_endSection"));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// NOTE: Codes must be kept in sync with DType.java.
|
||||
// NOTE: Never serialize these, because they can change between releases.
|
||||
constexpr static int kTensorDTypeUInt8 = 1;
|
||||
constexpr static int kTensorDTypeInt8 = 2;
|
||||
constexpr static int kTensorDTypeInt32 = 3;
|
||||
constexpr static int kTensorDTypeFloat32 = 4;
|
||||
constexpr static int kTensorDTypeInt64 = 5;
|
||||
constexpr static int kTensorDTypeFloat64 = 6;
|
||||
|
||||
constexpr static int kTensorMemoryFormatContiguous = 1;
|
||||
constexpr static int kTensorMemoryFormatChannelsLast = 2;
|
||||
constexpr static int kTensorMemoryFormatChannelsLast3d = 3;
|
||||
|
||||
template <typename K = jobject, typename V = jobject>
|
||||
struct JHashMap
|
||||
: facebook::jni::JavaClass<JHashMap<K, V>, facebook::jni::JMap<K, V>> {
|
||||
constexpr static auto kJavaDescriptor = "Ljava/util/HashMap;";
|
||||
|
||||
using Super =
|
||||
facebook::jni::JavaClass<JHashMap<K, V>, facebook::jni::JMap<K, V>>;
|
||||
|
||||
static facebook::jni::local_ref<JHashMap<K, V>> create() {
|
||||
return Super::newInstance();
|
||||
}
|
||||
|
||||
void put(
|
||||
facebook::jni::alias_ref<facebook::jni::JObject::javaobject> key,
|
||||
facebook::jni::alias_ref<facebook::jni::JObject::javaobject> value) {
|
||||
static auto putMethod =
|
||||
Super::javaClassStatic()
|
||||
->template getMethod<facebook::jni::alias_ref<
|
||||
facebook::jni::JObject::javaobject>(
|
||||
facebook::jni::alias_ref<facebook::jni::JObject::javaobject>,
|
||||
facebook::jni::alias_ref<facebook::jni::JObject::javaobject>)>(
|
||||
"put");
|
||||
putMethod(Super::self(), key, value);
|
||||
}
|
||||
};
|
||||
|
||||
static at::Tensor newAtTensor(
|
||||
facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,
|
||||
facebook::jni::alias_ref<jlongArray> jshape,
|
||||
jint jdtype,
|
||||
jint jmemoryFormat) {
|
||||
const auto rank = jshape->size();
|
||||
const auto shapeArr = jshape->getRegion(0, rank);
|
||||
std::vector<int64_t> shapeVec{};
|
||||
shapeVec.reserve(rank);
|
||||
auto numel = 1;
|
||||
for (const auto i : c10::irange(rank)) {
|
||||
shapeVec.push_back(shapeArr[i]);
|
||||
numel *= shapeArr[i];
|
||||
}
|
||||
JNIEnv* jni = facebook::jni::Environment::current();
|
||||
caffe2::TypeMeta typeMeta{};
|
||||
int dataElementSizeBytes = 0;
|
||||
if (kTensorDTypeFloat32 == jdtype) {
|
||||
dataElementSizeBytes = 4;
|
||||
typeMeta = caffe2::TypeMeta::Make<float>();
|
||||
} else if (kTensorDTypeInt32 == jdtype) {
|
||||
dataElementSizeBytes = 4;
|
||||
typeMeta = caffe2::TypeMeta::Make<int32_t>();
|
||||
} else if (kTensorDTypeInt8 == jdtype) {
|
||||
dataElementSizeBytes = 1;
|
||||
typeMeta = caffe2::TypeMeta::Make<int8_t>();
|
||||
} else if (kTensorDTypeUInt8 == jdtype) {
|
||||
dataElementSizeBytes = 1;
|
||||
typeMeta = caffe2::TypeMeta::Make<uint8_t>();
|
||||
} else if (kTensorDTypeFloat64 == jdtype) {
|
||||
dataElementSizeBytes = 8;
|
||||
typeMeta = caffe2::TypeMeta::Make<double>();
|
||||
} else if (kTensorDTypeInt64 == jdtype) {
|
||||
dataElementSizeBytes = 8;
|
||||
typeMeta = caffe2::TypeMeta::Make<int64_t>();
|
||||
} else {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unknown Tensor jdtype %d",
|
||||
jdtype);
|
||||
}
|
||||
const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
|
||||
if (dataCapacity != numel) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Tensor dimensions(elements number:%d, element byte size:%d, total "
|
||||
"bytes:%d) inconsistent with buffer capacity(%d)",
|
||||
numel,
|
||||
dataElementSizeBytes,
|
||||
numel * dataElementSizeBytes,
|
||||
dataCapacity);
|
||||
}
|
||||
|
||||
if (jmemoryFormat == kTensorMemoryFormatChannelsLast) {
|
||||
auto sizes = torch::IntArrayRef(shapeVec);
|
||||
return torch::from_blob(
|
||||
jni->GetDirectBufferAddress(jbuffer.get()),
|
||||
sizes,
|
||||
torch::IntArrayRef(c10::get_channels_last_strides_2d(sizes)),
|
||||
at::TensorOptions(typeMeta).memory_format(
|
||||
at::MemoryFormat::ChannelsLast));
|
||||
} else if (jmemoryFormat == kTensorMemoryFormatChannelsLast3d) {
|
||||
auto sizes = torch::IntArrayRef(shapeVec);
|
||||
return torch::from_blob(
|
||||
jni->GetDirectBufferAddress(jbuffer.get()),
|
||||
sizes,
|
||||
torch::IntArrayRef(c10::get_channels_last_strides_3d(sizes)),
|
||||
at::TensorOptions(typeMeta).memory_format(
|
||||
at::MemoryFormat::ChannelsLast3d));
|
||||
}
|
||||
return torch::from_blob(
|
||||
jni->GetDirectBufferAddress(jbuffer.get()),
|
||||
torch::IntArrayRef(shapeVec),
|
||||
at::TensorOptions(typeMeta));
|
||||
}
|
||||
|
||||
class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
|
||||
public:
|
||||
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/Tensor;";
|
||||
|
||||
explicit TensorHybrid(at::Tensor tensor) : tensor_(tensor) {}
|
||||
|
||||
static facebook::jni::local_ref<TensorHybrid::jhybriddata> initHybrid(
|
||||
facebook::jni::alias_ref<TensorHybrid::javaobject> jTensorThis) {
|
||||
static auto cls = TensorHybrid::javaClassStatic();
|
||||
static const auto jMethodDTypeCode = cls->getMethod<jint()>("dtypeJniCode");
|
||||
static const auto jMethodMemoryFormatCode =
|
||||
cls->getMethod<jint()>("memoryFormatJniCode");
|
||||
static const auto jFieldShape = cls->getField<jlongArray>("shape");
|
||||
static const auto jMethodGetDataBuffer = cls->getMethod<
|
||||
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
|
||||
"getRawDataBuffer");
|
||||
|
||||
at::Tensor tensor = newAtTensor(
|
||||
jMethodGetDataBuffer(jTensorThis),
|
||||
jTensorThis->getFieldValue(jFieldShape),
|
||||
jMethodDTypeCode(jTensorThis),
|
||||
jMethodMemoryFormatCode(jTensorThis));
|
||||
return makeCxxInstance(std::move(tensor));
|
||||
}
|
||||
|
||||
static facebook::jni::local_ref<TensorHybrid::javaobject>
|
||||
newJTensorFromAtTensor(const at::Tensor& input_tensor) {
|
||||
// Java wrapper currently only supports contiguous tensors.
|
||||
|
||||
int jmemoryFormat = 0;
|
||||
at::Tensor tensor{};
|
||||
if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast)) {
|
||||
tensor = input_tensor;
|
||||
jmemoryFormat = kTensorMemoryFormatChannelsLast;
|
||||
} else if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
|
||||
tensor = input_tensor;
|
||||
jmemoryFormat = kTensorMemoryFormatChannelsLast3d;
|
||||
} else {
|
||||
tensor = input_tensor.contiguous();
|
||||
jmemoryFormat = kTensorMemoryFormatContiguous;
|
||||
}
|
||||
|
||||
const auto scalarType = tensor.scalar_type();
|
||||
int jdtype = 0;
|
||||
if (at::kFloat == scalarType) {
|
||||
jdtype = kTensorDTypeFloat32;
|
||||
} else if (at::kInt == scalarType) {
|
||||
jdtype = kTensorDTypeInt32;
|
||||
} else if (at::kByte == scalarType) {
|
||||
jdtype = kTensorDTypeUInt8;
|
||||
} else if (at::kChar == scalarType) {
|
||||
jdtype = kTensorDTypeInt8;
|
||||
} else if (at::kLong == scalarType) {
|
||||
jdtype = kTensorDTypeInt64;
|
||||
} else if (at::kDouble == scalarType) {
|
||||
jdtype = kTensorDTypeFloat64;
|
||||
} else {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"at::Tensor scalar type %s is not supported on java side",
|
||||
c10::toString(scalarType));
|
||||
}
|
||||
|
||||
const auto& tensorShape = tensor.sizes();
|
||||
std::vector<jlong> tensorShapeVec;
|
||||
for (const auto& s : tensorShape) {
|
||||
tensorShapeVec.push_back(s);
|
||||
}
|
||||
facebook::jni::local_ref<jlongArray> jTensorShape =
|
||||
facebook::jni::make_long_array(tensorShapeVec.size());
|
||||
jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data());
|
||||
|
||||
static auto cls = TensorHybrid::javaClassStatic();
|
||||
facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
|
||||
facebook::jni::JByteBuffer::wrapBytes(
|
||||
(uint8_t*)tensor.data_ptr(), tensor.nbytes());
|
||||
jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
|
||||
|
||||
static const auto jMethodNewTensor =
|
||||
cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
|
||||
facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
|
||||
facebook::jni::alias_ref<jlongArray>,
|
||||
jint,
|
||||
jint,
|
||||
facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
|
||||
return jMethodNewTensor(
|
||||
cls,
|
||||
jTensorBuffer,
|
||||
jTensorShape,
|
||||
jdtype,
|
||||
jmemoryFormat,
|
||||
makeCxxInstance(tensor));
|
||||
}
|
||||
|
||||
static at::Tensor newAtTensorFromJTensor(
|
||||
facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor) {
|
||||
static auto cls = TensorHybrid::javaClassStatic();
|
||||
static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
|
||||
jint jdtype = dtypeMethod(jtensor);
|
||||
|
||||
static const auto memoryFormatMethod =
|
||||
cls->getMethod<jint()>("memoryFormatJniCode");
|
||||
jint jmemoryFormat = memoryFormatMethod(jtensor);
|
||||
|
||||
static const auto shapeField = cls->getField<jlongArray>("shape");
|
||||
auto jshape = jtensor->getFieldValue(shapeField);
|
||||
|
||||
static auto dataBufferMethod = cls->getMethod<
|
||||
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
|
||||
"getRawDataBuffer");
|
||||
facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
|
||||
dataBufferMethod(jtensor);
|
||||
return newAtTensor(jbuffer, jshape, jdtype, jmemoryFormat);
|
||||
}
|
||||
|
||||
at::Tensor tensor() const {
|
||||
return tensor_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend HybridBase;
|
||||
at::Tensor tensor_;
|
||||
};
|
||||
|
||||
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromStringDict(
|
||||
c10::Dict<c10::IValue, c10::IValue> dict) {
|
||||
static auto jMethodDictStringKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictStringKeyFrom");
|
||||
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||
for (auto& pair : dict) {
|
||||
jmap->put(
|
||||
facebook::jni::make_jstring(pair.key().toStringRef()),
|
||||
JIValue::newJIValueFromAtIValue(pair.value()));
|
||||
}
|
||||
return jMethodDictStringKey(JIValue::javaClassStatic(), jmap);
|
||||
}
|
||||
|
||||
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromIntDict(
|
||||
c10::Dict<c10::IValue, c10::IValue> dict) {
|
||||
static auto jMethodDictLongKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictLongKeyFrom");
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||
for (auto& pair : dict) {
|
||||
jmap->put(
|
||||
facebook::jni::JLong::valueOf(pair.key().toInt()),
|
||||
JIValue::newJIValueFromAtIValue(pair.value()));
|
||||
}
|
||||
return jMethodDictLongKey(JIValue::javaClassStatic(), jmap);
|
||||
}
|
||||
|
||||
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
|
||||
const at::IValue& ivalue,
|
||||
DictCallback stringDictCallback,
|
||||
DictCallback intDictCallback) {
|
||||
Trace _s{"jni::JIValue::newJIValueFromAtIValue"};
|
||||
if (ivalue.isNone()) {
|
||||
static auto jMethodOptionalNull =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>()>(
|
||||
"optionalNull");
|
||||
return jMethodOptionalNull(JIValue::javaClassStatic());
|
||||
} else if (ivalue.isTensor()) {
|
||||
static auto jMethodTensor =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::local_ref<TensorHybrid::javaobject>)>("from");
|
||||
const auto& tensor = ivalue.toTensor();
|
||||
return jMethodTensor(
|
||||
JIValue::javaClassStatic(),
|
||||
TensorHybrid::newJTensorFromAtTensor(tensor));
|
||||
} else if (ivalue.isBool()) {
|
||||
static auto jMethodBool =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jboolean)>(
|
||||
"from");
|
||||
return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool());
|
||||
} else if (ivalue.isInt()) {
|
||||
static auto jMethodInt =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jlong)>("from");
|
||||
return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt());
|
||||
} else if (ivalue.isDouble()) {
|
||||
static auto jMethodDouble =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jdouble)>(
|
||||
"from");
|
||||
return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble());
|
||||
} else if (ivalue.isString()) {
|
||||
static auto jMethodString =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>)>(
|
||||
"from");
|
||||
return jMethodString(
|
||||
JIValue::javaClassStatic(),
|
||||
facebook::jni::make_jstring(ivalue.toStringRef()));
|
||||
} else if (ivalue.isTuple()) {
|
||||
auto elementsVec = ivalue.toTupleRef().elements();
|
||||
static auto jMethodTupleArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject>)>("tupleFrom");
|
||||
auto jElementsArray =
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
|
||||
elementsVec.size());
|
||||
auto index = 0;
|
||||
for (const auto& e : elementsVec) {
|
||||
(*jElementsArray)[index++] = JIValue::newJIValueFromAtIValue(e);
|
||||
}
|
||||
return jMethodTupleArr(JIValue::javaClassStatic(), jElementsArray);
|
||||
} else if (ivalue.isBoolList()) {
|
||||
auto list = ivalue.toBoolList();
|
||||
static auto jMethodBoolListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jbooleanArray>)>("listFrom");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_boolean_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
auto index = 0;
|
||||
for (const auto& e : list) {
|
||||
jArrayPinned[index++] = e;
|
||||
}
|
||||
return jMethodBoolListArr(JIValue::javaClassStatic(), jArray);
|
||||
} else if (ivalue.isIntList()) {
|
||||
auto list = ivalue.toIntList();
|
||||
static auto jMethodLongListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jlongArray>)>("listFrom");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_long_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
auto index = 0;
|
||||
for (const auto& e : list) {
|
||||
jArrayPinned[index++] = e;
|
||||
}
|
||||
return jMethodLongListArr(JIValue::javaClassStatic(), jArray);
|
||||
} else if (ivalue.isDoubleList()) {
|
||||
auto list = ivalue.toDoubleList();
|
||||
static auto jMethoDoubleListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jdoubleArray>)>("listFrom");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_double_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
auto index = 0;
|
||||
for (const auto& e : list) {
|
||||
jArrayPinned[index++] = e;
|
||||
}
|
||||
return jMethoDoubleListArr(JIValue::javaClassStatic(), jArray);
|
||||
} else if (ivalue.isTensorList()) {
|
||||
auto list = ivalue.toTensorList();
|
||||
static auto jMethodTensorListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
TensorHybrid::javaobject>::javaobject>)>("listFrom");
|
||||
auto jArray =
|
||||
facebook::jni::JArrayClass<TensorHybrid::javaobject>::newArray(
|
||||
list.size());
|
||||
auto index = 0;
|
||||
for (const auto& e : list) {
|
||||
(*jArray)[index++] = TensorHybrid::newJTensorFromAtTensor(e);
|
||||
}
|
||||
return jMethodTensorListArr(JIValue::javaClassStatic(), jArray);
|
||||
} else if (ivalue.isList()) {
|
||||
auto list = ivalue.toList();
|
||||
static auto jMethodListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject>)>("listFrom");
|
||||
auto jArray =
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(list.size());
|
||||
auto index = 0;
|
||||
for (const auto& e : list) {
|
||||
(*jArray)[index++] = JIValue::newJIValueFromAtIValue(e);
|
||||
}
|
||||
return jMethodListArr(JIValue::javaClassStatic(), jArray);
|
||||
} else if (ivalue.isGenericDict()) {
|
||||
auto dict = ivalue.toGenericDict();
|
||||
const auto keyType = dict.keyType();
|
||||
|
||||
if (!keyType) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unknown IValue-Dict key type");
|
||||
}
|
||||
|
||||
if (*keyType == *c10::StringType::get()) {
|
||||
return stringDictCallback(std::move(dict));
|
||||
} else if (*keyType == *c10::IntType::get()) {
|
||||
return intDictCallback(std::move(dict));
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unsupported IValue-Dict key type: %s",
|
||||
keyType->str().c_str());
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unsupported IValue type %s",
|
||||
ivalue.tagKind().c_str());
|
||||
}
|
||||
|
||||
at::IValue JIValue::JIValueToAtIValue(
|
||||
facebook::jni::alias_ref<JIValue> jivalue) {
|
||||
Trace _s{"jni::JIValue::JIValueToAtIValue"};
|
||||
static const auto typeCodeField =
|
||||
JIValue::javaClassStatic()->getField<jint>("mTypeCode");
|
||||
const auto typeCode = jivalue->getFieldValue(typeCodeField);
|
||||
if (JIValue::kTypeCodeNull == typeCode) {
|
||||
return at::IValue{};
|
||||
} else if (JIValue::kTypeCodeTensor == typeCode) {
|
||||
static const auto jMethodGetTensor =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::alias_ref<TensorHybrid::javaobject>()>(
|
||||
"toTensor");
|
||||
return TensorHybrid::newAtTensorFromJTensor(jMethodGetTensor(jivalue));
|
||||
} else if (JIValue::kTypeCodeBool == typeCode) {
|
||||
static const auto jMethodGetBool =
|
||||
JIValue::javaClassStatic()->getMethod<jboolean()>("toBool");
|
||||
// explicit cast to bool as jboolean is defined as uint8_t, IValue ctor
|
||||
// for int will be called for jboolean
|
||||
bool b = jMethodGetBool(jivalue);
|
||||
return at::IValue{b};
|
||||
} else if (JIValue::kTypeCodeLong == typeCode) {
|
||||
static const auto jMethodGetLong =
|
||||
JIValue::javaClassStatic()->getMethod<jlong()>("toLong");
|
||||
return at::IValue{(int64_t)jMethodGetLong(jivalue)};
|
||||
} else if (JIValue::kTypeCodeDouble == typeCode) {
|
||||
static const auto jMethodGetDouble =
|
||||
JIValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
|
||||
return at::IValue{jMethodGetDouble(jivalue)};
|
||||
} else if (JIValue::kTypeCodeString == typeCode) {
|
||||
static const auto jMethodGetString =
|
||||
JIValue::javaClassStatic()->getMethod<jstring()>("toStr");
|
||||
return at::IValue{jMethodGetString(jivalue)->toStdString()};
|
||||
} else if (JIValue::kTypeCodeTuple == typeCode) {
|
||||
static const auto jMethodGetTuple =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject()>(
|
||||
"toTuple");
|
||||
auto jarray = jMethodGetTuple(jivalue);
|
||||
size_t n = jarray->size();
|
||||
|
||||
std::vector<at::IValue> elements;
|
||||
elements.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
auto jivalue_element = jarray->getElement(i);
|
||||
auto element = JIValue::JIValueToAtIValue(jivalue_element);
|
||||
elements.push_back(std::move(element));
|
||||
}
|
||||
return c10::ivalue::Tuple::create(std::move(elements));
|
||||
} else if (JIValue::kTypeCodeBoolList == typeCode) {
|
||||
static const auto jMethodGetBoolList =
|
||||
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("toBoolList");
|
||||
auto jArray = jMethodGetBoolList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
c10::List<bool> list{};
|
||||
list.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
list.push_back(jArrayPinned[i]);
|
||||
}
|
||||
return at::IValue{std::move(list)};
|
||||
} else if (JIValue::kTypeCodeLongList == typeCode) {
|
||||
static const auto jMethodGetLongList =
|
||||
JIValue::javaClassStatic()->getMethod<jlongArray()>("toLongList");
|
||||
auto jArray = jMethodGetLongList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
c10::List<int64_t> list{};
|
||||
list.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
list.push_back(jArrayPinned[i]);
|
||||
}
|
||||
return at::IValue{std::move(list)};
|
||||
} else if (JIValue::kTypeCodeDoubleList == typeCode) {
|
||||
static const auto jMethodGetDoubleList =
|
||||
JIValue::javaClassStatic()->getMethod<jdoubleArray()>("toDoubleList");
|
||||
auto jArray = jMethodGetDoubleList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
c10::List<double> list{};
|
||||
list.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
list.push_back(jArrayPinned[i]);
|
||||
}
|
||||
return at::IValue{std::move(list)};
|
||||
} else if (JIValue::kTypeCodeTensorList == typeCode) {
|
||||
static const auto jMethodGetTensorList =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JArrayClass<
|
||||
TensorHybrid::javaobject>::javaobject()>("toTensorList");
|
||||
auto jArray = jMethodGetTensorList(jivalue);
|
||||
size_t n = jArray->size();
|
||||
c10::List<at::Tensor> list{};
|
||||
list.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
list.push_back(
|
||||
TensorHybrid::newAtTensorFromJTensor(jArray->getElement(i)));
|
||||
}
|
||||
return at::IValue{std::move(list)};
|
||||
} else if (JIValue::kTypeCodeList == typeCode) {
|
||||
static const auto jMethodGetList =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject()>(
|
||||
"toList");
|
||||
auto jarray = jMethodGetList(jivalue);
|
||||
size_t n = jarray->size();
|
||||
if (n == 0) {
|
||||
return at::IValue{c10::impl::GenericList(c10::TensorType::get())};
|
||||
}
|
||||
|
||||
auto jivalue_first_element = jarray->getElement(0);
|
||||
auto first_element = JIValue::JIValueToAtIValue(jivalue_first_element);
|
||||
c10::impl::GenericList list{c10::unshapedType(first_element.type())};
|
||||
list.reserve(n);
|
||||
list.push_back(first_element);
|
||||
for (const auto i : c10::irange(1, n)) {
|
||||
auto jivalue_element = jarray->getElement(i);
|
||||
auto element = JIValue::JIValueToAtIValue(jivalue_element);
|
||||
list.push_back(element);
|
||||
}
|
||||
return at::IValue{list};
|
||||
} else if (JIValue::kTypeCodeDictStringKey == typeCode) {
|
||||
static const auto jMethodGetDictStringKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JMap<jstring, JIValue::javaobject>::
|
||||
javaobject()>("toDictStringKey");
|
||||
auto jmap = jMethodGetDictStringKey(jivalue);
|
||||
auto it = jmap->begin();
|
||||
if (it == jmap->end()) {
|
||||
return at::IValue{c10::impl::GenericDict(
|
||||
c10::StringType::get(), c10::TensorType::get())};
|
||||
}
|
||||
|
||||
auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
|
||||
c10::impl::GenericDict dict{
|
||||
c10::StringType::get(), c10::unshapedType(firstEntryValue.type())};
|
||||
dict.insert(it->first->toStdString(), firstEntryValue);
|
||||
it++;
|
||||
for (; it != jmap->end(); it++) {
|
||||
dict.insert(
|
||||
it->first->toStdString(), JIValue::JIValueToAtIValue(it->second));
|
||||
}
|
||||
return at::IValue{dict};
|
||||
} else if (JIValue::kTypeCodeDictLongKey == typeCode) {
|
||||
static const auto jMethodGetDictLongKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JMap<
|
||||
facebook::jni::JLong::javaobject,
|
||||
JIValue::javaobject>::javaobject()>("toDictLongKey");
|
||||
auto jmap = jMethodGetDictLongKey(jivalue);
|
||||
auto it = jmap->begin();
|
||||
if (it == jmap->end()) {
|
||||
return at::IValue{
|
||||
c10::impl::GenericDict(c10::IntType::get(), c10::TensorType::get())};
|
||||
}
|
||||
|
||||
auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
|
||||
c10::impl::GenericDict dict{
|
||||
c10::IntType::get(), c10::unshapedType(firstEntryValue.type())};
|
||||
dict.insert((int64_t)it->first->longValue(), firstEntryValue);
|
||||
it++;
|
||||
for (; it != jmap->end(); it++) {
|
||||
dict.insert(
|
||||
(int64_t)it->first->longValue(),
|
||||
JIValue::JIValueToAtIValue(it->second));
|
||||
}
|
||||
return at::IValue{dict};
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unknown IValue typeCode %d",
|
||||
typeCode);
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
|
||||
public:
|
||||
constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;";
|
||||
|
||||
static void registerNatives() {
|
||||
javaClassStatic()->registerNatives({
|
||||
makeNativeMethod(
|
||||
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
|
||||
});
|
||||
}
|
||||
|
||||
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
|
||||
caffe2::pthreadpool()->set_thread_count(numThreads);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
void common_registerNatives() {
|
||||
static const int once = []() {
|
||||
#if defined(__ANDROID__)
|
||||
pytorch_jni::PyTorchAndroidJni::registerNatives();
|
||||
#endif
|
||||
return 0;
|
||||
}();
|
||||
((void)once);
|
||||
}
|
||||
|
||||
} // namespace pytorch_jni
|
137
android/pytorch_android/src/main/cpp/pytorch_jni_common.h
Normal file
137
android/pytorch_android/src/main/cpp/pytorch_jni_common.h
Normal file
@ -0,0 +1,137 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/FunctionRef.h>
|
||||
#include <fbjni/fbjni.h>
|
||||
#include <torch/csrc/api/include/torch/types.h>
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
#include "cmake_macros.h"
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#include <android/log.h>
|
||||
#define ALOGI(...) \
|
||||
__android_log_print(ANDROID_LOG_INFO, "pytorch-jni", __VA_ARGS__)
|
||||
#define ALOGE(...) \
|
||||
__android_log_print(ANDROID_LOG_ERROR, "pytorch-jni", __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||
#include <android/trace.h>
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
constexpr static int kDeviceCPU = 1;
|
||||
constexpr static int kDeviceVulkan = 2;
|
||||
|
||||
c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode);
|
||||
|
||||
class Trace {
|
||||
public:
|
||||
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||
typedef void* (*fp_ATrace_beginSection)(const char* sectionName);
|
||||
typedef void* (*fp_ATrace_endSection)(void);
|
||||
|
||||
static fp_ATrace_beginSection ATrace_beginSection;
|
||||
static fp_ATrace_endSection ATrace_endSection;
|
||||
#endif
|
||||
|
||||
static void ensureInit() {
|
||||
if (!Trace::is_initialized_) {
|
||||
init();
|
||||
Trace::is_initialized_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
static void beginSection(const char* name) {
|
||||
Trace::ensureInit();
|
||||
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||
ATrace_beginSection(name);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void endSection() {
|
||||
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||
ATrace_endSection();
|
||||
#endif
|
||||
}
|
||||
|
||||
Trace(const char* name) {
|
||||
ensureInit();
|
||||
beginSection(name);
|
||||
}
|
||||
|
||||
~Trace() {
|
||||
endSection();
|
||||
}
|
||||
|
||||
private:
|
||||
static void init();
|
||||
static bool is_initialized_;
|
||||
};
|
||||
|
||||
class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
|
||||
public:
|
||||
explicit MemoryReadAdapter(const void* data, off_t size)
|
||||
: data_(data), size_(size){};
|
||||
|
||||
size_t size() const override {
|
||||
return size_;
|
||||
}
|
||||
|
||||
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||
const override {
|
||||
memcpy(buf, (int8_t*)(data_) + pos, n);
|
||||
return n;
|
||||
}
|
||||
|
||||
~MemoryReadAdapter() {}
|
||||
|
||||
private:
|
||||
const void* data_;
|
||||
off_t size_;
|
||||
};
|
||||
|
||||
class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
using DictCallback = c10::function_ref<facebook::jni::local_ref<JIValue>(
|
||||
c10::Dict<c10::IValue, c10::IValue>)>;
|
||||
|
||||
public:
|
||||
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
|
||||
|
||||
constexpr static int kTypeCodeNull = 1;
|
||||
|
||||
constexpr static int kTypeCodeTensor = 2;
|
||||
constexpr static int kTypeCodeBool = 3;
|
||||
constexpr static int kTypeCodeLong = 4;
|
||||
constexpr static int kTypeCodeDouble = 5;
|
||||
constexpr static int kTypeCodeString = 6;
|
||||
|
||||
constexpr static int kTypeCodeTuple = 7;
|
||||
constexpr static int kTypeCodeBoolList = 8;
|
||||
constexpr static int kTypeCodeLongList = 9;
|
||||
constexpr static int kTypeCodeDoubleList = 10;
|
||||
constexpr static int kTypeCodeTensorList = 11;
|
||||
constexpr static int kTypeCodeList = 12;
|
||||
|
||||
constexpr static int kTypeCodeDictStringKey = 13;
|
||||
constexpr static int kTypeCodeDictLongKey = 14;
|
||||
|
||||
static facebook::jni::local_ref<JIValue> newJIValueFromAtIValue(
|
||||
const at::IValue& ivalue,
|
||||
DictCallback stringDictCallback = newJIValueFromStringDict,
|
||||
DictCallback intDictCallback = newJIValueFromIntDict);
|
||||
|
||||
static at::IValue JIValueToAtIValue(
|
||||
facebook::jni::alias_ref<JIValue> jivalue);
|
||||
|
||||
private:
|
||||
static facebook::jni::local_ref<JIValue> newJIValueFromStringDict(
|
||||
c10::Dict<c10::IValue, c10::IValue>);
|
||||
static facebook::jni::local_ref<JIValue> newJIValueFromIntDict(
|
||||
c10::Dict<c10::IValue, c10::IValue>);
|
||||
};
|
||||
|
||||
void common_registerNatives();
|
||||
} // namespace pytorch_jni
|
245
android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
Normal file
245
android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
Normal file
@ -0,0 +1,245 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <fbjni/ByteBuffer.h>
|
||||
#include <fbjni/fbjni.h>
|
||||
|
||||
#include <ATen/record_function.h>
|
||||
#include <torch/csrc/jit/runtime/print_handler.h>
|
||||
#include <torch/script.h>
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
#include "pytorch_jni_common.h"
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
#include <android/log.h>
|
||||
#endif
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
namespace {
|
||||
|
||||
struct JITCallGuard {
|
||||
// Inference only workload.
|
||||
c10::InferenceMode guard;
|
||||
// Disable graph optimizer to ensure list of unused ops are not changed for
|
||||
// custom mobile build.
|
||||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
private:
|
||||
friend HybridBase;
|
||||
torch::jit::Module module_;
|
||||
c10::DeviceType deviceType_;
|
||||
|
||||
public:
|
||||
constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;";
|
||||
|
||||
static facebook::jni::local_ref<jhybriddata> initHybrid(
|
||||
facebook::jni::alias_ref<jclass>,
|
||||
facebook::jni::alias_ref<jstring> modelPath,
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
|
||||
extraFiles,
|
||||
jint device) {
|
||||
return makeCxxInstance(modelPath, extraFiles, device);
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
static facebook::jni::local_ref<jhybriddata> initHybridAndroidAsset(
|
||||
facebook::jni::alias_ref<jclass>,
|
||||
facebook::jni::alias_ref<jstring> assetName,
|
||||
facebook::jni::alias_ref<jobject> assetManager,
|
||||
jint device) {
|
||||
return makeCxxInstance(assetName, assetManager, device);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef TRACE_ENABLED
|
||||
static std::unique_ptr<at::ObserverContext> onFunctionEnter(
|
||||
const at::RecordFunction& fn) {
|
||||
Trace::beginSection(fn.name().str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void onFunctionExit(const at::RecordFunction&, at::ObserverContext*) {
|
||||
Trace::endSection();
|
||||
}
|
||||
#endif
|
||||
|
||||
static void preModuleLoadSetupOnce() {
|
||||
auto qengines = at::globalContext().supportedQEngines();
|
||||
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
|
||||
qengines.end()) {
|
||||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
torch::jit::setPrintHandler([](const std::string& s) {
|
||||
__android_log_print(ANDROID_LOG_DEBUG, "pytorch-print", "%s", s.c_str());
|
||||
});
|
||||
#endif
|
||||
|
||||
#ifdef TRACE_ENABLED
|
||||
at::addGlobalCallback(
|
||||
at::RecordFunctionCallback(&onFunctionEnter, &onFunctionExit)
|
||||
.scopes({RecordScope::FUNCTION, RecordScope::USER_SCOPE}));
|
||||
#endif
|
||||
}
|
||||
|
||||
void preModuleLoadSetup() {
|
||||
static const int once = []() {
|
||||
preModuleLoadSetupOnce();
|
||||
return 0;
|
||||
}();
|
||||
((void)once);
|
||||
}
|
||||
|
||||
PytorchJni(
|
||||
facebook::jni::alias_ref<jstring> modelPath,
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
|
||||
extraFiles,
|
||||
jint device) {
|
||||
preModuleLoadSetup();
|
||||
JITCallGuard guard;
|
||||
std::unordered_map<std::string, std::string> extra_files;
|
||||
const auto has_extra = extraFiles && extraFiles->size() > 0;
|
||||
if (has_extra) {
|
||||
for (const auto& e : *extraFiles) {
|
||||
extra_files[e.first->toStdString()] = "";
|
||||
}
|
||||
}
|
||||
deviceType_ = deviceJniCodeToDeviceType(device);
|
||||
module_ = torch::jit::load(
|
||||
std::move(modelPath->toStdString()), std::nullopt, extra_files);
|
||||
if (has_extra) {
|
||||
static auto putMethod =
|
||||
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>::
|
||||
javaClassStatic()
|
||||
->template getMethod<facebook::jni::alias_ref<jobject>(
|
||||
facebook::jni::alias_ref<jobject>,
|
||||
facebook::jni::alias_ref<jobject>)>("put");
|
||||
for (const auto& ef : extra_files) {
|
||||
putMethod(
|
||||
extraFiles,
|
||||
facebook::jni::make_jstring(ef.first),
|
||||
facebook::jni::make_jstring(ef.second));
|
||||
}
|
||||
}
|
||||
|
||||
module_.eval();
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
PytorchJni(
|
||||
facebook::jni::alias_ref<jstring> assetName,
|
||||
facebook::jni::alias_ref<jobject> assetManager,
|
||||
jint device) {
|
||||
preModuleLoadSetup();
|
||||
JNIEnv* env = facebook::jni::Environment::current();
|
||||
AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get());
|
||||
if (!mgr) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unable to get asset manager");
|
||||
}
|
||||
AAsset* asset = AAssetManager_open(
|
||||
mgr, assetName->toStdString().c_str(), AASSET_MODE_BUFFER);
|
||||
if (!asset) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Failed to open asset '%s'",
|
||||
assetName->toStdString().c_str());
|
||||
}
|
||||
auto assetBuffer = AAsset_getBuffer(asset);
|
||||
if (!assetBuffer) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Could not get buffer for asset '%s'",
|
||||
assetName->toStdString().c_str());
|
||||
}
|
||||
JITCallGuard guard;
|
||||
module_ = torch::jit::load(std::make_unique<MemoryReadAdapter>(
|
||||
assetBuffer, AAsset_getLength(asset)));
|
||||
AAsset_close(asset);
|
||||
module_.eval();
|
||||
deviceType_ = deviceJniCodeToDeviceType(device);
|
||||
}
|
||||
#endif
|
||||
|
||||
static void registerNatives() {
|
||||
registerHybrid({
|
||||
makeNativeMethod("initHybrid", PytorchJni::initHybrid),
|
||||
#ifdef __ANDROID__
|
||||
makeNativeMethod(
|
||||
"initHybridAndroidAsset", PytorchJni::initHybridAndroidAsset),
|
||||
#endif
|
||||
makeNativeMethod("forward", PytorchJni::forward),
|
||||
makeNativeMethod("runMethod", PytorchJni::runMethod),
|
||||
});
|
||||
}
|
||||
|
||||
facebook::jni::local_ref<JIValue> forward(
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
|
||||
jinputs) {
|
||||
Trace _s{"jni::Module::forward"};
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
auto output = [&]() {
|
||||
JITCallGuard guard;
|
||||
return module_.forward(std::move(inputs));
|
||||
}();
|
||||
return JIValue::newJIValueFromAtIValue(output);
|
||||
}
|
||||
|
||||
facebook::jni::local_ref<JIValue> runMethod(
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject> jmethodName,
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
|
||||
jinputs) {
|
||||
std::string methodName = jmethodName->toStdString();
|
||||
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
if (auto method = module_.find_method(methodName)) {
|
||||
auto output = [&]() {
|
||||
JITCallGuard guard;
|
||||
return (*method)(std::move(inputs));
|
||||
}();
|
||||
return JIValue::newJIValueFromAtIValue(output);
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Undefined method %s",
|
||||
methodName.c_str());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace pytorch_jni
|
||||
|
||||
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
|
||||
return facebook::jni::initialize(vm, [] {
|
||||
pytorch_jni::common_registerNatives();
|
||||
pytorch_jni::PytorchJni::registerNatives();
|
||||
});
|
||||
}
|
209
android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp
Normal file
209
android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp
Normal file
@ -0,0 +1,209 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <fbjni/ByteBuffer.h>
|
||||
#include <fbjni/fbjni.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/script.h>
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
#include "pytorch_jni_common.h"
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
#include <android/log.h>
|
||||
#endif
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
namespace {
|
||||
|
||||
struct LiteJITCallGuard {
|
||||
// VariableType dispatch is not included in default mobile build. We need set
|
||||
// this guard globally to avoid dispatch error (only for dynamic dispatch).
|
||||
// Thanks to the unification of Variable class and Tensor class it's no longer
|
||||
// required to toggle the NonVariableTypeMode per op - so it doesn't hurt to
|
||||
// always set NonVariableTypeMode for inference only use case.
|
||||
// TODO: Ideally AutoNonVariableTypeMode in this file should be changed to
|
||||
// InferenceMode but it's blocked due to typeahead application on Oculus
|
||||
// (D27943428). To unblock, we need to find out which op is making inplace
|
||||
// update to an inference tensor outside InferenceMode and properly guard it.
|
||||
torch::AutoNonVariableTypeMode non_var_guard;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
private:
|
||||
friend HybridBase;
|
||||
torch::jit::mobile::Module module_;
|
||||
c10::DeviceType deviceType_;
|
||||
|
||||
public:
|
||||
constexpr static auto kJavaDescriptor = "Lorg/pytorch/LiteNativePeer;";
|
||||
|
||||
static facebook::jni::local_ref<jhybriddata> initHybrid(
|
||||
facebook::jni::alias_ref<jclass>,
|
||||
facebook::jni::alias_ref<jstring> modelPath,
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
|
||||
extraFiles,
|
||||
jint device) {
|
||||
return makeCxxInstance(modelPath, extraFiles, device);
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
static facebook::jni::local_ref<jhybriddata> initHybridAndroidAsset(
|
||||
facebook::jni::alias_ref<jclass>,
|
||||
facebook::jni::alias_ref<jstring> assetName,
|
||||
facebook::jni::alias_ref<jobject> assetManager,
|
||||
jint device) {
|
||||
return makeCxxInstance(assetName, assetManager, device);
|
||||
}
|
||||
#endif
|
||||
|
||||
PytorchJni(
|
||||
facebook::jni::alias_ref<jstring> modelPath,
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
|
||||
extraFiles,
|
||||
jint device) {
|
||||
LiteJITCallGuard guard;
|
||||
std::unordered_map<std::string, std::string> extra_files;
|
||||
const auto has_extra = extraFiles && extraFiles->size() > 0;
|
||||
if (has_extra) {
|
||||
for (const auto& e : *extraFiles) {
|
||||
extra_files[e.first->toStdString()] = "";
|
||||
}
|
||||
}
|
||||
deviceType_ = deviceJniCodeToDeviceType(device);
|
||||
module_ = torch::jit::_load_for_mobile(
|
||||
std::move(modelPath->toStdString()), std::nullopt, extra_files);
|
||||
torch::jit::_load_extra_only_for_mobile(
|
||||
std::move(modelPath->toStdString()), std::nullopt, extra_files);
|
||||
if (has_extra) {
|
||||
static auto putMethod =
|
||||
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>::
|
||||
javaClassStatic()
|
||||
->template getMethod<facebook::jni::alias_ref<jobject>(
|
||||
facebook::jni::alias_ref<jobject>,
|
||||
facebook::jni::alias_ref<jobject>)>("put");
|
||||
for (const auto& ef : extra_files) {
|
||||
putMethod(
|
||||
extraFiles,
|
||||
facebook::jni::make_jstring(ef.first),
|
||||
facebook::jni::make_jstring(ef.second));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
PytorchJni(
|
||||
facebook::jni::alias_ref<jstring> assetName,
|
||||
facebook::jni::alias_ref<jobject> assetManager,
|
||||
jint device) {
|
||||
JNIEnv* env = facebook::jni::Environment::current();
|
||||
AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get());
|
||||
if (!mgr) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unable to get asset manager");
|
||||
}
|
||||
AAsset* asset = AAssetManager_open(
|
||||
mgr, assetName->toStdString().c_str(), AASSET_MODE_BUFFER);
|
||||
if (!asset) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Failed to open asset '%s'",
|
||||
assetName->toStdString().c_str());
|
||||
}
|
||||
auto assetBuffer = AAsset_getBuffer(asset);
|
||||
if (!assetBuffer) {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Could not get buffer for asset '%s'",
|
||||
assetName->toStdString().c_str());
|
||||
}
|
||||
LiteJITCallGuard guard;
|
||||
module_ =
|
||||
torch::jit::_load_for_mobile(std::make_unique<MemoryReadAdapter>(
|
||||
assetBuffer, AAsset_getLength(asset)));
|
||||
AAsset_close(asset);
|
||||
deviceType_ = deviceJniCodeToDeviceType(device);
|
||||
}
|
||||
#endif
|
||||
|
||||
static void registerNatives() {
|
||||
registerHybrid({
|
||||
makeNativeMethod("initHybrid", PytorchJni::initHybrid),
|
||||
#ifdef __ANDROID__
|
||||
makeNativeMethod(
|
||||
"initHybridAndroidAsset", PytorchJni::initHybridAndroidAsset),
|
||||
#endif
|
||||
makeNativeMethod("forward", PytorchJni::forward),
|
||||
makeNativeMethod("runMethod", PytorchJni::runMethod),
|
||||
});
|
||||
}
|
||||
|
||||
facebook::jni::local_ref<JIValue> forward(
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
|
||||
jinputs) {
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
|
||||
auto output = [&]() {
|
||||
LiteJITCallGuard guard;
|
||||
return module_.forward(inputs);
|
||||
}();
|
||||
return JIValue::newJIValueFromAtIValue(output);
|
||||
}
|
||||
|
||||
facebook::jni::local_ref<JIValue> runMethod(
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject> jmethodName,
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
|
||||
jinputs) {
|
||||
std::string methodName = jmethodName->toStdString();
|
||||
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
if (auto method = module_.find_method(methodName)) {
|
||||
auto output = [&]() {
|
||||
LiteJITCallGuard guard;
|
||||
return module_.get_method(methodName)(inputs);
|
||||
}();
|
||||
return JIValue::newJIValueFromAtIValue(output);
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Undefined method %s",
|
||||
methodName.c_str());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace pytorch_jni
|
||||
|
||||
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
|
||||
return facebook::jni::initialize(vm, [] {
|
||||
pytorch_jni::common_registerNatives();
|
||||
pytorch_jni::PytorchJni::registerNatives();
|
||||
});
|
||||
}
|
27
android/pytorch_android/src/main/java/org/pytorch/DType.java
Normal file
27
android/pytorch_android/src/main/java/org/pytorch/DType.java
Normal file
@ -0,0 +1,27 @@
|
||||
package org.pytorch;
|
||||
|
||||
/** Codes representing tensor data types. */
|
||||
public enum DType {
|
||||
// NOTE: "jniCode" must be kept in sync with pytorch_jni_common.cpp.
|
||||
// NOTE: Never serialize "jniCode", because it can change between releases.
|
||||
|
||||
/** Code for dtype torch.uint8. {@link Tensor#dtype()} */
|
||||
UINT8(1),
|
||||
/** Code for dtype torch.int8. {@link Tensor#dtype()} */
|
||||
INT8(2),
|
||||
/** Code for dtype torch.int32. {@link Tensor#dtype()} */
|
||||
INT32(3),
|
||||
/** Code for dtype torch.float32. {@link Tensor#dtype()} */
|
||||
FLOAT32(4),
|
||||
/** Code for dtype torch.int64. {@link Tensor#dtype()} */
|
||||
INT64(5),
|
||||
/** Code for dtype torch.float64. {@link Tensor#dtype()} */
|
||||
FLOAT64(6),
|
||||
;
|
||||
|
||||
final int jniCode;
|
||||
|
||||
DType(int jniCode) {
|
||||
this.jniCode = jniCode;
|
||||
}
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
package org.pytorch;
|
||||
|
||||
public enum Device {
|
||||
// Must be in sync with kDeviceCPU, kDeviceVulkan in
|
||||
// pytorch_android/src/main/cpp/pytorch_jni_lite.cpp
|
||||
CPU(1),
|
||||
VULKAN(2),
|
||||
;
|
||||
|
||||
final int jniCode;
|
||||
|
||||
Device(int jniCode) {
|
||||
this.jniCode = jniCode;
|
||||
}
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
package org.pytorch;
|
||||
|
||||
interface INativePeer {
|
||||
void resetNative();
|
||||
|
||||
IValue forward(IValue... inputs);
|
||||
|
||||
IValue runMethod(String methodName, IValue... inputs);
|
||||
}
|
343
android/pytorch_android/src/main/java/org/pytorch/IValue.java
Normal file
343
android/pytorch_android/src/main/java/org/pytorch/IValue.java
Normal file
@ -0,0 +1,343 @@
|
||||
package org.pytorch;
|
||||
|
||||
import com.facebook.jni.annotations.DoNotStrip;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Java representation of a TorchScript value, which is implemented as tagged union that can be one
|
||||
* of the supported types: https://pytorch.org/docs/stable/jit.html#types .
|
||||
*
|
||||
* <p>Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}.
|
||||
*
|
||||
* <p>{@code IValue} objects are constructed with {@code IValue.from(value)}, {@code
|
||||
* IValue.tupleFrom(value1, value2, ...)}, {@code IValue.listFrom(value1, value2, ...)}, or one of
|
||||
* the {@code dict} methods, depending on the key type.
|
||||
*
|
||||
* <p>Data is retrieved from {@code IValue} objects with the {@code toX()} methods. Note that {@code
|
||||
* str}-type IValues must be extracted with {@link #toStr()}, rather than {@link #toString()}.
|
||||
*
|
||||
* <p>{@code IValue} objects may retain references to objects passed into their constructors, and
|
||||
* may return references to their internal state from {@code toX()}.
|
||||
*/
|
||||
@DoNotStrip
|
||||
public class IValue {
|
||||
private static final int TYPE_CODE_NULL = 1;
|
||||
|
||||
private static final int TYPE_CODE_TENSOR = 2;
|
||||
private static final int TYPE_CODE_BOOL = 3;
|
||||
private static final int TYPE_CODE_LONG = 4;
|
||||
private static final int TYPE_CODE_DOUBLE = 5;
|
||||
private static final int TYPE_CODE_STRING = 6;
|
||||
|
||||
private static final int TYPE_CODE_TUPLE = 7;
|
||||
private static final int TYPE_CODE_BOOL_LIST = 8;
|
||||
private static final int TYPE_CODE_LONG_LIST = 9;
|
||||
private static final int TYPE_CODE_DOUBLE_LIST = 10;
|
||||
private static final int TYPE_CODE_TENSOR_LIST = 11;
|
||||
private static final int TYPE_CODE_LIST = 12;
|
||||
|
||||
private static final int TYPE_CODE_DICT_STRING_KEY = 13;
|
||||
private static final int TYPE_CODE_DICT_LONG_KEY = 14;
|
||||
|
||||
private String[] TYPE_NAMES = {
|
||||
"Unknown",
|
||||
"Null",
|
||||
"Tensor",
|
||||
"Bool",
|
||||
"Long",
|
||||
"Double",
|
||||
"String",
|
||||
"Tuple",
|
||||
"BoolList",
|
||||
"LongList",
|
||||
"DoubleList",
|
||||
"TensorList",
|
||||
"GenericList",
|
||||
"DictStringKey",
|
||||
"DictLongKey",
|
||||
};
|
||||
|
||||
@DoNotStrip private final int mTypeCode;
|
||||
@DoNotStrip private Object mData;
|
||||
|
||||
@DoNotStrip
|
||||
private IValue(int typeCode) {
|
||||
this.mTypeCode = typeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isNull() {
|
||||
return TYPE_CODE_NULL == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isTensor() {
|
||||
return TYPE_CODE_TENSOR == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isBool() {
|
||||
return TYPE_CODE_BOOL == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isLong() {
|
||||
return TYPE_CODE_LONG == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isDouble() {
|
||||
return TYPE_CODE_DOUBLE == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isString() {
|
||||
return TYPE_CODE_STRING == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isTuple() {
|
||||
return TYPE_CODE_TUPLE == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isBoolList() {
|
||||
return TYPE_CODE_BOOL_LIST == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isLongList() {
|
||||
return TYPE_CODE_LONG_LIST == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isDoubleList() {
|
||||
return TYPE_CODE_DOUBLE_LIST == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isTensorList() {
|
||||
return TYPE_CODE_TENSOR_LIST == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isList() {
|
||||
return TYPE_CODE_LIST == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isDictStringKey() {
|
||||
return TYPE_CODE_DICT_STRING_KEY == this.mTypeCode;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean isDictLongKey() {
|
||||
return TYPE_CODE_DICT_LONG_KEY == this.mTypeCode;
|
||||
}
|
||||
|
||||
/** Creates a new {@code IValue} of type {@code Optional} that contains no value. */
|
||||
@DoNotStrip
|
||||
public static IValue optionalNull() {
|
||||
return new IValue(TYPE_CODE_NULL);
|
||||
}
|
||||
/** Creates a new {@code IValue} of type {@code Tensor}. */
|
||||
@DoNotStrip
|
||||
public static IValue from(Tensor tensor) {
|
||||
final IValue iv = new IValue(TYPE_CODE_TENSOR);
|
||||
iv.mData = tensor;
|
||||
return iv;
|
||||
}
|
||||
/** Creates a new {@code IValue} of type {@code bool}. */
|
||||
@DoNotStrip
|
||||
public static IValue from(boolean value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_BOOL);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/** Creates a new {@code IValue} of type {@code int}. */
|
||||
@DoNotStrip
|
||||
public static IValue from(long value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_LONG);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
}
|
||||
/** Creates a new {@code IValue} of type {@code float}. */
|
||||
@DoNotStrip
|
||||
public static IValue from(double value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DOUBLE);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
}
|
||||
/** Creates a new {@code IValue} of type {@code str}. */
|
||||
@DoNotStrip
|
||||
public static IValue from(String value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_STRING);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/** Creates a new {@code IValue} of type {@code List[bool]}. */
|
||||
@DoNotStrip
|
||||
public static IValue listFrom(boolean... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_BOOL_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
}
|
||||
/** Creates a new {@code IValue} of type {@code List[int]}. */
|
||||
@DoNotStrip
|
||||
public static IValue listFrom(long... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_LONG_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
}
|
||||
/** Creates a new {@code IValue} of type {@code List[float]}. */
|
||||
@DoNotStrip
|
||||
public static IValue listFrom(double... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DOUBLE_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/** Creates a new {@code IValue} of type {@code List[Tensor]}. */
|
||||
@DoNotStrip
|
||||
public static IValue listFrom(Tensor... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_TENSOR_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/** Creates a new {@code IValue} of type {@code List[T]}. All elements must have the same type. */
|
||||
@DoNotStrip
|
||||
public static IValue listFrom(IValue... array) {
|
||||
final int size = array.length;
|
||||
if (size > 0) {
|
||||
final int typeCode0 = array[0].mTypeCode;
|
||||
for (int i = 1; i < size; i++) {
|
||||
if (typeCode0 != array[i].mTypeCode) {
|
||||
throw new IllegalArgumentException("List must contain items of the same type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
final IValue iv = new IValue(TYPE_CODE_LIST);
|
||||
iv.mData = array;
|
||||
return iv;
|
||||
}
|
||||
/** Creates a new {@code IValue} of type {@code Tuple[T0, T1, ...]}. */
|
||||
@DoNotStrip
|
||||
public static IValue tupleFrom(IValue... array) {
|
||||
final IValue iv = new IValue(TYPE_CODE_TUPLE);
|
||||
iv.mData = array;
|
||||
return iv;
|
||||
}
|
||||
|
||||
/** Creates a new {@code IValue} of type {@code Dict[str, V]}. */
|
||||
@DoNotStrip
|
||||
public static IValue dictStringKeyFrom(Map<String, IValue> map) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DICT_STRING_KEY);
|
||||
iv.mData = map;
|
||||
return iv;
|
||||
}
|
||||
/** Creates a new {@code IValue} of type {@code Dict[int, V]}. */
|
||||
@DoNotStrip
|
||||
public static IValue dictLongKeyFrom(Map<Long, IValue> map) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DICT_LONG_KEY);
|
||||
iv.mData = map;
|
||||
return iv;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public Tensor toTensor() {
|
||||
preconditionType(TYPE_CODE_TENSOR, mTypeCode);
|
||||
return (Tensor) mData;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean toBool() {
|
||||
preconditionType(TYPE_CODE_BOOL, mTypeCode);
|
||||
return (boolean) mData;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public long toLong() {
|
||||
preconditionType(TYPE_CODE_LONG, mTypeCode);
|
||||
return (long) mData;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public double toDouble() {
|
||||
preconditionType(TYPE_CODE_DOUBLE, mTypeCode);
|
||||
return (double) mData;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public String toStr() {
|
||||
preconditionType(TYPE_CODE_STRING, mTypeCode);
|
||||
return (String) mData;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public boolean[] toBoolList() {
|
||||
preconditionType(TYPE_CODE_BOOL_LIST, mTypeCode);
|
||||
return (boolean[]) mData;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public long[] toLongList() {
|
||||
preconditionType(TYPE_CODE_LONG_LIST, mTypeCode);
|
||||
return (long[]) mData;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public double[] toDoubleList() {
|
||||
preconditionType(TYPE_CODE_DOUBLE_LIST, mTypeCode);
|
||||
return (double[]) mData;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public Tensor[] toTensorList() {
|
||||
preconditionType(TYPE_CODE_TENSOR_LIST, mTypeCode);
|
||||
return (Tensor[]) mData;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public IValue[] toList() {
|
||||
preconditionType(TYPE_CODE_LIST, mTypeCode);
|
||||
return (IValue[]) mData;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public IValue[] toTuple() {
|
||||
preconditionType(TYPE_CODE_TUPLE, mTypeCode);
|
||||
return (IValue[]) mData;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public Map<String, IValue> toDictStringKey() {
|
||||
preconditionType(TYPE_CODE_DICT_STRING_KEY, mTypeCode);
|
||||
return (Map<String, IValue>) mData;
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public Map<Long, IValue> toDictLongKey() {
|
||||
preconditionType(TYPE_CODE_DICT_LONG_KEY, mTypeCode);
|
||||
return (Map<Long, IValue>) mData;
|
||||
}
|
||||
|
||||
private void preconditionType(int typeCodeExpected, int typeCode) {
|
||||
if (typeCode != typeCodeExpected) {
|
||||
throw new IllegalStateException(
|
||||
String.format(
|
||||
Locale.US,
|
||||
"Expected IValue type %s, actual type %s",
|
||||
getTypeName(typeCodeExpected),
|
||||
getTypeName(typeCode)));
|
||||
}
|
||||
}
|
||||
|
||||
private String getTypeName(int typeCode) {
|
||||
return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown";
|
||||
}
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
package org.pytorch;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import java.util.Map;
|
||||
|
||||
public class LiteModuleLoader {
|
||||
|
||||
/**
|
||||
* Loads a serialized TorchScript module from the specified path on the disk to run on specified
|
||||
* device. The model should be generated from this api _save_for_lite_interpreter().
|
||||
*
|
||||
* @param modelPath path to file that contains the serialized TorchScript module.
|
||||
* @param extraFiles map with extra files names as keys, content of them will be loaded to values.
|
||||
* @param device {@link org.pytorch.Device} to use for running specified module.
|
||||
* @return new {@link org.pytorch.Module} object which owns torch::jit::mobile::Module.
|
||||
*/
|
||||
public static Module load(
|
||||
final String modelPath, final Map<String, String> extraFiles, final Device device) {
|
||||
return new Module(new LiteNativePeer(modelPath, extraFiles, device));
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a serialized TorchScript module from the specified path on the disk to run on CPU. The
|
||||
* model should be generated from this api _save_for_lite_interpreter().
|
||||
*
|
||||
* @param modelPath path to file that contains the serialized TorchScript module.
|
||||
* @return new {@link org.pytorch.Module} object which owns torch::jit::mobile::Module.
|
||||
*/
|
||||
public static Module load(final String modelPath) {
|
||||
return new Module(new LiteNativePeer(modelPath, null, Device.CPU));
|
||||
}
|
||||
|
||||
/**
|
||||
* Attention: This is not recommended way of loading production modules, as prepackaged assets
|
||||
* increase apk size etc. For production usage consider using loading from file on the disk {@link
|
||||
* org.pytorch.Module#load(String)}.
|
||||
*
|
||||
* <p>This method is meant to use in tests and demos.
|
||||
*/
|
||||
public static Module loadModuleFromAsset(
|
||||
final AssetManager assetManager, final String assetName, final Device device) {
|
||||
return new Module(new LiteNativePeer(assetName, assetManager, device));
|
||||
}
|
||||
|
||||
public static Module loadModuleFromAsset(
|
||||
final AssetManager assetManager, final String assetName) {
|
||||
return new Module(new LiteNativePeer(assetName, assetManager, Device.CPU));
|
||||
}
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
package org.pytorch;
|
||||
|
||||
import com.facebook.jni.HybridData;
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||
import java.util.Map;
|
||||
|
||||
class LiteNativePeer implements INativePeer {
|
||||
static {
|
||||
if (!NativeLoader.isInitialized()) {
|
||||
NativeLoader.init(new SystemDelegate());
|
||||
}
|
||||
NativeLoader.loadLibrary("pytorch_jni_lite");
|
||||
PyTorchCodegenLoader.loadNativeLibs();
|
||||
}
|
||||
|
||||
private final HybridData mHybridData;
|
||||
|
||||
private static native HybridData initHybrid(
|
||||
String moduleAbsolutePath, Map<String, String> extraFiles, int deviceJniCode);
|
||||
|
||||
private static native HybridData initHybridAndroidAsset(
|
||||
String assetName, /* android.content.res.AssetManager */
|
||||
Object androidAssetManager,
|
||||
int deviceJniCode);
|
||||
|
||||
LiteNativePeer(String moduleAbsolutePath, Map<String, String> extraFiles, Device device) {
|
||||
mHybridData = initHybrid(moduleAbsolutePath, extraFiles, device.jniCode);
|
||||
}
|
||||
|
||||
LiteNativePeer(
|
||||
String assetName, /* android.content.res.AssetManager */
|
||||
Object androidAssetManager,
|
||||
Device device) {
|
||||
mHybridData = initHybridAndroidAsset(assetName, androidAssetManager, device.jniCode);
|
||||
}
|
||||
|
||||
/**
|
||||
* Explicitly destroys the native torch::jit::mobile::Module. Calling this method is not required,
|
||||
* as the native object will be destroyed when this object is garbage-collected. However, the
|
||||
* timing of garbage collection is not guaranteed, so proactively calling {@code resetNative} can
|
||||
* free memory more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
|
||||
*/
|
||||
public void resetNative() {
|
||||
mHybridData.resetNative();
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the 'forward' method of this module with the specified arguments.
|
||||
*
|
||||
* @param inputs arguments for the TorchScript module's 'forward' method.
|
||||
* @return return value from the 'forward' method.
|
||||
*/
|
||||
public native IValue forward(IValue... inputs);
|
||||
|
||||
/**
|
||||
* Runs the specified method of this module with the specified arguments.
|
||||
*
|
||||
* @param methodName name of the TorchScript method to run.
|
||||
* @param inputs arguments that will be passed to TorchScript method.
|
||||
* @return return value from the method.
|
||||
*/
|
||||
public native IValue runMethod(String methodName, IValue... inputs);
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
package org.pytorch;
|
||||
|
||||
public enum MemoryFormat {
|
||||
CONTIGUOUS(1),
|
||||
CHANNELS_LAST(2),
|
||||
CHANNELS_LAST_3D(3),
|
||||
;
|
||||
|
||||
final int jniCode;
|
||||
|
||||
MemoryFormat(int jniCode) {
|
||||
this.jniCode = jniCode;
|
||||
}
|
||||
}
|
@ -0,0 +1,75 @@
|
||||
// Copyright 2004-present Facebook. All Rights Reserved.
|
||||
|
||||
package org.pytorch;
|
||||
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||
import java.util.Map;
|
||||
|
||||
/** Java wrapper for torch::jit::Module. */
|
||||
public class Module {
|
||||
|
||||
private INativePeer mNativePeer;
|
||||
|
||||
/**
|
||||
* Loads a serialized TorchScript module from the specified path on the disk to run on specified
|
||||
* device.
|
||||
*
|
||||
* @param modelPath path to file that contains the serialized TorchScript module.
|
||||
* @param extraFiles map with extra files names as keys, content of them will be loaded to values.
|
||||
* @param device {@link org.pytorch.Device} to use for running specified module.
|
||||
* @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
|
||||
*/
|
||||
public static Module load(
|
||||
final String modelPath, final Map<String, String> extraFiles, final Device device) {
|
||||
if (!NativeLoader.isInitialized()) {
|
||||
NativeLoader.init(new SystemDelegate());
|
||||
}
|
||||
return new Module(new NativePeer(modelPath, extraFiles, device));
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a serialized TorchScript module from the specified path on the disk to run on CPU.
|
||||
*
|
||||
* @param modelPath path to file that contains the serialized TorchScript module.
|
||||
* @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
|
||||
*/
|
||||
public static Module load(final String modelPath) {
|
||||
return load(modelPath, null, Device.CPU);
|
||||
}
|
||||
|
||||
Module(INativePeer nativePeer) {
|
||||
this.mNativePeer = nativePeer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the 'forward' method of this module with the specified arguments.
|
||||
*
|
||||
* @param inputs arguments for the TorchScript module's 'forward' method.
|
||||
* @return return value from the 'forward' method.
|
||||
*/
|
||||
public IValue forward(IValue... inputs) {
|
||||
return mNativePeer.forward(inputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the specified method of this module with the specified arguments.
|
||||
*
|
||||
* @param methodName name of the TorchScript method to run.
|
||||
* @param inputs arguments that will be passed to TorchScript method.
|
||||
* @return return value from the method.
|
||||
*/
|
||||
public IValue runMethod(String methodName, IValue... inputs) {
|
||||
return mNativePeer.runMethod(methodName, inputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the
|
||||
* native object will be destroyed when this object is garbage-collected. However, the timing of
|
||||
* garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory
|
||||
* more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
|
||||
*/
|
||||
public void destroy() {
|
||||
mNativePeer.resetNative();
|
||||
}
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
package org.pytorch;
|
||||
|
||||
import com.facebook.jni.HybridData;
|
||||
import com.facebook.jni.annotations.DoNotStrip;
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import java.util.Map;
|
||||
|
||||
class NativePeer implements INativePeer {
|
||||
static {
|
||||
NativeLoader.loadLibrary("pytorch_jni");
|
||||
PyTorchCodegenLoader.loadNativeLibs();
|
||||
}
|
||||
|
||||
private final HybridData mHybridData;
|
||||
|
||||
@DoNotStrip
|
||||
private static native HybridData initHybrid(
|
||||
String moduleAbsolutePath, Map<String, String> extraFiles, int deviceJniCode);
|
||||
|
||||
@DoNotStrip
|
||||
private static native HybridData initHybridAndroidAsset(
|
||||
String assetName, /* android.content.res.AssetManager */
|
||||
Object androidAssetManager,
|
||||
int deviceJniCode);
|
||||
|
||||
NativePeer(String moduleAbsolutePath, Map<String, String> extraFiles, Device device) {
|
||||
mHybridData = initHybrid(moduleAbsolutePath, extraFiles, device.jniCode);
|
||||
}
|
||||
|
||||
NativePeer(
|
||||
String assetName, /* android.content.res.AssetManager */
|
||||
Object androidAssetManager,
|
||||
Device device) {
|
||||
mHybridData = initHybridAndroidAsset(assetName, androidAssetManager, device.jniCode);
|
||||
}
|
||||
|
||||
public void resetNative() {
|
||||
mHybridData.resetNative();
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
public native IValue forward(IValue... inputs);
|
||||
|
||||
@DoNotStrip
|
||||
public native IValue runMethod(String methodName, IValue... inputs);
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
package org.pytorch;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import com.facebook.jni.annotations.DoNotStrip;
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||
|
||||
public final class PyTorchAndroid {
|
||||
static {
|
||||
if (!NativeLoader.isInitialized()) {
|
||||
NativeLoader.init(new SystemDelegate());
|
||||
}
|
||||
NativeLoader.loadLibrary("pytorch_jni_lite");
|
||||
PyTorchCodegenLoader.loadNativeLibs();
|
||||
}
|
||||
|
||||
/**
|
||||
* Attention: This is not recommended way of loading production modules, as prepackaged assets
|
||||
* increase apk size etc. For production usage consider using loading from file on the disk {@link
|
||||
* org.pytorch.Module#load(String)}.
|
||||
*
|
||||
* <p>This method is meant to use in tests and demos.
|
||||
*/
|
||||
public static Module loadModuleFromAsset(
|
||||
final AssetManager assetManager, final String assetName, final Device device) {
|
||||
return new Module(new NativePeer(assetName, assetManager, device));
|
||||
}
|
||||
|
||||
public static Module loadModuleFromAsset(
|
||||
final AssetManager assetManager, final String assetName) {
|
||||
return new Module(new NativePeer(assetName, assetManager, Device.CPU));
|
||||
}
|
||||
|
||||
/**
|
||||
* Globally sets the number of threads used on native side. Attention: Has global effect, all
|
||||
* modules use one thread pool with specified number of threads.
|
||||
*
|
||||
* @param numThreads number of threads, must be positive number.
|
||||
*/
|
||||
public static void setNumThreads(int numThreads) {
|
||||
if (numThreads < 1) {
|
||||
throw new IllegalArgumentException("Number of threads cannot be less than 1");
|
||||
}
|
||||
|
||||
nativeSetNumThreads(numThreads);
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
private static native void nativeSetNumThreads(int numThreads);
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
package org.pytorch;
|
||||
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
|
||||
public class PyTorchCodegenLoader {
|
||||
|
||||
public static void loadNativeLibs() {
|
||||
try {
|
||||
NativeLoader.loadLibrary("torch-code-gen");
|
||||
} catch (Throwable t) {
|
||||
// Loading the codegen lib is best-effort since it's only there for query based builds.
|
||||
}
|
||||
}
|
||||
|
||||
private PyTorchCodegenLoader() {}
|
||||
}
|
735
android/pytorch_android/src/main/java/org/pytorch/Tensor.java
Normal file
735
android/pytorch_android/src/main/java/org/pytorch/Tensor.java
Normal file
@ -0,0 +1,735 @@
|
||||
package org.pytorch;
|
||||
|
||||
import com.facebook.jni.HybridData;
|
||||
import com.facebook.jni.annotations.DoNotStrip;
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.IntBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* Representation of a Tensor. Behavior is similar to PyTorch's tensor objects.
|
||||
*
|
||||
* <p>Most tensors will be constructed as {@code Tensor.fromBlob(data, shape)}, where {@code data}
|
||||
* can be an array or a direct {@link Buffer} (of the proper subclass). Helper methods are provided
|
||||
* to allocate buffers properly.
|
||||
*
|
||||
* <p>To access Tensor data, see {@link #dtype()}, {@link #shape()}, and various {@code getDataAs*}
|
||||
* methods.
|
||||
*
|
||||
* <p>When constructing {@code Tensor} objects with {@code data} as an array, it is not specified
|
||||
* whether this data is copied or retained as a reference so it is recommended not to modify it
|
||||
* after constructing. {@code data} passed as a {@link Buffer} is not copied, so it can be modified
|
||||
* between {@link Module} calls to avoid reallocation. Data retrieved from {@code Tensor} objects
|
||||
* may be copied or may be a reference to the {@code Tensor}'s internal data buffer. {@code shape}
|
||||
* is always copied.
|
||||
*/
|
||||
public abstract class Tensor {
|
||||
private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null";
|
||||
private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null";
|
||||
private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null";
|
||||
private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Shape elements must be non negative";
|
||||
private static final String ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER =
|
||||
"Data buffer must have native byte order (java.nio.ByteOrder#nativeOrder)";
|
||||
private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT =
|
||||
"Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)";
|
||||
|
||||
@DoNotStrip final long[] shape;
|
||||
final MemoryFormat memoryFormat;
|
||||
|
||||
private static final int INT_SIZE_BYTES = 4;
|
||||
private static final int FLOAT_SIZE_BYTES = 4;
|
||||
private static final int LONG_SIZE_BYTES = 8;
|
||||
private static final int DOUBLE_SIZE_BYTES = 8;
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.ByteBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#fromBlob(ByteBuffer, long[])}, {@link
|
||||
* Tensor#fromBlobUnsigned(ByteBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
public static ByteBuffer allocateByteBuffer(int numElements) {
|
||||
return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder());
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.IntBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#fromBlob(IntBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
public static IntBuffer allocateIntBuffer(int numElements) {
|
||||
return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
.asIntBuffer();
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.FloatBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
public static FloatBuffer allocateFloatBuffer(int numElements) {
|
||||
return ByteBuffer.allocateDirect(numElements * FLOAT_SIZE_BYTES)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
.asFloatBuffer();
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.LongBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#fromBlob(LongBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
public static LongBuffer allocateLongBuffer(int numElements) {
|
||||
return ByteBuffer.allocateDirect(numElements * LONG_SIZE_BYTES)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
.asLongBuffer();
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.DoubleBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
public static DoubleBuffer allocateDoubleBuffer(int numElements) {
|
||||
return ByteBuffer.allocateDirect(numElements * DOUBLE_SIZE_BYTES)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
.asDoubleBuffer();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of
|
||||
* bytes.
|
||||
*
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlobUnsigned(byte[] data, long[] shape, MemoryFormat memoryFormat) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.length, shape);
|
||||
final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape));
|
||||
byteBuffer.put(data);
|
||||
return new Tensor_uint8(byteBuffer, shape, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor fromBlobUnsigned(byte[] data, long[] shape) {
|
||||
return fromBlobUnsigned(data, shape, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of
|
||||
* bytes.
|
||||
*
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(byte[] data, long[] shape, MemoryFormat memoryFormat) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.length, shape);
|
||||
final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape));
|
||||
byteBuffer.put(data);
|
||||
return new Tensor_int8(byteBuffer, shape, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor fromBlob(byte[] data, long[] shape) {
|
||||
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of
|
||||
* ints.
|
||||
*
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(int[] data, long[] shape, MemoryFormat memoryFormat) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.length, shape);
|
||||
final IntBuffer intBuffer = allocateIntBuffer((int) numel(shape));
|
||||
intBuffer.put(data);
|
||||
return new Tensor_int32(intBuffer, shape, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor fromBlob(int[] data, long[] shape) {
|
||||
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array
|
||||
* of floats.
|
||||
*
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(float[] data, long[] shape, MemoryFormat memoryFormat) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.length, shape);
|
||||
final FloatBuffer floatBuffer = allocateFloatBuffer((int) numel(shape));
|
||||
floatBuffer.put(data);
|
||||
return new Tensor_float32(floatBuffer, shape, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor fromBlob(float[] data, long[] shape) {
|
||||
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of
|
||||
* longs.
|
||||
*
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(long[] data, long[] shape, MemoryFormat memoryFormat) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.length, shape);
|
||||
final LongBuffer longBuffer = allocateLongBuffer((int) numel(shape));
|
||||
longBuffer.put(data);
|
||||
return new Tensor_int64(longBuffer, shape, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor fromBlob(long[] data, long[] shape) {
|
||||
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.float64 with specified shape and data as array
|
||||
* of doubles.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
*/
|
||||
public static Tensor fromBlob(double[] data, long[] shape, MemoryFormat memoryFormat) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.length, shape);
|
||||
final DoubleBuffer doubleBuffer = allocateDoubleBuffer((int) numel(shape));
|
||||
doubleBuffer.put(data);
|
||||
return new Tensor_float64(doubleBuffer, shape, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor fromBlob(double[] data, long[] shape) {
|
||||
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data.
|
||||
*
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
|
||||
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
|
||||
checkArgument(
|
||||
(data.order() == ByteOrder.nativeOrder()),
|
||||
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
|
||||
return new Tensor_uint8(data, shape, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) {
|
||||
return fromBlobUnsigned(data, shape, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data.
|
||||
*
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
|
||||
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
|
||||
checkArgument(
|
||||
(data.order() == ByteOrder.nativeOrder()),
|
||||
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
|
||||
return new Tensor_int8(data, shape, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor fromBlob(ByteBuffer data, long[] shape) {
|
||||
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data.
|
||||
*
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(IntBuffer data, long[] shape, MemoryFormat memoryFormat) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
|
||||
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
|
||||
checkArgument(
|
||||
(data.order() == ByteOrder.nativeOrder()),
|
||||
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
|
||||
return new Tensor_int32(data, shape, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor fromBlob(IntBuffer data, long[] shape) {
|
||||
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data.
|
||||
*
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(FloatBuffer data, long[] shape, MemoryFormat memoryFormat) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
|
||||
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
|
||||
checkArgument(
|
||||
(data.order() == ByteOrder.nativeOrder()),
|
||||
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
|
||||
return new Tensor_float32(data, shape, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor fromBlob(FloatBuffer data, long[] shape) {
|
||||
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data.
|
||||
*
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(LongBuffer data, long[] shape, MemoryFormat memoryFormat) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
|
||||
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
|
||||
checkArgument(
|
||||
(data.order() == ByteOrder.nativeOrder()),
|
||||
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
|
||||
return new Tensor_int64(data, shape, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor fromBlob(LongBuffer data, long[] shape) {
|
||||
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.float64 with specified shape and data.
|
||||
*
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor fromBlob(DoubleBuffer data, long[] shape, MemoryFormat memoryFormat) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
|
||||
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
|
||||
checkArgument(
|
||||
(data.order() == ByteOrder.nativeOrder()),
|
||||
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
|
||||
return new Tensor_float64(data, shape, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
|
||||
return fromBlob(data, shape, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
@DoNotStrip private HybridData mHybridData;
|
||||
|
||||
private Tensor(long[] shape, MemoryFormat memoryFormat) {
|
||||
checkShape(shape);
|
||||
this.shape = Arrays.copyOf(shape, shape.length);
|
||||
this.memoryFormat = memoryFormat;
|
||||
}
|
||||
|
||||
/** Returns the number of elements in this tensor. */
|
||||
public long numel() {
|
||||
return numel(this.shape);
|
||||
}
|
||||
|
||||
/** Calculates the number of elements in a tensor with the specified shape. */
|
||||
public static long numel(long[] shape) {
|
||||
checkShape(shape);
|
||||
int result = 1;
|
||||
for (long s : shape) {
|
||||
result *= s;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Returns the shape of this tensor. (The array is a fresh copy.) */
|
||||
public long[] shape() {
|
||||
return Arrays.copyOf(shape, shape.length);
|
||||
}
|
||||
|
||||
/** Returns the memory format of this tensor. */
|
||||
public MemoryFormat memoryFormat() {
|
||||
return memoryFormat;
|
||||
}
|
||||
|
||||
/** @return data type of this tensor. */
|
||||
public abstract DType dtype();
|
||||
|
||||
// Called from native
|
||||
@DoNotStrip
|
||||
int dtypeJniCode() {
|
||||
return dtype().jniCode;
|
||||
}
|
||||
|
||||
// Called from native
|
||||
@DoNotStrip
|
||||
int memoryFormatJniCode() {
|
||||
return memoryFormat.jniCode;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java byte array that contains the tensor data. This may be a copy or reference.
|
||||
* @throws IllegalStateException if it is called for a non-int8 tensor.
|
||||
*/
|
||||
public byte[] getDataAsByteArray() {
|
||||
throw new IllegalStateException(
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java byte array that contains the tensor data. This may be a copy or reference.
|
||||
* @throws IllegalStateException if it is called for a non-uint8 tensor.
|
||||
*/
|
||||
public byte[] getDataAsUnsignedByteArray() {
|
||||
throw new IllegalStateException(
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java int array that contains the tensor data. This may be a copy or reference.
|
||||
* @throws IllegalStateException if it is called for a non-int32 tensor.
|
||||
*/
|
||||
public int[] getDataAsIntArray() {
|
||||
throw new IllegalStateException(
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot return data as int array.");
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java float array that contains the tensor data. This may be a copy or reference.
|
||||
* @throws IllegalStateException if it is called for a non-float32 tensor.
|
||||
*/
|
||||
public float[] getDataAsFloatArray() {
|
||||
throw new IllegalStateException(
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot return data as float array.");
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java long array that contains the tensor data. This may be a copy or reference.
|
||||
* @throws IllegalStateException if it is called for a non-int64 tensor.
|
||||
*/
|
||||
public long[] getDataAsLongArray() {
|
||||
throw new IllegalStateException(
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot return data as long array.");
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a Java double array that contains the tensor data. This may be a copy or reference.
|
||||
* @throws IllegalStateException if it is called for a non-float64 tensor.
|
||||
*/
|
||||
public double[] getDataAsDoubleArray() {
|
||||
throw new IllegalStateException(
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot return data as double array.");
|
||||
}
|
||||
|
||||
@DoNotStrip
|
||||
Buffer getRawDataBuffer() {
|
||||
throw new IllegalStateException(
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer.");
|
||||
}
|
||||
|
||||
static class Tensor_uint8 extends Tensor {
|
||||
private final ByteBuffer data;
|
||||
|
||||
private Tensor_uint8(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) {
|
||||
super(shape, memoryFormat);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DType dtype() {
|
||||
return DType.UINT8;
|
||||
}
|
||||
|
||||
@Override
|
||||
Buffer getRawDataBuffer() {
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] getDataAsUnsignedByteArray() {
|
||||
data.rewind();
|
||||
byte[] arr = new byte[data.remaining()];
|
||||
data.get(arr);
|
||||
return arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape));
|
||||
}
|
||||
}
|
||||
|
||||
static class Tensor_int8 extends Tensor {
|
||||
private final ByteBuffer data;
|
||||
|
||||
private Tensor_int8(ByteBuffer data, long[] shape, MemoryFormat memoryFormat) {
|
||||
super(shape, memoryFormat);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DType dtype() {
|
||||
return DType.INT8;
|
||||
}
|
||||
|
||||
@Override
|
||||
Buffer getRawDataBuffer() {
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] getDataAsByteArray() {
|
||||
data.rewind();
|
||||
byte[] arr = new byte[data.remaining()];
|
||||
data.get(arr);
|
||||
return arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape));
|
||||
}
|
||||
}
|
||||
|
||||
static class Tensor_int32 extends Tensor {
|
||||
private final IntBuffer data;
|
||||
|
||||
private Tensor_int32(IntBuffer data, long[] shape, MemoryFormat memoryFormat) {
|
||||
super(shape, memoryFormat);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DType dtype() {
|
||||
return DType.INT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
Buffer getRawDataBuffer() {
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int[] getDataAsIntArray() {
|
||||
data.rewind();
|
||||
int[] arr = new int[data.remaining()];
|
||||
data.get(arr);
|
||||
return arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape));
|
||||
}
|
||||
}
|
||||
|
||||
static class Tensor_float32 extends Tensor {
|
||||
private final FloatBuffer data;
|
||||
|
||||
Tensor_float32(FloatBuffer data, long[] shape, MemoryFormat memoryFormat) {
|
||||
super(shape, memoryFormat);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] getDataAsFloatArray() {
|
||||
data.rewind();
|
||||
float[] arr = new float[data.remaining()];
|
||||
data.get(arr);
|
||||
return arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DType dtype() {
|
||||
return DType.FLOAT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
Buffer getRawDataBuffer() {
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("Tensor(%s, dtype=torch.float32)", Arrays.toString(shape));
|
||||
}
|
||||
}
|
||||
|
||||
static class Tensor_int64 extends Tensor {
|
||||
private final LongBuffer data;
|
||||
|
||||
private Tensor_int64(LongBuffer data, long[] shape, MemoryFormat memoryFormat) {
|
||||
super(shape, memoryFormat);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DType dtype() {
|
||||
return DType.INT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
Buffer getRawDataBuffer() {
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] getDataAsLongArray() {
|
||||
data.rewind();
|
||||
long[] arr = new long[data.remaining()];
|
||||
data.get(arr);
|
||||
return arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape));
|
||||
}
|
||||
}
|
||||
|
||||
static class Tensor_float64 extends Tensor {
|
||||
private final DoubleBuffer data;
|
||||
|
||||
private Tensor_float64(DoubleBuffer data, long[] shape, MemoryFormat memoryFormat) {
|
||||
super(shape, memoryFormat);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DType dtype() {
|
||||
return DType.FLOAT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
Buffer getRawDataBuffer() {
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] getDataAsDoubleArray() {
|
||||
data.rewind();
|
||||
double[] arr = new double[data.remaining()];
|
||||
data.get(arr);
|
||||
return arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape));
|
||||
}
|
||||
}
|
||||
|
||||
// region checks
|
||||
private static void checkArgument(boolean expression, String errorMessage, Object... args) {
|
||||
if (!expression) {
|
||||
throw new IllegalArgumentException(String.format(Locale.US, errorMessage, args));
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkShape(long[] shape) {
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
for (int i = 0; i < shape.length; i++) {
|
||||
checkArgument(shape[i] >= 0, ERROR_MSG_SHAPE_NON_NEGATIVE);
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape) {
|
||||
final long numel = numel(shape);
|
||||
checkArgument(
|
||||
numel == dataCapacity,
|
||||
"Inconsistent data capacity:%d and shape number elements:%d shape:%s",
|
||||
dataCapacity,
|
||||
numel,
|
||||
Arrays.toString(shape));
|
||||
}
|
||||
// endregion checks
|
||||
|
||||
// Called from native
|
||||
@DoNotStrip
|
||||
private static Tensor nativeNewTensor(
|
||||
ByteBuffer data, long[] shape, int dtype, int memoryFormatCode, HybridData hybridData) {
|
||||
Tensor tensor = null;
|
||||
|
||||
MemoryFormat memoryFormat = MemoryFormat.CONTIGUOUS;
|
||||
if (MemoryFormat.CHANNELS_LAST.jniCode == memoryFormatCode) {
|
||||
memoryFormat = MemoryFormat.CHANNELS_LAST;
|
||||
} else if (MemoryFormat.CHANNELS_LAST_3D.jniCode == memoryFormatCode) {
|
||||
memoryFormat = MemoryFormat.CHANNELS_LAST_3D;
|
||||
}
|
||||
|
||||
if (DType.FLOAT32.jniCode == dtype) {
|
||||
tensor = new Tensor_float32(data.asFloatBuffer(), shape, memoryFormat);
|
||||
} else if (DType.INT32.jniCode == dtype) {
|
||||
tensor = new Tensor_int32(data.asIntBuffer(), shape, memoryFormat);
|
||||
} else if (DType.INT64.jniCode == dtype) {
|
||||
tensor = new Tensor_int64(data.asLongBuffer(), shape, memoryFormat);
|
||||
} else if (DType.FLOAT64.jniCode == dtype) {
|
||||
tensor = new Tensor_float64(data.asDoubleBuffer(), shape, memoryFormat);
|
||||
} else if (DType.UINT8.jniCode == dtype) {
|
||||
tensor = new Tensor_uint8(data, shape, memoryFormat);
|
||||
} else if (DType.INT8.jniCode == dtype) {
|
||||
tensor = new Tensor_int8(data, shape, memoryFormat);
|
||||
} else {
|
||||
new IllegalArgumentException("Unknown Tensor dtype");
|
||||
}
|
||||
tensor.mHybridData = hybridData;
|
||||
return tensor;
|
||||
}
|
||||
}
|
3
android/pytorch_android/src/main/res/values/strings.xml
Normal file
3
android/pytorch_android/src/main/res/values/strings.xml
Normal file
@ -0,0 +1,3 @@
|
||||
<resources>
|
||||
<string name="app_name">pytorch_android</string>
|
||||
</resources>
|
105
android/pytorch_android/test_asset.jit
Normal file
105
android/pytorch_android/test_asset.jit
Normal file
@ -0,0 +1,105 @@
|
||||
def forward(self, input):
|
||||
return None
|
||||
|
||||
def eqBool(self, input: bool) -> bool:
|
||||
return input
|
||||
|
||||
def eqInt(self, input: int) -> int:
|
||||
return input
|
||||
|
||||
def eqFloat(self, input: float) -> float:
|
||||
return input
|
||||
|
||||
def eqStr(self, input: str) -> str:
|
||||
return input
|
||||
|
||||
def eqTensor(self, input: Tensor) -> Tensor:
|
||||
return input
|
||||
|
||||
def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]:
|
||||
return input
|
||||
|
||||
def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]:
|
||||
return input
|
||||
|
||||
def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]:
|
||||
return input
|
||||
|
||||
def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
return (input, sum)
|
||||
|
||||
def listBoolConjunction(self, input: List[bool]) -> bool:
|
||||
res = True
|
||||
for x in input:
|
||||
res = res and x
|
||||
return res
|
||||
|
||||
def listBoolDisjunction(self, input: List[bool]) -> bool:
|
||||
res = False
|
||||
for x in input:
|
||||
res = res or x
|
||||
return res
|
||||
|
||||
def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
return (input, sum)
|
||||
|
||||
def optionalIntIsNone(self, input: Optional[int]) -> bool:
|
||||
return input is None
|
||||
|
||||
def intEq0None(self, input: int) -> Optional[int]:
|
||||
if input == 0:
|
||||
return None
|
||||
return input
|
||||
|
||||
def str3Concat(self, input: str) -> str:
|
||||
return input + input + input
|
||||
|
||||
def newEmptyShapeWithItem(self, input):
|
||||
return torch.tensor([int(input.item())])[0]
|
||||
|
||||
def testAliasWithOffset(self) -> List[Tensor]:
|
||||
x = torch.tensor([100, 200])
|
||||
a = [x[0], x[1]]
|
||||
return a
|
||||
|
||||
def testNonContiguous(self):
|
||||
x = torch.tensor([100, 200, 300])[::2]
|
||||
assert not x.is_contiguous()
|
||||
assert x[0] == 100
|
||||
assert x[1] == 300
|
||||
return x
|
||||
|
||||
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = torch.conv2d(x, w)
|
||||
if (toChannelsLast):
|
||||
# memory_format=torch.channels_last
|
||||
r = r.contiguous(memory_format=2)
|
||||
else:
|
||||
r = r.contiguous()
|
||||
return r
|
||||
|
||||
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = torch.conv3d(x, w)
|
||||
if (toChannelsLast):
|
||||
# memory_format=torch.channels_last_3d
|
||||
r = r.contiguous(memory_format=2)
|
||||
else:
|
||||
r = r.contiguous()
|
||||
return r
|
||||
|
||||
def contiguous(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous()
|
||||
|
||||
def contiguousChannelsLast(self, x: Tensor) -> Tensor:
|
||||
# memory_format=torch.channels_last
|
||||
return x.contiguous(memory_format=2)
|
||||
|
||||
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
|
||||
# memory_format=torch.channels_last_3d
|
||||
return x.contiguous(memory_format=3)
|
Reference in New Issue
Block a user