[atomically_apply_size_hint] Make unbacked replacements reconciles to a single expr (#164324)

## Problem
Okay there's limitations with today's `atomically_apply_size_hint` though it works for most observed failures we've seen so far. However, it's easy to come up with an edge case.

Suppose you encounter this setup.
```
a: [s0 + u0]
b: [s1 + u1]
c: [u2 + u3]
d: [u100]
```

Today, we use a few heuristics to specify the LHS and RHS for replacements.

10d2734d9b/torch/_inductor/sizevars.py (L730-L759)

It's possible to end up with these replacement rules. Notice how there's no replacement for `s1 + u1` and `u2 + u3` :( That's because today picking the LHS and RHS matters a lot, and `s1 + u1` & `u2 + u3` happened to end up on the RHS.
```
s0 + u0 => s1 + u1
s0 + u0 => u2 + u3         # overrides previous replacement; each expr only gets one replacement
s0 + u0 => u100            # overrides previous replacement; ditto
```

I believe what we really want is this: everybody gets a replacement! And they all should (eventually) settle at the same canonical expr (i.e. `u100`) when running the replacement several times.
```
s1 + u1 ==> s0 + u0
u2 + u3 ==> s0 + u0
s0 + u0 ==> u100
```

We can just short-cut this by using the canonical expr as the replacement.
```
s1 + u1 ==> u100
u2 + u3 ==> u100
s0 + u0 ==> u100
```

## Implementation

I offer one way to deal with this:
1. assure every expression has one canonical replacement (i.e. `u100`)
2. if two expressions are equal (inferred from `deferred_runtime_asserts`), then they must have the same canonical replacement

 We can implement the above with union find.
* Whenever you see `Eq(lhs, rhs)` then do `union(lhs, rhs)`.
* Whenever you want to find the canonical replacement for a given expr then do `find(expr)`.
* When picking the canonical replacement we can use a few heuristics like (1) prefer a fully backed expr, (2) replacing with sub-expressions, and whatever we'd like.

Differential Revision: [D84549260](https://our.internmc.facebook.com/intern/diff/D84549260)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164324
Approved by: https://github.com/laithsakka
This commit is contained in:
Colin Peppler
2025-10-13 17:10:32 -07:00
committed by PyTorch MergeBot
parent 56d6229ff9
commit 306c55ba27
2 changed files with 222 additions and 39 deletions

View File

@ -1664,6 +1664,82 @@ class AOTInductorTestsTemplate:
)
self.check_model(Repro(), example_inputs)
@skipIfMPS
@config.patch({"unbacked_symint_fallback": 12})
@parametrize("shift_k", [0, 1, 2, 3])
@parametrize("use_static_size", [True, False])
def test_unbacked_expr_replacements(self, shift_k, use_static_size):
"""
Test parameters
- shift_k: Validates that torch._check assertion order doesn't affect
results by shifting the order of torch._checks
- use_static_size: Tests torch._check compatibility between unbacked
symbolic expressions and static shapes
"""
if self.device != GPU_TYPE:
raise unittest.SkipTest("Need triton for user-defined triton kernel")
def realize_out_tensor_with_size(size):
STATIC_DIM = 256 # large enough to hit IMA w/o compute-sanitizer
tensor = torch.ones((size, STATIC_DIM), device=self.device)
# Realize the tensor as an intermediate buffer
nrows, ncols = tensor.shape
numel = tensor.numel()
add_kernel[nrows,](
in_ptr0=tensor,
in_ptr1=tensor,
out_ptr=tensor,
n_elements=numel,
BLOCK_SIZE=ncols,
)
return tensor
class Repro(torch.nn.Module):
def forward(self, x, y, lst):
STATIC_SIZE = 300
s0, s1 = x.shape
s2, s3 = y.shape
u0, u1, u2, u3, u100 = lst.tolist()
expr1 = s0 + u0
expr2 = s1 + u1
expr3 = (s2 * s3) + (u2 // u3) # make this one a lil complicated
expr4 = STATIC_SIZE if use_static_size else u100
t1 = realize_out_tensor_with_size(expr1)
t2 = realize_out_tensor_with_size(expr2)
t3 = realize_out_tensor_with_size(expr3)
t4 = realize_out_tensor_with_size(expr4)
# shift tensors to change up the torch._check order
tensors = [t1, t2, t3, t4]
shifted_tensors = tensors[shift_k:] + tensors[:shift_k]
# torch.cat implicitly runs torch._check(lhs == rhs)
cat = torch.cat(shifted_tensors, dim=1)
return cat * cat
# Disable cuda caching allocator to check for IMA
torch.cuda.caching_allocator_enable(False)
model = Repro()
example_inputs = (
# s0, s1
torch.randn((100, 200), device=self.device),
# s2, s3
torch.randn((100, 3), device=self.device),
# u0, u1, u2, u3, u100
torch.tensor([200, 100, 0, 1, 300], device=self.device, dtype=torch.int),
)
spec = {
"x": (Dim.DYNAMIC, Dim.DYNAMIC),
"y": (Dim.DYNAMIC, Dim.DYNAMIC),
"lst": (Dim.STATIC,),
}
self.check_model(model, example_inputs, dynamic_shapes=spec)
torch.cuda.caching_allocator_enable(True)
@skipIfMPS
@config.patch({"unbacked_symint_fallback": 12})
@config.patch({"triton.autotune_at_compile_time": None})

View File

@ -2,6 +2,7 @@
import functools
import itertools
import logging
from collections import defaultdict
from collections.abc import Iterable, Sequence
from typing import Any, Callable, cast, Optional, Union
@ -730,56 +731,162 @@ class SizeVarAllocator:
return strides
def _get_unbacked_replacements(self) -> dict[Expr, Expr]:
"""
This helps with covering unbacked symint cases where you may have two
expressions: s0 + u0 and u1. And s0 + u0 is known to be equal to u1
via deferred_runtime_asserts.
For example in atomically_apply_size_hint, it must return the same size
hint for both s0 + u0 and u1, but it first needs to know they are equal.
Then it can substitute s0 + u0 for u1.
"""
if self.unbacked_replacements is not None:
return self.unbacked_replacements
def should_keep_src_dst(lhs: Expr, rhs: Expr):
# assuming lhs is the expr to be replaced (src), rhs is the replacement (dst)
# checking if we should keep them for the replacement rule or swap
class CanonicalExprFinder:
"""
Purpose:
A disjoint-set/union-find data structure that can return the
"canonical" expression for a group of equivalent expressions.
- The canonical expression must come from the input eq_graph.
- The heuristics used to choose a leader determines which
expression becomes the canonical expression.
if not has_free_unbacked_symbols(rhs):
# prioritize replacing unbacked exprs with backed expressions
# e.g. u0 + s3 ==> s0 + s1
return True
elif not has_free_unbacked_symbols(lhs):
return False
elif lhs.has(rhs):
# handles cases where LHS is a sub-expression of the RHS
# e.g. Max(2, u0) == s1 * Max(2, u0)
return True
elif rhs.has(lhs):
return False
else:
# fallback to sympy.Basic.compare for a deterministic ordering
return lhs.compare(rhs) == 1
Problem:
Given any unbacked expression, we should be able to find a size_hint
for the unbacked expression, that adheres to the ShapeEnv's deferred
runtime assertions. Otherwise, we may generate conflicting size hints.
In other words, even though we know u0 + s0 == u2, we may generate
size hints, such that, size_hint(u0 + s0) != size_hint(u2).
NOTE: At this time, only deferred runtime asserts that are equalities
(i.e. Eq(lhs, rhs)) are considered in this data structure.
self.unbacked_replacements = {}
Examples:
- u0 + u1 == 9000, then find_expr(u0 + u1) == find_expr(9000)
- u0 + u1 == s9, then find_expr(u0 + u1) == find_expr(s9)
- u0 + s0 == u10, then find_expr(u0 + s0) == find_expr(u10)
Inputs:
- equality_graph: An adjacency set of expressions where the edge
connects two expressions that are found equal to each other. The
edges are sourced from ShapeEnv's deferred_runtime_asserts.
Usage:
- Call union_expr(a, b) to merge a & b into a single set which
shares the same canonical expression.
- Call find_expr(x) to find the canonical expression for x.
"""
def __init__(self, eq_graph: dict[Expr, OrderedSet[Expr]]):
self.eq_graph = eq_graph
self.expressions = list(eq_graph.keys())
self.reverse_expressions = {
expr: i for i, expr in enumerate(self.expressions)
}
# Each node is its own leader/parent initially
self.leader = list(range(len(self.expressions)))
# Track rank for union-by-rank
self.rank = [1] * len(self.expressions)
# Takes each edge from the undirected graph and starts merging them.
self._build_canonical_expr_mapping()
def _build_canonical_expr_mapping(self):
for expr, edges in self.eq_graph.items():
for adj in edges:
self.union_expr(expr, adj)
def union_expr(self, a: Expr, b: Expr):
return self.union(
self.reverse_expressions[a], self.reverse_expressions[b]
)
def union(self, a: int, b: int):
rootA = self.find(a)
rootB = self.find(b)
if rootA == rootB:
return False # already connected
leader, other = self.choose_leader(rootA, rootB)
self.leader[other] = leader
self.rank[leader] += self.rank[other]
return True
def find_expr(self, expr: Expr):
parent = self.find(self.reverse_expressions[expr])
return self.expressions[parent]
def find(self, x: int):
# Path compression
if self.leader[x] != x:
self.leader[x] = self.find(self.leader[x])
return self.leader[x]
def choose_leader(self, a: int, b: int):
"""
The leader will become the canonical expression.
Here are the heuristics used for choosing a leader:
1. Backed expression or constants preferred over unbacked expr
2. Simpler sub-expr when one contains the other
3. Higher frequency across equalities from deferred runtime assertions
4. Rank/size of the set
5. Fallback to sympy.Basic.compare
"""
def _choose(x: int, y: int) -> bool:
lhs, rhs = self.expressions[x], self.expressions[y]
# Prefer replacing unbacked exprs with backed expressions/constants.
# Examples:
# u0 + s3 ==> s0 + s1, then leader is s0 + s1
# u2 ==> 300, then leader is 300
any_unbacked_lhs = has_free_unbacked_symbols(lhs)
any_unbacked_rhs = has_free_unbacked_symbols(rhs)
if any_unbacked_lhs != any_unbacked_rhs:
return True if any_unbacked_rhs else False
# Handles cases where LHS contains the RHS. In other words,
# RHS is a sub-expression of LHS. For example:
# s1 * Max(2, u0) ==> Max(2, u0), then leader is Max(2, u0)
if lhs.has(rhs):
return False
elif rhs.has(lhs):
return True
# Prefer expressions that come up more often.
degrees_lhs = len(self.eq_graph[lhs])
degrees_rhs = len(self.eq_graph[rhs])
if degrees_lhs != degrees_rhs:
return True if degrees_lhs > degrees_rhs else False
# Try to apply union-by-rank optimization to flatten the
# leader trees.
if self.rank[x] != self.rank[y]:
return True if self.rank[x] > self.rank[y] else False
# Fallback to sympy.Basic.compare for a deterministic ordering.
return lhs.compare(rhs) == -1
if _choose(a, b):
return a, b
return b, a
# Build an undirected graph using ShapeEnv's deferred runtime assertions.
self.equality_graph: dict[Expr, OrderedSet[Expr]] = defaultdict(OrderedSet)
for assertions in self.shape_env.deferred_runtime_asserts.values():
for assertion in assertions:
if not isinstance(assertion.expr, sympy.Equality):
# We're ignoring other relationals for now. If you need to
# account for relationals, then you may need a solver solution.
continue
lhs = sympy.sympify(assertion.expr.lhs) # sympify helps with ints
rhs = sympy.sympify(assertion.expr.rhs)
self.equality_graph[lhs].add(rhs)
self.equality_graph[rhs].add(lhs)
lhs, rhs = assertion.expr.lhs, assertion.expr.rhs
should_keep = should_keep_src_dst(lhs, rhs)
src = lhs if should_keep else rhs
dst = rhs if should_keep else lhs
# Use the undirected graph to create a DSU data structure, so we can
# query for a "canonical" expression.
uf = CanonicalExprFinder(self.equality_graph)
# Start building the unbacked replacements mapping using CanonicalExprFinder
# The mapping is from Expr to its "canonical" Expr.
self.unbacked_replacements = {}
for expr in self.equality_graph.keys():
canonical_expr = uf.find_expr(expr)
if expr != canonical_expr:
self.unbacked_replacements[expr] = canonical_expr
existing_replacement = self.unbacked_replacements.get(src, None)
if existing_replacement and isinstance(
existing_replacement, sympy.Symbol
):
# Prefer to keep replacements with symbols.
continue
self.unbacked_replacements[src] = dst
return self.unbacked_replacements
@functools.lru_cache # noqa: B019