mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[atomically_apply_size_hint] Make unbacked replacements reconciles to a single expr (#164324)
## Problem
Okay there's limitations with today's `atomically_apply_size_hint` though it works for most observed failures we've seen so far. However, it's easy to come up with an edge case.
Suppose you encounter this setup.
```
a: [s0 + u0]
b: [s1 + u1]
c: [u2 + u3]
d: [u100]
```
Today, we use a few heuristics to specify the LHS and RHS for replacements.
10d2734d9b/torch/_inductor/sizevars.py (L730-L759)
It's possible to end up with these replacement rules. Notice how there's no replacement for `s1 + u1` and `u2 + u3` :( That's because today picking the LHS and RHS matters a lot, and `s1 + u1` & `u2 + u3` happened to end up on the RHS.
```
s0 + u0 => s1 + u1
s0 + u0 => u2 + u3 # overrides previous replacement; each expr only gets one replacement
s0 + u0 => u100 # overrides previous replacement; ditto
```
I believe what we really want is this: everybody gets a replacement! And they all should (eventually) settle at the same canonical expr (i.e. `u100`) when running the replacement several times.
```
s1 + u1 ==> s0 + u0
u2 + u3 ==> s0 + u0
s0 + u0 ==> u100
```
We can just short-cut this by using the canonical expr as the replacement.
```
s1 + u1 ==> u100
u2 + u3 ==> u100
s0 + u0 ==> u100
```
## Implementation
I offer one way to deal with this:
1. assure every expression has one canonical replacement (i.e. `u100`)
2. if two expressions are equal (inferred from `deferred_runtime_asserts`), then they must have the same canonical replacement
We can implement the above with union find.
* Whenever you see `Eq(lhs, rhs)` then do `union(lhs, rhs)`.
* Whenever you want to find the canonical replacement for a given expr then do `find(expr)`.
* When picking the canonical replacement we can use a few heuristics like (1) prefer a fully backed expr, (2) replacing with sub-expressions, and whatever we'd like.
Differential Revision: [D84549260](https://our.internmc.facebook.com/intern/diff/D84549260)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164324
Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
56d6229ff9
commit
306c55ba27
@ -1664,6 +1664,82 @@ class AOTInductorTestsTemplate:
|
||||
)
|
||||
self.check_model(Repro(), example_inputs)
|
||||
|
||||
@skipIfMPS
|
||||
@config.patch({"unbacked_symint_fallback": 12})
|
||||
@parametrize("shift_k", [0, 1, 2, 3])
|
||||
@parametrize("use_static_size", [True, False])
|
||||
def test_unbacked_expr_replacements(self, shift_k, use_static_size):
|
||||
"""
|
||||
Test parameters
|
||||
- shift_k: Validates that torch._check assertion order doesn't affect
|
||||
results by shifting the order of torch._checks
|
||||
- use_static_size: Tests torch._check compatibility between unbacked
|
||||
symbolic expressions and static shapes
|
||||
"""
|
||||
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("Need triton for user-defined triton kernel")
|
||||
|
||||
def realize_out_tensor_with_size(size):
|
||||
STATIC_DIM = 256 # large enough to hit IMA w/o compute-sanitizer
|
||||
tensor = torch.ones((size, STATIC_DIM), device=self.device)
|
||||
# Realize the tensor as an intermediate buffer
|
||||
nrows, ncols = tensor.shape
|
||||
numel = tensor.numel()
|
||||
add_kernel[nrows,](
|
||||
in_ptr0=tensor,
|
||||
in_ptr1=tensor,
|
||||
out_ptr=tensor,
|
||||
n_elements=numel,
|
||||
BLOCK_SIZE=ncols,
|
||||
)
|
||||
return tensor
|
||||
|
||||
class Repro(torch.nn.Module):
|
||||
def forward(self, x, y, lst):
|
||||
STATIC_SIZE = 300
|
||||
s0, s1 = x.shape
|
||||
s2, s3 = y.shape
|
||||
u0, u1, u2, u3, u100 = lst.tolist()
|
||||
|
||||
expr1 = s0 + u0
|
||||
expr2 = s1 + u1
|
||||
expr3 = (s2 * s3) + (u2 // u3) # make this one a lil complicated
|
||||
expr4 = STATIC_SIZE if use_static_size else u100
|
||||
|
||||
t1 = realize_out_tensor_with_size(expr1)
|
||||
t2 = realize_out_tensor_with_size(expr2)
|
||||
t3 = realize_out_tensor_with_size(expr3)
|
||||
t4 = realize_out_tensor_with_size(expr4)
|
||||
|
||||
# shift tensors to change up the torch._check order
|
||||
tensors = [t1, t2, t3, t4]
|
||||
shifted_tensors = tensors[shift_k:] + tensors[:shift_k]
|
||||
|
||||
# torch.cat implicitly runs torch._check(lhs == rhs)
|
||||
cat = torch.cat(shifted_tensors, dim=1)
|
||||
|
||||
return cat * cat
|
||||
|
||||
# Disable cuda caching allocator to check for IMA
|
||||
torch.cuda.caching_allocator_enable(False)
|
||||
model = Repro()
|
||||
example_inputs = (
|
||||
# s0, s1
|
||||
torch.randn((100, 200), device=self.device),
|
||||
# s2, s3
|
||||
torch.randn((100, 3), device=self.device),
|
||||
# u0, u1, u2, u3, u100
|
||||
torch.tensor([200, 100, 0, 1, 300], device=self.device, dtype=torch.int),
|
||||
)
|
||||
spec = {
|
||||
"x": (Dim.DYNAMIC, Dim.DYNAMIC),
|
||||
"y": (Dim.DYNAMIC, Dim.DYNAMIC),
|
||||
"lst": (Dim.STATIC,),
|
||||
}
|
||||
self.check_model(model, example_inputs, dynamic_shapes=spec)
|
||||
torch.cuda.caching_allocator_enable(True)
|
||||
|
||||
@skipIfMPS
|
||||
@config.patch({"unbacked_symint_fallback": 12})
|
||||
@config.patch({"triton.autotune_at_compile_time": None})
|
||||
|
@ -2,6 +2,7 @@
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
@ -730,56 +731,162 @@ class SizeVarAllocator:
|
||||
return strides
|
||||
|
||||
def _get_unbacked_replacements(self) -> dict[Expr, Expr]:
|
||||
"""
|
||||
This helps with covering unbacked symint cases where you may have two
|
||||
expressions: s0 + u0 and u1. And s0 + u0 is known to be equal to u1
|
||||
via deferred_runtime_asserts.
|
||||
|
||||
For example in atomically_apply_size_hint, it must return the same size
|
||||
hint for both s0 + u0 and u1, but it first needs to know they are equal.
|
||||
Then it can substitute s0 + u0 for u1.
|
||||
"""
|
||||
if self.unbacked_replacements is not None:
|
||||
return self.unbacked_replacements
|
||||
|
||||
def should_keep_src_dst(lhs: Expr, rhs: Expr):
|
||||
# assuming lhs is the expr to be replaced (src), rhs is the replacement (dst)
|
||||
# checking if we should keep them for the replacement rule or swap
|
||||
class CanonicalExprFinder:
|
||||
"""
|
||||
Purpose:
|
||||
A disjoint-set/union-find data structure that can return the
|
||||
"canonical" expression for a group of equivalent expressions.
|
||||
- The canonical expression must come from the input eq_graph.
|
||||
- The heuristics used to choose a leader determines which
|
||||
expression becomes the canonical expression.
|
||||
|
||||
if not has_free_unbacked_symbols(rhs):
|
||||
# prioritize replacing unbacked exprs with backed expressions
|
||||
# e.g. u0 + s3 ==> s0 + s1
|
||||
return True
|
||||
elif not has_free_unbacked_symbols(lhs):
|
||||
return False
|
||||
elif lhs.has(rhs):
|
||||
# handles cases where LHS is a sub-expression of the RHS
|
||||
# e.g. Max(2, u0) == s1 * Max(2, u0)
|
||||
return True
|
||||
elif rhs.has(lhs):
|
||||
return False
|
||||
else:
|
||||
# fallback to sympy.Basic.compare for a deterministic ordering
|
||||
return lhs.compare(rhs) == 1
|
||||
Problem:
|
||||
Given any unbacked expression, we should be able to find a size_hint
|
||||
for the unbacked expression, that adheres to the ShapeEnv's deferred
|
||||
runtime assertions. Otherwise, we may generate conflicting size hints.
|
||||
In other words, even though we know u0 + s0 == u2, we may generate
|
||||
size hints, such that, size_hint(u0 + s0) != size_hint(u2).
|
||||
NOTE: At this time, only deferred runtime asserts that are equalities
|
||||
(i.e. Eq(lhs, rhs)) are considered in this data structure.
|
||||
|
||||
self.unbacked_replacements = {}
|
||||
Examples:
|
||||
- u0 + u1 == 9000, then find_expr(u0 + u1) == find_expr(9000)
|
||||
- u0 + u1 == s9, then find_expr(u0 + u1) == find_expr(s9)
|
||||
- u0 + s0 == u10, then find_expr(u0 + s0) == find_expr(u10)
|
||||
|
||||
Inputs:
|
||||
- equality_graph: An adjacency set of expressions where the edge
|
||||
connects two expressions that are found equal to each other. The
|
||||
edges are sourced from ShapeEnv's deferred_runtime_asserts.
|
||||
|
||||
Usage:
|
||||
- Call union_expr(a, b) to merge a & b into a single set which
|
||||
shares the same canonical expression.
|
||||
- Call find_expr(x) to find the canonical expression for x.
|
||||
"""
|
||||
|
||||
def __init__(self, eq_graph: dict[Expr, OrderedSet[Expr]]):
|
||||
self.eq_graph = eq_graph
|
||||
self.expressions = list(eq_graph.keys())
|
||||
self.reverse_expressions = {
|
||||
expr: i for i, expr in enumerate(self.expressions)
|
||||
}
|
||||
# Each node is its own leader/parent initially
|
||||
self.leader = list(range(len(self.expressions)))
|
||||
# Track rank for union-by-rank
|
||||
self.rank = [1] * len(self.expressions)
|
||||
|
||||
# Takes each edge from the undirected graph and starts merging them.
|
||||
self._build_canonical_expr_mapping()
|
||||
|
||||
def _build_canonical_expr_mapping(self):
|
||||
for expr, edges in self.eq_graph.items():
|
||||
for adj in edges:
|
||||
self.union_expr(expr, adj)
|
||||
|
||||
def union_expr(self, a: Expr, b: Expr):
|
||||
return self.union(
|
||||
self.reverse_expressions[a], self.reverse_expressions[b]
|
||||
)
|
||||
|
||||
def union(self, a: int, b: int):
|
||||
rootA = self.find(a)
|
||||
rootB = self.find(b)
|
||||
if rootA == rootB:
|
||||
return False # already connected
|
||||
leader, other = self.choose_leader(rootA, rootB)
|
||||
self.leader[other] = leader
|
||||
self.rank[leader] += self.rank[other]
|
||||
return True
|
||||
|
||||
def find_expr(self, expr: Expr):
|
||||
parent = self.find(self.reverse_expressions[expr])
|
||||
return self.expressions[parent]
|
||||
|
||||
def find(self, x: int):
|
||||
# Path compression
|
||||
if self.leader[x] != x:
|
||||
self.leader[x] = self.find(self.leader[x])
|
||||
return self.leader[x]
|
||||
|
||||
def choose_leader(self, a: int, b: int):
|
||||
"""
|
||||
The leader will become the canonical expression.
|
||||
|
||||
Here are the heuristics used for choosing a leader:
|
||||
1. Backed expression or constants preferred over unbacked expr
|
||||
2. Simpler sub-expr when one contains the other
|
||||
3. Higher frequency across equalities from deferred runtime assertions
|
||||
4. Rank/size of the set
|
||||
5. Fallback to sympy.Basic.compare
|
||||
"""
|
||||
|
||||
def _choose(x: int, y: int) -> bool:
|
||||
lhs, rhs = self.expressions[x], self.expressions[y]
|
||||
|
||||
# Prefer replacing unbacked exprs with backed expressions/constants.
|
||||
# Examples:
|
||||
# u0 + s3 ==> s0 + s1, then leader is s0 + s1
|
||||
# u2 ==> 300, then leader is 300
|
||||
any_unbacked_lhs = has_free_unbacked_symbols(lhs)
|
||||
any_unbacked_rhs = has_free_unbacked_symbols(rhs)
|
||||
if any_unbacked_lhs != any_unbacked_rhs:
|
||||
return True if any_unbacked_rhs else False
|
||||
|
||||
# Handles cases where LHS contains the RHS. In other words,
|
||||
# RHS is a sub-expression of LHS. For example:
|
||||
# s1 * Max(2, u0) ==> Max(2, u0), then leader is Max(2, u0)
|
||||
if lhs.has(rhs):
|
||||
return False
|
||||
elif rhs.has(lhs):
|
||||
return True
|
||||
|
||||
# Prefer expressions that come up more often.
|
||||
degrees_lhs = len(self.eq_graph[lhs])
|
||||
degrees_rhs = len(self.eq_graph[rhs])
|
||||
if degrees_lhs != degrees_rhs:
|
||||
return True if degrees_lhs > degrees_rhs else False
|
||||
|
||||
# Try to apply union-by-rank optimization to flatten the
|
||||
# leader trees.
|
||||
if self.rank[x] != self.rank[y]:
|
||||
return True if self.rank[x] > self.rank[y] else False
|
||||
|
||||
# Fallback to sympy.Basic.compare for a deterministic ordering.
|
||||
return lhs.compare(rhs) == -1
|
||||
|
||||
if _choose(a, b):
|
||||
return a, b
|
||||
return b, a
|
||||
|
||||
# Build an undirected graph using ShapeEnv's deferred runtime assertions.
|
||||
self.equality_graph: dict[Expr, OrderedSet[Expr]] = defaultdict(OrderedSet)
|
||||
for assertions in self.shape_env.deferred_runtime_asserts.values():
|
||||
for assertion in assertions:
|
||||
if not isinstance(assertion.expr, sympy.Equality):
|
||||
# We're ignoring other relationals for now. If you need to
|
||||
# account for relationals, then you may need a solver solution.
|
||||
continue
|
||||
lhs = sympy.sympify(assertion.expr.lhs) # sympify helps with ints
|
||||
rhs = sympy.sympify(assertion.expr.rhs)
|
||||
self.equality_graph[lhs].add(rhs)
|
||||
self.equality_graph[rhs].add(lhs)
|
||||
|
||||
lhs, rhs = assertion.expr.lhs, assertion.expr.rhs
|
||||
should_keep = should_keep_src_dst(lhs, rhs)
|
||||
src = lhs if should_keep else rhs
|
||||
dst = rhs if should_keep else lhs
|
||||
# Use the undirected graph to create a DSU data structure, so we can
|
||||
# query for a "canonical" expression.
|
||||
uf = CanonicalExprFinder(self.equality_graph)
|
||||
|
||||
# Start building the unbacked replacements mapping using CanonicalExprFinder
|
||||
# The mapping is from Expr to its "canonical" Expr.
|
||||
self.unbacked_replacements = {}
|
||||
for expr in self.equality_graph.keys():
|
||||
canonical_expr = uf.find_expr(expr)
|
||||
if expr != canonical_expr:
|
||||
self.unbacked_replacements[expr] = canonical_expr
|
||||
|
||||
existing_replacement = self.unbacked_replacements.get(src, None)
|
||||
if existing_replacement and isinstance(
|
||||
existing_replacement, sympy.Symbol
|
||||
):
|
||||
# Prefer to keep replacements with symbols.
|
||||
continue
|
||||
self.unbacked_replacements[src] = dst
|
||||
return self.unbacked_replacements
|
||||
|
||||
@functools.lru_cache # noqa: B019
|
||||
|
Reference in New Issue
Block a user