add to Tensor symmetric methods getDataAsIntArray, getDataAsByteArray (#25183)

Summary:
Tensor has getDataAsFloatArray(), we also support Int and Byte Tensors,
adding symmetric methods for Int and Byte, that will throw
IllegalStateException if called for not appropriate type
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25183

Reviewed By: dreiss

Differential Revision: D17052674

Pulled By: IvanKobzarev

fbshipit-source-id: 1d44944461ad008e202e382152cd0690c61124f4
This commit is contained in:
Ivan Kobzarev
2019-08-26 19:08:45 -07:00
committed by Facebook Github Bot
parent c2e0383975
commit c0334015ed
3 changed files with 67 additions and 1 deletions

1
android/.gitignore vendored
View File

@ -2,6 +2,7 @@ local.properties
**/*.iml
.gradle
gradlew*
gradle/wrapper
.idea/*
.externalNativeBuild
build

View File

@ -214,6 +214,47 @@ public class PytorchInstrumentedTests {
module.runMethod("test_undefined_method_throws_exception");
}
@Test
public void testTensorMethods() {
long[] dims = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numElements(dims);
int[] ints = new int[numel];
float[] floats = new float[numel];
byte[] bytes = new byte[numel];
for (int i = 0; i < numel; i++) {
bytes[i] = (byte) ((i % 255) - 128);
ints[i] = i;
floats[i] = i / 1000.f;
}
Tensor tensorBytes = Tensor.newByteTensor(dims, bytes);
assertTrue(tensorBytes.isByteTensor());
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
Tensor tensorInts = Tensor.newIntTensor(dims, ints);
assertTrue(tensorInts.isIntTensor());
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
Tensor tensorFloats = Tensor.newFloatTensor(dims, floats);
assertTrue(tensorFloats.isFloatTensor());
float[] floatsOut = tensorFloats.getDataAsFloatArray();
assertTrue(floatsOut.length == numel);
for (int i = 0; i < numel; i++) {
assertTrue(floats[i] == floatsOut[i]);
}
}
@Test(expected = IllegalStateException.class)
public void testTensorIllegalStateOnWrongType() {
long[] dims = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numElements(dims);
float[] floats = new float[numel];
Tensor tensorFloats = Tensor.newFloatTensor(dims, floats);
assertTrue(tensorFloats.isFloatTensor());
tensorFloats.getDataAsByteArray();
}
private static String assetFilePath(String assetName) throws IOException {
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
File file = new File(appContext.getFilesDir(), assetName);

View File

@ -128,9 +128,31 @@ public abstract class Tensor {
return result;
}
public byte[] getDataAsByteArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
}
public int[] getDataAsIntArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as int array.");
}
public float[] getDataAsFloatArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot " + "return data as float.");
"Tensor of type " + getClass().getSimpleName() + " cannot return data as float array.");
}
public boolean isByteTensor() {
return TYPE_CODE_BYTE == getTypeCode();
}
public boolean isIntTensor() {
return TYPE_CODE_INT32 == getTypeCode();
}
public boolean isFloatTensor() {
return TYPE_CODE_FLOAT32 == getTypeCode();
}
abstract int getTypeCode();
@ -200,6 +222,7 @@ public abstract class Tensor {
return data;
}
@Override
public int[] getDataAsIntArray() {
data.rewind();
int[] arr = new int[data.remaining()];
@ -233,6 +256,7 @@ public abstract class Tensor {
return data;
}
@Override
public byte[] getDataAsByteArray() {
data.rewind();
byte[] arr = new byte[data.remaining()];