Perform value range analysis with rationals when possible (#105137)

This is particularly useful for guards to avoid rounding errors, as most
guards (all?) are rational functions.

Fixes https://github.com/pytorch/pytorch/issues/105097

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105137
Approved by: https://github.com/ezyang
This commit is contained in:
lezcano
2023-07-13 14:04:07 +00:00
committed by PyTorch MergeBot
parent 634659e262
commit d1fedad080
3 changed files with 48 additions and 17 deletions

View File

@ -224,6 +224,21 @@ class TestValueRanges(TestCase):
if r.is_finite:
self.assertIn(r, ref_r)
def test_rational_bounds(self):
# Repro from https://github.com/pytorch/pytorch/issues/105097
from sympy import floor, Eq
shape_0 = sympy.Symbol('shape_0', positive=True, integer=True)
new_expr = (
Eq(30 * floor(4 * (((shape_0 + 1) // 96)) *
(((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 647 +
2584 * (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 647),
2880 * floor((((shape_0 + 1) // 96)) *
(((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 15528 +
323 * (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 7764)))
new_range_env = {shape_0: ValueRanges(lower=1, upper=190)}
self.assertTrue(new_expr.subs({shape_0: 95}))
self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr))
class TestSympyInterp(TestCase):
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)