mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[codemod][lint][fbcode] Apply google-java-format
Test Plan: Sandcastle. Visual inspection. Reviewed By: scottrice Differential Revision: D19878711 fbshipit-source-id: be56f70b35825140676be511903e5274d1808f25
This commit is contained in:
committed by
Facebook Github Bot
parent
bf16688538
commit
b28a834813
@ -1,7 +1,5 @@
|
||||
package org.pytorch;
|
||||
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
@ -14,7 +12,8 @@ public class PytorchHostTests extends PytorchTestBase {
|
||||
@Override
|
||||
protected String assetFilePath(String assetName) throws IOException {
|
||||
Path tempFile = Files.createTempFile("test", ".pt");
|
||||
try (InputStream resource = Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {
|
||||
try (InputStream resource =
|
||||
Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {
|
||||
Files.copy(resource, tempFile, StandardCopyOption.REPLACE_EXISTING);
|
||||
}
|
||||
return tempFile.toAbsolutePath().toString();
|
||||
|
@ -1,26 +1,23 @@
|
||||
package org.pytorch;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.Test;
|
||||
|
||||
public abstract class PytorchTestBase {
|
||||
private static final String TEST_MODULE_ASSET_NAME = "test.pt";
|
||||
|
||||
@Test
|
||||
public void testForwardNull() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input =
|
||||
IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
|
||||
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
|
||||
assertTrue(input.isTensor());
|
||||
final IValue output = module.forward(input);
|
||||
assertTrue(output.isNull());
|
||||
@ -242,7 +239,6 @@ public abstract class PytorchTestBase {
|
||||
tensorFloats.getDataAsByteArray();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEqString() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
@ -277,7 +273,8 @@ public abstract class PytorchTestBase {
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("str3Concat", input);
|
||||
assertTrue(output.isString());
|
||||
String expectedOutput = new StringBuilder().append(value).append(value).append(value).toString();
|
||||
String expectedOutput =
|
||||
new StringBuilder().append(value).append(value).append(value).toString();
|
||||
assertTrue(expectedOutput.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,6 @@
|
||||
package org.pytorch;
|
||||
|
||||
/**
|
||||
* Codes representing tensor data types.
|
||||
*/
|
||||
/** Codes representing tensor data types. */
|
||||
public enum DType {
|
||||
// NOTE: "jniCode" must be kept in sync with pytorch_jni_common.cpp.
|
||||
// NOTE: Never serialize "jniCode", because it can change between releases.
|
||||
|
@ -17,7 +17,8 @@ class NativePeer implements INativePeer {
|
||||
private static native HybridData initHybrid(String moduleAbsolutePath);
|
||||
|
||||
@DoNotStrip
|
||||
private static native HybridData initHybridAndroidAsset(String assetName, /* android.content.res.AssetManager */ Object androidAssetManager);
|
||||
private static native HybridData initHybridAndroidAsset(
|
||||
String assetName, /* android.content.res.AssetManager */ Object androidAssetManager);
|
||||
|
||||
NativePeer(String moduleAbsolutePath) {
|
||||
mHybridData = initHybrid(moduleAbsolutePath);
|
||||
|
@ -1,7 +1,6 @@
|
||||
package org.pytorch;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
|
||||
import com.facebook.jni.annotations.DoNotStrip;
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||
@ -15,13 +14,14 @@ public final class PyTorchAndroid {
|
||||
}
|
||||
|
||||
/**
|
||||
* Attention:
|
||||
* This is not recommended way of loading production modules, as prepackaged assets increase apk size etc.
|
||||
* For production usage consider using loading from file on the disk {@link org.pytorch.Module#load(String)}.
|
||||
* Attention: This is not recommended way of loading production modules, as prepackaged assets
|
||||
* increase apk size etc. For production usage consider using loading from file on the disk {@link
|
||||
* org.pytorch.Module#load(String)}.
|
||||
*
|
||||
* This method is meant to use in tests and demos.
|
||||
* <p>This method is meant to use in tests and demos.
|
||||
*/
|
||||
public static Module loadModuleFromAsset(final AssetManager assetManager, final String assetName) {
|
||||
public static Module loadModuleFromAsset(
|
||||
final AssetManager assetManager, final String assetName) {
|
||||
return new Module(new NativePeer(assetName, assetManager));
|
||||
}
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
package org.pytorch;
|
||||
|
||||
import com.facebook.jni.annotations.DoNotStrip;
|
||||
import com.facebook.jni.HybridData;
|
||||
|
||||
import com.facebook.jni.annotations.DoNotStrip;
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
@ -643,7 +642,8 @@ public abstract class Tensor {
|
||||
|
||||
// Called from native
|
||||
@DoNotStrip
|
||||
private static Tensor nativeNewTensor(ByteBuffer data, long[] shape, int dtype, HybridData hybridData) {
|
||||
private static Tensor nativeNewTensor(
|
||||
ByteBuffer data, long[] shape, int dtype, HybridData hybridData) {
|
||||
Tensor tensor = null;
|
||||
if (DType.FLOAT32.jniCode == dtype) {
|
||||
tensor = new Tensor_float32(data.asFloatBuffer(), shape);
|
||||
|
@ -1,16 +1,13 @@
|
||||
package org.pytorch.torchvision;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
import org.junit.Before;
|
||||
import android.graphics.Bitmap;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.pytorch.Tensor;
|
||||
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class TorchVisionInstrumentedTests {
|
||||
|
||||
|
@ -3,20 +3,17 @@ package org.pytorch.torchvision;
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.ImageFormat;
|
||||
import android.media.Image;
|
||||
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||
|
||||
import org.pytorch.Tensor;
|
||||
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.Locale;
|
||||
import org.pytorch.Tensor;
|
||||
|
||||
/**
|
||||
* Contains utility functions for {@link org.pytorch.Tensor} creation from
|
||||
* {@link android.graphics.Bitmap} or {@link android.media.Image} source.
|
||||
* Contains utility functions for {@link org.pytorch.Tensor} creation from {@link
|
||||
* android.graphics.Bitmap} or {@link android.media.Image} source.
|
||||
*/
|
||||
public final class TensorImageUtils {
|
||||
|
||||
@ -28,7 +25,8 @@ public final class TensorImageUtils {
|
||||
* with specified in parameters mean and std.
|
||||
*
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
|
||||
* order
|
||||
*/
|
||||
public static Tensor bitmapToFloat32Tensor(
|
||||
final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[]) {
|
||||
@ -40,9 +38,8 @@ public final class TensorImageUtils {
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes tensor content from specified {@link android.graphics.Bitmap},
|
||||
* normalized with specified in parameters mean and std to specified {@link java.nio.FloatBuffer}
|
||||
* with specified offset.
|
||||
* Writes tensor content from specified {@link android.graphics.Bitmap}, normalized with specified
|
||||
* in parameters mean and std to specified {@link java.nio.FloatBuffer} with specified offset.
|
||||
*
|
||||
* @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data
|
||||
* @param x - x coordinate of top left corner of bitmap's area
|
||||
@ -50,7 +47,8 @@ public final class TensorImageUtils {
|
||||
* @param width - width of bitmap's area
|
||||
* @param height - height of bitmap's area
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
|
||||
* order
|
||||
*/
|
||||
public static void bitmapToFloatBuffer(
|
||||
final Bitmap bitmap,
|
||||
@ -95,7 +93,8 @@ public final class TensorImageUtils {
|
||||
* @param width - width of bitmap's area
|
||||
* @param height - height of bitmap's area
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
|
||||
* order
|
||||
*/
|
||||
public static Tensor bitmapToFloat32Tensor(
|
||||
final Bitmap bitmap,
|
||||
@ -123,7 +122,8 @@ public final class TensorImageUtils {
|
||||
* @param tensorWidth return tensor width, must be positive
|
||||
* @param tensorHeight return tensor height, must be positive
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
|
||||
* order
|
||||
*/
|
||||
public static Tensor imageYUV420CenterCropToFloat32Tensor(
|
||||
final Image image,
|
||||
@ -145,17 +145,14 @@ public final class TensorImageUtils {
|
||||
|
||||
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * tensorWidth * tensorHeight);
|
||||
imageYUV420CenterCropToFloatBuffer(
|
||||
image,
|
||||
rotateCWDegrees,
|
||||
tensorWidth,
|
||||
tensorHeight,
|
||||
normMeanRGB, normStdRGB, floatBuffer, 0);
|
||||
image, rotateCWDegrees, tensorWidth, tensorHeight, normMeanRGB, normStdRGB, floatBuffer, 0);
|
||||
return Tensor.fromBlob(floatBuffer, new long[] {1, 3, tensorHeight, tensorWidth});
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes tensor content from specified {@link android.media.Image}, doing optional rotation,
|
||||
* scaling (nearest) and center cropping to specified {@link java.nio.FloatBuffer} with specified offset.
|
||||
* scaling (nearest) and center cropping to specified {@link java.nio.FloatBuffer} with specified
|
||||
* offset.
|
||||
*
|
||||
* @param image {@link android.media.Image} as a source for Tensor data
|
||||
* @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be
|
||||
@ -163,7 +160,8 @@ public final class TensorImageUtils {
|
||||
* @param tensorWidth return tensor width, must be positive
|
||||
* @param tensorHeight return tensor height, must be positive
|
||||
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order
|
||||
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
|
||||
* order
|
||||
* @param outBuffer Output buffer, where tensor content will be written
|
||||
* @param outBufferOffset Output buffer offset with which tensor content will be written
|
||||
*/
|
||||
@ -210,8 +208,7 @@ public final class TensorImageUtils {
|
||||
normMeanRGB,
|
||||
normStdRGB,
|
||||
outBuffer,
|
||||
outBufferOffset
|
||||
);
|
||||
outBufferOffset);
|
||||
}
|
||||
|
||||
private static class NativePeer {
|
||||
@ -238,11 +235,11 @@ public final class TensorImageUtils {
|
||||
float[] normMeanRgb,
|
||||
float[] normStdRgb,
|
||||
Buffer outBuffer,
|
||||
int outBufferOffset
|
||||
);
|
||||
int outBufferOffset);
|
||||
}
|
||||
|
||||
private static void checkOutBufferCapacity(FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) {
|
||||
private static void checkOutBufferCapacity(
|
||||
FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) {
|
||||
if (outBufferOffset + 3 * tensorWidth * tensorHeight > outBuffer.capacity()) {
|
||||
throw new IllegalStateException("Buffer underflow");
|
||||
}
|
||||
|
@ -12,15 +12,6 @@ import android.view.TextureView;
|
||||
import android.view.ViewStub;
|
||||
import android.widget.TextView;
|
||||
import android.widget.Toast;
|
||||
|
||||
import org.pytorch.IValue;
|
||||
import org.pytorch.Module;
|
||||
import org.pytorch.PyTorchAndroid;
|
||||
import org.pytorch.Tensor;
|
||||
import org.pytorch.torchvision.TensorImageUtils;
|
||||
|
||||
import java.nio.FloatBuffer;
|
||||
|
||||
import androidx.annotation.Nullable;
|
||||
import androidx.annotation.UiThread;
|
||||
import androidx.annotation.WorkerThread;
|
||||
@ -32,6 +23,12 @@ import androidx.camera.core.ImageProxy;
|
||||
import androidx.camera.core.Preview;
|
||||
import androidx.camera.core.PreviewConfig;
|
||||
import androidx.core.app.ActivityCompat;
|
||||
import java.nio.FloatBuffer;
|
||||
import org.pytorch.IValue;
|
||||
import org.pytorch.Module;
|
||||
import org.pytorch.PyTorchAndroid;
|
||||
import org.pytorch.Tensor;
|
||||
import org.pytorch.torchvision.TensorImageUtils;
|
||||
|
||||
public class CameraActivity extends AppCompatActivity {
|
||||
private static final String TAG = BuildConfig.LOGCAT_TAG;
|
||||
@ -59,10 +56,7 @@ public class CameraActivity extends AppCompatActivity {
|
||||
|
||||
if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA)
|
||||
!= PackageManager.PERMISSION_GRANTED) {
|
||||
ActivityCompat.requestPermissions(
|
||||
this,
|
||||
PERMISSIONS,
|
||||
REQUEST_CODE_CAMERA_PERMISSION);
|
||||
ActivityCompat.requestPermissions(this, PERMISSIONS, REQUEST_CODE_CAMERA_PERMISSION);
|
||||
} else {
|
||||
setupCameraX();
|
||||
}
|
||||
@ -118,12 +112,14 @@ public class CameraActivity extends AppCompatActivity {
|
||||
private static final int TENSOR_HEIGHT = 224;
|
||||
|
||||
private void setupCameraX() {
|
||||
final TextureView textureView = ((ViewStub) findViewById(R.id.camera_texture_view_stub))
|
||||
final TextureView textureView =
|
||||
((ViewStub) findViewById(R.id.camera_texture_view_stub))
|
||||
.inflate()
|
||||
.findViewById(R.id.texture_view);
|
||||
final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
|
||||
final Preview preview = new Preview(previewConfig);
|
||||
preview.setOnPreviewOutputUpdateListener(new Preview.OnPreviewOutputUpdateListener() {
|
||||
preview.setOnPreviewOutputUpdateListener(
|
||||
new Preview.OnPreviewOutputUpdateListener() {
|
||||
@Override
|
||||
public void onUpdated(Preview.PreviewOutput output) {
|
||||
textureView.setSurfaceTexture(output.getSurfaceTexture());
|
||||
@ -149,7 +145,8 @@ public class CameraActivity extends AppCompatActivity {
|
||||
|
||||
if (result != null) {
|
||||
mLastAnalysisResultTime = SystemClock.elapsedRealtime();
|
||||
CameraActivity.this.runOnUiThread(new Runnable() {
|
||||
CameraActivity.this.runOnUiThread(
|
||||
new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
CameraActivity.this.handleResult(result);
|
||||
@ -174,17 +171,20 @@ public class CameraActivity extends AppCompatActivity {
|
||||
Log.i(TAG, "Loading module from asset '" + BuildConfig.MODULE_ASSET_NAME + "'");
|
||||
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
|
||||
mInputTensorBuffer = Tensor.allocateFloatBuffer(3 * TENSOR_WIDTH * TENSOR_HEIGHT);
|
||||
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[]{1, 3, TENSOR_WIDTH,
|
||||
TENSOR_HEIGHT});
|
||||
mInputTensor =
|
||||
Tensor.fromBlob(mInputTensorBuffer, new long[] {1, 3, TENSOR_WIDTH, TENSOR_HEIGHT});
|
||||
}
|
||||
|
||||
final long startTime = SystemClock.elapsedRealtime();
|
||||
TensorImageUtils.imageYUV420CenterCropToFloatBuffer(
|
||||
image.getImage(), rotationDegrees,
|
||||
TENSOR_WIDTH, TENSOR_HEIGHT,
|
||||
image.getImage(),
|
||||
rotationDegrees,
|
||||
TENSOR_WIDTH,
|
||||
TENSOR_HEIGHT,
|
||||
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
|
||||
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
|
||||
mInputTensorBuffer, 0);
|
||||
mInputTensorBuffer,
|
||||
0);
|
||||
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
|
||||
final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
|
||||
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
|
||||
@ -198,9 +198,10 @@ public class CameraActivity extends AppCompatActivity {
|
||||
@UiThread
|
||||
protected void handleResult(Result result) {
|
||||
int ixs[] = Utils.topK(result.scores, 1);
|
||||
String message = String.format("forwardDuration:%d class:%s",
|
||||
result.moduleForwardDuration,
|
||||
Constants.IMAGENET_CLASSES[ixs[0]]);
|
||||
String message =
|
||||
String.format(
|
||||
"forwardDuration:%d class:%s",
|
||||
result.moduleForwardDuration, Constants.IMAGENET_CLASSES[ixs[0]]);
|
||||
Log.i(TAG, message);
|
||||
mTextViewStringBuilder.insert(0, '\n').insert(0, message);
|
||||
if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
|
||||
|
@ -3,7 +3,8 @@ package org.pytorch.testapp;
|
||||
public class Constants {
|
||||
public static final String TAG = "PyTorchDemo";
|
||||
|
||||
public static String[] IMAGENET_CLASSES = new String[]{
|
||||
public static String[] IMAGENET_CLASSES =
|
||||
new String[] {
|
||||
"tench, Tinca tinca",
|
||||
"goldfish, Carassius auratus",
|
||||
"great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
|
||||
|
@ -10,13 +10,12 @@ import androidx.annotation.Nullable;
|
||||
import androidx.annotation.UiThread;
|
||||
import androidx.annotation.WorkerThread;
|
||||
import androidx.appcompat.app.AppCompatActivity;
|
||||
import java.nio.FloatBuffer;
|
||||
import org.pytorch.IValue;
|
||||
import org.pytorch.Module;
|
||||
import org.pytorch.PyTorchAndroid;
|
||||
import org.pytorch.Tensor;
|
||||
|
||||
import java.nio.FloatBuffer;
|
||||
|
||||
public class MainActivity extends AppCompatActivity {
|
||||
|
||||
private static final String TAG = BuildConfig.LOGCAT_TAG;
|
||||
@ -31,11 +30,13 @@ public class MainActivity extends AppCompatActivity {
|
||||
private Tensor mInputTensor;
|
||||
private StringBuilder mTextViewStringBuilder = new StringBuilder();
|
||||
|
||||
private final Runnable mModuleForwardRunnable = new Runnable() {
|
||||
private final Runnable mModuleForwardRunnable =
|
||||
new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
final Result result = doModuleForward();
|
||||
runOnUiThread(new Runnable() {
|
||||
runOnUiThread(
|
||||
new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
handleResult(result);
|
||||
|
Reference in New Issue
Block a user