Various cleanups to pytorch_android API (#27454)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27454

See detailed discussion at
https://github.com/pytorch/pytorch/issues/27350

Test Plan: Imported from OSS

Reviewed By: IvanKobzarev

Differential Revision: D17800480

Pulled By: dreiss

fbshipit-source-id: bf174e8b16231b89be771de0fa54c41e864a3eb0
This commit is contained in:
David Reiss
2019-10-07 16:58:27 -07:00
committed by Facebook Github Bot
parent b66df47a11
commit 1ffa81d772
6 changed files with 199 additions and 176 deletions

View File

@ -24,7 +24,7 @@ public abstract class PytorchTestBase {
public void testForwardNull() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final IValue input =
IValue.tensor(Tensor.newInt8Tensor(new long[] {1}, Tensor.allocateByteBuffer(1)));
IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
assertTrue(input.isTensor());
final IValue output = module.forward(input);
assertTrue(output.isNull());
@ -34,12 +34,12 @@ public abstract class PytorchTestBase {
public void testEqBool() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
for (boolean value : new boolean[] {false, true}) {
final IValue input = IValue.bool(value);
final IValue input = IValue.from(value);
assertTrue(input.isBool());
assertTrue(value == input.getBool());
assertTrue(value == input.toBool());
final IValue output = module.runMethod("eqBool", input);
assertTrue(output.isBool());
assertTrue(value == output.getBool());
assertTrue(value == output.toBool());
}
}
@ -47,12 +47,12 @@ public abstract class PytorchTestBase {
public void testEqInt() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
final IValue input = IValue.long64(value);
final IValue input = IValue.from(value);
assertTrue(input.isLong());
assertTrue(value == input.getLong());
assertTrue(value == input.toLong());
final IValue output = module.runMethod("eqInt", input);
assertTrue(output.isLong());
assertTrue(value == output.getLong());
assertTrue(value == output.toLong());
}
}
@ -74,12 +74,12 @@ public abstract class PytorchTestBase {
1,
};
for (double value : values) {
final IValue input = IValue.double64(value);
final IValue input = IValue.from(value);
assertTrue(input.isDouble());
assertTrue(value == input.getDouble());
assertTrue(value == input.toDouble());
final IValue output = module.runMethod("eqFloat", input);
assertTrue(output.isDouble());
assertTrue(value == output.getDouble());
assertTrue(value == output.toDouble());
}
}
@ -91,17 +91,17 @@ public abstract class PytorchTestBase {
for (int i = 0; i < numElements; ++i) {
inputTensorData[i] = i;
}
final Tensor inputTensor = Tensor.newFloat32Tensor(inputTensorShape, inputTensorData);
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final IValue input = IValue.tensor(inputTensor);
final IValue input = IValue.from(inputTensor);
assertTrue(input.isTensor());
assertTrue(inputTensor == input.getTensor());
assertTrue(inputTensor == input.toTensor());
final IValue output = module.runMethod("eqTensor", input);
assertTrue(output.isTensor());
final Tensor outputTensor = output.getTensor();
final Tensor outputTensor = output.toTensor();
assertNotNull(outputTensor);
assertArrayEquals(inputTensorShape, outputTensor.shape);
assertArrayEquals(inputTensorShape, outputTensor.shape());
float[] outputData = outputTensor.getDataAsFloatArray();
for (int i = 0; i < numElements; i++) {
assertTrue(inputTensorData[i] == outputData[i]);
@ -113,22 +113,22 @@ public abstract class PytorchTestBase {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Map<Long, IValue> inputMap = new HashMap<>();
inputMap.put(Long.MIN_VALUE, IValue.long64(-Long.MIN_VALUE));
inputMap.put(Long.MAX_VALUE, IValue.long64(-Long.MAX_VALUE));
inputMap.put(0l, IValue.long64(0l));
inputMap.put(1l, IValue.long64(-1l));
inputMap.put(-1l, IValue.long64(1l));
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
inputMap.put(0l, IValue.from(0l));
inputMap.put(1l, IValue.from(-1l));
inputMap.put(-1l, IValue.from(1l));
final IValue input = IValue.dictLongKey(inputMap);
final IValue input = IValue.dictLongKeyFrom(inputMap);
assertTrue(input.isDictLongKey());
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
assertTrue(output.isDictLongKey());
final Map<Long, IValue> outputMap = output.getDictLongKey();
final Map<Long, IValue> outputMap = output.toDictLongKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).getLong() == entry.getValue().getLong());
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@ -137,22 +137,22 @@ public abstract class PytorchTestBase {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Map<String, IValue> inputMap = new HashMap<>();
inputMap.put("long_min_value", IValue.long64(Long.MIN_VALUE));
inputMap.put("long_max_value", IValue.long64(Long.MAX_VALUE));
inputMap.put("long_0", IValue.long64(0l));
inputMap.put("long_1", IValue.long64(1l));
inputMap.put("long_-1", IValue.long64(-1l));
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
inputMap.put("long_0", IValue.from(0l));
inputMap.put("long_1", IValue.from(1l));
inputMap.put("long_-1", IValue.from(-1l));
final IValue input = IValue.dictStringKey(inputMap);
final IValue input = IValue.dictStringKeyFrom(inputMap);
assertTrue(input.isDictStringKey());
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
assertTrue(output.isDictStringKey());
final Map<String, IValue> outputMap = output.getDictStringKey();
final Map<String, IValue> outputMap = output.toDictStringKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).getLong() == entry.getValue().getLong());
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@ -167,19 +167,19 @@ public abstract class PytorchTestBase {
a[i] = i;
sum += a[i];
}
final IValue input = IValue.longList(a);
final IValue input = IValue.listFrom(a);
assertTrue(input.isLongList());
final IValue output = module.runMethod("listIntSumReturnTuple", input);
assertTrue(output.isTuple());
assertTrue(2 == output.getTuple().length);
assertTrue(2 == output.toTuple().length);
IValue output0 = output.getTuple()[0];
IValue output1 = output.getTuple()[1];
IValue output0 = output.toTuple()[0];
IValue output1 = output.toTuple()[1];
assertArrayEquals(a, output0.getLongList());
assertTrue(sum == output1.getLong());
assertArrayEquals(a, output0.toLongList());
assertTrue(sum == output1.toLong());
}
}
@ -187,16 +187,16 @@ public abstract class PytorchTestBase {
public void testOptionalIntIsNone() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
assertFalse(module.runMethod("optionalIntIsNone", IValue.long64(1l)).getBool());
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).getBool());
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
}
@Test
public void testIntEq0None() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
assertTrue(module.runMethod("intEq0None", IValue.long64(0l)).isNull());
assertTrue(module.runMethod("intEq0None", IValue.long64(1l)).getLong() == 1l);
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
}
@Test(expected = IllegalArgumentException.class)
@ -219,16 +219,16 @@ public abstract class PytorchTestBase {
floats[i] = i / 1000.f;
}
Tensor tensorBytes = Tensor.newInt8Tensor(shape, bytes);
assertTrue(tensorBytes.dtype() == Tensor.DTYPE_INT8);
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
assertTrue(tensorBytes.dtype() == DType.INT8);
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
Tensor tensorInts = Tensor.newInt32Tensor(shape, ints);
assertTrue(tensorInts.dtype() == Tensor.DTYPE_INT32);
Tensor tensorInts = Tensor.fromBlob(ints, shape);
assertTrue(tensorInts.dtype() == DType.INT32);
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats);
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
float[] floatsOut = tensorFloats.getDataAsFloatArray();
assertTrue(floatsOut.length == numel);
for (int i = 0; i < numel; i++) {
@ -241,8 +241,8 @@ public abstract class PytorchTestBase {
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
float[] floats = new float[numel];
Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats);
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
tensorFloats.getDataAsByteArray();
}
@ -257,12 +257,12 @@ public abstract class PytorchTestBase {
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.string(value);
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.getString()));
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("eqStr", input);
assertTrue(output.isString());
assertTrue(value.equals(output.getString()));
assertTrue(value.equals(output.toStr()));
}
}
@ -276,13 +276,13 @@ public abstract class PytorchTestBase {
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.string(value);
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.getString()));
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("str3Concat", input);
assertTrue(output.isString());
String expectedOutput = new StringBuilder().append(value).append(value).append(value).toString();
assertTrue(expectedOutput.equals(output.getString()));
assertTrue(expectedOutput.equals(output.toStr()));
}
}

