mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
simplify max(1,x) to x when x known >=1 (#157189)
Creating contiguous strides creates an expression max(1, x). Often we know that x >= 1, in which case we should simplify max(1, x) to x. This appeared in two situations: 1) An internal user complained about statically_known_true(x == max(1, x)) failing (internal link: https://fb.workplace.com/groups/1028545332188949/permalink/1232958568414290). This https://github.com/pytorch/pytorch/pull/155938 won't be needed with this. 3) Not simplifying the above could result in wrong ConstraintViolationErrors. Because we assume non-trival single arg guards shall evaporate see the logic in the function issue_guard in symbolic_shapes.py with this change we longer throw ConstraintViolationErrors with the program bellow this is blocking landing this [PR](https://github.com/pytorch/pytorch/pull/155590) from landing internally. Due to internal export tests throwing ConstraintViolationErrors. like ``` Constraints violated (width)! - Not all values of width = L['x'].size()[3] in the specified range 224 <= width <= 455 satisfy the generated guard max(1, 1 + (((-1) + L['x'].size()[3]) // 2)) == (1 + (((-1) + L['x'].size()[3]) // 2)). ```` ``` x = torch.rand(10) torch._dynamo.mark_dynamic(x, 0, max=20, min=5) @torch.compile(fullgraph=True, dynamic=True) def func(x): if max(1, (-1 + x.size()[0]//2)) == (-1+x.size()[0]//2): return x*400 else: return (x*10)*100 func(x) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157189 Approved by: https://github.com/pianpwk
This commit is contained in:
committed by
PyTorch MergeBot
parent
836bb1941b
commit
6cc490d40b
@ -1857,6 +1857,28 @@ class TestFloorDiv(TestCase):
|
||||
|
||||
|
||||
class TestDimConstraints(TestCase):
|
||||
@skipIfTorchDynamo("mark_dynamic not supported")
|
||||
def test_simplify_max_1_0(self):
|
||||
x = torch.rand(10)
|
||||
torch._dynamo.mark_dynamic(x, 0, max=20, min=5)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def func(x, v):
|
||||
# test that statically_known_true
|
||||
if (v == 0 or v == 1) and not statically_known_true(
|
||||
max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2)
|
||||
):
|
||||
raise AssertionError("error")
|
||||
|
||||
if max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2):
|
||||
return x * 400
|
||||
else:
|
||||
return (x * 10) * 100
|
||||
|
||||
# testing that this does not throw constraint violation error.
|
||||
self.assertEqual(func(x, 1), x * 400)
|
||||
self.assertEqual(func(x, 0), x * 400)
|
||||
|
||||
def test_dim_constraints_reduce_congruences_simple(self):
|
||||
from sympy import Symbol
|
||||
|
||||
|
@ -1463,7 +1463,6 @@ def statically_known_true(x: BoolLikeType) -> bool:
|
||||
if not isinstance(x, SymBool):
|
||||
assert isinstance(x, bool)
|
||||
return x
|
||||
|
||||
result = _static_eval_sym_bool(x)
|
||||
if result is None:
|
||||
return False
|
||||
@ -6360,6 +6359,24 @@ class ShapeEnv:
|
||||
expr = safe_expand(expr)
|
||||
expr = self.replace(expr)
|
||||
|
||||
# Simplify max(0/1, x) to x when x >= 0/1. max(1, x) is a commonly introduced
|
||||
# expression when creating contiguous strides.
|
||||
if not size_oblivious:
|
||||
min_max_replacements = {}
|
||||
for atom in expr.atoms(Max): # type: ignore[has-type]
|
||||
if len(atom.args) > 2:
|
||||
continue
|
||||
a, b = atom.args
|
||||
if b == 1 or b == 0:
|
||||
a, b = b, a
|
||||
|
||||
if a == 1 and self._maybe_evaluate_static(sympy.Ge(b, 1)):
|
||||
min_max_replacements[atom] = b
|
||||
if a == 0 and self._maybe_evaluate_static(sympy.Ge(b, 0)):
|
||||
min_max_replacements[atom] = b
|
||||
if min_max_replacements:
|
||||
expr = expr.xreplace(min_max_replacements)
|
||||
|
||||
if size_oblivious and (expr.has(Max) or expr.has(Min)): # type: ignore[has-type]
|
||||
min_max_replacements = {}
|
||||
for atom in (*expr.atoms(Max), *expr.atoms(Min)): # type: ignore[has-type]
|
||||
|
Reference in New Issue
Block a user