mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Refactor android torchvision: not hardcoded mean/std (#26690)
Summary: - Normalization mean and std specified as parameters instead of hardcode - imageYUV420CenterCropToFloat32Tensor before this change worked only with square tensors (width==height) - added generalization to support width != height with all rotations and scalings - javadocs Pull Request resolved: https://github.com/pytorch/pytorch/pull/26690 Differential Revision: D17556006 Pulled By: IvanKobzarev fbshipit-source-id: 63f3321ea2e6b46ba5c34f9e92c48d116f7dc5ce
This commit is contained in:
committed by
facebook-github-bot
parent
de3d4686ca
commit
c8109058c4
@ -22,7 +22,11 @@ public class TorchVisionInstrumentedTests {
|
||||
@Test
|
||||
public void smokeTest() {
|
||||
Bitmap bitmap = Bitmap.createBitmap(320, 240, Bitmap.Config.ARGB_8888);
|
||||
Tensor tensor = TensorImageUtils.bitmapToFloatTensorTorchVisionForm(bitmap);
|
||||
Tensor tensor =
|
||||
TensorImageUtils.bitmapToFloat32Tensor(
|
||||
bitmap,
|
||||
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
|
||||
TensorImageUtils.TORCHVISION_NORM_STD_RGB);
|
||||
assertArrayEquals(new long[] {1l, 3l, 240l, 320l}, tensor.shape);
|
||||
}
|
||||
}
|
||||
|
@ -9,61 +9,135 @@ import org.pytorch.Tensor;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* Contains utility functions for {@link org.pytorch.Tensor} creation from
|
||||
* {@link android.graphics.Bitmap} or {@link android.media.Image} source.
|
||||
*/
|
||||
public final class TensorImageUtils {
|
||||
private static float NORM_MEAN_R = 0.485f;
|
||||
private static float NORM_MEAN_G = 0.456f;
|
||||
private static float NORM_MEAN_B = 0.406f;
|
||||
|
||||
private static float NORM_STD_R = 0.229f;
|
||||
private static float NORM_STD_G = 0.224f;
|
||||
private static float NORM_STD_B = 0.225f;
|
||||
public static float[] TORCHVISION_NORM_MEAN_RGB = new float[]{0.485f, 0.456f, 0.406f};
|
||||
public static float[] TORCHVISION_NORM_STD_RGB = new float[]{0.229f, 0.224f, 0.225f};
|
||||
|
||||
public static Tensor bitmapToFloatTensorTorchVisionForm(final Bitmap bitmap) {
|
||||
return bitmapToFloatTensorTorchVisionForm(bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight());
|
||||
/**
|
||||
* Creates new {@link org.pytorch.Tensor} from full {@link android.graphics.Bitmap}, normalized
|
||||
* with specified in parameters mean and std.
|
||||
*
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
*/
|
||||
public static Tensor bitmapToFloat32Tensor(
|
||||
final Bitmap bitmap, float[] normMeanRGB, float normStdRGB[]) {
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
|
||||
return bitmapToFloat32Tensor(
|
||||
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB);
|
||||
}
|
||||
|
||||
public static Tensor bitmapToFloatTensorTorchVisionForm(
|
||||
final Bitmap bitmap, int x, int y, int width, int height) {
|
||||
final int pixelsSize = height * width;
|
||||
final int[] pixels = new int[pixelsSize];
|
||||
/**
|
||||
* Creates new {@link org.pytorch.Tensor} from specified area of {@link android.graphics.Bitmap},
|
||||
* normalized with specified in parameters mean and std.
|
||||
*
|
||||
* @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data
|
||||
* @param x - x coordinate of top left corner of bitmap's area
|
||||
* @param y - y coordinate of top left corner of bitmap's area
|
||||
* @param width - width of bitmap's area
|
||||
* @param height - height of bitmap's area
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
*/
|
||||
public static Tensor bitmapToFloat32Tensor(
|
||||
final Bitmap bitmap,
|
||||
int x,
|
||||
int y,
|
||||
int width,
|
||||
int height,
|
||||
float[] normMeanRGB,
|
||||
float[] normStdRGB) {
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
|
||||
final int pixelsCount = height * width;
|
||||
final int[] pixels = new int[pixelsCount];
|
||||
bitmap.getPixels(pixels, 0, width, x, y, width, height);
|
||||
final float[] floatArray = new float[3 * pixelsSize];
|
||||
final int offset_g = pixelsSize;
|
||||
final int offset_b = 2 * pixelsSize;
|
||||
for (int i = 0; i < pixelsSize; i++) {
|
||||
final float[] floatArray = new float[3 * pixelsCount];
|
||||
final int offset_g = pixelsCount;
|
||||
final int offset_b = 2 * pixelsCount;
|
||||
for (int i = 0; i < pixelsCount; i++) {
|
||||
final int c = pixels[i];
|
||||
float r = ((c >> 16) & 0xff) / 255.0f;
|
||||
float g = ((c >> 8) & 0xff) / 255.0f;
|
||||
float b = ((c) & 0xff) / 255.0f;
|
||||
floatArray[i] = (r - NORM_MEAN_R) / NORM_STD_R;
|
||||
floatArray[offset_g + i] = (g - NORM_MEAN_G) / NORM_STD_G;
|
||||
floatArray[offset_b + i] = (b - NORM_MEAN_B) / NORM_STD_B;
|
||||
floatArray[i] = (r - normMeanRGB[0]) / normStdRGB[0];
|
||||
floatArray[offset_g + i] = (g - normMeanRGB[1]) / normStdRGB[1];
|
||||
floatArray[offset_b + i] = (b - normMeanRGB[2]) / normStdRGB[2];
|
||||
}
|
||||
final long shape[] = new long[] {1, 3, height, width};
|
||||
return Tensor.newFloat32Tensor(shape, floatArray);
|
||||
return Tensor.newFloat32Tensor(new long[]{1, 3, height, width}, floatArray);
|
||||
}
|
||||
|
||||
public static Tensor imageYUV420CenterCropToFloatTensorTorchVisionForm(
|
||||
final Image image, int rotateCWDegrees, final int tensorWidth, final int tensorHeight) {
|
||||
/**
|
||||
* Creates new {@link org.pytorch.Tensor} from specified area of {@link android.media.Image},
|
||||
* doing optional rotation, scaling (nearest) and center cropping.
|
||||
*
|
||||
* @param image {@link android.media.Image} as a source for Tensor data
|
||||
* @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be
|
||||
* upright. Range of valid values: 0, 90, 180, 270
|
||||
* @param tensorWidth return tensor width, must be positive
|
||||
* @param tensorHeight return tensor height, must be positive
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
*/
|
||||
public static Tensor imageYUV420CenterCropToFloat32Tensor(
|
||||
final Image image,
|
||||
int rotateCWDegrees,
|
||||
final int tensorWidth,
|
||||
final int tensorHeight,
|
||||
float[] normMeanRGB,
|
||||
float[] normStdRGB) {
|
||||
if (image.getFormat() != ImageFormat.YUV_420_888) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
Locale.US, "Image format %d != ImageFormat.YUV_420_888", image.getFormat()));
|
||||
}
|
||||
|
||||
final int width = image.getWidth();
|
||||
final int height = image.getHeight();
|
||||
int offsetX = 0;
|
||||
int offsetY = 0;
|
||||
int centerCropSize;
|
||||
if (width > height) {
|
||||
offsetX = (int) Math.floor((width - height) / 2.f);
|
||||
centerCropSize = height;
|
||||
} else {
|
||||
offsetY = (int) Math.floor((height - width) / 2.f);
|
||||
centerCropSize = width;
|
||||
checkNormMeanArg(normMeanRGB);
|
||||
checkNormStdArg(normStdRGB);
|
||||
checkRotateCWDegrees(rotateCWDegrees);
|
||||
checkTensorSize(tensorWidth, tensorHeight);
|
||||
|
||||
final int widthBeforeRotation = image.getWidth();
|
||||
final int heightBeforeRotation = image.getHeight();
|
||||
|
||||
int widthAfterRotation = widthBeforeRotation;
|
||||
int heightAfterRotation = heightBeforeRotation;
|
||||
if (rotateCWDegrees == 90 || rotateCWDegrees == 270) {
|
||||
widthAfterRotation = heightBeforeRotation;
|
||||
heightAfterRotation = widthBeforeRotation;
|
||||
}
|
||||
|
||||
int centerCropWidthAfterRotation = widthAfterRotation;
|
||||
int centerCropHeightAfterRotation = heightAfterRotation;
|
||||
|
||||
if (tensorWidth * heightAfterRotation <= tensorHeight * widthAfterRotation) {
|
||||
centerCropWidthAfterRotation =
|
||||
(int) Math.floor((float) tensorWidth * heightAfterRotation / tensorHeight);
|
||||
} else {
|
||||
centerCropHeightAfterRotation =
|
||||
(int) Math.floor((float) tensorHeight * widthAfterRotation / tensorWidth);
|
||||
}
|
||||
|
||||
int centerCropWidthBeforeRotation = centerCropWidthAfterRotation;
|
||||
int centerCropHeightBeforeRotation = centerCropHeightAfterRotation;
|
||||
if (rotateCWDegrees == 90 || rotateCWDegrees == 270) {
|
||||
centerCropHeightBeforeRotation = centerCropWidthAfterRotation;
|
||||
centerCropWidthBeforeRotation = centerCropHeightAfterRotation;
|
||||
}
|
||||
|
||||
final int offsetX =
|
||||
(int) Math.floor((widthBeforeRotation - centerCropWidthBeforeRotation) / 2.f);
|
||||
final int offsetY =
|
||||
(int) Math.floor((heightBeforeRotation - centerCropHeightBeforeRotation) / 2.f);
|
||||
|
||||
final Image.Plane yPlane = image.getPlanes()[0];
|
||||
final Image.Plane uPlane = image.getPlanes()[1];
|
||||
final Image.Plane vPlane = image.getPlanes()[2];
|
||||
@ -72,44 +146,44 @@ public final class TensorImageUtils {
|
||||
final ByteBuffer uBuffer = uPlane.getBuffer();
|
||||
final ByteBuffer vBuffer = vPlane.getBuffer();
|
||||
|
||||
int yRowStride = yPlane.getRowStride();
|
||||
int uRowStride = uPlane.getRowStride();
|
||||
final int yRowStride = yPlane.getRowStride();
|
||||
final int uRowStride = uPlane.getRowStride();
|
||||
|
||||
int yPixelStride = yPlane.getPixelStride();
|
||||
int uPixelStride = uPlane.getPixelStride();
|
||||
final int yPixelStride = yPlane.getPixelStride();
|
||||
final int uPixelStride = uPlane.getPixelStride();
|
||||
|
||||
float tx = (float) centerCropSize / tensorWidth;
|
||||
float ty = (float) centerCropSize / tensorHeight;
|
||||
int uvRowStride = uRowStride >> 1;
|
||||
|
||||
int cSize = tensorHeight * tensorWidth;
|
||||
final int tensorInputOffsetG = cSize;
|
||||
final int tensorInputOffsetB = 2 * centerCropSize;
|
||||
final float[] floatArray = new float[3 * cSize];
|
||||
final float scale = (float) centerCropWidthAfterRotation / tensorWidth;
|
||||
final int uvRowStride = uRowStride >> 1;
|
||||
|
||||
final int channelSize = tensorHeight * tensorWidth;
|
||||
final int tensorInputOffsetG = channelSize;
|
||||
final int tensorInputOffsetB = 2 * channelSize;
|
||||
final float[] floatArray = new float[3 * channelSize];
|
||||
for (int x = 0; x < tensorWidth; x++) {
|
||||
for (int y = 0; y < tensorHeight; y++) {
|
||||
|
||||
// scaling as nearest
|
||||
final int centerCropX = (int) Math.floor(x * tx);
|
||||
final int centerCropY = (int) Math.floor(y * ty);
|
||||
|
||||
int srcX = centerCropY + offsetX;
|
||||
int srcY = (centerCropSize - 1) - centerCropX + offsetY;
|
||||
final int centerCropXAfterRotation = (int) Math.floor(x * scale);
|
||||
final int centerCropYAfterRotation = (int) Math.floor(y * scale);
|
||||
|
||||
int xBeforeRotation = offsetX + centerCropXAfterRotation;
|
||||
int yBeforeRotation = offsetY + centerCropYAfterRotation;
|
||||
if (rotateCWDegrees == 90) {
|
||||
srcX = offsetX + centerCropY;
|
||||
srcY = offsetY + (centerCropSize - 1) - centerCropX;
|
||||
xBeforeRotation = offsetX + centerCropYAfterRotation;
|
||||
yBeforeRotation =
|
||||
offsetY + (centerCropHeightBeforeRotation - 1) - centerCropXAfterRotation;
|
||||
} else if (rotateCWDegrees == 180) {
|
||||
srcX = offsetX + (centerCropSize - 1) - centerCropX;
|
||||
srcY = offsetY + (centerCropSize - 1) - centerCropY;
|
||||
xBeforeRotation =
|
||||
offsetX + (centerCropWidthBeforeRotation - 1) - centerCropXAfterRotation;
|
||||
yBeforeRotation =
|
||||
offsetY + (centerCropHeightBeforeRotation - 1) - centerCropYAfterRotation;
|
||||
} else if (rotateCWDegrees == 270) {
|
||||
srcX = offsetX + (centerCropSize - 1) - centerCropY;
|
||||
srcY = offsetY + centerCropX;
|
||||
xBeforeRotation =
|
||||
offsetX + (centerCropWidthBeforeRotation - 1) - centerCropYAfterRotation;
|
||||
yBeforeRotation = offsetY + centerCropXAfterRotation;
|
||||
}
|
||||
|
||||
final int yIdx = srcY * yRowStride + srcX * yPixelStride;
|
||||
final int uvIdx = (srcY >> 1) * uvRowStride + srcX * uPixelStride;
|
||||
final int yIdx = yBeforeRotation * yRowStride + xBeforeRotation * yPixelStride;
|
||||
final int uvIdx = (yBeforeRotation >> 1) * uvRowStride + xBeforeRotation * uPixelStride;
|
||||
|
||||
int Yi = yBuffer.get(yIdx) & 0xff;
|
||||
int Ui = uBuffer.get(uvIdx) & 0xff;
|
||||
@ -125,16 +199,42 @@ public final class TensorImageUtils {
|
||||
int g = clamp((a0 - a2 - a3) >> 10, 0, 255);
|
||||
int b = clamp((a0 + a4) >> 10, 0, 255);
|
||||
final int offset = y * tensorWidth + x;
|
||||
floatArray[offset] = ((r / 255.f) - NORM_MEAN_R) / NORM_STD_R;
|
||||
floatArray[tensorInputOffsetG + offset] = ((g / 255.f) - NORM_MEAN_G) / NORM_STD_G;
|
||||
floatArray[tensorInputOffsetB + offset] = ((b / 255.f) - NORM_MEAN_B) / NORM_STD_B;
|
||||
floatArray[offset] = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0];
|
||||
floatArray[tensorInputOffsetG + offset] = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1];
|
||||
floatArray[tensorInputOffsetB + offset] = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2];
|
||||
}
|
||||
}
|
||||
final long shape[] = new long[] {1, 3, tensorHeight, tensorHeight};
|
||||
return Tensor.newFloat32Tensor(shape, floatArray);
|
||||
return Tensor.newFloat32Tensor(new long[]{1, 3, tensorHeight, tensorWidth}, floatArray);
|
||||
}
|
||||
|
||||
private static void checkTensorSize(int tensorWidth, int tensorHeight) {
|
||||
if (tensorHeight <= 0 || tensorWidth <= 0) {
|
||||
throw new IllegalArgumentException("tensorHeight and tensorWidth must be positive");
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkRotateCWDegrees(int rotateCWDegrees) {
|
||||
if (rotateCWDegrees != 0
|
||||
&& rotateCWDegrees != 90
|
||||
&& rotateCWDegrees != 180
|
||||
&& rotateCWDegrees != 270) {
|
||||
throw new IllegalArgumentException("rotateCWDegrees must be one of 0, 90, 180, 270");
|
||||
}
|
||||
}
|
||||
|
||||
private static final int clamp(int c, int min, int max) {
|
||||
return c < min ? min : c > max ? max : c;
|
||||
}
|
||||
|
||||
private static void checkNormStdArg(float[] normStdRGB) {
|
||||
if (normStdRGB.length != 3) {
|
||||
throw new IllegalArgumentException("normStdRGB length must be 3");
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkNormMeanArg(float[] normMeanRGB) {
|
||||
if (normMeanRGB.length != 3) {
|
||||
throw new IllegalArgumentException("normMeanRGB length must be 3");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user