mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Perform value range analysis with rationals when possible (#105137)
This is particularly useful for guards to avoid rounding errors, as most guards (all?) are rational functions. Fixes https://github.com/pytorch/pytorch/issues/105097 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105137 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
634659e262
commit
d1fedad080
@ -224,6 +224,21 @@ class TestValueRanges(TestCase):
|
||||
if r.is_finite:
|
||||
self.assertIn(r, ref_r)
|
||||
|
||||
def test_rational_bounds(self):
|
||||
# Repro from https://github.com/pytorch/pytorch/issues/105097
|
||||
from sympy import floor, Eq
|
||||
shape_0 = sympy.Symbol('shape_0', positive=True, integer=True)
|
||||
new_expr = (
|
||||
Eq(30 * floor(4 * (((shape_0 + 1) // 96)) *
|
||||
(((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 647 +
|
||||
2584 * (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 647),
|
||||
2880 * floor((((shape_0 + 1) // 96)) *
|
||||
(((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 15528 +
|
||||
323 * (((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646))) / 7764)))
|
||||
new_range_env = {shape_0: ValueRanges(lower=1, upper=190)}
|
||||
self.assertTrue(new_expr.subs({shape_0: 95}))
|
||||
self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr))
|
||||
|
||||
|
||||
class TestSympyInterp(TestCase):
|
||||
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
|
||||
|
Reference in New Issue
Block a user