Tensor prep from image in native (#31426)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31426

Tensor convertion from YUV image is moved to native with optimizations to eliminate branching inside loop, no variables declaration, less ops.

Perf stat from local devices - measuring converting 320x240 image from camera to 1,3,224,224 tensor;
Legend:
Java - current java impl
JavaOpt - current java impl + the same optimizations with no if/else in for, declare variables outside of for, inlining etc.
C - C impl

```
Nexus 5
JavaOpt N:25 avg:119.24 min: 87 max:177 p10:102 p25:105 p50:115 p75:127 p90:150
      C N:25 avg: 17.24 min: 14 max: 39 p10: 14 p25: 15 p50: 15 p75: 16 p90: 23
   Java N:25 avg:139.96 min: 70 max:214 p10: 89 p25:110 p50:139 p75:173 p90:181
avg C vs JavaOpt 6.91x

Pixel 3 XL
JavaOpt N:19 avg: 16.11 min: 12 max: 19 p10: 14 p25: 15 p50: 16 p75: 18 p90: 19
      C N:19 avg:  5.79 min:  3 max: 10 p10:  4 p25:  5 p50:  6 p75:  6 p90:  9
   Java N:19 avg: 16.21 min: 12 max: 20 p10: 14 p25: 15 p50: 16 p75: 18 p90: 20
avg C vs JavaOpt 2.78x

Full build with 4 abis inside:
Pixel 3 XL
JavaOpt N:25 avg: 18.84 min: 16 max: 24 p10: 16 p25: 17 p50: 18 p75: 20 p90: 22
      C N:25 avg:  7.96 min:  5 max: 10 p10:  7 p25:  7 p50:  8 p75:  9 p90:  9
avg C vs JavaOpt 2.36x
```

Test Plan: Imported from OSS

Differential Revision: D19165429

Pulled By: IvanKobzarev

fbshipit-source-id: 3b54e545f6fbecbc5bb43216aca81061e70bd369
This commit is contained in:
Ivan Kobzarev
2020-01-15 17:08:06 -08:00
committed by Facebook Github Bot
parent de5821d291
commit 104b2c610b
6 changed files with 235 additions and 107 deletions

View File

@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.4.1)
project(pytorch_vision_jni CXX)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_VERBOSE_MAKEFILE ON)
set(pytorch_vision_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
file(GLOB pytorch_vision_SOURCES
${pytorch_vision_cpp_DIR}/pytorch_vision_jni.cpp
)
add_library(pytorch_vision_jni SHARED
${pytorch_vision_SOURCES}
)
target_compile_options(pytorch_vision_jni PRIVATE
-fexceptions
)
set(BUILD_SUBDIR ${ANDROID_ABI})
target_link_libraries(pytorch_vision_jni)

View File

@ -13,7 +13,9 @@ android {
versionName "0.1"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
ndk {
abiFilters ABI_FILTERS.split(",")
}
}
buildTypes {
@ -26,6 +28,12 @@ android {
}
}
externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}
useLibrary 'android.test.runner'
useLibrary 'android.test.base'
useLibrary 'android.test.mock'
@ -34,7 +42,8 @@ android {
dependencies {
implementation project(':pytorch_android')
implementation 'com.android.support:appcompat-v7:28.0.0'
implementation 'com.android.support:appcompat-v7:' + rootProject.androidSupportAppCompatV7Version
implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion
testImplementation 'junit:junit:' + rootProject.junitVersion
testImplementation 'androidx.test:core:' + rootProject.coreVersion

View File

@ -0,0 +1,144 @@
#include <cassert>
#include <cmath>
#include <vector>
#include "jni.h"
#define clamp0255(x) x > 255 ? 255 : x < 0 ? 0 : x
namespace pytorch_vision_jni {
static void imageYUV420CenterCropToFloatBuffer(
JNIEnv* jniEnv,
jclass,
jobject yBuffer,
jint yRowStride,
jint yPixelStride,
jobject uBuffer,
jobject vBuffer,
jint uRowStride,
jint uvPixelStride,
jint imageWidth,
jint imageHeight,
jint rotateCWDegrees,
jint tensorWidth,
jint tensorHeight,
jfloatArray jnormMeanRGB,
jfloatArray jnormStdRGB,
jobject outBuffer,
jint outOffset) {
float* outData = (float*)jniEnv->GetDirectBufferAddress(outBuffer);
jfloat normMeanRGB[3];
jfloat normStdRGB[3];
jniEnv->GetFloatArrayRegion(jnormMeanRGB, 0, 3, normMeanRGB);
jniEnv->GetFloatArrayRegion(jnormStdRGB, 0, 3, normStdRGB);
int widthAfterRtn = imageWidth;
int heightAfterRtn = imageHeight;
bool oddRotation = rotateCWDegrees == 90 || rotateCWDegrees == 270;
if (oddRotation) {
widthAfterRtn = imageHeight;
heightAfterRtn = imageWidth;
}
int cropWidthAfterRtn = widthAfterRtn;
int cropHeightAfterRtn = heightAfterRtn;
if (tensorWidth * heightAfterRtn <= tensorHeight * widthAfterRtn) {
cropWidthAfterRtn = tensorWidth * heightAfterRtn / tensorHeight;
} else {
cropHeightAfterRtn = tensorHeight * widthAfterRtn / tensorWidth;
}
int cropWidthBeforeRtn = cropWidthAfterRtn;
int cropHeightBeforeRtn = cropHeightAfterRtn;
if (oddRotation) {
cropWidthBeforeRtn = cropHeightAfterRtn;
cropHeightBeforeRtn = cropWidthAfterRtn;
}
const int offsetX = (imageWidth - cropWidthBeforeRtn) / 2.f;
const int offsetY = (imageHeight - cropHeightBeforeRtn) / 2.f;
const uint8_t* yData = (uint8_t*)jniEnv->GetDirectBufferAddress(yBuffer);
const uint8_t* uData = (uint8_t*)jniEnv->GetDirectBufferAddress(uBuffer);
const uint8_t* vData = (uint8_t*)jniEnv->GetDirectBufferAddress(vBuffer);
float scale = cropWidthAfterRtn / tensorWidth;
int uvRowStride = uRowStride >> 1;
int cropXMult = 1;
int cropYMult = 1;
int cropXAdd = offsetX;
int cropYAdd = offsetY;
if (rotateCWDegrees == 90) {
cropYMult = -1;
cropYAdd = offsetY + (cropHeightBeforeRtn - 1);
} else if (rotateCWDegrees == 180) {
cropXMult = -1;
cropXAdd = offsetX + (cropWidthBeforeRtn - 1);
cropYMult = -1;
cropYAdd = offsetY + (cropHeightBeforeRtn - 1);
} else if (rotateCWDegrees == 270) {
cropXMult = -1;
cropXAdd = offsetX + (cropWidthBeforeRtn - 1);
}
float normMeanRm255 = 255 * normMeanRGB[0];
float normMeanGm255 = 255 * normMeanRGB[1];
float normMeanBm255 = 255 * normMeanRGB[2];
float normStdRm255 = 255 * normStdRGB[0];
float normStdGm255 = 255 * normStdRGB[1];
float normStdBm255 = 255 * normStdRGB[2];
int xBeforeRtn, yBeforeRtn;
int yIdx, uvIdx, ui, vi, a0, ri, gi, bi;
int channelSize = tensorWidth * tensorHeight;
int wr = outOffset;
int wg = wr + channelSize;
int wb = wg + channelSize;
for (int x = 0; x < tensorWidth; x++) {
for (int y = 0; y < tensorHeight; y++) {
xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
uvIdx = (yBeforeRtn >> 1) * uvRowStride + xBeforeRtn * uvPixelStride;
ui = uData[uvIdx];
vi = vData[uvIdx];
a0 = 1192 * (yData[yIdx] - 16);
ri = (a0 + 1634 * (vi - 128)) >> 10;
gi = (a0 - 832 * (vi - 128) - 400 * (ui - 128)) >> 10;
bi = (a0 + 2066 * (ui - 128)) >> 10;
outData[wr++] = (clamp0255(ri) - normMeanRm255) / normStdRm255;
outData[wg++] = (clamp0255(gi) - normMeanGm255) / normStdGm255;
outData[wb++] = (clamp0255(bi) - normMeanBm255) / normStdBm255;
}
}
}
} // namespace pytorch_vision_jni
JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
JNIEnv* env;
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
return JNI_ERR;
}
jclass c =
env->FindClass("org/pytorch/torchvision/TensorImageUtils$NativePeer");
if (c == nullptr) {
return JNI_ERR;
}
static const JNINativeMethod methods[] = {
{"imageYUV420CenterCropToFloatBuffer",
"(Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;Ljava/nio/ByteBuffer;IIIIIII[F[FLjava/nio/Buffer;I)V",
(void*)pytorch_vision_jni::imageYUV420CenterCropToFloatBuffer},
};
int rc = env->RegisterNatives(
c, methods, sizeof(methods) / sizeof(JNINativeMethod));
if (rc != JNI_OK) {
return rc;
}
return JNI_VERSION_1_6;
}

