mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add to Tensor symmetric methods getDataAsIntArray, getDataAsByteArray (#25183)
Summary: Tensor has getDataAsFloatArray(), we also support Int and Byte Tensors, adding symmetric methods for Int and Byte, that will throw IllegalStateException if called for not appropriate type Pull Request resolved: https://github.com/pytorch/pytorch/pull/25183 Reviewed By: dreiss Differential Revision: D17052674 Pulled By: IvanKobzarev fbshipit-source-id: 1d44944461ad008e202e382152cd0690c61124f4
This commit is contained in:
committed by
Facebook Github Bot
parent
c2e0383975
commit
c0334015ed
1
android/.gitignore
vendored
1
android/.gitignore
vendored
@ -2,6 +2,7 @@ local.properties
|
||||
**/*.iml
|
||||
.gradle
|
||||
gradlew*
|
||||
gradle/wrapper
|
||||
.idea/*
|
||||
.externalNativeBuild
|
||||
build
|
||||
|
@ -214,6 +214,47 @@ public class PytorchInstrumentedTests {
|
||||
module.runMethod("test_undefined_method_throws_exception");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorMethods() {
|
||||
long[] dims = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numElements(dims);
|
||||
int[] ints = new int[numel];
|
||||
float[] floats = new float[numel];
|
||||
|
||||
byte[] bytes = new byte[numel];
|
||||
for (int i = 0; i < numel; i++) {
|
||||
bytes[i] = (byte) ((i % 255) - 128);
|
||||
ints[i] = i;
|
||||
floats[i] = i / 1000.f;
|
||||
}
|
||||
|
||||
Tensor tensorBytes = Tensor.newByteTensor(dims, bytes);
|
||||
assertTrue(tensorBytes.isByteTensor());
|
||||
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
|
||||
|
||||
Tensor tensorInts = Tensor.newIntTensor(dims, ints);
|
||||
assertTrue(tensorInts.isIntTensor());
|
||||
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
|
||||
|
||||
Tensor tensorFloats = Tensor.newFloatTensor(dims, floats);
|
||||
assertTrue(tensorFloats.isFloatTensor());
|
||||
float[] floatsOut = tensorFloats.getDataAsFloatArray();
|
||||
assertTrue(floatsOut.length == numel);
|
||||
for (int i = 0; i < numel; i++) {
|
||||
assertTrue(floats[i] == floatsOut[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void testTensorIllegalStateOnWrongType() {
|
||||
long[] dims = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numElements(dims);
|
||||
float[] floats = new float[numel];
|
||||
Tensor tensorFloats = Tensor.newFloatTensor(dims, floats);
|
||||
assertTrue(tensorFloats.isFloatTensor());
|
||||
tensorFloats.getDataAsByteArray();
|
||||
}
|
||||
|
||||
private static String assetFilePath(String assetName) throws IOException {
|
||||
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||
File file = new File(appContext.getFilesDir(), assetName);
|
||||
|
@ -128,9 +128,31 @@ public abstract class Tensor {
|
||||
return result;
|
||||
}
|
||||
|
||||
public byte[] getDataAsByteArray() {
|
||||
throw new IllegalStateException(
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
|
||||
}
|
||||
|
||||
public int[] getDataAsIntArray() {
|
||||
throw new IllegalStateException(
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot return data as int array.");
|
||||
}
|
||||
|
||||
public float[] getDataAsFloatArray() {
|
||||
throw new IllegalStateException(
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot " + "return data as float.");
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot return data as float array.");
|
||||
}
|
||||
|
||||
public boolean isByteTensor() {
|
||||
return TYPE_CODE_BYTE == getTypeCode();
|
||||
}
|
||||
|
||||
public boolean isIntTensor() {
|
||||
return TYPE_CODE_INT32 == getTypeCode();
|
||||
}
|
||||
|
||||
public boolean isFloatTensor() {
|
||||
return TYPE_CODE_FLOAT32 == getTypeCode();
|
||||
}
|
||||
|
||||
abstract int getTypeCode();
|
||||
@ -200,6 +222,7 @@ public abstract class Tensor {
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int[] getDataAsIntArray() {
|
||||
data.rewind();
|
||||
int[] arr = new int[data.remaining()];
|
||||
@ -233,6 +256,7 @@ public abstract class Tensor {
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] getDataAsByteArray() {
|
||||
data.rewind();
|
||||
byte[] arr = new byte[data.remaining()];
|
||||
|
Reference in New Issue
Block a user