mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Perform value range analysis with rationals when possible (#105137)
This is particularly useful for guards to avoid rounding errors, as most guards (all?) are rational functions. Fixes https://github.com/pytorch/pytorch/issues/105097 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105137 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
634659e262
commit
d1fedad080
@ -224,6 +224,21 @@ class TestValueRanges(TestCase):
|
||||
if r.is_finite:
|
||||
self.assertIn(r, ref_r)
|
||||
|
||||
def test_rational_bounds(self):
|
||||
# Repro from https://github.com/pytorch/pytorch/issues/105097
|
||||
from sympy import floor, Eq
|
||||
shape_0 = sympy.Symbol('shape_0', positive=True, integer=True)
|
||||
new_expr = (
|
||||
Eq(30 * floor(4 * (((shape_0 + 1) // 96)) *
|
||||
(((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 647 +
|
||||
2584 * (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 647),
|
||||
2880 * floor((((shape_0 + 1) // 96)) *
|
||||
(((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 15528 +
|
||||
323 * (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 7764)))
|
||||
new_range_env = {shape_0: ValueRanges(lower=1, upper=190)}
|
||||
self.assertTrue(new_expr.subs({shape_0: 95}))
|
||||
self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr))
|
||||
|
||||
|
||||
class TestSympyInterp(TestCase):
|
||||
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
|
||||
|
||||
@ -66,15 +66,16 @@ def sympy_interp(
|
||||
analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean]
|
||||
):
|
||||
# Handle base cases
|
||||
# TODO: not really sure if I'm passing the right dtype here
|
||||
# TODO: wouldn't it be better to pass the sympy expression through
|
||||
# sometimes?
|
||||
if isinstance(expr, sympy.Integer):
|
||||
return analysis.constant(int(expr), torch.int64)
|
||||
dtype = None
|
||||
if isinstance(expr, BooleanAtom):
|
||||
dtype = torch.bool
|
||||
elif isinstance(expr, sympy.Integer):
|
||||
dtype = torch.int64
|
||||
elif isinstance(expr, sympy.Number):
|
||||
return analysis.constant(float(expr), torch.double)
|
||||
elif isinstance(expr, BooleanAtom):
|
||||
return analysis.constant(bool(expr), torch.bool)
|
||||
dtype = torch.double
|
||||
|
||||
if dtype is not None:
|
||||
return analysis.constant(expr, dtype)
|
||||
elif isinstance(expr, sympy.Symbol):
|
||||
return env[expr]
|
||||
|
||||
|
||||
@ -162,15 +162,27 @@ class SymPyValueRangeAnalysis:
|
||||
@staticmethod
|
||||
def constant(value, dtype):
|
||||
# NB: value is NOT a sympy expression, it's a constant!
|
||||
assert isinstance(value, (int, float, bool))
|
||||
is_python = isinstance(value, (int, float, bool))
|
||||
assert is_python or isinstance(value, (BooleanAtom, sympy.Integer, sympy.Number))
|
||||
|
||||
# using nan makes subsequent computation throw, and for the purposes of optimization
|
||||
# returning -math.inf - math.inf is equivalent to giving up
|
||||
if math.isnan(value):
|
||||
return ValueRanges.unknown()
|
||||
|
||||
type_ = dtype_to_type(dtype)
|
||||
value = type_(value)
|
||||
if is_python:
|
||||
type_ = dtype_to_type(dtype)
|
||||
value = type_(value)
|
||||
else:
|
||||
# We do a type check on a best-effort basis
|
||||
# We don't want to force a cast to sympy.Float if the value is Rational to avoid losing precision
|
||||
if dtype == torch.bool:
|
||||
assert isinstance(value, BooleanAtom)
|
||||
elif dtype.is_floating_point:
|
||||
assert not value.is_finite or value.is_real
|
||||
else:
|
||||
# dtype is intXX
|
||||
assert value.is_integer
|
||||
|
||||
return ValueRanges.wrap(value)
|
||||
|
||||
@ -314,7 +326,9 @@ class SymPyValueRangeAnalysis:
|
||||
return ValueRanges.wrap(r)
|
||||
|
||||
if b == 0:
|
||||
type_ = sympy.Float if a.lower.is_Float else sympy.Integer
|
||||
if not a.lower.is_finite:
|
||||
return ValueRanges.unknown()
|
||||
type_ = sympy.Float if a.lower.is_real else sympy.Integer
|
||||
return ValueRanges.wrap(type_(1))
|
||||
|
||||
if b < 0:
|
||||
@ -381,12 +395,13 @@ class SymPyValueRangeAnalysis:
|
||||
def fn_(x, y):
|
||||
# Poorman's version of upcasting in Sympy
|
||||
# Inf is not a float...
|
||||
if x.is_Float or not x.is_finite or y.is_Float or not y.is_finite:
|
||||
result_type = sympy.Float
|
||||
else:
|
||||
assert x.is_Integer
|
||||
assert y.is_Integer
|
||||
if x.is_Integer and y.is_Integer:
|
||||
result_type = sympy.Integer
|
||||
elif x.is_rational and y.is_rational:
|
||||
result_type = sympy.Rational
|
||||
else:
|
||||
assert x.is_real or not x.is_finite or y.is_real or not y.is_finite
|
||||
result_type = sympy.Float
|
||||
return fn(result_type(x), result_type(y))
|
||||
|
||||
return ValueRanges.coordinatewise_increasing_map(a, b, fn_)
|
||||
|
||||
Reference in New Issue
Block a user