mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Various cleanups to pytorch_android API (#27454)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27454 See detailed discussion at https://github.com/pytorch/pytorch/issues/27350 Test Plan: Imported from OSS Reviewed By: IvanKobzarev Differential Revision: D17800480 Pulled By: dreiss fbshipit-source-id: bf174e8b16231b89be771de0fa54c41e864a3eb0
This commit is contained in:
committed by
Facebook Github Bot
parent
b66df47a11
commit
1ffa81d772
@ -24,7 +24,7 @@ public abstract class PytorchTestBase {
|
||||
public void testForwardNull() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input =
|
||||
IValue.tensor(Tensor.newInt8Tensor(new long[] {1}, Tensor.allocateByteBuffer(1)));
|
||||
IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
|
||||
assertTrue(input.isTensor());
|
||||
final IValue output = module.forward(input);
|
||||
assertTrue(output.isNull());
|
||||
@ -34,12 +34,12 @@ public abstract class PytorchTestBase {
|
||||
public void testEqBool() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
for (boolean value : new boolean[] {false, true}) {
|
||||
final IValue input = IValue.bool(value);
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isBool());
|
||||
assertTrue(value == input.getBool());
|
||||
assertTrue(value == input.toBool());
|
||||
final IValue output = module.runMethod("eqBool", input);
|
||||
assertTrue(output.isBool());
|
||||
assertTrue(value == output.getBool());
|
||||
assertTrue(value == output.toBool());
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,12 +47,12 @@ public abstract class PytorchTestBase {
|
||||
public void testEqInt() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
|
||||
final IValue input = IValue.long64(value);
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isLong());
|
||||
assertTrue(value == input.getLong());
|
||||
assertTrue(value == input.toLong());
|
||||
final IValue output = module.runMethod("eqInt", input);
|
||||
assertTrue(output.isLong());
|
||||
assertTrue(value == output.getLong());
|
||||
assertTrue(value == output.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,12 +74,12 @@ public abstract class PytorchTestBase {
|
||||
1,
|
||||
};
|
||||
for (double value : values) {
|
||||
final IValue input = IValue.double64(value);
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isDouble());
|
||||
assertTrue(value == input.getDouble());
|
||||
assertTrue(value == input.toDouble());
|
||||
final IValue output = module.runMethod("eqFloat", input);
|
||||
assertTrue(output.isDouble());
|
||||
assertTrue(value == output.getDouble());
|
||||
assertTrue(value == output.toDouble());
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,17 +91,17 @@ public abstract class PytorchTestBase {
|
||||
for (int i = 0; i < numElements; ++i) {
|
||||
inputTensorData[i] = i;
|
||||
}
|
||||
final Tensor inputTensor = Tensor.newFloat32Tensor(inputTensorShape, inputTensorData);
|
||||
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
|
||||
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final IValue input = IValue.tensor(inputTensor);
|
||||
final IValue input = IValue.from(inputTensor);
|
||||
assertTrue(input.isTensor());
|
||||
assertTrue(inputTensor == input.getTensor());
|
||||
assertTrue(inputTensor == input.toTensor());
|
||||
final IValue output = module.runMethod("eqTensor", input);
|
||||
assertTrue(output.isTensor());
|
||||
final Tensor outputTensor = output.getTensor();
|
||||
final Tensor outputTensor = output.toTensor();
|
||||
assertNotNull(outputTensor);
|
||||
assertArrayEquals(inputTensorShape, outputTensor.shape);
|
||||
assertArrayEquals(inputTensorShape, outputTensor.shape());
|
||||
float[] outputData = outputTensor.getDataAsFloatArray();
|
||||
for (int i = 0; i < numElements; i++) {
|
||||
assertTrue(inputTensorData[i] == outputData[i]);
|
||||
@ -113,22 +113,22 @@ public abstract class PytorchTestBase {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Map<Long, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put(Long.MIN_VALUE, IValue.long64(-Long.MIN_VALUE));
|
||||
inputMap.put(Long.MAX_VALUE, IValue.long64(-Long.MAX_VALUE));
|
||||
inputMap.put(0l, IValue.long64(0l));
|
||||
inputMap.put(1l, IValue.long64(-1l));
|
||||
inputMap.put(-1l, IValue.long64(1l));
|
||||
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
|
||||
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
|
||||
inputMap.put(0l, IValue.from(0l));
|
||||
inputMap.put(1l, IValue.from(-1l));
|
||||
inputMap.put(-1l, IValue.from(1l));
|
||||
|
||||
final IValue input = IValue.dictLongKey(inputMap);
|
||||
final IValue input = IValue.dictLongKeyFrom(inputMap);
|
||||
assertTrue(input.isDictLongKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
|
||||
assertTrue(output.isDictLongKey());
|
||||
|
||||
final Map<Long, IValue> outputMap = output.getDictLongKey();
|
||||
final Map<Long, IValue> outputMap = output.toDictLongKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).getLong() == entry.getValue().getLong());
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@ -137,22 +137,22 @@ public abstract class PytorchTestBase {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Map<String, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put("long_min_value", IValue.long64(Long.MIN_VALUE));
|
||||
inputMap.put("long_max_value", IValue.long64(Long.MAX_VALUE));
|
||||
inputMap.put("long_0", IValue.long64(0l));
|
||||
inputMap.put("long_1", IValue.long64(1l));
|
||||
inputMap.put("long_-1", IValue.long64(-1l));
|
||||
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
|
||||
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
|
||||
inputMap.put("long_0", IValue.from(0l));
|
||||
inputMap.put("long_1", IValue.from(1l));
|
||||
inputMap.put("long_-1", IValue.from(-1l));
|
||||
|
||||
final IValue input = IValue.dictStringKey(inputMap);
|
||||
final IValue input = IValue.dictStringKeyFrom(inputMap);
|
||||
assertTrue(input.isDictStringKey());
|
||||
|
||||
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
|
||||
assertTrue(output.isDictStringKey());
|
||||
|
||||
final Map<String, IValue> outputMap = output.getDictStringKey();
|
||||
final Map<String, IValue> outputMap = output.toDictStringKey();
|
||||
assertTrue(inputMap.size() == outputMap.size());
|
||||
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
|
||||
assertTrue(outputMap.get(entry.getKey()).getLong() == entry.getValue().getLong());
|
||||
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@ -167,19 +167,19 @@ public abstract class PytorchTestBase {
|
||||
a[i] = i;
|
||||
sum += a[i];
|
||||
}
|
||||
final IValue input = IValue.longList(a);
|
||||
final IValue input = IValue.listFrom(a);
|
||||
assertTrue(input.isLongList());
|
||||
|
||||
final IValue output = module.runMethod("listIntSumReturnTuple", input);
|
||||
|
||||
assertTrue(output.isTuple());
|
||||
assertTrue(2 == output.getTuple().length);
|
||||
assertTrue(2 == output.toTuple().length);
|
||||
|
||||
IValue output0 = output.getTuple()[0];
|
||||
IValue output1 = output.getTuple()[1];
|
||||
IValue output0 = output.toTuple()[0];
|
||||
IValue output1 = output.toTuple()[1];
|
||||
|
||||
assertArrayEquals(a, output0.getLongList());
|
||||
assertTrue(sum == output1.getLong());
|
||||
assertArrayEquals(a, output0.toLongList());
|
||||
assertTrue(sum == output1.toLong());
|
||||
}
|
||||
}
|
||||
|
||||
@ -187,16 +187,16 @@ public abstract class PytorchTestBase {
|
||||
public void testOptionalIntIsNone() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
assertFalse(module.runMethod("optionalIntIsNone", IValue.long64(1l)).getBool());
|
||||
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).getBool());
|
||||
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
|
||||
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntEq0None() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
|
||||
assertTrue(module.runMethod("intEq0None", IValue.long64(0l)).isNull());
|
||||
assertTrue(module.runMethod("intEq0None", IValue.long64(1l)).getLong() == 1l);
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@ -219,16 +219,16 @@ public abstract class PytorchTestBase {
|
||||
floats[i] = i / 1000.f;
|
||||
}
|
||||
|
||||
Tensor tensorBytes = Tensor.newInt8Tensor(shape, bytes);
|
||||
assertTrue(tensorBytes.dtype() == Tensor.DTYPE_INT8);
|
||||
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
|
||||
assertTrue(tensorBytes.dtype() == DType.INT8);
|
||||
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
|
||||
|
||||
Tensor tensorInts = Tensor.newInt32Tensor(shape, ints);
|
||||
assertTrue(tensorInts.dtype() == Tensor.DTYPE_INT32);
|
||||
Tensor tensorInts = Tensor.fromBlob(ints, shape);
|
||||
assertTrue(tensorInts.dtype() == DType.INT32);
|
||||
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
|
||||
|
||||
Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats);
|
||||
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
float[] floatsOut = tensorFloats.getDataAsFloatArray();
|
||||
assertTrue(floatsOut.length == numel);
|
||||
for (int i = 0; i < numel; i++) {
|
||||
@ -241,8 +241,8 @@ public abstract class PytorchTestBase {
|
||||
long[] shape = new long[] {1, 3, 224, 224};
|
||||
final int numel = (int) Tensor.numel(shape);
|
||||
float[] floats = new float[numel];
|
||||
Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats);
|
||||
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
|
||||
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
|
||||
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
|
||||
tensorFloats.getDataAsByteArray();
|
||||
}
|
||||
|
||||
@ -257,12 +257,12 @@ public abstract class PytorchTestBase {
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.string(value);
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.getString()));
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("eqStr", input);
|
||||
assertTrue(output.isString());
|
||||
assertTrue(value.equals(output.getString()));
|
||||
assertTrue(value.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -276,13 +276,13 @@ public abstract class PytorchTestBase {
|
||||
"#@$!@#)($*!@#$)(!@*#$"
|
||||
};
|
||||
for (String value : values) {
|
||||
final IValue input = IValue.string(value);
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isString());
|
||||
assertTrue(value.equals(input.getString()));
|
||||
assertTrue(value.equals(input.toStr()));
|
||||
final IValue output = module.runMethod("str3Concat", input);
|
||||
assertTrue(output.isString());
|
||||
String expectedOutput = new StringBuilder().append(value).append(value).append(value).toString();
|
||||
assertTrue(expectedOutput.equals(output.getString()));
|
||||
assertTrue(expectedOutput.equals(output.toStr()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,8 @@
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
// NOTE: Codes must be kept in sync with DType.java.
|
||||
// NOTE: Never serialize these, because they can change between releases.
|
||||
constexpr static int kTensorDTypeUInt8 = 1;
|
||||
constexpr static int kTensorDTypeInt8 = 2;
|
||||
constexpr static int kTensorDTypeInt32 = 3;
|
||||
@ -164,7 +166,7 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
|
||||
static at::Tensor newAtTensorFromJTensor(
|
||||
facebook::jni::alias_ref<JTensor> jtensor) {
|
||||
static const auto dtypeMethod =
|
||||
JTensor::javaClassStatic()->getMethod<jint()>("dtype");
|
||||
JTensor::javaClassStatic()->getMethod<jint()>("dtypeJniCode");
|
||||
jint jdtype = dtypeMethod(jtensor);
|
||||
|
||||
static const auto shapeField =
|
||||
@ -216,7 +218,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethodTensor =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::local_ref<JTensor>)>("tensor");
|
||||
facebook::jni::local_ref<JTensor>)>("from");
|
||||
return jMethodTensor(
|
||||
JIValue::javaClassStatic(),
|
||||
JTensor::newJTensorFromAtTensor(ivalue.toTensor()));
|
||||
@ -224,26 +226,26 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethodBool =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jboolean)>(
|
||||
"bool");
|
||||
"from");
|
||||
return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool());
|
||||
} else if (ivalue.isInt()) {
|
||||
static auto jMethodInt =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jlong)>(
|
||||
"long64");
|
||||
"from");
|
||||
return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt());
|
||||
} else if (ivalue.isDouble()) {
|
||||
static auto jMethodDouble =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(jdouble)>(
|
||||
"double64");
|
||||
"from");
|
||||
return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble());
|
||||
} else if (ivalue.isString()) {
|
||||
static auto jMethodString =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JString::javaobject>)>("string");
|
||||
facebook::jni::JString::javaobject>)>("from");
|
||||
return jMethodString(
|
||||
JIValue::javaClassStatic(),
|
||||
facebook::jni::make_jstring(ivalue.toStringRef()));
|
||||
@ -253,7 +255,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject>)>("tuple");
|
||||
JIValue::javaobject>::javaobject>)>("tupleFrom");
|
||||
auto jElementsArray =
|
||||
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
|
||||
elementsVec.size());
|
||||
@ -267,7 +269,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethodBoolListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jbooleanArray>)>("boolList");
|
||||
facebook::jni::alias_ref<jbooleanArray>)>("listFrom");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_boolean_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
@ -281,7 +283,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethodLongListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jlongArray>)>("longList");
|
||||
facebook::jni::alias_ref<jlongArray>)>("listFrom");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_long_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
@ -295,7 +297,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static auto jMethoDoubleListArr =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<jdoubleArray>)>("doubleList");
|
||||
facebook::jni::alias_ref<jdoubleArray>)>("listFrom");
|
||||
size_t n = list.size();
|
||||
auto jArray = facebook::jni::make_double_array(n);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
@ -310,7 +312,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JTensor::javaobject>::javaobject>)>("tensorList");
|
||||
JTensor::javaobject>::javaobject>)>("listFrom");
|
||||
auto jArray = facebook::jni::JArrayClass<JTensor::javaobject>::newArray(
|
||||
list.size());
|
||||
auto index = 0;
|
||||
@ -324,7 +326,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject>)>("list");
|
||||
JIValue::javaobject>::javaobject>)>("listFrom");
|
||||
auto jArray = facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
|
||||
list.size());
|
||||
auto index = 0;
|
||||
@ -351,7 +353,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JString::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictStringKey");
|
||||
"dictStringKeyFrom");
|
||||
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
||||
@ -370,7 +372,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictLongKey");
|
||||
"dictLongKeyFrom");
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||
@ -404,32 +406,32 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static const auto jMethodGetTensor =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::alias_ref<JTensor::javaobject>()>(
|
||||
"getTensor");
|
||||
"toTensor");
|
||||
return JTensor::newAtTensorFromJTensor(jMethodGetTensor(jivalue));
|
||||
} else if (JIValue::kTypeCodeBool == typeCode) {
|
||||
static const auto jMethodGetBool =
|
||||
JIValue::javaClassStatic()->getMethod<jboolean()>("getBool");
|
||||
JIValue::javaClassStatic()->getMethod<jboolean()>("toBool");
|
||||
// explicit cast to bool as jboolean is defined as uint8_t, IValue ctor
|
||||
// for int will be called for jboolean
|
||||
bool b = jMethodGetBool(jivalue);
|
||||
return at::IValue{b};
|
||||
} else if (JIValue::kTypeCodeLong == typeCode) {
|
||||
static const auto jMethodGetLong =
|
||||
JIValue::javaClassStatic()->getMethod<jlong()>("getLong");
|
||||
JIValue::javaClassStatic()->getMethod<jlong()>("toLong");
|
||||
return at::IValue{jMethodGetLong(jivalue)};
|
||||
} else if (JIValue::kTypeCodeDouble == typeCode) {
|
||||
static const auto jMethodGetDouble =
|
||||
JIValue::javaClassStatic()->getMethod<jdouble()>("getDouble");
|
||||
JIValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
|
||||
return at::IValue{jMethodGetDouble(jivalue)};
|
||||
} else if (JIValue::kTypeCodeString == typeCode) {
|
||||
static const auto jMethodGetString =
|
||||
JIValue::javaClassStatic()->getMethod<jstring()>("getString");
|
||||
JIValue::javaClassStatic()->getMethod<jstring()>("toStr");
|
||||
return at::IValue{jMethodGetString(jivalue)->toStdString()};
|
||||
} else if (JIValue::kTypeCodeTuple == typeCode) {
|
||||
static const auto jMethodGetTuple =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject()>("getTuple");
|
||||
JIValue::javaobject>::javaobject()>("toTuple");
|
||||
auto jarray = jMethodGetTuple(jivalue);
|
||||
size_t n = jarray->size();
|
||||
|
||||
@ -443,7 +445,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
return c10::ivalue::Tuple::create(std::move(elements));
|
||||
} else if (JIValue::kTypeCodeBoolList == typeCode) {
|
||||
static const auto jMethodGetBoolList =
|
||||
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("getBoolList");
|
||||
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("toBoolList");
|
||||
auto jArray = jMethodGetBoolList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
@ -455,7 +457,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
return at::IValue{std::move(list)};
|
||||
} else if (JIValue::kTypeCodeLongList == typeCode) {
|
||||
static const auto jMethodGetLongList =
|
||||
JIValue::javaClassStatic()->getMethod<jlongArray()>("getLongList");
|
||||
JIValue::javaClassStatic()->getMethod<jlongArray()>("toLongList");
|
||||
auto jArray = jMethodGetLongList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
@ -468,7 +470,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
} else if (JIValue::kTypeCodeDoubleList == typeCode) {
|
||||
static const auto jMethodGetDoubleList =
|
||||
JIValue::javaClassStatic()->getMethod<jdoubleArray()>(
|
||||
"getDoubleList");
|
||||
"toDoubleList");
|
||||
auto jArray = jMethodGetDoubleList(jivalue);
|
||||
auto jArrayPinned = jArray->pin();
|
||||
size_t n = jArrayPinned.size();
|
||||
@ -482,7 +484,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static const auto jMethodGetTensorList =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JArrayClass<
|
||||
JTensor::javaobject>::javaobject()>("getTensorList");
|
||||
JTensor::javaobject>::javaobject()>("toTensorList");
|
||||
auto jArray = jMethodGetTensorList(jivalue);
|
||||
size_t n = jArray->size();
|
||||
c10::List<at::Tensor> list{};
|
||||
@ -495,7 +497,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static const auto jMethodGetList =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JArrayClass<
|
||||
JIValue::javaobject>::javaobject()>("getList");
|
||||
JIValue::javaobject>::javaobject()>("toList");
|
||||
auto jarray = jMethodGetList(jivalue);
|
||||
size_t n = jarray->size();
|
||||
if (n == 0) {
|
||||
@ -518,7 +520,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
static const auto jMethodGetDictStringKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JMap<jstring, JIValue::javaobject>::
|
||||
javaobject()>("getDictStringKey");
|
||||
javaobject()>("toDictStringKey");
|
||||
auto jmap = jMethodGetDictStringKey(jivalue);
|
||||
auto it = jmap->begin();
|
||||
if (it == jmap->end()) {
|
||||
@ -541,7 +543,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
JIValue::javaClassStatic()
|
||||
->getMethod<facebook::jni::JMap<
|
||||
facebook::jni::JLong::javaobject,
|
||||
JIValue::javaobject>::javaobject()>("getDictLongKey");
|
||||
JIValue::javaobject>::javaobject()>("toDictLongKey");
|
||||
auto jmap = jMethodGetDictLongKey(jivalue);
|
||||
auto it = jmap->begin();
|
||||
if (it == jmap->end()) {
|
||||
|
29
android/pytorch_android/src/main/java/org/pytorch/DType.java
Normal file
29
android/pytorch_android/src/main/java/org/pytorch/DType.java
Normal file
@ -0,0 +1,29 @@
|
||||
package org.pytorch;
|
||||
|
||||
/**
|
||||
* Codes representing tensor data types.
|
||||
*/
|
||||
public enum DType {
|
||||
// NOTE: "jniCode" must be kept in sync with pytorch_jni.cpp.
|
||||
// NOTE: Never serialize "jniCode", because it can change between releases.
|
||||
|
||||
/** Code for dtype torch.uint8. {@link Tensor#dtype()} */
|
||||
UINT8(1),
|
||||
/** Code for dtype torch.int8. {@link Tensor#dtype()} */
|
||||
INT8(2),
|
||||
/** Code for dtype torch.int32. {@link Tensor#dtype()} */
|
||||
INT32(3),
|
||||
/** Code for dtype torch.float32. {@link Tensor#dtype()} */
|
||||
FLOAT32(4),
|
||||
/** Code for dtype torch.int64. {@link Tensor#dtype()} */
|
||||
INT64(5),
|
||||
/** Code for dtype torch.float64. {@link Tensor#dtype()} */
|
||||
FLOAT64(6),
|
||||
;
|
||||
|
||||
final int jniCode;
|
||||
|
||||
DType(int jniCode) {
|
||||
this.jniCode = jniCode;
|
||||
}
|
||||
}
|
@ -98,7 +98,7 @@ public class IValue {
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript Tensor type.
|
||||
*/
|
||||
public static IValue tensor(Tensor tensor) {
|
||||
public static IValue from(Tensor tensor) {
|
||||
final IValue iv = new IValue(TYPE_CODE_TENSOR);
|
||||
iv.mData = tensor;
|
||||
return iv;
|
||||
@ -107,7 +107,7 @@ public class IValue {
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript bool type.
|
||||
*/
|
||||
public static IValue bool(boolean value) {
|
||||
public static IValue from(boolean value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_BOOL);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
@ -116,7 +116,7 @@ public class IValue {
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript int type.
|
||||
*/
|
||||
public static IValue long64(long value) {
|
||||
public static IValue from(long value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_LONG);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
@ -125,7 +125,7 @@ public class IValue {
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript float type.
|
||||
*/
|
||||
public static IValue double64(double value) {
|
||||
public static IValue from(double value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DOUBLE);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
@ -134,7 +134,7 @@ public class IValue {
|
||||
/**
|
||||
* Creates new IValue instance of torchscript str type.
|
||||
*/
|
||||
public static IValue string(String value) {
|
||||
public static IValue from(String value) {
|
||||
final IValue iv = new IValue(TYPE_CODE_STRING);
|
||||
iv.mData = value;
|
||||
return iv;
|
||||
@ -143,7 +143,7 @@ public class IValue {
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript List[bool] type.
|
||||
*/
|
||||
public static IValue boolList(boolean... list) {
|
||||
public static IValue listFrom(boolean... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_BOOL_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
@ -152,7 +152,7 @@ public class IValue {
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript List[int] type.
|
||||
*/
|
||||
public static IValue longList(long... list) {
|
||||
public static IValue listFrom(long... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_LONG_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
@ -161,7 +161,7 @@ public class IValue {
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript List[float] type.
|
||||
*/
|
||||
public static IValue doubleList(double... list) {
|
||||
public static IValue listFrom(double... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DOUBLE_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
@ -170,7 +170,7 @@ public class IValue {
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript List[Tensor] type.
|
||||
*/
|
||||
public static IValue tensorList(Tensor... list) {
|
||||
public static IValue listFrom(Tensor... list) {
|
||||
final IValue iv = new IValue(TYPE_CODE_TENSOR_LIST);
|
||||
iv.mData = list;
|
||||
return iv;
|
||||
@ -179,7 +179,7 @@ public class IValue {
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript List[T] type. All elements must have the same type.
|
||||
*/
|
||||
public static IValue list(IValue... array) {
|
||||
public static IValue listFrom(IValue... array) {
|
||||
final int size = array.length;
|
||||
if (size > 0) {
|
||||
final int typeCode0 = array[0].mTypeCode;
|
||||
@ -198,7 +198,7 @@ public class IValue {
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript Tuple[T0, T1, ...] type.
|
||||
*/
|
||||
public static IValue tuple(IValue... array) {
|
||||
public static IValue tupleFrom(IValue... array) {
|
||||
final IValue iv = new IValue(TYPE_CODE_TUPLE);
|
||||
iv.mData = array;
|
||||
return iv;
|
||||
@ -207,7 +207,7 @@ public class IValue {
|
||||
/**
|
||||
* Creates a new IValue instance oftorchscript Dict[Str, V] type.
|
||||
*/
|
||||
public static IValue dictStringKey(Map<String, IValue> map) {
|
||||
public static IValue dictStringKeyFrom(Map<String, IValue> map) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DICT_STRING_KEY);
|
||||
iv.mData = map;
|
||||
return iv;
|
||||
@ -216,73 +216,73 @@ public class IValue {
|
||||
/**
|
||||
* Creates a new IValue instance of torchscript Dict[int, V] type.
|
||||
*/
|
||||
public static IValue dictLongKey(Map<Long, IValue> map) {
|
||||
public static IValue dictLongKeyFrom(Map<Long, IValue> map) {
|
||||
final IValue iv = new IValue(TYPE_CODE_DICT_LONG_KEY);
|
||||
iv.mData = map;
|
||||
return iv;
|
||||
}
|
||||
|
||||
public Tensor getTensor() {
|
||||
public Tensor toTensor() {
|
||||
preconditionType(TYPE_CODE_TENSOR, mTypeCode);
|
||||
return (Tensor) mData;
|
||||
}
|
||||
|
||||
public boolean getBool() {
|
||||
public boolean toBool() {
|
||||
preconditionType(TYPE_CODE_BOOL, mTypeCode);
|
||||
return (boolean) mData;
|
||||
}
|
||||
|
||||
public long getLong() {
|
||||
public long toLong() {
|
||||
preconditionType(TYPE_CODE_LONG, mTypeCode);
|
||||
return (long) mData;
|
||||
}
|
||||
|
||||
public double getDouble() {
|
||||
public double toDouble() {
|
||||
preconditionType(TYPE_CODE_DOUBLE, mTypeCode);
|
||||
return (double) mData;
|
||||
}
|
||||
|
||||
public String getString() {
|
||||
public String toStr() {
|
||||
preconditionType(TYPE_CODE_STRING, mTypeCode);
|
||||
return (String) mData;
|
||||
}
|
||||
|
||||
public boolean[] getBoolList() {
|
||||
public boolean[] toBoolList() {
|
||||
preconditionType(TYPE_CODE_BOOL_LIST, mTypeCode);
|
||||
return (boolean[]) mData;
|
||||
}
|
||||
|
||||
public long[] getLongList() {
|
||||
public long[] toLongList() {
|
||||
preconditionType(TYPE_CODE_LONG_LIST, mTypeCode);
|
||||
return (long[]) mData;
|
||||
}
|
||||
|
||||
public double[] getDoubleList() {
|
||||
public double[] toDoubleList() {
|
||||
preconditionType(TYPE_CODE_DOUBLE_LIST, mTypeCode);
|
||||
return (double[]) mData;
|
||||
}
|
||||
|
||||
public Tensor[] getTensorList() {
|
||||
public Tensor[] toTensorList() {
|
||||
preconditionType(TYPE_CODE_TENSOR_LIST, mTypeCode);
|
||||
return (Tensor[]) mData;
|
||||
}
|
||||
|
||||
public IValue[] getList() {
|
||||
public IValue[] toList() {
|
||||
preconditionType(TYPE_CODE_LIST, mTypeCode);
|
||||
return (IValue[]) mData;
|
||||
}
|
||||
|
||||
public IValue[] getTuple() {
|
||||
public IValue[] toTuple() {
|
||||
preconditionType(TYPE_CODE_TUPLE, mTypeCode);
|
||||
return (IValue[]) mData;
|
||||
}
|
||||
|
||||
public Map<String, IValue> getDictStringKey() {
|
||||
public Map<String, IValue> toDictStringKey() {
|
||||
preconditionType(TYPE_CODE_DICT_STRING_KEY, mTypeCode);
|
||||
return (Map<String, IValue>) mData;
|
||||
}
|
||||
|
||||
public Map<Long, IValue> getDictLongKey() {
|
||||
public Map<Long, IValue> toDictLongKey() {
|
||||
preconditionType(TYPE_CODE_DICT_LONG_KEY, mTypeCode);
|
||||
return (Map<Long, IValue>) mData;
|
||||
}
|
||||
|
@ -15,20 +15,6 @@ import java.util.Locale;
|
||||
* {@link java.nio.DirectByteBuffer} of one of the supported types.
|
||||
*/
|
||||
public abstract class Tensor {
|
||||
|
||||
/** Code for dtype torch.uint8. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_UINT8 = 1;
|
||||
/** Code for dtype torch.int8. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_INT8 = 2;
|
||||
/** Code for dtype torch.int32. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_INT32 = 3;
|
||||
/** Code for dtype torch.float32. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_FLOAT32 = 4;
|
||||
/** Code for dtype torch.int64. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_INT64 = 5;
|
||||
/** Code for dtype torch.float64. {@link Tensor#dtype()} */
|
||||
public static final int DTYPE_FLOAT64 = 6;
|
||||
|
||||
private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null";
|
||||
private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null";
|
||||
private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null";
|
||||
@ -39,8 +25,7 @@ public abstract class Tensor {
|
||||
private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT =
|
||||
"Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)";
|
||||
|
||||
/** Shape of current tensor. */
|
||||
public final long[] shape;
|
||||
final long[] shape;
|
||||
|
||||
private static final int INT_SIZE_BYTES = 4;
|
||||
private static final int FLOAT_SIZE_BYTES = 4;
|
||||
@ -49,8 +34,8 @@ public abstract class Tensor {
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.ByteBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#newInt8Tensor(long[], ByteBuffer)}, {@link
|
||||
* Tensor#newUInt8Tensor(long[], ByteBuffer)}.
|
||||
* capacity that can be used in {@link Tensor#fromBlob(ByteBuffer, long[])},
|
||||
* {@link Tensor#fromBlobUnsigned(ByteBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
@ -66,7 +51,7 @@ public abstract class Tensor {
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.FloatBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#newFloat32Tensor(long[], FloatBuffer)}.
|
||||
* capacity that can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
@ -78,7 +63,7 @@ public abstract class Tensor {
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.LongBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#newInt64Tensor(long[], LongBuffer)}.
|
||||
* capacity that can be used in {@link Tensor#fromBlob(LongBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
@ -90,7 +75,7 @@ public abstract class Tensor {
|
||||
|
||||
/**
|
||||
* Allocates a new direct {@link java.nio.DoubleBuffer} with native byte order with specified
|
||||
* capacity that can be used in {@link Tensor#newFloat64Tensor(long[], DoubleBuffer)}.
|
||||
* capacity that can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}.
|
||||
*
|
||||
* @param numElements capacity (number of elements) of result buffer.
|
||||
*/
|
||||
@ -104,10 +89,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of
|
||||
* bytes.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newUInt8Tensor(long[] shape, byte[] data) {
|
||||
public static Tensor fromBlobUnsigned(byte[] data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -121,10 +106,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of
|
||||
* bytes.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newInt8Tensor(long[] shape, byte[] data) {
|
||||
public static Tensor fromBlob(byte[] data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -138,10 +123,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of
|
||||
* ints.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newInt32Tensor(long[] shape, int[] data) {
|
||||
public static Tensor fromBlob(int[] data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -155,10 +140,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array
|
||||
* of floats.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newFloat32Tensor(long[] shape, float[] data) {
|
||||
public static Tensor fromBlob(float[] data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -172,10 +157,10 @@ public abstract class Tensor {
|
||||
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of
|
||||
* longs.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newInt64Tensor(long[] shape, long[] data) {
|
||||
public static Tensor fromBlob(long[] data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -192,7 +177,7 @@ public abstract class Tensor {
|
||||
* @param shape Tensor shape
|
||||
* @param data Tensor elements
|
||||
*/
|
||||
public static Tensor newFloat64Tensor(long[] shape, double[] data) {
|
||||
public static Tensor fromBlob(long[] shape, double[] data) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -205,12 +190,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.uint8 with specified shape and data.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newUInt8Tensor(long[] shape, ByteBuffer data) {
|
||||
public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -225,12 +210,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int8 with specified shape and data.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newInt8Tensor(long[] shape, ByteBuffer data) {
|
||||
public static Tensor fromBlob(ByteBuffer data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -245,12 +230,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int32 with specified shape and data.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newInt32Tensor(long[] shape, IntBuffer data) {
|
||||
public static Tensor fromBlob(IntBuffer data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -265,12 +250,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.float32 with specified shape and data.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newFloat32Tensor(long[] shape, FloatBuffer data) {
|
||||
public static Tensor fromBlob(FloatBuffer data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -285,12 +270,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newInt64Tensor(long[] shape, LongBuffer data) {
|
||||
public static Tensor fromBlob(LongBuffer data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -305,12 +290,12 @@ public abstract class Tensor {
|
||||
/**
|
||||
* Creates a new Tensor instance with dtype torch.float64 with specified shape and data.
|
||||
*
|
||||
* @param shape Tensor shape
|
||||
* @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)}
|
||||
* elements. The buffer is used directly without copying, and changes to its content will
|
||||
* change the tensor.
|
||||
* @param shape Tensor shape
|
||||
*/
|
||||
public static Tensor newFloat64Tensor(long[] shape, DoubleBuffer data) {
|
||||
public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
|
||||
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
|
||||
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
|
||||
checkShape(shape);
|
||||
@ -342,12 +327,19 @@ public abstract class Tensor {
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Shape of current tensor. */
|
||||
public long[] shape() {
|
||||
return Arrays.copyOf(shape, shape.length);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns dtype of current tensor. Can be one of {@link Tensor#DTYPE_UINT8}, {@link
|
||||
* Tensor#DTYPE_INT8}, {@link Tensor#DTYPE_INT32},{@link Tensor#DTYPE_FLOAT32}, {@link
|
||||
* Tensor#DTYPE_INT64}, {@link Tensor#DTYPE_FLOAT64}.
|
||||
* Returns dtype of current tensor.
|
||||
*/
|
||||
public abstract int dtype();
|
||||
public abstract DType dtype();
|
||||
|
||||
int dtypeJniCode() {
|
||||
return dtype().jniCode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns newly allocated java byte array that contains a copy of tensor data.
|
||||
@ -423,8 +415,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_UINT8;
|
||||
public DType dtype() {
|
||||
return DType.UINT8;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -455,8 +447,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_INT8;
|
||||
public DType dtype() {
|
||||
return DType.INT8;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -487,8 +479,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_INT32;
|
||||
public DType dtype() {
|
||||
return DType.INT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -527,8 +519,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_FLOAT32;
|
||||
public DType dtype() {
|
||||
return DType.FLOAT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -551,8 +543,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_INT64;
|
||||
public DType dtype() {
|
||||
return DType.INT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -583,8 +575,8 @@ public abstract class Tensor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dtype() {
|
||||
return DTYPE_FLOAT64;
|
||||
public DType dtype() {
|
||||
return DType.FLOAT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -634,17 +626,17 @@ public abstract class Tensor {
|
||||
|
||||
// Called from native
|
||||
private static Tensor nativeNewTensor(ByteBuffer data, long[] shape, int dtype) {
|
||||
if (DTYPE_FLOAT32 == dtype) {
|
||||
if (DType.FLOAT32.jniCode == dtype) {
|
||||
return new Tensor_float32(data.asFloatBuffer(), shape);
|
||||
} else if (DTYPE_INT32 == dtype) {
|
||||
} else if (DType.INT32.jniCode == dtype) {
|
||||
return new Tensor_int32(data.asIntBuffer(), shape);
|
||||
} else if (DTYPE_INT64 == dtype) {
|
||||
} else if (DType.INT64.jniCode == dtype) {
|
||||
return new Tensor_int64(data.asLongBuffer(), shape);
|
||||
} else if (DTYPE_FLOAT64 == dtype) {
|
||||
} else if (DType.FLOAT64.jniCode == dtype) {
|
||||
return new Tensor_float64(data.asDoubleBuffer(), shape);
|
||||
} else if (DTYPE_UINT8 == dtype) {
|
||||
} else if (DType.UINT8.jniCode == dtype) {
|
||||
return new Tensor_uint8(data, shape);
|
||||
} else if (DTYPE_INT8 == dtype) {
|
||||
} else if (DType.INT8.jniCode == dtype) {
|
||||
return new Tensor_int8(data, shape);
|
||||
}
|
||||
throw new IllegalArgumentException("Unknown Tensor dtype");
|
||||
|
@ -106,7 +106,7 @@ public final class TensorImageUtils {
|
||||
|
||||
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * width * height);
|
||||
bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0);
|
||||
return Tensor.newFloat32Tensor(new long[]{1, 3, height, width}, floatBuffer);
|
||||
return Tensor.fromBlob(floatBuffer, new long[]{1, 3, height, width});
|
||||
}
|
||||
|
||||
/**
|
||||
@ -146,7 +146,7 @@ public final class TensorImageUtils {
|
||||
tensorWidth,
|
||||
tensorHeight,
|
||||
normMeanRGB, normStdRGB, floatBuffer, 0);
|
||||
return Tensor.newFloat32Tensor(new long[]{1, 3, tensorHeight, tensorWidth}, floatBuffer);
|
||||
return Tensor.fromBlob(floatBuffer, new long[]{1, 3, tensorHeight, tensorWidth});
|
||||
}
|
||||
|
||||
/**
|
||||
|
Reference in New Issue
Block a user