mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Tensor prep from image in native (#31426)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31426 Tensor convertion from YUV image is moved to native with optimizations to eliminate branching inside loop, no variables declaration, less ops. Perf stat from local devices - measuring converting 320x240 image from camera to 1,3,224,224 tensor; Legend: Java - current java impl JavaOpt - current java impl + the same optimizations with no if/else in for, declare variables outside of for, inlining etc. C - C impl ``` Nexus 5 JavaOpt N:25 avg:119.24 min: 87 max:177 p10:102 p25:105 p50:115 p75:127 p90:150 C N:25 avg: 17.24 min: 14 max: 39 p10: 14 p25: 15 p50: 15 p75: 16 p90: 23 Java N:25 avg:139.96 min: 70 max:214 p10: 89 p25:110 p50:139 p75:173 p90:181 avg C vs JavaOpt 6.91x Pixel 3 XL JavaOpt N:19 avg: 16.11 min: 12 max: 19 p10: 14 p25: 15 p50: 16 p75: 18 p90: 19 C N:19 avg: 5.79 min: 3 max: 10 p10: 4 p25: 5 p50: 6 p75: 6 p90: 9 Java N:19 avg: 16.21 min: 12 max: 20 p10: 14 p25: 15 p50: 16 p75: 18 p90: 20 avg C vs JavaOpt 2.78x Full build with 4 abis inside: Pixel 3 XL JavaOpt N:25 avg: 18.84 min: 16 max: 24 p10: 16 p25: 17 p50: 18 p75: 20 p90: 22 C N:25 avg: 7.96 min: 5 max: 10 p10: 7 p25: 7 p50: 8 p75: 9 p90: 9 avg C vs JavaOpt 2.36x ``` Test Plan: Imported from OSS Differential Revision: D19165429 Pulled By: IvanKobzarev fbshipit-source-id: 3b54e545f6fbecbc5bb43216aca81061e70bd369
This commit is contained in:
committed by
Facebook Github Bot
parent
de5821d291
commit
104b2c610b
@ -11,6 +11,10 @@ allprojects {
|
||||
runnerVersion = "1.2.0"
|
||||
rulesVersion = "1.2.0"
|
||||
junitVersion = "4.12"
|
||||
|
||||
androidSupportAppCompatV7Version = "28.0.0"
|
||||
fbjniJavaOnlyVersion = "0.0.3"
|
||||
soLoaderNativeLoaderVersion = "0.8.0"
|
||||
}
|
||||
|
||||
repositories {
|
||||
|
@ -58,9 +58,9 @@ android {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'com.facebook.fbjni:fbjni-java-only:0.0.3'
|
||||
implementation 'com.android.support:appcompat-v7:28.0.0'
|
||||
implementation 'com.facebook.soloader:nativeloader:0.8.0'
|
||||
implementation 'com.facebook.fbjni:fbjni-java-only:' + rootProject.fbjniJavaOnlyVersion
|
||||
implementation 'com.android.support:appcompat-v7:' + rootProject.androidSupportAppCompatV7Version
|
||||
implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion
|
||||
|
||||
testImplementation 'junit:junit:' + rootProject.junitVersion
|
||||
testImplementation 'androidx.test:core:' + rootProject.coreVersion
|
||||
|
22
android/pytorch_android_torchvision/CMakeLists.txt
Normal file
22
android/pytorch_android_torchvision/CMakeLists.txt
Normal file
@ -0,0 +1,22 @@
|
||||
cmake_minimum_required(VERSION 3.4.1)
|
||||
project(pytorch_vision_jni CXX)
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
|
||||
set(pytorch_vision_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
|
||||
|
||||
file(GLOB pytorch_vision_SOURCES
|
||||
${pytorch_vision_cpp_DIR}/pytorch_vision_jni.cpp
|
||||
)
|
||||
|
||||
add_library(pytorch_vision_jni SHARED
|
||||
${pytorch_vision_SOURCES}
|
||||
)
|
||||
|
||||
target_compile_options(pytorch_vision_jni PRIVATE
|
||||
-fexceptions
|
||||
)
|
||||
|
||||
set(BUILD_SUBDIR ${ANDROID_ABI})
|
||||
|
||||
target_link_libraries(pytorch_vision_jni)
|
@ -13,7 +13,9 @@ android {
|
||||
versionName "0.1"
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
|
||||
ndk {
|
||||
abiFilters ABI_FILTERS.split(",")
|
||||
}
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
@ -26,6 +28,12 @@ android {
|
||||
}
|
||||
}
|
||||
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path "CMakeLists.txt"
|
||||
}
|
||||
}
|
||||
|
||||
useLibrary 'android.test.runner'
|
||||
useLibrary 'android.test.base'
|
||||
useLibrary 'android.test.mock'
|
||||
@ -34,7 +42,8 @@ android {
|
||||
dependencies {
|
||||
implementation project(':pytorch_android')
|
||||
|
||||
implementation 'com.android.support:appcompat-v7:28.0.0'
|
||||
implementation 'com.android.support:appcompat-v7:' + rootProject.androidSupportAppCompatV7Version
|
||||
implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion
|
||||
|
||||
testImplementation 'junit:junit:' + rootProject.junitVersion
|
||||
testImplementation 'androidx.test:core:' + rootProject.coreVersion
|
||||
|
@ -0,0 +1,144 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "jni.h"
|
||||
|
||||
#define clamp0255(x) x > 255 ? 255 : x < 0 ? 0 : x
|
||||
|
||||
namespace pytorch_vision_jni {
|
||||
|
||||
static void imageYUV420CenterCropToFloatBuffer(
|
||||
JNIEnv* jniEnv,
|
||||
jclass,
|
||||
jobject yBuffer,
|
||||
jint yRowStride,
|
||||
jint yPixelStride,
|
||||
jobject uBuffer,
|
||||
jobject vBuffer,
|
||||
jint uRowStride,
|
||||
jint uvPixelStride,
|
||||
jint imageWidth,
|
||||
jint imageHeight,
|
||||
jint rotateCWDegrees,
|
||||
jint tensorWidth,
|
||||
jint tensorHeight,
|
||||
jfloatArray jnormMeanRGB,
|
||||
jfloatArray jnormStdRGB,
|
||||
jobject outBuffer,
|
||||
jint outOffset) {
|
||||
float* outData = (float*)jniEnv->GetDirectBufferAddress(outBuffer);
|
||||
|
||||
jfloat normMeanRGB[3];
|
||||
jfloat normStdRGB[3];
|
||||
jniEnv->GetFloatArrayRegion(jnormMeanRGB, 0, 3, normMeanRGB);
|
||||
jniEnv->GetFloatArrayRegion(jnormStdRGB, 0, 3, normStdRGB);
|
||||
int widthAfterRtn = imageWidth;
|
||||
int heightAfterRtn = imageHeight;
|
||||
bool oddRotation = rotateCWDegrees == 90 || rotateCWDegrees == 270;
|
||||
if (oddRotation) {
|
||||
widthAfterRtn = imageHeight;
|
||||
heightAfterRtn = imageWidth;
|
||||
}
|
||||
|
||||
int cropWidthAfterRtn = widthAfterRtn;
|
||||
int cropHeightAfterRtn = heightAfterRtn;
|
||||
|
||||
if (tensorWidth * heightAfterRtn <= tensorHeight * widthAfterRtn) {
|
||||
cropWidthAfterRtn = tensorWidth * heightAfterRtn / tensorHeight;
|
||||
} else {
|
||||
cropHeightAfterRtn = tensorHeight * widthAfterRtn / tensorWidth;
|
||||
}
|
||||
|
||||
int cropWidthBeforeRtn = cropWidthAfterRtn;
|
||||
int cropHeightBeforeRtn = cropHeightAfterRtn;
|
||||
if (oddRotation) {
|
||||
cropWidthBeforeRtn = cropHeightAfterRtn;
|
||||
cropHeightBeforeRtn = cropWidthAfterRtn;
|
||||
}
|
||||
|
||||
const int offsetX = (imageWidth - cropWidthBeforeRtn) / 2.f;
|
||||
const int offsetY = (imageHeight - cropHeightBeforeRtn) / 2.f;
|
||||
|
||||
const uint8_t* yData = (uint8_t*)jniEnv->GetDirectBufferAddress(yBuffer);
|
||||
const uint8_t* uData = (uint8_t*)jniEnv->GetDirectBufferAddress(uBuffer);
|
||||
const uint8_t* vData = (uint8_t*)jniEnv->GetDirectBufferAddress(vBuffer);
|
||||
|
||||
float scale = cropWidthAfterRtn / tensorWidth;
|
||||
int uvRowStride = uRowStride >> 1;
|
||||
int cropXMult = 1;
|
||||
int cropYMult = 1;
|
||||
int cropXAdd = offsetX;
|
||||
int cropYAdd = offsetY;
|
||||
if (rotateCWDegrees == 90) {
|
||||
cropYMult = -1;
|
||||
cropYAdd = offsetY + (cropHeightBeforeRtn - 1);
|
||||
} else if (rotateCWDegrees == 180) {
|
||||
cropXMult = -1;
|
||||
cropXAdd = offsetX + (cropWidthBeforeRtn - 1);
|
||||
cropYMult = -1;
|
||||
cropYAdd = offsetY + (cropHeightBeforeRtn - 1);
|
||||
} else if (rotateCWDegrees == 270) {
|
||||
cropXMult = -1;
|
||||
cropXAdd = offsetX + (cropWidthBeforeRtn - 1);
|
||||
}
|
||||
|
||||
float normMeanRm255 = 255 * normMeanRGB[0];
|
||||
float normMeanGm255 = 255 * normMeanRGB[1];
|
||||
float normMeanBm255 = 255 * normMeanRGB[2];
|
||||
float normStdRm255 = 255 * normStdRGB[0];
|
||||
float normStdGm255 = 255 * normStdRGB[1];
|
||||
float normStdBm255 = 255 * normStdRGB[2];
|
||||
|
||||
int xBeforeRtn, yBeforeRtn;
|
||||
int yIdx, uvIdx, ui, vi, a0, ri, gi, bi;
|
||||
int channelSize = tensorWidth * tensorHeight;
|
||||
int wr = outOffset;
|
||||
int wg = wr + channelSize;
|
||||
int wb = wg + channelSize;
|
||||
for (int x = 0; x < tensorWidth; x++) {
|
||||
for (int y = 0; y < tensorHeight; y++) {
|
||||
xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
|
||||
yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
|
||||
yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
|
||||
uvIdx = (yBeforeRtn >> 1) * uvRowStride + xBeforeRtn * uvPixelStride;
|
||||
ui = uData[uvIdx];
|
||||
vi = vData[uvIdx];
|
||||
a0 = 1192 * (yData[yIdx] - 16);
|
||||
ri = (a0 + 1634 * (vi - 128)) >> 10;
|
||||
gi = (a0 - 832 * (vi - 128) - 400 * (ui - 128)) >> 10;
|
||||
bi = (a0 + 2066 * (ui - 128)) >> 10;
|
||||
outData[wr++] = (clamp0255(ri) - normMeanRm255) / normStdRm255;
|
||||
outData[wg++] = (clamp0255(gi) - normMeanGm255) / normStdGm255;
|
||||
outData[wb++] = (clamp0255(bi) - normMeanBm255) / normStdBm255;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace pytorch_vision_jni
|
||||
|
||||
JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
|
||||
JNIEnv* env;
|
||||
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
|
||||
return JNI_ERR;
|
||||
}
|
||||
|
||||
jclass c =
|
||||
env->FindClass("org/pytorch/torchvision/TensorImageUtils$NativePeer");
|
||||
if (c == nullptr) {
|
||||
return JNI_ERR;
|
||||
}
|
||||
|
||||
static const JNINativeMethod methods[] = {
|
||||
{"imageYUV420CenterCropToFloatBuffer",
|
||||
"(Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;Ljava/nio/ByteBuffer;IIIIIII[F[FLjava/nio/Buffer;I)V",
|
||||
(void*)pytorch_vision_jni::imageYUV420CenterCropToFloatBuffer},
|
||||
};
|
||||
int rc = env->RegisterNatives(
|
||||
c, methods, sizeof(methods) / sizeof(JNINativeMethod));
|
||||
|
||||
if (rc != JNI_OK) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
return JNI_VERSION_1_6;
|
||||
}
|
@ -4,8 +4,12 @@ import android.graphics.Bitmap;
|
||||
import android.graphics.ImageFormat;
|
||||
import android.media.Image;
|
||||
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||
|
||||
import org.pytorch.Tensor;
|
||||
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.Locale;
|
||||
@ -185,108 +189,57 @@ public final class TensorImageUtils {
|
||||
checkRotateCWDegrees(rotateCWDegrees);
|
||||
checkTensorSize(tensorWidth, tensorHeight);
|
||||
|
||||
final int widthBeforeRotation = image.getWidth();
|
||||
final int heightBeforeRotation = image.getHeight();
|
||||
Image.Plane[] planes = image.getPlanes();
|
||||
Image.Plane Y = planes[0];
|
||||
Image.Plane U = planes[1];
|
||||
Image.Plane V = planes[2];
|
||||
|
||||
int widthAfterRotation = widthBeforeRotation;
|
||||
int heightAfterRotation = heightBeforeRotation;
|
||||
if (rotateCWDegrees == 90 || rotateCWDegrees == 270) {
|
||||
widthAfterRotation = heightBeforeRotation;
|
||||
heightAfterRotation = widthBeforeRotation;
|
||||
}
|
||||
NativePeer.imageYUV420CenterCropToFloatBuffer(
|
||||
Y.getBuffer(),
|
||||
Y.getRowStride(),
|
||||
Y.getPixelStride(),
|
||||
U.getBuffer(),
|
||||
V.getBuffer(),
|
||||
U.getRowStride(),
|
||||
U.getPixelStride(),
|
||||
image.getWidth(),
|
||||
image.getHeight(),
|
||||
rotateCWDegrees,
|
||||
tensorWidth,
|
||||
tensorHeight,
|
||||
normMeanRGB,
|
||||
normStdRGB,
|
||||
outBuffer,
|
||||
outBufferOffset
|
||||
);
|
||||
}
|
||||
|
||||
int centerCropWidthAfterRotation = widthAfterRotation;
|
||||
int centerCropHeightAfterRotation = heightAfterRotation;
|
||||
|
||||
if (tensorWidth * heightAfterRotation <= tensorHeight * widthAfterRotation) {
|
||||
centerCropWidthAfterRotation =
|
||||
(int) Math.floor((float) tensorWidth * heightAfterRotation / tensorHeight);
|
||||
} else {
|
||||
centerCropHeightAfterRotation =
|
||||
(int) Math.floor((float) tensorHeight * widthAfterRotation / tensorWidth);
|
||||
}
|
||||
|
||||
int centerCropWidthBeforeRotation = centerCropWidthAfterRotation;
|
||||
int centerCropHeightBeforeRotation = centerCropHeightAfterRotation;
|
||||
if (rotateCWDegrees == 90 || rotateCWDegrees == 270) {
|
||||
centerCropHeightBeforeRotation = centerCropWidthAfterRotation;
|
||||
centerCropWidthBeforeRotation = centerCropHeightAfterRotation;
|
||||
}
|
||||
|
||||
final int offsetX =
|
||||
(int) Math.floor((widthBeforeRotation - centerCropWidthBeforeRotation) / 2.f);
|
||||
final int offsetY =
|
||||
(int) Math.floor((heightBeforeRotation - centerCropHeightBeforeRotation) / 2.f);
|
||||
|
||||
final Image.Plane yPlane = image.getPlanes()[0];
|
||||
final Image.Plane uPlane = image.getPlanes()[1];
|
||||
final Image.Plane vPlane = image.getPlanes()[2];
|
||||
|
||||
final ByteBuffer yBuffer = yPlane.getBuffer();
|
||||
final ByteBuffer uBuffer = uPlane.getBuffer();
|
||||
final ByteBuffer vBuffer = vPlane.getBuffer();
|
||||
|
||||
final int yRowStride = yPlane.getRowStride();
|
||||
final int uRowStride = uPlane.getRowStride();
|
||||
|
||||
final int yPixelStride = yPlane.getPixelStride();
|
||||
final int uPixelStride = uPlane.getPixelStride();
|
||||
|
||||
final float scale = (float) centerCropWidthAfterRotation / tensorWidth;
|
||||
final int uvRowStride = uRowStride >> 1;
|
||||
|
||||
final int channelSize = tensorHeight * tensorWidth;
|
||||
final int tensorInputOffsetG = channelSize;
|
||||
final int tensorInputOffsetB = 2 * channelSize;
|
||||
for (int x = 0; x < tensorWidth; x++) {
|
||||
for (int y = 0; y < tensorHeight; y++) {
|
||||
|
||||
final int centerCropXAfterRotation = (int) Math.floor(x * scale);
|
||||
final int centerCropYAfterRotation = (int) Math.floor(y * scale);
|
||||
|
||||
int xBeforeRotation = offsetX + centerCropXAfterRotation;
|
||||
int yBeforeRotation = offsetY + centerCropYAfterRotation;
|
||||
if (rotateCWDegrees == 90) {
|
||||
xBeforeRotation = offsetX + centerCropYAfterRotation;
|
||||
yBeforeRotation =
|
||||
offsetY + (centerCropHeightBeforeRotation - 1) - centerCropXAfterRotation;
|
||||
} else if (rotateCWDegrees == 180) {
|
||||
xBeforeRotation =
|
||||
offsetX + (centerCropWidthBeforeRotation - 1) - centerCropXAfterRotation;
|
||||
yBeforeRotation =
|
||||
offsetY + (centerCropHeightBeforeRotation - 1) - centerCropYAfterRotation;
|
||||
} else if (rotateCWDegrees == 270) {
|
||||
xBeforeRotation =
|
||||
offsetX + (centerCropWidthBeforeRotation - 1) - centerCropYAfterRotation;
|
||||
yBeforeRotation = offsetY + centerCropXAfterRotation;
|
||||
}
|
||||
|
||||
final int yIdx = yBeforeRotation * yRowStride + xBeforeRotation * yPixelStride;
|
||||
final int uvIdx = (yBeforeRotation >> 1) * uvRowStride + xBeforeRotation * uPixelStride;
|
||||
|
||||
int Yi = yBuffer.get(yIdx) & 0xff;
|
||||
int Ui = uBuffer.get(uvIdx) & 0xff;
|
||||
int Vi = vBuffer.get(uvIdx) & 0xff;
|
||||
|
||||
int a0 = 1192 * (Yi - 16);
|
||||
int a1 = 1634 * (Vi - 128);
|
||||
int a2 = 832 * (Vi - 128);
|
||||
int a3 = 400 * (Ui - 128);
|
||||
int a4 = 2066 * (Ui - 128);
|
||||
|
||||
int r = clamp((a0 + a1) >> 10, 0, 255);
|
||||
int g = clamp((a0 - a2 - a3) >> 10, 0, 255);
|
||||
int b = clamp((a0 + a4) >> 10, 0, 255);
|
||||
final int offset = outBufferOffset + y * tensorWidth + x;
|
||||
float rF = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0];
|
||||
float gF = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1];
|
||||
float bF = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2];
|
||||
|
||||
outBuffer.put(offset, rF);
|
||||
outBuffer.put(offset + tensorInputOffsetG, gF);
|
||||
outBuffer.put(offset + tensorInputOffsetB, bF);
|
||||
private static class NativePeer {
|
||||
static {
|
||||
if (!NativeLoader.isInitialized()) {
|
||||
NativeLoader.init(new SystemDelegate());
|
||||
}
|
||||
NativeLoader.loadLibrary("pytorch_vision_jni");
|
||||
}
|
||||
|
||||
private static native void imageYUV420CenterCropToFloatBuffer(
|
||||
ByteBuffer yBuffer,
|
||||
int yRowStride,
|
||||
int yPixelStride,
|
||||
ByteBuffer uBuffer,
|
||||
ByteBuffer vBuffer,
|
||||
int uvRowStride,
|
||||
int uvPixelStride,
|
||||
int imageWidth,
|
||||
int imageHeight,
|
||||
int rotateCWDegrees,
|
||||
int tensorWidth,
|
||||
int tensorHeight,
|
||||
float[] normMeanRgb,
|
||||
float[] normStdRgb,
|
||||
Buffer outBuffer,
|
||||
int outBufferOffset
|
||||
);
|
||||
}
|
||||
|
||||
private static void checkOutBufferCapacity(FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) {
|
||||
@ -310,10 +263,6 @@ public final class TensorImageUtils {
|
||||
}
|
||||
}
|
||||
|
||||
private static final int clamp(int c, int min, int max) {
|
||||
return c < min ? min : c > max ? max : c;
|
||||
}
|
||||
|
||||
private static void checkNormStdArg(float[] normStdRGB) {
|
||||
if (normStdRGB.length != 3) {
|
||||
throw new IllegalArgumentException("normStdRGB length must be 3");
|
||||
|
Reference in New Issue
Block a user