[dynamo] support itertools.permutations (#159694)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159694
Approved by: https://github.com/guilhermeleobas
ghstack dependencies: #159693
This commit is contained in:
Rob Timpe
2025-08-08 23:26:50 +00:00
committed by PyTorch MergeBot
parent e07c52b2c0
commit 5ed4f91779
7 changed files with 102 additions and 34 deletions

View File

@ -1,5 +1,5 @@
diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py
index 7d5ba727389..f1cabfe2111 100644
index 7d5ba727389..d15d83a2184 100644
--- a/test/dynamo/cpython/3_13/test_itertools.py
+++ b/test/dynamo/cpython/3_13/test_itertools.py
@@ -1,3 +1,25 @@
@ -50,7 +50,41 @@ index 7d5ba727389..f1cabfe2111 100644
def pickletest(self, protocol, it, stop=4, take=1, compare=None):
"""Test that an iterator is the same after pickling, also when part-consumed"""
@@ -756,7 +778,7 @@ class TestBasicOps(unittest.TestCase):
@@ -454,14 +476,8 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
- @pickle_deprecated
def test_permutations(self):
- self.assertRaises(TypeError, permutations) # too few arguments
- self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
- self.assertRaises(TypeError, permutations, None) # pool is not iterable
- self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
self.assertEqual(list(permutations('abc', 32)), []) # r > n
- self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
self.assertEqual(list(permutations(range(3), 2)),
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
@@ -498,7 +514,7 @@ class TestBasicOps(unittest.TestCase):
if len(set(indices)) == r:
yield tuple(pool[i] for i in indices)
- for n in range(7):
+ for n in range(5):
values = [5*x-12 for x in range(n)]
for r in range(n+2):
result = list(permutations(values, r))
@@ -515,9 +531,6 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(result, list(permutations(values, None))) # test r as None
self.assertEqual(result, list(permutations(values))) # test default r
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- self.pickletest(proto, permutations(values, r)) # test pickling
-
@support.bigaddrspacetest
def test_permutations_overflow(self):
with self.assertRaises((OverflowError, MemoryError)):
@@ -756,7 +769,7 @@ class TestBasicOps(unittest.TestCase):
def test_cycle(self):
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
self.assertEqual(list(cycle('')), [])
@ -59,7 +93,7 @@ index 7d5ba727389..f1cabfe2111 100644
self.assertRaises(TypeError, cycle, 5)
self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])
@@ -888,7 +910,7 @@ class TestBasicOps(unittest.TestCase):
@@ -888,7 +901,7 @@ class TestBasicOps(unittest.TestCase):
# Check normal pickled
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
dup = []
@ -68,7 +102,7 @@ index 7d5ba727389..f1cabfe2111 100644
for elem in g:
self.assertEqual(k, elem[0])
dup.append(elem)
@@ -896,8 +918,8 @@ class TestBasicOps(unittest.TestCase):
@@ -896,8 +909,8 @@ class TestBasicOps(unittest.TestCase):
# Check nested case
dup = []
@ -79,7 +113,7 @@ index 7d5ba727389..f1cabfe2111 100644
for elem in ig:
self.assertEqual(k, elem[0])
self.assertEqual(ik, elem[2])
@@ -907,8 +929,8 @@ class TestBasicOps(unittest.TestCase):
@@ -907,8 +920,8 @@ class TestBasicOps(unittest.TestCase):
# Check nested and pickled
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
dup = []
@ -90,7 +124,7 @@ index 7d5ba727389..f1cabfe2111 100644
for elem in ig:
self.assertEqual(k, elem[0])
self.assertEqual(ik, elem[2])
@@ -917,7 +939,7 @@ class TestBasicOps(unittest.TestCase):
@@ -917,7 +930,7 @@ class TestBasicOps(unittest.TestCase):
# Check case where inner iterator is not used
@ -99,7 +133,7 @@ index 7d5ba727389..f1cabfe2111 100644
expectedkeys = set([r[0] for r in s])
self.assertEqual(set(keys), expectedkeys)
self.assertEqual(len(keys), len(expectedkeys))
@@ -925,7 +947,7 @@ class TestBasicOps(unittest.TestCase):
@@ -925,7 +938,7 @@ class TestBasicOps(unittest.TestCase):
# Check case where inner iterator is used after advancing the groupby
# iterator
s = list(zip('AABBBAAAA', range(9)))
@ -108,7 +142,7 @@ index 7d5ba727389..f1cabfe2111 100644
_, g1 = next(it)
_, g2 = next(it)
_, g3 = next(it)
@@ -936,7 +958,7 @@ class TestBasicOps(unittest.TestCase):
@@ -936,7 +949,7 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(g3), [])
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
@ -117,7 +151,7 @@ index 7d5ba727389..f1cabfe2111 100644
_, g = next(it)
next(it)
next(it)
@@ -1002,27 +1024,29 @@ class TestBasicOps(unittest.TestCase):
@@ -1002,27 +1015,29 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(filter(None, [0,1,0,2,0])), [1,2])
self.assertEqual(list(filter(bool, [0,1,0,2,0])), [1,2])
self.assertEqual(take(4, filter(isEven, count())), [0,2,4,6])
@ -166,7 +200,7 @@ index 7d5ba727389..f1cabfe2111 100644
@pickle_deprecated
def test_filterfalse(self):
@@ -1047,8 +1071,8 @@ class TestBasicOps(unittest.TestCase):
@@ -1047,8 +1062,8 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(take(3,zip('abcdef', count())), lzip('abcdef', range(3)))
self.assertEqual(list(zip('abcdef')), lzip('abcdef'))
self.assertEqual(list(zip()), lzip())
@ -177,7 +211,7 @@ index 7d5ba727389..f1cabfe2111 100644
self.assertEqual([tuple(list(pair)) for pair in zip('abc', 'def')],
lzip('abc', 'def'))
self.assertEqual([pair for pair in zip('abc', 'def')],
@@ -1105,19 +1129,19 @@ class TestBasicOps(unittest.TestCase):
@@ -1105,19 +1120,19 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(zip_longest('abc', 'defg', **{})),
list(zip(list('abc')+[None], 'defg'))) # empty keyword dict
@ -210,7 +244,7 @@ index 7d5ba727389..f1cabfe2111 100644
self.assertEqual([tuple(list(pair)) for pair in zip_longest('abc', 'def')],
list(zip('abc', 'def')))
@@ -1296,7 +1320,6 @@ class TestBasicOps(unittest.TestCase):
@@ -1296,7 +1311,6 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(product(*(args*r))),
list(product(*args, **dict(repeat=r))))
self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
@ -218,7 +252,7 @@ index 7d5ba727389..f1cabfe2111 100644
def product1(*args, **kwds):
pools = list(map(tuple, args)) * kwds.get('repeat', 1)
@@ -1336,7 +1359,8 @@ class TestBasicOps(unittest.TestCase):
@@ -1336,7 +1350,8 @@ class TestBasicOps(unittest.TestCase):
argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3),
set('abcdefg'), range(11), tuple(range(13))]
for i in range(100):
@ -228,7 +262,7 @@ index 7d5ba727389..f1cabfe2111 100644
expected_len = prod(map(len, args))
self.assertEqual(len(list(product(*args))), expected_len)
self.assertEqual(list(product(*args)), list(product1(*args)))
@@ -1767,6 +1791,7 @@ class TestBasicOps(unittest.TestCase):
@@ -1767,6 +1782,7 @@ class TestBasicOps(unittest.TestCase):
script_helper.assert_python_ok("-c", script)
# Issue 13454: Crash when deleting backward iterator from tee()
@ -236,7 +270,7 @@ index 7d5ba727389..f1cabfe2111 100644
def test_tee_del_backward(self):
forward, backward = tee(repeat(None, 20000000))
try:
@@ -1920,7 +1945,7 @@ class TestBasicOps(unittest.TestCase):
@@ -1920,7 +1936,7 @@ class TestBasicOps(unittest.TestCase):
tp.foobar = 1
@ -245,7 +279,7 @@ index 7d5ba727389..f1cabfe2111 100644
def test_accumulate(self):
self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15])
@@ -2032,7 +2057,7 @@ class TestExamples(unittest.TestCase):
@@ -2032,7 +2048,7 @@ class TestExamples(unittest.TestCase):
self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4])
@ -254,7 +288,7 @@ index 7d5ba727389..f1cabfe2111 100644
def test_batched_recipe(self):
def batched_recipe(iterable, n):
@@ -2081,6 +2106,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
@@ -2081,6 +2097,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
for i, element in zip(range(i + 1, stop), iterable):
pass
@ -262,7 +296,7 @@ index 7d5ba727389..f1cabfe2111 100644
def test_islice_recipe(self):
self.assertEqual(list(self.islice('ABCDEFG', 2)), list('AB'))
self.assertEqual(list(self.islice('ABCDEFG', 2, 4)), list('CD'))
@@ -2265,7 +2291,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
@@ -2265,7 +2282,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
raise
@ -271,7 +305,7 @@ index 7d5ba727389..f1cabfe2111 100644
def makecycle(self, iterator, container):
container.append(iterator)
@@ -2465,7 +2491,7 @@ def L(seqn):
@@ -2465,7 +2482,7 @@ def L(seqn):
return chain(map(lambda x:x, R(Ig(G(seqn)))))
@ -280,7 +314,7 @@ index 7d5ba727389..f1cabfe2111 100644
def test_accumulate(self):
s = [1,2,3,4,5]
@@ -2644,7 +2670,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
@@ -2644,7 +2661,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
self.assertRaises(TypeError, tee, N(s))
self.assertRaises(ZeroDivisionError, list, tee(E(s))[0])
@ -289,7 +323,7 @@ index 7d5ba727389..f1cabfe2111 100644
def test_repeat(self):
self.assertEqual(operator.length_hint(repeat(None, 50)), 50)
@@ -2657,7 +2683,7 @@ class LengthTransparency(unittest.TestCase):
@@ -2657,7 +2674,7 @@ class LengthTransparency(unittest.TestCase):
self.assertEqual(operator.length_hint(repeat(None, times=-1)), 0)
self.assertEqual(operator.length_hint(repeat(None, times=-2)), 0)
@ -298,7 +332,7 @@ index 7d5ba727389..f1cabfe2111 100644
def test_sf_793826(self):
# Fix Armin Rigo's successful efforts to wreak havoc
@@ -2718,6 +2744,7 @@ class RegressionTests(unittest.TestCase):
@@ -2718,6 +2735,7 @@ class RegressionTests(unittest.TestCase):
@support.skip_if_pgo_task
@support.requires_resource('cpu')
@ -306,7 +340,7 @@ index 7d5ba727389..f1cabfe2111 100644
def test_long_chain_of_empty_iterables(self):
# Make sure itertools.chain doesn't run into recursion limits when
# dealing with long chains of empty iterables. Even with a high
@@ -2750,7 +2777,7 @@ class RegressionTests(unittest.TestCase):
@@ -2750,7 +2768,7 @@ class RegressionTests(unittest.TestCase):
next(g, None) # shouldn't crash
@ -315,7 +349,7 @@ index 7d5ba727389..f1cabfe2111 100644
def test_keywords_in_subclass(self):
# count is not subclassable...
testcases = [
@@ -2805,49 +2832,5 @@ class SubclassWithKwargsTest(unittest.TestCase):
@@ -2805,49 +2823,5 @@ class SubclassWithKwargsTest(unittest.TestCase):
self.assertEqual(u.newarg, 3)

View File

@ -476,14 +476,8 @@ class TestBasicOps(__TestCase):
self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
@pickle_deprecated
def test_permutations(self):
self.assertRaises(TypeError, permutations) # too few arguments
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
self.assertRaises(TypeError, permutations, None) # pool is not iterable
self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
self.assertEqual(list(permutations('abc', 32)), []) # r > n
self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
self.assertEqual(list(permutations(range(3), 2)),
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
@ -520,7 +514,7 @@ class TestBasicOps(__TestCase):
if len(set(indices)) == r:
yield tuple(pool[i] for i in indices)
for n in range(7):
for n in range(5):
values = [5*x-12 for x in range(n)]
for r in range(n+2):
result = list(permutations(values, r))
@ -537,9 +531,6 @@ class TestBasicOps(__TestCase):
self.assertEqual(result, list(permutations(values, None))) # test r as None
self.assertEqual(result, list(permutations(values))) # test default r
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
self.pickletest(proto, permutations(values, r)) # test pickling
@support.bigaddrspacetest
def test_permutations_overflow(self):
with self.assertRaises((OverflowError, MemoryError)):

View File

@ -285,6 +285,31 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
)
return a
def test_itertools_permutations_basic(self):
def fn():
return torch.tensor(list(itertools.permutations([1, 2, 3], 2)))
actual = torch.compile(fn, backend="eager", fullgraph=True)()
expected = fn()
self.assertEqual(actual, expected)
def test_itertools_permutations_args(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(*args, **kwargs):
return torch.tensor(list(itertools.permutations(*args, **kwargs)))
self.assertRaises(Unsupported, fn)
self.assertRaises(Unsupported, fn, [1, 2, 3], 1, 2)
self.assertRaises(Unsupported, fn, [1, 2, 3], fake_arg=1)
@make_test
def test_itertools_permutations_various_iterators(a, b):
itertools.permutations([a, b])
itertools.permutations(zip([1, 2], [3, 4]))
itertools.permutations(map(lambda x: x, [1, 2]))
itertools.permutations(filter(lambda x: True, [1, 2]))
return a
@make_test
def test_itertools_chain(a, b):
v = a

View File

@ -190,6 +190,24 @@ class ItertoolsVariable(VariableTracker):
return variables.CountIteratorVariable(
*args, mutation_type=ValueMutationNew()
)
elif (
self.value is itertools.permutations
and (len(args) == 1 or (len(args) == 2 and args[1].is_python_constant()))
and not kwargs
):
if len(args) == 2:
r = args[1].as_python_constant()
else:
r = None
items = [
variables.TupleVariable(list(item))
for item in itertools.permutations(
args[0].force_unpack_var_sequence(tx), r
)
]
return variables.ListIteratorVariable(
items, mutation_type=ValueMutationNew()
)
else:
return super().call_function(tx, args, kwargs)