mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo] Support zip_longest (#131497)
Fixes #121348 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131497 Approved by: https://github.com/mlazos, https://github.com/jansel, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
c9888c2739
commit
e76e566cfb
@ -983,6 +983,15 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
return torch.ones(2, 2)
|
||||
return x.sin()
|
||||
|
||||
@make_test
|
||||
def test_zip_longest(x):
|
||||
list1 = [1, 2, 3]
|
||||
list2 = ["a", "b"]
|
||||
list3 = [True, False, True, False]
|
||||
return torch.sin(x + 1), list(
|
||||
itertools.zip_longest(list1, list2, list3, fillvalue=None)
|
||||
)
|
||||
|
||||
def test_dict_param_keys(self):
|
||||
a_param = torch.nn.Parameter(torch.ones([4, 4]))
|
||||
|
||||
|
@ -97,6 +97,28 @@ def dropwhile(predicate, iterable):
|
||||
yield from iterable
|
||||
|
||||
|
||||
def zip_longest(*iterables, fillvalue=None):
|
||||
# Create a list of iterators from the input iterables
|
||||
iterators = [iter(it) for it in iterables]
|
||||
result = []
|
||||
while True:
|
||||
row = []
|
||||
active = False
|
||||
for it in iterators:
|
||||
try:
|
||||
# Try to get the next item from the iterator
|
||||
value = next(it)
|
||||
row.append(value)
|
||||
active = True
|
||||
except StopIteration:
|
||||
# If the iterator is exhausted, use the fillvalue
|
||||
row.append(fillvalue)
|
||||
if not active:
|
||||
break
|
||||
result.append(tuple(row))
|
||||
return result
|
||||
|
||||
|
||||
def getattr_and_trace(*args, **kwargs):
|
||||
wrapper_obj = args[0]
|
||||
attr_name = args[1]
|
||||
|
@ -196,6 +196,10 @@ class ItertoolsVariable(VariableTracker):
|
||||
return variables.UserFunctionVariable(polyfill.dropwhile).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
elif self.value is itertools.zip_longest:
|
||||
return variables.UserFunctionVariable(polyfill.zip_longest).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
else:
|
||||
return super().call_function(tx, args, kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user