mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] support itertools.permutations (#159694)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159694 Approved by: https://github.com/guilhermeleobas ghstack dependencies: #159693
This commit is contained in:
committed by
PyTorch MergeBot
parent
e07c52b2c0
commit
5ed4f91779
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py
|
||||
index 7d5ba727389..f1cabfe2111 100644
|
||||
index 7d5ba727389..d15d83a2184 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_itertools.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_itertools.py
|
||||
@@ -1,3 +1,25 @@
|
||||
@ -50,7 +50,41 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
|
||||
def pickletest(self, protocol, it, stop=4, take=1, compare=None):
|
||||
"""Test that an iterator is the same after pickling, also when part-consumed"""
|
||||
@@ -756,7 +778,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -454,14 +476,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
|
||||
self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
|
||||
|
||||
- @pickle_deprecated
|
||||
def test_permutations(self):
|
||||
- self.assertRaises(TypeError, permutations) # too few arguments
|
||||
- self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
|
||||
- self.assertRaises(TypeError, permutations, None) # pool is not iterable
|
||||
- self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
|
||||
self.assertEqual(list(permutations('abc', 32)), []) # r > n
|
||||
- self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
|
||||
self.assertEqual(list(permutations(range(3), 2)),
|
||||
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
|
||||
|
||||
@@ -498,7 +514,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
if len(set(indices)) == r:
|
||||
yield tuple(pool[i] for i in indices)
|
||||
|
||||
- for n in range(7):
|
||||
+ for n in range(5):
|
||||
values = [5*x-12 for x in range(n)]
|
||||
for r in range(n+2):
|
||||
result = list(permutations(values, r))
|
||||
@@ -515,9 +531,6 @@ class TestBasicOps(unittest.TestCase):
|
||||
self.assertEqual(result, list(permutations(values, None))) # test r as None
|
||||
self.assertEqual(result, list(permutations(values))) # test default r
|
||||
|
||||
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
- self.pickletest(proto, permutations(values, r)) # test pickling
|
||||
-
|
||||
@support.bigaddrspacetest
|
||||
def test_permutations_overflow(self):
|
||||
with self.assertRaises((OverflowError, MemoryError)):
|
||||
@@ -756,7 +769,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
def test_cycle(self):
|
||||
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
|
||||
self.assertEqual(list(cycle('')), [])
|
||||
@ -59,7 +93,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
self.assertRaises(TypeError, cycle, 5)
|
||||
self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])
|
||||
|
||||
@@ -888,7 +910,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -888,7 +901,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
# Check normal pickled
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
dup = []
|
||||
@ -68,7 +102,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
for elem in g:
|
||||
self.assertEqual(k, elem[0])
|
||||
dup.append(elem)
|
||||
@@ -896,8 +918,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -896,8 +909,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
|
||||
# Check nested case
|
||||
dup = []
|
||||
@ -79,7 +113,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
for elem in ig:
|
||||
self.assertEqual(k, elem[0])
|
||||
self.assertEqual(ik, elem[2])
|
||||
@@ -907,8 +929,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -907,8 +920,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
# Check nested and pickled
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
dup = []
|
||||
@ -90,7 +124,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
for elem in ig:
|
||||
self.assertEqual(k, elem[0])
|
||||
self.assertEqual(ik, elem[2])
|
||||
@@ -917,7 +939,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -917,7 +930,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
|
||||
|
||||
# Check case where inner iterator is not used
|
||||
@ -99,7 +133,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
expectedkeys = set([r[0] for r in s])
|
||||
self.assertEqual(set(keys), expectedkeys)
|
||||
self.assertEqual(len(keys), len(expectedkeys))
|
||||
@@ -925,7 +947,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -925,7 +938,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
# Check case where inner iterator is used after advancing the groupby
|
||||
# iterator
|
||||
s = list(zip('AABBBAAAA', range(9)))
|
||||
@ -108,7 +142,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
_, g1 = next(it)
|
||||
_, g2 = next(it)
|
||||
_, g3 = next(it)
|
||||
@@ -936,7 +958,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -936,7 +949,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
self.assertEqual(list(g3), [])
|
||||
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
@ -117,7 +151,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
_, g = next(it)
|
||||
next(it)
|
||||
next(it)
|
||||
@@ -1002,27 +1024,29 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -1002,27 +1015,29 @@ class TestBasicOps(unittest.TestCase):
|
||||
self.assertEqual(list(filter(None, [0,1,0,2,0])), [1,2])
|
||||
self.assertEqual(list(filter(bool, [0,1,0,2,0])), [1,2])
|
||||
self.assertEqual(take(4, filter(isEven, count())), [0,2,4,6])
|
||||
@ -166,7 +200,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
|
||||
@pickle_deprecated
|
||||
def test_filterfalse(self):
|
||||
@@ -1047,8 +1071,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -1047,8 +1062,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
self.assertEqual(take(3,zip('abcdef', count())), lzip('abcdef', range(3)))
|
||||
self.assertEqual(list(zip('abcdef')), lzip('abcdef'))
|
||||
self.assertEqual(list(zip()), lzip())
|
||||
@ -177,7 +211,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
self.assertEqual([tuple(list(pair)) for pair in zip('abc', 'def')],
|
||||
lzip('abc', 'def'))
|
||||
self.assertEqual([pair for pair in zip('abc', 'def')],
|
||||
@@ -1105,19 +1129,19 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -1105,19 +1120,19 @@ class TestBasicOps(unittest.TestCase):
|
||||
|
||||
self.assertEqual(list(zip_longest('abc', 'defg', **{})),
|
||||
list(zip(list('abc')+[None], 'defg'))) # empty keyword dict
|
||||
@ -210,7 +244,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
|
||||
self.assertEqual([tuple(list(pair)) for pair in zip_longest('abc', 'def')],
|
||||
list(zip('abc', 'def')))
|
||||
@@ -1296,7 +1320,6 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -1296,7 +1311,6 @@ class TestBasicOps(unittest.TestCase):
|
||||
self.assertEqual(list(product(*(args*r))),
|
||||
list(product(*args, **dict(repeat=r))))
|
||||
self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
|
||||
@ -218,7 +252,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
|
||||
def product1(*args, **kwds):
|
||||
pools = list(map(tuple, args)) * kwds.get('repeat', 1)
|
||||
@@ -1336,7 +1359,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -1336,7 +1350,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3),
|
||||
set('abcdefg'), range(11), tuple(range(13))]
|
||||
for i in range(100):
|
||||
@ -228,7 +262,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
expected_len = prod(map(len, args))
|
||||
self.assertEqual(len(list(product(*args))), expected_len)
|
||||
self.assertEqual(list(product(*args)), list(product1(*args)))
|
||||
@@ -1767,6 +1791,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -1767,6 +1782,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
script_helper.assert_python_ok("-c", script)
|
||||
|
||||
# Issue 13454: Crash when deleting backward iterator from tee()
|
||||
@ -236,7 +270,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
def test_tee_del_backward(self):
|
||||
forward, backward = tee(repeat(None, 20000000))
|
||||
try:
|
||||
@@ -1920,7 +1945,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -1920,7 +1936,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
tp.foobar = 1
|
||||
|
||||
|
||||
@ -245,7 +279,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
|
||||
def test_accumulate(self):
|
||||
self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15])
|
||||
@@ -2032,7 +2057,7 @@ class TestExamples(unittest.TestCase):
|
||||
@@ -2032,7 +2048,7 @@ class TestExamples(unittest.TestCase):
|
||||
self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4])
|
||||
|
||||
|
||||
@ -254,7 +288,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
|
||||
def test_batched_recipe(self):
|
||||
def batched_recipe(iterable, n):
|
||||
@@ -2081,6 +2106,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
|
||||
@@ -2081,6 +2097,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
|
||||
for i, element in zip(range(i + 1, stop), iterable):
|
||||
pass
|
||||
|
||||
@ -262,7 +296,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
def test_islice_recipe(self):
|
||||
self.assertEqual(list(self.islice('ABCDEFG', 2)), list('AB'))
|
||||
self.assertEqual(list(self.islice('ABCDEFG', 2, 4)), list('CD'))
|
||||
@@ -2265,7 +2291,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
|
||||
@@ -2265,7 +2282,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
|
||||
raise
|
||||
|
||||
|
||||
@ -271,7 +305,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
|
||||
def makecycle(self, iterator, container):
|
||||
container.append(iterator)
|
||||
@@ -2465,7 +2491,7 @@ def L(seqn):
|
||||
@@ -2465,7 +2482,7 @@ def L(seqn):
|
||||
return chain(map(lambda x:x, R(Ig(G(seqn)))))
|
||||
|
||||
|
||||
@ -280,7 +314,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
|
||||
def test_accumulate(self):
|
||||
s = [1,2,3,4,5]
|
||||
@@ -2644,7 +2670,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
|
||||
@@ -2644,7 +2661,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
|
||||
self.assertRaises(TypeError, tee, N(s))
|
||||
self.assertRaises(ZeroDivisionError, list, tee(E(s))[0])
|
||||
|
||||
@ -289,7 +323,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
|
||||
def test_repeat(self):
|
||||
self.assertEqual(operator.length_hint(repeat(None, 50)), 50)
|
||||
@@ -2657,7 +2683,7 @@ class LengthTransparency(unittest.TestCase):
|
||||
@@ -2657,7 +2674,7 @@ class LengthTransparency(unittest.TestCase):
|
||||
self.assertEqual(operator.length_hint(repeat(None, times=-1)), 0)
|
||||
self.assertEqual(operator.length_hint(repeat(None, times=-2)), 0)
|
||||
|
||||
@ -298,7 +332,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
|
||||
def test_sf_793826(self):
|
||||
# Fix Armin Rigo's successful efforts to wreak havoc
|
||||
@@ -2718,6 +2744,7 @@ class RegressionTests(unittest.TestCase):
|
||||
@@ -2718,6 +2735,7 @@ class RegressionTests(unittest.TestCase):
|
||||
|
||||
@support.skip_if_pgo_task
|
||||
@support.requires_resource('cpu')
|
||||
@ -306,7 +340,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
def test_long_chain_of_empty_iterables(self):
|
||||
# Make sure itertools.chain doesn't run into recursion limits when
|
||||
# dealing with long chains of empty iterables. Even with a high
|
||||
@@ -2750,7 +2777,7 @@ class RegressionTests(unittest.TestCase):
|
||||
@@ -2750,7 +2768,7 @@ class RegressionTests(unittest.TestCase):
|
||||
next(g, None) # shouldn't crash
|
||||
|
||||
|
||||
@ -315,7 +349,7 @@ index 7d5ba727389..f1cabfe2111 100644
|
||||
def test_keywords_in_subclass(self):
|
||||
# count is not subclassable...
|
||||
testcases = [
|
||||
@@ -2805,49 +2832,5 @@ class SubclassWithKwargsTest(unittest.TestCase):
|
||||
@@ -2805,49 +2823,5 @@ class SubclassWithKwargsTest(unittest.TestCase):
|
||||
self.assertEqual(u.newarg, 3)
|
||||
|
||||
|
||||
|
@ -476,14 +476,8 @@ class TestBasicOps(__TestCase):
|
||||
self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
|
||||
self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
|
||||
|
||||
@pickle_deprecated
|
||||
def test_permutations(self):
|
||||
self.assertRaises(TypeError, permutations) # too few arguments
|
||||
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
|
||||
self.assertRaises(TypeError, permutations, None) # pool is not iterable
|
||||
self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
|
||||
self.assertEqual(list(permutations('abc', 32)), []) # r > n
|
||||
self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
|
||||
self.assertEqual(list(permutations(range(3), 2)),
|
||||
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
|
||||
|
||||
@ -520,7 +514,7 @@ class TestBasicOps(__TestCase):
|
||||
if len(set(indices)) == r:
|
||||
yield tuple(pool[i] for i in indices)
|
||||
|
||||
for n in range(7):
|
||||
for n in range(5):
|
||||
values = [5*x-12 for x in range(n)]
|
||||
for r in range(n+2):
|
||||
result = list(permutations(values, r))
|
||||
@ -537,9 +531,6 @@ class TestBasicOps(__TestCase):
|
||||
self.assertEqual(result, list(permutations(values, None))) # test r as None
|
||||
self.assertEqual(result, list(permutations(values))) # test default r
|
||||
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
self.pickletest(proto, permutations(values, r)) # test pickling
|
||||
|
||||
@support.bigaddrspacetest
|
||||
def test_permutations_overflow(self):
|
||||
with self.assertRaises((OverflowError, MemoryError)):
|
||||
|
@ -285,6 +285,31 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
)
|
||||
return a
|
||||
|
||||
def test_itertools_permutations_basic(self):
|
||||
def fn():
|
||||
return torch.tensor(list(itertools.permutations([1, 2, 3], 2)))
|
||||
|
||||
actual = torch.compile(fn, backend="eager", fullgraph=True)()
|
||||
expected = fn()
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_itertools_permutations_args(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(*args, **kwargs):
|
||||
return torch.tensor(list(itertools.permutations(*args, **kwargs)))
|
||||
|
||||
self.assertRaises(Unsupported, fn)
|
||||
self.assertRaises(Unsupported, fn, [1, 2, 3], 1, 2)
|
||||
self.assertRaises(Unsupported, fn, [1, 2, 3], fake_arg=1)
|
||||
|
||||
@make_test
|
||||
def test_itertools_permutations_various_iterators(a, b):
|
||||
itertools.permutations([a, b])
|
||||
itertools.permutations(zip([1, 2], [3, 4]))
|
||||
itertools.permutations(map(lambda x: x, [1, 2]))
|
||||
itertools.permutations(filter(lambda x: True, [1, 2]))
|
||||
return a
|
||||
|
||||
@make_test
|
||||
def test_itertools_chain(a, b):
|
||||
v = a
|
||||
|
@ -190,6 +190,24 @@ class ItertoolsVariable(VariableTracker):
|
||||
return variables.CountIteratorVariable(
|
||||
*args, mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif (
|
||||
self.value is itertools.permutations
|
||||
and (len(args) == 1 or (len(args) == 2 and args[1].is_python_constant()))
|
||||
and not kwargs
|
||||
):
|
||||
if len(args) == 2:
|
||||
r = args[1].as_python_constant()
|
||||
else:
|
||||
r = None
|
||||
items = [
|
||||
variables.TupleVariable(list(item))
|
||||
for item in itertools.permutations(
|
||||
args[0].force_unpack_var_sequence(tx), r
|
||||
)
|
||||
]
|
||||
return variables.ListIteratorVariable(
|
||||
items, mutation_type=ValueMutationNew()
|
||||
)
|
||||
else:
|
||||
return super().call_function(tx, args, kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user