View File

@ -4,8 +4,12 @@ import android.graphics.Bitmap;
import android.graphics.ImageFormat;
import android.media.Image;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import org.pytorch.Tensor;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.Locale;
@ -185,108 +189,57 @@ public final class TensorImageUtils {
checkRotateCWDegrees(rotateCWDegrees);
checkTensorSize(tensorWidth, tensorHeight);
final int widthBeforeRotation = image.getWidth();
final int heightBeforeRotation = image.getHeight();
Image.Plane[] planes = image.getPlanes();
Image.Plane Y = planes[0];
Image.Plane U = planes[1];
Image.Plane V = planes[2];
int widthAfterRotation = widthBeforeRotation;
int heightAfterRotation = heightBeforeRotation;
if (rotateCWDegrees == 90 || rotateCWDegrees == 270) {
widthAfterRotation = heightBeforeRotation;
heightAfterRotation = widthBeforeRotation;
}
NativePeer.imageYUV420CenterCropToFloatBuffer(
Y.getBuffer(),
Y.getRowStride(),
Y.getPixelStride(),
U.getBuffer(),
V.getBuffer(),
U.getRowStride(),
U.getPixelStride(),
image.getWidth(),
image.getHeight(),
rotateCWDegrees,
tensorWidth,
tensorHeight,
normMeanRGB,
normStdRGB,
outBuffer,
outBufferOffset
);
}
int centerCropWidthAfterRotation = widthAfterRotation;
int centerCropHeightAfterRotation = heightAfterRotation;
if (tensorWidth * heightAfterRotation <= tensorHeight * widthAfterRotation) {
centerCropWidthAfterRotation =
(int) Math.floor((float) tensorWidth * heightAfterRotation / tensorHeight);
} else {
centerCropHeightAfterRotation =
(int) Math.floor((float) tensorHeight * widthAfterRotation / tensorWidth);
}
int centerCropWidthBeforeRotation = centerCropWidthAfterRotation;
int centerCropHeightBeforeRotation = centerCropHeightAfterRotation;
if (rotateCWDegrees == 90 || rotateCWDegrees == 270) {
centerCropHeightBeforeRotation = centerCropWidthAfterRotation;
centerCropWidthBeforeRotation = centerCropHeightAfterRotation;
}
final int offsetX =
(int) Math.floor((widthBeforeRotation - centerCropWidthBeforeRotation) / 2.f);
final int offsetY =
(int) Math.floor((heightBeforeRotation - centerCropHeightBeforeRotation) / 2.f);
final Image.Plane yPlane = image.getPlanes()[0];
final Image.Plane uPlane = image.getPlanes()[1];
final Image.Plane vPlane = image.getPlanes()[2];
final ByteBuffer yBuffer = yPlane.getBuffer();
final ByteBuffer uBuffer = uPlane.getBuffer();
final ByteBuffer vBuffer = vPlane.getBuffer();
final int yRowStride = yPlane.getRowStride();
final int uRowStride = uPlane.getRowStride();
final int yPixelStride = yPlane.getPixelStride();
final int uPixelStride = uPlane.getPixelStride();
final float scale = (float) centerCropWidthAfterRotation / tensorWidth;
final int uvRowStride = uRowStride >> 1;
final int channelSize = tensorHeight * tensorWidth;
final int tensorInputOffsetG = channelSize;
final int tensorInputOffsetB = 2 * channelSize;
for (int x = 0; x < tensorWidth; x++) {
for (int y = 0; y < tensorHeight; y++) {
final int centerCropXAfterRotation = (int) Math.floor(x * scale);
final int centerCropYAfterRotation = (int) Math.floor(y * scale);
int xBeforeRotation = offsetX + centerCropXAfterRotation;
int yBeforeRotation = offsetY + centerCropYAfterRotation;
if (rotateCWDegrees == 90) {
xBeforeRotation = offsetX + centerCropYAfterRotation;
yBeforeRotation =
offsetY + (centerCropHeightBeforeRotation - 1) - centerCropXAfterRotation;
} else if (rotateCWDegrees == 180) {
xBeforeRotation =
offsetX + (centerCropWidthBeforeRotation - 1) - centerCropXAfterRotation;
yBeforeRotation =
offsetY + (centerCropHeightBeforeRotation - 1) - centerCropYAfterRotation;
} else if (rotateCWDegrees == 270) {
xBeforeRotation =
offsetX + (centerCropWidthBeforeRotation - 1) - centerCropYAfterRotation;
yBeforeRotation = offsetY + centerCropXAfterRotation;
}
final int yIdx = yBeforeRotation * yRowStride + xBeforeRotation * yPixelStride;
final int uvIdx = (yBeforeRotation >> 1) * uvRowStride + xBeforeRotation * uPixelStride;
int Yi = yBuffer.get(yIdx) & 0xff;
int Ui = uBuffer.get(uvIdx) & 0xff;
int Vi = vBuffer.get(uvIdx) & 0xff;
int a0 = 1192 * (Yi - 16);
int a1 = 1634 * (Vi - 128);
int a2 = 832 * (Vi - 128);
int a3 = 400 * (Ui - 128);
int a4 = 2066 * (Ui - 128);
int r = clamp((a0 + a1) >> 10, 0, 255);
int g = clamp((a0 - a2 - a3) >> 10, 0, 255);
int b = clamp((a0 + a4) >> 10, 0, 255);
final int offset = outBufferOffset + y * tensorWidth + x;
float rF = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0];
float gF = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1];
float bF = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2];
outBuffer.put(offset, rF);
outBuffer.put(offset + tensorInputOffsetG, gF);
outBuffer.put(offset + tensorInputOffsetB, bF);
private static class NativePeer {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_vision_jni");
}
private static native void imageYUV420CenterCropToFloatBuffer(
ByteBuffer yBuffer,
int yRowStride,
int yPixelStride,
ByteBuffer uBuffer,
ByteBuffer vBuffer,
int uvRowStride,
int uvPixelStride,
int imageWidth,
int imageHeight,
int rotateCWDegrees,
int tensorWidth,
int tensorHeight,
float[] normMeanRgb,
float[] normStdRgb,
Buffer outBuffer,
int outBufferOffset
);
}
private static void checkOutBufferCapacity(FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) {
@ -310,10 +263,6 @@ public final class TensorImageUtils {
}
}
private static final int clamp(int c, int min, int max) {
return c < min ? min : c > max ? max : c;
}
private static void checkNormStdArg(float[] normStdRGB) {
if (normStdRGB.length != 3) {
throw new IllegalArgumentException("normStdRGB length must be 3");