mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Dynamo] Support zip_longest (#131497)
Fixes #121348 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131497 Approved by: https://github.com/mlazos, https://github.com/jansel, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
c9888c2739
commit
e76e566cfb
@ -983,6 +983,15 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
return torch.ones(2, 2)
|
return torch.ones(2, 2)
|
||||||
return x.sin()
|
return x.sin()
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_zip_longest(x):
|
||||||
|
list1 = [1, 2, 3]
|
||||||
|
list2 = ["a", "b"]
|
||||||
|
list3 = [True, False, True, False]
|
||||||
|
return torch.sin(x + 1), list(
|
||||||
|
itertools.zip_longest(list1, list2, list3, fillvalue=None)
|
||||||
|
)
|
||||||
|
|
||||||
def test_dict_param_keys(self):
|
def test_dict_param_keys(self):
|
||||||
a_param = torch.nn.Parameter(torch.ones([4, 4]))
|
a_param = torch.nn.Parameter(torch.ones([4, 4]))
|
||||||
|
|
||||||
|
@ -97,6 +97,28 @@ def dropwhile(predicate, iterable):
|
|||||||
yield from iterable
|
yield from iterable
|
||||||
|
|
||||||
|
|
||||||
|
def zip_longest(*iterables, fillvalue=None):
|
||||||
|
# Create a list of iterators from the input iterables
|
||||||
|
iterators = [iter(it) for it in iterables]
|
||||||
|
result = []
|
||||||
|
while True:
|
||||||
|
row = []
|
||||||
|
active = False
|
||||||
|
for it in iterators:
|
||||||
|
try:
|
||||||
|
# Try to get the next item from the iterator
|
||||||
|
value = next(it)
|
||||||
|
row.append(value)
|
||||||
|
active = True
|
||||||
|
except StopIteration:
|
||||||
|
# If the iterator is exhausted, use the fillvalue
|
||||||
|
row.append(fillvalue)
|
||||||
|
if not active:
|
||||||
|
break
|
||||||
|
result.append(tuple(row))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def getattr_and_trace(*args, **kwargs):
|
def getattr_and_trace(*args, **kwargs):
|
||||||
wrapper_obj = args[0]
|
wrapper_obj = args[0]
|
||||||
attr_name = args[1]
|
attr_name = args[1]
|
||||||
|
@ -196,6 +196,10 @@ class ItertoolsVariable(VariableTracker):
|
|||||||
return variables.UserFunctionVariable(polyfill.dropwhile).call_function(
|
return variables.UserFunctionVariable(polyfill.dropwhile).call_function(
|
||||||
tx, args, kwargs
|
tx, args, kwargs
|
||||||
)
|
)
|
||||||
|
elif self.value is itertools.zip_longest:
|
||||||
|
return variables.UserFunctionVariable(polyfill.zip_longest).call_function(
|
||||||
|
tx, args, kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return super().call_function(tx, args, kwargs)
|
return super().call_function(tx, args, kwargs)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user