mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] make requires_stride_order more unbacked-symint-aware (#137063)
Previously, we tried to sort SymInt strides to determine the stride order. This PR makes the sorting more unbacked symint aware: given a Tensor with sizes (u0, u1, u2), it has strides (u1 * u2, u1, 1), which is sortable under the guard_size_oblivious assumptions. Test Plan: - test case Pull Request resolved: https://github.com/pytorch/pytorch/pull/137063 Approved by: https://github.com/eellison
This commit is contained in:
@ -10864,6 +10864,39 @@ class CommonTemplate:
|
||||
check_lowp=False,
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
||||
@torch._inductor.config.patch(implicit_fallbacks=True)
|
||||
def test_custom_op_unbacked_symints(self):
|
||||
@torch.library.custom_op("mylib::foo", mutates_args={})
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.clone()
|
||||
|
||||
@foo.register_fake
|
||||
def _(x):
|
||||
u0 = torch.library.get_ctx().new_dynamic_size()
|
||||
u1 = torch.library.get_ctx().new_dynamic_size()
|
||||
u2 = torch.library.get_ctx().new_dynamic_size()
|
||||
return x.new_empty(u0, u1, u2)
|
||||
|
||||
@torch.library.custom_op("mylib::bar", mutates_args={})
|
||||
def bar(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.clone()
|
||||
|
||||
@bar.register_fake
|
||||
def _(x):
|
||||
return torch.empty_like(x)
|
||||
|
||||
x = torch.randn(2, 3, 4)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def f(x):
|
||||
y = foo(x)
|
||||
z = bar(y)
|
||||
return z
|
||||
|
||||
# No error
|
||||
f(x)
|
||||
|
||||
@requires_gpu()
|
||||
@torch._inductor.config.patch("layout_optimization", True)
|
||||
@torch._inductor.config.patch("keep_output_stride", False)
|
||||
|
@ -78,6 +78,7 @@ from .runtime.benchmarking import benchmarker
|
||||
from .runtime.hints import ReductionHint
|
||||
from .utils import (
|
||||
argsort,
|
||||
argsort_sym,
|
||||
cache_on_self,
|
||||
ceildiv,
|
||||
convert_shape_to_inductor,
|
||||
@ -235,11 +236,17 @@ NHWC_STRIDE_ORDER = [3, 0, 2, 1]
|
||||
NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1]
|
||||
|
||||
|
||||
def get_fill_order(seq: Sequence[Union[int, torch.SymInt, Expr]]) -> Sequence[int]:
|
||||
def get_fill_order(
|
||||
seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env=None
|
||||
) -> Sequence[int]:
|
||||
"""
|
||||
Convert strides to fill order (argsort)
|
||||
"""
|
||||
sorted_idx: Sequence[int] = argsort(seq)
|
||||
if shape_env is None:
|
||||
sorted_idx: Sequence[int] = argsort(seq)
|
||||
else:
|
||||
# argsort_sym handles unbacked symints (with the help of the shape_env)
|
||||
sorted_idx = argsort_sym(shape_env, seq)
|
||||
return sorted_idx
|
||||
|
||||
|
||||
@ -255,11 +262,13 @@ def stride_order2fill_order(order: Sequence[Union[int, Integer]]) -> Sequence[in
|
||||
return fill_order
|
||||
|
||||
|
||||
def get_stride_order(seq: Sequence[Union[int, torch.SymInt, Expr]]) -> Sequence[int]:
|
||||
def get_stride_order(
|
||||
seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env=None
|
||||
) -> Sequence[int]:
|
||||
"""
|
||||
Convert strides to stride order
|
||||
"""
|
||||
sorted_idx: Sequence[int] = get_fill_order(seq)
|
||||
sorted_idx: Sequence[int] = get_fill_order(seq, shape_env)
|
||||
out = [0 for _ in range(len(seq))]
|
||||
for i, elem in enumerate(sorted_idx):
|
||||
out[elem] = i
|
||||
@ -3017,10 +3026,15 @@ class Layout(IRNode):
|
||||
# reorder the stride given order
|
||||
stride_ordered = [-1] * len(order)
|
||||
for i in range(len(order)):
|
||||
stride_ordered[order[i]] = V.graph.sizevars.size_hint(stride[i])
|
||||
stride_ordered[order[i]] = stride[i]
|
||||
# check if it is in ascending order
|
||||
for i in range(len(order) - 1):
|
||||
if stride_ordered[i] > stride_ordered[i + 1]:
|
||||
expr = stride_ordered[i] > stride_ordered[i + 1]
|
||||
if not isinstance(expr, bool):
|
||||
expr = V.graph._shape_env.evaluate_expr(
|
||||
stride_ordered[i] > stride_ordered[i + 1], size_oblivious=True
|
||||
)
|
||||
if expr:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -2285,7 +2285,9 @@ def require_channels_last(_, *args, **kwargs):
|
||||
def constrain_to_fx_strides(fx_node, *args, **kwargs):
|
||||
def apply_constraint(arg, fx_arg):
|
||||
if isinstance(arg, ir.IRNode):
|
||||
stride_order = ir.get_stride_order(fx_arg.meta["val"].stride())
|
||||
stride_order = ir.get_stride_order(
|
||||
fx_arg.meta["val"].stride(), V.graph.sizevars.shape_env
|
||||
)
|
||||
return ir.ExternKernel.require_stride_order(arg, stride_order)
|
||||
if isinstance(arg, dict):
|
||||
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg.keys()}
|
||||
|
@ -857,6 +857,42 @@ def argsort(seq) -> List[int]:
|
||||
return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
|
||||
|
||||
|
||||
def argsort_sym(
|
||||
shape_env, seq: Sequence[Union[int, torch.SymInt, sympy.Expr]]
|
||||
) -> List[int]:
|
||||
def cmp(a, b):
|
||||
a_idx, a_val = a
|
||||
b_idx, b_val = b
|
||||
|
||||
def evaluate(expr):
|
||||
if isinstance(expr, bool):
|
||||
return expr
|
||||
return shape_env.evaluate_expr(expr, size_oblivious=True)
|
||||
|
||||
if evaluate(a_val < b_val):
|
||||
return -1
|
||||
if evaluate(a_val > b_val):
|
||||
return 1
|
||||
# If strides are the same, prefer the original order.
|
||||
# (this matches argsort's algorithm).
|
||||
# For strides = [2048, 2048, 16, 1], this is
|
||||
# [3, 2, 1, 0].
|
||||
if a_idx < b_idx:
|
||||
return 1
|
||||
if a_idx > b_idx:
|
||||
return -1
|
||||
return 0
|
||||
|
||||
# Strategy: convert all symints to sympy.Expr, then use a custom comparator
|
||||
exprs = [
|
||||
(idx, s.node.expr if isinstance(s, torch.SymInt) else s)
|
||||
for idx, s in enumerate(seq)
|
||||
]
|
||||
exprs = sorted(exprs, key=functools.cmp_to_key(cmp))
|
||||
result = [idx for idx, _ in exprs]
|
||||
return result
|
||||
|
||||
|
||||
@functools.lru_cache(8)
|
||||
def get_dtype_size(dtype):
|
||||
return torch.empty((), dtype=dtype).element_size()
|
||||
|
Reference in New Issue
Block a user