mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
[dynamo] fix functools.reduce() function with None as initial (#116398)
The `initial` argument in `functools.reduce` can be `None`.
```python
initial_missing = object()
def reduce(function, iterable, initial=initial_missing, /):
it = iter(iterable)
if initial is initial_missing:
value = next(it)
else:
value = initial
for element in it:
value = function(value, element)
return value
```
Reference:
- python/cpython#102759
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116398
Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
c7e9c15102
commit
039fbeb016
@ -63,11 +63,19 @@ def func_with_default(a, b, some_default_arg=True):
|
||||
return a - b
|
||||
|
||||
|
||||
def make_test(fn):
|
||||
def make_test(fn=None, expected_frame_count=1):
|
||||
if fn is None:
|
||||
return lambda fn: make_test(fn, expected_frame_count=expected_frame_count)
|
||||
|
||||
nargs = len(inspect.signature(fn).parameters)
|
||||
|
||||
def test_fn(self):
|
||||
return torch._dynamo.testing.standard_test(self, fn=fn, nargs=nargs)
|
||||
return torch._dynamo.testing.standard_test(
|
||||
self,
|
||||
fn=fn,
|
||||
nargs=nargs,
|
||||
expected_frame_count=expected_frame_count,
|
||||
)
|
||||
|
||||
return test_fn
|
||||
|
||||
@ -870,6 +878,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
def test_reduce(a, b, c, d):
|
||||
return functools.reduce(operator.add, [a, b, c, d])
|
||||
|
||||
@make_test
|
||||
def test_reduce_with_initial(a, b, c, d):
|
||||
return functools.reduce(operator.add, [b, c, d], a)
|
||||
|
||||
@make_test(expected_frame_count=0)
|
||||
def test_reduce_with_single(x):
|
||||
return functools.reduce(lambda a, b: (a, b), [x])
|
||||
|
||||
@make_test(expected_frame_count=0)
|
||||
def test_reduce_with_single_with_initial(x, y):
|
||||
return functools.reduce(lambda a, b: (a, b), [y], x)
|
||||
|
||||
@make_test(expected_frame_count=0)
|
||||
def test_reduce_with_none_initial(x):
|
||||
return functools.reduce(lambda a, b: (a, b), [x], None)
|
||||
|
||||
@make_test
|
||||
def test_tuple_contains(a, b):
|
||||
v1 = "a"
|
||||
|
||||
@ -244,7 +244,14 @@ def normalize_gm(gm_str) -> str:
|
||||
return remove_trailing_space(strip_comment(gm_str))
|
||||
|
||||
|
||||
def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None):
|
||||
def standard_test(
|
||||
self,
|
||||
fn,
|
||||
nargs,
|
||||
expected_ops=None,
|
||||
expected_ops_dynamic=None,
|
||||
expected_frame_count=1,
|
||||
):
|
||||
if not config.assume_static_by_default and expected_ops_dynamic is not None:
|
||||
expected_ops = expected_ops_dynamic
|
||||
|
||||
@ -265,7 +272,7 @@ def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None)
|
||||
self.assertTrue(same(val1b, correct1))
|
||||
self.assertTrue(same(val2a, correct2))
|
||||
self.assertTrue(same(val2b, correct2))
|
||||
self.assertEqual(actual.frame_count, 1)
|
||||
self.assertEqual(actual.frame_count, expected_frame_count)
|
||||
if expected_ops is not None:
|
||||
self.assertEqual(actual.op_count, expected_ops)
|
||||
|
||||
|
||||
@ -86,6 +86,8 @@ def _polyfill_call_impl(name):
|
||||
|
||||
|
||||
class BuiltinVariable(VariableTracker):
|
||||
_SENTINEL = object()
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def _constant_fold_functions():
|
||||
@ -1100,13 +1102,13 @@ class BuiltinVariable(VariableTracker):
|
||||
{},
|
||||
)
|
||||
|
||||
def call_reduce(self, tx, function, iterable, initializer=None):
|
||||
def call_reduce(self, tx, function, iterable, initial=_SENTINEL):
|
||||
if iterable.has_unpack_var_sequence(tx):
|
||||
items = iterable.unpack_var_sequence(tx)
|
||||
if initializer is None:
|
||||
if initial is self._SENTINEL:
|
||||
value, items = items[0], items[1:]
|
||||
else:
|
||||
value = initializer
|
||||
value = initial
|
||||
for element in items:
|
||||
value = function.call_function(tx, [value, element], {})
|
||||
return value
|
||||
|
||||
Reference in New Issue
Block a user