Perform value range analysis with rationals when possible (#105137)

This is particularly useful for guards to avoid rounding errors, as most
guards (all?) are rational functions.

Fixes https://github.com/pytorch/pytorch/issues/105097

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105137
Approved by: https://github.com/ezyang
This commit is contained in:
lezcano
2023-07-13 14:04:07 +00:00
committed by PyTorch MergeBot
parent 634659e262
commit d1fedad080
3 changed files with 48 additions and 17 deletions

View File

@ -224,6 +224,21 @@ class TestValueRanges(TestCase):
if r.is_finite: if r.is_finite:
self.assertIn(r, ref_r) self.assertIn(r, ref_r)
def test_rational_bounds(self):
# Repro from https://github.com/pytorch/pytorch/issues/105097
from sympy import floor, Eq
shape_0 = sympy.Symbol('shape_0', positive=True, integer=True)
new_expr = (
Eq(30 * floor(4 * (((shape_0 + 1) // 96)) *
(((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 647 +
2584 * (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 647),
2880 * floor((((shape_0 + 1) // 96)) *
(((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 15528 +
323 * (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 7764)))
new_range_env = {shape_0: ValueRanges(lower=1, upper=190)}
self.assertTrue(new_expr.subs({shape_0: 95}))
self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr))
class TestSympyInterp(TestCase): class TestSympyInterp(TestCase):
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)

View File

@ -66,15 +66,16 @@ def sympy_interp(
analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean] analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean]
): ):
# Handle base cases # Handle base cases
# TODO: not really sure if I'm passing the right dtype here dtype = None
# TODO: wouldn't it be better to pass the sympy expression through if isinstance(expr, BooleanAtom):
# sometimes? dtype = torch.bool
if isinstance(expr, sympy.Integer): elif isinstance(expr, sympy.Integer):
return analysis.constant(int(expr), torch.int64) dtype = torch.int64
elif isinstance(expr, sympy.Number): elif isinstance(expr, sympy.Number):
return analysis.constant(float(expr), torch.double) dtype = torch.double
elif isinstance(expr, BooleanAtom):
return analysis.constant(bool(expr), torch.bool) if dtype is not None:
return analysis.constant(expr, dtype)
elif isinstance(expr, sympy.Symbol): elif isinstance(expr, sympy.Symbol):
return env[expr] return env[expr]

View File

@ -162,15 +162,27 @@ class SymPyValueRangeAnalysis:
@staticmethod @staticmethod
def constant(value, dtype): def constant(value, dtype):
# NB: value is NOT a sympy expression, it's a constant! # NB: value is NOT a sympy expression, it's a constant!
assert isinstance(value, (int, float, bool)) is_python = isinstance(value, (int, float, bool))
assert is_python or isinstance(value, (BooleanAtom, sympy.Integer, sympy.Number))
# using nan makes subsequent computation throw, and for the purposes of optimization # using nan makes subsequent computation throw, and for the purposes of optimization
# returning -math.inf - math.inf is equivalent to giving up # returning -math.inf - math.inf is equivalent to giving up
if math.isnan(value): if math.isnan(value):
return ValueRanges.unknown() return ValueRanges.unknown()
type_ = dtype_to_type(dtype) if is_python:
value = type_(value) type_ = dtype_to_type(dtype)
value = type_(value)
else:
# We do a type check on a best-effort basis
# We don't want to force a cast to sympy.Float if the value is Rational to avoid losing precision
if dtype == torch.bool:
assert isinstance(value, BooleanAtom)
elif dtype.is_floating_point:
assert not value.is_finite or value.is_real
else:
# dtype is intXX
assert value.is_integer
return ValueRanges.wrap(value) return ValueRanges.wrap(value)
@ -314,7 +326,9 @@ class SymPyValueRangeAnalysis:
return ValueRanges.wrap(r) return ValueRanges.wrap(r)
if b == 0: if b == 0:
type_ = sympy.Float if a.lower.is_Float else sympy.Integer if not a.lower.is_finite:
return ValueRanges.unknown()
type_ = sympy.Float if a.lower.is_real else sympy.Integer
return ValueRanges.wrap(type_(1)) return ValueRanges.wrap(type_(1))
if b < 0: if b < 0:
@ -381,12 +395,13 @@ class SymPyValueRangeAnalysis:
def fn_(x, y): def fn_(x, y):
# Poorman's version of upcasting in Sympy # Poorman's version of upcasting in Sympy
# Inf is not a float... # Inf is not a float...
if x.is_Float or not x.is_finite or y.is_Float or not y.is_finite: if x.is_Integer and y.is_Integer:
result_type = sympy.Float
else:
assert x.is_Integer
assert y.is_Integer
result_type = sympy.Integer result_type = sympy.Integer
elif x.is_rational and y.is_rational:
result_type = sympy.Rational
else:
assert x.is_real or not x.is_finite or y.is_real or not y.is_finite
result_type = sympy.Float
return fn(result_type(x), result_type(y)) return fn(result_type(x), result_type(y))
return ValueRanges.coordinatewise_increasing_map(a, b, fn_) return ValueRanges.coordinatewise_increasing_map(a, b, fn_)