simplify max(1,x) to x when x known >=1 (#157189)

Creating contiguous strides creates an expression max(1, x). Often we know that x >= 1, in
 which case we should simplify max(1, x) to x.

This appeared in two situations:
1) An internal user complained about statically_known_true(x == max(1, x)) failing (internal link: https://fb.workplace.com/groups/1028545332188949/permalink/1232958568414290).
This https://github.com/pytorch/pytorch/pull/155938 won't be needed with this.

3) Not simplifying the above could result in wrong ConstraintViolationErrors.
Because we assume non-trival single arg guards shall evaporate see the logic in the function
issue_guard in symbolic_shapes.py

with this change we longer throw ConstraintViolationErrors with the program bellow
this is blocking landing this [PR](https://github.com/pytorch/pytorch/pull/155590) from landing
internally. Due to internal export tests throwing ConstraintViolationErrors.
like
```
Constraints violated (width)!
  - Not all values of width = L['x'].size()[3] in the specified range 224 <= width <= 455 satisfy the generated guard max(1, 1 + (((-1) + L['x'].size()[3]) // 2)) == (1 + (((-1) + L['x'].size()[3]) // 2)).
````

```
x = torch.rand(10)
torch._dynamo.mark_dynamic(x, 0, max=20, min=5)

@torch.compile(fullgraph=True, dynamic=True)
def func(x):
    if max(1, (-1 + x.size()[0]//2)) == (-1+x.size()[0]//2):
        return x*400
    else:
        return (x*10)*100

func(x)

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157189
Approved by: https://github.com/pianpwk
This commit is contained in:
Laith Sakka
2025-06-28 10:56:17 -07:00
committed by PyTorch MergeBot
parent 836bb1941b
commit 6cc490d40b
2 changed files with 40 additions and 1 deletions

View File

@ -1857,6 +1857,28 @@ class TestFloorDiv(TestCase):
class TestDimConstraints(TestCase):
@skipIfTorchDynamo("mark_dynamic not supported")
def test_simplify_max_1_0(self):
x = torch.rand(10)
torch._dynamo.mark_dynamic(x, 0, max=20, min=5)
@torch.compile(fullgraph=True)
def func(x, v):
# test that statically_known_true
if (v == 0 or v == 1) and not statically_known_true(
max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2)
):
raise AssertionError("error")
if max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2):
return x * 400
else:
return (x * 10) * 100
# testing that this does not throw constraint violation error.
self.assertEqual(func(x, 1), x * 400)
self.assertEqual(func(x, 0), x * 400)
def test_dim_constraints_reduce_congruences_simple(self):
from sympy import Symbol

View File

@ -1463,7 +1463,6 @@ def statically_known_true(x: BoolLikeType) -> bool:
if not isinstance(x, SymBool):
assert isinstance(x, bool)
return x
result = _static_eval_sym_bool(x)
if result is None:
return False
@ -6360,6 +6359,24 @@ class ShapeEnv:
expr = safe_expand(expr)
expr = self.replace(expr)
# Simplify max(0/1, x) to x when x >= 0/1. max(1, x) is a commonly introduced
# expression when creating contiguous strides.
if not size_oblivious:
min_max_replacements = {}
for atom in expr.atoms(Max): # type: ignore[has-type]
if len(atom.args) > 2:
continue
a, b = atom.args
if b == 1 or b == 0:
a, b = b, a
if a == 1 and self._maybe_evaluate_static(sympy.Ge(b, 1)):
min_max_replacements[atom] = b
if a == 0 and self._maybe_evaluate_static(sympy.Ge(b, 0)):
min_max_replacements[atom] = b
if min_max_replacements:
expr = expr.xreplace(min_max_replacements)
if size_oblivious and (expr.has(Max) or expr.has(Min)): # type: ignore[has-type]
min_max_replacements = {}
for atom in (*expr.atoms(Max), *expr.atoms(Min)): # type: ignore[has-type]