mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Delete TorchScript based Android demo app and point to ExecuTorch (#153633)"
This reverts commit b22f01fcb9d69bb7d77e08d69004c7265ef7fa4a. Reverted https://github.com/pytorch/pytorch/pull/153633 on behalf of https://github.com/malfet due to But libtorch build regressions are real, fbjni is still used for C++ builds ([comment](https://github.com/pytorch/pytorch/pull/153633#issuecomment-2884951805))
This commit is contained in:
22
android/pytorch_android_torchvision/CMakeLists.txt
Normal file
22
android/pytorch_android_torchvision/CMakeLists.txt
Normal file
@ -0,0 +1,22 @@
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
project(pytorch_vision_jni CXX)
|
||||
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.")
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
|
||||
set(pytorch_vision_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
|
||||
|
||||
file(GLOB pytorch_vision_SOURCES
|
||||
${pytorch_vision_cpp_DIR}/pytorch_vision_jni.cpp
|
||||
)
|
||||
|
||||
add_library(pytorch_vision_jni SHARED
|
||||
${pytorch_vision_SOURCES}
|
||||
)
|
||||
|
||||
target_compile_options(pytorch_vision_jni PRIVATE
|
||||
-fexceptions
|
||||
)
|
||||
|
||||
set(BUILD_SUBDIR ${ANDROID_ABI})
|
||||
|
||||
target_link_libraries(pytorch_vision_jni)
|
64
android/pytorch_android_torchvision/build.gradle
Normal file
64
android/pytorch_android_torchvision/build.gradle
Normal file
@ -0,0 +1,64 @@
|
||||
apply plugin: 'com.android.library'
|
||||
apply plugin: 'maven'
|
||||
|
||||
android {
|
||||
compileSdkVersion rootProject.compileSdkVersion
|
||||
buildToolsVersion rootProject.buildToolsVersion
|
||||
|
||||
|
||||
defaultConfig {
|
||||
minSdkVersion rootProject.minSdkVersion
|
||||
targetSdkVersion rootProject.targetSdkVersion
|
||||
versionCode 0
|
||||
versionName "0.1"
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
ndk {
|
||||
abiFilters ABI_FILTERS.split(",")
|
||||
}
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
debug {
|
||||
minifyEnabled false
|
||||
debuggable true
|
||||
}
|
||||
release {
|
||||
minifyEnabled false
|
||||
}
|
||||
}
|
||||
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path "CMakeLists.txt"
|
||||
}
|
||||
}
|
||||
|
||||
useLibrary 'android.test.runner'
|
||||
useLibrary 'android.test.base'
|
||||
useLibrary 'android.test.mock'
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation project(':pytorch_android')
|
||||
|
||||
implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion
|
||||
|
||||
testImplementation 'junit:junit:' + rootProject.junitVersion
|
||||
testImplementation 'androidx.test:core:' + rootProject.coreVersion
|
||||
|
||||
androidTestImplementation 'junit:junit:' + rootProject.junitVersion
|
||||
androidTestImplementation 'androidx.test:core:' + rootProject.coreVersion
|
||||
androidTestImplementation 'androidx.test.ext:junit:' + rootProject.extJUnitVersion
|
||||
androidTestImplementation 'androidx.test:rules:' + rootProject.rulesVersion
|
||||
androidTestImplementation 'androidx.test:runner:' + rootProject.runnerVersion
|
||||
}
|
||||
|
||||
apply from: rootProject.file('gradle/release.gradle')
|
||||
|
||||
task sourcesJar(type: Jar) {
|
||||
from android.sourceSets.main.java.srcDirs
|
||||
classifier = 'sources'
|
||||
}
|
||||
|
||||
artifacts.add('archives', sourcesJar)
|
4
android/pytorch_android_torchvision/gradle.properties
Normal file
4
android/pytorch_android_torchvision/gradle.properties
Normal file
@ -0,0 +1,4 @@
|
||||
POM_NAME=pytorch_android_torchvision_lite
|
||||
POM_DESCRIPTION=pytorch_android_torchvision_lite
|
||||
POM_ARTIFACT_ID=pytorch_android_torchvision_lite
|
||||
POM_PACKAGING=aar
|
@ -0,0 +1,24 @@
|
||||
package org.pytorch.torchvision;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.pytorch.Tensor;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class TorchVisionInstrumentedTests {
|
||||
|
||||
@Test
|
||||
public void smokeTest() {
|
||||
Bitmap bitmap = Bitmap.createBitmap(320, 240, Bitmap.Config.ARGB_8888);
|
||||
Tensor tensor =
|
||||
TensorImageUtils.bitmapToFloat32Tensor(
|
||||
bitmap,
|
||||
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
|
||||
TensorImageUtils.TORCHVISION_NORM_STD_RGB);
|
||||
assertArrayEquals(new long[] {1l, 3l, 240l, 320l}, tensor.shape());
|
||||
}
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
package org.pytorch.torchvision.suite;
|
||||
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.pytorch.torchvision.TorchVisionInstrumentedTests;
|
||||
|
||||
@RunWith(Suite.class)
|
||||
@Suite.SuiteClasses({TorchVisionInstrumentedTests.class})
|
||||
public class TorchVisionInstrumentedTestSuite {}
|
@ -0,0 +1 @@
|
||||
<manifest package="org.pytorch.torchvision" />
|
@ -0,0 +1,187 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "jni.h"
|
||||
|
||||
namespace pytorch_vision_jni {
|
||||
|
||||
static void imageYUV420CenterCropToFloatBuffer(
|
||||
JNIEnv* jniEnv,
|
||||
jclass,
|
||||
jobject yBuffer,
|
||||
jint yRowStride,
|
||||
jint yPixelStride,
|
||||
jobject uBuffer,
|
||||
jobject vBuffer,
|
||||
jint uRowStride,
|
||||
jint uvPixelStride,
|
||||
jint imageWidth,
|
||||
jint imageHeight,
|
||||
jint rotateCWDegrees,
|
||||
jint tensorWidth,
|
||||
jint tensorHeight,
|
||||
jfloatArray jnormMeanRGB,
|
||||
jfloatArray jnormStdRGB,
|
||||
jobject outBuffer,
|
||||
jint outOffset,
|
||||
jint memoryFormatCode) {
|
||||
constexpr static int32_t kMemoryFormatContiguous = 1;
|
||||
constexpr static int32_t kMemoryFormatChannelsLast = 2;
|
||||
|
||||
float* outData = (float*)jniEnv->GetDirectBufferAddress(outBuffer);
|
||||
|
||||
jfloat normMeanRGB[3];
|
||||
jfloat normStdRGB[3];
|
||||
jniEnv->GetFloatArrayRegion(jnormMeanRGB, 0, 3, normMeanRGB);
|
||||
jniEnv->GetFloatArrayRegion(jnormStdRGB, 0, 3, normStdRGB);
|
||||
int widthAfterRtn = imageWidth;
|
||||
int heightAfterRtn = imageHeight;
|
||||
bool oddRotation = rotateCWDegrees == 90 || rotateCWDegrees == 270;
|
||||
if (oddRotation) {
|
||||
widthAfterRtn = imageHeight;
|
||||
heightAfterRtn = imageWidth;
|
||||
}
|
||||
|
||||
int cropWidthAfterRtn = widthAfterRtn;
|
||||
int cropHeightAfterRtn = heightAfterRtn;
|
||||
|
||||
if (tensorWidth * heightAfterRtn <= tensorHeight * widthAfterRtn) {
|
||||
cropWidthAfterRtn = tensorWidth * heightAfterRtn / tensorHeight;
|
||||
} else {
|
||||
cropHeightAfterRtn = tensorHeight * widthAfterRtn / tensorWidth;
|
||||
}
|
||||
|
||||
int cropWidthBeforeRtn = cropWidthAfterRtn;
|
||||
int cropHeightBeforeRtn = cropHeightAfterRtn;
|
||||
if (oddRotation) {
|
||||
cropWidthBeforeRtn = cropHeightAfterRtn;
|
||||
cropHeightBeforeRtn = cropWidthAfterRtn;
|
||||
}
|
||||
|
||||
const int offsetX = (imageWidth - cropWidthBeforeRtn) / 2.f;
|
||||
const int offsetY = (imageHeight - cropHeightBeforeRtn) / 2.f;
|
||||
|
||||
const uint8_t* yData = (uint8_t*)jniEnv->GetDirectBufferAddress(yBuffer);
|
||||
const uint8_t* uData = (uint8_t*)jniEnv->GetDirectBufferAddress(uBuffer);
|
||||
const uint8_t* vData = (uint8_t*)jniEnv->GetDirectBufferAddress(vBuffer);
|
||||
|
||||
float scale = cropWidthAfterRtn / tensorWidth;
|
||||
int uvRowStride = uRowStride;
|
||||
int cropXMult = 1;
|
||||
int cropYMult = 1;
|
||||
int cropXAdd = offsetX;
|
||||
int cropYAdd = offsetY;
|
||||
if (rotateCWDegrees == 90) {
|
||||
cropYMult = -1;
|
||||
cropYAdd = offsetY + (cropHeightBeforeRtn - 1);
|
||||
} else if (rotateCWDegrees == 180) {
|
||||
cropXMult = -1;
|
||||
cropXAdd = offsetX + (cropWidthBeforeRtn - 1);
|
||||
cropYMult = -1;
|
||||
cropYAdd = offsetY + (cropHeightBeforeRtn - 1);
|
||||
} else if (rotateCWDegrees == 270) {
|
||||
cropXMult = -1;
|
||||
cropXAdd = offsetX + (cropWidthBeforeRtn - 1);
|
||||
}
|
||||
|
||||
float normMeanRm255 = 255 * normMeanRGB[0];
|
||||
float normMeanGm255 = 255 * normMeanRGB[1];
|
||||
float normMeanBm255 = 255 * normMeanRGB[2];
|
||||
float normStdRm255 = 255 * normStdRGB[0];
|
||||
float normStdGm255 = 255 * normStdRGB[1];
|
||||
float normStdBm255 = 255 * normStdRGB[2];
|
||||
|
||||
int xBeforeRtn, yBeforeRtn;
|
||||
int yi, yIdx, uvIdx, ui, vi, a0, ri, gi, bi;
|
||||
int channelSize = tensorWidth * tensorHeight;
|
||||
// A bit of code duplication to avoid branching in the cycles
|
||||
if (memoryFormatCode == kMemoryFormatContiguous) {
|
||||
int wr = outOffset;
|
||||
int wg = wr + channelSize;
|
||||
int wb = wg + channelSize;
|
||||
for (int y = 0; y < tensorHeight; y++) {
|
||||
for (int x = 0; x < tensorWidth; x++) {
|
||||
xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
|
||||
yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
|
||||
yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
|
||||
uvIdx =
|
||||
(yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
|
||||
ui = uData[uvIdx];
|
||||
vi = vData[uvIdx];
|
||||
yi = yData[yIdx];
|
||||
yi = (yi - 16) < 0 ? 0 : (yi - 16);
|
||||
ui -= 128;
|
||||
vi -= 128;
|
||||
a0 = 1192 * yi;
|
||||
ri = (a0 + 1634 * vi) >> 10;
|
||||
gi = (a0 - 833 * vi - 400 * ui) >> 10;
|
||||
bi = (a0 + 2066 * ui) >> 10;
|
||||
ri = ri > 255 ? 255 : ri < 0 ? 0 : ri;
|
||||
gi = gi > 255 ? 255 : gi < 0 ? 0 : gi;
|
||||
bi = bi > 255 ? 255 : bi < 0 ? 0 : bi;
|
||||
outData[wr++] = (ri - normMeanRm255) / normStdRm255;
|
||||
outData[wg++] = (gi - normMeanGm255) / normStdGm255;
|
||||
outData[wb++] = (bi - normMeanBm255) / normStdBm255;
|
||||
}
|
||||
}
|
||||
} else if (memoryFormatCode == kMemoryFormatChannelsLast) {
|
||||
int wc = outOffset;
|
||||
for (int y = 0; y < tensorHeight; y++) {
|
||||
for (int x = 0; x < tensorWidth; x++) {
|
||||
xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
|
||||
yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
|
||||
yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
|
||||
uvIdx =
|
||||
(yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
|
||||
ui = uData[uvIdx];
|
||||
vi = vData[uvIdx];
|
||||
yi = yData[yIdx];
|
||||
yi = (yi - 16) < 0 ? 0 : (yi - 16);
|
||||
ui -= 128;
|
||||
vi -= 128;
|
||||
a0 = 1192 * yi;
|
||||
ri = (a0 + 1634 * vi) >> 10;
|
||||
gi = (a0 - 833 * vi - 400 * ui) >> 10;
|
||||
bi = (a0 + 2066 * ui) >> 10;
|
||||
ri = ri > 255 ? 255 : ri < 0 ? 0 : ri;
|
||||
gi = gi > 255 ? 255 : gi < 0 ? 0 : gi;
|
||||
bi = bi > 255 ? 255 : bi < 0 ? 0 : bi;
|
||||
outData[wc++] = (ri - normMeanRm255) / normStdRm255;
|
||||
outData[wc++] = (gi - normMeanGm255) / normStdGm255;
|
||||
outData[wc++] = (bi - normMeanBm255) / normStdBm255;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
jclass Exception = jniEnv->FindClass("java/lang/IllegalArgumentException");
|
||||
jniEnv->ThrowNew(Exception, "Illegal memory format code");
|
||||
}
|
||||
}
|
||||
} // namespace pytorch_vision_jni
|
||||
|
||||
JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
|
||||
JNIEnv* env;
|
||||
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
|
||||
return JNI_ERR;
|
||||
}
|
||||
|
||||
jclass c =
|
||||
env->FindClass("org/pytorch/torchvision/TensorImageUtils$NativePeer");
|
||||
if (c == nullptr) {
|
||||
return JNI_ERR;
|
||||
}
|
||||
|
||||
static const JNINativeMethod methods[] = {
|
||||
{"imageYUV420CenterCropToFloatBuffer",
|
||||
"(Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;Ljava/nio/ByteBuffer;IIIIIII[F[FLjava/nio/Buffer;II)V",
|
||||
(void*)pytorch_vision_jni::imageYUV420CenterCropToFloatBuffer},
|
||||
};
|
||||
int rc = env->RegisterNatives(
|
||||
c, methods, sizeof(methods) / sizeof(JNINativeMethod));
|
||||
|
||||
if (rc != JNI_OK) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
return JNI_VERSION_1_6;
|
||||
}
|
@ -0,0 +1,398 @@
|
||||
package org.pytorch.torchvision;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.ImageFormat;
|
||||
import android.media.Image;
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.Locale;
|
||||
import org.pytorch.MemoryFormat;
|
||||
import org.pytorch.Tensor;
|
||||
|
||||
/**
|
||||
* Contains utility functions for {@link org.pytorch.Tensor} creation from {@link
|
||||
* android.graphics.Bitmap} or {@link android.media.Image} source.
|
||||
*/
|
||||
public final class TensorImageUtils {
|
||||
|
||||
public static float[] TORCHVISION_NORM_MEAN_RGB = new float[] {0.485f, 0.456f, 0.406f};
|
||||
public static float[] TORCHVISION_NORM_STD_RGB = new float[] {0.229f, 0.224f, 0.225f};
|
||||
|
||||
/**
|
||||
* Creates new {@link org.pytorch.Tensor} from full {@link android.graphics.Bitmap}, normalized
|
||||
* with specified in parameters mean and std.
|
||||
*
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
|
||||
* order
|
||||
*/
|
||||
public static Tensor bitmapToFloat32Tensor(
|
||||
final Bitmap bitmap,
|
||||
final float[] normMeanRGB,
|
||||
final float normStdRGB[],
|
||||
final MemoryFormat memoryFormat) {
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
|
||||
return bitmapToFloat32Tensor(
|
||||
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor bitmapToFloat32Tensor(
|
||||
final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[]) {
|
||||
return bitmapToFloat32Tensor(
|
||||
bitmap,
|
||||
0,
|
||||
0,
|
||||
bitmap.getWidth(),
|
||||
bitmap.getHeight(),
|
||||
normMeanRGB,
|
||||
normStdRGB,
|
||||
MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes tensor content from specified {@link android.graphics.Bitmap}, normalized with specified
|
||||
* in parameters mean and std to specified {@link java.nio.FloatBuffer} with specified offset.
|
||||
*
|
||||
* @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data
|
||||
* @param x - x coordinate of top left corner of bitmap's area
|
||||
* @param y - y coordinate of top left corner of bitmap's area
|
||||
* @param width - width of bitmap's area
|
||||
* @param height - height of bitmap's area
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
|
||||
* order
|
||||
*/
|
||||
public static void bitmapToFloatBuffer(
|
||||
final Bitmap bitmap,
|
||||
final int x,
|
||||
final int y,
|
||||
final int width,
|
||||
final int height,
|
||||
final float[] normMeanRGB,
|
||||
final float[] normStdRGB,
|
||||
final FloatBuffer outBuffer,
|
||||
final int outBufferOffset,
|
||||
final MemoryFormat memoryFormat) {
|
||||
checkOutBufferCapacity(outBuffer, outBufferOffset, width, height);
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
if (memoryFormat != MemoryFormat.CONTIGUOUS && memoryFormat != MemoryFormat.CHANNELS_LAST) {
|
||||
throw new IllegalArgumentException("Unsupported memory format " + memoryFormat);
|
||||
}
|
||||
|
||||
final int pixelsCount = height * width;
|
||||
final int[] pixels = new int[pixelsCount];
|
||||
bitmap.getPixels(pixels, 0, width, x, y, width, height);
|
||||
if (MemoryFormat.CONTIGUOUS == memoryFormat) {
|
||||
final int offset_g = pixelsCount;
|
||||
final int offset_b = 2 * pixelsCount;
|
||||
for (int i = 0; i < pixelsCount; i++) {
|
||||
final int c = pixels[i];
|
||||
float r = ((c >> 16) & 0xff) / 255.0f;
|
||||
float g = ((c >> 8) & 0xff) / 255.0f;
|
||||
float b = ((c) & 0xff) / 255.0f;
|
||||
outBuffer.put(outBufferOffset + i, (r - normMeanRGB[0]) / normStdRGB[0]);
|
||||
outBuffer.put(outBufferOffset + offset_g + i, (g - normMeanRGB[1]) / normStdRGB[1]);
|
||||
outBuffer.put(outBufferOffset + offset_b + i, (b - normMeanRGB[2]) / normStdRGB[2]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < pixelsCount; i++) {
|
||||
final int c = pixels[i];
|
||||
float r = ((c >> 16) & 0xff) / 255.0f;
|
||||
float g = ((c >> 8) & 0xff) / 255.0f;
|
||||
float b = ((c) & 0xff) / 255.0f;
|
||||
outBuffer.put(outBufferOffset + 3 * i + 0, (r - normMeanRGB[0]) / normStdRGB[0]);
|
||||
outBuffer.put(outBufferOffset + 3 * i + 1, (g - normMeanRGB[1]) / normStdRGB[1]);
|
||||
outBuffer.put(outBufferOffset + 3 * i + 2, (b - normMeanRGB[2]) / normStdRGB[2]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void bitmapToFloatBuffer(
|
||||
final Bitmap bitmap,
|
||||
final int x,
|
||||
final int y,
|
||||
final int width,
|
||||
final int height,
|
||||
final float[] normMeanRGB,
|
||||
final float[] normStdRGB,
|
||||
final FloatBuffer outBuffer,
|
||||
final int outBufferOffset) {
|
||||
bitmapToFloatBuffer(
|
||||
bitmap,
|
||||
x,
|
||||
y,
|
||||
width,
|
||||
height,
|
||||
normMeanRGB,
|
||||
normStdRGB,
|
||||
outBuffer,
|
||||
outBufferOffset,
|
||||
MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates new {@link org.pytorch.Tensor} from specified area of {@link android.graphics.Bitmap},
|
||||
* normalized with specified in parameters mean and std.
|
||||
*
|
||||
* @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data
|
||||
* @param x - x coordinate of top left corner of bitmap's area
|
||||
* @param y - y coordinate of top left corner of bitmap's area
|
||||
* @param width - width of bitmap's area
|
||||
* @param height - height of bitmap's area
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
|
||||
* order
|
||||
*/
|
||||
public static Tensor bitmapToFloat32Tensor(
|
||||
final Bitmap bitmap,
|
||||
int x,
|
||||
int y,
|
||||
int width,
|
||||
int height,
|
||||
float[] normMeanRGB,
|
||||
float[] normStdRGB,
|
||||
MemoryFormat memoryFormat) {
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
|
||||
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * width * height);
|
||||
bitmapToFloatBuffer(
|
||||
bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0, memoryFormat);
|
||||
return Tensor.fromBlob(floatBuffer, new long[] {1, 3, height, width}, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor bitmapToFloat32Tensor(
|
||||
final Bitmap bitmap,
|
||||
int x,
|
||||
int y,
|
||||
int width,
|
||||
int height,
|
||||
float[] normMeanRGB,
|
||||
float[] normStdRGB) {
|
||||
return bitmapToFloat32Tensor(
|
||||
bitmap, x, y, width, height, normMeanRGB, normStdRGB, MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates new {@link org.pytorch.Tensor} from specified area of {@link android.media.Image},
|
||||
* doing optional rotation, scaling (nearest) and center cropping.
|
||||
*
|
||||
* @param image {@link android.media.Image} as a source for Tensor data
|
||||
* @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be
|
||||
* upright. Range of valid values: 0, 90, 180, 270
|
||||
* @param tensorWidth return tensor width, must be positive
|
||||
* @param tensorHeight return tensor height, must be positive
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
|
||||
* order
|
||||
*/
|
||||
public static Tensor imageYUV420CenterCropToFloat32Tensor(
|
||||
final Image image,
|
||||
int rotateCWDegrees,
|
||||
final int tensorWidth,
|
||||
final int tensorHeight,
|
||||
float[] normMeanRGB,
|
||||
float[] normStdRGB,
|
||||
MemoryFormat memoryFormat) {
|
||||
if (image.getFormat() != ImageFormat.YUV_420_888) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
Locale.US, "Image format %d != ImageFormat.YUV_420_888", image.getFormat()));
|
||||
}
|
||||
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
checkRotateCWDegrees(rotateCWDegrees);
|
||||
checkTensorSize(tensorWidth, tensorHeight);
|
||||
|
||||
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * tensorWidth * tensorHeight);
|
||||
imageYUV420CenterCropToFloatBuffer(
|
||||
image,
|
||||
rotateCWDegrees,
|
||||
tensorWidth,
|
||||
tensorHeight,
|
||||
normMeanRGB,
|
||||
normStdRGB,
|
||||
floatBuffer,
|
||||
0,
|
||||
memoryFormat);
|
||||
return Tensor.fromBlob(floatBuffer, new long[] {1, 3, tensorHeight, tensorWidth}, memoryFormat);
|
||||
}
|
||||
|
||||
public static Tensor imageYUV420CenterCropToFloat32Tensor(
|
||||
final Image image,
|
||||
int rotateCWDegrees,
|
||||
final int tensorWidth,
|
||||
final int tensorHeight,
|
||||
float[] normMeanRGB,
|
||||
float[] normStdRGB) {
|
||||
return imageYUV420CenterCropToFloat32Tensor(
|
||||
image,
|
||||
rotateCWDegrees,
|
||||
tensorWidth,
|
||||
tensorHeight,
|
||||
normMeanRGB,
|
||||
normStdRGB,
|
||||
MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes tensor content from specified {@link android.media.Image}, doing optional rotation,
|
||||
* scaling (nearest) and center cropping to specified {@link java.nio.FloatBuffer} with specified
|
||||
* offset.
|
||||
*
|
||||
* @param image {@link android.media.Image} as a source for Tensor data
|
||||
* @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be
|
||||
* upright. Range of valid values: 0, 90, 180, 270
|
||||
* @param tensorWidth return tensor width, must be positive
|
||||
* @param tensorHeight return tensor height, must be positive
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
|
||||
* order
|
||||
* @param outBuffer Output buffer, where tensor content will be written
|
||||
* @param outBufferOffset Output buffer offset with which tensor content will be written
|
||||
*/
|
||||
public static void imageYUV420CenterCropToFloatBuffer(
|
||||
final Image image,
|
||||
int rotateCWDegrees,
|
||||
final int tensorWidth,
|
||||
final int tensorHeight,
|
||||
float[] normMeanRGB,
|
||||
float[] normStdRGB,
|
||||
final FloatBuffer outBuffer,
|
||||
final int outBufferOffset,
|
||||
final MemoryFormat memoryFormat) {
|
||||
checkOutBufferCapacity(outBuffer, outBufferOffset, tensorWidth, tensorHeight);
|
||||
|
||||
if (image.getFormat() != ImageFormat.YUV_420_888) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
Locale.US, "Image format %d != ImageFormat.YUV_420_888", image.getFormat()));
|
||||
}
|
||||
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
checkRotateCWDegrees(rotateCWDegrees);
|
||||
checkTensorSize(tensorWidth, tensorHeight);
|
||||
|
||||
Image.Plane[] planes = image.getPlanes();
|
||||
Image.Plane Y = planes[0];
|
||||
Image.Plane U = planes[1];
|
||||
Image.Plane V = planes[2];
|
||||
|
||||
int memoryFormatJniCode = 0;
|
||||
if (MemoryFormat.CONTIGUOUS == memoryFormat) {
|
||||
memoryFormatJniCode = 1;
|
||||
} else if (MemoryFormat.CHANNELS_LAST == memoryFormat) {
|
||||
memoryFormatJniCode = 2;
|
||||
}
|
||||
|
||||
NativePeer.imageYUV420CenterCropToFloatBuffer(
|
||||
Y.getBuffer(),
|
||||
Y.getRowStride(),
|
||||
Y.getPixelStride(),
|
||||
U.getBuffer(),
|
||||
V.getBuffer(),
|
||||
U.getRowStride(),
|
||||
U.getPixelStride(),
|
||||
image.getWidth(),
|
||||
image.getHeight(),
|
||||
rotateCWDegrees,
|
||||
tensorWidth,
|
||||
tensorHeight,
|
||||
normMeanRGB,
|
||||
normStdRGB,
|
||||
outBuffer,
|
||||
outBufferOffset,
|
||||
memoryFormatJniCode);
|
||||
}
|
||||
|
||||
public static void imageYUV420CenterCropToFloatBuffer(
|
||||
final Image image,
|
||||
int rotateCWDegrees,
|
||||
final int tensorWidth,
|
||||
final int tensorHeight,
|
||||
float[] normMeanRGB,
|
||||
float[] normStdRGB,
|
||||
final FloatBuffer outBuffer,
|
||||
final int outBufferOffset) {
|
||||
imageYUV420CenterCropToFloatBuffer(
|
||||
image,
|
||||
rotateCWDegrees,
|
||||
tensorWidth,
|
||||
tensorHeight,
|
||||
normMeanRGB,
|
||||
normStdRGB,
|
||||
outBuffer,
|
||||
outBufferOffset,
|
||||
MemoryFormat.CONTIGUOUS);
|
||||
}
|
||||
|
||||
private static class NativePeer {
|
||||
static {
|
||||
if (!NativeLoader.isInitialized()) {
|
||||
NativeLoader.init(new SystemDelegate());
|
||||
}
|
||||
NativeLoader.loadLibrary("pytorch_vision_jni");
|
||||
}
|
||||
|
||||
private static native void imageYUV420CenterCropToFloatBuffer(
|
||||
ByteBuffer yBuffer,
|
||||
int yRowStride,
|
||||
int yPixelStride,
|
||||
ByteBuffer uBuffer,
|
||||
ByteBuffer vBuffer,
|
||||
int uvRowStride,
|
||||
int uvPixelStride,
|
||||
int imageWidth,
|
||||
int imageHeight,
|
||||
int rotateCWDegrees,
|
||||
int tensorWidth,
|
||||
int tensorHeight,
|
||||
float[] normMeanRgb,
|
||||
float[] normStdRgb,
|
||||
Buffer outBuffer,
|
||||
int outBufferOffset,
|
||||
int memoryFormatJniCode);
|
||||
}
|
||||
|
||||
private static void checkOutBufferCapacity(
|
||||
FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) {
|
||||
if (outBufferOffset + 3 * tensorWidth * tensorHeight > outBuffer.capacity()) {
|
||||
throw new IllegalStateException("Buffer underflow");
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkTensorSize(int tensorWidth, int tensorHeight) {
|
||||
if (tensorHeight <= 0 || tensorWidth <= 0) {
|
||||
throw new IllegalArgumentException("tensorHeight and tensorWidth must be positive");
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkRotateCWDegrees(int rotateCWDegrees) {
|
||||
if (rotateCWDegrees != 0
|
||||
&& rotateCWDegrees != 90
|
||||
&& rotateCWDegrees != 180
|
||||
&& rotateCWDegrees != 270) {
|
||||
throw new IllegalArgumentException("rotateCWDegrees must be one of 0, 90, 180, 270");
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkNormStdArg(float[] normStdRGB) {
|
||||
if (normStdRGB.length != 3) {
|
||||
throw new IllegalArgumentException("normStdRGB length must be 3");
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkNormMeanArg(float[] normMeanRGB) {
|
||||
if (normMeanRGB.length != 3) {
|
||||
throw new IllegalArgumentException("normMeanRGB length must be 3");
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
<resources>
|
||||
<string name="app_name">pytorch_android_torchvision</string>
|
||||
</resources>
|
Reference in New Issue
Block a user