[android][utils] Support ChannelsLast in TensorImageUtils (#48990)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48990

Introducing TensorImageUtils methods to prepare tensors in channelsLast MemoryFormat.
ChannlesLast is preferred for performance.

Not to introduce api breaking changes, adding additional parameter MemoryFormat which is CONTIGUOUS by default.

Testing by checking test_app that uses this call
```
gradle -p android installMnetLocalBaseDebug -PABI_FILTERS=arm64-v8a
```

Test Plan: Imported from OSS

Reviewed By: jeffxtang

Differential Revision: D27173940

Pulled By: IvanKobzarev

fbshipit-source-id: 27788082d2c8b190323eadcf18de25d2c3b5e1f1
This commit is contained in:
Ivan Kobzarev
2021-03-23 14:53:25 -07:00
committed by Facebook GitHub Bot
parent 792f5ffb83
commit 345b26ca08
4 changed files with 170 additions and 53 deletions

View File

@ -24,7 +24,11 @@ static void imageYUV420CenterCropToFloatBuffer(
jfloatArray jnormMeanRGB,
jfloatArray jnormStdRGB,
jobject outBuffer,
jint outOffset) {
jint outOffset,
jint memoryFormatCode) {
constexpr static int32_t kMemoryFormatContiguous = 1;
constexpr static int32_t kMemoryFormatChannelsLast = 2;
float* outData = (float*)jniEnv->GetDirectBufferAddress(outBuffer);
jfloat normMeanRGB[3];
@ -91,32 +95,64 @@ static void imageYUV420CenterCropToFloatBuffer(
int xBeforeRtn, yBeforeRtn;
int yi, yIdx, uvIdx, ui, vi, a0, ri, gi, bi;
int channelSize = tensorWidth * tensorHeight;
int wr = outOffset;
int wg = wr + channelSize;
int wb = wg + channelSize;
for (int x = 0; x < tensorWidth; x++) {
// A bit of code duplication to avoid branching in the cycles
if (memoryFormatCode == kMemoryFormatContiguous) {
int wr = outOffset;
int wg = wr + channelSize;
int wb = wg + channelSize;
for (int y = 0; y < tensorHeight; y++) {
xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
uvIdx = (yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
ui = uData[uvIdx];
vi = vData[uvIdx];
yi = yData[yIdx];
yi = (yi - 16) < 0 ? 0 : (yi - 16);
ui -= 128;
vi -= 128;
a0 = 1192 * yi;
ri = (a0 + 1634 * vi) >> 10;
gi = (a0 - 833 * vi - 400 * ui) >> 10;
bi = (a0 + 2066 * ui) >> 10;
ri = ri > 255 ? 255 : ri < 0 ? 0 : ri;
gi = gi > 255 ? 255 : gi < 0 ? 0 : gi;
bi = bi > 255 ? 255 : bi < 0 ? 0 : bi;
outData[wr++] = (ri - normMeanRm255) / normStdRm255;
outData[wg++] = (gi - normMeanGm255) / normStdGm255;
outData[wb++] = (bi - normMeanBm255) / normStdBm255;
for (int x = 0; x < tensorWidth; x++) {
xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
uvIdx = (yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
ui = uData[uvIdx];
vi = vData[uvIdx];
yi = yData[yIdx];
yi = (yi - 16) < 0 ? 0 : (yi - 16);
ui -= 128;
vi -= 128;
a0 = 1192 * yi;
ri = (a0 + 1634 * vi) >> 10;
gi = (a0 - 833 * vi - 400 * ui) >> 10;
bi = (a0 + 2066 * ui) >> 10;
ri = ri > 255 ? 255 : ri < 0 ? 0 : ri;
gi = gi > 255 ? 255 : gi < 0 ? 0 : gi;
bi = bi > 255 ? 255 : bi < 0 ? 0 : bi;
outData[wr++] = (ri - normMeanRm255) / normStdRm255;
outData[wg++] = (gi - normMeanGm255) / normStdGm255;
outData[wb++] = (bi - normMeanBm255) / normStdBm255;
}
}
} else if (memoryFormatCode == kMemoryFormatChannelsLast) {
int wc = outOffset;
for (int y = 0; y < tensorHeight; y++) {
for (int x = 0; x < tensorWidth; x++) {
xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
uvIdx = (yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
ui = uData[uvIdx];
vi = vData[uvIdx];
yi = yData[yIdx];
yi = (yi - 16) < 0 ? 0 : (yi - 16);
ui -= 128;
vi -= 128;
a0 = 1192 * yi;
ri = (a0 + 1634 * vi) >> 10;
gi = (a0 - 833 * vi - 400 * ui) >> 10;
bi = (a0 + 2066 * ui) >> 10;
ri = ri > 255 ? 255 : ri < 0 ? 0 : ri;
gi = gi > 255 ? 255 : gi < 0 ? 0 : gi;
bi = bi > 255 ? 255 : bi < 0 ? 0 : bi;
outData[wc++] = (ri - normMeanRm255) / normStdRm255;
outData[wc++] = (gi - normMeanGm255) / normStdGm255;
outData[wc++] = (bi - normMeanBm255) / normStdBm255;
}
}
} else {
jclass Exception = jniEnv->FindClass("java/lang/IllegalArgumentException");
jniEnv->ThrowNew(Exception,"Illegal memory format code");
}
}
} // namespace pytorch_vision_jni
@ -135,7 +171,7 @@ JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
static const JNINativeMethod methods[] = {
{"imageYUV420CenterCropToFloatBuffer",
"(Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;Ljava/nio/ByteBuffer;IIIIIII[F[FLjava/nio/Buffer;I)V",
"(Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;Ljava/nio/ByteBuffer;IIIIIII[F[FLjava/nio/Buffer;II)V",
(void*)pytorch_vision_jni::imageYUV420CenterCropToFloatBuffer},
};
int rc = env->RegisterNatives(

View File

@ -9,6 +9,7 @@ import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.Locale;
import org.pytorch.MemoryFormat;
import org.pytorch.Tensor;
/**
@ -29,12 +30,18 @@ public final class TensorImageUtils {
* order
*/
public static Tensor bitmapToFloat32Tensor(
final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[]) {
final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[], final MemoryFormat memoryFormat) {
checkNormMeanArg(normMeanRGB);
checkNormStdArg(normStdRGB);
return bitmapToFloat32Tensor(
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB);
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB, memoryFormat);
}
public static Tensor bitmapToFloat32Tensor(
final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[]) {
return bitmapToFloat32Tensor(
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB, MemoryFormat.CONTIGUOUS);
}
/**
@ -59,30 +66,56 @@ public final class TensorImageUtils {
final float[] normMeanRGB,
final float[] normStdRGB,
final FloatBuffer outBuffer,
final int outBufferOffset) {
final int outBufferOffset,
final MemoryFormat memoryFormat) {
checkOutBufferCapacity(outBuffer, outBufferOffset, width, height);
checkNormMeanArg(normMeanRGB);
checkNormStdArg(normStdRGB);
if (memoryFormat != MemoryFormat.CONTIGUOUS && memoryFormat != MemoryFormat.CHANNELS_LAST) {
throw new IllegalArgumentException("Unsupported memory format " + memoryFormat);
}
final int pixelsCount = height * width;
final int[] pixels = new int[pixelsCount];
bitmap.getPixels(pixels, 0, width, x, y, width, height);
final int offset_g = pixelsCount;
final int offset_b = 2 * pixelsCount;
for (int i = 0; i < pixelsCount; i++) {
final int c = pixels[i];
float r = ((c >> 16) & 0xff) / 255.0f;
float g = ((c >> 8) & 0xff) / 255.0f;
float b = ((c) & 0xff) / 255.0f;
float rF = (r - normMeanRGB[0]) / normStdRGB[0];
float gF = (g - normMeanRGB[1]) / normStdRGB[1];
float bF = (b - normMeanRGB[2]) / normStdRGB[2];
outBuffer.put(outBufferOffset + i, rF);
outBuffer.put(outBufferOffset + offset_g + i, gF);
outBuffer.put(outBufferOffset + offset_b + i, bF);
if (MemoryFormat.CONTIGUOUS == memoryFormat) {
final int offset_g = pixelsCount;
final int offset_b = 2 * pixelsCount;
for (int i = 0; i < pixelsCount; i++) {
final int c = pixels[i];
float r = ((c >> 16) & 0xff) / 255.0f;
float g = ((c >> 8) & 0xff) / 255.0f;
float b = ((c) & 0xff) / 255.0f;
outBuffer.put(outBufferOffset + i, (r - normMeanRGB[0]) / normStdRGB[0]);
outBuffer.put(outBufferOffset + offset_g + i, (g - normMeanRGB[1]) / normStdRGB[1]);
outBuffer.put(outBufferOffset + offset_b + i, (b - normMeanRGB[2]) / normStdRGB[2]);
}
} else {
for (int i = 0; i < pixelsCount; i++) {
final int c = pixels[i];
float r = ((c >> 16) & 0xff) / 255.0f;
float g = ((c >> 8) & 0xff) / 255.0f;
float b = ((c) & 0xff) / 255.0f;
outBuffer.put(outBufferOffset + 3 * i + 0, (r - normMeanRGB[0]) / normStdRGB[0]);
outBuffer.put(outBufferOffset + 3 * i + 1, (g - normMeanRGB[1]) / normStdRGB[1]);
outBuffer.put(outBufferOffset + 3 * i + 2, (b - normMeanRGB[2]) / normStdRGB[2]);
}
}
}
public static void bitmapToFloatBuffer(
final Bitmap bitmap,
final int x,
final int y,
final int width,
final int height,
final float[] normMeanRGB,
final float[] normStdRGB,
final FloatBuffer outBuffer,
final int outBufferOffset) {
bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, outBuffer, outBufferOffset, MemoryFormat.CONTIGUOUS);
}
/**
* Creates new {@link org.pytorch.Tensor} from specified area of {@link android.graphics.Bitmap},
* normalized with specified in parameters mean and std.
@ -103,13 +136,25 @@ public final class TensorImageUtils {
int width,
int height,
float[] normMeanRGB,
float[] normStdRGB) {
float[] normStdRGB,
MemoryFormat memoryFormat) {
checkNormMeanArg(normMeanRGB);
checkNormStdArg(normStdRGB);
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * width * height);
bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0);
return Tensor.fromBlob(floatBuffer, new long[] {1, 3, height, width});
bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0, memoryFormat);
return Tensor.fromBlob(floatBuffer, new long[] {1, 3, height, width}, memoryFormat);
}
public static Tensor bitmapToFloat32Tensor(
final Bitmap bitmap,
int x,
int y,
int width,
int height,
float[] normMeanRGB,
float[] normStdRGB) {
return bitmapToFloat32Tensor(bitmap, x, y, width, height, normMeanRGB, normStdRGB, MemoryFormat.CONTIGUOUS);
}
/**
@ -131,7 +176,8 @@ public final class TensorImageUtils {
final int tensorWidth,
final int tensorHeight,
float[] normMeanRGB,
float[] normStdRGB) {
float[] normStdRGB,
MemoryFormat memoryFormat) {
if (image.getFormat() != ImageFormat.YUV_420_888) {
throw new IllegalArgumentException(
String.format(
@ -145,8 +191,18 @@ public final class TensorImageUtils {
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * tensorWidth * tensorHeight);
imageYUV420CenterCropToFloatBuffer(
image, rotateCWDegrees, tensorWidth, tensorHeight, normMeanRGB, normStdRGB, floatBuffer, 0);
return Tensor.fromBlob(floatBuffer, new long[] {1, 3, tensorHeight, tensorWidth});
image, rotateCWDegrees, tensorWidth, tensorHeight, normMeanRGB, normStdRGB, floatBuffer, 0, memoryFormat);
return Tensor.fromBlob(floatBuffer, new long[] {1, 3, tensorHeight, tensorWidth}, memoryFormat);
}
public static Tensor imageYUV420CenterCropToFloat32Tensor(
final Image image,
int rotateCWDegrees,
final int tensorWidth,
final int tensorHeight,
float[] normMeanRGB,
float[] normStdRGB) {
return imageYUV420CenterCropToFloat32Tensor(image, rotateCWDegrees, tensorWidth, tensorHeight, normMeanRGB, normStdRGB, MemoryFormat.CONTIGUOUS);
}
/**
@ -173,7 +229,8 @@ public final class TensorImageUtils {
float[] normMeanRGB,
float[] normStdRGB,
final FloatBuffer outBuffer,
final int outBufferOffset) {
final int outBufferOffset,
final MemoryFormat memoryFormat) {
checkOutBufferCapacity(outBuffer, outBufferOffset, tensorWidth, tensorHeight);
if (image.getFormat() != ImageFormat.YUV_420_888) {
@ -192,6 +249,13 @@ public final class TensorImageUtils {
Image.Plane U = planes[1];
Image.Plane V = planes[2];
int memoryFormatJniCode = 0;
if (MemoryFormat.CONTIGUOUS == memoryFormat) {
memoryFormatJniCode = 1;
} else if (MemoryFormat.CHANNELS_LAST == memoryFormat) {
memoryFormatJniCode = 2;
}
NativePeer.imageYUV420CenterCropToFloatBuffer(
Y.getBuffer(),
Y.getRowStride(),
@ -208,7 +272,20 @@ public final class TensorImageUtils {
normMeanRGB,
normStdRGB,
outBuffer,
outBufferOffset);
outBufferOffset,
memoryFormatJniCode);
}
public static void imageYUV420CenterCropToFloatBuffer(
final Image image,
int rotateCWDegrees,
final int tensorWidth,
final int tensorHeight,
float[] normMeanRGB,
float[] normStdRGB,
final FloatBuffer outBuffer,
final int outBufferOffset) {
imageYUV420CenterCropToFloatBuffer(image, rotateCWDegrees, tensorWidth, tensorHeight, normMeanRGB, normStdRGB, outBuffer, outBufferOffset, MemoryFormat.CONTIGUOUS);
}
private static class NativePeer {
@ -235,7 +312,8 @@ public final class TensorImageUtils {
float[] normMeanRgb,
float[] normStdRgb,
Buffer outBuffer,
int outBufferOffset);
int outBufferOffset,
int memoryFormatJniCode);
}
private static void checkOutBufferCapacity(

View File

@ -25,6 +25,7 @@ import androidx.camera.core.PreviewConfig;
import androidx.core.app.ActivityCompat;
import java.nio.FloatBuffer;
import org.pytorch.IValue;
import org.pytorch.MemoryFormat;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;
@ -184,7 +185,8 @@ public class CameraActivity extends AppCompatActivity {
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
mInputTensorBuffer,
0);
0,
MemoryFormat.CHANNELS_LAST);
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;

View File

@ -19,6 +19,7 @@ import java.io.OutputStream;
import java.nio.FloatBuffer;
import org.pytorch.Device;
import org.pytorch.IValue;
import org.pytorch.MemoryFormat;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;
@ -125,7 +126,7 @@ public class MainActivity extends AppCompatActivity {
numElements *= shape[i];
}
mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE);
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE, MemoryFormat.CHANNELS_LAST);
PyTorchAndroid.setNumThreads(1);
mModule =
BuildConfig.USE_VULKAN_DEVICE