mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use == operator to test type equivalance in pytorch_jni_common.cpp (#71508)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71508 "==" is the more universal way to test type equalities, and also ::get() doesn't incur any refcount overhead now, so we can swtich to == instead of relying on type kinds. ghstack-source-id: 147353057 Test Plan: CI buck test xplat/caffe2/android:pytorch_jni_common_test Differential Revision: D33672433 fbshipit-source-id: 5973fd40de48b8017b5c3ebaa55bcf5b4b391aa3 (cherry picked from commit db84a4b566d1f2f17cda8785e11bc11739e6f50c)
This commit is contained in:
committed by
PyTorch MergeBot
parent
0df607ce00
commit
c92ff47afd
@ -0,0 +1,18 @@
|
||||
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include "caffe2/android/pytorch_android/src/main/cpp/pytorch_jni_common.h"
|
||||
|
||||
using namespace ::testing;
|
||||
|
||||
TEST(pytorch_jni_common_test, newJIValueFromAtIValue) {
|
||||
auto dict = c10::impl::GenericDict(
|
||||
c10::dynT<c10::IntType>(), c10::dynT<c10::StringType>());
|
||||
auto dictCallback = [](auto&&) {
|
||||
return facebook::jni::local_ref<pytorch_jni::JIValue>{};
|
||||
};
|
||||
EXPECT_NO_THROW(pytorch_jni::JIValue::newJIValueFromAtIValue(
|
||||
dict, dictCallback, dictCallback));
|
||||
}
|
@ -287,8 +287,51 @@ class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
|
||||
at::Tensor tensor_;
|
||||
};
|
||||
|
||||
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromStringDict(
|
||||
c10::Dict<c10::IValue, c10::IValue> dict) {
|
||||
static auto jMethodDictStringKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictStringKeyFrom");
|
||||
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||
for (auto& pair : dict) {
|
||||
jmap->put(
|
||||
facebook::jni::make_jstring(pair.key().toString()->string()),
|
||||
JIValue::newJIValueFromAtIValue(pair.value()));
|
||||
}
|
||||
return jMethodDictStringKey(JIValue::javaClassStatic(), jmap);
|
||||
}
|
||||
|
||||
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromIntDict(
|
||||
c10::Dict<c10::IValue, c10::IValue> dict) {
|
||||
static auto jMethodDictLongKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictLongKeyFrom");
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||
for (auto& pair : dict) {
|
||||
jmap->put(
|
||||
facebook::jni::JLong::valueOf(pair.key().toInt()),
|
||||
JIValue::newJIValueFromAtIValue(pair.value()));
|
||||
}
|
||||
return jMethodDictLongKey(JIValue::javaClassStatic(), jmap);
|
||||
}
|
||||
|
||||
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
|
||||
const at::IValue& ivalue) {
|
||||
const at::IValue& ivalue,
|
||||
DictCallback stringDictCallback,
|
||||
DictCallback intDictCallback) {
|
||||
Trace _s{"jni::JIValue::newJIValueFromAtIValue"};
|
||||
if (ivalue.isNone()) {
|
||||
static auto jMethodOptionalNull =
|
||||
@ -427,49 +470,16 @@ facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
|
||||
"Unknown IValue-Dict key type");
|
||||
}
|
||||
|
||||
const auto keyTypeKind = keyType->kind();
|
||||
if (c10::TypeKind::StringType == keyTypeKind) {
|
||||
static auto jMethodDictStringKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JMap<
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JString::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictStringKeyFrom");
|
||||
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||
for (auto& pair : dict) {
|
||||
jmap->put(
|
||||
facebook::jni::make_jstring(pair.key().toString()->string()),
|
||||
JIValue::newJIValueFromAtIValue(pair.value()));
|
||||
}
|
||||
return jMethodDictStringKey(JIValue::javaClassStatic(), jmap);
|
||||
} else if (c10::TypeKind::IntType == keyTypeKind) {
|
||||
static auto jMethodDictLongKey =
|
||||
JIValue::javaClassStatic()
|
||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||
facebook::jni::alias_ref<facebook::jni::JMap<
|
||||
facebook::jni::alias_ref<
|
||||
facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||
"dictLongKeyFrom");
|
||||
auto jmap = JHashMap<
|
||||
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||
for (auto& pair : dict) {
|
||||
jmap->put(
|
||||
facebook::jni::JLong::valueOf(pair.key().toInt()),
|
||||
JIValue::newJIValueFromAtIValue(pair.value()));
|
||||
}
|
||||
return jMethodDictLongKey(JIValue::javaClassStatic(), jmap);
|
||||
if (*keyType == *c10::StringType::get()) {
|
||||
return stringDictCallback(std::move(dict));
|
||||
} else if (*keyType == *c10::IntType::get()) {
|
||||
return intDictCallback(std::move(dict));
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
facebook::jni::gJavaLangIllegalArgumentException,
|
||||
"Unsupported IValue-Dict key type");
|
||||
"Unsupported IValue-Dict key type: %s",
|
||||
keyType->str().c_str());
|
||||
}
|
||||
|
||||
facebook::jni::throwNewJavaException(
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/FunctionRef.h>
|
||||
#include <fbjni/fbjni.h>
|
||||
#include <torch/csrc/api/include/torch/types.h>
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
@ -93,6 +94,9 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
|
||||
};
|
||||
|
||||
class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
using DictCallback = c10::function_ref<facebook::jni::local_ref<JIValue>(
|
||||
c10::Dict<c10::IValue, c10::IValue>)>;
|
||||
|
||||
public:
|
||||
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
|
||||
|
||||
@ -115,10 +119,18 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||
constexpr static int kTypeCodeDictLongKey = 14;
|
||||
|
||||
static facebook::jni::local_ref<JIValue> newJIValueFromAtIValue(
|
||||
const at::IValue& ivalue);
|
||||
const at::IValue& ivalue,
|
||||
DictCallback stringDictCallback = newJIValueFromStringDict,
|
||||
DictCallback intDictCallback = newJIValueFromIntDict);
|
||||
|
||||
static at::IValue JIValueToAtIValue(
|
||||
facebook::jni::alias_ref<JIValue> jivalue);
|
||||
|
||||
private:
|
||||
static facebook::jni::local_ref<JIValue> newJIValueFromStringDict(
|
||||
c10::Dict<c10::IValue, c10::IValue>);
|
||||
static facebook::jni::local_ref<JIValue> newJIValueFromIntDict(
|
||||
c10::Dict<c10::IValue, c10::IValue>);
|
||||
};
|
||||
|
||||
void common_registerNatives();
|
||||
|
Reference in New Issue
Block a user