Refactor android torchvision: not hardcoded mean/std (#26690)

Summary:
- Normalization mean and std specified as parameters instead of hardcode
 - imageYUV420CenterCropToFloat32Tensor before this change worked only with square tensors (width==height) - added generalization to support width != height with all rotations and scalings
- javadocs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26690

Differential Revision: D17556006

Pulled By: IvanKobzarev

fbshipit-source-id: 63f3321ea2e6b46ba5c34f9e92c48d116f7dc5ce
This commit is contained in:
Ivan Kobzarev
2019-09-24 13:36:34 -07:00
committed by facebook-github-bot
parent de3d4686ca
commit c8109058c4
2 changed files with 170 additions and 66 deletions

View File

@ -22,7 +22,11 @@ public class TorchVisionInstrumentedTests {
@Test
public void smokeTest() {
Bitmap bitmap = Bitmap.createBitmap(320, 240, Bitmap.Config.ARGB_8888);
Tensor tensor = TensorImageUtils.bitmapToFloatTensorTorchVisionForm(bitmap);
Tensor tensor =
TensorImageUtils.bitmapToFloat32Tensor(
bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB);
assertArrayEquals(new long[] {1l, 3l, 240l, 320l}, tensor.shape);
}
}

View File

@ -9,61 +9,135 @@ import org.pytorch.Tensor;
import java.nio.ByteBuffer;
import java.util.Locale;
/**
* Contains utility functions for {@link org.pytorch.Tensor} creation from
* {@link android.graphics.Bitmap} or {@link android.media.Image} source.
*/
public final class TensorImageUtils {
private static float NORM_MEAN_R = 0.485f;
private static float NORM_MEAN_G = 0.456f;
private static float NORM_MEAN_B = 0.406f;
private static float NORM_STD_R = 0.229f;
private static float NORM_STD_G = 0.224f;
private static float NORM_STD_B = 0.225f;
public static float[] TORCHVISION_NORM_MEAN_RGB = new float[]{0.485f, 0.456f, 0.406f};
public static float[] TORCHVISION_NORM_STD_RGB = new float[]{0.229f, 0.224f, 0.225f};
public static Tensor bitmapToFloatTensorTorchVisionForm(final Bitmap bitmap) {
return bitmapToFloatTensorTorchVisionForm(bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight());
/**
* Creates new {@link org.pytorch.Tensor} from full {@link android.graphics.Bitmap}, normalized
* with specified in parameters mean and std.
*
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
*/
public static Tensor bitmapToFloat32Tensor(
final Bitmap bitmap, float[] normMeanRGB, float normStdRGB[]) {
checkNormMeanArg(normMeanRGB);
checkNormStdArg(normStdRGB);
return bitmapToFloat32Tensor(
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB);
}
public static Tensor bitmapToFloatTensorTorchVisionForm(
final Bitmap bitmap, int x, int y, int width, int height) {
final int pixelsSize = height * width;
final int[] pixels = new int[pixelsSize];
/**
* Creates new {@link org.pytorch.Tensor} from specified area of {@link android.graphics.Bitmap},
* normalized with specified in parameters mean and std.
*
* @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data
* @param x - x coordinate of top left corner of bitmap's area
* @param y - y coordinate of top left corner of bitmap's area
* @param width - width of bitmap's area
* @param height - height of bitmap's area
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
*/
public static Tensor bitmapToFloat32Tensor(
final Bitmap bitmap,
int x,
int y,
int width,
int height,
float[] normMeanRGB,
float[] normStdRGB) {
checkNormMeanArg(normMeanRGB);
checkNormStdArg(normStdRGB);
final int pixelsCount = height * width;
final int[] pixels = new int[pixelsCount];
bitmap.getPixels(pixels, 0, width, x, y, width, height);
final float[] floatArray = new float[3 * pixelsSize];
final int offset_g = pixelsSize;
final int offset_b = 2 * pixelsSize;
for (int i = 0; i < pixelsSize; i++) {
final float[] floatArray = new float[3 * pixelsCount];
final int offset_g = pixelsCount;
final int offset_b = 2 * pixelsCount;
for (int i = 0; i < pixelsCount; i++) {
final int c = pixels[i];
float r = ((c >> 16) & 0xff) / 255.0f;
float g = ((c >> 8) & 0xff) / 255.0f;
float b = ((c) & 0xff) / 255.0f;
floatArray[i] = (r - NORM_MEAN_R) / NORM_STD_R;
floatArray[offset_g + i] = (g - NORM_MEAN_G) / NORM_STD_G;
floatArray[offset_b + i] = (b - NORM_MEAN_B) / NORM_STD_B;
floatArray[i] = (r - normMeanRGB[0]) / normStdRGB[0];
floatArray[offset_g + i] = (g - normMeanRGB[1]) / normStdRGB[1];
floatArray[offset_b + i] = (b - normMeanRGB[2]) / normStdRGB[2];
}
final long shape[] = new long[] {1, 3, height, width};
return Tensor.newFloat32Tensor(shape, floatArray);
return Tensor.newFloat32Tensor(new long[]{1, 3, height, width}, floatArray);
}
public static Tensor imageYUV420CenterCropToFloatTensorTorchVisionForm(
final Image image, int rotateCWDegrees, final int tensorWidth, final int tensorHeight) {
/**
* Creates new {@link org.pytorch.Tensor} from specified area of {@link android.media.Image},
* doing optional rotation, scaling (nearest) and center cropping.
*
* @param image {@link android.media.Image} as a source for Tensor data
* @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be
* upright. Range of valid values: 0, 90, 180, 270
* @param tensorWidth return tensor width, must be positive
* @param tensorHeight return tensor height, must be positive
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
*/
public static Tensor imageYUV420CenterCropToFloat32Tensor(
final Image image,
int rotateCWDegrees,
final int tensorWidth,
final int tensorHeight,
float[] normMeanRGB,
float[] normStdRGB) {
if (image.getFormat() != ImageFormat.YUV_420_888) {
throw new IllegalArgumentException(
String.format(
Locale.US, "Image format %d != ImageFormat.YUV_420_888", image.getFormat()));
}
final int width = image.getWidth();
final int height = image.getHeight();
int offsetX = 0;
int offsetY = 0;
int centerCropSize;
if (width > height) {
offsetX = (int) Math.floor((width - height) / 2.f);
centerCropSize = height;
} else {
offsetY = (int) Math.floor((height - width) / 2.f);
centerCropSize = width;
checkNormMeanArg(normMeanRGB);
checkNormStdArg(normStdRGB);
checkRotateCWDegrees(rotateCWDegrees);
checkTensorSize(tensorWidth, tensorHeight);
final int widthBeforeRotation = image.getWidth();
final int heightBeforeRotation = image.getHeight();
int widthAfterRotation = widthBeforeRotation;
int heightAfterRotation = heightBeforeRotation;
if (rotateCWDegrees == 90 || rotateCWDegrees == 270) {
widthAfterRotation = heightBeforeRotation;
heightAfterRotation = widthBeforeRotation;
}
int centerCropWidthAfterRotation = widthAfterRotation;
int centerCropHeightAfterRotation = heightAfterRotation;
if (tensorWidth * heightAfterRotation <= tensorHeight * widthAfterRotation) {
centerCropWidthAfterRotation =
(int) Math.floor((float) tensorWidth * heightAfterRotation / tensorHeight);
} else {
centerCropHeightAfterRotation =
(int) Math.floor((float) tensorHeight * widthAfterRotation / tensorWidth);
}
int centerCropWidthBeforeRotation = centerCropWidthAfterRotation;
int centerCropHeightBeforeRotation = centerCropHeightAfterRotation;
if (rotateCWDegrees == 90 || rotateCWDegrees == 270) {
centerCropHeightBeforeRotation = centerCropWidthAfterRotation;
centerCropWidthBeforeRotation = centerCropHeightAfterRotation;
}
final int offsetX =
(int) Math.floor((widthBeforeRotation - centerCropWidthBeforeRotation) / 2.f);
final int offsetY =
(int) Math.floor((heightBeforeRotation - centerCropHeightBeforeRotation) / 2.f);
final Image.Plane yPlane = image.getPlanes()[0];
final Image.Plane uPlane = image.getPlanes()[1];
final Image.Plane vPlane = image.getPlanes()[2];
@ -72,44 +146,44 @@ public final class TensorImageUtils {
final ByteBuffer uBuffer = uPlane.getBuffer();
final ByteBuffer vBuffer = vPlane.getBuffer();
int yRowStride = yPlane.getRowStride();
int uRowStride = uPlane.getRowStride();
final int yRowStride = yPlane.getRowStride();
final int uRowStride = uPlane.getRowStride();
int yPixelStride = yPlane.getPixelStride();
int uPixelStride = uPlane.getPixelStride();
final int yPixelStride = yPlane.getPixelStride();
final int uPixelStride = uPlane.getPixelStride();
float tx = (float) centerCropSize / tensorWidth;
float ty = (float) centerCropSize / tensorHeight;
int uvRowStride = uRowStride >> 1;
int cSize = tensorHeight * tensorWidth;
final int tensorInputOffsetG = cSize;
final int tensorInputOffsetB = 2 * centerCropSize;
final float[] floatArray = new float[3 * cSize];
final float scale = (float) centerCropWidthAfterRotation / tensorWidth;
final int uvRowStride = uRowStride >> 1;
final int channelSize = tensorHeight * tensorWidth;
final int tensorInputOffsetG = channelSize;
final int tensorInputOffsetB = 2 * channelSize;
final float[] floatArray = new float[3 * channelSize];
for (int x = 0; x < tensorWidth; x++) {
for (int y = 0; y < tensorHeight; y++) {
// scaling as nearest
final int centerCropX = (int) Math.floor(x * tx);
final int centerCropY = (int) Math.floor(y * ty);
int srcX = centerCropY + offsetX;
int srcY = (centerCropSize - 1) - centerCropX + offsetY;
final int centerCropXAfterRotation = (int) Math.floor(x * scale);
final int centerCropYAfterRotation = (int) Math.floor(y * scale);
int xBeforeRotation = offsetX + centerCropXAfterRotation;
int yBeforeRotation = offsetY + centerCropYAfterRotation;
if (rotateCWDegrees == 90) {
srcX = offsetX + centerCropY;
srcY = offsetY + (centerCropSize - 1) - centerCropX;
xBeforeRotation = offsetX + centerCropYAfterRotation;
yBeforeRotation =
offsetY + (centerCropHeightBeforeRotation - 1) - centerCropXAfterRotation;
} else if (rotateCWDegrees == 180) {
srcX = offsetX + (centerCropSize - 1) - centerCropX;
srcY = offsetY + (centerCropSize - 1) - centerCropY;
xBeforeRotation =
offsetX + (centerCropWidthBeforeRotation - 1) - centerCropXAfterRotation;
yBeforeRotation =
offsetY + (centerCropHeightBeforeRotation - 1) - centerCropYAfterRotation;
} else if (rotateCWDegrees == 270) {
srcX = offsetX + (centerCropSize - 1) - centerCropY;
srcY = offsetY + centerCropX;
xBeforeRotation =
offsetX + (centerCropWidthBeforeRotation - 1) - centerCropYAfterRotation;
yBeforeRotation = offsetY + centerCropXAfterRotation;
}
final int yIdx = srcY * yRowStride + srcX * yPixelStride;
final int uvIdx = (srcY >> 1) * uvRowStride + srcX * uPixelStride;
final int yIdx = yBeforeRotation * yRowStride + xBeforeRotation * yPixelStride;
final int uvIdx = (yBeforeRotation >> 1) * uvRowStride + xBeforeRotation * uPixelStride;
int Yi = yBuffer.get(yIdx) & 0xff;
int Ui = uBuffer.get(uvIdx) & 0xff;
@ -125,16 +199,42 @@ public final class TensorImageUtils {
int g = clamp((a0 - a2 - a3) >> 10, 0, 255);
int b = clamp((a0 + a4) >> 10, 0, 255);
final int offset = y * tensorWidth + x;
floatArray[offset] = ((r / 255.f) - NORM_MEAN_R) / NORM_STD_R;
floatArray[tensorInputOffsetG + offset] = ((g / 255.f) - NORM_MEAN_G) / NORM_STD_G;
floatArray[tensorInputOffsetB + offset] = ((b / 255.f) - NORM_MEAN_B) / NORM_STD_B;
floatArray[offset] = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0];
floatArray[tensorInputOffsetG + offset] = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1];
floatArray[tensorInputOffsetB + offset] = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2];
}
}
final long shape[] = new long[] {1, 3, tensorHeight, tensorHeight};
return Tensor.newFloat32Tensor(shape, floatArray);
return Tensor.newFloat32Tensor(new long[]{1, 3, tensorHeight, tensorWidth}, floatArray);
}
private static void checkTensorSize(int tensorWidth, int tensorHeight) {
if (tensorHeight <= 0 || tensorWidth <= 0) {
throw new IllegalArgumentException("tensorHeight and tensorWidth must be positive");
}
}
private static void checkRotateCWDegrees(int rotateCWDegrees) {
if (rotateCWDegrees != 0
&& rotateCWDegrees != 90
&& rotateCWDegrees != 180
&& rotateCWDegrees != 270) {
throw new IllegalArgumentException("rotateCWDegrees must be one of 0, 90, 180, 270");
}
}
private static final int clamp(int c, int min, int max) {
return c < min ? min : c > max ? max : c;
}
private static void checkNormStdArg(float[] normStdRGB) {
if (normStdRGB.length != 3) {
throw new IllegalArgumentException("normStdRGB length must be 3");
}
}
private static void checkNormMeanArg(float[] normMeanRGB) {
if (normMeanRGB.length != 3) {
throw new IllegalArgumentException("normMeanRGB length must be 3");
}
}
}