mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Keep zero check be compatible with different sympy versions (#130729)
# Motivation
I found a difference between sympy 1.12 and 1.13.
```python
# for 1.12
>>> import sympy
>>> a = sympy.Number(0.0)
>>> a == 0
True
```
```python
# for 1.13
>>> import sympy
>>> a = sympy.Number(0.0)
>>> a == 0
False
```
The different behavior will impact the result of [safe_mul](6beec34b1c/torch/utils/_sympy/value_ranges.py (L521-L528)
), resulting in an incorrect results when `a = sympy.Number(0.0)`, `b = inf` and the result is `nan` if sympy version is 1.13. (the expected result is **0**)
```python
def safe_mul(a, b):
# Make unknown() * wrap(0.0) == wrap(0.0)
if a == 0.0:
return a
elif b == 0.0:
return b
else:
return a * b
```
In different sympy versions, `sympy.Number(0)` always has the same behavior that equals to 0.0.
```python
>>> import sympy
>>> a = sympy.Number(0)
>>> a == 0.0
True # for different sympy versions
```
So, use 0.0 when checking zero in safe_mul to keep compatible with different sympy versions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130729
Approved by: https://github.com/lezcano, https://github.com/EikanWang
This commit is contained in:
committed by
PyTorch MergeBot
parent
fedae41c57
commit
096dc444ce
@ -241,6 +241,10 @@ class TestValueRanges(TestCase):
|
||||
ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()),
|
||||
ValueRanges.wrap(0),
|
||||
)
|
||||
self.assertEqual(
|
||||
ValueRangeAnalysis.mul(ValueRanges.wrap(0.0), ValueRanges.unknown()),
|
||||
ValueRanges.wrap(0.0),
|
||||
)
|
||||
|
||||
@parametrize("fn", UNARY_BOOL_OPS)
|
||||
def test_unary_bool_ref_range(self, fn):
|
||||
|
Reference in New Issue
Block a user