mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157636 Approved by: https://github.com/yewentao256, https://github.com/mlazos ghstack dependencies: #156311, #156609
897 lines
29 KiB
Python
897 lines
29 KiB
Python
# Owner(s): ["module: meta tensors"]
|
|
|
|
import copy
|
|
import gc
|
|
import random
|
|
import threading
|
|
import unittest
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import IS_MACOS, run_tests, TestCase
|
|
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
|
from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary
|
|
|
|
|
|
def C():
|
|
return torch.randn(1)
|
|
|
|
|
|
# These tests are ported from cpython/Lib/test/test_weakref.py,
|
|
# but adapted to use tensor rather than object
|
|
class WeakTest(TestCase):
|
|
COUNT = 10
|
|
|
|
def test_make_weak_keyed_dict_from_dict(self):
|
|
o = torch.randn(2)
|
|
dict = WeakIdKeyDictionary({o: 364})
|
|
self.assertEqual(dict[o], 364)
|
|
|
|
def test_make_weak_keyed_dict_from_weak_keyed_dict(self):
|
|
o = torch.randn(3)
|
|
dict = WeakIdKeyDictionary({o: 364})
|
|
self.assertEqual(dict[o], 364)
|
|
dict2 = WeakIdKeyDictionary(dict)
|
|
self.assertEqual(dict2[o], 364)
|
|
|
|
def check_popitem(self, klass, key1, value1, key2, value2):
|
|
weakdict = klass()
|
|
weakdict[key1] = value1
|
|
weakdict[key2] = value2
|
|
self.assertEqual(len(weakdict), 2)
|
|
k, v = weakdict.popitem()
|
|
self.assertEqual(len(weakdict), 1)
|
|
if k is key1:
|
|
self.assertIs(v, value1)
|
|
else:
|
|
self.assertIs(v, value2)
|
|
k, v = weakdict.popitem()
|
|
self.assertEqual(len(weakdict), 0)
|
|
if k is key1:
|
|
self.assertIs(v, value1)
|
|
else:
|
|
self.assertIs(v, value2)
|
|
|
|
def test_weak_keyed_dict_popitem(self):
|
|
self.check_popitem(WeakIdKeyDictionary, C(), "value 1", C(), "value 2")
|
|
|
|
def check_setdefault(self, klass, key, value1, value2):
|
|
self.assertIsNot(
|
|
value1,
|
|
value2,
|
|
"invalid test -- value parameters must be distinct objects",
|
|
)
|
|
weakdict = klass()
|
|
o = weakdict.setdefault(key, value1)
|
|
self.assertIs(o, value1)
|
|
self.assertIn(key, weakdict)
|
|
self.assertIs(weakdict.get(key), value1)
|
|
self.assertIs(weakdict[key], value1)
|
|
|
|
o = weakdict.setdefault(key, value2)
|
|
self.assertIs(o, value1)
|
|
self.assertIn(key, weakdict)
|
|
self.assertIs(weakdict.get(key), value1)
|
|
self.assertIs(weakdict[key], value1)
|
|
|
|
def test_weak_keyed_dict_setdefault(self):
|
|
self.check_setdefault(WeakIdKeyDictionary, C(), "value 1", "value 2")
|
|
|
|
def check_update(self, klass, dict):
|
|
#
|
|
# This exercises d.update(), len(d), d.keys(), k in d,
|
|
# d.get(), d[].
|
|
#
|
|
weakdict = klass()
|
|
weakdict.update(dict)
|
|
self.assertEqual(len(weakdict), len(dict))
|
|
for k in weakdict.keys():
|
|
self.assertIn(k, dict, "mysterious new key appeared in weak dict")
|
|
v = dict.get(k)
|
|
self.assertIs(v, weakdict[k])
|
|
self.assertIs(v, weakdict.get(k))
|
|
for k in dict.keys():
|
|
self.assertIn(k, weakdict, "original key disappeared in weak dict")
|
|
v = dict[k]
|
|
self.assertIs(v, weakdict[k])
|
|
self.assertIs(v, weakdict.get(k))
|
|
|
|
def test_weak_keyed_dict_update(self):
|
|
self.check_update(WeakIdKeyDictionary, {C(): 1, C(): 2, C(): 3})
|
|
|
|
def test_weak_keyed_delitem(self):
|
|
d = WeakIdKeyDictionary()
|
|
o1 = torch.randn(1)
|
|
o2 = torch.randn(2)
|
|
d[o1] = "something"
|
|
d[o2] = "something"
|
|
self.assertEqual(len(d), 2)
|
|
del d[o1]
|
|
self.assertEqual(len(d), 1)
|
|
self.assertEqual(list(d.keys()), [o2])
|
|
|
|
def test_weak_keyed_union_operators(self):
|
|
try:
|
|
{} | {}
|
|
except TypeError:
|
|
self.skipTest("dict union not supported in this Python")
|
|
|
|
o1 = C()
|
|
o2 = C()
|
|
o3 = C()
|
|
wkd1 = WeakIdKeyDictionary({o1: 1, o2: 2})
|
|
wkd2 = WeakIdKeyDictionary({o3: 3, o1: 4})
|
|
wkd3 = wkd1.copy()
|
|
d1 = {o2: "5", o3: "6"}
|
|
pairs = [(o2, 7), (o3, 8)]
|
|
|
|
tmp1 = wkd1 | wkd2 # Between two WeakKeyDictionaries
|
|
self.assertEqual(dict(tmp1), dict(wkd1) | dict(wkd2))
|
|
self.assertIs(type(tmp1), WeakIdKeyDictionary)
|
|
wkd1 |= wkd2
|
|
self.assertEqual(wkd1, tmp1)
|
|
|
|
tmp2 = wkd2 | d1 # Between WeakKeyDictionary and mapping
|
|
self.assertEqual(dict(tmp2), dict(wkd2) | d1)
|
|
self.assertIs(type(tmp2), WeakIdKeyDictionary)
|
|
wkd2 |= d1
|
|
self.assertEqual(wkd2, tmp2)
|
|
|
|
tmp3 = wkd3.copy() # Between WeakKeyDictionary and iterable key, value
|
|
tmp3 |= pairs
|
|
self.assertEqual(dict(tmp3), dict(wkd3) | dict(pairs))
|
|
self.assertIs(type(tmp3), WeakIdKeyDictionary)
|
|
|
|
tmp4 = d1 | wkd3 # Testing .__ror__
|
|
self.assertEqual(dict(tmp4), d1 | dict(wkd3))
|
|
self.assertIs(type(tmp4), WeakIdKeyDictionary)
|
|
|
|
del o1
|
|
self.assertNotIn(4, tmp1.values())
|
|
self.assertNotIn(4, tmp2.values())
|
|
self.assertNotIn(1, tmp3.values())
|
|
self.assertNotIn(1, tmp4.values())
|
|
|
|
def test_weak_keyed_bad_delitem(self):
|
|
d = WeakIdKeyDictionary()
|
|
o = torch.randn(1)
|
|
# An attempt to delete an object that isn't there should raise
|
|
# KeyError. It didn't before 2.3.
|
|
self.assertRaises(KeyError, d.__delitem__, o)
|
|
self.assertRaises(KeyError, d.__getitem__, o)
|
|
|
|
# If a key isn't of a weakly referenceable type, __getitem__ and
|
|
# __setitem__ raise TypeError. __delitem__ should too.
|
|
self.assertRaises(TypeError, d.__delitem__, 13)
|
|
self.assertRaises(TypeError, d.__getitem__, 13)
|
|
self.assertRaises(TypeError, d.__setitem__, 13, 13)
|
|
|
|
def test_make_weak_keyed_dict_repr(self):
|
|
dict = WeakIdKeyDictionary()
|
|
self.assertRegex(repr(dict), "<WeakIdKeyDictionary at 0x.*>")
|
|
|
|
def check_threaded_weak_dict_copy(self, type_, deepcopy):
|
|
# `deepcopy` should be either True or False.
|
|
exc = []
|
|
|
|
# Cannot give these slots as weakrefs weren't supported
|
|
# on these objects until later versions of Python
|
|
class DummyKey: # noqa: B903
|
|
def __init__(self, ctr):
|
|
self.ctr = ctr
|
|
|
|
class DummyValue: # noqa: B903
|
|
def __init__(self, ctr):
|
|
self.ctr = ctr
|
|
|
|
def dict_copy(d, exc):
|
|
try:
|
|
if deepcopy is True:
|
|
_ = copy.deepcopy(d)
|
|
else:
|
|
_ = d.copy()
|
|
except Exception as ex:
|
|
exc.append(ex)
|
|
|
|
def pop_and_collect(lst):
|
|
gc_ctr = 0
|
|
while lst:
|
|
i = random.randint(0, len(lst) - 1)
|
|
gc_ctr += 1
|
|
lst.pop(i)
|
|
if gc_ctr % 10000 == 0:
|
|
gc.collect() # just in case
|
|
|
|
d = type_()
|
|
keys = []
|
|
values = []
|
|
# Initialize d with many entries
|
|
for i in range(70000):
|
|
k, v = DummyKey(i), DummyValue(i)
|
|
keys.append(k)
|
|
values.append(v)
|
|
d[k] = v
|
|
del k
|
|
del v
|
|
|
|
t_copy = threading.Thread(target=dict_copy, args=(d, exc))
|
|
t_collect = threading.Thread(target=pop_and_collect, args=(keys,))
|
|
|
|
t_copy.start()
|
|
t_collect.start()
|
|
|
|
t_copy.join()
|
|
t_collect.join()
|
|
|
|
# Test exceptions
|
|
if exc:
|
|
raise exc[0]
|
|
|
|
def test_threaded_weak_key_dict_copy(self):
|
|
# Issue #35615: Weakref keys or values getting GC'ed during dict
|
|
# copying should not result in a crash.
|
|
self.check_threaded_weak_dict_copy(WeakIdKeyDictionary, False)
|
|
|
|
def test_threaded_weak_key_dict_deepcopy(self):
|
|
# Issue #35615: Weakref keys or values getting GC'ed during dict
|
|
# copying should not result in a crash.
|
|
self.check_threaded_weak_dict_copy(WeakIdKeyDictionary, True)
|
|
|
|
|
|
# Adapted from cpython/Lib/test/mapping_tests.py
|
|
class WeakKeyDictionaryTestCase(TestCase):
|
|
__ref = {torch.randn(1): 1, torch.randn(2): 2, torch.randn(3): 3}
|
|
type2test = WeakIdKeyDictionary
|
|
|
|
def _reference(self):
|
|
return self.__ref.copy()
|
|
|
|
def _empty_mapping(self):
|
|
"""Return an empty mapping object"""
|
|
return self.type2test()
|
|
|
|
def _full_mapping(self, data):
|
|
"""Return a mapping object with the value contained in data
|
|
dictionary"""
|
|
x = self._empty_mapping()
|
|
for key, value in data.items():
|
|
x[key] = value
|
|
return x
|
|
|
|
def __init__(self, *args, **kw):
|
|
unittest.TestCase.__init__(self, *args, **kw)
|
|
self.reference = self._reference().copy()
|
|
|
|
# A (key, value) pair not in the mapping
|
|
key, value = self.reference.popitem()
|
|
self.other = {key: value}
|
|
|
|
# A (key, value) pair in the mapping
|
|
key, value = self.reference.popitem()
|
|
self.inmapping = {key: value}
|
|
self.reference[key] = value
|
|
|
|
def test_read(self):
|
|
# Test for read only operations on mapping
|
|
p = self._empty_mapping()
|
|
p1 = dict(p) # workaround for singleton objects
|
|
d = self._full_mapping(self.reference)
|
|
if d is p:
|
|
p = p1
|
|
# Indexing
|
|
for key, value in self.reference.items():
|
|
self.assertEqual(d[key], value)
|
|
knownkey = next(iter(self.other.keys()))
|
|
self.assertRaises(KeyError, lambda: d[knownkey])
|
|
# len
|
|
self.assertEqual(len(p), 0)
|
|
self.assertEqual(len(d), len(self.reference))
|
|
# __contains__
|
|
for k in self.reference:
|
|
self.assertIn(k, d)
|
|
for k in self.other:
|
|
self.assertNotIn(k, d)
|
|
# cmp
|
|
self.assertTrue(
|
|
p == p
|
|
) # NB: don't use assertEqual, that doesn't actually use ==
|
|
self.assertTrue(d == d)
|
|
self.assertTrue(p != d)
|
|
self.assertTrue(d != p)
|
|
# bool
|
|
if p:
|
|
self.fail("Empty mapping must compare to False")
|
|
if not d:
|
|
self.fail("Full mapping must compare to True")
|
|
|
|
# keys(), items(), iterkeys() ...
|
|
def check_iterandlist(iter, lst, ref):
|
|
self.assertTrue(hasattr(iter, "__next__"))
|
|
self.assertTrue(hasattr(iter, "__iter__"))
|
|
x = list(iter)
|
|
self.assertTrue(set(x) == set(lst) == set(ref))
|
|
|
|
check_iterandlist(iter(d.keys()), list(d.keys()), self.reference.keys())
|
|
check_iterandlist(iter(d), list(d.keys()), self.reference.keys())
|
|
check_iterandlist(iter(d.values()), list(d.values()), self.reference.values())
|
|
check_iterandlist(iter(d.items()), list(d.items()), self.reference.items())
|
|
# get
|
|
key, value = next(iter(d.items()))
|
|
knownkey, knownvalue = next(iter(self.other.items()))
|
|
self.assertEqual(d.get(key, knownvalue), value)
|
|
self.assertEqual(d.get(knownkey, knownvalue), knownvalue)
|
|
self.assertNotIn(knownkey, d)
|
|
|
|
def test_write(self):
|
|
# Test for write operations on mapping
|
|
p = self._empty_mapping()
|
|
# Indexing
|
|
for key, value in self.reference.items():
|
|
p[key] = value
|
|
self.assertEqual(p[key], value)
|
|
for key in self.reference.keys():
|
|
del p[key]
|
|
self.assertRaises(KeyError, lambda: p[key])
|
|
p = self._empty_mapping()
|
|
# update
|
|
p.update(self.reference)
|
|
self.assertEqual(dict(p), self.reference)
|
|
items = list(p.items())
|
|
p = self._empty_mapping()
|
|
p.update(items)
|
|
self.assertEqual(dict(p), self.reference)
|
|
d = self._full_mapping(self.reference)
|
|
# setdefault
|
|
key, value = next(iter(d.items()))
|
|
knownkey, knownvalue = next(iter(self.other.items()))
|
|
self.assertEqual(d.setdefault(key, knownvalue), value)
|
|
self.assertEqual(d[key], value)
|
|
self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue)
|
|
self.assertEqual(d[knownkey], knownvalue)
|
|
# pop
|
|
self.assertEqual(d.pop(knownkey), knownvalue)
|
|
self.assertNotIn(knownkey, d)
|
|
self.assertRaises(KeyError, d.pop, knownkey)
|
|
default = 909
|
|
d[knownkey] = knownvalue
|
|
self.assertEqual(d.pop(knownkey, default), knownvalue)
|
|
self.assertNotIn(knownkey, d)
|
|
self.assertEqual(d.pop(knownkey, default), default)
|
|
# popitem
|
|
key, value = d.popitem()
|
|
self.assertNotIn(key, d)
|
|
self.assertEqual(value, self.reference[key])
|
|
p = self._empty_mapping()
|
|
self.assertRaises(KeyError, p.popitem)
|
|
|
|
def test_constructor(self):
|
|
self.assertEqual(self._empty_mapping(), self._empty_mapping())
|
|
|
|
def test_bool(self):
|
|
self.assertTrue(not self._empty_mapping())
|
|
self.assertTrue(self.reference)
|
|
self.assertTrue(bool(self._empty_mapping()) is False)
|
|
self.assertTrue(bool(self.reference) is True)
|
|
|
|
def test_keys(self):
|
|
d = self._empty_mapping()
|
|
self.assertEqual(list(d.keys()), [])
|
|
d = self.reference
|
|
self.assertIn(next(iter(self.inmapping.keys())), d.keys())
|
|
self.assertNotIn(next(iter(self.other.keys())), d.keys())
|
|
self.assertRaises(TypeError, d.keys, None)
|
|
|
|
def test_values(self):
|
|
d = self._empty_mapping()
|
|
self.assertEqual(list(d.values()), [])
|
|
|
|
self.assertRaises(TypeError, d.values, None)
|
|
|
|
def test_items(self):
|
|
d = self._empty_mapping()
|
|
self.assertEqual(list(d.items()), [])
|
|
|
|
self.assertRaises(TypeError, d.items, None)
|
|
|
|
def test_len(self):
|
|
d = self._empty_mapping()
|
|
self.assertEqual(len(d), 0)
|
|
|
|
def test_getitem(self):
|
|
d = self.reference
|
|
self.assertEqual(
|
|
d[next(iter(self.inmapping.keys()))], next(iter(self.inmapping.values()))
|
|
)
|
|
|
|
self.assertRaises(TypeError, d.__getitem__)
|
|
|
|
def test_update(self):
|
|
# mapping argument
|
|
d = self._empty_mapping()
|
|
d.update(self.other)
|
|
self.assertEqual(list(d.items()), list(self.other.items()))
|
|
|
|
# No argument
|
|
d = self._empty_mapping()
|
|
d.update()
|
|
self.assertEqual(d, self._empty_mapping())
|
|
|
|
# item sequence
|
|
d = self._empty_mapping()
|
|
d.update(self.other.items())
|
|
self.assertEqual(list(d.items()), list(self.other.items()))
|
|
|
|
# Iterator
|
|
d = self._empty_mapping()
|
|
d.update(self.other.items())
|
|
self.assertEqual(list(d.items()), list(self.other.items()))
|
|
|
|
# FIXME: Doesn't work with UserDict
|
|
# self.assertRaises((TypeError, AttributeError), d.update, None)
|
|
self.assertRaises((TypeError, AttributeError), d.update, 42)
|
|
|
|
outerself = self
|
|
|
|
class SimpleUserDict:
|
|
def __init__(self) -> None:
|
|
self.d = outerself.reference
|
|
|
|
def keys(self):
|
|
return self.d.keys()
|
|
|
|
def __getitem__(self, i):
|
|
return self.d[i]
|
|
|
|
d.clear()
|
|
d.update(SimpleUserDict())
|
|
i1 = sorted((id(k), v) for k, v in d.items())
|
|
i2 = sorted((id(k), v) for k, v in self.reference.items())
|
|
self.assertEqual(i1, i2)
|
|
|
|
class Exc(Exception):
|
|
pass
|
|
|
|
d = self._empty_mapping()
|
|
|
|
class FailingUserDict:
|
|
def keys(self):
|
|
raise Exc
|
|
|
|
self.assertRaises(Exc, d.update, FailingUserDict())
|
|
|
|
d.clear()
|
|
|
|
class FailingUserDict:
|
|
def keys(self):
|
|
class BogonIter:
|
|
def __init__(self) -> None:
|
|
self.i = 1
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.i:
|
|
self.i = 0
|
|
return "a"
|
|
raise Exc
|
|
|
|
return BogonIter()
|
|
|
|
def __getitem__(self, key):
|
|
return key
|
|
|
|
self.assertRaises(Exc, d.update, FailingUserDict())
|
|
|
|
class FailingUserDict:
|
|
def keys(self):
|
|
class BogonIter:
|
|
def __init__(self) -> None:
|
|
self.i = ord("a")
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.i <= ord("z"):
|
|
rtn = chr(self.i)
|
|
self.i += 1
|
|
return rtn
|
|
raise StopIteration
|
|
|
|
return BogonIter()
|
|
|
|
def __getitem__(self, key):
|
|
raise Exc
|
|
|
|
self.assertRaises(Exc, d.update, FailingUserDict())
|
|
|
|
d = self._empty_mapping()
|
|
|
|
class badseq:
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
raise Exc
|
|
|
|
self.assertRaises(Exc, d.update, badseq())
|
|
|
|
self.assertRaises(ValueError, d.update, [(1, 2, 3)])
|
|
|
|
# no test_fromkeys or test_copy as both os.environ and selves don't support it
|
|
|
|
def test_get(self):
|
|
d = self._empty_mapping()
|
|
self.assertTrue(d.get(next(iter(self.other.keys()))) is None)
|
|
self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3)
|
|
d = self.reference
|
|
self.assertTrue(d.get(next(iter(self.other.keys()))) is None)
|
|
self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3)
|
|
self.assertEqual(
|
|
d.get(next(iter(self.inmapping.keys()))),
|
|
next(iter(self.inmapping.values())),
|
|
)
|
|
self.assertEqual(
|
|
d.get(next(iter(self.inmapping.keys())), 3),
|
|
next(iter(self.inmapping.values())),
|
|
)
|
|
self.assertRaises(TypeError, d.get)
|
|
self.assertRaises(TypeError, d.get, None, None, None)
|
|
|
|
def test_setdefault(self):
|
|
d = self._empty_mapping()
|
|
self.assertRaises(TypeError, d.setdefault)
|
|
|
|
def test_popitem(self):
|
|
d = self._empty_mapping()
|
|
self.assertRaises(KeyError, d.popitem)
|
|
self.assertRaises(TypeError, d.popitem, 42)
|
|
|
|
def test_pop(self):
|
|
d = self._empty_mapping()
|
|
k, v = next(iter(self.inmapping.items()))
|
|
d[k] = v
|
|
self.assertRaises(KeyError, d.pop, next(iter(self.other.keys())))
|
|
|
|
self.assertEqual(d.pop(k), v)
|
|
self.assertEqual(len(d), 0)
|
|
|
|
self.assertRaises(KeyError, d.pop, k)
|
|
|
|
|
|
# Adapted from cpython/Lib/test/mapping_tests.py
|
|
class WeakKeyDictionaryScriptObjectTestCase(TestCase):
|
|
def _reference(self):
|
|
self.__ref = {
|
|
torch.classes._TorchScriptTesting._Foo(1, 2): 1,
|
|
torch.classes._TorchScriptTesting._Foo(2, 3): 2,
|
|
torch.classes._TorchScriptTesting._Foo(3, 4): 3,
|
|
}
|
|
return self.__ref.copy()
|
|
|
|
def _empty_mapping(self):
|
|
"""Return an empty mapping object"""
|
|
return WeakIdKeyDictionary(ref_type=_WeakHashRef)
|
|
|
|
def _full_mapping(self, data):
|
|
"""Return a mapping object with the value contained in data
|
|
dictionary"""
|
|
x = self._empty_mapping()
|
|
for key, value in data.items():
|
|
x[key] = value
|
|
return x
|
|
|
|
def setUp(self):
|
|
if IS_MACOS:
|
|
raise unittest.SkipTest("non-portable load_library call used in test")
|
|
|
|
def __init__(self, *args, **kw):
|
|
unittest.TestCase.__init__(self, *args, **kw)
|
|
try:
|
|
load_torchbind_test_lib()
|
|
except unittest.SkipTest:
|
|
return # Skip in setup
|
|
|
|
self.reference = self._reference().copy()
|
|
|
|
# A (key, value) pair not in the mapping
|
|
key, value = self.reference.popitem()
|
|
self.other = {key: value}
|
|
|
|
# A (key, value) pair in the mapping
|
|
key, value = self.reference.popitem()
|
|
self.inmapping = {key: value}
|
|
self.reference[key] = value
|
|
|
|
def test_read(self):
|
|
# Test for read only operations on mapping
|
|
p = self._empty_mapping()
|
|
p1 = dict(p) # workaround for singleton objects
|
|
d = self._full_mapping(self.reference)
|
|
if d is p:
|
|
p = p1
|
|
# Indexing
|
|
for key, value in self.reference.items():
|
|
self.assertEqual(d[key], value)
|
|
knownkey = next(iter(self.other.keys()))
|
|
self.assertRaises(KeyError, lambda: d[knownkey])
|
|
# len
|
|
self.assertEqual(len(p), 0)
|
|
self.assertEqual(len(d), len(self.reference))
|
|
# __contains__
|
|
for k in self.reference:
|
|
self.assertIn(k, d)
|
|
for k in self.other:
|
|
self.assertNotIn(k, d)
|
|
# cmp
|
|
self.assertTrue(
|
|
p == p
|
|
) # NB: don't use assertEqual, that doesn't actually use ==
|
|
self.assertTrue(d == d)
|
|
self.assertTrue(p != d)
|
|
self.assertTrue(d != p)
|
|
# bool
|
|
if p:
|
|
self.fail("Empty mapping must compare to False")
|
|
if not d:
|
|
self.fail("Full mapping must compare to True")
|
|
|
|
# keys(), items(), iterkeys() ...
|
|
def check_iterandlist(iter, lst, ref):
|
|
self.assertTrue(hasattr(iter, "__next__"))
|
|
self.assertTrue(hasattr(iter, "__iter__"))
|
|
x = list(iter)
|
|
self.assertTrue(set(x) == set(lst) == set(ref))
|
|
|
|
check_iterandlist(iter(d.keys()), list(d.keys()), self.reference.keys())
|
|
check_iterandlist(iter(d), list(d.keys()), self.reference.keys())
|
|
check_iterandlist(iter(d.values()), list(d.values()), self.reference.values())
|
|
check_iterandlist(iter(d.items()), list(d.items()), self.reference.items())
|
|
# get
|
|
key, value = next(iter(d.items()))
|
|
knownkey, knownvalue = next(iter(self.other.items()))
|
|
self.assertEqual(d.get(key, knownvalue), value)
|
|
self.assertEqual(d.get(knownkey, knownvalue), knownvalue)
|
|
self.assertNotIn(knownkey, d)
|
|
|
|
def test_write(self):
|
|
# Test for write operations on mapping
|
|
p = self._empty_mapping()
|
|
# Indexing
|
|
for key, value in self.reference.items():
|
|
p[key] = value
|
|
self.assertEqual(p[key], value)
|
|
for key in self.reference.keys():
|
|
del p[key]
|
|
self.assertRaises(KeyError, lambda: p[key])
|
|
p = self._empty_mapping()
|
|
# update
|
|
p.update(self.reference)
|
|
self.assertEqual(dict(p), self.reference)
|
|
items = list(p.items())
|
|
p = self._empty_mapping()
|
|
p.update(items)
|
|
self.assertEqual(dict(p), self.reference)
|
|
d = self._full_mapping(self.reference)
|
|
# setdefault
|
|
key, value = next(iter(d.items()))
|
|
knownkey, knownvalue = next(iter(self.other.items()))
|
|
self.assertEqual(d.setdefault(key, knownvalue), value)
|
|
self.assertEqual(d[key], value)
|
|
self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue)
|
|
self.assertEqual(d[knownkey], knownvalue)
|
|
# pop
|
|
self.assertEqual(d.pop(knownkey), knownvalue)
|
|
self.assertNotIn(knownkey, d)
|
|
self.assertRaises(KeyError, d.pop, knownkey)
|
|
default = 909
|
|
d[knownkey] = knownvalue
|
|
self.assertEqual(d.pop(knownkey, default), knownvalue)
|
|
self.assertNotIn(knownkey, d)
|
|
self.assertEqual(d.pop(knownkey, default), default)
|
|
# popitem
|
|
key, value = d.popitem()
|
|
self.assertNotIn(key, d)
|
|
self.assertEqual(value, self.reference[key])
|
|
p = self._empty_mapping()
|
|
self.assertRaises(KeyError, p.popitem)
|
|
|
|
def test_constructor(self):
|
|
self.assertEqual(self._empty_mapping(), self._empty_mapping())
|
|
|
|
def test_bool(self):
|
|
self.assertTrue(not self._empty_mapping())
|
|
self.assertTrue(self.reference)
|
|
self.assertTrue(bool(self._empty_mapping()) is False)
|
|
self.assertTrue(bool(self.reference) is True)
|
|
|
|
def test_keys(self):
|
|
d = self._empty_mapping()
|
|
self.assertEqual(list(d.keys()), [])
|
|
d = self.reference
|
|
self.assertIn(next(iter(self.inmapping.keys())), d.keys())
|
|
self.assertNotIn(next(iter(self.other.keys())), d.keys())
|
|
self.assertRaises(TypeError, d.keys, None)
|
|
|
|
def test_values(self):
|
|
d = self._empty_mapping()
|
|
self.assertEqual(list(d.values()), [])
|
|
|
|
self.assertRaises(TypeError, d.values, None)
|
|
|
|
def test_items(self):
|
|
d = self._empty_mapping()
|
|
self.assertEqual(list(d.items()), [])
|
|
|
|
self.assertRaises(TypeError, d.items, None)
|
|
|
|
def test_len(self):
|
|
d = self._empty_mapping()
|
|
self.assertEqual(len(d), 0)
|
|
|
|
def test_getitem(self):
|
|
d = self.reference
|
|
self.assertEqual(
|
|
d[next(iter(self.inmapping.keys()))], next(iter(self.inmapping.values()))
|
|
)
|
|
|
|
self.assertRaises(TypeError, d.__getitem__)
|
|
|
|
def test_update(self):
|
|
# mapping argument
|
|
d = self._empty_mapping()
|
|
d.update(self.other)
|
|
self.assertEqual(list(d.items()), list(self.other.items()))
|
|
|
|
# No argument
|
|
d = self._empty_mapping()
|
|
d.update()
|
|
self.assertEqual(d, self._empty_mapping())
|
|
|
|
# item sequence
|
|
d = self._empty_mapping()
|
|
d.update(self.other.items())
|
|
self.assertEqual(list(d.items()), list(self.other.items()))
|
|
|
|
# Iterator
|
|
d = self._empty_mapping()
|
|
d.update(self.other.items())
|
|
self.assertEqual(list(d.items()), list(self.other.items()))
|
|
|
|
# FIXME: Doesn't work with UserDict
|
|
# self.assertRaises((TypeError, AttributeError), d.update, None)
|
|
self.assertRaises((TypeError, AttributeError), d.update, 42)
|
|
|
|
outerself = self
|
|
|
|
class SimpleUserDict:
|
|
def __init__(self) -> None:
|
|
self.d = outerself.reference
|
|
|
|
def keys(self):
|
|
return self.d.keys()
|
|
|
|
def __getitem__(self, i):
|
|
return self.d[i]
|
|
|
|
d.clear()
|
|
d.update(SimpleUserDict())
|
|
i1 = sorted((id(k), v) for k, v in d.items())
|
|
i2 = sorted((id(k), v) for k, v in self.reference.items())
|
|
self.assertEqual(i1, i2)
|
|
|
|
class Exc(Exception):
|
|
pass
|
|
|
|
d = self._empty_mapping()
|
|
|
|
class FailingUserDict:
|
|
def keys(self):
|
|
raise Exc
|
|
|
|
self.assertRaises(Exc, d.update, FailingUserDict())
|
|
|
|
d.clear()
|
|
|
|
class FailingUserDict:
|
|
def keys(self):
|
|
class BogonIter:
|
|
def __init__(self) -> None:
|
|
self.i = 1
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.i:
|
|
self.i = 0
|
|
return "a"
|
|
raise Exc
|
|
|
|
return BogonIter()
|
|
|
|
def __getitem__(self, key):
|
|
return key
|
|
|
|
self.assertRaises(Exc, d.update, FailingUserDict())
|
|
|
|
class FailingUserDict:
|
|
def keys(self):
|
|
class BogonIter:
|
|
def __init__(self) -> None:
|
|
self.i = ord("a")
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.i <= ord("z"):
|
|
rtn = chr(self.i)
|
|
self.i += 1
|
|
return rtn
|
|
raise StopIteration
|
|
|
|
return BogonIter()
|
|
|
|
def __getitem__(self, key):
|
|
raise Exc
|
|
|
|
self.assertRaises(Exc, d.update, FailingUserDict())
|
|
|
|
d = self._empty_mapping()
|
|
|
|
class badseq:
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
raise Exc
|
|
|
|
self.assertRaises(Exc, d.update, badseq())
|
|
|
|
self.assertRaises(ValueError, d.update, [(1, 2, 3)])
|
|
|
|
# no test_fromkeys or test_copy as both os.environ and selves don't support it
|
|
|
|
def test_get(self):
|
|
d = self._empty_mapping()
|
|
self.assertTrue(d.get(next(iter(self.other.keys()))) is None)
|
|
self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3)
|
|
d = self.reference
|
|
self.assertTrue(d.get(next(iter(self.other.keys()))) is None)
|
|
self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3)
|
|
self.assertEqual(
|
|
d.get(next(iter(self.inmapping.keys()))),
|
|
next(iter(self.inmapping.values())),
|
|
)
|
|
self.assertEqual(
|
|
d.get(next(iter(self.inmapping.keys())), 3),
|
|
next(iter(self.inmapping.values())),
|
|
)
|
|
self.assertRaises(TypeError, d.get)
|
|
self.assertRaises(TypeError, d.get, None, None, None)
|
|
|
|
def test_setdefault(self):
|
|
d = self._empty_mapping()
|
|
self.assertRaises(TypeError, d.setdefault)
|
|
|
|
def test_popitem(self):
|
|
d = self._empty_mapping()
|
|
self.assertRaises(KeyError, d.popitem)
|
|
self.assertRaises(TypeError, d.popitem, 42)
|
|
|
|
def test_pop(self):
|
|
d = self._empty_mapping()
|
|
k, v = next(iter(self.inmapping.items()))
|
|
d[k] = v
|
|
self.assertRaises(KeyError, d.pop, next(iter(self.other.keys())))
|
|
|
|
self.assertEqual(d.pop(k), v)
|
|
self.assertEqual(len(d), 0)
|
|
|
|
self.assertRaises(KeyError, d.pop, k)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|