[dynamo] fix functools.reduce() function with None as initial (#116398)

The `initial` argument in `functools.reduce` can be `None`.

```python
initial_missing = object()

def reduce(function, iterable, initial=initial_missing, /):
    it = iter(iterable)
    if initial is initial_missing:
        value = next(it)
    else:
        value = initial
    for element in it:
        value = function(value, element)
    return value
```

Reference:

- python/cpython#102759

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116398
Approved by: https://github.com/Skylion007
This commit is contained in:
Xuehai Pan
2023-12-25 21:23:28 +00:00
committed by PyTorch MergeBot
parent c7e9c15102
commit 039fbeb016
3 changed files with 40 additions and 7 deletions

View File

@ -63,11 +63,19 @@ def func_with_default(a, b, some_default_arg=True):
return a - b
def make_test(fn):
def make_test(fn=None, expected_frame_count=1):
if fn is None:
return lambda fn: make_test(fn, expected_frame_count=expected_frame_count)
nargs = len(inspect.signature(fn).parameters)
def test_fn(self):
return torch._dynamo.testing.standard_test(self, fn=fn, nargs=nargs)
return torch._dynamo.testing.standard_test(
self,
fn=fn,
nargs=nargs,
expected_frame_count=expected_frame_count,
)
return test_fn
@ -870,6 +878,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
def test_reduce(a, b, c, d):
return functools.reduce(operator.add, [a, b, c, d])
@make_test
def test_reduce_with_initial(a, b, c, d):
return functools.reduce(operator.add, [b, c, d], a)
@make_test(expected_frame_count=0)
def test_reduce_with_single(x):
return functools.reduce(lambda a, b: (a, b), [x])
@make_test(expected_frame_count=0)
def test_reduce_with_single_with_initial(x, y):
return functools.reduce(lambda a, b: (a, b), [y], x)
@make_test(expected_frame_count=0)
def test_reduce_with_none_initial(x):
return functools.reduce(lambda a, b: (a, b), [x], None)
@make_test
def test_tuple_contains(a, b):
v1 = "a"

View File

@ -244,7 +244,14 @@ def normalize_gm(gm_str) -> str:
return remove_trailing_space(strip_comment(gm_str))
def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None):
def standard_test(
self,
fn,
nargs,
expected_ops=None,
expected_ops_dynamic=None,
expected_frame_count=1,
):
if not config.assume_static_by_default and expected_ops_dynamic is not None:
expected_ops = expected_ops_dynamic
@ -265,7 +272,7 @@ def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None)
self.assertTrue(same(val1b, correct1))
self.assertTrue(same(val2a, correct2))
self.assertTrue(same(val2b, correct2))
self.assertEqual(actual.frame_count, 1)
self.assertEqual(actual.frame_count, expected_frame_count)
if expected_ops is not None:
self.assertEqual(actual.op_count, expected_ops)

View File

@ -86,6 +86,8 @@ def _polyfill_call_impl(name):
class BuiltinVariable(VariableTracker):
_SENTINEL = object()
@staticmethod
@functools.lru_cache(None)
def _constant_fold_functions():
@ -1100,13 +1102,13 @@ class BuiltinVariable(VariableTracker):
{},
)
def call_reduce(self, tx, function, iterable, initializer=None):
def call_reduce(self, tx, function, iterable, initial=_SENTINEL):
if iterable.has_unpack_var_sequence(tx):
items = iterable.unpack_var_sequence(tx)
if initializer is None:
if initial is self._SENTINEL:
value, items = items[0], items[1:]
else:
value = initializer
value = initial
for element in items:
value = function.call_function(tx, [value, element], {})
return value