Generalise mod value ranges (#123253)

We also add the usual comment where we note that we don't handle
negative values in mod properly.

We should also fix this in the definition of ModularIndexing. I'll do that
in a later PR, as for that one I'll also need to fix a number of tests that
are testing an incorrect behaviour.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123253
Approved by: https://github.com/peterbell10
This commit is contained in:
lezcano
2024-04-05 12:59:22 +00:00
committed by PyTorch MergeBot
parent caed7f6727
commit 7ce42ebd44
3 changed files with 46 additions and 8 deletions

View File

@ -255,7 +255,7 @@ class TestSympyInterp(TestCase):
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
def test_interp(self, fn):
# SymPy does not implement truncation for Expressions
if fn in ("div", "truncdiv", "minimum", "maximum"):
if fn in ("div", "truncdiv", "minimum", "maximum", "mod"):
return
from sympy.abc import x, y
@ -288,6 +288,10 @@ class TestSympyInterp(TestCase):
if fn in ("log", "exp"):
return
# Sympy does not support truncation on symbolic shapes
if fn in ("truncdiv", "mod"):
return
vals = CONSTANTS
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
vals = [True, False]