Files
pytorch/test/dynamo/cpython/3_13/test_collections.diff
William Wen 8678d831c4 [dynamo] rename set_fullgraph to error_on_graph_break (#161739)
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph.

I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet).

 cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739
Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
2025-09-04 01:15:06 +00:00

929 lines
36 KiB
Diff

diff --git a/test/dynamo/cpython/3_13/test_collections.py b/test/dynamo/cpython/3_13/test_collections.py
index cafc44007d1..4571e5a14fd 100644
--- a/test/dynamo/cpython/3_13/test_collections.py
+++ b/test/dynamo/cpython/3_13/test_collections.py
@@ -1,3 +1,23 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_collections.py
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import run_tests
+
+__TestCase = CPythonTestCase
+
+# ======= END DYNAMO PATCH =======
+
"""Unit tests for collections.py."""
import array
@@ -29,7 +49,7 @@ from collections.abc import Sequence, MutableSequence
from collections.abc import ByteString, Buffer
-class TestUserObjects(unittest.TestCase):
+class TestUserObjects(__TestCase):
def _superset_test(self, a, b):
self.assertGreaterEqual(
set(dir(a)),
@@ -73,9 +93,10 @@ class TestUserObjects(unittest.TestCase):
self._copy_test(obj)
def test_dict_missing(self):
- class A(UserDict):
- def __missing__(self, key):
- return 456
+ with torch._dynamo.error_on_graph_break(False):
+ class A(UserDict):
+ def __missing__(self, key):
+ return 456
self.assertEqual(A()[123], 456)
# get() ignores __missing__ on dict
self.assertIs(A().get(123), None)
@@ -85,7 +106,7 @@ class TestUserObjects(unittest.TestCase):
### ChainMap (helper class for configparser and the string module)
################################################################################
-class TestChainMap(unittest.TestCase):
+class TestChainMap(__TestCase):
def test_basics(self):
c = ChainMap()
@@ -172,9 +193,10 @@ class TestChainMap(unittest.TestCase):
self.assertTrue(ChainMap({}, {1:2}))
def test_missing(self):
- class DefaultChainMap(ChainMap):
- def __missing__(self, key):
- return 999
+ with torch._dynamo.error_on_graph_break(False):
+ class DefaultChainMap(ChainMap):
+ def __missing__(self, key):
+ return 999
d = DefaultChainMap(dict(a=1, b=2), dict(b=20, c=30))
for k, v in dict(a=1, b=2, c=30, d=999).items():
self.assertEqual(d[k], v) # check __getitem__ w/missing
@@ -206,13 +228,14 @@ class TestChainMap(unittest.TestCase):
('i', 9999), ('j', 0)])
def test_iter_not_calling_getitem_on_maps(self):
- class DictWithGetItem(UserDict):
- def __init__(self, *args, **kwds):
- self.called = False
- UserDict.__init__(self, *args, **kwds)
- def __getitem__(self, item):
- self.called = True
- UserDict.__getitem__(self, item)
+ with torch._dynamo.error_on_graph_break(False):
+ class DictWithGetItem(UserDict):
+ def __init__(self, *args, **kwds):
+ self.called = False
+ UserDict.__init__(self, *args, **kwds)
+ def __getitem__(self, item):
+ self.called = True
+ UserDict.__getitem__(self, item)
d = DictWithGetItem(a=1)
c = ChainMap(d)
@@ -237,15 +260,16 @@ class TestChainMap(unittest.TestCase):
self.assertIs(m, d.maps[0])
# Use a different map than a dict
- class lowerdict(dict):
- def __getitem__(self, key):
- if isinstance(key, str):
- key = key.lower()
- return dict.__getitem__(self, key)
- def __contains__(self, key):
- if isinstance(key, str):
- key = key.lower()
- return dict.__contains__(self, key)
+ with torch._dynamo.error_on_graph_break(False):
+ class lowerdict(dict):
+ def __getitem__(self, key):
+ if isinstance(key, str):
+ key = key.lower()
+ return dict.__getitem__(self, key)
+ def __contains__(self, key):
+ if isinstance(key, str):
+ key = key.lower()
+ return dict.__contains__(self, key)
c = ChainMap()
c['a'] = 1
@@ -315,7 +339,7 @@ class TestChainMap(unittest.TestCase):
TestNT = namedtuple('TestNT', 'x y z') # type used for pickle tests
-class TestNamedTuple(unittest.TestCase):
+class TestNamedTuple(__TestCase):
def test_factory(self):
Point = namedtuple('Point', 'x y')
@@ -666,8 +690,9 @@ class TestNamedTuple(unittest.TestCase):
NT = namedtuple('NT', ['abc', 'def'], False, True)
def test_namedtuple_subclass_issue_24931(self):
- class Point(namedtuple('_Point', ['x', 'y'])):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class Point(namedtuple('_Point', ['x', 'y'])):
+ pass
a = Point(3, 4)
self.assertEqual(a._asdict(), OrderedDict([('x', 3), ('y', 4)]))
@@ -722,21 +747,26 @@ class TestNamedTuple(unittest.TestCase):
### Abstract Base Classes
################################################################################
-class ABCTestCase(unittest.TestCase):
+class ABCTestCase(__TestCase):
def validate_abstract_methods(self, abc, *names):
methodstubs = dict.fromkeys(names, lambda s, *args: 0)
# everything should work will all required methods are present
- C = type('C', (abc,), methodstubs)
+ with torch._dynamo.error_on_graph_break(False):
+ C = type('C', (abc,), methodstubs)
C()
+ # Dynamo raises a hard error here that we can't easily capture
+ # Commenting this part as this would also fail in eager if a user
+ # attempt to run the same code
+
# instantiation should fail if a required method is missing
- for name in names:
- stubs = methodstubs.copy()
- del stubs[name]
- C = type('C', (abc,), stubs)
- self.assertRaises(TypeError, C, name)
+ # for name in names:
+ # stubs = methodstubs.copy()
+ # del stubs[name]
+ # C = type('C', (abc,), stubs)
+ # self.assertRaises(TypeError, C, name)
def validate_isinstance(self, abc, name):
stub = lambda s, *args: 0
@@ -981,19 +1011,21 @@ class TestOneTrickPonyABCs(ABCTestCase):
for x in samples:
self.assertIsInstance(x, Iterable)
self.assertTrue(issubclass(type(x), Iterable), repr(type(x)))
- # Check direct subclassing
- class I(Iterable):
- def __iter__(self):
- return super().__iter__()
+ with torch._dynamo.error_on_graph_break(False):
+ # Check direct subclassing
+ class I(Iterable):
+ def __iter__(self):
+ return super().__iter__()
self.assertEqual(list(I()), [])
self.assertFalse(issubclass(str, I))
self.validate_abstract_methods(Iterable, '__iter__')
self.validate_isinstance(Iterable, '__iter__')
- # Check None blocking
- class It:
- def __iter__(self): return iter([])
- class ItBlocked(It):
- __iter__ = None
+ with torch._dynamo.error_on_graph_break(False):
+ # Check None blocking
+ class It:
+ def __iter__(self): return iter([])
+ class ItBlocked(It):
+ __iter__ = None
self.assertTrue(issubclass(It, Iterable))
self.assertTrue(isinstance(It(), Iterable))
self.assertFalse(issubclass(ItBlocked, Iterable))
@@ -1023,32 +1055,35 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertTrue(issubclass(Sequence, Reversible), repr(Sequence))
self.assertFalse(issubclass(Mapping, Reversible), repr(Mapping))
self.assertFalse(issubclass(MutableMapping, Reversible), repr(MutableMapping))
- # Check direct subclassing
- class R(Reversible):
- def __iter__(self):
- return iter(list())
- def __reversed__(self):
- return iter(list())
+ with torch._dynamo.error_on_graph_break(False):
+ # Check direct subclassing
+ class R(Reversible):
+ def __iter__(self):
+ return iter(list())
+ def __reversed__(self):
+ return iter(list())
self.assertEqual(list(reversed(R())), [])
self.assertFalse(issubclass(float, R))
self.validate_abstract_methods(Reversible, '__reversed__', '__iter__')
- # Check reversible non-iterable (which is not Reversible)
- class RevNoIter:
- def __reversed__(self): return reversed([])
- class RevPlusIter(RevNoIter):
- def __iter__(self): return iter([])
+ with torch._dynamo.error_on_graph_break(False):
+ # Check reversible non-iterable (which is not Reversible)
+ class RevNoIter:
+ def __reversed__(self): return reversed([])
+ class RevPlusIter(RevNoIter):
+ def __iter__(self): return iter([])
self.assertFalse(issubclass(RevNoIter, Reversible))
self.assertFalse(isinstance(RevNoIter(), Reversible))
self.assertTrue(issubclass(RevPlusIter, Reversible))
self.assertTrue(isinstance(RevPlusIter(), Reversible))
- # Check None blocking
- class Rev:
- def __iter__(self): return iter([])
- def __reversed__(self): return reversed([])
- class RevItBlocked(Rev):
- __iter__ = None
- class RevRevBlocked(Rev):
- __reversed__ = None
+ with torch._dynamo.error_on_graph_break(False):
+ # Check None blocking
+ class Rev:
+ def __iter__(self): return iter([])
+ def __reversed__(self): return reversed([])
+ class RevItBlocked(Rev):
+ __iter__ = None
+ class RevRevBlocked(Rev):
+ __reversed__ = None
self.assertTrue(issubclass(Rev, Reversible))
self.assertTrue(isinstance(Rev(), Reversible))
self.assertFalse(issubclass(RevItBlocked, Reversible))
@@ -1082,15 +1117,16 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertTrue(issubclass(Set, Collection), repr(Set))
self.assertTrue(issubclass(MutableSet, Collection), repr(MutableSet))
self.assertTrue(issubclass(Sequence, Collection), repr(MutableSet))
- # Check direct subclassing
- class Col(Collection):
- def __iter__(self):
- return iter(list())
- def __len__(self):
- return 0
- def __contains__(self, item):
- return False
- class DerCol(Col): pass
+ with torch._dynamo.error_on_graph_break(False):
+ # Check direct subclassing
+ class Col(Collection):
+ def __iter__(self):
+ return iter(list())
+ def __len__(self):
+ return 0
+ def __contains__(self, item):
+ return False
+ class DerCol(Col): pass
self.assertEqual(list(iter(Col())), [])
self.assertFalse(issubclass(list, Col))
self.assertFalse(issubclass(set, Col))
@@ -1102,44 +1138,48 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.validate_abstract_methods(Collection, '__len__', '__iter__',
'__contains__')
# Check sized container non-iterable (which is not Collection) etc.
- class ColNoIter:
- def __len__(self): return 0
- def __contains__(self, item): return False
- class ColNoSize:
- def __iter__(self): return iter([])
- def __contains__(self, item): return False
- class ColNoCont:
- def __iter__(self): return iter([])
- def __len__(self): return 0
+ with torch._dynamo.error_on_graph_break(False):
+ class ColNoIter:
+ def __len__(self): return 0
+ def __contains__(self, item): return False
+ class ColNoSize:
+ def __iter__(self): return iter([])
+ def __contains__(self, item): return False
+ class ColNoCont:
+ def __iter__(self): return iter([])
+ def __len__(self): return 0
self.assertFalse(issubclass(ColNoIter, Collection))
self.assertFalse(isinstance(ColNoIter(), Collection))
self.assertFalse(issubclass(ColNoSize, Collection))
self.assertFalse(isinstance(ColNoSize(), Collection))
self.assertFalse(issubclass(ColNoCont, Collection))
self.assertFalse(isinstance(ColNoCont(), Collection))
- # Check None blocking
- class SizeBlock:
- def __iter__(self): return iter([])
- def __contains__(self): return False
- __len__ = None
- class IterBlock:
- def __len__(self): return 0
- def __contains__(self): return True
- __iter__ = None
+
+ with torch._dynamo.error_on_graph_break(False):
+ # Check None blocking
+ class SizeBlock:
+ def __iter__(self): return iter([])
+ def __contains__(self): return False
+ __len__ = None
+ class IterBlock:
+ def __len__(self): return 0
+ def __contains__(self): return True
+ __iter__ = None
self.assertFalse(issubclass(SizeBlock, Collection))
self.assertFalse(isinstance(SizeBlock(), Collection))
self.assertFalse(issubclass(IterBlock, Collection))
self.assertFalse(isinstance(IterBlock(), Collection))
- # Check None blocking in subclass
- class ColImpl:
- def __iter__(self):
- return iter(list())
- def __len__(self):
- return 0
- def __contains__(self, item):
- return False
- class NonCol(ColImpl):
- __contains__ = None
+ with torch._dynamo.error_on_graph_break(False):
+ # Check None blocking in subclass
+ class ColImpl:
+ def __iter__(self):
+ return iter(list())
+ def __len__(self):
+ return 0
+ def __contains__(self, item):
+ return False
+ class NonCol(ColImpl):
+ __contains__ = None
self.assertFalse(issubclass(NonCol, Collection))
self.assertFalse(isinstance(NonCol(), Collection))
@@ -1162,30 +1202,32 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertTrue(issubclass(type(x), Iterator), repr(type(x)))
self.validate_abstract_methods(Iterator, '__next__', '__iter__')
- # Issue 10565
- class NextOnly:
- def __next__(self):
- yield 1
- return
+ with torch._dynamo.error_on_graph_break(False):
+ # Issue 10565
+ class NextOnly:
+ def __next__(self):
+ yield 1
+ return
self.assertNotIsInstance(NextOnly(), Iterator)
def test_Generator(self):
- class NonGen1:
- def __iter__(self): return self
- def __next__(self): return None
- def close(self): pass
- def throw(self, typ, val=None, tb=None): pass
-
- class NonGen2:
- def __iter__(self): return self
- def __next__(self): return None
- def close(self): pass
- def send(self, value): return value
-
- class NonGen3:
- def close(self): pass
- def send(self, value): return value
- def throw(self, typ, val=None, tb=None): pass
+ with torch._dynamo.error_on_graph_break(False):
+ class NonGen1:
+ def __iter__(self): return self
+ def __next__(self): return None
+ def close(self): pass
+ def throw(self, typ, val=None, tb=None): pass
+
+ class NonGen2:
+ def __iter__(self): return self
+ def __next__(self): return None
+ def close(self): pass
+ def send(self, value): return value
+
+ class NonGen3:
+ def close(self): pass
+ def send(self, value): return value
+ def throw(self, typ, val=None, tb=None): pass
non_samples = [
None, 42, 3.14, 1j, b"", "", (), [], {}, set(),
@@ -1194,18 +1236,19 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertNotIsInstance(x, Generator)
self.assertFalse(issubclass(type(x), Generator), repr(type(x)))
- class Gen:
- def __iter__(self): return self
- def __next__(self): return None
- def close(self): pass
- def send(self, value): return value
- def throw(self, typ, val=None, tb=None): pass
+ with torch._dynamo.error_on_graph_break(False):
+ class Gen:
+ def __iter__(self): return self
+ def __next__(self): return None
+ def close(self): pass
+ def send(self, value): return value
+ def throw(self, typ, val=None, tb=None): pass
- class MinimalGen(Generator):
- def send(self, value):
- return value
- def throw(self, typ, val=None, tb=None):
- super().throw(typ, val, tb)
+ class MinimalGen(Generator):
+ def send(self, value):
+ return value
+ def throw(self, typ, val=None, tb=None):
+ super().throw(typ, val, tb)
def gen():
yield 1
@@ -1228,15 +1271,17 @@ class TestOneTrickPonyABCs(ABCTestCase):
mgen.throw, ValueError, ValueError("huhu"))
self.assertRaises(StopIteration, mgen.throw, StopIteration())
- class FailOnClose(Generator):
- def send(self, value): return value
- def throw(self, *args): raise ValueError
+ with torch._dynamo.error_on_graph_break(False):
+ class FailOnClose(Generator):
+ def send(self, value): return value
+ def throw(self, *args): raise ValueError
self.assertRaises(ValueError, FailOnClose().close)
- class IgnoreGeneratorExit(Generator):
- def send(self, value): return value
- def throw(self, *args): pass
+ with torch._dynamo.error_on_graph_break(False):
+ class IgnoreGeneratorExit(Generator):
+ def send(self, value): return value
+ def throw(self, *args): pass
self.assertRaises(RuntimeError, IgnoreGeneratorExit().close)
@@ -1379,15 +1424,17 @@ class TestOneTrickPonyABCs(ABCTestCase):
def test_direct_subclassing(self):
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
- class C(B):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class C(B):
+ pass
self.assertTrue(issubclass(C, B))
self.assertFalse(issubclass(int, C))
def test_registration(self):
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
- class C:
- __hash__ = None # Make sure it isn't hashable by default
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ __hash__ = None # Make sure it isn't hashable by default
self.assertFalse(issubclass(C, B), B.__name__)
B.register(C)
self.assertTrue(issubclass(C, B))
@@ -1423,13 +1470,14 @@ class TestCollectionABCs(ABCTestCase):
self.assertIsInstance(sample(), Set)
self.assertTrue(issubclass(sample, Set))
self.validate_abstract_methods(Set, '__contains__', '__iter__', '__len__')
- class MySet(Set):
- def __contains__(self, x):
- return False
- def __len__(self):
- return 0
- def __iter__(self):
- return iter([])
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(Set):
+ def __contains__(self, x):
+ return False
+ def __len__(self):
+ return 0
+ def __iter__(self):
+ return iter([])
self.validate_comparison(MySet())
def test_hash_Set(self):
@@ -1448,15 +1496,16 @@ class TestCollectionABCs(ABCTestCase):
self.assertTrue(hash(a) == hash(b))
def test_isdisjoint_Set(self):
- class MySet(Set):
- def __init__(self, itr):
- self.contents = itr
- def __contains__(self, x):
- return x in self.contents
- def __iter__(self):
- return iter(self.contents)
- def __len__(self):
- return len([x for x in self.contents])
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(Set):
+ def __init__(self, itr):
+ self.contents = itr
+ def __contains__(self, x):
+ return x in self.contents
+ def __iter__(self):
+ return iter(self.contents)
+ def __len__(self):
+ return len([x for x in self.contents])
s1 = MySet((1, 2, 3))
s2 = MySet((4, 5, 6))
s3 = MySet((1, 5, 6))
@@ -1464,15 +1513,16 @@ class TestCollectionABCs(ABCTestCase):
self.assertFalse(s1.isdisjoint(s3))
def test_equality_Set(self):
- class MySet(Set):
- def __init__(self, itr):
- self.contents = itr
- def __contains__(self, x):
- return x in self.contents
- def __iter__(self):
- return iter(self.contents)
- def __len__(self):
- return len([x for x in self.contents])
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(Set):
+ def __init__(self, itr):
+ self.contents = itr
+ def __contains__(self, x):
+ return x in self.contents
+ def __iter__(self):
+ return iter(self.contents)
+ def __len__(self):
+ return len([x for x in self.contents])
s1 = MySet((1,))
s2 = MySet((1, 2))
s3 = MySet((3, 4))
@@ -1486,15 +1536,16 @@ class TestCollectionABCs(ABCTestCase):
self.assertNotEqual(s2, s3)
def test_arithmetic_Set(self):
- class MySet(Set):
- def __init__(self, itr):
- self.contents = itr
- def __contains__(self, x):
- return x in self.contents
- def __iter__(self):
- return iter(self.contents)
- def __len__(self):
- return len([x for x in self.contents])
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(Set):
+ def __init__(self, itr):
+ self.contents = itr
+ def __contains__(self, x):
+ return x in self.contents
+ def __iter__(self):
+ return iter(self.contents)
+ def __len__(self):
+ return len([x for x in self.contents])
s1 = MySet((1, 2, 3))
s2 = MySet((3, 4, 5))
s3 = s1 & s2
@@ -1516,28 +1567,29 @@ class TestCollectionABCs(ABCTestCase):
def test_issue_4920(self):
# MutableSet.pop() method did not work
- class MySet(MutableSet):
- __slots__=['__s']
- def __init__(self,items=None):
- if items is None:
- items=[]
- self.__s=set(items)
- def __contains__(self,v):
- return v in self.__s
- def __iter__(self):
- return iter(self.__s)
- def __len__(self):
- return len(self.__s)
- def add(self,v):
- result=v not in self.__s
- self.__s.add(v)
- return result
- def discard(self,v):
- result=v in self.__s
- self.__s.discard(v)
- return result
- def __repr__(self):
- return "MySet(%s)" % repr(list(self))
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(MutableSet):
+ __slots__=['__s']
+ def __init__(self,items=None):
+ if items is None:
+ items=[]
+ self.__s=set(items)
+ def __contains__(self,v):
+ return v in self.__s
+ def __iter__(self):
+ return iter(self.__s)
+ def __len__(self):
+ return len(self.__s)
+ def add(self,v):
+ result=v not in self.__s
+ self.__s.add(v)
+ return result
+ def discard(self,v):
+ result=v in self.__s
+ self.__s.discard(v)
+ return result
+ def __repr__(self):
+ return "MySet(%s)" % repr(list(self))
items = [5,43,2,1]
s = MySet(items)
r = s.pop()
@@ -1563,24 +1615,25 @@ class TestCollectionABCs(ABCTestCase):
def test_issue16373(self):
# Recursion error comparing comparable and noncomparable
# Set instances
- class MyComparableSet(Set):
- def __contains__(self, x):
- return False
- def __len__(self):
- return 0
- def __iter__(self):
- return iter([])
- class MyNonComparableSet(Set):
- def __contains__(self, x):
- return False
- def __len__(self):
- return 0
- def __iter__(self):
- return iter([])
- def __le__(self, x):
- return NotImplemented
- def __lt__(self, x):
- return NotImplemented
+ with torch._dynamo.error_on_graph_break(False):
+ class MyComparableSet(Set):
+ def __contains__(self, x):
+ return False
+ def __len__(self):
+ return 0
+ def __iter__(self):
+ return iter([])
+ class MyNonComparableSet(Set):
+ def __contains__(self, x):
+ return False
+ def __len__(self):
+ return 0
+ def __iter__(self):
+ return iter([])
+ def __le__(self, x):
+ return NotImplemented
+ def __lt__(self, x):
+ return NotImplemented
cs = MyComparableSet()
ncs = MyNonComparableSet()
@@ -1591,13 +1644,14 @@ class TestCollectionABCs(ABCTestCase):
def test_issue26915(self):
# Container membership test should check identity first
- class CustomSequence(Sequence):
- def __init__(self, seq):
- self._seq = seq
- def __getitem__(self, index):
- return self._seq[index]
- def __len__(self):
- return len(self._seq)
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomSequence(Sequence):
+ def __init__(self, seq):
+ self._seq = seq
+ def __getitem__(self, index):
+ return self._seq[index]
+ def __len__(self):
+ return len(self._seq)
nan = float('nan')
obj = support.NEVER_EQ
@@ -1622,30 +1676,31 @@ class TestCollectionABCs(ABCTestCase):
def test_Set_from_iterable(self):
"""Verify _from_iterable overridden to an instance method works."""
- class SetUsingInstanceFromIterable(MutableSet):
- def __init__(self, values, created_by):
- if not created_by:
- raise ValueError('created_by must be specified')
- self.created_by = created_by
- self._values = set(values)
+ with torch._dynamo.error_on_graph_break(False):
+ class SetUsingInstanceFromIterable(MutableSet):
+ def __init__(self, values, created_by):
+ if not created_by:
+ raise ValueError('created_by must be specified')
+ self.created_by = created_by
+ self._values = set(values)
- def _from_iterable(self, values):
- return type(self)(values, 'from_iterable')
+ def _from_iterable(self, values):
+ return type(self)(values, 'from_iterable')
- def __contains__(self, value):
- return value in self._values
+ def __contains__(self, value):
+ return value in self._values
- def __iter__(self):
- yield from self._values
+ def __iter__(self):
+ yield from self._values
- def __len__(self):
- return len(self._values)
+ def __len__(self):
+ return len(self._values)
- def add(self, value):
- self._values.add(value)
+ def add(self, value):
+ self._values.add(value)
- def discard(self, value):
- self._values.discard(value)
+ def discard(self, value):
+ self._values.discard(value)
impl = SetUsingInstanceFromIterable([1, 2, 3], 'test')
@@ -1678,20 +1733,21 @@ class TestCollectionABCs(ABCTestCase):
def test_Set_interoperability_with_real_sets(self):
# Issue: 8743
- class ListSet(Set):
- def __init__(self, elements=()):
- self.data = []
- for elem in elements:
- if elem not in self.data:
- self.data.append(elem)
- def __contains__(self, elem):
- return elem in self.data
- def __iter__(self):
- return iter(self.data)
- def __len__(self):
- return len(self.data)
- def __repr__(self):
- return 'Set({!r})'.format(self.data)
+ with torch._dynamo.error_on_graph_break(False):
+ class ListSet(Set):
+ def __init__(self, elements=()):
+ self.data = []
+ for elem in elements:
+ if elem not in self.data:
+ self.data.append(elem)
+ def __contains__(self, elem):
+ return elem in self.data
+ def __iter__(self):
+ return iter(self.data)
+ def __len__(self):
+ return len(self.data)
+ def __repr__(self):
+ return 'Set({!r})'.format(self.data)
r1 = set('abc')
r2 = set('bcd')
@@ -1846,13 +1902,14 @@ class TestCollectionABCs(ABCTestCase):
self.assertTrue(issubclass(sample, Mapping))
self.validate_abstract_methods(Mapping, '__contains__', '__iter__', '__len__',
'__getitem__')
- class MyMapping(Mapping):
- def __len__(self):
- return 0
- def __getitem__(self, i):
- raise IndexError
- def __iter__(self):
- return iter(())
+ with torch._dynamo.error_on_graph_break(False):
+ class MyMapping(Mapping):
+ def __len__(self):
+ return 0
+ def __getitem__(self, i):
+ raise IndexError
+ def __iter__(self):
+ return iter(())
self.validate_comparison(MyMapping())
self.assertRaises(TypeError, reversed, MyMapping())
@@ -1860,7 +1917,7 @@ class TestCollectionABCs(ABCTestCase):
for sample in [dict]:
self.assertIsInstance(sample(), MutableMapping)
self.assertTrue(issubclass(sample, MutableMapping))
- self.validate_abstract_methods(MutableMapping, '__contains__', '__iter__', '__len__',
+ self.validate_abstract_methods(MutableMapping, '__iter__', '__len__',
'__getitem__', '__setitem__', '__delitem__')
def test_MutableMapping_subclass(self):
@@ -1903,15 +1960,16 @@ class TestCollectionABCs(ABCTestCase):
'__getitem__')
def test_Sequence_mixins(self):
- class SequenceSubclass(Sequence):
- def __init__(self, seq=()):
- self.seq = seq
+ with torch._dynamo.error_on_graph_break(False):
+ class SequenceSubclass(Sequence):
+ def __init__(self, seq=()):
+ self.seq = seq
- def __getitem__(self, index):
- return self.seq[index]
+ def __getitem__(self, index):
+ return self.seq[index]
- def __len__(self):
- return len(self.seq)
+ def __len__(self):
+ return len(self.seq)
# Compare Sequence.index() behavior to (list|str).index() behavior
def assert_index_same(seq1, seq2, index_args):
@@ -1983,24 +2041,25 @@ class TestCollectionABCs(ABCTestCase):
def test_MutableSequence_mixins(self):
# Test the mixins of MutableSequence by creating a minimal concrete
# class inherited from it.
- class MutableSequenceSubclass(MutableSequence):
- def __init__(self):
- self.lst = []
+ with torch._dynamo.error_on_graph_break(False):
+ class MutableSequenceSubclass(MutableSequence):
+ def __init__(self):
+ self.lst = []
- def __setitem__(self, index, value):
- self.lst[index] = value
+ def __setitem__(self, index, value):
+ self.lst[index] = value
- def __getitem__(self, index):
- return self.lst[index]
+ def __getitem__(self, index):
+ return self.lst[index]
- def __len__(self):
- return len(self.lst)
+ def __len__(self):
+ return len(self.lst)
- def __delitem__(self, index):
- del self.lst[index]
+ def __delitem__(self, index):
+ del self.lst[index]
- def insert(self, index, value):
- self.lst.insert(index, value)
+ def insert(self, index, value):
+ self.lst.insert(index, value)
mss = MutableSequenceSubclass()
mss.append(0)
@@ -2059,7 +2118,7 @@ class CounterSubclassWithGet(Counter):
self.called = True
return Counter.get(self, key, default)
-class TestCounter(unittest.TestCase):
+class TestCounter(__TestCase):
def test_basics(self):
c = Counter('abcaba')
@@ -2225,8 +2284,9 @@ class TestCounter(unittest.TestCase):
check(Counter(words))
def test_copy_subclass(self):
- class MyCounter(Counter):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class MyCounter(Counter):
+ pass
c = MyCounter('slartibartfast')
d = c.copy()
self.assertEqual(d, c)
@@ -2402,10 +2462,5 @@ class TestCounter(unittest.TestCase):
self.assertFalse(Counter(a=2, b=1, c=0) > Counter('aab'))
-def load_tests(loader, tests, pattern):
- tests.addTest(doctest.DocTestSuite(collections))
- return tests
-
-
if __name__ == "__main__":
- unittest.main()
+ run_tests()