mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph. I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet). cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739 Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
929 lines
36 KiB
Diff
929 lines
36 KiB
Diff
diff --git a/test/dynamo/cpython/3_13/test_collections.py b/test/dynamo/cpython/3_13/test_collections.py
|
|
index cafc44007d1..4571e5a14fd 100644
|
|
--- a/test/dynamo/cpython/3_13/test_collections.py
|
|
+++ b/test/dynamo/cpython/3_13/test_collections.py
|
|
@@ -1,3 +1,23 @@
|
|
+# ======= BEGIN Dynamo patch =======
|
|
+# Owner(s): ["module: dynamo"]
|
|
+
|
|
+# ruff: noqa
|
|
+# flake8: noqa
|
|
+
|
|
+# Test copied from
|
|
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_collections.py
|
|
+
|
|
+import sys
|
|
+import torch
|
|
+import torch._dynamo.test_case
|
|
+import unittest
|
|
+from torch._dynamo.test_case import CPythonTestCase
|
|
+from torch.testing._internal.common_utils import run_tests
|
|
+
|
|
+__TestCase = CPythonTestCase
|
|
+
|
|
+# ======= END DYNAMO PATCH =======
|
|
+
|
|
"""Unit tests for collections.py."""
|
|
|
|
import array
|
|
@@ -29,7 +49,7 @@ from collections.abc import Sequence, MutableSequence
|
|
from collections.abc import ByteString, Buffer
|
|
|
|
|
|
-class TestUserObjects(unittest.TestCase):
|
|
+class TestUserObjects(__TestCase):
|
|
def _superset_test(self, a, b):
|
|
self.assertGreaterEqual(
|
|
set(dir(a)),
|
|
@@ -73,9 +93,10 @@ class TestUserObjects(unittest.TestCase):
|
|
self._copy_test(obj)
|
|
|
|
def test_dict_missing(self):
|
|
- class A(UserDict):
|
|
- def __missing__(self, key):
|
|
- return 456
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class A(UserDict):
|
|
+ def __missing__(self, key):
|
|
+ return 456
|
|
self.assertEqual(A()[123], 456)
|
|
# get() ignores __missing__ on dict
|
|
self.assertIs(A().get(123), None)
|
|
@@ -85,7 +106,7 @@ class TestUserObjects(unittest.TestCase):
|
|
### ChainMap (helper class for configparser and the string module)
|
|
################################################################################
|
|
|
|
-class TestChainMap(unittest.TestCase):
|
|
+class TestChainMap(__TestCase):
|
|
|
|
def test_basics(self):
|
|
c = ChainMap()
|
|
@@ -172,9 +193,10 @@ class TestChainMap(unittest.TestCase):
|
|
self.assertTrue(ChainMap({}, {1:2}))
|
|
|
|
def test_missing(self):
|
|
- class DefaultChainMap(ChainMap):
|
|
- def __missing__(self, key):
|
|
- return 999
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class DefaultChainMap(ChainMap):
|
|
+ def __missing__(self, key):
|
|
+ return 999
|
|
d = DefaultChainMap(dict(a=1, b=2), dict(b=20, c=30))
|
|
for k, v in dict(a=1, b=2, c=30, d=999).items():
|
|
self.assertEqual(d[k], v) # check __getitem__ w/missing
|
|
@@ -206,13 +228,14 @@ class TestChainMap(unittest.TestCase):
|
|
('i', 9999), ('j', 0)])
|
|
|
|
def test_iter_not_calling_getitem_on_maps(self):
|
|
- class DictWithGetItem(UserDict):
|
|
- def __init__(self, *args, **kwds):
|
|
- self.called = False
|
|
- UserDict.__init__(self, *args, **kwds)
|
|
- def __getitem__(self, item):
|
|
- self.called = True
|
|
- UserDict.__getitem__(self, item)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class DictWithGetItem(UserDict):
|
|
+ def __init__(self, *args, **kwds):
|
|
+ self.called = False
|
|
+ UserDict.__init__(self, *args, **kwds)
|
|
+ def __getitem__(self, item):
|
|
+ self.called = True
|
|
+ UserDict.__getitem__(self, item)
|
|
|
|
d = DictWithGetItem(a=1)
|
|
c = ChainMap(d)
|
|
@@ -237,15 +260,16 @@ class TestChainMap(unittest.TestCase):
|
|
self.assertIs(m, d.maps[0])
|
|
|
|
# Use a different map than a dict
|
|
- class lowerdict(dict):
|
|
- def __getitem__(self, key):
|
|
- if isinstance(key, str):
|
|
- key = key.lower()
|
|
- return dict.__getitem__(self, key)
|
|
- def __contains__(self, key):
|
|
- if isinstance(key, str):
|
|
- key = key.lower()
|
|
- return dict.__contains__(self, key)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class lowerdict(dict):
|
|
+ def __getitem__(self, key):
|
|
+ if isinstance(key, str):
|
|
+ key = key.lower()
|
|
+ return dict.__getitem__(self, key)
|
|
+ def __contains__(self, key):
|
|
+ if isinstance(key, str):
|
|
+ key = key.lower()
|
|
+ return dict.__contains__(self, key)
|
|
|
|
c = ChainMap()
|
|
c['a'] = 1
|
|
@@ -315,7 +339,7 @@ class TestChainMap(unittest.TestCase):
|
|
|
|
TestNT = namedtuple('TestNT', 'x y z') # type used for pickle tests
|
|
|
|
-class TestNamedTuple(unittest.TestCase):
|
|
+class TestNamedTuple(__TestCase):
|
|
|
|
def test_factory(self):
|
|
Point = namedtuple('Point', 'x y')
|
|
@@ -666,8 +690,9 @@ class TestNamedTuple(unittest.TestCase):
|
|
NT = namedtuple('NT', ['abc', 'def'], False, True)
|
|
|
|
def test_namedtuple_subclass_issue_24931(self):
|
|
- class Point(namedtuple('_Point', ['x', 'y'])):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Point(namedtuple('_Point', ['x', 'y'])):
|
|
+ pass
|
|
|
|
a = Point(3, 4)
|
|
self.assertEqual(a._asdict(), OrderedDict([('x', 3), ('y', 4)]))
|
|
@@ -722,21 +747,26 @@ class TestNamedTuple(unittest.TestCase):
|
|
### Abstract Base Classes
|
|
################################################################################
|
|
|
|
-class ABCTestCase(unittest.TestCase):
|
|
+class ABCTestCase(__TestCase):
|
|
|
|
def validate_abstract_methods(self, abc, *names):
|
|
methodstubs = dict.fromkeys(names, lambda s, *args: 0)
|
|
|
|
# everything should work will all required methods are present
|
|
- C = type('C', (abc,), methodstubs)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ C = type('C', (abc,), methodstubs)
|
|
C()
|
|
|
|
+ # Dynamo raises a hard error here that we can't easily capture
|
|
+ # Commenting this part as this would also fail in eager if a user
|
|
+ # attempt to run the same code
|
|
+
|
|
# instantiation should fail if a required method is missing
|
|
- for name in names:
|
|
- stubs = methodstubs.copy()
|
|
- del stubs[name]
|
|
- C = type('C', (abc,), stubs)
|
|
- self.assertRaises(TypeError, C, name)
|
|
+ # for name in names:
|
|
+ # stubs = methodstubs.copy()
|
|
+ # del stubs[name]
|
|
+ # C = type('C', (abc,), stubs)
|
|
+ # self.assertRaises(TypeError, C, name)
|
|
|
|
def validate_isinstance(self, abc, name):
|
|
stub = lambda s, *args: 0
|
|
@@ -981,19 +1011,21 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|
for x in samples:
|
|
self.assertIsInstance(x, Iterable)
|
|
self.assertTrue(issubclass(type(x), Iterable), repr(type(x)))
|
|
- # Check direct subclassing
|
|
- class I(Iterable):
|
|
- def __iter__(self):
|
|
- return super().__iter__()
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ # Check direct subclassing
|
|
+ class I(Iterable):
|
|
+ def __iter__(self):
|
|
+ return super().__iter__()
|
|
self.assertEqual(list(I()), [])
|
|
self.assertFalse(issubclass(str, I))
|
|
self.validate_abstract_methods(Iterable, '__iter__')
|
|
self.validate_isinstance(Iterable, '__iter__')
|
|
- # Check None blocking
|
|
- class It:
|
|
- def __iter__(self): return iter([])
|
|
- class ItBlocked(It):
|
|
- __iter__ = None
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ # Check None blocking
|
|
+ class It:
|
|
+ def __iter__(self): return iter([])
|
|
+ class ItBlocked(It):
|
|
+ __iter__ = None
|
|
self.assertTrue(issubclass(It, Iterable))
|
|
self.assertTrue(isinstance(It(), Iterable))
|
|
self.assertFalse(issubclass(ItBlocked, Iterable))
|
|
@@ -1023,32 +1055,35 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|
self.assertTrue(issubclass(Sequence, Reversible), repr(Sequence))
|
|
self.assertFalse(issubclass(Mapping, Reversible), repr(Mapping))
|
|
self.assertFalse(issubclass(MutableMapping, Reversible), repr(MutableMapping))
|
|
- # Check direct subclassing
|
|
- class R(Reversible):
|
|
- def __iter__(self):
|
|
- return iter(list())
|
|
- def __reversed__(self):
|
|
- return iter(list())
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ # Check direct subclassing
|
|
+ class R(Reversible):
|
|
+ def __iter__(self):
|
|
+ return iter(list())
|
|
+ def __reversed__(self):
|
|
+ return iter(list())
|
|
self.assertEqual(list(reversed(R())), [])
|
|
self.assertFalse(issubclass(float, R))
|
|
self.validate_abstract_methods(Reversible, '__reversed__', '__iter__')
|
|
- # Check reversible non-iterable (which is not Reversible)
|
|
- class RevNoIter:
|
|
- def __reversed__(self): return reversed([])
|
|
- class RevPlusIter(RevNoIter):
|
|
- def __iter__(self): return iter([])
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ # Check reversible non-iterable (which is not Reversible)
|
|
+ class RevNoIter:
|
|
+ def __reversed__(self): return reversed([])
|
|
+ class RevPlusIter(RevNoIter):
|
|
+ def __iter__(self): return iter([])
|
|
self.assertFalse(issubclass(RevNoIter, Reversible))
|
|
self.assertFalse(isinstance(RevNoIter(), Reversible))
|
|
self.assertTrue(issubclass(RevPlusIter, Reversible))
|
|
self.assertTrue(isinstance(RevPlusIter(), Reversible))
|
|
- # Check None blocking
|
|
- class Rev:
|
|
- def __iter__(self): return iter([])
|
|
- def __reversed__(self): return reversed([])
|
|
- class RevItBlocked(Rev):
|
|
- __iter__ = None
|
|
- class RevRevBlocked(Rev):
|
|
- __reversed__ = None
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ # Check None blocking
|
|
+ class Rev:
|
|
+ def __iter__(self): return iter([])
|
|
+ def __reversed__(self): return reversed([])
|
|
+ class RevItBlocked(Rev):
|
|
+ __iter__ = None
|
|
+ class RevRevBlocked(Rev):
|
|
+ __reversed__ = None
|
|
self.assertTrue(issubclass(Rev, Reversible))
|
|
self.assertTrue(isinstance(Rev(), Reversible))
|
|
self.assertFalse(issubclass(RevItBlocked, Reversible))
|
|
@@ -1082,15 +1117,16 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|
self.assertTrue(issubclass(Set, Collection), repr(Set))
|
|
self.assertTrue(issubclass(MutableSet, Collection), repr(MutableSet))
|
|
self.assertTrue(issubclass(Sequence, Collection), repr(MutableSet))
|
|
- # Check direct subclassing
|
|
- class Col(Collection):
|
|
- def __iter__(self):
|
|
- return iter(list())
|
|
- def __len__(self):
|
|
- return 0
|
|
- def __contains__(self, item):
|
|
- return False
|
|
- class DerCol(Col): pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ # Check direct subclassing
|
|
+ class Col(Collection):
|
|
+ def __iter__(self):
|
|
+ return iter(list())
|
|
+ def __len__(self):
|
|
+ return 0
|
|
+ def __contains__(self, item):
|
|
+ return False
|
|
+ class DerCol(Col): pass
|
|
self.assertEqual(list(iter(Col())), [])
|
|
self.assertFalse(issubclass(list, Col))
|
|
self.assertFalse(issubclass(set, Col))
|
|
@@ -1102,44 +1138,48 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|
self.validate_abstract_methods(Collection, '__len__', '__iter__',
|
|
'__contains__')
|
|
# Check sized container non-iterable (which is not Collection) etc.
|
|
- class ColNoIter:
|
|
- def __len__(self): return 0
|
|
- def __contains__(self, item): return False
|
|
- class ColNoSize:
|
|
- def __iter__(self): return iter([])
|
|
- def __contains__(self, item): return False
|
|
- class ColNoCont:
|
|
- def __iter__(self): return iter([])
|
|
- def __len__(self): return 0
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class ColNoIter:
|
|
+ def __len__(self): return 0
|
|
+ def __contains__(self, item): return False
|
|
+ class ColNoSize:
|
|
+ def __iter__(self): return iter([])
|
|
+ def __contains__(self, item): return False
|
|
+ class ColNoCont:
|
|
+ def __iter__(self): return iter([])
|
|
+ def __len__(self): return 0
|
|
self.assertFalse(issubclass(ColNoIter, Collection))
|
|
self.assertFalse(isinstance(ColNoIter(), Collection))
|
|
self.assertFalse(issubclass(ColNoSize, Collection))
|
|
self.assertFalse(isinstance(ColNoSize(), Collection))
|
|
self.assertFalse(issubclass(ColNoCont, Collection))
|
|
self.assertFalse(isinstance(ColNoCont(), Collection))
|
|
- # Check None blocking
|
|
- class SizeBlock:
|
|
- def __iter__(self): return iter([])
|
|
- def __contains__(self): return False
|
|
- __len__ = None
|
|
- class IterBlock:
|
|
- def __len__(self): return 0
|
|
- def __contains__(self): return True
|
|
- __iter__ = None
|
|
+
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ # Check None blocking
|
|
+ class SizeBlock:
|
|
+ def __iter__(self): return iter([])
|
|
+ def __contains__(self): return False
|
|
+ __len__ = None
|
|
+ class IterBlock:
|
|
+ def __len__(self): return 0
|
|
+ def __contains__(self): return True
|
|
+ __iter__ = None
|
|
self.assertFalse(issubclass(SizeBlock, Collection))
|
|
self.assertFalse(isinstance(SizeBlock(), Collection))
|
|
self.assertFalse(issubclass(IterBlock, Collection))
|
|
self.assertFalse(isinstance(IterBlock(), Collection))
|
|
- # Check None blocking in subclass
|
|
- class ColImpl:
|
|
- def __iter__(self):
|
|
- return iter(list())
|
|
- def __len__(self):
|
|
- return 0
|
|
- def __contains__(self, item):
|
|
- return False
|
|
- class NonCol(ColImpl):
|
|
- __contains__ = None
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ # Check None blocking in subclass
|
|
+ class ColImpl:
|
|
+ def __iter__(self):
|
|
+ return iter(list())
|
|
+ def __len__(self):
|
|
+ return 0
|
|
+ def __contains__(self, item):
|
|
+ return False
|
|
+ class NonCol(ColImpl):
|
|
+ __contains__ = None
|
|
self.assertFalse(issubclass(NonCol, Collection))
|
|
self.assertFalse(isinstance(NonCol(), Collection))
|
|
|
|
@@ -1162,30 +1202,32 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|
self.assertTrue(issubclass(type(x), Iterator), repr(type(x)))
|
|
self.validate_abstract_methods(Iterator, '__next__', '__iter__')
|
|
|
|
- # Issue 10565
|
|
- class NextOnly:
|
|
- def __next__(self):
|
|
- yield 1
|
|
- return
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ # Issue 10565
|
|
+ class NextOnly:
|
|
+ def __next__(self):
|
|
+ yield 1
|
|
+ return
|
|
self.assertNotIsInstance(NextOnly(), Iterator)
|
|
|
|
def test_Generator(self):
|
|
- class NonGen1:
|
|
- def __iter__(self): return self
|
|
- def __next__(self): return None
|
|
- def close(self): pass
|
|
- def throw(self, typ, val=None, tb=None): pass
|
|
-
|
|
- class NonGen2:
|
|
- def __iter__(self): return self
|
|
- def __next__(self): return None
|
|
- def close(self): pass
|
|
- def send(self, value): return value
|
|
-
|
|
- class NonGen3:
|
|
- def close(self): pass
|
|
- def send(self, value): return value
|
|
- def throw(self, typ, val=None, tb=None): pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class NonGen1:
|
|
+ def __iter__(self): return self
|
|
+ def __next__(self): return None
|
|
+ def close(self): pass
|
|
+ def throw(self, typ, val=None, tb=None): pass
|
|
+
|
|
+ class NonGen2:
|
|
+ def __iter__(self): return self
|
|
+ def __next__(self): return None
|
|
+ def close(self): pass
|
|
+ def send(self, value): return value
|
|
+
|
|
+ class NonGen3:
|
|
+ def close(self): pass
|
|
+ def send(self, value): return value
|
|
+ def throw(self, typ, val=None, tb=None): pass
|
|
|
|
non_samples = [
|
|
None, 42, 3.14, 1j, b"", "", (), [], {}, set(),
|
|
@@ -1194,18 +1236,19 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|
self.assertNotIsInstance(x, Generator)
|
|
self.assertFalse(issubclass(type(x), Generator), repr(type(x)))
|
|
|
|
- class Gen:
|
|
- def __iter__(self): return self
|
|
- def __next__(self): return None
|
|
- def close(self): pass
|
|
- def send(self, value): return value
|
|
- def throw(self, typ, val=None, tb=None): pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Gen:
|
|
+ def __iter__(self): return self
|
|
+ def __next__(self): return None
|
|
+ def close(self): pass
|
|
+ def send(self, value): return value
|
|
+ def throw(self, typ, val=None, tb=None): pass
|
|
|
|
- class MinimalGen(Generator):
|
|
- def send(self, value):
|
|
- return value
|
|
- def throw(self, typ, val=None, tb=None):
|
|
- super().throw(typ, val, tb)
|
|
+ class MinimalGen(Generator):
|
|
+ def send(self, value):
|
|
+ return value
|
|
+ def throw(self, typ, val=None, tb=None):
|
|
+ super().throw(typ, val, tb)
|
|
|
|
def gen():
|
|
yield 1
|
|
@@ -1228,15 +1271,17 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|
mgen.throw, ValueError, ValueError("huhu"))
|
|
self.assertRaises(StopIteration, mgen.throw, StopIteration())
|
|
|
|
- class FailOnClose(Generator):
|
|
- def send(self, value): return value
|
|
- def throw(self, *args): raise ValueError
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class FailOnClose(Generator):
|
|
+ def send(self, value): return value
|
|
+ def throw(self, *args): raise ValueError
|
|
|
|
self.assertRaises(ValueError, FailOnClose().close)
|
|
|
|
- class IgnoreGeneratorExit(Generator):
|
|
- def send(self, value): return value
|
|
- def throw(self, *args): pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class IgnoreGeneratorExit(Generator):
|
|
+ def send(self, value): return value
|
|
+ def throw(self, *args): pass
|
|
|
|
self.assertRaises(RuntimeError, IgnoreGeneratorExit().close)
|
|
|
|
@@ -1379,15 +1424,17 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|
|
|
def test_direct_subclassing(self):
|
|
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
|
|
- class C(B):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class C(B):
|
|
+ pass
|
|
self.assertTrue(issubclass(C, B))
|
|
self.assertFalse(issubclass(int, C))
|
|
|
|
def test_registration(self):
|
|
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
|
|
- class C:
|
|
- __hash__ = None # Make sure it isn't hashable by default
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class C:
|
|
+ __hash__ = None # Make sure it isn't hashable by default
|
|
self.assertFalse(issubclass(C, B), B.__name__)
|
|
B.register(C)
|
|
self.assertTrue(issubclass(C, B))
|
|
@@ -1423,13 +1470,14 @@ class TestCollectionABCs(ABCTestCase):
|
|
self.assertIsInstance(sample(), Set)
|
|
self.assertTrue(issubclass(sample, Set))
|
|
self.validate_abstract_methods(Set, '__contains__', '__iter__', '__len__')
|
|
- class MySet(Set):
|
|
- def __contains__(self, x):
|
|
- return False
|
|
- def __len__(self):
|
|
- return 0
|
|
- def __iter__(self):
|
|
- return iter([])
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MySet(Set):
|
|
+ def __contains__(self, x):
|
|
+ return False
|
|
+ def __len__(self):
|
|
+ return 0
|
|
+ def __iter__(self):
|
|
+ return iter([])
|
|
self.validate_comparison(MySet())
|
|
|
|
def test_hash_Set(self):
|
|
@@ -1448,15 +1496,16 @@ class TestCollectionABCs(ABCTestCase):
|
|
self.assertTrue(hash(a) == hash(b))
|
|
|
|
def test_isdisjoint_Set(self):
|
|
- class MySet(Set):
|
|
- def __init__(self, itr):
|
|
- self.contents = itr
|
|
- def __contains__(self, x):
|
|
- return x in self.contents
|
|
- def __iter__(self):
|
|
- return iter(self.contents)
|
|
- def __len__(self):
|
|
- return len([x for x in self.contents])
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MySet(Set):
|
|
+ def __init__(self, itr):
|
|
+ self.contents = itr
|
|
+ def __contains__(self, x):
|
|
+ return x in self.contents
|
|
+ def __iter__(self):
|
|
+ return iter(self.contents)
|
|
+ def __len__(self):
|
|
+ return len([x for x in self.contents])
|
|
s1 = MySet((1, 2, 3))
|
|
s2 = MySet((4, 5, 6))
|
|
s3 = MySet((1, 5, 6))
|
|
@@ -1464,15 +1513,16 @@ class TestCollectionABCs(ABCTestCase):
|
|
self.assertFalse(s1.isdisjoint(s3))
|
|
|
|
def test_equality_Set(self):
|
|
- class MySet(Set):
|
|
- def __init__(self, itr):
|
|
- self.contents = itr
|
|
- def __contains__(self, x):
|
|
- return x in self.contents
|
|
- def __iter__(self):
|
|
- return iter(self.contents)
|
|
- def __len__(self):
|
|
- return len([x for x in self.contents])
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MySet(Set):
|
|
+ def __init__(self, itr):
|
|
+ self.contents = itr
|
|
+ def __contains__(self, x):
|
|
+ return x in self.contents
|
|
+ def __iter__(self):
|
|
+ return iter(self.contents)
|
|
+ def __len__(self):
|
|
+ return len([x for x in self.contents])
|
|
s1 = MySet((1,))
|
|
s2 = MySet((1, 2))
|
|
s3 = MySet((3, 4))
|
|
@@ -1486,15 +1536,16 @@ class TestCollectionABCs(ABCTestCase):
|
|
self.assertNotEqual(s2, s3)
|
|
|
|
def test_arithmetic_Set(self):
|
|
- class MySet(Set):
|
|
- def __init__(self, itr):
|
|
- self.contents = itr
|
|
- def __contains__(self, x):
|
|
- return x in self.contents
|
|
- def __iter__(self):
|
|
- return iter(self.contents)
|
|
- def __len__(self):
|
|
- return len([x for x in self.contents])
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MySet(Set):
|
|
+ def __init__(self, itr):
|
|
+ self.contents = itr
|
|
+ def __contains__(self, x):
|
|
+ return x in self.contents
|
|
+ def __iter__(self):
|
|
+ return iter(self.contents)
|
|
+ def __len__(self):
|
|
+ return len([x for x in self.contents])
|
|
s1 = MySet((1, 2, 3))
|
|
s2 = MySet((3, 4, 5))
|
|
s3 = s1 & s2
|
|
@@ -1516,28 +1567,29 @@ class TestCollectionABCs(ABCTestCase):
|
|
|
|
def test_issue_4920(self):
|
|
# MutableSet.pop() method did not work
|
|
- class MySet(MutableSet):
|
|
- __slots__=['__s']
|
|
- def __init__(self,items=None):
|
|
- if items is None:
|
|
- items=[]
|
|
- self.__s=set(items)
|
|
- def __contains__(self,v):
|
|
- return v in self.__s
|
|
- def __iter__(self):
|
|
- return iter(self.__s)
|
|
- def __len__(self):
|
|
- return len(self.__s)
|
|
- def add(self,v):
|
|
- result=v not in self.__s
|
|
- self.__s.add(v)
|
|
- return result
|
|
- def discard(self,v):
|
|
- result=v in self.__s
|
|
- self.__s.discard(v)
|
|
- return result
|
|
- def __repr__(self):
|
|
- return "MySet(%s)" % repr(list(self))
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MySet(MutableSet):
|
|
+ __slots__=['__s']
|
|
+ def __init__(self,items=None):
|
|
+ if items is None:
|
|
+ items=[]
|
|
+ self.__s=set(items)
|
|
+ def __contains__(self,v):
|
|
+ return v in self.__s
|
|
+ def __iter__(self):
|
|
+ return iter(self.__s)
|
|
+ def __len__(self):
|
|
+ return len(self.__s)
|
|
+ def add(self,v):
|
|
+ result=v not in self.__s
|
|
+ self.__s.add(v)
|
|
+ return result
|
|
+ def discard(self,v):
|
|
+ result=v in self.__s
|
|
+ self.__s.discard(v)
|
|
+ return result
|
|
+ def __repr__(self):
|
|
+ return "MySet(%s)" % repr(list(self))
|
|
items = [5,43,2,1]
|
|
s = MySet(items)
|
|
r = s.pop()
|
|
@@ -1563,24 +1615,25 @@ class TestCollectionABCs(ABCTestCase):
|
|
def test_issue16373(self):
|
|
# Recursion error comparing comparable and noncomparable
|
|
# Set instances
|
|
- class MyComparableSet(Set):
|
|
- def __contains__(self, x):
|
|
- return False
|
|
- def __len__(self):
|
|
- return 0
|
|
- def __iter__(self):
|
|
- return iter([])
|
|
- class MyNonComparableSet(Set):
|
|
- def __contains__(self, x):
|
|
- return False
|
|
- def __len__(self):
|
|
- return 0
|
|
- def __iter__(self):
|
|
- return iter([])
|
|
- def __le__(self, x):
|
|
- return NotImplemented
|
|
- def __lt__(self, x):
|
|
- return NotImplemented
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MyComparableSet(Set):
|
|
+ def __contains__(self, x):
|
|
+ return False
|
|
+ def __len__(self):
|
|
+ return 0
|
|
+ def __iter__(self):
|
|
+ return iter([])
|
|
+ class MyNonComparableSet(Set):
|
|
+ def __contains__(self, x):
|
|
+ return False
|
|
+ def __len__(self):
|
|
+ return 0
|
|
+ def __iter__(self):
|
|
+ return iter([])
|
|
+ def __le__(self, x):
|
|
+ return NotImplemented
|
|
+ def __lt__(self, x):
|
|
+ return NotImplemented
|
|
|
|
cs = MyComparableSet()
|
|
ncs = MyNonComparableSet()
|
|
@@ -1591,13 +1644,14 @@ class TestCollectionABCs(ABCTestCase):
|
|
|
|
def test_issue26915(self):
|
|
# Container membership test should check identity first
|
|
- class CustomSequence(Sequence):
|
|
- def __init__(self, seq):
|
|
- self._seq = seq
|
|
- def __getitem__(self, index):
|
|
- return self._seq[index]
|
|
- def __len__(self):
|
|
- return len(self._seq)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class CustomSequence(Sequence):
|
|
+ def __init__(self, seq):
|
|
+ self._seq = seq
|
|
+ def __getitem__(self, index):
|
|
+ return self._seq[index]
|
|
+ def __len__(self):
|
|
+ return len(self._seq)
|
|
|
|
nan = float('nan')
|
|
obj = support.NEVER_EQ
|
|
@@ -1622,30 +1676,31 @@ class TestCollectionABCs(ABCTestCase):
|
|
|
|
def test_Set_from_iterable(self):
|
|
"""Verify _from_iterable overridden to an instance method works."""
|
|
- class SetUsingInstanceFromIterable(MutableSet):
|
|
- def __init__(self, values, created_by):
|
|
- if not created_by:
|
|
- raise ValueError('created_by must be specified')
|
|
- self.created_by = created_by
|
|
- self._values = set(values)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class SetUsingInstanceFromIterable(MutableSet):
|
|
+ def __init__(self, values, created_by):
|
|
+ if not created_by:
|
|
+ raise ValueError('created_by must be specified')
|
|
+ self.created_by = created_by
|
|
+ self._values = set(values)
|
|
|
|
- def _from_iterable(self, values):
|
|
- return type(self)(values, 'from_iterable')
|
|
+ def _from_iterable(self, values):
|
|
+ return type(self)(values, 'from_iterable')
|
|
|
|
- def __contains__(self, value):
|
|
- return value in self._values
|
|
+ def __contains__(self, value):
|
|
+ return value in self._values
|
|
|
|
- def __iter__(self):
|
|
- yield from self._values
|
|
+ def __iter__(self):
|
|
+ yield from self._values
|
|
|
|
- def __len__(self):
|
|
- return len(self._values)
|
|
+ def __len__(self):
|
|
+ return len(self._values)
|
|
|
|
- def add(self, value):
|
|
- self._values.add(value)
|
|
+ def add(self, value):
|
|
+ self._values.add(value)
|
|
|
|
- def discard(self, value):
|
|
- self._values.discard(value)
|
|
+ def discard(self, value):
|
|
+ self._values.discard(value)
|
|
|
|
impl = SetUsingInstanceFromIterable([1, 2, 3], 'test')
|
|
|
|
@@ -1678,20 +1733,21 @@ class TestCollectionABCs(ABCTestCase):
|
|
|
|
def test_Set_interoperability_with_real_sets(self):
|
|
# Issue: 8743
|
|
- class ListSet(Set):
|
|
- def __init__(self, elements=()):
|
|
- self.data = []
|
|
- for elem in elements:
|
|
- if elem not in self.data:
|
|
- self.data.append(elem)
|
|
- def __contains__(self, elem):
|
|
- return elem in self.data
|
|
- def __iter__(self):
|
|
- return iter(self.data)
|
|
- def __len__(self):
|
|
- return len(self.data)
|
|
- def __repr__(self):
|
|
- return 'Set({!r})'.format(self.data)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class ListSet(Set):
|
|
+ def __init__(self, elements=()):
|
|
+ self.data = []
|
|
+ for elem in elements:
|
|
+ if elem not in self.data:
|
|
+ self.data.append(elem)
|
|
+ def __contains__(self, elem):
|
|
+ return elem in self.data
|
|
+ def __iter__(self):
|
|
+ return iter(self.data)
|
|
+ def __len__(self):
|
|
+ return len(self.data)
|
|
+ def __repr__(self):
|
|
+ return 'Set({!r})'.format(self.data)
|
|
|
|
r1 = set('abc')
|
|
r2 = set('bcd')
|
|
@@ -1846,13 +1902,14 @@ class TestCollectionABCs(ABCTestCase):
|
|
self.assertTrue(issubclass(sample, Mapping))
|
|
self.validate_abstract_methods(Mapping, '__contains__', '__iter__', '__len__',
|
|
'__getitem__')
|
|
- class MyMapping(Mapping):
|
|
- def __len__(self):
|
|
- return 0
|
|
- def __getitem__(self, i):
|
|
- raise IndexError
|
|
- def __iter__(self):
|
|
- return iter(())
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MyMapping(Mapping):
|
|
+ def __len__(self):
|
|
+ return 0
|
|
+ def __getitem__(self, i):
|
|
+ raise IndexError
|
|
+ def __iter__(self):
|
|
+ return iter(())
|
|
self.validate_comparison(MyMapping())
|
|
self.assertRaises(TypeError, reversed, MyMapping())
|
|
|
|
@@ -1860,7 +1917,7 @@ class TestCollectionABCs(ABCTestCase):
|
|
for sample in [dict]:
|
|
self.assertIsInstance(sample(), MutableMapping)
|
|
self.assertTrue(issubclass(sample, MutableMapping))
|
|
- self.validate_abstract_methods(MutableMapping, '__contains__', '__iter__', '__len__',
|
|
+ self.validate_abstract_methods(MutableMapping, '__iter__', '__len__',
|
|
'__getitem__', '__setitem__', '__delitem__')
|
|
|
|
def test_MutableMapping_subclass(self):
|
|
@@ -1903,15 +1960,16 @@ class TestCollectionABCs(ABCTestCase):
|
|
'__getitem__')
|
|
|
|
def test_Sequence_mixins(self):
|
|
- class SequenceSubclass(Sequence):
|
|
- def __init__(self, seq=()):
|
|
- self.seq = seq
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class SequenceSubclass(Sequence):
|
|
+ def __init__(self, seq=()):
|
|
+ self.seq = seq
|
|
|
|
- def __getitem__(self, index):
|
|
- return self.seq[index]
|
|
+ def __getitem__(self, index):
|
|
+ return self.seq[index]
|
|
|
|
- def __len__(self):
|
|
- return len(self.seq)
|
|
+ def __len__(self):
|
|
+ return len(self.seq)
|
|
|
|
# Compare Sequence.index() behavior to (list|str).index() behavior
|
|
def assert_index_same(seq1, seq2, index_args):
|
|
@@ -1983,24 +2041,25 @@ class TestCollectionABCs(ABCTestCase):
|
|
def test_MutableSequence_mixins(self):
|
|
# Test the mixins of MutableSequence by creating a minimal concrete
|
|
# class inherited from it.
|
|
- class MutableSequenceSubclass(MutableSequence):
|
|
- def __init__(self):
|
|
- self.lst = []
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MutableSequenceSubclass(MutableSequence):
|
|
+ def __init__(self):
|
|
+ self.lst = []
|
|
|
|
- def __setitem__(self, index, value):
|
|
- self.lst[index] = value
|
|
+ def __setitem__(self, index, value):
|
|
+ self.lst[index] = value
|
|
|
|
- def __getitem__(self, index):
|
|
- return self.lst[index]
|
|
+ def __getitem__(self, index):
|
|
+ return self.lst[index]
|
|
|
|
- def __len__(self):
|
|
- return len(self.lst)
|
|
+ def __len__(self):
|
|
+ return len(self.lst)
|
|
|
|
- def __delitem__(self, index):
|
|
- del self.lst[index]
|
|
+ def __delitem__(self, index):
|
|
+ del self.lst[index]
|
|
|
|
- def insert(self, index, value):
|
|
- self.lst.insert(index, value)
|
|
+ def insert(self, index, value):
|
|
+ self.lst.insert(index, value)
|
|
|
|
mss = MutableSequenceSubclass()
|
|
mss.append(0)
|
|
@@ -2059,7 +2118,7 @@ class CounterSubclassWithGet(Counter):
|
|
self.called = True
|
|
return Counter.get(self, key, default)
|
|
|
|
-class TestCounter(unittest.TestCase):
|
|
+class TestCounter(__TestCase):
|
|
|
|
def test_basics(self):
|
|
c = Counter('abcaba')
|
|
@@ -2225,8 +2284,9 @@ class TestCounter(unittest.TestCase):
|
|
check(Counter(words))
|
|
|
|
def test_copy_subclass(self):
|
|
- class MyCounter(Counter):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MyCounter(Counter):
|
|
+ pass
|
|
c = MyCounter('slartibartfast')
|
|
d = c.copy()
|
|
self.assertEqual(d, c)
|
|
@@ -2402,10 +2462,5 @@ class TestCounter(unittest.TestCase):
|
|
self.assertFalse(Counter(a=2, b=1, c=0) > Counter('aab'))
|
|
|
|
|
|
-def load_tests(loader, tests, pattern):
|
|
- tests.addTest(doctest.DocTestSuite(collections))
|
|
- return tests
|
|
-
|
|
-
|
|
if __name__ == "__main__":
|
|
- unittest.main()
|
|
+ run_tests()
|