mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Dict is a reference type (#20669)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20669 Before, Dict was a value type, i.e. copying it did a deep copy. Unfortunately, this doesn't work well with storing and passing Dicts around in IValues because IValues are reference types. This diff changes Dict to be a reference type. Reviewed By: dzhulgakov Differential Revision: D15404911 fbshipit-source-id: dc990d3eb7cae044b74dd0253f8b704dde6a6c86
This commit is contained in:
committed by
Facebook Github Bot
parent
93d5503f34
commit
d5b7138a2c
@ -2,28 +2,43 @@
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/TypeTraits.h>
|
||||
#include <c10/util/TypeList.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
|
||||
namespace c10 {
|
||||
struct IValue;
|
||||
template<class Key, class Value> class Dict;
|
||||
template<class Key, class Value> class DictPtr;
|
||||
|
||||
/**
|
||||
* Creates an empty dict.
|
||||
*/
|
||||
template<class Key, class Value>
|
||||
DictPtr<Key, Value> make_dict();
|
||||
|
||||
namespace impl {
|
||||
bool shallowEquals(const IValue& lhs, const IValue& rhs);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct DictHash {
|
||||
struct DictKeyHash {
|
||||
size_t operator()(const IValue& ivalue) const;
|
||||
};
|
||||
|
||||
struct DictEqualTo {
|
||||
struct DictKeyEqualTo {
|
||||
bool operator()(const IValue& lhs, const IValue& rhs) const {
|
||||
return impl::shallowEquals(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
using dict_map_type = ska::flat_hash_map<IValue, IValue, DictHash, DictEqualTo>;
|
||||
struct DictImpl final : public c10::intrusive_ptr_target {
|
||||
using dict_map_type = ska::flat_hash_map<IValue, IValue, DictKeyHash, DictKeyEqualTo>;
|
||||
dict_map_type dict;
|
||||
|
||||
intrusive_ptr<DictImpl> copy() const;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
@ -62,7 +77,7 @@ public:
|
||||
private:
|
||||
Iterator iterator_;
|
||||
friend class DictIterator<Key, Value, Iterator>;
|
||||
friend class Dict<Key, Value>;
|
||||
friend class DictPtr<Key, Value>;
|
||||
friend bool operator==<Key, Value, Iterator>(const DictIterator<Key, Value, Iterator>& lhs, const DictIterator<Key, Value, Iterator>& rhs);
|
||||
};
|
||||
|
||||
@ -100,7 +115,7 @@ public:
|
||||
|
||||
// the template automatically disables the operator when we are already a
|
||||
// const_iterator, because that would cause a lot of compiler warnings otherwise.
|
||||
template<class const_iterator_ = typename detail::dict_map_type::const_iterator, class = guts::enable_if_t<!std::is_same<const_iterator_, Iterator>::value>>
|
||||
template<class const_iterator_ = typename detail::DictImpl::dict_map_type::const_iterator, class = guts::enable_if_t<!std::is_same<const_iterator_, Iterator>::value>>
|
||||
/* implicit */ operator DictIterator<Key, Value, const_iterator_>() const
|
||||
{
|
||||
return DictIterator<Key, Value, const_iterator_> { const_iterator_ { entryRef_.iterator_ } };
|
||||
@ -111,8 +126,8 @@ private:
|
||||
|
||||
DictEntryRef<Key, Value, Iterator> entryRef_;
|
||||
|
||||
friend class DictIterator<Key, Value, typename detail::dict_map_type::iterator>;
|
||||
friend class Dict<Key, Value>;
|
||||
friend class DictIterator<Key, Value, typename detail::DictImpl::dict_map_type::iterator>;
|
||||
friend class DictPtr<Key, Value>;
|
||||
friend bool operator==<Key, Value, Iterator>(const DictIterator& lhs, const DictIterator& rhs);
|
||||
};
|
||||
|
||||
@ -126,56 +141,72 @@ inline bool operator!=(const DictIterator<Key, Value, Iterator>& lhs, const Dict
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template<class Key, class Value> Dict<Key, Value> toTypedDict(Dict<IValue, IValue>&& dict);
|
||||
template<class Key, class Value> Dict<IValue, IValue> toGenericDict(Dict<Key, Value>&& dict);
|
||||
template<class Key, class Value> DictPtr<Key, Value> toTypedDict(DictPtr<IValue, IValue> dict);
|
||||
template<class Key, class Value> DictPtr<IValue, IValue> toGenericDict(DictPtr<Key, Value> dict);
|
||||
}
|
||||
|
||||
/**
|
||||
* An object of this class stores a map from Key to Value.
|
||||
* You can create instances using the make_dict<Key, Value>() function.
|
||||
*
|
||||
* We use this class in the PyTorch kernel API instead of
|
||||
* std::unordered_map<Key, Value>, because that allows us
|
||||
* to do optimizations and switch out the underlying map
|
||||
* implementation without breaking backwards compatibility
|
||||
* This is a pointer type. After a copy, both DictPtrs
|
||||
* will share the same storage:
|
||||
*
|
||||
* > DictPtr<int, string> a = make_dict<int, string>();
|
||||
* > DictPtr<int, string> b = a;
|
||||
* > b.insert(3, "three");
|
||||
* > ASSERT("three" == a.at(3));
|
||||
*
|
||||
* We use this class in the PyTorch kernel API because that
|
||||
* allows us to do optimizations and switch out the underlying
|
||||
* map implementation without breaking backwards compatibility
|
||||
* for the kernel API.
|
||||
*
|
||||
* The API of this class is borrowed from std::unordered_map,
|
||||
* but with slight differences and it intentionally does not
|
||||
* support the full std::unordered_map API, because a more
|
||||
* narrow abstraction gives us more freedom to change the internals.
|
||||
*/
|
||||
template<class Key, class Value>
|
||||
class Dict final {
|
||||
class DictPtr final {
|
||||
private:
|
||||
// map_ stores the underlying map as a ska::flat_hash_map.
|
||||
// impl_ stores the underlying map as a ska::flat_hash_map.
|
||||
// We intentionally don't offer conversion from/to
|
||||
// ska::flat_hash_map, return references to it or something like that,
|
||||
// because such operations would get expensive if we switch out
|
||||
// the actual map implementation.
|
||||
detail::dict_map_type map_;
|
||||
// This is an intrusive_ptr because DictPtr is a pointer type.
|
||||
// Invariant: This will never be a nullptr, there will always be a valid
|
||||
// DictImpl.
|
||||
c10::intrusive_ptr<detail::DictImpl> impl_;
|
||||
|
||||
explicit Dict(detail::dict_map_type&& map): map_(std::move(map)) {}
|
||||
template<class K, class V> friend Dict<K, V> impl::toTypedDict(Dict<IValue, IValue>&&);
|
||||
template<class K, class V> friend Dict<IValue, IValue> impl::toGenericDict(Dict<K, V>&&);
|
||||
explicit DictPtr(c10::intrusive_ptr<detail::DictImpl>&& impl);
|
||||
template<class K, class V> friend DictPtr<K, V> impl::toTypedDict(DictPtr<IValue, IValue>);
|
||||
template<class K, class V> friend DictPtr<IValue, IValue> impl::toGenericDict(DictPtr<K, V>);
|
||||
|
||||
public:
|
||||
using key_type = Key;
|
||||
using mapped_type = Value;
|
||||
using size_type = typename detail::dict_map_type::size_type;
|
||||
using iterator = impl::DictIterator<Key, Value, typename detail::dict_map_type::iterator>;
|
||||
using const_iterator = impl::DictIterator<Key, Value, typename detail::dict_map_type::const_iterator>;
|
||||
using size_type = typename detail::DictImpl::dict_map_type::size_type;
|
||||
using iterator = impl::DictIterator<Key, Value, typename detail::DictImpl::dict_map_type::iterator>;
|
||||
using const_iterator = impl::DictIterator<Key, Value, typename detail::DictImpl::dict_map_type::const_iterator>;
|
||||
|
||||
/**
|
||||
* Creates an empty dict.
|
||||
*/
|
||||
explicit Dict() = default;
|
||||
friend DictPtr make_dict<Key, Value>();
|
||||
|
||||
~Dict() = default;
|
||||
// please use make_dict instead
|
||||
DictPtr() = delete;
|
||||
|
||||
Dict(const Dict&) = default;
|
||||
Dict(Dict&&) noexcept = default;
|
||||
Dict& operator=(const Dict&) = default;
|
||||
Dict& operator=(Dict&&) noexcept = default;
|
||||
~DictPtr() = default;
|
||||
|
||||
DictPtr(const DictPtr&) = default;
|
||||
DictPtr& operator=(const DictPtr&) = default;
|
||||
DictPtr(DictPtr&&) noexcept;
|
||||
DictPtr& operator=(DictPtr&&) noexcept;
|
||||
|
||||
/**
|
||||
* Create a new DictPtr pointing to a deep copy of the same data.
|
||||
* The DictPtr returned is a new dict with separate storage.
|
||||
* Changes in it are not reflected in the original dict or vice versa.
|
||||
*/
|
||||
DictPtr copy() const;
|
||||
|
||||
/**
|
||||
* Returns an iterator to the first element of the container.
|
||||
@ -300,19 +331,23 @@ public:
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
// GenericDict is how IValue stores dicts. It is, however, not part of the
|
||||
// GenericDictPtr is how IValue stores dicts. It is, however, not part of the
|
||||
// public API. Kernels should use Dicts with concrete Key, Value types instead
|
||||
// (maybe except for some internal prim ops).
|
||||
using GenericDict = Dict<IValue, IValue>;
|
||||
using GenericDictPtr = DictPtr<IValue, IValue>;
|
||||
|
||||
template<class Key, class Value>
|
||||
Dict<Key, Value> toTypedDict(GenericDict&& dict) {
|
||||
return Dict<Key, Value>(std::move(dict.map_));
|
||||
inline GenericDictPtr make_generic_dict() {
|
||||
return make_dict<IValue, IValue>();
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
GenericDict toGenericDict(Dict<Key, Value>&& dict) {
|
||||
return GenericDict(std::move(dict.map_));
|
||||
DictPtr<Key, Value> toTypedDict(GenericDictPtr dict) {
|
||||
return DictPtr<Key, Value>(std::move(dict.impl_));
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
GenericDictPtr toGenericDict(DictPtr<Key, Value> dict) {
|
||||
return GenericDictPtr(std::move(dict.impl_));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -23,7 +23,7 @@ inline bool shallowEquals(const IValue& lhs, const IValue& rhs) {
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline size_t DictHash::operator()(const IValue& ivalue) const {
|
||||
inline size_t DictKeyHash::operator()(const IValue& ivalue) const {
|
||||
if (ivalue.isInt()) {
|
||||
return std::hash<int>()(ivalue.toInt());
|
||||
} else if (ivalue.isString()) {
|
||||
@ -37,59 +37,90 @@ inline size_t DictHash::operator()(const IValue& ivalue) const {
|
||||
}
|
||||
}
|
||||
|
||||
inline intrusive_ptr<DictImpl> DictImpl::copy() const {
|
||||
auto result = make_intrusive<DictImpl>();
|
||||
result->dict = dict;
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
typename Dict<Key, Value>::iterator Dict<Key, Value>::begin() {
|
||||
return iterator{map_.begin()};
|
||||
DictPtr<Key, Value> make_dict() {
|
||||
return DictPtr<Key, Value>(make_intrusive<detail::DictImpl>());
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
typename Dict<Key, Value>::const_iterator Dict<Key, Value>::begin() const {
|
||||
return const_iterator{map_.begin()};
|
||||
DictPtr<Key, Value>::DictPtr(DictPtr&& rhs) noexcept: impl_(std::move(rhs.impl_)) {
|
||||
rhs.impl_ = make_intrusive<detail::DictImpl>();
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
typename Dict<Key, Value>::const_iterator Dict<Key, Value>::cbegin() const {
|
||||
return const_iterator{map_.cbegin()};
|
||||
DictPtr<Key, Value>::DictPtr(c10::intrusive_ptr<detail::DictImpl>&& impl): impl_(std::move(impl)) {}
|
||||
|
||||
template<class Key, class Value>
|
||||
DictPtr<Key, Value>& DictPtr<Key, Value>::operator=(DictPtr&& rhs) noexcept {
|
||||
impl_ = std::move(rhs.impl_);
|
||||
rhs.impl_ = make_intrusive<detail::DictImpl>();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
typename Dict<Key, Value>::iterator Dict<Key, Value>::end() {
|
||||
return iterator{map_.end()};
|
||||
DictPtr<Key, Value> DictPtr<Key, Value>::copy() const {
|
||||
return DictPtr<Key, Value>(impl_->copy());
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
typename Dict<Key, Value>::const_iterator Dict<Key, Value>::end() const {
|
||||
return const_iterator{map_.end()};
|
||||
typename DictPtr<Key, Value>::iterator DictPtr<Key, Value>::begin() {
|
||||
return iterator{impl_->dict.begin()};
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
typename Dict<Key, Value>::const_iterator Dict<Key, Value>::cend() const {
|
||||
return const_iterator{map_.cend()};
|
||||
typename DictPtr<Key, Value>::const_iterator DictPtr<Key, Value>::begin() const {
|
||||
return const_iterator{impl_->dict.begin()};
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
bool Dict<Key, Value>::empty() const {
|
||||
return map_.empty();
|
||||
typename DictPtr<Key, Value>::const_iterator DictPtr<Key, Value>::cbegin() const {
|
||||
return const_iterator{impl_->dict.cbegin()};
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
typename Dict<Key, Value>::size_type Dict<Key, Value>::size() const {
|
||||
return map_.size();
|
||||
typename DictPtr<Key, Value>::iterator DictPtr<Key, Value>::end() {
|
||||
return iterator{impl_->dict.end()};
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
void Dict<Key, Value>::clear() {
|
||||
map_.clear();
|
||||
typename DictPtr<Key, Value>::const_iterator DictPtr<Key, Value>::end() const {
|
||||
return const_iterator{impl_->dict.end()};
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
typename DictPtr<Key, Value>::const_iterator DictPtr<Key, Value>::cend() const {
|
||||
return const_iterator{impl_->dict.cend()};
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
bool DictPtr<Key, Value>::empty() const {
|
||||
return impl_->dict.empty();
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
typename DictPtr<Key, Value>::size_type DictPtr<Key, Value>::size() const {
|
||||
return impl_->dict.size();
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
void DictPtr<Key, Value>::clear() {
|
||||
impl_->dict.clear();
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
template<class Key_, class Value_>
|
||||
std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert(Key_&& key, Value_&& value) {
|
||||
static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of Dict::insert");
|
||||
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of Dict::insert");
|
||||
auto inserted = map_.insert(std::pair<IValue, IValue>{
|
||||
std::pair<typename DictPtr<Key, Value>::iterator, bool> DictPtr<Key, Value>::insert(Key_&& key, Value_&& value) {
|
||||
static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of DictPtr::insert");
|
||||
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of DictPtr::insert");
|
||||
auto inserted = impl_->dict.insert(std::pair<IValue, IValue>{
|
||||
Key(std::forward<Key_>(key)),
|
||||
Value(std::forward<Value_>(value))});
|
||||
return {iterator{inserted.first}, inserted.second};
|
||||
@ -97,48 +128,48 @@ std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert(Ke
|
||||
|
||||
template<class Key, class Value>
|
||||
template<class Key_, class Value_>
|
||||
std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert_or_assign(Key_&& key, Value_&& value) {
|
||||
static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of Dict::insert_or_assign");
|
||||
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of Dict::insert_or_assign");
|
||||
auto inserted = map_.insert_or_assign(
|
||||
std::pair<typename DictPtr<Key, Value>::iterator, bool> DictPtr<Key, Value>::insert_or_assign(Key_&& key, Value_&& value) {
|
||||
static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of DictPtr::insert_or_assign");
|
||||
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of DictPtr::insert_or_assign");
|
||||
auto inserted = impl_->dict.insert_or_assign(
|
||||
Key(std::forward<Key_>(key)),
|
||||
Value(std::forward<Value_>(value)));
|
||||
return {iterator{inserted.first}, inserted.second};
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
void Dict<Key, Value>::erase(const_iterator iter) {
|
||||
map_.erase(iter.entryRef_.iterator_);
|
||||
void DictPtr<Key, Value>::erase(const_iterator iter) {
|
||||
impl_->dict.erase(iter.entryRef_.iterator_);
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
C10_NODISCARD size_t Dict<Key, Value>::erase(const Key& key) {
|
||||
return map_.erase(key);
|
||||
C10_NODISCARD size_t DictPtr<Key, Value>::erase(const Key& key) {
|
||||
return impl_->dict.erase(key);
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
Value Dict<Key, Value>::at(const Key& key) const {
|
||||
return map_.at(key).template to<Value>();
|
||||
Value DictPtr<Key, Value>::at(const Key& key) const {
|
||||
return impl_->dict.at(key).template to<Value>();
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
typename Dict<Key, Value>::iterator Dict<Key, Value>::find(const Key& key) {
|
||||
return iterator{map_.find(key)};
|
||||
typename DictPtr<Key, Value>::iterator DictPtr<Key, Value>::find(const Key& key) {
|
||||
return iterator{impl_->dict.find(key)};
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
typename Dict<Key, Value>::const_iterator Dict<Key, Value>::find(const Key& key) const {
|
||||
return const_iterator{map_.find(key)};
|
||||
typename DictPtr<Key, Value>::const_iterator DictPtr<Key, Value>::find(const Key& key) const {
|
||||
return const_iterator{impl_->dict.find(key)};
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
bool Dict<Key, Value>::contains(const Key& key) const {
|
||||
bool DictPtr<Key, Value>::contains(const Key& key) const {
|
||||
return end() != find(key);
|
||||
}
|
||||
|
||||
template<class Key, class Value>
|
||||
void Dict<Key, Value>::reserve(size_type count) {
|
||||
map_.reserve(count);
|
||||
void DictPtr<Key, Value>::reserve(size_type count) {
|
||||
impl_->dict.reserve(count);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -4,33 +4,34 @@
|
||||
#include <string>
|
||||
|
||||
using std::string;
|
||||
using c10::Dict;
|
||||
using c10::DictPtr;
|
||||
using c10::make_dict;
|
||||
|
||||
TEST(DictTest, givenEmptyDict_whenCallingEmpty_thenReturnsTrue) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
EXPECT_TRUE(dict.empty());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenNonemptyDict_whenCallingEmpty_thenReturnsFalse) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "value");
|
||||
EXPECT_FALSE(dict.empty());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenEmptyDict_whenCallingSize_thenReturnsZero) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
EXPECT_EQ(0, dict.size());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenNonemptyDict_whenCallingSize_thenReturnsNumberOfElements) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "value");
|
||||
dict.insert(4, "value2");
|
||||
EXPECT_EQ(2, dict.size());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenNonemptyDict_whenCallingClear_thenIsEmpty) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "value");
|
||||
dict.insert(4, "value2");
|
||||
dict.clear();
|
||||
@ -38,24 +39,24 @@ TEST(DictTest, givenNonemptyDict_whenCallingClear_thenIsEmpty) {
|
||||
}
|
||||
|
||||
TEST(DictTest, whenInsertingNewKey_thenReturnsTrueAndIteratorToNewElement) {
|
||||
Dict<int, string> dict;
|
||||
std::pair<Dict<int, string>::iterator, bool> result = dict.insert(3, "value");
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
std::pair<DictPtr<int, string>::iterator, bool> result = dict.insert(3, "value");
|
||||
EXPECT_TRUE(result.second);
|
||||
EXPECT_EQ(3, result.first->key());
|
||||
EXPECT_EQ("value", result.first->value());
|
||||
}
|
||||
|
||||
TEST(DictTest, whenInsertingExistingKey_thenReturnsFalseAndIteratorToExistingElement) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "old_value");
|
||||
std::pair<Dict<int, string>::iterator, bool> result = dict.insert(3, "new_value");
|
||||
std::pair<DictPtr<int, string>::iterator, bool> result = dict.insert(3, "new_value");
|
||||
EXPECT_FALSE(result.second);
|
||||
EXPECT_EQ(3, result.first->key());
|
||||
EXPECT_EQ("old_value", result.first->value());
|
||||
}
|
||||
|
||||
TEST(DictTest, whenInsertingExistingKey_thenDoesNotModifyDict) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "old_value");
|
||||
dict.insert(3, "new_value");
|
||||
EXPECT_EQ(1, dict.size());
|
||||
@ -64,24 +65,24 @@ TEST(DictTest, whenInsertingExistingKey_thenDoesNotModifyDict) {
|
||||
}
|
||||
|
||||
TEST(DictTest, whenInsertOrAssigningNewKey_thenReturnsTrueAndIteratorToNewElement) {
|
||||
Dict<int, string> dict;
|
||||
std::pair<Dict<int, string>::iterator, bool> result = dict.insert_or_assign(3, "value");
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
std::pair<DictPtr<int, string>::iterator, bool> result = dict.insert_or_assign(3, "value");
|
||||
EXPECT_TRUE(result.second);
|
||||
EXPECT_EQ(3, result.first->key());
|
||||
EXPECT_EQ("value", result.first->value());
|
||||
}
|
||||
|
||||
TEST(DictTest, whenInsertOrAssigningExistingKey_thenReturnsFalseAndIteratorToChangedElement) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "old_value");
|
||||
std::pair<Dict<int, string>::iterator, bool> result = dict.insert_or_assign(3, "new_value");
|
||||
std::pair<DictPtr<int, string>::iterator, bool> result = dict.insert_or_assign(3, "new_value");
|
||||
EXPECT_FALSE(result.second);
|
||||
EXPECT_EQ(3, result.first->key());
|
||||
EXPECT_EQ("new_value", result.first->value());
|
||||
}
|
||||
|
||||
TEST(DictTest, whenInsertOrAssigningExistingKey_thenDoesModifyDict) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "old_value");
|
||||
dict.insert_or_assign(3, "new_value");
|
||||
EXPECT_EQ(1, dict.size());
|
||||
@ -90,8 +91,8 @@ TEST(DictTest, whenInsertOrAssigningExistingKey_thenDoesModifyDict) {
|
||||
}
|
||||
|
||||
TEST(DictTest, givenEmptyDict_whenIterating_thenBeginIsEnd) {
|
||||
Dict<int, string> dict{};
|
||||
const Dict<int, string> cdict{};
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
const DictPtr<int, string> cdict = make_dict<int, string>();
|
||||
EXPECT_EQ(dict.begin(), dict.end());
|
||||
EXPECT_EQ(dict.cbegin(), dict.cend());
|
||||
EXPECT_EQ(cdict.begin(), cdict.end());
|
||||
@ -99,12 +100,12 @@ TEST(DictTest, givenEmptyDict_whenIterating_thenBeginIsEnd) {
|
||||
}
|
||||
|
||||
TEST(DictTest, givenMutableDict_whenIterating_thenFindsElements) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(5, "5");
|
||||
bool found_first = false;
|
||||
bool found_second = false;
|
||||
for (Dict<int, string>::iterator iter = dict.begin(); iter != dict.end(); ++iter) {
|
||||
for (DictPtr<int, string>::iterator iter = dict.begin(); iter != dict.end(); ++iter) {
|
||||
if (iter->key() == 3) {
|
||||
EXPECT_EQ("3", iter->value());
|
||||
EXPECT_FALSE(found_first);
|
||||
@ -122,7 +123,7 @@ TEST(DictTest, givenMutableDict_whenIterating_thenFindsElements) {
|
||||
}
|
||||
|
||||
TEST(DictTest, givenMutableDict_whenIteratingWithForeach_thenFindsElements) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(5, "5");
|
||||
bool found_first = false;
|
||||
@ -145,13 +146,13 @@ TEST(DictTest, givenMutableDict_whenIteratingWithForeach_thenFindsElements) {
|
||||
}
|
||||
|
||||
TEST(DictTest, givenConstDict_whenIterating_thenFindsElements) {
|
||||
Dict<int, string> dict_;
|
||||
DictPtr<int, string> dict_ = make_dict<int, string>();
|
||||
dict_.insert(3, "3");
|
||||
dict_.insert(5, "5");
|
||||
const Dict<int, string>& dict = dict_;
|
||||
const DictPtr<int, string>& dict = dict_;
|
||||
bool found_first = false;
|
||||
bool found_second = false;
|
||||
for (Dict<int, string>::const_iterator iter = dict.begin(); iter != dict.end(); ++iter) {
|
||||
for (DictPtr<int, string>::const_iterator iter = dict.begin(); iter != dict.end(); ++iter) {
|
||||
if (iter->key() == 3) {
|
||||
EXPECT_EQ("3", iter->value());
|
||||
EXPECT_FALSE(found_first);
|
||||
@ -169,10 +170,10 @@ TEST(DictTest, givenConstDict_whenIterating_thenFindsElements) {
|
||||
}
|
||||
|
||||
TEST(DictTest, givenConstDict_whenIteratingWithForeach_thenFindsElements) {
|
||||
Dict<int, string> dict_;
|
||||
DictPtr<int, string> dict_ = make_dict<int, string>();
|
||||
dict_.insert(3, "3");
|
||||
dict_.insert(5, "5");
|
||||
const Dict<int, string>& dict = dict_;
|
||||
const DictPtr<int, string>& dict = dict_;
|
||||
bool found_first = false;
|
||||
bool found_second = false;
|
||||
for (const auto& elem : dict) {
|
||||
@ -193,28 +194,28 @@ TEST(DictTest, givenConstDict_whenIteratingWithForeach_thenFindsElements) {
|
||||
}
|
||||
|
||||
TEST(DictTest, givenIterator_thenCanModifyValue) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "old_value");
|
||||
dict.begin()->setValue("new_value");
|
||||
EXPECT_EQ("new_value", dict.begin()->value());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenOneElementDict_whenErasingByConstIterator_thenDictIsEmpty) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.erase(dict.cbegin());
|
||||
EXPECT_TRUE(dict.empty());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenOneElementDict_whenErasingByIterator_thenDictIsEmpty) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.erase(dict.begin());
|
||||
EXPECT_TRUE(dict.empty());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenOneElementDict_whenErasingByKey_thenReturnsOneAndDictIsEmpty) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
bool result = dict.erase(3);
|
||||
EXPECT_EQ(1, result);
|
||||
@ -222,7 +223,7 @@ TEST(DictTest, givenOneElementDict_whenErasingByKey_thenReturnsOneAndDictIsEmpty
|
||||
}
|
||||
|
||||
TEST(DictTest, givenOneElementDict_whenErasingByNonexistingKey_thenReturnsZeroAndDictIsUnchanged) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
bool result = dict.erase(4);
|
||||
EXPECT_EQ(0, result);
|
||||
@ -230,80 +231,80 @@ TEST(DictTest, givenOneElementDict_whenErasingByNonexistingKey_thenReturnsZeroAn
|
||||
}
|
||||
|
||||
TEST(DictTest, whenCallingAtWithExistingKey_thenReturnsCorrectElement) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
EXPECT_EQ("4", dict.at(4));
|
||||
}
|
||||
|
||||
TEST(DictTest, whenCallingAtWithNonExistingKey_thenReturnsCorrectElement) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
EXPECT_THROW(dict.at(5), std::out_of_range);
|
||||
}
|
||||
|
||||
TEST(DictTest, givenMutableDict_whenCallingFindOnExistingKey_thenFindsCorrectElement) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
Dict<int, string>::iterator found = dict.find(3);
|
||||
DictPtr<int, string>::iterator found = dict.find(3);
|
||||
EXPECT_EQ(3, found->key());
|
||||
EXPECT_EQ("3", found->value());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenMutableDict_whenCallingFindOnNonExistingKey_thenReturnsEnd) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
Dict<int, string>::iterator found = dict.find(5);
|
||||
DictPtr<int, string>::iterator found = dict.find(5);
|
||||
EXPECT_EQ(dict.end(), found);
|
||||
}
|
||||
|
||||
TEST(DictTest, givenConstDict_whenCallingFindOnExistingKey_thenFindsCorrectElement) {
|
||||
Dict<int, string> dict_;
|
||||
DictPtr<int, string> dict_ = make_dict<int, string>();
|
||||
dict_.insert(3, "3");
|
||||
dict_.insert(4, "4");
|
||||
const Dict<int, string>& dict = dict_;
|
||||
Dict<int, string>::const_iterator found = dict.find(3);
|
||||
const DictPtr<int, string>& dict = dict_;
|
||||
DictPtr<int, string>::const_iterator found = dict.find(3);
|
||||
EXPECT_EQ(3, found->key());
|
||||
EXPECT_EQ("3", found->value());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenConstDict_whenCallingFindOnNonExistingKey_thenReturnsEnd) {
|
||||
Dict<int, string> dict_;
|
||||
DictPtr<int, string> dict_ = make_dict<int, string>();
|
||||
dict_.insert(3, "3");
|
||||
dict_.insert(4, "4");
|
||||
const Dict<int, string>& dict = dict_;
|
||||
Dict<int, string>::const_iterator found = dict.find(5);
|
||||
const DictPtr<int, string>& dict = dict_;
|
||||
DictPtr<int, string>::const_iterator found = dict.find(5);
|
||||
EXPECT_EQ(dict.end(), found);
|
||||
}
|
||||
|
||||
TEST(DictTest, whenCallingContainsWithExistingKey_thenReturnsTrue) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
EXPECT_TRUE(dict.contains(3));
|
||||
}
|
||||
|
||||
TEST(DictTest, whenCallingContainsWithNonExistingKey_thenReturnsFalse) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
EXPECT_FALSE(dict.contains(5));
|
||||
}
|
||||
|
||||
TEST(DictTest, whenCallingReserve_thenDoesntCrash) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.reserve(100);
|
||||
}
|
||||
|
||||
TEST(DictTest, whenCopyConstructingDict_thenAreEqual) {
|
||||
Dict<int, string> dict1;
|
||||
DictPtr<int, string> dict1 = make_dict<int, string>();
|
||||
dict1.insert(3, "3");
|
||||
dict1.insert(4, "4");
|
||||
|
||||
Dict<int, string> dict2(dict1);
|
||||
DictPtr<int, string> dict2(dict1);
|
||||
|
||||
EXPECT_EQ(2, dict2.size());
|
||||
EXPECT_EQ("3", dict2.at(3));
|
||||
@ -311,11 +312,11 @@ TEST(DictTest, whenCopyConstructingDict_thenAreEqual) {
|
||||
}
|
||||
|
||||
TEST(DictTest, whenCopyAssigningDict_thenAreEqual) {
|
||||
Dict<int, string> dict1;
|
||||
DictPtr<int, string> dict1 = make_dict<int, string>();
|
||||
dict1.insert(3, "3");
|
||||
dict1.insert(4, "4");
|
||||
|
||||
Dict<int, string> dict2;
|
||||
DictPtr<int, string> dict2 = make_dict<int, string>();
|
||||
dict2 = dict1;
|
||||
|
||||
EXPECT_EQ(2, dict2.size());
|
||||
@ -323,12 +324,24 @@ TEST(DictTest, whenCopyAssigningDict_thenAreEqual) {
|
||||
EXPECT_EQ("4", dict2.at(4));
|
||||
}
|
||||
|
||||
TEST(DictTest, whenMoveConstructingDict_thenNewIsCorrect) {
|
||||
Dict<int, string> dict1;
|
||||
TEST(DictTest, whenCopyingDict_thenAreEqual) {
|
||||
DictPtr<int, string> dict1 = make_dict<int, string>();
|
||||
dict1.insert(3, "3");
|
||||
dict1.insert(4, "4");
|
||||
|
||||
Dict<int, string> dict2(std::move(dict1));
|
||||
DictPtr<int, string> dict2 = dict1.copy();
|
||||
|
||||
EXPECT_EQ(2, dict2.size());
|
||||
EXPECT_EQ("3", dict2.at(3));
|
||||
EXPECT_EQ("4", dict2.at(4));
|
||||
}
|
||||
|
||||
TEST(DictTest, whenMoveConstructingDict_thenNewIsCorrect) {
|
||||
DictPtr<int, string> dict1 = make_dict<int, string>();
|
||||
dict1.insert(3, "3");
|
||||
dict1.insert(4, "4");
|
||||
|
||||
DictPtr<int, string> dict2(std::move(dict1));
|
||||
|
||||
EXPECT_EQ(2, dict2.size());
|
||||
EXPECT_EQ("3", dict2.at(3));
|
||||
@ -336,11 +349,11 @@ TEST(DictTest, whenMoveConstructingDict_thenNewIsCorrect) {
|
||||
}
|
||||
|
||||
TEST(DictTest, whenMoveAssigningDict_thenNewIsCorrect) {
|
||||
Dict<int, string> dict1;
|
||||
DictPtr<int, string> dict1 = make_dict<int, string>();
|
||||
dict1.insert(3, "3");
|
||||
dict1.insert(4, "4");
|
||||
|
||||
Dict<int, string> dict2;
|
||||
DictPtr<int, string> dict2 = make_dict<int, string>();
|
||||
dict2 = std::move(dict1);
|
||||
|
||||
EXPECT_EQ(2, dict2.size());
|
||||
@ -349,95 +362,95 @@ TEST(DictTest, whenMoveAssigningDict_thenNewIsCorrect) {
|
||||
}
|
||||
|
||||
TEST(DictTest, whenMoveConstructingDict_thenOldIsEmpty) {
|
||||
Dict<int, string> dict1;
|
||||
DictPtr<int, string> dict1 = make_dict<int, string>();
|
||||
dict1.insert(3, "3");
|
||||
dict1.insert(4, "4");
|
||||
|
||||
Dict<int, string> dict2(std::move(dict1));
|
||||
DictPtr<int, string> dict2(std::move(dict1));
|
||||
EXPECT_TRUE(dict1.empty());
|
||||
}
|
||||
|
||||
TEST(DictTest, whenMoveAssigningDict_thenOldIsEmpty) {
|
||||
Dict<int, string> dict1;
|
||||
DictPtr<int, string> dict1 = make_dict<int, string>();
|
||||
dict1.insert(3, "3");
|
||||
dict1.insert(4, "4");
|
||||
|
||||
Dict<int, string> dict2;
|
||||
DictPtr<int, string> dict2 = make_dict<int, string>();
|
||||
dict2 = std::move(dict1);
|
||||
EXPECT_TRUE(dict1.empty());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenMutableIterator_whenAssigningToConstIterator_thenWorks) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
Dict<int, string>::iterator iter = dict.begin();
|
||||
Dict<int, string>::const_iterator const_iter = iter;
|
||||
DictPtr<int, string>::iterator iter = dict.begin();
|
||||
DictPtr<int, string>::const_iterator const_iter = iter;
|
||||
EXPECT_EQ(3, const_iter->key());
|
||||
EXPECT_EQ("3", const_iter->value());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenMutableIterator_whenPostfixIncrementing_thenMovesToNextAndReturnsOldPosition) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
|
||||
Dict<int, string>::iterator iter1 = dict.begin();
|
||||
Dict<int, string>::iterator iter2 = iter1++;
|
||||
DictPtr<int, string>::iterator iter1 = dict.begin();
|
||||
DictPtr<int, string>::iterator iter2 = iter1++;
|
||||
EXPECT_NE(dict.begin()->key(), iter1->key());
|
||||
EXPECT_EQ(dict.begin()->key(), iter2->key());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenConstIterator_whenPostfixIncrementing_thenMovesToNextAndReturnsOldPosition) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
|
||||
Dict<int, string>::const_iterator iter1 = dict.cbegin();
|
||||
Dict<int, string>::const_iterator iter2 = iter1++;
|
||||
DictPtr<int, string>::const_iterator iter1 = dict.cbegin();
|
||||
DictPtr<int, string>::const_iterator iter2 = iter1++;
|
||||
EXPECT_NE(dict.begin()->key(), iter1->key());
|
||||
EXPECT_EQ(dict.begin()->key(), iter2->key());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenMutableIterator_whenPrefixIncrementing_thenMovesToNextAndReturnsNewPosition) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
|
||||
Dict<int, string>::iterator iter1 = dict.begin();
|
||||
Dict<int, string>::iterator iter2 = ++iter1;
|
||||
DictPtr<int, string>::iterator iter1 = dict.begin();
|
||||
DictPtr<int, string>::iterator iter2 = ++iter1;
|
||||
EXPECT_NE(dict.begin()->key(), iter1->key());
|
||||
EXPECT_NE(dict.begin()->key(), iter2->key());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenConstIterator_whenPrefixIncrementing_thenMovesToNextAndReturnsNewPosition) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
|
||||
Dict<int, string>::const_iterator iter1 = dict.cbegin();
|
||||
Dict<int, string>::const_iterator iter2 = ++iter1;
|
||||
DictPtr<int, string>::const_iterator iter1 = dict.cbegin();
|
||||
DictPtr<int, string>::const_iterator iter2 = ++iter1;
|
||||
EXPECT_NE(dict.begin()->key(), iter1->key());
|
||||
EXPECT_NE(dict.begin()->key(), iter2->key());
|
||||
}
|
||||
|
||||
TEST(DictTest, givenEqualMutableIterators_thenAreEqual) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
|
||||
Dict<int, string>::iterator iter1 = dict.begin();
|
||||
Dict<int, string>::iterator iter2 = dict.begin();
|
||||
DictPtr<int, string>::iterator iter1 = dict.begin();
|
||||
DictPtr<int, string>::iterator iter2 = dict.begin();
|
||||
EXPECT_TRUE(iter1 == iter2);
|
||||
EXPECT_FALSE(iter1 != iter2);
|
||||
}
|
||||
|
||||
TEST(DictTest, givenDifferentMutableIterators_thenAreNotEqual) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
|
||||
Dict<int, string>::iterator iter1 = dict.begin();
|
||||
Dict<int, string>::iterator iter2 = dict.begin();
|
||||
DictPtr<int, string>::iterator iter1 = dict.begin();
|
||||
DictPtr<int, string>::iterator iter2 = dict.begin();
|
||||
iter2++;
|
||||
|
||||
EXPECT_FALSE(iter1 == iter2);
|
||||
@ -445,23 +458,23 @@ TEST(DictTest, givenDifferentMutableIterators_thenAreNotEqual) {
|
||||
}
|
||||
|
||||
TEST(DictTest, givenEqualConstIterators_thenAreEqual) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
|
||||
Dict<int, string>::const_iterator iter1 = dict.cbegin();
|
||||
Dict<int, string>::const_iterator iter2 = dict.cbegin();
|
||||
DictPtr<int, string>::const_iterator iter1 = dict.cbegin();
|
||||
DictPtr<int, string>::const_iterator iter2 = dict.cbegin();
|
||||
EXPECT_TRUE(iter1 == iter2);
|
||||
EXPECT_FALSE(iter1 != iter2);
|
||||
}
|
||||
|
||||
TEST(DictTest, givenDifferentConstIterators_thenAreNotEqual) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
dict.insert(4, "4");
|
||||
|
||||
Dict<int, string>::const_iterator iter1 = dict.cbegin();
|
||||
Dict<int, string>::const_iterator iter2 = dict.cbegin();
|
||||
DictPtr<int, string>::const_iterator iter1 = dict.cbegin();
|
||||
DictPtr<int, string>::const_iterator iter2 = dict.cbegin();
|
||||
iter2++;
|
||||
|
||||
EXPECT_FALSE(iter1 == iter2);
|
||||
@ -469,10 +482,10 @@ TEST(DictTest, givenDifferentConstIterators_thenAreNotEqual) {
|
||||
}
|
||||
|
||||
TEST(DictTest, givenMutableIterator_whenDereferencing_thenPointsToCorrectElement) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
|
||||
Dict<int, string>::iterator iter = dict.begin();
|
||||
DictPtr<int, string>::iterator iter = dict.begin();
|
||||
EXPECT_EQ(3, (*iter).key());
|
||||
EXPECT_EQ("3", (*iter).value());
|
||||
EXPECT_EQ(3, iter->key());
|
||||
@ -480,10 +493,10 @@ TEST(DictTest, givenMutableIterator_whenDereferencing_thenPointsToCorrectElement
|
||||
}
|
||||
|
||||
TEST(DictTest, givenConstIterator_whenDereferencing_thenPointsToCorrectElement) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
|
||||
Dict<int, string>::const_iterator iter = dict.cbegin();
|
||||
DictPtr<int, string>::const_iterator iter = dict.cbegin();
|
||||
EXPECT_EQ(3, (*iter).key());
|
||||
EXPECT_EQ("3", (*iter).value());
|
||||
EXPECT_EQ(3, iter->key());
|
||||
@ -491,10 +504,10 @@ TEST(DictTest, givenConstIterator_whenDereferencing_thenPointsToCorrectElement)
|
||||
}
|
||||
|
||||
TEST(DictTest, givenMutableIterator_whenWritingToValue_thenWorks) {
|
||||
Dict<int, string> dict;
|
||||
DictPtr<int, string> dict = make_dict<int, string>();
|
||||
dict.insert(3, "3");
|
||||
|
||||
Dict<int, string>::iterator iter = dict.begin();
|
||||
DictPtr<int, string>::iterator iter = dict.begin();
|
||||
|
||||
(*iter).setValue("new_value");
|
||||
EXPECT_EQ("new_value", dict.begin()->value());
|
||||
@ -502,3 +515,27 @@ TEST(DictTest, givenMutableIterator_whenWritingToValue_thenWorks) {
|
||||
iter->setValue("new_value_2");
|
||||
EXPECT_EQ("new_value_2", dict.begin()->value());
|
||||
}
|
||||
|
||||
TEST(DictTest, isReferenceType) {
|
||||
DictPtr<int, string> dict1 = make_dict<int, string>();
|
||||
DictPtr<int, string> dict2(dict1);
|
||||
DictPtr<int, string> dict3 = make_dict<int, string>();
|
||||
dict3 = dict1;
|
||||
|
||||
dict1.insert(3, "three");
|
||||
EXPECT_EQ(1, dict1.size());
|
||||
EXPECT_EQ(1, dict2.size());
|
||||
EXPECT_EQ(1, dict3.size());
|
||||
}
|
||||
|
||||
TEST(DictTest, copyHasSeparateStorage) {
|
||||
DictPtr<int, string> dict1 = make_dict<int, string>();
|
||||
DictPtr<int, string> dict2(dict1.copy());
|
||||
DictPtr<int, string> dict3 = make_dict<int, string>();
|
||||
dict3 = dict1.copy();
|
||||
|
||||
dict1.insert(3, "three");
|
||||
EXPECT_EQ(1, dict1.size());
|
||||
EXPECT_EQ(0, dict2.size());
|
||||
EXPECT_EQ(0, dict3.size());
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ struct Function;
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
namespace c10 {
|
||||
template<class Key, class Value> class Dict;
|
||||
template<class Key, class Value> class DictPtr;
|
||||
struct IValue;
|
||||
namespace ivalue {
|
||||
struct Tuple;
|
||||
@ -220,7 +220,7 @@ struct CAFFE2_API IValue final {
|
||||
const std::vector<bool>& toBoolListRef() const;
|
||||
const std::vector<at::Tensor>& toTensorListRef() const;
|
||||
const std::vector<IValue>& toGenericListRef() const;
|
||||
const c10::Dict<IValue, IValue>& toGenericDictRef() const;
|
||||
const c10::DictPtr<IValue, IValue>& toGenericDictRef() const;
|
||||
const std::string& toStringRef() const;
|
||||
|
||||
// ConstantString
|
||||
@ -261,7 +261,7 @@ struct CAFFE2_API IValue final {
|
||||
|
||||
// GenericDict
|
||||
IValue(c10::intrusive_ptr<ivalue::GenericDict> v);
|
||||
IValue(c10::Dict<IValue, IValue> v);
|
||||
IValue(c10::DictPtr<IValue, IValue> v);
|
||||
bool isGenericDict() const { return Tag::GenericDict == tag; }
|
||||
c10::intrusive_ptr<ivalue::GenericDict> toGenericDict() &&;
|
||||
c10::intrusive_ptr<ivalue::GenericDict> toGenericDict() const &;
|
||||
|
@ -401,19 +401,19 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
|
||||
|
||||
struct C10_EXPORT ivalue::GenericDict : c10::intrusive_ptr_target {
|
||||
private:
|
||||
c10::impl::GenericDict elements_;
|
||||
c10::impl::GenericDictPtr elements_;
|
||||
|
||||
public:
|
||||
GenericDict(c10::impl::GenericDict elements_)
|
||||
GenericDict(c10::impl::GenericDictPtr elements_)
|
||||
: elements_(std::move(elements_)) {}
|
||||
static c10::intrusive_ptr<GenericDict> create(
|
||||
c10::impl::GenericDict elements_) {
|
||||
c10::impl::GenericDictPtr elements_) {
|
||||
return c10::make_intrusive<GenericDict>(std::move(elements_));
|
||||
}
|
||||
const c10::impl::GenericDict& elements() const & {
|
||||
const c10::impl::GenericDictPtr& elements() const & {
|
||||
return elements_;
|
||||
}
|
||||
c10::impl::GenericDict& elements() & {
|
||||
c10::impl::GenericDictPtr& elements() & {
|
||||
return elements_;
|
||||
}
|
||||
|
||||
@ -570,7 +570,7 @@ inline IValue::IValue(c10::intrusive_ptr<ivalue::GenericDict> v)
|
||||
: tag(Tag::GenericDict), is_intrusive_ptr(true) {
|
||||
payload.as_intrusive_ptr = v.release();
|
||||
}
|
||||
inline IValue::IValue(c10::impl::GenericDict v)
|
||||
inline IValue::IValue(c10::impl::GenericDictPtr v)
|
||||
: IValue(ivalue::GenericDict::create(std::move(v))) {}
|
||||
|
||||
inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
|
||||
@ -602,7 +602,7 @@ inline const std::vector<IValue>& IValue::toGenericListRef() const {
|
||||
return toGenericList()->elements();
|
||||
}
|
||||
|
||||
inline const c10::impl::GenericDict& IValue::
|
||||
inline const c10::impl::GenericDictPtr& IValue::
|
||||
toGenericDictRef() const {
|
||||
return toGenericDict()->elements();
|
||||
}
|
||||
|
@ -1280,7 +1280,7 @@ struct getTypePtr_<std::unordered_map<K, V>> final {
|
||||
}
|
||||
};
|
||||
template <class K, class V>
|
||||
struct getTypePtr_<c10::Dict<K, V>> final {
|
||||
struct getTypePtr_<c10::DictPtr<K, V>> final {
|
||||
static TypePtr call() {
|
||||
static auto type =
|
||||
DictType::create(getTypePtr_<K>::call(), getTypePtr_<V>::call());
|
||||
|
@ -25,7 +25,8 @@ using c10::guts::make_unique;
|
||||
using c10::ivalue::TensorList;
|
||||
using c10::ivalue::IntList;
|
||||
using c10::intrusive_ptr;
|
||||
using c10::Dict;
|
||||
using c10::DictPtr;
|
||||
using c10::make_dict;
|
||||
using at::Tensor;
|
||||
using std::string;
|
||||
using std::unique_ptr;
|
||||
@ -209,11 +210,11 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntListO
|
||||
EXPECT_EQ(6, result[0].toIntListRef()[2]);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>> kernelWithMultipleOutputs(Tensor) {
|
||||
Dict<string, Tensor> dict;
|
||||
std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>> kernelWithMultipleOutputs(Tensor) {
|
||||
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
|
||||
dict.insert("first", dummyTensor(TensorType1()));
|
||||
dict.insert("second", dummyTensor(TensorType2()));
|
||||
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>>(
|
||||
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>>(
|
||||
dummyTensor(TensorType2()),
|
||||
5,
|
||||
{dummyTensor(TensorType1()), dummyTensor(TensorType2())},
|
||||
@ -516,7 +517,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithStringLi
|
||||
|
||||
int captured_dict_size = 0;
|
||||
|
||||
void kernelWithDictInputWithoutOutput(Dict<string, Tensor> input1) {
|
||||
void kernelWithDictInputWithoutOutput(DictPtr<string, Tensor> input1) {
|
||||
captured_dict_size = input1.size();
|
||||
}
|
||||
|
||||
@ -528,7 +529,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithDictInpu
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
captured_dict_size = 0;
|
||||
Dict<string, Tensor> dict;
|
||||
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
|
||||
dict.insert("key1", dummyTensor(TensorType1()));
|
||||
dict.insert("key2", dummyTensor(TensorType2()));
|
||||
auto outputs = callOp(*op, dict);
|
||||
@ -536,7 +537,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithDictInpu
|
||||
EXPECT_EQ(2, captured_dict_size);
|
||||
}
|
||||
|
||||
string kernelWithDictInputWithOutput(Dict<string, string> input1) {
|
||||
string kernelWithDictInputWithOutput(DictPtr<string, string> input1) {
|
||||
return input1.at("key2");
|
||||
}
|
||||
|
||||
@ -547,7 +548,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithDictInpu
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
Dict<string, string> dict;
|
||||
DictPtr<string, string> dict = make_dict<string, string>();
|
||||
dict.insert("key1", "value1");
|
||||
dict.insert("key2", "value2");
|
||||
auto outputs = callOp(*op, dict);
|
||||
@ -555,7 +556,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithDictInpu
|
||||
EXPECT_EQ("value2", outputs[0].toString()->string());
|
||||
}
|
||||
|
||||
Dict<string, string> kernelWithDictOutput(Dict<string, string> input) {
|
||||
DictPtr<string, string> kernelWithDictOutput(DictPtr<string, string> input) {
|
||||
return input;
|
||||
}
|
||||
|
||||
@ -566,7 +567,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithDictOutp
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_output", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
Dict<string, string> dict;
|
||||
DictPtr<string, string> dict = make_dict<string, string>();
|
||||
dict.insert("key1", "value1");
|
||||
dict.insert("key2", "value2");
|
||||
auto outputs = callOp(*op, dict);
|
||||
|
@ -13,7 +13,8 @@ using c10::guts::make_unique;
|
||||
using c10::ivalue::TensorList;
|
||||
using c10::ivalue::IntList;
|
||||
using c10::intrusive_ptr;
|
||||
using c10::Dict;
|
||||
using c10::DictPtr;
|
||||
using c10::make_dict;
|
||||
using at::Tensor;
|
||||
using std::string;
|
||||
using std::unique_ptr;
|
||||
@ -206,11 +207,11 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListOutput_
|
||||
EXPECT_EQ(6, result[0].toIntListRef()[2]);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>> kernelWithMultipleOutputs(Tensor) {
|
||||
Dict<string, Tensor> dict;
|
||||
std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>> kernelWithMultipleOutputs(Tensor) {
|
||||
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
|
||||
dict.insert("first", dummyTensor(TensorType1()));
|
||||
dict.insert("second", dummyTensor(TensorType2()));
|
||||
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>>(
|
||||
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>>(
|
||||
dummyTensor(TensorType2()),
|
||||
5,
|
||||
{dummyTensor(TensorType1()), dummyTensor(TensorType2())},
|
||||
@ -431,7 +432,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInpu
|
||||
|
||||
int captured_dict_size = 0;
|
||||
|
||||
void kernelWithDictInputWithoutOutput(Dict<string, Tensor> input1) {
|
||||
void kernelWithDictInputWithoutOutput(DictPtr<string, Tensor> input1) {
|
||||
captured_dict_size = input1.size();
|
||||
}
|
||||
|
||||
@ -443,7 +444,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithDictInput_with
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
captured_dict_size = 0;
|
||||
Dict<string, Tensor> dict;
|
||||
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
|
||||
dict.insert("key1", dummyTensor(TensorType1()));
|
||||
dict.insert("key2", dummyTensor(TensorType2()));
|
||||
auto outputs = callOp(*op, dict);
|
||||
@ -451,7 +452,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithDictInput_with
|
||||
EXPECT_EQ(2, captured_dict_size);
|
||||
}
|
||||
|
||||
string kernelWithDictInputWithOutput(Dict<string, string> input1) {
|
||||
string kernelWithDictInputWithOutput(DictPtr<string, string> input1) {
|
||||
return input1.at("key2");
|
||||
}
|
||||
|
||||
@ -462,7 +463,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithDictInput_with
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
Dict<string, string> dict;
|
||||
DictPtr<string, string> dict = make_dict<string, string>();
|
||||
dict.insert("key1", "value1");
|
||||
dict.insert("key2", "value2");
|
||||
auto outputs = callOp(*op, dict);
|
||||
@ -470,7 +471,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithDictInput_with
|
||||
EXPECT_EQ("value2", outputs[0].toString()->string());
|
||||
}
|
||||
|
||||
Dict<string, string> kernelWithDictOutput(Dict<string, string> input) {
|
||||
DictPtr<string, string> kernelWithDictOutput(DictPtr<string, string> input) {
|
||||
return input;
|
||||
}
|
||||
|
||||
@ -481,7 +482,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithDictOutput_whe
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_output", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
Dict<string, string> dict;
|
||||
DictPtr<string, string> dict = make_dict<string, string>();
|
||||
dict.insert("key1", "value1");
|
||||
dict.insert("key2", "value2");
|
||||
auto outputs = callOp(*op, dict);
|
||||
|
@ -60,10 +60,10 @@ namespace detail {
|
||||
}
|
||||
};
|
||||
template<class Key, class Value>
|
||||
struct ivalue_to_arg_type<Dict<Key, Value>> {
|
||||
static Dict<Key, Value> call(IValue&& v) {
|
||||
static_assert(guts::typelist::contains<supported_primitive_arg_types, Key>::value, "You tried to register a kernel with an unsupported argument type: c10::Dict<Key, Value> and Key is not one of the supported primitive types.");
|
||||
static_assert(guts::typelist::contains<supported_primitive_arg_types, Value>::value, "You tried to register a kernel with an unsupported argument type: c10::Dict<Key, Value> and Value is not one of the supported primitive types.");
|
||||
struct ivalue_to_arg_type<DictPtr<Key, Value>> {
|
||||
static DictPtr<Key, Value> call(IValue&& v) {
|
||||
static_assert(guts::typelist::contains<supported_primitive_arg_types, Key>::value, "You tried to register a kernel with an unsupported argument type: c10::DictPtr<Key, Value> and Key is not one of the supported primitive types.");
|
||||
static_assert(guts::typelist::contains<supported_primitive_arg_types, Value>::value, "You tried to register a kernel with an unsupported argument type: c10::DictPtr<Key, Value> and Value is not one of the supported primitive types.");
|
||||
|
||||
auto dict_ptr = std::move(v).toGenericDict();
|
||||
return impl::toTypedDict<Key, Value>(std::move(dict_ptr->elements()));
|
||||
@ -169,17 +169,17 @@ namespace detail {
|
||||
}
|
||||
};
|
||||
template<class Key, class Value>
|
||||
struct return_type_to_ivalue<c10::Dict<Key, Value>> {
|
||||
static IValue call(c10::Dict<Key, Value>&& v) {
|
||||
static_assert(guts::typelist::contains<supported_primitive_arg_types, Key>::value, "You tried to register a kernel with an unsupported return type: Dict<Key, Value> and Key is not one of the supported primitive types.");
|
||||
static_assert(guts::typelist::contains<supported_primitive_arg_types, Value>::value, "You tried to register a kernel with an unsupported return type: Dict<Key, Value> and Value is not one of the supported primitive types.");
|
||||
struct return_type_to_ivalue<c10::DictPtr<Key, Value>> {
|
||||
static IValue call(c10::DictPtr<Key, Value>&& v) {
|
||||
static_assert(guts::typelist::contains<supported_primitive_arg_types, Key>::value, "You tried to register a kernel with an unsupported return type: DictPtr<Key, Value> and Key is not one of the supported primitive types.");
|
||||
static_assert(guts::typelist::contains<supported_primitive_arg_types, Value>::value, "You tried to register a kernel with an unsupported return type: DictPtr<Key, Value> and Value is not one of the supported primitive types.");
|
||||
return IValue(impl::toGenericDict(std::move(v)));
|
||||
}
|
||||
};
|
||||
template<class Key, class Value>
|
||||
struct return_type_to_ivalue<std::unordered_map<Key, Value>> final {
|
||||
static IValue call(std::unordered_map<Key, Value>&& v) {
|
||||
c10::impl::GenericDict dict;
|
||||
c10::impl::GenericDictPtr dict = c10::impl::make_generic_dict();
|
||||
dict.reserve(v.size());
|
||||
for (auto& element : v) {
|
||||
dict.insert(return_type_to_ivalue<Key>::call(Key{element.first}), return_type_to_ivalue<Value>::call(std::move(element.second)));
|
||||
|
@ -14,7 +14,8 @@ using c10::guts::make_unique;
|
||||
using c10::ivalue::TensorList;
|
||||
using c10::ivalue::IntList;
|
||||
using c10::intrusive_ptr;
|
||||
using c10::Dict;
|
||||
using c10::DictPtr;
|
||||
using c10::make_dict;
|
||||
using at::Tensor;
|
||||
using std::unique_ptr;
|
||||
using std::string;
|
||||
@ -226,11 +227,11 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListOutput_w
|
||||
}
|
||||
|
||||
struct KernelWithMultipleOutputs final : OperatorKernel {
|
||||
std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>> operator()(Tensor) {
|
||||
Dict<string, Tensor> dict;
|
||||
std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>> operator()(Tensor) {
|
||||
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
|
||||
dict.insert("first", dummyTensor(TensorType1()));
|
||||
dict.insert("second", dummyTensor(TensorType2()));
|
||||
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>>(
|
||||
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>>(
|
||||
dummyTensor(TensorType2()),
|
||||
5,
|
||||
{dummyTensor(TensorType1()), dummyTensor(TensorType2())},
|
||||
@ -473,7 +474,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListInput
|
||||
int captured_dict_size = 0;
|
||||
|
||||
struct KernelWithDictInputWithoutOutput final : OperatorKernel {
|
||||
void operator()(Dict<string, Tensor> input1) {
|
||||
void operator()(DictPtr<string, Tensor> input1) {
|
||||
captured_dict_size = input1.size();
|
||||
}
|
||||
};
|
||||
@ -486,7 +487,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithDictInput_witho
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
captured_dict_size = 0;
|
||||
Dict<string, Tensor> dict;
|
||||
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
|
||||
dict.insert("key1", dummyTensor(TensorType1()));
|
||||
dict.insert("key2", dummyTensor(TensorType2()));
|
||||
auto outputs = callOp(*op, dict);
|
||||
@ -495,7 +496,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithDictInput_witho
|
||||
}
|
||||
|
||||
struct KernelWithDictInputWithOutput final : OperatorKernel {
|
||||
string operator()(Dict<string, string> input1) {
|
||||
string operator()(DictPtr<string, string> input1) {
|
||||
return input1.at("key2");
|
||||
}
|
||||
};
|
||||
@ -507,7 +508,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithDictInput_withO
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
Dict<string, string> dict;
|
||||
DictPtr<string, string> dict = make_dict<string, string>();
|
||||
dict.insert("key1", "value1");
|
||||
dict.insert("key2", "value2");
|
||||
auto outputs = callOp(*op, dict);
|
||||
@ -516,7 +517,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithDictInput_withO
|
||||
}
|
||||
|
||||
struct KernelWithDictOutput final : OperatorKernel {
|
||||
Dict<string, string> operator()(Dict<string, string> input) {
|
||||
DictPtr<string, string> operator()(DictPtr<string, string> input) {
|
||||
return input;
|
||||
}
|
||||
};
|
||||
@ -528,7 +529,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithDictOutput_when
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_output", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
Dict<string, string> dict;
|
||||
DictPtr<string, string> dict = make_dict<string, string>();
|
||||
dict.insert("key1", "value1");
|
||||
dict.insert("key2", "value2");
|
||||
auto outputs = callOp(*op, dict);
|
||||
|
@ -24,7 +24,8 @@ using c10::guts::make_unique;
|
||||
using c10::ivalue::TensorList;
|
||||
using c10::ivalue::IntList;
|
||||
using c10::intrusive_ptr;
|
||||
using c10::Dict;
|
||||
using c10::DictPtr;
|
||||
using c10::make_dict;
|
||||
using at::Tensor;
|
||||
using std::string;
|
||||
using std::unique_ptr;
|
||||
@ -192,11 +193,11 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListOut
|
||||
|
||||
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[], int?, Dict(str, Tensor))", [] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>> {
|
||||
Dict<string, Tensor> dict;
|
||||
.op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[], int?, Dict(str, Tensor))", [] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>> {
|
||||
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
|
||||
dict.insert("first", dummyTensor(TensorType1()));
|
||||
dict.insert("second", dummyTensor(TensorType2()));
|
||||
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>>(
|
||||
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>>(
|
||||
dummyTensor(TensorType2()),
|
||||
5,
|
||||
{dummyTensor(TensorType1()), dummyTensor(TensorType2())},
|
||||
@ -468,7 +469,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictInput_
|
||||
int captured_dict_size = 0;
|
||||
|
||||
auto registrar = RegisterOperators()
|
||||
.op("_test::dict_input(Dict(str, Tensor) input) -> ()", [&] (Dict<string, Tensor> input1) {
|
||||
.op("_test::dict_input(Dict(str, Tensor) input) -> ()", [&] (DictPtr<string, Tensor> input1) {
|
||||
captured_dict_size = input1.size();
|
||||
});
|
||||
|
||||
@ -476,7 +477,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictInput_
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
captured_dict_size = 0;
|
||||
Dict<string, Tensor> dict;
|
||||
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
|
||||
dict.insert("key1", dummyTensor(TensorType1()));
|
||||
dict.insert("key2", dummyTensor(TensorType2()));
|
||||
auto outputs = callOp(*op, dict);
|
||||
@ -486,14 +487,14 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictInput_
|
||||
|
||||
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictInput_withOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op("_test::dict_input(Dict(str, str) input) -> str", [&] (Dict<string, string> input1) {
|
||||
.op("_test::dict_input(Dict(str, str) input) -> str", [&] (DictPtr<string, string> input1) {
|
||||
return input1.at("key2");
|
||||
});
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
Dict<string, string> dict;
|
||||
DictPtr<string, string> dict = make_dict<string, string>();
|
||||
dict.insert("key1", "value1");
|
||||
dict.insert("key2", "value2");
|
||||
auto outputs = callOp(*op, dict);
|
||||
@ -503,14 +504,14 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictInput_
|
||||
|
||||
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op("_test::dict_output(Dict(str, str) input) -> Dict(str, str)", [] (Dict<string, string> input) {
|
||||
.op("_test::dict_output(Dict(str, str) input) -> Dict(str, str)", [] (DictPtr<string, string> input) {
|
||||
return input;
|
||||
});
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_output", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
Dict<string, string> dict;
|
||||
DictPtr<string, string> dict = make_dict<string, string>();
|
||||
dict.insert("key1", "value1");
|
||||
dict.insert("key2", "value2");
|
||||
auto outputs = callOp(*op, dict);
|
||||
|
@ -13,7 +13,8 @@ using c10::guts::make_unique;
|
||||
using c10::ivalue::TensorList;
|
||||
using c10::ivalue::IntList;
|
||||
using c10::intrusive_ptr;
|
||||
using c10::Dict;
|
||||
using c10::DictPtr;
|
||||
using c10::make_dict;
|
||||
using at::Tensor;
|
||||
using std::string;
|
||||
using std::unique_ptr;
|
||||
@ -192,11 +193,11 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListOutput_wh
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[], int?, Dict(str, Tensor))",
|
||||
RegisterOperators::options().kernel([] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>> {
|
||||
Dict<string, Tensor> dict;
|
||||
RegisterOperators::options().kernel([] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>> {
|
||||
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
|
||||
dict.insert("first", dummyTensor(TensorType1()));
|
||||
dict.insert("second", dummyTensor(TensorType2()));
|
||||
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>>(
|
||||
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>>(
|
||||
dummyTensor(TensorType2()),
|
||||
5,
|
||||
{dummyTensor(TensorType1()), dummyTensor(TensorType2())},
|
||||
@ -404,7 +405,7 @@ int captured_dict_size = 0;
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictInput_withoutOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op("_test::dict_input(Dict(str, Tensor) input) -> ()", RegisterOperators::options().kernel([] (Dict<string, Tensor> input1) {
|
||||
.op("_test::dict_input(Dict(str, Tensor) input) -> ()", RegisterOperators::options().kernel([] (DictPtr<string, Tensor> input1) {
|
||||
captured_dict_size = input1.size();
|
||||
}));
|
||||
|
||||
@ -412,7 +413,7 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictInput_withou
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
captured_dict_size = 0;
|
||||
Dict<string, Tensor> dict;
|
||||
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
|
||||
dict.insert("key1", dummyTensor(TensorType1()));
|
||||
dict.insert("key2", dummyTensor(TensorType2()));
|
||||
auto outputs = callOp(*op, dict);
|
||||
@ -422,14 +423,14 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictInput_withou
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictInput_withOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op("_test::dict_input(Dict(str, str) input) -> str", RegisterOperators::options().kernel([] (Dict<string, string> input1) {
|
||||
.op("_test::dict_input(Dict(str, str) input) -> str", RegisterOperators::options().kernel([] (DictPtr<string, string> input1) {
|
||||
return input1.at("key2");
|
||||
}));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
Dict<string, string> dict;
|
||||
DictPtr<string, string> dict = make_dict<string, string>();
|
||||
dict.insert("key1", "value1");
|
||||
dict.insert("key2", "value2");
|
||||
auto outputs = callOp(*op, dict);
|
||||
@ -439,14 +440,14 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictInput_withOu
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op("_test::dict_output(Dict(str, str) input) -> Dict(str, str)", RegisterOperators::options().kernel([] (Dict<string, string> input) {
|
||||
.op("_test::dict_output(Dict(str, str) input) -> Dict(str, str)", RegisterOperators::options().kernel([] (DictPtr<string, string> input) {
|
||||
return input;
|
||||
}));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_output", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
Dict<string, string> dict;
|
||||
DictPtr<string, string> dict = make_dict<string, string>();
|
||||
dict.insert("key1", "value1");
|
||||
dict.insert("key2", "value2");
|
||||
auto outputs = callOp(*op, dict);
|
||||
|
@ -638,33 +638,33 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
|
||||
// TODO Do we want to support list of optional ?
|
||||
|
||||
// dict types
|
||||
c10::Dict<std::string, std::string> str_dict;
|
||||
c10::DictPtr<std::string, std::string> str_dict = c10::make_dict<std::string, std::string>();
|
||||
str_dict.insert("key1", "value1");
|
||||
str_dict.insert("key2", "value2");
|
||||
testArgTypes<c10::Dict<std::string, std::string>>::test(
|
||||
str_dict, [] (c10::Dict<std::string, std::string> v) {
|
||||
testArgTypes<c10::DictPtr<std::string, std::string>>::test(
|
||||
str_dict, [] (c10::DictPtr<std::string, std::string> v) {
|
||||
EXPECT_EQ(2, v.size());
|
||||
EXPECT_EQ("value1", v.at("key1"));
|
||||
EXPECT_EQ("value2", v.at("key2"));
|
||||
},
|
||||
str_dict, [] (const IValue& v) {
|
||||
c10::Dict<std::string, std::string> dict = c10::impl::toTypedDict<std::string, std::string>(std::move(v.toGenericDict()->elements()));
|
||||
c10::DictPtr<std::string, std::string> dict = c10::impl::toTypedDict<std::string, std::string>(std::move(v.toGenericDict()->elements()));
|
||||
EXPECT_EQ(2, dict.size());
|
||||
EXPECT_EQ("value1", dict.at("key1"));
|
||||
EXPECT_EQ("value2", dict.at("key2"));
|
||||
},
|
||||
"(Dict(str, str) a) -> Dict(str, str)");
|
||||
c10::Dict<int64_t, Tensor> tensor_dict;
|
||||
c10::DictPtr<int64_t, Tensor> tensor_dict = c10::make_dict<int64_t, Tensor>();
|
||||
tensor_dict.insert(1, dummyTensor(TensorType1()));
|
||||
tensor_dict.insert(2, dummyTensor(TensorType2()));
|
||||
testArgTypes<c10::Dict<int64_t, Tensor>>::test(
|
||||
tensor_dict, [] (c10::Dict<int64_t, Tensor> v) {
|
||||
testArgTypes<c10::DictPtr<int64_t, Tensor>>::test(
|
||||
tensor_dict, [] (c10::DictPtr<int64_t, Tensor> v) {
|
||||
EXPECT_EQ(2, v.size());
|
||||
EXPECT_EQ(TensorType1(), v.at(1).type_id());
|
||||
EXPECT_EQ(TensorType2(), v.at(2).type_id());
|
||||
},
|
||||
tensor_dict, [] (const IValue& v) {
|
||||
c10::Dict<int64_t, Tensor> dict = c10::impl::toTypedDict<int64_t, Tensor>(std::move(v.toGenericDict()->elements()));
|
||||
c10::DictPtr<int64_t, Tensor> dict = c10::impl::toTypedDict<int64_t, Tensor>(std::move(v.toGenericDict()->elements()));
|
||||
EXPECT_EQ(2, dict.size());
|
||||
EXPECT_EQ(TensorType1(), dict.at(1).type_id());
|
||||
EXPECT_EQ(TensorType2(), dict.at(2).type_id());
|
||||
|
@ -60,7 +60,7 @@ struct InputToIValue<std::vector<std::string>> final {
|
||||
}
|
||||
};
|
||||
template<class Key, class Value>
|
||||
struct InputToIValue<c10::Dict<Key, Value>> final {
|
||||
struct InputToIValue<c10::DictPtr<Key, Value>> final {
|
||||
template<class T_>
|
||||
static c10::IValue call(T_&& v) {
|
||||
return c10::IValue(c10::impl::toGenericDict(std::move(v)));
|
||||
@ -70,7 +70,7 @@ template<class Key, class Value>
|
||||
struct InputToIValue<std::unordered_map<Key, Value>> final {
|
||||
template<class T_>
|
||||
static c10::IValue call(T_&& v) {
|
||||
c10::impl::GenericDict dict;
|
||||
c10::impl::GenericDictPtr dict = c10::impl::make_generic_dict();
|
||||
dict.reserve(v.size());
|
||||
for (auto& element : v) {
|
||||
dict.insert(InputToIValue<Key>::call(element.first), InputToIValue<Value>::call(element.second));
|
||||
|
@ -227,7 +227,7 @@ bool isSubvalueOf(const IValue& ivalue, TypePtr type) {
|
||||
auto dict_type = type->expect<DictType>();
|
||||
const auto& dict = ivalue.toGenericDictRef();
|
||||
return std::all_of(
|
||||
dict.begin(), dict.end(), [=](const c10::impl::GenericDict::const_iterator::value_type& item) {
|
||||
dict.begin(), dict.end(), [=](const c10::impl::GenericDictPtr::const_iterator::value_type& item) {
|
||||
return isSubvalueOf(item.key(), dict_type->getKeyType()) &&
|
||||
isSubvalueOf(item.value(), dict_type->getValueType());
|
||||
});
|
||||
|
@ -93,7 +93,7 @@ TEST(TorchScriptTest, TestDictArgMatching) {
|
||||
def dict_op(a: Dict[str, Tensor], b: str):
|
||||
return a[b]
|
||||
)JIT");
|
||||
c10::impl::GenericDict dict;
|
||||
c10::impl::GenericDictPtr dict = c10::impl::make_generic_dict();
|
||||
dict.insert("hello", torch::ones({2}));
|
||||
auto output = module->run_method("dict_op", dict, std::string("hello"));
|
||||
ASSERT_EQ(1, output.toTensor()[0].item<int64_t>());
|
||||
|
@ -667,7 +667,7 @@ OpCode Unpickler::readInstruction() {
|
||||
stack_.emplace_back(IValue(tuple));
|
||||
} break;
|
||||
case OpCode::EMPTY_DICT:
|
||||
stack_.emplace_back(c10::impl::GenericDict());
|
||||
stack_.emplace_back(c10::impl::make_generic_dict());
|
||||
break;
|
||||
case OpCode::APPENDS: {
|
||||
readList();
|
||||
|
@ -115,7 +115,7 @@ inline TypedIValue toTypedIValue(py::handle input) {
|
||||
} else if (PyDict_Check(input.ptr())) {
|
||||
// Check to make sure we can generate useful input/output types
|
||||
auto dict = py::cast<py::dict>(input);
|
||||
c10::impl::GenericDict elems;
|
||||
c10::impl::GenericDictPtr elems = c10::impl::make_generic_dict();
|
||||
|
||||
size_t len = py::len(dict);
|
||||
if (!len) {
|
||||
@ -205,7 +205,7 @@ inline IValue createGenericDict(
|
||||
py::handle obj,
|
||||
const TypePtr& key_type,
|
||||
const TypePtr& value_type) {
|
||||
c10::impl::GenericDict elems;
|
||||
c10::impl::GenericDictPtr elems = c10::impl::make_generic_dict();
|
||||
elems.reserve(py::len(obj));
|
||||
for (auto key : obj) {
|
||||
elems.insert(
|
||||
|
@ -875,7 +875,7 @@ RegisterOperators reg(
|
||||
"DictConstruct must have an even number of inputs");
|
||||
}
|
||||
return [=](Stack& stack) {
|
||||
c10::impl::GenericDict vals;
|
||||
c10::impl::GenericDictPtr vals = c10::impl::make_generic_dict();
|
||||
for (size_t i = 0; i < num_inputs; i += 2) {
|
||||
auto val = pop(stack);
|
||||
auto key = pop(stack);
|
||||
|
Reference in New Issue
Block a user