View File

@ -10,6 +10,8 @@
namespace pytorch_jni {
// NOTE: Codes must be kept in sync with DType.java.
// NOTE: Never serialize these, because they can change between releases.
constexpr static int kTensorDTypeUInt8 = 1;
constexpr static int kTensorDTypeInt8 = 2;
constexpr static int kTensorDTypeInt32 = 3;
@ -164,7 +166,7 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
static at::Tensor newAtTensorFromJTensor(
facebook::jni::alias_ref<JTensor> jtensor) {
static const auto dtypeMethod =
JTensor::javaClassStatic()->getMethod<jint()>("dtype");
JTensor::javaClassStatic()->getMethod<jint()>("dtypeJniCode");
jint jdtype = dtypeMethod(jtensor);
static const auto shapeField =
@ -216,7 +218,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
static auto jMethodTensor =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::local_ref<JTensor>)>("tensor");
facebook::jni::local_ref<JTensor>)>("from");
return jMethodTensor(
JIValue::javaClassStatic(),
JTensor::newJTensorFromAtTensor(ivalue.toTensor()));
@ -224,26 +226,26 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
static auto jMethodBool =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(jboolean)>(
"bool");
"from");
return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool());
} else if (ivalue.isInt()) {
static auto jMethodInt =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(jlong)>(
"long64");
"from");
return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt());
} else if (ivalue.isDouble()) {
static auto jMethodDouble =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(jdouble)>(
"double64");
"from");
return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble());
} else if (ivalue.isString()) {
static auto jMethodString =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<
facebook::jni::JString::javaobject>)>("string");
facebook::jni::JString::javaobject>)>("from");
return jMethodString(
JIValue::javaClassStatic(),
facebook::jni::make_jstring(ivalue.toStringRef()));
@ -253,7 +255,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JArrayClass<
JIValue::javaobject>::javaobject>)>("tuple");
JIValue::javaobject>::javaobject>)>("tupleFrom");
auto jElementsArray =
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
elementsVec.size());
@ -267,7 +269,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
static auto jMethodBoolListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<jbooleanArray>)>("boolList");
facebook::jni::alias_ref<jbooleanArray>)>("listFrom");
size_t n = list.size();
auto jArray = facebook::jni::make_boolean_array(n);
auto jArrayPinned = jArray->pin();
@ -281,7 +283,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
static auto jMethodLongListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<jlongArray>)>("longList");
facebook::jni::alias_ref<jlongArray>)>("listFrom");
size_t n = list.size();
auto jArray = facebook::jni::make_long_array(n);
auto jArrayPinned = jArray->pin();
@ -295,7 +297,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
static auto jMethoDoubleListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<jdoubleArray>)>("doubleList");
facebook::jni::alias_ref<jdoubleArray>)>("listFrom");
size_t n = list.size();
auto jArray = facebook::jni::make_double_array(n);
auto jArrayPinned = jArray->pin();
@ -310,7 +312,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JArrayClass<
JTensor::javaobject>::javaobject>)>("tensorList");
JTensor::javaobject>::javaobject>)>("listFrom");
auto jArray = facebook::jni::JArrayClass<JTensor::javaobject>::newArray(
list.size());
auto index = 0;
@ -324,7 +326,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JArrayClass<
JIValue::javaobject>::javaobject>)>("list");
JIValue::javaobject>::javaobject>)>("listFrom");
auto jArray = facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
list.size());
auto index = 0;
@ -351,7 +353,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
facebook::jni::alias_ref<
facebook::jni::JString::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
"dictStringKey");
"dictStringKeyFrom");
auto jmap = JHashMap<
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
@ -370,7 +372,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
facebook::jni::alias_ref<
facebook::jni::JLong::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
"dictLongKey");
"dictLongKeyFrom");
auto jmap = JHashMap<
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>::create();
@ -404,32 +406,32 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
static const auto jMethodGetTensor =
JIValue::javaClassStatic()
->getMethod<facebook::jni::alias_ref<JTensor::javaobject>()>(
"getTensor");
"toTensor");
return JTensor::newAtTensorFromJTensor(jMethodGetTensor(jivalue));
} else if (JIValue::kTypeCodeBool == typeCode) {
static const auto jMethodGetBool =
JIValue::javaClassStatic()->getMethod<jboolean()>("getBool");
JIValue::javaClassStatic()->getMethod<jboolean()>("toBool");
// explicit cast to bool as jboolean is defined as uint8_t, IValue ctor
// for int will be called for jboolean
bool b = jMethodGetBool(jivalue);
return at::IValue{b};
} else if (JIValue::kTypeCodeLong == typeCode) {
static const auto jMethodGetLong =
JIValue::javaClassStatic()->getMethod<jlong()>("getLong");
JIValue::javaClassStatic()->getMethod<jlong()>("toLong");
return at::IValue{jMethodGetLong(jivalue)};
} else if (JIValue::kTypeCodeDouble == typeCode) {
static const auto jMethodGetDouble =
JIValue::javaClassStatic()->getMethod<jdouble()>("getDouble");
JIValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
return at::IValue{jMethodGetDouble(jivalue)};
} else if (JIValue::kTypeCodeString == typeCode) {
static const auto jMethodGetString =
JIValue::javaClassStatic()->getMethod<jstring()>("getString");
JIValue::javaClassStatic()->getMethod<jstring()>("toStr");
return at::IValue{jMethodGetString(jivalue)->toStdString()};
} else if (JIValue::kTypeCodeTuple == typeCode) {
static const auto jMethodGetTuple =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JArrayClass<
JIValue::javaobject>::javaobject()>("getTuple");
JIValue::javaobject>::javaobject()>("toTuple");
auto jarray = jMethodGetTuple(jivalue);
size_t n = jarray->size();
@ -443,7 +445,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
return c10::ivalue::Tuple::create(std::move(elements));
} else if (JIValue::kTypeCodeBoolList == typeCode) {
static const auto jMethodGetBoolList =
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("getBoolList");
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("toBoolList");
auto jArray = jMethodGetBoolList(jivalue);
auto jArrayPinned = jArray->pin();
size_t n = jArrayPinned.size();
@ -455,7 +457,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeLongList == typeCode) {
static const auto jMethodGetLongList =
JIValue::javaClassStatic()->getMethod<jlongArray()>("getLongList");
JIValue::javaClassStatic()->getMethod<jlongArray()>("toLongList");
auto jArray = jMethodGetLongList(jivalue);
auto jArrayPinned = jArray->pin();
size_t n = jArrayPinned.size();
@ -468,7 +470,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
} else if (JIValue::kTypeCodeDoubleList == typeCode) {
static const auto jMethodGetDoubleList =
JIValue::javaClassStatic()->getMethod<jdoubleArray()>(
"getDoubleList");
"toDoubleList");
auto jArray = jMethodGetDoubleList(jivalue);
auto jArrayPinned = jArray->pin();
size_t n = jArrayPinned.size();
@ -482,7 +484,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
static const auto jMethodGetTensorList =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JArrayClass<
JTensor::javaobject>::javaobject()>("getTensorList");
JTensor::javaobject>::javaobject()>("toTensorList");
auto jArray = jMethodGetTensorList(jivalue);
size_t n = jArray->size();
c10::List<at::Tensor> list{};
@ -495,7 +497,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
static const auto jMethodGetList =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JArrayClass<
JIValue::javaobject>::javaobject()>("getList");
JIValue::javaobject>::javaobject()>("toList");
auto jarray = jMethodGetList(jivalue);
size_t n = jarray->size();
if (n == 0) {
@ -518,7 +520,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
static const auto jMethodGetDictStringKey =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JMap<jstring, JIValue::javaobject>::
javaobject()>("getDictStringKey");
javaobject()>("toDictStringKey");
auto jmap = jMethodGetDictStringKey(jivalue);
auto it = jmap->begin();
if (it == jmap->end()) {
@ -541,7 +543,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
JIValue::javaClassStatic()
->getMethod<facebook::jni::JMap<
facebook::jni::JLong::javaobject,
JIValue::javaobject>::javaobject()>("getDictLongKey");
JIValue::javaobject>::javaobject()>("toDictLongKey");
auto jmap = jMethodGetDictLongKey(jivalue);
auto it = jmap->begin();
if (it == jmap->end()) {

View File

@ -0,0 +1,29 @@
package org.pytorch;
/**
* Codes representing tensor data types.
*/
public enum DType {
// NOTE: "jniCode" must be kept in sync with pytorch_jni.cpp.
// NOTE: Never serialize "jniCode", because it can change between releases.
/** Code for dtype torch.uint8. {@link Tensor#dtype()} */
UINT8(1),
/** Code for dtype torch.int8. {@link Tensor#dtype()} */
INT8(2),
/** Code for dtype torch.int32. {@link Tensor#dtype()} */
INT32(3),
/** Code for dtype torch.float32. {@link Tensor#dtype()} */
FLOAT32(4),
/** Code for dtype torch.int64. {@link Tensor#dtype()} */
INT64(5),
/** Code for dtype torch.float64. {@link Tensor#dtype()} */
FLOAT64(6),
;
final int jniCode;
DType(int jniCode) {
this.jniCode = jniCode;
}
}

View File

@ -98,7 +98,7 @@ public class IValue {
/**
* Creates a new IValue instance of torchscript Tensor type.
*/
public static IValue tensor(Tensor tensor) {
public static IValue from(Tensor tensor) {
final IValue iv = new IValue(TYPE_CODE_TENSOR);
iv.mData = tensor;
return iv;
@ -107,7 +107,7 @@ public class IValue {
/**
* Creates a new IValue instance of torchscript bool type.
*/
public static IValue bool(boolean value) {
public static IValue from(boolean value) {
final IValue iv = new IValue(TYPE_CODE_BOOL);
iv.mData = value;
return iv;
@ -116,7 +116,7 @@ public class IValue {
/**
* Creates a new IValue instance of torchscript int type.
*/
public static IValue long64(long value) {
public static IValue from(long value) {
final IValue iv = new IValue(TYPE_CODE_LONG);
iv.mData = value;
return iv;
@ -125,7 +125,7 @@ public class IValue {
/**
* Creates a new IValue instance of torchscript float type.
*/
public static IValue double64(double value) {
public static IValue from(double value) {
final IValue iv = new IValue(TYPE_CODE_DOUBLE);
iv.mData = value;
return iv;
@ -134,7 +134,7 @@ public class IValue {
/**
* Creates new IValue instance of torchscript str type.
*/
public static IValue string(String value) {
public static IValue from(String value) {
final IValue iv = new IValue(TYPE_CODE_STRING);
iv.mData = value;
return iv;
@ -143,7 +143,7 @@ public class IValue {
/**
* Creates a new IValue instance of torchscript List[bool] type.
*/
public static IValue boolList(boolean... list) {
public static IValue listFrom(boolean... list) {
final IValue iv = new IValue(TYPE_CODE_BOOL_LIST);
iv.mData = list;
return iv;
@ -152,7 +152,7 @@ public class IValue {
/**
* Creates a new IValue instance of torchscript List[int] type.
*/
public static IValue longList(long... list) {
public static IValue listFrom(long... list) {
final IValue iv = new IValue(TYPE_CODE_LONG_LIST);
iv.mData = list;
return iv;
@ -161,7 +161,7 @@ public class IValue {
/**
* Creates a new IValue instance of torchscript List[float] type.
*/
public static IValue doubleList(double... list) {
public static IValue listFrom(double... list) {
final IValue iv = new IValue(TYPE_CODE_DOUBLE_LIST);
iv.mData = list;
return iv;
@ -170,7 +170,7 @@ public class IValue {
/**
* Creates a new IValue instance of torchscript List[Tensor] type.
*/
public static IValue tensorList(Tensor... list) {
public static IValue listFrom(Tensor... list) {
final IValue iv = new IValue(TYPE_CODE_TENSOR_LIST);
iv.mData = list;
return iv;
@ -179,7 +179,7 @@ public class IValue {
/**
* Creates a new IValue instance of torchscript List[T] type. All elements must have the same type.
*/
public static IValue list(IValue... array) {
public static IValue listFrom(IValue... array) {
final int size = array.length;
if (size > 0) {
final int typeCode0 = array[0].mTypeCode;
@ -198,7 +198,7 @@ public class IValue {
/**
* Creates a new IValue instance of torchscript Tuple[T0, T1, ...] type.
*/
public static IValue tuple(IValue... array) {
public static IValue tupleFrom(IValue... array) {
final IValue iv = new IValue(TYPE_CODE_TUPLE);
iv.mData = array;
return iv;
@ -207,7 +207,7 @@ public class IValue {
/**
* Creates a new IValue instance oftorchscript Dict[Str, V] type.
*/
public static IValue dictStringKey(Map<String, IValue> map) {
public static IValue dictStringKeyFrom(Map<String, IValue> map) {
final IValue iv = new IValue(TYPE_CODE_DICT_STRING_KEY);
iv.mData = map;
return iv;
@ -216,73 +216,73 @@ public class IValue {
/**
* Creates a new IValue instance of torchscript Dict[int, V] type.
*/
public static IValue dictLongKey(Map<Long, IValue> map) {
public static IValue dictLongKeyFrom(Map<Long, IValue> map) {
final IValue iv = new IValue(TYPE_CODE_DICT_LONG_KEY);
iv.mData = map;
return iv;
}
public Tensor getTensor() {
public Tensor toTensor() {
preconditionType(TYPE_CODE_TENSOR, mTypeCode);
return (Tensor) mData;
}
public boolean getBool() {
public boolean toBool() {
preconditionType(TYPE_CODE_BOOL, mTypeCode);
return (boolean) mData;
}
public long getLong() {
public long toLong() {
preconditionType(TYPE_CODE_LONG, mTypeCode);
return (long) mData;
}
public double getDouble() {
public double toDouble() {
preconditionType(TYPE_CODE_DOUBLE, mTypeCode);
return (double) mData;
}
public String getString() {
public String toStr() {
preconditionType(TYPE_CODE_STRING, mTypeCode);
return (String) mData;
}
public boolean[] getBoolList() {
public boolean[] toBoolList() {
preconditionType(TYPE_CODE_BOOL_LIST, mTypeCode);
return (boolean[]) mData;
}
public long[] getLongList() {
public long[] toLongList() {
preconditionType(TYPE_CODE_LONG_LIST, mTypeCode);
return (long[]) mData;
}
public double[] getDoubleList() {
public double[] toDoubleList() {
preconditionType(TYPE_CODE_DOUBLE_LIST, mTypeCode);
return (double[]) mData;
}
public Tensor[] getTensorList() {
public Tensor[] toTensorList() {
preconditionType(TYPE_CODE_TENSOR_LIST, mTypeCode);
return (Tensor[]) mData;
}
public IValue[] getList() {
public IValue[] toList() {
preconditionType(TYPE_CODE_LIST, mTypeCode);
return (IValue[]) mData;
}
public IValue[] getTuple() {
public IValue[] toTuple() {
preconditionType(TYPE_CODE_TUPLE, mTypeCode);
return (IValue[]) mData;
}
public Map<String, IValue> getDictStringKey() {
public Map<String, IValue> toDictStringKey() {
preconditionType(TYPE_CODE_DICT_STRING_KEY, mTypeCode);
return (Map<String, IValue>) mData;
}
public Map<Long, IValue> getDictLongKey() {
public Map<Long, IValue> toDictLongKey() {
preconditionType(TYPE_CODE_DICT_LONG_KEY, mTypeCode);
return (Map<Long, IValue>) mData;
}

View File

@ -15,20 +15,6 @@ import java.util.Locale;
* {@link java.nio.DirectByteBuffer} of one of the supported types.
*/
public abstract class Tensor {
/** Code for dtype torch.uint8. {@link Tensor#dtype()} */
public static final int DTYPE_UINT8 = 1;
/** Code for dtype torch.int8. {@link Tensor#dtype()} */
public static final int DTYPE_INT8 = 2;
/** Code for dtype torch.int32. {@link Tensor#dtype()} */
public static final int DTYPE_INT32 = 3;
/** Code for dtype torch.float32. {@link Tensor#dtype()} */
public static final int DTYPE_FLOAT32 = 4;
/** Code for dtype torch.int64. {@link Tensor#dtype()} */
public static final int DTYPE_INT64 = 5;
/** Code for dtype torch.float64. {@link Tensor#dtype()} */
public static final int DTYPE_FLOAT64 = 6;
private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null";
private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null";
private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null";
@ -39,8 +25,7 @@ public abstract class Tensor {
private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT =
"Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)";
/** Shape of current tensor. */
public final long[] shape;
final long[] shape;
private static final int INT_SIZE_BYTES = 4;
private static final int FLOAT_SIZE_BYTES = 4;
@ -49,8 +34,8 @@ public abstract class Tensor {
/**
* Allocates a new direct {@link java.nio.ByteBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#newInt8Tensor(long[], ByteBuffer)}, {@link
* Tensor#newUInt8Tensor(long[], ByteBuffer)}.
* capacity that can be used in {@link Tensor#fromBlob(ByteBuffer, long[])},
* {@link Tensor#fromBlobUnsigned(ByteBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
@ -66,7 +51,7 @@ public abstract class Tensor {
/**
* Allocates a new direct {@link java.nio.FloatBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#newFloat32Tensor(long[], FloatBuffer)}.
* capacity that can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
@ -78,7 +63,7 @@ public abstract class Tensor {
/**
* Allocates a new direct {@link java.nio.LongBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#newInt64Tensor(long[], LongBuffer)}.
* capacity that can be used in {@link Tensor#fromBlob(LongBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
@ -90,7 +75,7 @@ public abstract class Tensor {
/**
* Allocates a new direct {@link java.nio.DoubleBuffer} with native byte order with specified
* capacity that can be used in {@link Tensor#newFloat64Tensor(long[], DoubleBuffer)}.
* capacity that can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}.
*
* @param numElements capacity (number of elements) of result buffer.
*/
@ -104,10 +89,10 @@ public abstract class Tensor {
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of
* bytes.
*
* @param shape Tensor shape
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor newUInt8Tensor(long[] shape, byte[] data) {
public static Tensor fromBlobUnsigned(byte[] data, long[] shape) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -121,10 +106,10 @@ public abstract class Tensor {
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of
* bytes.
*
* @param shape Tensor shape
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor newInt8Tensor(long[] shape, byte[] data) {
public static Tensor fromBlob(byte[] data, long[] shape) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -138,10 +123,10 @@ public abstract class Tensor {
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of
* ints.
*
* @param shape Tensor shape
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor newInt32Tensor(long[] shape, int[] data) {
public static Tensor fromBlob(int[] data, long[] shape) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -155,10 +140,10 @@ public abstract class Tensor {
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array
* of floats.
*
* @param shape Tensor shape
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor newFloat32Tensor(long[] shape, float[] data) {
public static Tensor fromBlob(float[] data, long[] shape) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -172,10 +157,10 @@ public abstract class Tensor {
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of
* longs.
*
* @param shape Tensor shape
* @param data Tensor elements
* @param shape Tensor shape
*/
public static Tensor newInt64Tensor(long[] shape, long[] data) {
public static Tensor fromBlob(long[] data, long[] shape) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -192,7 +177,7 @@ public abstract class Tensor {
* @param shape Tensor shape
* @param data Tensor elements
*/
public static Tensor newFloat64Tensor(long[] shape, double[] data) {
public static Tensor fromBlob(long[] shape, double[] data) {
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -205,12 +190,12 @@ public abstract class Tensor {
/**
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data.
*
* @param shape Tensor shape
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor newUInt8Tensor(long[] shape, ByteBuffer data) {
public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -225,12 +210,12 @@ public abstract class Tensor {
/**
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data.
*
* @param shape Tensor shape
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor newInt8Tensor(long[] shape, ByteBuffer data) {
public static Tensor fromBlob(ByteBuffer data, long[] shape) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -245,12 +230,12 @@ public abstract class Tensor {
/**
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data.
*
* @param shape Tensor shape
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor newInt32Tensor(long[] shape, IntBuffer data) {
public static Tensor fromBlob(IntBuffer data, long[] shape) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -265,12 +250,12 @@ public abstract class Tensor {
/**
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data.
*
* @param shape Tensor shape
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor newFloat32Tensor(long[] shape, FloatBuffer data) {
public static Tensor fromBlob(FloatBuffer data, long[] shape) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -285,12 +270,12 @@ public abstract class Tensor {
/**
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data.
*
* @param shape Tensor shape
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor newInt64Tensor(long[] shape, LongBuffer data) {
public static Tensor fromBlob(LongBuffer data, long[] shape) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -305,12 +290,12 @@ public abstract class Tensor {
/**
* Creates a new Tensor instance with dtype torch.float64 with specified shape and data.
*
* @param shape Tensor shape
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
* elements. The buffer is used directly without copying, and changes to its content will
* change the tensor.
* @param shape Tensor shape
*/
public static Tensor newFloat64Tensor(long[] shape, DoubleBuffer data) {
public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
@ -342,12 +327,19 @@ public abstract class Tensor {
return result;
}
/** Shape of current tensor. */
public long[] shape() {
return Arrays.copyOf(shape, shape.length);
}
/**
* Returns dtype of current tensor. Can be one of {@link Tensor#DTYPE_UINT8}, {@link
* Tensor#DTYPE_INT8}, {@link Tensor#DTYPE_INT32},{@link Tensor#DTYPE_FLOAT32}, {@link
* Tensor#DTYPE_INT64}, {@link Tensor#DTYPE_FLOAT64}.
* Returns dtype of current tensor.
*/
public abstract int dtype();
public abstract DType dtype();
int dtypeJniCode() {
return dtype().jniCode;
}
/**
* Returns newly allocated java byte array that contains a copy of tensor data.
@ -423,8 +415,8 @@ public abstract class Tensor {
}
@Override
public int dtype() {
return DTYPE_UINT8;
public DType dtype() {
return DType.UINT8;
}
@Override
@ -455,8 +447,8 @@ public abstract class Tensor {
}
@Override
public int dtype() {
return DTYPE_INT8;
public DType dtype() {
return DType.INT8;
}
@Override
@ -487,8 +479,8 @@ public abstract class Tensor {
}
@Override
public int dtype() {
return DTYPE_INT32;
public DType dtype() {
return DType.INT32;
}
@Override
@ -527,8 +519,8 @@ public abstract class Tensor {
}
@Override
public int dtype() {
return DTYPE_FLOAT32;
public DType dtype() {
return DType.FLOAT32;
}
@Override
@ -551,8 +543,8 @@ public abstract class Tensor {
}
@Override
public int dtype() {
return DTYPE_INT64;
public DType dtype() {
return DType.INT64;
}
@Override
@ -583,8 +575,8 @@ public abstract class Tensor {
}
@Override
public int dtype() {
return DTYPE_FLOAT64;
public DType dtype() {
return DType.FLOAT64;
}
@Override
@ -634,17 +626,17 @@ public abstract class Tensor {
// Called from native
private static Tensor nativeNewTensor(ByteBuffer data, long[] shape, int dtype) {
if (DTYPE_FLOAT32 == dtype) {
if (DType.FLOAT32.jniCode == dtype) {
return new Tensor_float32(data.asFloatBuffer(), shape);
} else if (DTYPE_INT32 == dtype) {
} else if (DType.INT32.jniCode == dtype) {
return new Tensor_int32(data.asIntBuffer(), shape);
} else if (DTYPE_INT64 == dtype) {
} else if (DType.INT64.jniCode == dtype) {
return new Tensor_int64(data.asLongBuffer(), shape);
} else if (DTYPE_FLOAT64 == dtype) {
} else if (DType.FLOAT64.jniCode == dtype) {
return new Tensor_float64(data.asDoubleBuffer(), shape);
} else if (DTYPE_UINT8 == dtype) {
} else if (DType.UINT8.jniCode == dtype) {
return new Tensor_uint8(data, shape);
} else if (DTYPE_INT8 == dtype) {
} else if (DType.INT8.jniCode == dtype) {
return new Tensor_int8(data, shape);
}
throw new IllegalArgumentException("Unknown Tensor dtype");

View File

@ -106,7 +106,7 @@ public final class TensorImageUtils {
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * width * height);
bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0);
return Tensor.newFloat32Tensor(new long[]{1, 3, height, width}, floatBuffer);
return Tensor.fromBlob(floatBuffer, new long[]{1, 3, height, width});
}
/**
@ -146,7 +146,7 @@ public final class TensorImageUtils {
tensorWidth,
tensorHeight,
normMeanRGB, normStdRGB, floatBuffer, 0);
return Tensor.newFloat32Tensor(new long[]{1, 3, tensorHeight, tensorWidth}, floatBuffer);
return Tensor.fromBlob(floatBuffer, new long[]{1, 3, tensorHeight, tensorWidth});
}
/**