Keep zero check be compatible with different sympy versions (#130729)

# Motivation
I found a difference between sympy 1.12 and 1.13.
```python
# for 1.12
>>> import sympy
>>> a = sympy.Number(0.0)
>>> a == 0
True
```
```python
# for 1.13
>>> import sympy
>>> a = sympy.Number(0.0)
>>> a == 0
False
```
The different behavior will impact the result of [safe_mul](6beec34b1c/torch/utils/_sympy/value_ranges.py (L521-L528)), resulting in an incorrect results when `a = sympy.Number(0.0)`, `b = inf` and the result is `nan` if sympy version is 1.13. (the expected result is **0**)
```python
def safe_mul(a, b):
    # Make unknown() * wrap(0.0) == wrap(0.0)
    if a == 0.0:
        return a
    elif b == 0.0:
        return b
    else:
        return a * b
```

In different sympy versions, `sympy.Number(0)` always has the same behavior that equals to 0.0.
```python
>>> import sympy
>>> a = sympy.Number(0)
>>> a == 0.0
True # for different sympy versions
```
So, use 0.0 when checking zero in safe_mul to keep compatible with different sympy versions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130729
Approved by: https://github.com/lezcano, https://github.com/EikanWang
This commit is contained in:
Yu, Guangye
2024-07-16 01:16:18 +00:00
committed by PyTorch MergeBot
parent fedae41c57
commit 096dc444ce
2 changed files with 7 additions and 3 deletions

View File

@ -241,6 +241,10 @@ class TestValueRanges(TestCase):
ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()),
ValueRanges.wrap(0),
)
self.assertEqual(
ValueRangeAnalysis.mul(ValueRanges.wrap(0.0), ValueRanges.unknown()),
ValueRanges.wrap(0.0),
)
@parametrize("fn", UNARY_BOOL_OPS)
def test_unary_bool_ref_range(self, fn):