[Dynamo] Support zip_longest (#131497)

Fixes #121348

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131497
Approved by: https://github.com/mlazos, https://github.com/jansel, https://github.com/zou3519
This commit is contained in:
Yanbo Liang
2024-07-26 14:06:10 +00:00
committed by PyTorch MergeBot
parent c9888c2739
commit e76e566cfb
3 changed files with 35 additions and 0 deletions

View File

@ -983,6 +983,15 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
return torch.ones(2, 2)
return x.sin()
@make_test
def test_zip_longest(x):
list1 = [1, 2, 3]
list2 = ["a", "b"]
list3 = [True, False, True, False]
return torch.sin(x + 1), list(
itertools.zip_longest(list1, list2, list3, fillvalue=None)
)
def test_dict_param_keys(self):
a_param = torch.nn.Parameter(torch.ones([4, 4]))

View File

@ -97,6 +97,28 @@ def dropwhile(predicate, iterable):
yield from iterable
def zip_longest(*iterables, fillvalue=None):
# Create a list of iterators from the input iterables
iterators = [iter(it) for it in iterables]
result = []
while True:
row = []
active = False
for it in iterators:
try:
# Try to get the next item from the iterator
value = next(it)
row.append(value)
active = True
except StopIteration:
# If the iterator is exhausted, use the fillvalue
row.append(fillvalue)
if not active:
break
result.append(tuple(row))
return result
def getattr_and_trace(*args, **kwargs):
wrapper_obj = args[0]
attr_name = args[1]

View File

@ -196,6 +196,10 @@ class ItertoolsVariable(VariableTracker):
return variables.UserFunctionVariable(polyfill.dropwhile).call_function(
tx, args, kwargs
)
elif self.value is itertools.zip_longest:
return variables.UserFunctionVariable(polyfill.zip_longest).call_function(
tx, args, kwargs
)
else:
return super().call_function(tx, args, kwargs)