mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Perform value range analysis with rationals when possible (#105137)
This is particularly useful for guards to avoid rounding errors, as most guards (all?) are rational functions. Fixes https://github.com/pytorch/pytorch/issues/105097 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105137 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
634659e262
commit
d1fedad080
@ -224,6 +224,21 @@ class TestValueRanges(TestCase):
|
|||||||
if r.is_finite:
|
if r.is_finite:
|
||||||
self.assertIn(r, ref_r)
|
self.assertIn(r, ref_r)
|
||||||
|
|
||||||
|
def test_rational_bounds(self):
|
||||||
|
# Repro from https://github.com/pytorch/pytorch/issues/105097
|
||||||
|
from sympy import floor, Eq
|
||||||
|
shape_0 = sympy.Symbol('shape_0', positive=True, integer=True)
|
||||||
|
new_expr = (
|
||||||
|
Eq(30 * floor(4 * (((shape_0 + 1) // 96)) *
|
||||||
|
(((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 647 +
|
||||||
|
2584 * (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 647),
|
||||||
|
2880 * floor((((shape_0 + 1) // 96)) *
|
||||||
|
(((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 15528 +
|
||||||
|
323 * (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 7764)))
|
||||||
|
new_range_env = {shape_0: ValueRanges(lower=1, upper=190)}
|
||||||
|
self.assertTrue(new_expr.subs({shape_0: 95}))
|
||||||
|
self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr))
|
||||||
|
|
||||||
|
|
||||||
class TestSympyInterp(TestCase):
|
class TestSympyInterp(TestCase):
|
||||||
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
|
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
|
||||||
|
@ -66,15 +66,16 @@ def sympy_interp(
|
|||||||
analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean]
|
analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean]
|
||||||
):
|
):
|
||||||
# Handle base cases
|
# Handle base cases
|
||||||
# TODO: not really sure if I'm passing the right dtype here
|
dtype = None
|
||||||
# TODO: wouldn't it be better to pass the sympy expression through
|
if isinstance(expr, BooleanAtom):
|
||||||
# sometimes?
|
dtype = torch.bool
|
||||||
if isinstance(expr, sympy.Integer):
|
elif isinstance(expr, sympy.Integer):
|
||||||
return analysis.constant(int(expr), torch.int64)
|
dtype = torch.int64
|
||||||
elif isinstance(expr, sympy.Number):
|
elif isinstance(expr, sympy.Number):
|
||||||
return analysis.constant(float(expr), torch.double)
|
dtype = torch.double
|
||||||
elif isinstance(expr, BooleanAtom):
|
|
||||||
return analysis.constant(bool(expr), torch.bool)
|
if dtype is not None:
|
||||||
|
return analysis.constant(expr, dtype)
|
||||||
elif isinstance(expr, sympy.Symbol):
|
elif isinstance(expr, sympy.Symbol):
|
||||||
return env[expr]
|
return env[expr]
|
||||||
|
|
||||||
|
@ -162,15 +162,27 @@ class SymPyValueRangeAnalysis:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def constant(value, dtype):
|
def constant(value, dtype):
|
||||||
# NB: value is NOT a sympy expression, it's a constant!
|
# NB: value is NOT a sympy expression, it's a constant!
|
||||||
assert isinstance(value, (int, float, bool))
|
is_python = isinstance(value, (int, float, bool))
|
||||||
|
assert is_python or isinstance(value, (BooleanAtom, sympy.Integer, sympy.Number))
|
||||||
|
|
||||||
# using nan makes subsequent computation throw, and for the purposes of optimization
|
# using nan makes subsequent computation throw, and for the purposes of optimization
|
||||||
# returning -math.inf - math.inf is equivalent to giving up
|
# returning -math.inf - math.inf is equivalent to giving up
|
||||||
if math.isnan(value):
|
if math.isnan(value):
|
||||||
return ValueRanges.unknown()
|
return ValueRanges.unknown()
|
||||||
|
|
||||||
type_ = dtype_to_type(dtype)
|
if is_python:
|
||||||
value = type_(value)
|
type_ = dtype_to_type(dtype)
|
||||||
|
value = type_(value)
|
||||||
|
else:
|
||||||
|
# We do a type check on a best-effort basis
|
||||||
|
# We don't want to force a cast to sympy.Float if the value is Rational to avoid losing precision
|
||||||
|
if dtype == torch.bool:
|
||||||
|
assert isinstance(value, BooleanAtom)
|
||||||
|
elif dtype.is_floating_point:
|
||||||
|
assert not value.is_finite or value.is_real
|
||||||
|
else:
|
||||||
|
# dtype is intXX
|
||||||
|
assert value.is_integer
|
||||||
|
|
||||||
return ValueRanges.wrap(value)
|
return ValueRanges.wrap(value)
|
||||||
|
|
||||||
@ -314,7 +326,9 @@ class SymPyValueRangeAnalysis:
|
|||||||
return ValueRanges.wrap(r)
|
return ValueRanges.wrap(r)
|
||||||
|
|
||||||
if b == 0:
|
if b == 0:
|
||||||
type_ = sympy.Float if a.lower.is_Float else sympy.Integer
|
if not a.lower.is_finite:
|
||||||
|
return ValueRanges.unknown()
|
||||||
|
type_ = sympy.Float if a.lower.is_real else sympy.Integer
|
||||||
return ValueRanges.wrap(type_(1))
|
return ValueRanges.wrap(type_(1))
|
||||||
|
|
||||||
if b < 0:
|
if b < 0:
|
||||||
@ -381,12 +395,13 @@ class SymPyValueRangeAnalysis:
|
|||||||
def fn_(x, y):
|
def fn_(x, y):
|
||||||
# Poorman's version of upcasting in Sympy
|
# Poorman's version of upcasting in Sympy
|
||||||
# Inf is not a float...
|
# Inf is not a float...
|
||||||
if x.is_Float or not x.is_finite or y.is_Float or not y.is_finite:
|
if x.is_Integer and y.is_Integer:
|
||||||
result_type = sympy.Float
|
|
||||||
else:
|
|
||||||
assert x.is_Integer
|
|
||||||
assert y.is_Integer
|
|
||||||
result_type = sympy.Integer
|
result_type = sympy.Integer
|
||||||
|
elif x.is_rational and y.is_rational:
|
||||||
|
result_type = sympy.Rational
|
||||||
|
else:
|
||||||
|
assert x.is_real or not x.is_finite or y.is_real or not y.is_finite
|
||||||
|
result_type = sympy.Float
|
||||||
return fn(result_type(x), result_type(y))
|
return fn(result_type(x), result_type(y))
|
||||||
|
|
||||||
return ValueRanges.coordinatewise_increasing_map(a, b, fn_)
|
return ValueRanges.coordinatewise_increasing_map(a, b, fn_)
|
||||||
|
Reference in New Issue
Block a user