Use == operator to test type equivalance in pytorch_jni_common.cpp (#71508)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71508

"==" is the more universal way to test type equalities, and also ::get() doesn't incur any refcount overhead now, so we can swtich to == instead of relying on type kinds.
ghstack-source-id: 147353057

Test Plan:
CI
buck test xplat/caffe2/android:pytorch_jni_common_test

Differential Revision: D33672433

fbshipit-source-id: 5973fd40de48b8017b5c3ebaa55bcf5b4b391aa3
(cherry picked from commit db84a4b566d1f2f17cda8785e11bc11739e6f50c)
This commit is contained in:
Zhengxu Chen
2022-01-20 15:40:38 -08:00
committed by PyTorch MergeBot
parent 0df607ce00
commit c92ff47afd
3 changed files with 81 additions and 41 deletions

View File

@ -0,0 +1,18 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
#include <gtest/gtest.h>
#include <ATen/core/type_factory.h>
#include "caffe2/android/pytorch_android/src/main/cpp/pytorch_jni_common.h"
using namespace ::testing;
TEST(pytorch_jni_common_test, newJIValueFromAtIValue) {
auto dict = c10::impl::GenericDict(
c10::dynT<c10::IntType>(), c10::dynT<c10::StringType>());
auto dictCallback = [](auto&&) {
return facebook::jni::local_ref<pytorch_jni::JIValue>{};
};
EXPECT_NO_THROW(pytorch_jni::JIValue::newJIValueFromAtIValue(
dict, dictCallback, dictCallback));
}

View File

@ -287,8 +287,51 @@ class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
at::Tensor tensor_;
};
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromStringDict(
c10::Dict<c10::IValue, c10::IValue> dict) {
static auto jMethodDictStringKey =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JMap<
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
"dictStringKeyFrom");
auto jmap = JHashMap<
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>::create();
for (auto& pair : dict) {
jmap->put(
facebook::jni::make_jstring(pair.key().toString()->string()),
JIValue::newJIValueFromAtIValue(pair.value()));
}
return jMethodDictStringKey(JIValue::javaClassStatic(), jmap);
}
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromIntDict(
c10::Dict<c10::IValue, c10::IValue> dict) {
static auto jMethodDictLongKey =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JMap<
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
"dictLongKeyFrom");
auto jmap = JHashMap<
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>::create();
for (auto& pair : dict) {
jmap->put(
facebook::jni::JLong::valueOf(pair.key().toInt()),
JIValue::newJIValueFromAtIValue(pair.value()));
}
return jMethodDictLongKey(JIValue::javaClassStatic(), jmap);
}
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
const at::IValue& ivalue) {
const at::IValue& ivalue,
DictCallback stringDictCallback,
DictCallback intDictCallback) {
Trace _s{"jni::JIValue::newJIValueFromAtIValue"};
if (ivalue.isNone()) {
static auto jMethodOptionalNull =
@ -427,49 +470,16 @@ facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
"Unknown IValue-Dict key type");
}
const auto keyTypeKind = keyType->kind();
if (c10::TypeKind::StringType == keyTypeKind) {
static auto jMethodDictStringKey =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JMap<
facebook::jni::alias_ref<
facebook::jni::JString::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
"dictStringKeyFrom");
auto jmap = JHashMap<
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>::create();
for (auto& pair : dict) {
jmap->put(
facebook::jni::make_jstring(pair.key().toString()->string()),
JIValue::newJIValueFromAtIValue(pair.value()));
}
return jMethodDictStringKey(JIValue::javaClassStatic(), jmap);
} else if (c10::TypeKind::IntType == keyTypeKind) {
static auto jMethodDictLongKey =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JMap<
facebook::jni::alias_ref<
facebook::jni::JLong::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
"dictLongKeyFrom");
auto jmap = JHashMap<
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>::create();
for (auto& pair : dict) {
jmap->put(
facebook::jni::JLong::valueOf(pair.key().toInt()),
JIValue::newJIValueFromAtIValue(pair.value()));
}
return jMethodDictLongKey(JIValue::javaClassStatic(), jmap);
if (*keyType == *c10::StringType::get()) {
return stringDictCallback(std::move(dict));
} else if (*keyType == *c10::IntType::get()) {
return intDictCallback(std::move(dict));
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unsupported IValue-Dict key type");
"Unsupported IValue-Dict key type: %s",
keyType->str().c_str());
}
facebook::jni::throwNewJavaException(

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/util/FunctionRef.h>
#include <fbjni/fbjni.h>
#include <torch/csrc/api/include/torch/types.h>
#include "caffe2/serialize/read_adapter_interface.h"
@ -93,6 +94,9 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
};
class JIValue : public facebook::jni::JavaClass<JIValue> {
using DictCallback = c10::function_ref<facebook::jni::local_ref<JIValue>(
c10::Dict<c10::IValue, c10::IValue>)>;
public:
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
@ -115,10 +119,18 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
constexpr static int kTypeCodeDictLongKey = 14;
static facebook::jni::local_ref<JIValue> newJIValueFromAtIValue(
const at::IValue& ivalue);
const at::IValue& ivalue,
DictCallback stringDictCallback = newJIValueFromStringDict,
DictCallback intDictCallback = newJIValueFromIntDict);
static at::IValue JIValueToAtIValue(
facebook::jni::alias_ref<JIValue> jivalue);
private:
static facebook::jni::local_ref<JIValue> newJIValueFromStringDict(
c10::Dict<c10::IValue, c10::IValue>);
static facebook::jni::local_ref<JIValue> newJIValueFromIntDict(
c10::Dict<c10::IValue, c10::IValue>);
};
void common_registerNatives();