[inductor] make requires_stride_order more unbacked-symint-aware (#137063)

Previously, we tried to sort SymInt strides to determine the stride
order. This PR makes the sorting more unbacked symint aware: given a Tensor
with sizes (u0, u1, u2), it has strides (u1 * u2, u1, 1), which is
sortable under the guard_size_oblivious assumptions.

Test Plan:
- test case

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137063
Approved by: https://github.com/eellison
This commit is contained in:
rzou
2024-10-30 11:59:39 -07:00
committed by PyTorch MergeBot
parent 3192bdeea4
commit ccaa2a206a
4 changed files with 92 additions and 7 deletions

View File

@ -10864,6 +10864,39 @@ class CommonTemplate:
check_lowp=False,
)
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
@torch._inductor.config.patch(implicit_fallbacks=True)
def test_custom_op_unbacked_symints(self):
@torch.library.custom_op("mylib::foo", mutates_args={})
def foo(x: torch.Tensor) -> torch.Tensor:
return x.clone()
@foo.register_fake
def _(x):
u0 = torch.library.get_ctx().new_dynamic_size()
u1 = torch.library.get_ctx().new_dynamic_size()
u2 = torch.library.get_ctx().new_dynamic_size()
return x.new_empty(u0, u1, u2)
@torch.library.custom_op("mylib::bar", mutates_args={})
def bar(x: torch.Tensor) -> torch.Tensor:
return x.clone()
@bar.register_fake
def _(x):
return torch.empty_like(x)
x = torch.randn(2, 3, 4)
@torch.compile(fullgraph=True)
def f(x):
y = foo(x)
z = bar(y)
return z
# No error
f(x)
@requires_gpu()
@torch._inductor.config.patch("layout_optimization", True)
@torch._inductor.config.patch("keep_output_stride", False)

View File

@ -78,6 +78,7 @@ from .runtime.benchmarking import benchmarker
from .runtime.hints import ReductionHint
from .utils import (
argsort,
argsort_sym,
cache_on_self,
ceildiv,
convert_shape_to_inductor,
@ -235,11 +236,17 @@ NHWC_STRIDE_ORDER = [3, 0, 2, 1]
NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1]
def get_fill_order(seq: Sequence[Union[int, torch.SymInt, Expr]]) -> Sequence[int]:
def get_fill_order(
seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env=None
) -> Sequence[int]:
"""
Convert strides to fill order (argsort)
"""
sorted_idx: Sequence[int] = argsort(seq)
if shape_env is None:
sorted_idx: Sequence[int] = argsort(seq)
else:
# argsort_sym handles unbacked symints (with the help of the shape_env)
sorted_idx = argsort_sym(shape_env, seq)
return sorted_idx
@ -255,11 +262,13 @@ def stride_order2fill_order(order: Sequence[Union[int, Integer]]) -> Sequence[in
return fill_order
def get_stride_order(seq: Sequence[Union[int, torch.SymInt, Expr]]) -> Sequence[int]:
def get_stride_order(
seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env=None
) -> Sequence[int]:
"""
Convert strides to stride order
"""
sorted_idx: Sequence[int] = get_fill_order(seq)
sorted_idx: Sequence[int] = get_fill_order(seq, shape_env)
out = [0 for _ in range(len(seq))]
for i, elem in enumerate(sorted_idx):
out[elem] = i
@ -3017,10 +3026,15 @@ class Layout(IRNode):
# reorder the stride given order
stride_ordered = [-1] * len(order)
for i in range(len(order)):
stride_ordered[order[i]] = V.graph.sizevars.size_hint(stride[i])
stride_ordered[order[i]] = stride[i]
# check if it is in ascending order
for i in range(len(order) - 1):
if stride_ordered[i] > stride_ordered[i + 1]:
expr = stride_ordered[i] > stride_ordered[i + 1]
if not isinstance(expr, bool):
expr = V.graph._shape_env.evaluate_expr(
stride_ordered[i] > stride_ordered[i + 1], size_oblivious=True
)
if expr:
return False
return True

View File

@ -2285,7 +2285,9 @@ def require_channels_last(_, *args, **kwargs):
def constrain_to_fx_strides(fx_node, *args, **kwargs):
def apply_constraint(arg, fx_arg):
if isinstance(arg, ir.IRNode):
stride_order = ir.get_stride_order(fx_arg.meta["val"].stride())
stride_order = ir.get_stride_order(
fx_arg.meta["val"].stride(), V.graph.sizevars.shape_env
)
return ir.ExternKernel.require_stride_order(arg, stride_order)
if isinstance(arg, dict):
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg.keys()}

View File

@ -857,6 +857,42 @@ def argsort(seq) -> List[int]:
return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
def argsort_sym(
shape_env, seq: Sequence[Union[int, torch.SymInt, sympy.Expr]]
) -> List[int]:
def cmp(a, b):
a_idx, a_val = a
b_idx, b_val = b
def evaluate(expr):
if isinstance(expr, bool):
return expr
return shape_env.evaluate_expr(expr, size_oblivious=True)
if evaluate(a_val < b_val):
return -1
if evaluate(a_val > b_val):
return 1
# If strides are the same, prefer the original order.
# (this matches argsort's algorithm).
# For strides = [2048, 2048, 16, 1], this is
# [3, 2, 1, 0].
if a_idx < b_idx:
return 1
if a_idx > b_idx:
return -1
return 0
# Strategy: convert all symints to sympy.Expr, then use a custom comparator
exprs = [
(idx, s.node.expr if isinstance(s, torch.SymInt) else s)
for idx, s in enumerate(seq)
]
exprs = sorted(exprs, key=functools.cmp_to_key(cmp))
result = [idx for idx, _ in exprs]
return result
@functools.lru_cache(8)
def get_dtype_size(dtype):
return torch.empty((), dtype=dtype).element_size()