mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix error messages; tensor creation method names with type (#26219)
Summary: After offline discussion with dzhulgakov : - In future we will introduce creation of byte signed and byte unsigned dtype tensors, but java has only signed byte - we will have to add some separation for it in method names ( java types and tensor types can not be clearly mapped) => Returning type in method names - fixes in error messages - non-static method Tensor.numel() - Change Tensor toString() to be more consistent with python Update on Sep 16: Type renaming on java side to uint8, int8, int32, float32, int64, float64 ``` public abstract class Tensor { public static final int DTYPE_UINT8 = 1; public static final int DTYPE_INT8 = 2; public static final int DTYPE_INT32 = 3; public static final int DTYPE_FLOAT32 = 4; public static final int DTYPE_INT64 = 5; public static final int DTYPE_FLOAT64 = 6; ``` ``` public static Tensor newUInt8Tensor(long[] shape, byte[] data) public static Tensor newInt8Tensor(long[] shape, byte[] data) public static Tensor newInt32Tensor(long[] shape, int[] data) public static Tensor newFloat32Tensor(long[] shape, float[] data) public static Tensor newInt64Tensor(long[] shape, long[] data) public static Tensor newFloat64Tensor(long[] shape, double[] data) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/26219 Differential Revision: D17406467 Pulled By: IvanKobzarev fbshipit-source-id: a0d7d44dc8ce8a562da1a18bd873db762975b184
This commit is contained in:
committed by
Facebook Github Bot
parent
448c53747a
commit
b07991f7f5
@ -36,7 +36,7 @@ public class PytorchInstrumentedTests {
|
||||
public void testForwardNull() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input =
|
||||
IValue.tensor(Tensor.newTensor(new long[] {1}, Tensor.allocateByteBuffer(1)));
|
||||
IValue.tensor(Tensor.newInt8Tensor(new long[] {1}, Tensor.allocateByteBuffer(1)));
|
||||
assertTrue(input.isTensor());
|
||||
final IValue output = module.forward(input);
|
||||
assertTrue(output.isNull());
|
||||
@ -103,7 +103,7 @@ public class PytorchInstrumentedTests {
|
||||
for (int i = 0; i < numElements; ++i) {
|
||||
inputTensorData[i] = i;
|
||||
}
|
||||
final Tensor inputTensor = Tensor.newTensor(inputTensorShape, inputTensorData);
|
||||
final Tensor inputTensor = Tensor.newFloat32Tensor(inputTensorShape, inputTensorData);
|
||||
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input = IValue.tensor(inputTensor);
|
||||
@ -231,15 +231,15 @@ public class PytorchInstrumentedTests {
|
||||
floats[i] = i / 1000.f;
|
||||
}
|
||||
|
||||
Tensor tensorBytes = Tensor.newTensor(shape, bytes);
|
||||
assertTrue(tensorBytes.dtype() == Tensor.DTYPE_BYTE);
|
||||
Tensor tensorBytes = Tensor.newInt8Tensor(shape, bytes);
|
||||
assertTrue(tensorBytes.dtype() == Tensor.DTYPE_INT8);
|
||||
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
|
||||
|
||||
Tensor tensorInts = Tensor.newTensor(shape, ints);
|
||||
Tensor tensorInts = Tensor.newInt32Tensor(shape, ints);
|
||||
assertTrue(tensorInts.dtype() == Tensor.DTYPE_INT32);
|
||||
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
|
||||
|
||||
Tensor tensorFloats = Tensor.newTensor(shape, floats);
|
||||
Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats);
|
||||
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
|
||||
float[] floatsOut = tensorFloats.getDataAsFloatArray();
|
||||
assertTrue(floatsOut.length == numel);
|
||||
@ -253,7 +253,7 @@ public class PytorchInstrumentedTests {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
float[] floats = new float[numel];
|
||||
Tensor tensorFloats = Tensor.newTensor(shape, floats);
|
||||
Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats);
|
||||
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
|
||||
tensorFloats.getDataAsByteArray();
|
||||
}
|
||||
|
@ -10,11 +10,12 @@
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
constexpr static int kTensorDTypeByte = 1;
|
||||
constexpr static int kTensorDTypeInt32 = 2;
|
||||
constexpr static int kTensorDTypeFloat32 = 3;
|
||||
constexpr static int kTensorDTypeLong64 = 4;
|
||||
constexpr static int kTensorDTypeDouble64 = 5;
|
||||
constexpr static int kTensorDTypeUInt8 = 1;
|
||||
constexpr static int kTensorDTypeInt8 = 2;
|
||||
constexpr static int kTensorDTypeInt32 = 3;
|
||||
constexpr static int kTensorDTypeFloat32 = 4;
|
||||
constexpr static int kTensorDTypeInt64 = 5;
|
||||
constexpr static int kTensorDTypeFloat64 = 6;
|
||||
|
||||
template <typename K = jobject, typename V = jobject>
|
||||
struct JHashMap
|
||||
@ -64,15 +65,18 @@ static at::Tensor newAtTensor(
|
||||
} else if (kTensorDTypeInt32 == jdtype) {
|
||||
dataElementSizeBytes = 4;
|
||||
typeMeta = caffe2::TypeMeta::Make<int32_t>();
|
||||
} else if (kTensorDTypeByte == jdtype) {
|
||||
} else if (kTensorDTypeInt8 == jdtype) {
|
||||
dataElementSizeBytes = 1;
|
||||
typeMeta = caffe2::TypeMeta::Make<int8_t>();
|
||||
} else if (kTensorDTypeLong64 == jdtype) {
|
||||
dataElementSizeBytes = 8;
|
||||
typeMeta = caffe2::TypeMeta::Make<int64_t>();
|
||||
} else if (kTensorDTypeDouble64 == jdtype) {
|
||||
} else if (kTensorDTypeUInt8 == jdtype) {
|
||||
dataElementSizeBytes = 1;
|
||||
typeMeta = caffe2::TypeMeta::Make<uint8_t>();
|
||||
} else if (kTensorDTypeFloat64 == jdtype) {
|
||||
dataElementSizeBytes = 8;
|
||||
typeMeta = caffe2::TypeMeta::Make<double>();
|
||||
} else if (kTensorDTypeInt64 == jdtype) {
|
||||
dataElementSizeBytes = 8;
|
||||
typeMeta = caffe2::TypeMeta::Make<int64_t>();
|
||||
} else {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
@ -123,11 +127,13 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
|
||||
} else if (at::kInt == scalarType) {
|
||||
jdtype = kTensorDTypeInt32;
|
||||
} else if (at::kByte == scalarType) {
|
||||
jdtype = kTensorDTypeByte;
|
||||
jdtype = kTensorDTypeUInt8;
|
||||
} else if (at::kChar == scalarType) {
|
||||
jdtype = kTensorDTypeInt8;
|
||||
} else if (at::kLong == scalarType) {
|
||||
jdtype = kTensorDTypeLong64;
|
||||
jdtype = kTensorDTypeInt64;
|
||||
} else if (at::kDouble == scalarType) {
|
||||
jdtype = kTensorDTypeDouble64;
|
||||
jdtype = kTensorDTypeFloat64;
|
||||
} else {
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
|
@ -11,17 +11,18 @@ import java.util.Arrays;
|
||||
import java.util.Locale;
|
||||
|
||||
public abstract class Tensor {
|
||||
public static final int DTYPE_BYTE = 1;
|
||||
public static final int DTYPE_INT32 = 2;
|
||||
public static final int DTYPE_FLOAT32 = 3;
|
||||
public static final int DTYPE_LONG64 = 4;
|
||||
public static final int DTYPE_DOUBLE64 = 5;
|
||||
public static final int DTYPE_UINT8 = 1;
|
||||
public static final int DTYPE_INT8 = 2;
|
||||
public static final int DTYPE_INT32 = 3;
|
||||
public static final int DTYPE_FLOAT32 = 4;
|
||||
public static final int DTYPE_INT64 = 5;
|
||||
public static final int DTYPE_FLOAT64 = 6;
|
||||
|
||||
private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null";
|
||||
private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null";
|
||||
private static final String ERROR_MSG_SHAPE_NOT_NULL = "Dims must be not null";
|
||||
private static final String ERROR_MSG_SHAPE_NOT_EMPTY = "Dims must be not empty";
|
||||
private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Dims must be non negative";
|
||||
private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null";
|
||||
private static final String ERROR_MSG_SHAPE_NOT_EMPTY = "Shape must be not empty";
|
||||
private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Shape elements must be non negative";
|
||||
private static final String ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER =
|
||||
"Data buffer must have native byte order (java.nio.ByteOrder#nativeOrder)";
|
||||
private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT =
|
||||
@ -62,17 +63,27 @@ public abstract class Tensor {
|
||||
.asDoubleBuffer();
|
||||
}
|
||||
|
||||
public static Tensor newTensor(long[] shape, byte[] data) {
|
||||
public static Tensor newUInt8Tensor(long[] shape, byte[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.length, shape);
|
||||
final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape));
|
||||
byteBuffer.put(data);
|
||||
return new Tensor_byte(byteBuffer, shape);
|
||||
return new Tensor_uint8(byteBuffer, shape);
|
||||
}
|
||||
|
||||
public static Tensor newTensor(long[] shape, int[] data) {
|
||||
public static Tensor newInt8Tensor(long[] shape, byte[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.length, shape);
|
||||
final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape));
|
||||
byteBuffer.put(data);
|
||||
return new Tensor_int8(byteBuffer, shape);
|
||||
}
|
||||
|
||||
public static Tensor newInt32Tensor(long[] shape, int[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -82,7 +93,7 @@ public abstract class Tensor {
|
||||
return new Tensor_int32(intBuffer, shape);
|
||||
}
|
||||
|
||||
public static Tensor newTensor(long[] shape, float[] data) {
|
||||
public static Tensor newFloat32Tensor(long[] shape, float[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -92,27 +103,27 @@ public abstract class Tensor {
|
||||
return new Tensor_float32(floatBuffer, shape);
|
||||
}
|
||||
|
||||
public static Tensor newTensor(long[] shape, long[] data) {
|
||||
public static Tensor newInt64Tensor(long[] shape, long[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.length, shape);
|
||||
final LongBuffer longBuffer = allocateLongBuffer((int) numel(shape));
|
||||
longBuffer.put(data);
|
||||
return new Tensor_long64(longBuffer, shape);
|
||||
return new Tensor_int64(longBuffer, shape);
|
||||
}
|
||||
|
||||
public static Tensor newTensor(long[] shape, double[] data) {
|
||||
public static Tensor newFloat64Tensor(long[] shape, double[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.length, shape);
|
||||
final DoubleBuffer doubleBuffer = allocateDoubleBuffer((int) numel(shape));
|
||||
doubleBuffer.put(data);
|
||||
return new Tensor_double64(doubleBuffer, shape);
|
||||
return new Tensor_float64(doubleBuffer, shape);
|
||||
}
|
||||
|
||||
public static Tensor newTensor(long[] shape, FloatBuffer data) {
|
||||
public static Tensor newUInt8Tensor(long[] shape, ByteBuffer data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -121,10 +132,22 @@ public abstract class Tensor {
|
||||
checkArgument(
|
||||
(data.order() == ByteOrder.nativeOrder()),
|
||||
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
|
||||
return new Tensor_float32(data, shape);
|
||||
return new Tensor_uint8(data, shape);
|
||||
}
|
||||
|
||||
public static Tensor newTensor(long[] shape, IntBuffer data) {
|
||||
public static Tensor newInt8Tensor(long[] shape, ByteBuffer data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
|
||||
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
|
||||
checkArgument(
|
||||
(data.order() == ByteOrder.nativeOrder()),
|
||||
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
|
||||
return new Tensor_int8(data, shape);
|
||||
}
|
||||
|
||||
public static Tensor newInt32Tensor(long[] shape, IntBuffer data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -136,7 +159,7 @@ public abstract class Tensor {
|
||||
return new Tensor_int32(data, shape);
|
||||
}
|
||||
|
||||
public static Tensor newTensor(long[] shape, ByteBuffer data) {
|
||||
public static Tensor newFloat32Tensor(long[] shape, FloatBuffer data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -145,7 +168,31 @@ public abstract class Tensor {
|
||||
checkArgument(
|
||||
(data.order() == ByteOrder.nativeOrder()),
|
||||
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
|
||||
return new Tensor_byte(data, shape);
|
||||
return new Tensor_float32(data, shape);
|
||||
}
|
||||
|
||||
public static Tensor newInt64Tensor(long[] shape, LongBuffer data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
|
||||
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
|
||||
checkArgument(
|
||||
(data.order() == ByteOrder.nativeOrder()),
|
||||
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
|
||||
return new Tensor_int64(data, shape);
|
||||
}
|
||||
|
||||
public static Tensor newFloat64Tensor(long[] shape, DoubleBuffer data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
|
||||
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
|
||||
checkArgument(
|
||||
(data.order() == ByteOrder.nativeOrder()),
|
||||
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
|
||||
return new Tensor_float64(data, shape);
|
||||
}
|
||||
|
||||
private Tensor(long[] shape) {
|
||||
@ -153,11 +200,15 @@ public abstract class Tensor {
|
||||
this.shape = Arrays.copyOf(shape, shape.length);
|
||||
}
|
||||
|
||||
public long numel() {
|
||||
return numel(this.shape);
|
||||
}
|
||||
|
||||
public static long numel(long[] shape) {
|
||||
checkShape(shape);
|
||||
int result = 1;
|
||||
for (long dim : shape) {
|
||||
result *= dim;
|
||||
for (long s : shape) {
|
||||
result *= s;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -169,6 +220,11 @@ public abstract class Tensor {
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
|
||||
}
|
||||
|
||||
public byte[] getDataAsUnsignedByteArray() {
|
||||
throw new IllegalStateException(
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
|
||||
}
|
||||
|
||||
public int[] getDataAsIntArray() {
|
||||
throw new IllegalStateException(
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot return data as int array.");
|
||||
@ -194,17 +250,49 @@ public abstract class Tensor {
|
||||
"Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer.");
|
||||
}
|
||||
|
||||
static class Tensor_byte extends Tensor {
|
||||
static class Tensor_uint8 extends Tensor {
|
||||
private final ByteBuffer data;
|
||||
|
||||
private Tensor_byte(ByteBuffer data, long[] dims) {
|
||||
super(dims);
|
||||
private Tensor_uint8(ByteBuffer data, long[] shape) {
|
||||
super(shape);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_BYTE;
|
||||
return DTYPE_UINT8;
|
||||
}
|
||||
|
||||
@Override
|
||||
Buffer getRawDataBuffer() {
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] getDataAsUnsignedByteArray() {
|
||||
data.rewind();
|
||||
byte[] arr = new byte[data.remaining()];
|
||||
data.get(arr);
|
||||
return arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape));
|
||||
}
|
||||
}
|
||||
|
||||
static class Tensor_int8 extends Tensor {
|
||||
private final ByteBuffer data;
|
||||
|
||||
private Tensor_int8(ByteBuffer data, long[] shape) {
|
||||
super(shape);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_INT8;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -222,16 +310,15 @@ public abstract class Tensor {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format(
|
||||
"Tensor_byte{shape:%s numel:%d}", Arrays.toString(shape), data.capacity());
|
||||
return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape));
|
||||
}
|
||||
}
|
||||
|
||||
static class Tensor_int32 extends Tensor {
|
||||
private final IntBuffer data;
|
||||
|
||||
private Tensor_int32(IntBuffer data, long[] dims) {
|
||||
super(dims);
|
||||
private Tensor_int32(IntBuffer data, long[] shape) {
|
||||
super(shape);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
@ -255,16 +342,15 @@ public abstract class Tensor {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format(
|
||||
"Tensor_int32{shape:%s numel:%d}", Arrays.toString(shape), data.capacity());
|
||||
return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape));
|
||||
}
|
||||
}
|
||||
|
||||
static class Tensor_float32 extends Tensor {
|
||||
private final FloatBuffer data;
|
||||
|
||||
Tensor_float32(FloatBuffer data, long[] dims) {
|
||||
super(dims);
|
||||
Tensor_float32(FloatBuffer data, long[] shape) {
|
||||
super(shape);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
@ -288,22 +374,21 @@ public abstract class Tensor {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format(
|
||||
"Tensor_float32{shape:%s capacity:%d}", Arrays.toString(shape), data.capacity());
|
||||
return String.format("Tensor(%s, dtype=torch.float32)", Arrays.toString(shape));
|
||||
}
|
||||
}
|
||||
|
||||
static class Tensor_long64 extends Tensor {
|
||||
static class Tensor_int64 extends Tensor {
|
||||
private final LongBuffer data;
|
||||
|
||||
private Tensor_long64(LongBuffer data, long[] dims) {
|
||||
super(dims);
|
||||
private Tensor_int64(LongBuffer data, long[] shape) {
|
||||
super(shape);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_LONG64;
|
||||
return DTYPE_INT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -321,22 +406,21 @@ public abstract class Tensor {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format(
|
||||
"Tensor_long64{shape:%s numel:%d}", Arrays.toString(shape), data.capacity());
|
||||
return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape));
|
||||
}
|
||||
}
|
||||
|
||||
static class Tensor_double64 extends Tensor {
|
||||
static class Tensor_float64 extends Tensor {
|
||||
private final DoubleBuffer data;
|
||||
|
||||
private Tensor_double64(DoubleBuffer data, long[] shape) {
|
||||
private Tensor_float64(DoubleBuffer data, long[] shape) {
|
||||
super(shape);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_DOUBLE64;
|
||||
return DTYPE_FLOAT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -354,8 +438,7 @@ public abstract class Tensor {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format(
|
||||
"Tensor_double64{shape:%s numel:%d}", Arrays.toString(shape), data.capacity());
|
||||
return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape));
|
||||
}
|
||||
}
|
||||
|
||||
@ -374,14 +457,14 @@ public abstract class Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] dims) {
|
||||
final long numElements = numel(dims);
|
||||
private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape) {
|
||||
final long numel = numel(shape);
|
||||
checkArgument(
|
||||
numElements == dataCapacity,
|
||||
"Inconsistent data capacity:%d and dims number elements:%d dims:%s",
|
||||
numel == dataCapacity,
|
||||
"Inconsistent data capacity:%d and shape number elements:%d shape:%s",
|
||||
dataCapacity,
|
||||
numElements,
|
||||
Arrays.toString(dims));
|
||||
numel,
|
||||
Arrays.toString(shape));
|
||||
}
|
||||
// endregion checks
|
||||
|
||||
@ -391,12 +474,14 @@ public abstract class Tensor {
|
||||
return new Tensor_float32(data.asFloatBuffer(), shape);
|
||||
} else if (DTYPE_INT32 == dtype) {
|
||||
return new Tensor_int32(data.asIntBuffer(), shape);
|
||||
} else if (DTYPE_LONG64 == dtype) {
|
||||
return new Tensor_long64(data.asLongBuffer(), shape);
|
||||
} else if (DTYPE_DOUBLE64 == dtype) {
|
||||
return new Tensor_double64(data.asDoubleBuffer(), shape);
|
||||
} else if (DTYPE_BYTE == dtype) {
|
||||
return new Tensor_byte(data, shape);
|
||||
} else if (DTYPE_INT64 == dtype) {
|
||||
return new Tensor_int64(data.asLongBuffer(), shape);
|
||||
} else if (DTYPE_FLOAT64 == dtype) {
|
||||
return new Tensor_float64(data.asDoubleBuffer(), shape);
|
||||
} else if (DTYPE_UINT8 == dtype) {
|
||||
return new Tensor_uint8(data, shape);
|
||||
} else if (DTYPE_INT8 == dtype) {
|
||||
return new Tensor_int8(data, shape);
|
||||
}
|
||||
throw new IllegalArgumentException("Unknown Tensor dtype");
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ public final class TensorImageUtils {
|
||||
floatArray[offset_b + i] = (b - NORM_MEAN_B) / NORM_STD_B;
|
||||
}
|
||||
final long shape[] = new long[] {1, 3, height, width};
|
||||
return Tensor.newTensor(shape, floatArray);
|
||||
return Tensor.newFloat32Tensor(shape, floatArray);
|
||||
}
|
||||
|
||||
public static Tensor imageYUV420CenterCropToFloatTensorTorchVisionForm(
|
||||
@ -131,7 +131,7 @@ public final class TensorImageUtils {
|
||||
}
|
||||
}
|
||||
final long shape[] = new long[] {1, 3, tensorHeight, tensorHeight};
|
||||
return Tensor.newTensor(shape, floatArray);
|
||||
return Tensor.newFloat32Tensor(shape, floatArray);
|
||||
}
|
||||
|
||||
private static final int clamp(int c, int min, int max) {
|
||||
|
Reference in New Issue
Block a user