Dict is a reference type (#20669)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20669

Before, Dict was a value type, i.e. copying it did a deep copy.
Unfortunately, this doesn't work well with storing and passing Dicts around in IValues because IValues are reference types.
This diff changes Dict to be a reference type.

Reviewed By: dzhulgakov

Differential Revision: D15404911

fbshipit-source-id: dc990d3eb7cae044b74dd0253f8b704dde6a6c86
This commit is contained in:
Sebastian Messmer
2019-05-23 15:18:32 -07:00
committed by Facebook Github Bot
parent 93d5503f34
commit d5b7138a2c
19 changed files with 367 additions and 259 deletions

View File

@ -2,28 +2,43 @@
#include <c10/macros/Macros.h>
#include <c10/util/TypeTraits.h>
#include <c10/util/TypeList.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/intrusive_ptr.h>
namespace c10 {
struct IValue;
template<class Key, class Value> class Dict;
template<class Key, class Value> class DictPtr;
/**
* Creates an empty dict.
*/
template<class Key, class Value>
DictPtr<Key, Value> make_dict();
namespace impl {
bool shallowEquals(const IValue& lhs, const IValue& rhs);
}
namespace detail {
struct DictHash {
struct DictKeyHash {
size_t operator()(const IValue& ivalue) const;
};
struct DictEqualTo {
struct DictKeyEqualTo {
bool operator()(const IValue& lhs, const IValue& rhs) const {
return impl::shallowEquals(lhs, rhs);
}
};
using dict_map_type = ska::flat_hash_map<IValue, IValue, DictHash, DictEqualTo>;
struct DictImpl final : public c10::intrusive_ptr_target {
using dict_map_type = ska::flat_hash_map<IValue, IValue, DictKeyHash, DictKeyEqualTo>;
dict_map_type dict;
intrusive_ptr<DictImpl> copy() const;
};
}
namespace impl {
@ -62,7 +77,7 @@ public:
private:
Iterator iterator_;
friend class DictIterator<Key, Value, Iterator>;
friend class Dict<Key, Value>;
friend class DictPtr<Key, Value>;
friend bool operator==<Key, Value, Iterator>(const DictIterator<Key, Value, Iterator>& lhs, const DictIterator<Key, Value, Iterator>& rhs);
};
@ -100,7 +115,7 @@ public:
// the template automatically disables the operator when we are already a
// const_iterator, because that would cause a lot of compiler warnings otherwise.
template<class const_iterator_ = typename detail::dict_map_type::const_iterator, class = guts::enable_if_t<!std::is_same<const_iterator_, Iterator>::value>>
template<class const_iterator_ = typename detail::DictImpl::dict_map_type::const_iterator, class = guts::enable_if_t<!std::is_same<const_iterator_, Iterator>::value>>
/* implicit */ operator DictIterator<Key, Value, const_iterator_>() const
{
return DictIterator<Key, Value, const_iterator_> { const_iterator_ { entryRef_.iterator_ } };
@ -111,8 +126,8 @@ private:
DictEntryRef<Key, Value, Iterator> entryRef_;
friend class DictIterator<Key, Value, typename detail::dict_map_type::iterator>;
friend class Dict<Key, Value>;
friend class DictIterator<Key, Value, typename detail::DictImpl::dict_map_type::iterator>;
friend class DictPtr<Key, Value>;
friend bool operator==<Key, Value, Iterator>(const DictIterator& lhs, const DictIterator& rhs);
};
@ -126,56 +141,72 @@ inline bool operator!=(const DictIterator<Key, Value, Iterator>& lhs, const Dict
return !(lhs == rhs);
}
template<class Key, class Value> Dict<Key, Value> toTypedDict(Dict<IValue, IValue>&& dict);
template<class Key, class Value> Dict<IValue, IValue> toGenericDict(Dict<Key, Value>&& dict);
template<class Key, class Value> DictPtr<Key, Value> toTypedDict(DictPtr<IValue, IValue> dict);
template<class Key, class Value> DictPtr<IValue, IValue> toGenericDict(DictPtr<Key, Value> dict);
}
/**
* An object of this class stores a map from Key to Value.
* You can create instances using the make_dict<Key, Value>() function.
*
* We use this class in the PyTorch kernel API instead of
* std::unordered_map<Key, Value>, because that allows us
* to do optimizations and switch out the underlying map
* implementation without breaking backwards compatibility
* This is a pointer type. After a copy, both DictPtrs
* will share the same storage:
*
* > DictPtr<int, string> a = make_dict<int, string>();
* > DictPtr<int, string> b = a;
* > b.insert(3, "three");
* > ASSERT("three" == a.at(3));
*
* We use this class in the PyTorch kernel API because that
* allows us to do optimizations and switch out the underlying
* map implementation without breaking backwards compatibility
* for the kernel API.
*
* The API of this class is borrowed from std::unordered_map,
* but with slight differences and it intentionally does not
* support the full std::unordered_map API, because a more
* narrow abstraction gives us more freedom to change the internals.
*/
template<class Key, class Value>
class Dict final {
class DictPtr final {
private:
// map_ stores the underlying map as a ska::flat_hash_map.
// impl_ stores the underlying map as a ska::flat_hash_map.
// We intentionally don't offer conversion from/to
// ska::flat_hash_map, return references to it or something like that,
// because such operations would get expensive if we switch out
// the actual map implementation.
detail::dict_map_type map_;
// This is an intrusive_ptr because DictPtr is a pointer type.
// Invariant: This will never be a nullptr, there will always be a valid
// DictImpl.
c10::intrusive_ptr<detail::DictImpl> impl_;
explicit Dict(detail::dict_map_type&& map): map_(std::move(map)) {}
template<class K, class V> friend Dict<K, V> impl::toTypedDict(Dict<IValue, IValue>&&);
template<class K, class V> friend Dict<IValue, IValue> impl::toGenericDict(Dict<K, V>&&);
explicit DictPtr(c10::intrusive_ptr<detail::DictImpl>&& impl);
template<class K, class V> friend DictPtr<K, V> impl::toTypedDict(DictPtr<IValue, IValue>);
template<class K, class V> friend DictPtr<IValue, IValue> impl::toGenericDict(DictPtr<K, V>);
public:
using key_type = Key;
using mapped_type = Value;
using size_type = typename detail::dict_map_type::size_type;
using iterator = impl::DictIterator<Key, Value, typename detail::dict_map_type::iterator>;
using const_iterator = impl::DictIterator<Key, Value, typename detail::dict_map_type::const_iterator>;
using size_type = typename detail::DictImpl::dict_map_type::size_type;
using iterator = impl::DictIterator<Key, Value, typename detail::DictImpl::dict_map_type::iterator>;
using const_iterator = impl::DictIterator<Key, Value, typename detail::DictImpl::dict_map_type::const_iterator>;
/**
* Creates an empty dict.
*/
explicit Dict() = default;
friend DictPtr make_dict<Key, Value>();
~Dict() = default;
// please use make_dict instead
DictPtr() = delete;
Dict(const Dict&) = default;
Dict(Dict&&) noexcept = default;
Dict& operator=(const Dict&) = default;
Dict& operator=(Dict&&) noexcept = default;
~DictPtr() = default;
DictPtr(const DictPtr&) = default;
DictPtr& operator=(const DictPtr&) = default;
DictPtr(DictPtr&&) noexcept;
DictPtr& operator=(DictPtr&&) noexcept;
/**
* Create a new DictPtr pointing to a deep copy of the same data.
* The DictPtr returned is a new dict with separate storage.
* Changes in it are not reflected in the original dict or vice versa.
*/
DictPtr copy() const;
/**
* Returns an iterator to the first element of the container.
@ -300,19 +331,23 @@ public:
};
namespace impl {
// GenericDict is how IValue stores dicts. It is, however, not part of the
// GenericDictPtr is how IValue stores dicts. It is, however, not part of the
// public API. Kernels should use Dicts with concrete Key, Value types instead
// (maybe except for some internal prim ops).
using GenericDict = Dict<IValue, IValue>;
using GenericDictPtr = DictPtr<IValue, IValue>;
template<class Key, class Value>
Dict<Key, Value> toTypedDict(GenericDict&& dict) {
return Dict<Key, Value>(std::move(dict.map_));
inline GenericDictPtr make_generic_dict() {
return make_dict<IValue, IValue>();
}
template<class Key, class Value>
GenericDict toGenericDict(Dict<Key, Value>&& dict) {
return GenericDict(std::move(dict.map_));
DictPtr<Key, Value> toTypedDict(GenericDictPtr dict) {
return DictPtr<Key, Value>(std::move(dict.impl_));
}
template<class Key, class Value>
GenericDictPtr toGenericDict(DictPtr<Key, Value> dict) {
return GenericDictPtr(std::move(dict.impl_));
}
}

View File

@ -23,7 +23,7 @@ inline bool shallowEquals(const IValue& lhs, const IValue& rhs) {
namespace detail {
inline size_t DictHash::operator()(const IValue& ivalue) const {
inline size_t DictKeyHash::operator()(const IValue& ivalue) const {
if (ivalue.isInt()) {
return std::hash<int>()(ivalue.toInt());
} else if (ivalue.isString()) {
@ -37,59 +37,90 @@ inline size_t DictHash::operator()(const IValue& ivalue) const {
}
}
inline intrusive_ptr<DictImpl> DictImpl::copy() const {
auto result = make_intrusive<DictImpl>();
result->dict = dict;
return result;
}
}
template<class Key, class Value>
typename Dict<Key, Value>::iterator Dict<Key, Value>::begin() {
return iterator{map_.begin()};
DictPtr<Key, Value> make_dict() {
return DictPtr<Key, Value>(make_intrusive<detail::DictImpl>());
}
template<class Key, class Value>
typename Dict<Key, Value>::const_iterator Dict<Key, Value>::begin() const {
return const_iterator{map_.begin()};
DictPtr<Key, Value>::DictPtr(DictPtr&& rhs) noexcept: impl_(std::move(rhs.impl_)) {
rhs.impl_ = make_intrusive<detail::DictImpl>();
}
template<class Key, class Value>
typename Dict<Key, Value>::const_iterator Dict<Key, Value>::cbegin() const {
return const_iterator{map_.cbegin()};
DictPtr<Key, Value>::DictPtr(c10::intrusive_ptr<detail::DictImpl>&& impl): impl_(std::move(impl)) {}
template<class Key, class Value>
DictPtr<Key, Value>& DictPtr<Key, Value>::operator=(DictPtr&& rhs) noexcept {
impl_ = std::move(rhs.impl_);
rhs.impl_ = make_intrusive<detail::DictImpl>();
return *this;
}
template<class Key, class Value>
typename Dict<Key, Value>::iterator Dict<Key, Value>::end() {
return iterator{map_.end()};
DictPtr<Key, Value> DictPtr<Key, Value>::copy() const {
return DictPtr<Key, Value>(impl_->copy());
}
template<class Key, class Value>
typename Dict<Key, Value>::const_iterator Dict<Key, Value>::end() const {
return const_iterator{map_.end()};
typename DictPtr<Key, Value>::iterator DictPtr<Key, Value>::begin() {
return iterator{impl_->dict.begin()};
}
template<class Key, class Value>
typename Dict<Key, Value>::const_iterator Dict<Key, Value>::cend() const {
return const_iterator{map_.cend()};
typename DictPtr<Key, Value>::const_iterator DictPtr<Key, Value>::begin() const {
return const_iterator{impl_->dict.begin()};
}
template<class Key, class Value>
bool Dict<Key, Value>::empty() const {
return map_.empty();
typename DictPtr<Key, Value>::const_iterator DictPtr<Key, Value>::cbegin() const {
return const_iterator{impl_->dict.cbegin()};
}
template<class Key, class Value>
typename Dict<Key, Value>::size_type Dict<Key, Value>::size() const {
return map_.size();
typename DictPtr<Key, Value>::iterator DictPtr<Key, Value>::end() {
return iterator{impl_->dict.end()};
}
template<class Key, class Value>
void Dict<Key, Value>::clear() {
map_.clear();
typename DictPtr<Key, Value>::const_iterator DictPtr<Key, Value>::end() const {
return const_iterator{impl_->dict.end()};
}
template<class Key, class Value>
typename DictPtr<Key, Value>::const_iterator DictPtr<Key, Value>::cend() const {
return const_iterator{impl_->dict.cend()};
}
template<class Key, class Value>
bool DictPtr<Key, Value>::empty() const {
return impl_->dict.empty();
}
template<class Key, class Value>
typename DictPtr<Key, Value>::size_type DictPtr<Key, Value>::size() const {
return impl_->dict.size();
}
template<class Key, class Value>
void DictPtr<Key, Value>::clear() {
impl_->dict.clear();
}
template<class Key, class Value>
template<class Key_, class Value_>
std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert(Key_&& key, Value_&& value) {
static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of Dict::insert");
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of Dict::insert");
auto inserted = map_.insert(std::pair<IValue, IValue>{
std::pair<typename DictPtr<Key, Value>::iterator, bool> DictPtr<Key, Value>::insert(Key_&& key, Value_&& value) {
static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of DictPtr::insert");
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of DictPtr::insert");
auto inserted = impl_->dict.insert(std::pair<IValue, IValue>{
Key(std::forward<Key_>(key)),
Value(std::forward<Value_>(value))});
return {iterator{inserted.first}, inserted.second};
@ -97,48 +128,48 @@ std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert(Ke
template<class Key, class Value>
template<class Key_, class Value_>
std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert_or_assign(Key_&& key, Value_&& value) {
static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of Dict::insert_or_assign");
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of Dict::insert_or_assign");
auto inserted = map_.insert_or_assign(
std::pair<typename DictPtr<Key, Value>::iterator, bool> DictPtr<Key, Value>::insert_or_assign(Key_&& key, Value_&& value) {
static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of DictPtr::insert_or_assign");
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of DictPtr::insert_or_assign");
auto inserted = impl_->dict.insert_or_assign(
Key(std::forward<Key_>(key)),
Value(std::forward<Value_>(value)));
return {iterator{inserted.first}, inserted.second};
}
template<class Key, class Value>
void Dict<Key, Value>::erase(const_iterator iter) {
map_.erase(iter.entryRef_.iterator_);
void DictPtr<Key, Value>::erase(const_iterator iter) {
impl_->dict.erase(iter.entryRef_.iterator_);
}
template<class Key, class Value>
C10_NODISCARD size_t Dict<Key, Value>::erase(const Key& key) {
return map_.erase(key);
C10_NODISCARD size_t DictPtr<Key, Value>::erase(const Key& key) {
return impl_->dict.erase(key);
}
template<class Key, class Value>
Value Dict<Key, Value>::at(const Key& key) const {
return map_.at(key).template to<Value>();
Value DictPtr<Key, Value>::at(const Key& key) const {
return impl_->dict.at(key).template to<Value>();
}
template<class Key, class Value>
typename Dict<Key, Value>::iterator Dict<Key, Value>::find(const Key& key) {
return iterator{map_.find(key)};
typename DictPtr<Key, Value>::iterator DictPtr<Key, Value>::find(const Key& key) {
return iterator{impl_->dict.find(key)};
}
template<class Key, class Value>
typename Dict<Key, Value>::const_iterator Dict<Key, Value>::find(const Key& key) const {
return const_iterator{map_.find(key)};
typename DictPtr<Key, Value>::const_iterator DictPtr<Key, Value>::find(const Key& key) const {
return const_iterator{impl_->dict.find(key)};
}
template<class Key, class Value>
bool Dict<Key, Value>::contains(const Key& key) const {
bool DictPtr<Key, Value>::contains(const Key& key) const {
return end() != find(key);
}
template<class Key, class Value>
void Dict<Key, Value>::reserve(size_type count) {
map_.reserve(count);
void DictPtr<Key, Value>::reserve(size_type count) {
impl_->dict.reserve(count);
}
}

View File

@ -4,33 +4,34 @@
#include <string>
using std::string;
using c10::Dict;
using c10::DictPtr;
using c10::make_dict;
TEST(DictTest, givenEmptyDict_whenCallingEmpty_thenReturnsTrue) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
EXPECT_TRUE(dict.empty());
}
TEST(DictTest, givenNonemptyDict_whenCallingEmpty_thenReturnsFalse) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "value");
EXPECT_FALSE(dict.empty());
}
TEST(DictTest, givenEmptyDict_whenCallingSize_thenReturnsZero) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
EXPECT_EQ(0, dict.size());
}
TEST(DictTest, givenNonemptyDict_whenCallingSize_thenReturnsNumberOfElements) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "value");
dict.insert(4, "value2");
EXPECT_EQ(2, dict.size());
}
TEST(DictTest, givenNonemptyDict_whenCallingClear_thenIsEmpty) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "value");
dict.insert(4, "value2");
dict.clear();
@ -38,24 +39,24 @@ TEST(DictTest, givenNonemptyDict_whenCallingClear_thenIsEmpty) {
}
TEST(DictTest, whenInsertingNewKey_thenReturnsTrueAndIteratorToNewElement) {
Dict<int, string> dict;
std::pair<Dict<int, string>::iterator, bool> result = dict.insert(3, "value");
DictPtr<int, string> dict = make_dict<int, string>();
std::pair<DictPtr<int, string>::iterator, bool> result = dict.insert(3, "value");
EXPECT_TRUE(result.second);
EXPECT_EQ(3, result.first->key());
EXPECT_EQ("value", result.first->value());
}
TEST(DictTest, whenInsertingExistingKey_thenReturnsFalseAndIteratorToExistingElement) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "old_value");
std::pair<Dict<int, string>::iterator, bool> result = dict.insert(3, "new_value");
std::pair<DictPtr<int, string>::iterator, bool> result = dict.insert(3, "new_value");
EXPECT_FALSE(result.second);
EXPECT_EQ(3, result.first->key());
EXPECT_EQ("old_value", result.first->value());
}
TEST(DictTest, whenInsertingExistingKey_thenDoesNotModifyDict) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "old_value");
dict.insert(3, "new_value");
EXPECT_EQ(1, dict.size());
@ -64,24 +65,24 @@ TEST(DictTest, whenInsertingExistingKey_thenDoesNotModifyDict) {
}
TEST(DictTest, whenInsertOrAssigningNewKey_thenReturnsTrueAndIteratorToNewElement) {
Dict<int, string> dict;
std::pair<Dict<int, string>::iterator, bool> result = dict.insert_or_assign(3, "value");
DictPtr<int, string> dict = make_dict<int, string>();
std::pair<DictPtr<int, string>::iterator, bool> result = dict.insert_or_assign(3, "value");
EXPECT_TRUE(result.second);
EXPECT_EQ(3, result.first->key());
EXPECT_EQ("value", result.first->value());
}
TEST(DictTest, whenInsertOrAssigningExistingKey_thenReturnsFalseAndIteratorToChangedElement) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "old_value");
std::pair<Dict<int, string>::iterator, bool> result = dict.insert_or_assign(3, "new_value");
std::pair<DictPtr<int, string>::iterator, bool> result = dict.insert_or_assign(3, "new_value");
EXPECT_FALSE(result.second);
EXPECT_EQ(3, result.first->key());
EXPECT_EQ("new_value", result.first->value());
}
TEST(DictTest, whenInsertOrAssigningExistingKey_thenDoesModifyDict) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "old_value");
dict.insert_or_assign(3, "new_value");
EXPECT_EQ(1, dict.size());
@ -90,8 +91,8 @@ TEST(DictTest, whenInsertOrAssigningExistingKey_thenDoesModifyDict) {
}
TEST(DictTest, givenEmptyDict_whenIterating_thenBeginIsEnd) {
Dict<int, string> dict{};
const Dict<int, string> cdict{};
DictPtr<int, string> dict = make_dict<int, string>();
const DictPtr<int, string> cdict = make_dict<int, string>();
EXPECT_EQ(dict.begin(), dict.end());
EXPECT_EQ(dict.cbegin(), dict.cend());
EXPECT_EQ(cdict.begin(), cdict.end());
@ -99,12 +100,12 @@ TEST(DictTest, givenEmptyDict_whenIterating_thenBeginIsEnd) {
}
TEST(DictTest, givenMutableDict_whenIterating_thenFindsElements) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(5, "5");
bool found_first = false;
bool found_second = false;
for (Dict<int, string>::iterator iter = dict.begin(); iter != dict.end(); ++iter) {
for (DictPtr<int, string>::iterator iter = dict.begin(); iter != dict.end(); ++iter) {
if (iter->key() == 3) {
EXPECT_EQ("3", iter->value());
EXPECT_FALSE(found_first);
@ -122,7 +123,7 @@ TEST(DictTest, givenMutableDict_whenIterating_thenFindsElements) {
}
TEST(DictTest, givenMutableDict_whenIteratingWithForeach_thenFindsElements) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(5, "5");
bool found_first = false;
@ -145,13 +146,13 @@ TEST(DictTest, givenMutableDict_whenIteratingWithForeach_thenFindsElements) {
}
TEST(DictTest, givenConstDict_whenIterating_thenFindsElements) {
Dict<int, string> dict_;
DictPtr<int, string> dict_ = make_dict<int, string>();
dict_.insert(3, "3");
dict_.insert(5, "5");
const Dict<int, string>& dict = dict_;
const DictPtr<int, string>& dict = dict_;
bool found_first = false;
bool found_second = false;
for (Dict<int, string>::const_iterator iter = dict.begin(); iter != dict.end(); ++iter) {
for (DictPtr<int, string>::const_iterator iter = dict.begin(); iter != dict.end(); ++iter) {
if (iter->key() == 3) {
EXPECT_EQ("3", iter->value());
EXPECT_FALSE(found_first);
@ -169,10 +170,10 @@ TEST(DictTest, givenConstDict_whenIterating_thenFindsElements) {
}
TEST(DictTest, givenConstDict_whenIteratingWithForeach_thenFindsElements) {
Dict<int, string> dict_;
DictPtr<int, string> dict_ = make_dict<int, string>();
dict_.insert(3, "3");
dict_.insert(5, "5");
const Dict<int, string>& dict = dict_;
const DictPtr<int, string>& dict = dict_;
bool found_first = false;
bool found_second = false;
for (const auto& elem : dict) {
@ -193,28 +194,28 @@ TEST(DictTest, givenConstDict_whenIteratingWithForeach_thenFindsElements) {
}
TEST(DictTest, givenIterator_thenCanModifyValue) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "old_value");
dict.begin()->setValue("new_value");
EXPECT_EQ("new_value", dict.begin()->value());
}
TEST(DictTest, givenOneElementDict_whenErasingByConstIterator_thenDictIsEmpty) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.erase(dict.cbegin());
EXPECT_TRUE(dict.empty());
}
TEST(DictTest, givenOneElementDict_whenErasingByIterator_thenDictIsEmpty) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.erase(dict.begin());
EXPECT_TRUE(dict.empty());
}
TEST(DictTest, givenOneElementDict_whenErasingByKey_thenReturnsOneAndDictIsEmpty) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
bool result = dict.erase(3);
EXPECT_EQ(1, result);
@ -222,7 +223,7 @@ TEST(DictTest, givenOneElementDict_whenErasingByKey_thenReturnsOneAndDictIsEmpty
}
TEST(DictTest, givenOneElementDict_whenErasingByNonexistingKey_thenReturnsZeroAndDictIsUnchanged) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
bool result = dict.erase(4);
EXPECT_EQ(0, result);
@ -230,80 +231,80 @@ TEST(DictTest, givenOneElementDict_whenErasingByNonexistingKey_thenReturnsZeroAn
}
TEST(DictTest, whenCallingAtWithExistingKey_thenReturnsCorrectElement) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
EXPECT_EQ("4", dict.at(4));
}
TEST(DictTest, whenCallingAtWithNonExistingKey_thenReturnsCorrectElement) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
EXPECT_THROW(dict.at(5), std::out_of_range);
}
TEST(DictTest, givenMutableDict_whenCallingFindOnExistingKey_thenFindsCorrectElement) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
Dict<int, string>::iterator found = dict.find(3);
DictPtr<int, string>::iterator found = dict.find(3);
EXPECT_EQ(3, found->key());
EXPECT_EQ("3", found->value());
}
TEST(DictTest, givenMutableDict_whenCallingFindOnNonExistingKey_thenReturnsEnd) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
Dict<int, string>::iterator found = dict.find(5);
DictPtr<int, string>::iterator found = dict.find(5);
EXPECT_EQ(dict.end(), found);
}
TEST(DictTest, givenConstDict_whenCallingFindOnExistingKey_thenFindsCorrectElement) {
Dict<int, string> dict_;
DictPtr<int, string> dict_ = make_dict<int, string>();
dict_.insert(3, "3");
dict_.insert(4, "4");
const Dict<int, string>& dict = dict_;
Dict<int, string>::const_iterator found = dict.find(3);
const DictPtr<int, string>& dict = dict_;
DictPtr<int, string>::const_iterator found = dict.find(3);
EXPECT_EQ(3, found->key());
EXPECT_EQ("3", found->value());
}
TEST(DictTest, givenConstDict_whenCallingFindOnNonExistingKey_thenReturnsEnd) {
Dict<int, string> dict_;
DictPtr<int, string> dict_ = make_dict<int, string>();
dict_.insert(3, "3");
dict_.insert(4, "4");
const Dict<int, string>& dict = dict_;
Dict<int, string>::const_iterator found = dict.find(5);
const DictPtr<int, string>& dict = dict_;
DictPtr<int, string>::const_iterator found = dict.find(5);
EXPECT_EQ(dict.end(), found);
}
TEST(DictTest, whenCallingContainsWithExistingKey_thenReturnsTrue) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
EXPECT_TRUE(dict.contains(3));
}
TEST(DictTest, whenCallingContainsWithNonExistingKey_thenReturnsFalse) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
EXPECT_FALSE(dict.contains(5));
}
TEST(DictTest, whenCallingReserve_thenDoesntCrash) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.reserve(100);
}
TEST(DictTest, whenCopyConstructingDict_thenAreEqual) {
Dict<int, string> dict1;
DictPtr<int, string> dict1 = make_dict<int, string>();
dict1.insert(3, "3");
dict1.insert(4, "4");
Dict<int, string> dict2(dict1);
DictPtr<int, string> dict2(dict1);
EXPECT_EQ(2, dict2.size());
EXPECT_EQ("3", dict2.at(3));
@ -311,11 +312,11 @@ TEST(DictTest, whenCopyConstructingDict_thenAreEqual) {
}
TEST(DictTest, whenCopyAssigningDict_thenAreEqual) {
Dict<int, string> dict1;
DictPtr<int, string> dict1 = make_dict<int, string>();
dict1.insert(3, "3");
dict1.insert(4, "4");
Dict<int, string> dict2;
DictPtr<int, string> dict2 = make_dict<int, string>();
dict2 = dict1;
EXPECT_EQ(2, dict2.size());
@ -323,12 +324,24 @@ TEST(DictTest, whenCopyAssigningDict_thenAreEqual) {
EXPECT_EQ("4", dict2.at(4));
}
TEST(DictTest, whenMoveConstructingDict_thenNewIsCorrect) {
Dict<int, string> dict1;
TEST(DictTest, whenCopyingDict_thenAreEqual) {
DictPtr<int, string> dict1 = make_dict<int, string>();
dict1.insert(3, "3");
dict1.insert(4, "4");
Dict<int, string> dict2(std::move(dict1));
DictPtr<int, string> dict2 = dict1.copy();
EXPECT_EQ(2, dict2.size());
EXPECT_EQ("3", dict2.at(3));
EXPECT_EQ("4", dict2.at(4));
}
TEST(DictTest, whenMoveConstructingDict_thenNewIsCorrect) {
DictPtr<int, string> dict1 = make_dict<int, string>();
dict1.insert(3, "3");
dict1.insert(4, "4");
DictPtr<int, string> dict2(std::move(dict1));
EXPECT_EQ(2, dict2.size());
EXPECT_EQ("3", dict2.at(3));
@ -336,11 +349,11 @@ TEST(DictTest, whenMoveConstructingDict_thenNewIsCorrect) {
}
TEST(DictTest, whenMoveAssigningDict_thenNewIsCorrect) {
Dict<int, string> dict1;
DictPtr<int, string> dict1 = make_dict<int, string>();
dict1.insert(3, "3");
dict1.insert(4, "4");
Dict<int, string> dict2;
DictPtr<int, string> dict2 = make_dict<int, string>();
dict2 = std::move(dict1);
EXPECT_EQ(2, dict2.size());
@ -349,95 +362,95 @@ TEST(DictTest, whenMoveAssigningDict_thenNewIsCorrect) {
}
TEST(DictTest, whenMoveConstructingDict_thenOldIsEmpty) {
Dict<int, string> dict1;
DictPtr<int, string> dict1 = make_dict<int, string>();
dict1.insert(3, "3");
dict1.insert(4, "4");
Dict<int, string> dict2(std::move(dict1));
DictPtr<int, string> dict2(std::move(dict1));
EXPECT_TRUE(dict1.empty());
}
TEST(DictTest, whenMoveAssigningDict_thenOldIsEmpty) {
Dict<int, string> dict1;
DictPtr<int, string> dict1 = make_dict<int, string>();
dict1.insert(3, "3");
dict1.insert(4, "4");
Dict<int, string> dict2;
DictPtr<int, string> dict2 = make_dict<int, string>();
dict2 = std::move(dict1);
EXPECT_TRUE(dict1.empty());
}
TEST(DictTest, givenMutableIterator_whenAssigningToConstIterator_thenWorks) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
Dict<int, string>::iterator iter = dict.begin();
Dict<int, string>::const_iterator const_iter = iter;
DictPtr<int, string>::iterator iter = dict.begin();
DictPtr<int, string>::const_iterator const_iter = iter;
EXPECT_EQ(3, const_iter->key());
EXPECT_EQ("3", const_iter->value());
}
TEST(DictTest, givenMutableIterator_whenPostfixIncrementing_thenMovesToNextAndReturnsOldPosition) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
Dict<int, string>::iterator iter1 = dict.begin();
Dict<int, string>::iterator iter2 = iter1++;
DictPtr<int, string>::iterator iter1 = dict.begin();
DictPtr<int, string>::iterator iter2 = iter1++;
EXPECT_NE(dict.begin()->key(), iter1->key());
EXPECT_EQ(dict.begin()->key(), iter2->key());
}
TEST(DictTest, givenConstIterator_whenPostfixIncrementing_thenMovesToNextAndReturnsOldPosition) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
Dict<int, string>::const_iterator iter1 = dict.cbegin();
Dict<int, string>::const_iterator iter2 = iter1++;
DictPtr<int, string>::const_iterator iter1 = dict.cbegin();
DictPtr<int, string>::const_iterator iter2 = iter1++;
EXPECT_NE(dict.begin()->key(), iter1->key());
EXPECT_EQ(dict.begin()->key(), iter2->key());
}
TEST(DictTest, givenMutableIterator_whenPrefixIncrementing_thenMovesToNextAndReturnsNewPosition) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
Dict<int, string>::iterator iter1 = dict.begin();
Dict<int, string>::iterator iter2 = ++iter1;
DictPtr<int, string>::iterator iter1 = dict.begin();
DictPtr<int, string>::iterator iter2 = ++iter1;
EXPECT_NE(dict.begin()->key(), iter1->key());
EXPECT_NE(dict.begin()->key(), iter2->key());
}
TEST(DictTest, givenConstIterator_whenPrefixIncrementing_thenMovesToNextAndReturnsNewPosition) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
Dict<int, string>::const_iterator iter1 = dict.cbegin();
Dict<int, string>::const_iterator iter2 = ++iter1;
DictPtr<int, string>::const_iterator iter1 = dict.cbegin();
DictPtr<int, string>::const_iterator iter2 = ++iter1;
EXPECT_NE(dict.begin()->key(), iter1->key());
EXPECT_NE(dict.begin()->key(), iter2->key());
}
TEST(DictTest, givenEqualMutableIterators_thenAreEqual) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
Dict<int, string>::iterator iter1 = dict.begin();
Dict<int, string>::iterator iter2 = dict.begin();
DictPtr<int, string>::iterator iter1 = dict.begin();
DictPtr<int, string>::iterator iter2 = dict.begin();
EXPECT_TRUE(iter1 == iter2);
EXPECT_FALSE(iter1 != iter2);
}
TEST(DictTest, givenDifferentMutableIterators_thenAreNotEqual) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
Dict<int, string>::iterator iter1 = dict.begin();
Dict<int, string>::iterator iter2 = dict.begin();
DictPtr<int, string>::iterator iter1 = dict.begin();
DictPtr<int, string>::iterator iter2 = dict.begin();
iter2++;
EXPECT_FALSE(iter1 == iter2);
@ -445,23 +458,23 @@ TEST(DictTest, givenDifferentMutableIterators_thenAreNotEqual) {
}
TEST(DictTest, givenEqualConstIterators_thenAreEqual) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
Dict<int, string>::const_iterator iter1 = dict.cbegin();
Dict<int, string>::const_iterator iter2 = dict.cbegin();
DictPtr<int, string>::const_iterator iter1 = dict.cbegin();
DictPtr<int, string>::const_iterator iter2 = dict.cbegin();
EXPECT_TRUE(iter1 == iter2);
EXPECT_FALSE(iter1 != iter2);
}
TEST(DictTest, givenDifferentConstIterators_thenAreNotEqual) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
dict.insert(4, "4");
Dict<int, string>::const_iterator iter1 = dict.cbegin();
Dict<int, string>::const_iterator iter2 = dict.cbegin();
DictPtr<int, string>::const_iterator iter1 = dict.cbegin();
DictPtr<int, string>::const_iterator iter2 = dict.cbegin();
iter2++;
EXPECT_FALSE(iter1 == iter2);
@ -469,10 +482,10 @@ TEST(DictTest, givenDifferentConstIterators_thenAreNotEqual) {
}
TEST(DictTest, givenMutableIterator_whenDereferencing_thenPointsToCorrectElement) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
Dict<int, string>::iterator iter = dict.begin();
DictPtr<int, string>::iterator iter = dict.begin();
EXPECT_EQ(3, (*iter).key());
EXPECT_EQ("3", (*iter).value());
EXPECT_EQ(3, iter->key());
@ -480,10 +493,10 @@ TEST(DictTest, givenMutableIterator_whenDereferencing_thenPointsToCorrectElement
}
TEST(DictTest, givenConstIterator_whenDereferencing_thenPointsToCorrectElement) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
Dict<int, string>::const_iterator iter = dict.cbegin();
DictPtr<int, string>::const_iterator iter = dict.cbegin();
EXPECT_EQ(3, (*iter).key());
EXPECT_EQ("3", (*iter).value());
EXPECT_EQ(3, iter->key());
@ -491,10 +504,10 @@ TEST(DictTest, givenConstIterator_whenDereferencing_thenPointsToCorrectElement)
}
TEST(DictTest, givenMutableIterator_whenWritingToValue_thenWorks) {
Dict<int, string> dict;
DictPtr<int, string> dict = make_dict<int, string>();
dict.insert(3, "3");
Dict<int, string>::iterator iter = dict.begin();
DictPtr<int, string>::iterator iter = dict.begin();
(*iter).setValue("new_value");
EXPECT_EQ("new_value", dict.begin()->value());
@ -502,3 +515,27 @@ TEST(DictTest, givenMutableIterator_whenWritingToValue_thenWorks) {
iter->setValue("new_value_2");
EXPECT_EQ("new_value_2", dict.begin()->value());
}
TEST(DictTest, isReferenceType) {
DictPtr<int, string> dict1 = make_dict<int, string>();
DictPtr<int, string> dict2(dict1);
DictPtr<int, string> dict3 = make_dict<int, string>();
dict3 = dict1;
dict1.insert(3, "three");
EXPECT_EQ(1, dict1.size());
EXPECT_EQ(1, dict2.size());
EXPECT_EQ(1, dict3.size());
}
TEST(DictTest, copyHasSeparateStorage) {
DictPtr<int, string> dict1 = make_dict<int, string>();
DictPtr<int, string> dict2(dict1.copy());
DictPtr<int, string> dict3 = make_dict<int, string>();
dict3 = dict1.copy();
dict1.insert(3, "three");
EXPECT_EQ(1, dict1.size());
EXPECT_EQ(0, dict2.size());
EXPECT_EQ(0, dict3.size());
}

View File

@ -12,7 +12,7 @@ struct Function;
} // namespace jit
} // namespace torch
namespace c10 {
template<class Key, class Value> class Dict;
template<class Key, class Value> class DictPtr;
struct IValue;
namespace ivalue {
struct Tuple;
@ -220,7 +220,7 @@ struct CAFFE2_API IValue final {
const std::vector<bool>& toBoolListRef() const;
const std::vector<at::Tensor>& toTensorListRef() const;
const std::vector<IValue>& toGenericListRef() const;
const c10::Dict<IValue, IValue>& toGenericDictRef() const;
const c10::DictPtr<IValue, IValue>& toGenericDictRef() const;
const std::string& toStringRef() const;
// ConstantString
@ -261,7 +261,7 @@ struct CAFFE2_API IValue final {
// GenericDict
IValue(c10::intrusive_ptr<ivalue::GenericDict> v);
IValue(c10::Dict<IValue, IValue> v);
IValue(c10::DictPtr<IValue, IValue> v);
bool isGenericDict() const { return Tag::GenericDict == tag; }
c10::intrusive_ptr<ivalue::GenericDict> toGenericDict() &&;
c10::intrusive_ptr<ivalue::GenericDict> toGenericDict() const &;

View File

@ -401,19 +401,19 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
struct C10_EXPORT ivalue::GenericDict : c10::intrusive_ptr_target {
private:
c10::impl::GenericDict elements_;
c10::impl::GenericDictPtr elements_;
public:
GenericDict(c10::impl::GenericDict elements_)
GenericDict(c10::impl::GenericDictPtr elements_)
: elements_(std::move(elements_)) {}
static c10::intrusive_ptr<GenericDict> create(
c10::impl::GenericDict elements_) {
c10::impl::GenericDictPtr elements_) {
return c10::make_intrusive<GenericDict>(std::move(elements_));
}
const c10::impl::GenericDict& elements() const & {
const c10::impl::GenericDictPtr& elements() const & {
return elements_;
}
c10::impl::GenericDict& elements() & {
c10::impl::GenericDictPtr& elements() & {
return elements_;
}
@ -570,7 +570,7 @@ inline IValue::IValue(c10::intrusive_ptr<ivalue::GenericDict> v)
: tag(Tag::GenericDict), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
}
inline IValue::IValue(c10::impl::GenericDict v)
inline IValue::IValue(c10::impl::GenericDictPtr v)
: IValue(ivalue::GenericDict::create(std::move(v))) {}
inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
@ -602,7 +602,7 @@ inline const std::vector<IValue>& IValue::toGenericListRef() const {
return toGenericList()->elements();
}
inline const c10::impl::GenericDict& IValue::
inline const c10::impl::GenericDictPtr& IValue::
toGenericDictRef() const {
return toGenericDict()->elements();
}

View File

@ -1280,7 +1280,7 @@ struct getTypePtr_<std::unordered_map<K, V>> final {
}
};
template <class K, class V>
struct getTypePtr_<c10::Dict<K, V>> final {
struct getTypePtr_<c10::DictPtr<K, V>> final {
static TypePtr call() {
static auto type =
DictType::create(getTypePtr_<K>::call(), getTypePtr_<V>::call());

View File

@ -25,7 +25,8 @@ using c10::guts::make_unique;
using c10::ivalue::TensorList;
using c10::ivalue::IntList;
using c10::intrusive_ptr;
using c10::Dict;
using c10::DictPtr;
using c10::make_dict;
using at::Tensor;
using std::string;
using std::unique_ptr;
@ -209,11 +210,11 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntListO
EXPECT_EQ(6, result[0].toIntListRef()[2]);
}
std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>> kernelWithMultipleOutputs(Tensor) {
Dict<string, Tensor> dict;
std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>> kernelWithMultipleOutputs(Tensor) {
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
dict.insert("first", dummyTensor(TensorType1()));
dict.insert("second", dummyTensor(TensorType2()));
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>>(
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>>(
dummyTensor(TensorType2()),
5,
{dummyTensor(TensorType1()), dummyTensor(TensorType2())},
@ -516,7 +517,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithStringLi
int captured_dict_size = 0;
void kernelWithDictInputWithoutOutput(Dict<string, Tensor> input1) {
void kernelWithDictInputWithoutOutput(DictPtr<string, Tensor> input1) {
captured_dict_size = input1.size();
}
@ -528,7 +529,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithDictInpu
ASSERT_TRUE(op.has_value());
captured_dict_size = 0;
Dict<string, Tensor> dict;
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
dict.insert("key1", dummyTensor(TensorType1()));
dict.insert("key2", dummyTensor(TensorType2()));
auto outputs = callOp(*op, dict);
@ -536,7 +537,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithDictInpu
EXPECT_EQ(2, captured_dict_size);
}
string kernelWithDictInputWithOutput(Dict<string, string> input1) {
string kernelWithDictInputWithOutput(DictPtr<string, string> input1) {
return input1.at("key2");
}
@ -547,7 +548,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithDictInpu
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_input", "");
ASSERT_TRUE(op.has_value());
Dict<string, string> dict;
DictPtr<string, string> dict = make_dict<string, string>();
dict.insert("key1", "value1");
dict.insert("key2", "value2");
auto outputs = callOp(*op, dict);
@ -555,7 +556,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithDictInpu
EXPECT_EQ("value2", outputs[0].toString()->string());
}
Dict<string, string> kernelWithDictOutput(Dict<string, string> input) {
DictPtr<string, string> kernelWithDictOutput(DictPtr<string, string> input) {
return input;
}
@ -566,7 +567,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithDictOutp
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_output", "");
ASSERT_TRUE(op.has_value());
Dict<string, string> dict;
DictPtr<string, string> dict = make_dict<string, string>();
dict.insert("key1", "value1");
dict.insert("key2", "value2");
auto outputs = callOp(*op, dict);

View File

@ -13,7 +13,8 @@ using c10::guts::make_unique;
using c10::ivalue::TensorList;
using c10::ivalue::IntList;
using c10::intrusive_ptr;
using c10::Dict;
using c10::DictPtr;
using c10::make_dict;
using at::Tensor;
using std::string;
using std::unique_ptr;
@ -206,11 +207,11 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListOutput_
EXPECT_EQ(6, result[0].toIntListRef()[2]);
}
std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>> kernelWithMultipleOutputs(Tensor) {
Dict<string, Tensor> dict;
std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>> kernelWithMultipleOutputs(Tensor) {
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
dict.insert("first", dummyTensor(TensorType1()));
dict.insert("second", dummyTensor(TensorType2()));
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>>(
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>>(
dummyTensor(TensorType2()),
5,
{dummyTensor(TensorType1()), dummyTensor(TensorType2())},
@ -431,7 +432,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInpu
int captured_dict_size = 0;
void kernelWithDictInputWithoutOutput(Dict<string, Tensor> input1) {
void kernelWithDictInputWithoutOutput(DictPtr<string, Tensor> input1) {
captured_dict_size = input1.size();
}
@ -443,7 +444,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithDictInput_with
ASSERT_TRUE(op.has_value());
captured_dict_size = 0;
Dict<string, Tensor> dict;
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
dict.insert("key1", dummyTensor(TensorType1()));
dict.insert("key2", dummyTensor(TensorType2()));
auto outputs = callOp(*op, dict);
@ -451,7 +452,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithDictInput_with
EXPECT_EQ(2, captured_dict_size);
}
string kernelWithDictInputWithOutput(Dict<string, string> input1) {
string kernelWithDictInputWithOutput(DictPtr<string, string> input1) {
return input1.at("key2");
}
@ -462,7 +463,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithDictInput_with
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_input", "");
ASSERT_TRUE(op.has_value());
Dict<string, string> dict;
DictPtr<string, string> dict = make_dict<string, string>();
dict.insert("key1", "value1");
dict.insert("key2", "value2");
auto outputs = callOp(*op, dict);
@ -470,7 +471,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithDictInput_with
EXPECT_EQ("value2", outputs[0].toString()->string());
}
Dict<string, string> kernelWithDictOutput(Dict<string, string> input) {
DictPtr<string, string> kernelWithDictOutput(DictPtr<string, string> input) {
return input;
}
@ -481,7 +482,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithDictOutput_whe
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_output", "");
ASSERT_TRUE(op.has_value());
Dict<string, string> dict;
DictPtr<string, string> dict = make_dict<string, string>();
dict.insert("key1", "value1");
dict.insert("key2", "value2");
auto outputs = callOp(*op, dict);

View File

@ -60,10 +60,10 @@ namespace detail {
}
};
template<class Key, class Value>
struct ivalue_to_arg_type<Dict<Key, Value>> {
static Dict<Key, Value> call(IValue&& v) {
static_assert(guts::typelist::contains<supported_primitive_arg_types, Key>::value, "You tried to register a kernel with an unsupported argument type: c10::Dict<Key, Value> and Key is not one of the supported primitive types.");
static_assert(guts::typelist::contains<supported_primitive_arg_types, Value>::value, "You tried to register a kernel with an unsupported argument type: c10::Dict<Key, Value> and Value is not one of the supported primitive types.");
struct ivalue_to_arg_type<DictPtr<Key, Value>> {
static DictPtr<Key, Value> call(IValue&& v) {
static_assert(guts::typelist::contains<supported_primitive_arg_types, Key>::value, "You tried to register a kernel with an unsupported argument type: c10::DictPtr<Key, Value> and Key is not one of the supported primitive types.");
static_assert(guts::typelist::contains<supported_primitive_arg_types, Value>::value, "You tried to register a kernel with an unsupported argument type: c10::DictPtr<Key, Value> and Value is not one of the supported primitive types.");
auto dict_ptr = std::move(v).toGenericDict();
return impl::toTypedDict<Key, Value>(std::move(dict_ptr->elements()));
@ -169,17 +169,17 @@ namespace detail {
}
};
template<class Key, class Value>
struct return_type_to_ivalue<c10::Dict<Key, Value>> {
static IValue call(c10::Dict<Key, Value>&& v) {
static_assert(guts::typelist::contains<supported_primitive_arg_types, Key>::value, "You tried to register a kernel with an unsupported return type: Dict<Key, Value> and Key is not one of the supported primitive types.");
static_assert(guts::typelist::contains<supported_primitive_arg_types, Value>::value, "You tried to register a kernel with an unsupported return type: Dict<Key, Value> and Value is not one of the supported primitive types.");
struct return_type_to_ivalue<c10::DictPtr<Key, Value>> {
static IValue call(c10::DictPtr<Key, Value>&& v) {
static_assert(guts::typelist::contains<supported_primitive_arg_types, Key>::value, "You tried to register a kernel with an unsupported return type: DictPtr<Key, Value> and Key is not one of the supported primitive types.");
static_assert(guts::typelist::contains<supported_primitive_arg_types, Value>::value, "You tried to register a kernel with an unsupported return type: DictPtr<Key, Value> and Value is not one of the supported primitive types.");
return IValue(impl::toGenericDict(std::move(v)));
}
};
template<class Key, class Value>
struct return_type_to_ivalue<std::unordered_map<Key, Value>> final {
static IValue call(std::unordered_map<Key, Value>&& v) {
c10::impl::GenericDict dict;
c10::impl::GenericDictPtr dict = c10::impl::make_generic_dict();
dict.reserve(v.size());
for (auto& element : v) {
dict.insert(return_type_to_ivalue<Key>::call(Key{element.first}), return_type_to_ivalue<Value>::call(std::move(element.second)));

View File

@ -14,7 +14,8 @@ using c10::guts::make_unique;
using c10::ivalue::TensorList;
using c10::ivalue::IntList;
using c10::intrusive_ptr;
using c10::Dict;
using c10::DictPtr;
using c10::make_dict;
using at::Tensor;
using std::unique_ptr;
using std::string;
@ -226,11 +227,11 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListOutput_w
}
struct KernelWithMultipleOutputs final : OperatorKernel {
std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>> operator()(Tensor) {
Dict<string, Tensor> dict;
std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>> operator()(Tensor) {
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
dict.insert("first", dummyTensor(TensorType1()));
dict.insert("second", dummyTensor(TensorType2()));
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>>(
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>>(
dummyTensor(TensorType2()),
5,
{dummyTensor(TensorType1()), dummyTensor(TensorType2())},
@ -473,7 +474,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListInput
int captured_dict_size = 0;
struct KernelWithDictInputWithoutOutput final : OperatorKernel {
void operator()(Dict<string, Tensor> input1) {
void operator()(DictPtr<string, Tensor> input1) {
captured_dict_size = input1.size();
}
};
@ -486,7 +487,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithDictInput_witho
ASSERT_TRUE(op.has_value());
captured_dict_size = 0;
Dict<string, Tensor> dict;
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
dict.insert("key1", dummyTensor(TensorType1()));
dict.insert("key2", dummyTensor(TensorType2()));
auto outputs = callOp(*op, dict);
@ -495,7 +496,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithDictInput_witho
}
struct KernelWithDictInputWithOutput final : OperatorKernel {
string operator()(Dict<string, string> input1) {
string operator()(DictPtr<string, string> input1) {
return input1.at("key2");
}
};
@ -507,7 +508,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithDictInput_withO
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_input", "");
ASSERT_TRUE(op.has_value());
Dict<string, string> dict;
DictPtr<string, string> dict = make_dict<string, string>();
dict.insert("key1", "value1");
dict.insert("key2", "value2");
auto outputs = callOp(*op, dict);
@ -516,7 +517,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithDictInput_withO
}
struct KernelWithDictOutput final : OperatorKernel {
Dict<string, string> operator()(Dict<string, string> input) {
DictPtr<string, string> operator()(DictPtr<string, string> input) {
return input;
}
};
@ -528,7 +529,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithDictOutput_when
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_output", "");
ASSERT_TRUE(op.has_value());
Dict<string, string> dict;
DictPtr<string, string> dict = make_dict<string, string>();
dict.insert("key1", "value1");
dict.insert("key2", "value2");
auto outputs = callOp(*op, dict);

View File

@ -24,7 +24,8 @@ using c10::guts::make_unique;
using c10::ivalue::TensorList;
using c10::ivalue::IntList;
using c10::intrusive_ptr;
using c10::Dict;
using c10::DictPtr;
using c10::make_dict;
using at::Tensor;
using std::string;
using std::unique_ptr;
@ -192,11 +193,11 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListOut
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[], int?, Dict(str, Tensor))", [] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>> {
Dict<string, Tensor> dict;
.op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[], int?, Dict(str, Tensor))", [] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>> {
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
dict.insert("first", dummyTensor(TensorType1()));
dict.insert("second", dummyTensor(TensorType2()));
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>>(
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>>(
dummyTensor(TensorType2()),
5,
{dummyTensor(TensorType1()), dummyTensor(TensorType2())},
@ -468,7 +469,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictInput_
int captured_dict_size = 0;
auto registrar = RegisterOperators()
.op("_test::dict_input(Dict(str, Tensor) input) -> ()", [&] (Dict<string, Tensor> input1) {
.op("_test::dict_input(Dict(str, Tensor) input) -> ()", [&] (DictPtr<string, Tensor> input1) {
captured_dict_size = input1.size();
});
@ -476,7 +477,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictInput_
ASSERT_TRUE(op.has_value());
captured_dict_size = 0;
Dict<string, Tensor> dict;
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
dict.insert("key1", dummyTensor(TensorType1()));
dict.insert("key2", dummyTensor(TensorType2()));
auto outputs = callOp(*op, dict);
@ -486,14 +487,14 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictInput_
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op("_test::dict_input(Dict(str, str) input) -> str", [&] (Dict<string, string> input1) {
.op("_test::dict_input(Dict(str, str) input) -> str", [&] (DictPtr<string, string> input1) {
return input1.at("key2");
});
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_input", "");
ASSERT_TRUE(op.has_value());
Dict<string, string> dict;
DictPtr<string, string> dict = make_dict<string, string>();
dict.insert("key1", "value1");
dict.insert("key2", "value2");
auto outputs = callOp(*op, dict);
@ -503,14 +504,14 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictInput_
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op("_test::dict_output(Dict(str, str) input) -> Dict(str, str)", [] (Dict<string, string> input) {
.op("_test::dict_output(Dict(str, str) input) -> Dict(str, str)", [] (DictPtr<string, string> input) {
return input;
});
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_output", "");
ASSERT_TRUE(op.has_value());
Dict<string, string> dict;
DictPtr<string, string> dict = make_dict<string, string>();
dict.insert("key1", "value1");
dict.insert("key2", "value2");
auto outputs = callOp(*op, dict);

View File

@ -13,7 +13,8 @@ using c10::guts::make_unique;
using c10::ivalue::TensorList;
using c10::ivalue::IntList;
using c10::intrusive_ptr;
using c10::Dict;
using c10::DictPtr;
using c10::make_dict;
using at::Tensor;
using std::string;
using std::unique_ptr;
@ -192,11 +193,11 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListOutput_wh
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[], int?, Dict(str, Tensor))",
RegisterOperators::options().kernel([] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>> {
Dict<string, Tensor> dict;
RegisterOperators::options().kernel([] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>> {
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
dict.insert("first", dummyTensor(TensorType1()));
dict.insert("second", dummyTensor(TensorType2()));
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, Dict<string, Tensor>>(
return std::tuple<Tensor, int64_t, std::vector<Tensor>, c10::optional<int64_t>, DictPtr<string, Tensor>>(
dummyTensor(TensorType2()),
5,
{dummyTensor(TensorType1()), dummyTensor(TensorType2())},
@ -404,7 +405,7 @@ int captured_dict_size = 0;
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op("_test::dict_input(Dict(str, Tensor) input) -> ()", RegisterOperators::options().kernel([] (Dict<string, Tensor> input1) {
.op("_test::dict_input(Dict(str, Tensor) input) -> ()", RegisterOperators::options().kernel([] (DictPtr<string, Tensor> input1) {
captured_dict_size = input1.size();
}));
@ -412,7 +413,7 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictInput_withou
ASSERT_TRUE(op.has_value());
captured_dict_size = 0;
Dict<string, Tensor> dict;
DictPtr<string, Tensor> dict = make_dict<string, Tensor>();
dict.insert("key1", dummyTensor(TensorType1()));
dict.insert("key2", dummyTensor(TensorType2()));
auto outputs = callOp(*op, dict);
@ -422,14 +423,14 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictInput_withou
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op("_test::dict_input(Dict(str, str) input) -> str", RegisterOperators::options().kernel([] (Dict<string, string> input1) {
.op("_test::dict_input(Dict(str, str) input) -> str", RegisterOperators::options().kernel([] (DictPtr<string, string> input1) {
return input1.at("key2");
}));
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_input", "");
ASSERT_TRUE(op.has_value());
Dict<string, string> dict;
DictPtr<string, string> dict = make_dict<string, string>();
dict.insert("key1", "value1");
dict.insert("key2", "value2");
auto outputs = callOp(*op, dict);
@ -439,14 +440,14 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictInput_withOu
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op("_test::dict_output(Dict(str, str) input) -> Dict(str, str)", RegisterOperators::options().kernel([] (Dict<string, string> input) {
.op("_test::dict_output(Dict(str, str) input) -> Dict(str, str)", RegisterOperators::options().kernel([] (DictPtr<string, string> input) {
return input;
}));
auto op = c10::Dispatcher::singleton().findSchema("_test::dict_output", "");
ASSERT_TRUE(op.has_value());
Dict<string, string> dict;
DictPtr<string, string> dict = make_dict<string, string>();
dict.insert("key1", "value1");
dict.insert("key2", "value2");
auto outputs = callOp(*op, dict);

View File

@ -638,33 +638,33 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
// TODO Do we want to support list of optional ?
// dict types
c10::Dict<std::string, std::string> str_dict;
c10::DictPtr<std::string, std::string> str_dict = c10::make_dict<std::string, std::string>();
str_dict.insert("key1", "value1");
str_dict.insert("key2", "value2");
testArgTypes<c10::Dict<std::string, std::string>>::test(
str_dict, [] (c10::Dict<std::string, std::string> v) {
testArgTypes<c10::DictPtr<std::string, std::string>>::test(
str_dict, [] (c10::DictPtr<std::string, std::string> v) {
EXPECT_EQ(2, v.size());
EXPECT_EQ("value1", v.at("key1"));
EXPECT_EQ("value2", v.at("key2"));
},
str_dict, [] (const IValue& v) {
c10::Dict<std::string, std::string> dict = c10::impl::toTypedDict<std::string, std::string>(std::move(v.toGenericDict()->elements()));
c10::DictPtr<std::string, std::string> dict = c10::impl::toTypedDict<std::string, std::string>(std::move(v.toGenericDict()->elements()));
EXPECT_EQ(2, dict.size());
EXPECT_EQ("value1", dict.at("key1"));
EXPECT_EQ("value2", dict.at("key2"));
},
"(Dict(str, str) a) -> Dict(str, str)");
c10::Dict<int64_t, Tensor> tensor_dict;
c10::DictPtr<int64_t, Tensor> tensor_dict = c10::make_dict<int64_t, Tensor>();
tensor_dict.insert(1, dummyTensor(TensorType1()));
tensor_dict.insert(2, dummyTensor(TensorType2()));
testArgTypes<c10::Dict<int64_t, Tensor>>::test(
tensor_dict, [] (c10::Dict<int64_t, Tensor> v) {
testArgTypes<c10::DictPtr<int64_t, Tensor>>::test(
tensor_dict, [] (c10::DictPtr<int64_t, Tensor> v) {
EXPECT_EQ(2, v.size());
EXPECT_EQ(TensorType1(), v.at(1).type_id());
EXPECT_EQ(TensorType2(), v.at(2).type_id());
},
tensor_dict, [] (const IValue& v) {
c10::Dict<int64_t, Tensor> dict = c10::impl::toTypedDict<int64_t, Tensor>(std::move(v.toGenericDict()->elements()));
c10::DictPtr<int64_t, Tensor> dict = c10::impl::toTypedDict<int64_t, Tensor>(std::move(v.toGenericDict()->elements()));
EXPECT_EQ(2, dict.size());
EXPECT_EQ(TensorType1(), dict.at(1).type_id());
EXPECT_EQ(TensorType2(), dict.at(2).type_id());

View File

@ -60,7 +60,7 @@ struct InputToIValue<std::vector<std::string>> final {
}
};
template<class Key, class Value>
struct InputToIValue<c10::Dict<Key, Value>> final {
struct InputToIValue<c10::DictPtr<Key, Value>> final {
template<class T_>
static c10::IValue call(T_&& v) {
return c10::IValue(c10::impl::toGenericDict(std::move(v)));
@ -70,7 +70,7 @@ template<class Key, class Value>
struct InputToIValue<std::unordered_map<Key, Value>> final {
template<class T_>
static c10::IValue call(T_&& v) {
c10::impl::GenericDict dict;
c10::impl::GenericDictPtr dict = c10::impl::make_generic_dict();
dict.reserve(v.size());
for (auto& element : v) {
dict.insert(InputToIValue<Key>::call(element.first), InputToIValue<Value>::call(element.second));

View File

@ -227,7 +227,7 @@ bool isSubvalueOf(const IValue& ivalue, TypePtr type) {
auto dict_type = type->expect<DictType>();
const auto& dict = ivalue.toGenericDictRef();
return std::all_of(
dict.begin(), dict.end(), [=](const c10::impl::GenericDict::const_iterator::value_type& item) {
dict.begin(), dict.end(), [=](const c10::impl::GenericDictPtr::const_iterator::value_type& item) {
return isSubvalueOf(item.key(), dict_type->getKeyType()) &&
isSubvalueOf(item.value(), dict_type->getValueType());
});

View File

@ -93,7 +93,7 @@ TEST(TorchScriptTest, TestDictArgMatching) {
def dict_op(a: Dict[str, Tensor], b: str):
return a[b]
)JIT");
c10::impl::GenericDict dict;
c10::impl::GenericDictPtr dict = c10::impl::make_generic_dict();
dict.insert("hello", torch::ones({2}));
auto output = module->run_method("dict_op", dict, std::string("hello"));
ASSERT_EQ(1, output.toTensor()[0].item<int64_t>());

View File

@ -667,7 +667,7 @@ OpCode Unpickler::readInstruction() {
stack_.emplace_back(IValue(tuple));
} break;
case OpCode::EMPTY_DICT:
stack_.emplace_back(c10::impl::GenericDict());
stack_.emplace_back(c10::impl::make_generic_dict());
break;
case OpCode::APPENDS: {
readList();

View File

@ -115,7 +115,7 @@ inline TypedIValue toTypedIValue(py::handle input) {
} else if (PyDict_Check(input.ptr())) {
// Check to make sure we can generate useful input/output types
auto dict = py::cast<py::dict>(input);
c10::impl::GenericDict elems;
c10::impl::GenericDictPtr elems = c10::impl::make_generic_dict();
size_t len = py::len(dict);
if (!len) {
@ -205,7 +205,7 @@ inline IValue createGenericDict(
py::handle obj,
const TypePtr& key_type,
const TypePtr& value_type) {
c10::impl::GenericDict elems;
c10::impl::GenericDictPtr elems = c10::impl::make_generic_dict();
elems.reserve(py::len(obj));
for (auto key : obj) {
elems.insert(

View File

@ -875,7 +875,7 @@ RegisterOperators reg(
"DictConstruct must have an even number of inputs");
}
return [=](Stack& stack) {
c10::impl::GenericDict vals;
c10::impl::GenericDictPtr vals = c10::impl::make_generic_dict();
for (size_t i = 0; i < num_inputs; i += 2) {
auto val = pop(stack);
auto key = pop(stack);