Fix error messages; tensor creation method names with type (#26219)

Summary:
After offline discussion with dzhulgakov :
 - In future we will introduce creation of byte signed and byte unsigned dtype tensors, but java has only signed byte - we will have to add some separation for it in method names ( java types and tensor types  can not be clearly mapped) => Returning type in method names

- fixes in error messages

- non-static method Tensor.numel()

- Change Tensor toString() to be more consistent with python

Update on Sep 16:

Type renaming on java side to uint8, int8, int32, float32, int64, float64
```
public abstract class Tensor {
  public static final int DTYPE_UINT8 = 1;
  public static final int DTYPE_INT8 = 2;
  public static final int DTYPE_INT32 = 3;
  public static final int DTYPE_FLOAT32 = 4;
  public static final int DTYPE_INT64 = 5;
  public static final int DTYPE_FLOAT64 = 6;
```
```
  public static Tensor newUInt8Tensor(long[] shape, byte[] data)
  public static Tensor newInt8Tensor(long[] shape, byte[] data)
  public static Tensor newInt32Tensor(long[] shape, int[] data)
  public static Tensor newFloat32Tensor(long[] shape, float[] data)
  public static Tensor newInt64Tensor(long[] shape, long[] data)
  public static Tensor newFloat64Tensor(long[] shape, double[] data)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26219

Differential Revision: D17406467

Pulled By: IvanKobzarev

fbshipit-source-id: a0d7d44dc8ce8a562da1a18bd873db762975b184
This commit is contained in:
Ivan Kobzarev
2019-09-16 18:25:17 -07:00
committed by Facebook Github Bot
parent 448c53747a
commit b07991f7f5
4 changed files with 173 additions and 82 deletions

View File

@ -36,7 +36,7 @@ public class PytorchInstrumentedTests {
public void testForwardNull() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final IValue input =
IValue.tensor(Tensor.newTensor(new long[] {1}, Tensor.allocateByteBuffer(1)));
IValue.tensor(Tensor.newInt8Tensor(new long[] {1}, Tensor.allocateByteBuffer(1)));
assertTrue(input.isTensor());
final IValue output = module.forward(input);
assertTrue(output.isNull());
@ -103,7 +103,7 @@ public class PytorchInstrumentedTests {
for (int i = 0; i < numElements; ++i) {
inputTensorData[i] = i;
}
final Tensor inputTensor = Tensor.newTensor(inputTensorShape, inputTensorData);
final Tensor inputTensor = Tensor.newFloat32Tensor(inputTensorShape, inputTensorData);
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final IValue input = IValue.tensor(inputTensor);
@ -231,15 +231,15 @@ public class PytorchInstrumentedTests {
floats[i] = i / 1000.f;
}
Tensor tensorBytes = Tensor.newTensor(shape, bytes);
assertTrue(tensorBytes.dtype() == Tensor.DTYPE_BYTE);
Tensor tensorBytes = Tensor.newInt8Tensor(shape, bytes);
assertTrue(tensorBytes.dtype() == Tensor.DTYPE_INT8);
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
Tensor tensorInts = Tensor.newTensor(shape, ints);
Tensor tensorInts = Tensor.newInt32Tensor(shape, ints);
assertTrue(tensorInts.dtype() == Tensor.DTYPE_INT32);
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
Tensor tensorFloats = Tensor.newTensor(shape, floats);
Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats);
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
float[] floatsOut = tensorFloats.getDataAsFloatArray();
assertTrue(floatsOut.length == numel);
@ -253,7 +253,7 @@ public class PytorchInstrumentedTests {
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
float[] floats = new float[numel];
Tensor tensorFloats = Tensor.newTensor(shape, floats);
Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats);
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
tensorFloats.getDataAsByteArray();
}

View File

@ -10,11 +10,12 @@
namespace pytorch_jni {
constexpr static int kTensorDTypeByte = 1;
constexpr static int kTensorDTypeInt32 = 2;
constexpr static int kTensorDTypeFloat32 = 3;
constexpr static int kTensorDTypeLong64 = 4;
constexpr static int kTensorDTypeDouble64 = 5;
constexpr static int kTensorDTypeUInt8 = 1;
constexpr static int kTensorDTypeInt8 = 2;
constexpr static int kTensorDTypeInt32 = 3;
constexpr static int kTensorDTypeFloat32 = 4;
constexpr static int kTensorDTypeInt64 = 5;
constexpr static int kTensorDTypeFloat64 = 6;
template <typename K = jobject, typename V = jobject>
struct JHashMap
@ -64,15 +65,18 @@ static at::Tensor newAtTensor(
} else if (kTensorDTypeInt32 == jdtype) {
dataElementSizeBytes = 4;
typeMeta = caffe2::TypeMeta::Make<int32_t>();
} else if (kTensorDTypeByte == jdtype) {
} else if (kTensorDTypeInt8 == jdtype) {
dataElementSizeBytes = 1;
typeMeta = caffe2::TypeMeta::Make<int8_t>();
} else if (kTensorDTypeLong64 == jdtype) {
dataElementSizeBytes = 8;
typeMeta = caffe2::TypeMeta::Make<int64_t>();
} else if (kTensorDTypeDouble64 == jdtype) {
} else if (kTensorDTypeUInt8 == jdtype) {
dataElementSizeBytes = 1;
typeMeta = caffe2::TypeMeta::Make<uint8_t>();
} else if (kTensorDTypeFloat64 == jdtype) {
dataElementSizeBytes = 8;
typeMeta = caffe2::TypeMeta::Make<double>();
} else if (kTensorDTypeInt64 == jdtype) {
dataElementSizeBytes = 8;
typeMeta = caffe2::TypeMeta::Make<int64_t>();
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
@ -123,11 +127,13 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
} else if (at::kInt == scalarType) {
jdtype = kTensorDTypeInt32;
} else if (at::kByte == scalarType) {
jdtype = kTensorDTypeByte;
jdtype = kTensorDTypeUInt8;
} else if (at::kChar == scalarType) {
jdtype = kTensorDTypeInt8;
} else if (at::kLong == scalarType) {
jdtype = kTensorDTypeLong64;
jdtype = kTensorDTypeInt64;
} else if (at::kDouble == scalarType) {
jdtype = kTensorDTypeDouble64;
jdtype = kTensorDTypeFloat64;
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,

View File

@ -11,17 +11,18 @@ import java.util.Arrays;
import java.util.Locale;
public abstract class Tensor {
public static final int DTYPE_BYTE = 1;
public static final int DTYPE_INT32 = 2;
public static final int DTYPE_FLOAT32 = 3;
public static final int DTYPE_LONG64 = 4;
public static final int DTYPE_DOUBLE64 = 5;
public static final int DTYPE_UINT8 = 1;
public static final int DTYPE_INT8 = 2;
public static final int DTYPE_INT32 = 3;
public static final int DTYPE_FLOAT32 = 4;
public static final int DTYPE_INT64 = 5;
public static final int DTYPE_FLOAT64 = 6;
private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null";
private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null";
private static final String ERROR_MSG_SHAPE_NOT_NULL = "Dims must be not null";
private static final String ERROR_MSG_SHAPE_NOT_EMPTY = "Dims must be not empty";
private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Dims must be non negative";
private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null";
private static final String ERROR_MSG_SHAPE_NOT_EMPTY = "Shape must be not empty";
private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Shape elements must be non negative";
private static final String ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER =
"Data buffer must have native byte order (java.nio.ByteOrder#nativeOrder)";
private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT =
@ -62,17 +63,27 @@ public abstract class Tensor {
.asDoubleBuffer();
}
public static Tensor newTensor(long[] shape, byte[] data) {
public static Tensor newUInt8Tensor(long[] shape, byte[] data) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape));
byteBuffer.put(data);
return new Tensor_byte(byteBuffer, shape);
return new Tensor_uint8(byteBuffer, shape);
}
public static Tensor newTensor(long[] shape, int[] data) {
public static Tensor newInt8Tensor(long[] shape, byte[] data) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape));
byteBuffer.put(data);
return new Tensor_int8(byteBuffer, shape);
}
public static Tensor newInt32Tensor(long[] shape, int[] data) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -82,7 +93,7 @@ public abstract class Tensor {
return new Tensor_int32(intBuffer, shape);
}
public static Tensor newTensor(long[] shape, float[] data) {
public static Tensor newFloat32Tensor(long[] shape, float[] data) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -92,27 +103,27 @@ public abstract class Tensor {
return new Tensor_float32(floatBuffer, shape);
}
public static Tensor newTensor(long[] shape, long[] data) {
public static Tensor newInt64Tensor(long[] shape, long[] data) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final LongBuffer longBuffer = allocateLongBuffer((int) numel(shape));
longBuffer.put(data);
return new Tensor_long64(longBuffer, shape);
return new Tensor_int64(longBuffer, shape);
}
public static Tensor newTensor(long[] shape, double[] data) {
public static Tensor newFloat64Tensor(long[] shape, double[] data) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.length, shape);
final DoubleBuffer doubleBuffer = allocateDoubleBuffer((int) numel(shape));
doubleBuffer.put(data);
return new Tensor_double64(doubleBuffer, shape);
return new Tensor_float64(doubleBuffer, shape);
}
public static Tensor newTensor(long[] shape, FloatBuffer data) {
public static Tensor newUInt8Tensor(long[] shape, ByteBuffer data) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -121,10 +132,22 @@ public abstract class Tensor {
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_float32(data, shape);
return new Tensor_uint8(data, shape);
}
public static Tensor newTensor(long[] shape, IntBuffer data) {
public static Tensor newInt8Tensor(long[] shape, ByteBuffer data) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_int8(data, shape);
}
public static Tensor newInt32Tensor(long[] shape, IntBuffer data) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -136,7 +159,7 @@ public abstract class Tensor {
return new Tensor_int32(data, shape);
}
public static Tensor newTensor(long[] shape, ByteBuffer data) {
public static Tensor newFloat32Tensor(long[] shape, FloatBuffer data) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -145,7 +168,31 @@ public abstract class Tensor {
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_byte(data, shape);
return new Tensor_float32(data, shape);
}
public static Tensor newInt64Tensor(long[] shape, LongBuffer data) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_int64(data, shape);
}
public static Tensor newFloat64Tensor(long[] shape, DoubleBuffer data) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
checkArgument(
(data.order() == ByteOrder.nativeOrder()),
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
return new Tensor_float64(data, shape);
}
private Tensor(long[] shape) {
@ -153,11 +200,15 @@ public abstract class Tensor {
this.shape = Arrays.copyOf(shape, shape.length);
}
public long numel() {
return numel(this.shape);
}
public static long numel(long[] shape) {
checkShape(shape);
int result = 1;
for (long dim : shape) {
result *= dim;
for (long s : shape) {
result *= s;
}
return result;
}
@ -169,6 +220,11 @@ public abstract class Tensor {
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
}
public byte[] getDataAsUnsignedByteArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
}
public int[] getDataAsIntArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as int array.");
@ -194,17 +250,49 @@ public abstract class Tensor {
"Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer.");
}
static class Tensor_byte extends Tensor {
static class Tensor_uint8 extends Tensor {
private final ByteBuffer data;
private Tensor_byte(ByteBuffer data, long[] dims) {
super(dims);
private Tensor_uint8(ByteBuffer data, long[] shape) {
super(shape);
this.data = data;
}
@Override
public int dtype() {
return DTYPE_BYTE;
return DTYPE_UINT8;
}
@Override
Buffer getRawDataBuffer() {
return data;
}
@Override
public byte[] getDataAsUnsignedByteArray() {
data.rewind();
byte[] arr = new byte[data.remaining()];
data.get(arr);
return arr;
}
@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape));
}
}
static class Tensor_int8 extends Tensor {
private final ByteBuffer data;
private Tensor_int8(ByteBuffer data, long[] shape) {
super(shape);
this.data = data;
}
@Override
public int dtype() {
return DTYPE_INT8;
}
@Override
@ -222,16 +310,15 @@ public abstract class Tensor {
@Override
public String toString() {
return String.format(
"Tensor_byte{shape:%s numel:%d}", Arrays.toString(shape), data.capacity());
return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape));
}
}
static class Tensor_int32 extends Tensor {
private final IntBuffer data;
private Tensor_int32(IntBuffer data, long[] dims) {
super(dims);
private Tensor_int32(IntBuffer data, long[] shape) {
super(shape);
this.data = data;
}
@ -255,16 +342,15 @@ public abstract class Tensor {
@Override
public String toString() {
return String.format(
"Tensor_int32{shape:%s numel:%d}", Arrays.toString(shape), data.capacity());
return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape));
}
}
static class Tensor_float32 extends Tensor {
private final FloatBuffer data;
Tensor_float32(FloatBuffer data, long[] dims) {
super(dims);
Tensor_float32(FloatBuffer data, long[] shape) {
super(shape);
this.data = data;
}
@ -288,22 +374,21 @@ public abstract class Tensor {
@Override
public String toString() {
return String.format(
"Tensor_float32{shape:%s capacity:%d}", Arrays.toString(shape), data.capacity());
return String.format("Tensor(%s, dtype=torch.float32)", Arrays.toString(shape));
}
}
static class Tensor_long64 extends Tensor {
static class Tensor_int64 extends Tensor {
private final LongBuffer data;
private Tensor_long64(LongBuffer data, long[] dims) {
super(dims);
private Tensor_int64(LongBuffer data, long[] shape) {
super(shape);
this.data = data;
}
@Override
public int dtype() {
return DTYPE_LONG64;
return DTYPE_INT64;
}
@Override
@ -321,22 +406,21 @@ public abstract class Tensor {
@Override
public String toString() {
return String.format(
"Tensor_long64{shape:%s numel:%d}", Arrays.toString(shape), data.capacity());
return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape));
}
}
static class Tensor_double64 extends Tensor {
static class Tensor_float64 extends Tensor {
private final DoubleBuffer data;
private Tensor_double64(DoubleBuffer data, long[] shape) {
private Tensor_float64(DoubleBuffer data, long[] shape) {
super(shape);
this.data = data;
}
@Override
public int dtype() {
return DTYPE_DOUBLE64;
return DTYPE_FLOAT64;
}
@Override
@ -354,8 +438,7 @@ public abstract class Tensor {
@Override
public String toString() {
return String.format(
"Tensor_double64{shape:%s numel:%d}", Arrays.toString(shape), data.capacity());
return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape));
}
}
@ -374,14 +457,14 @@ public abstract class Tensor {
}
}
private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] dims) {
final long numElements = numel(dims);
private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape) {
final long numel = numel(shape);
checkArgument(
numElements == dataCapacity,
"Inconsistent data capacity:%d and dims number elements:%d dims:%s",
numel == dataCapacity,
"Inconsistent data capacity:%d and shape number elements:%d shape:%s",
dataCapacity,
numElements,
Arrays.toString(dims));
numel,
Arrays.toString(shape));
}
// endregion checks
@ -391,12 +474,14 @@ public abstract class Tensor {
return new Tensor_float32(data.asFloatBuffer(), shape);
} else if (DTYPE_INT32 == dtype) {
return new Tensor_int32(data.asIntBuffer(), shape);
} else if (DTYPE_LONG64 == dtype) {
return new Tensor_long64(data.asLongBuffer(), shape);
} else if (DTYPE_DOUBLE64 == dtype) {
return new Tensor_double64(data.asDoubleBuffer(), shape);
} else if (DTYPE_BYTE == dtype) {
return new Tensor_byte(data, shape);
} else if (DTYPE_INT64 == dtype) {
return new Tensor_int64(data.asLongBuffer(), shape);
} else if (DTYPE_FLOAT64 == dtype) {
return new Tensor_float64(data.asDoubleBuffer(), shape);
} else if (DTYPE_UINT8 == dtype) {
return new Tensor_uint8(data, shape);
} else if (DTYPE_INT8 == dtype) {
return new Tensor_int8(data, shape);
}
throw new IllegalArgumentException("Unknown Tensor dtype");
}

View File

@ -40,7 +40,7 @@ public final class TensorImageUtils {
floatArray[offset_b + i] = (b - NORM_MEAN_B) / NORM_STD_B;
}
final long shape[] = new long[] {1, 3, height, width};
return Tensor.newTensor(shape, floatArray);
return Tensor.newFloat32Tensor(shape, floatArray);
}
public static Tensor imageYUV420CenterCropToFloatTensorTorchVisionForm(
@ -131,7 +131,7 @@ public final class TensorImageUtils {
}
}
final long shape[] = new long[] {1, 3, tensorHeight, tensorHeight};
return Tensor.newTensor(shape, floatArray);
return Tensor.newFloat32Tensor(shape, floatArray);
}
private static final int clamp(int c, int min, int max) {