[inductor] fix issue for example value with unbacked strides (#163660)

## Issue

During autotune, we're not applying size hints atomically for the example inputs used for benchmarking.

If there is unbacked symint showing up in inputs' strides, this might lead to CUDA IMA,

and this could be reproduced by the added unittest, with stride being `[128 * u0, 128, 1]` and unbacked fallback being 8192, after calling `benchmark_example_value`, we get back a tensor with stride as `[8192, 128, 1]` as opposed to `[128 * 8192, 128, 1]`

## Fix

Using the atomic API when trying to apply size hints to input tensor' strides.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163660
Approved by: https://github.com/ColinPeppler
This commit is contained in:
q1l1
2025-10-14 20:07:47 +00:00
committed by PyTorch MergeBot
parent d7e3f493d9
commit 3f83e8915e
3 changed files with 43 additions and 9 deletions

View File

@ -653,6 +653,28 @@ class TestUnbackedSymints(InductorTestCase):
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
@inductor_config.patch({"max_autotune": True})
@dynamo_config.patch({"capture_scalar_outputs": True})
def test_autotune_with_unbacked_stride(self, device):
def fn(x, y, a):
u0 = a.item()
torch._check(u0 != 1)
unbacked = x.expand(8, u0, *x.shape).clone()
unbacked = torch.permute(unbacked, [0, 2, 1])
y = y.expand(8, *y.shape)
bmm = torch.ops.aten.bmm(unbacked, y)
return bmm
example_inputs = (
torch.randn((32,), dtype=torch.bfloat16, device=device),
torch.randn((128, 64), dtype=torch.bfloat16, device=device),
torch.tensor(128, device=device),
)
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)

View File

@ -3622,10 +3622,13 @@ class AlgorithmSelectorCache(PersistentCache):
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
V.graph.sizevars.size_hints(
node.get_stride(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
tuple(
V.graph.sizevars.atomically_apply_size_hint(
stride,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
for stride in node.get_stride()
),
node.get_device(),
node.get_dtype(),
@ -3677,9 +3680,12 @@ class AlgorithmSelectorCache(PersistentCache):
node.get_size(),
fallback=config.unbacked_symint_fallback,
),
*sizevars.size_hints(
node.get_stride(),
fallback=config.unbacked_symint_fallback,
*tuple(
V.graph.sizevars.atomically_apply_size_hint(
stride,
fallback=config.unbacked_symint_fallback,
)
for stride in node.get_stride()
),
sizevars.size_hint(
node.get_layout().offset,

View File

@ -908,7 +908,11 @@ class SizeVarAllocator:
return expr
def atomically_apply_size_hint(
self, expr: Union[Expr, int], *, fallback: Optional[int] = None
self,
expr: Union[Expr, int],
*,
fallback: Optional[int] = None,
hint_override: Optional[int] = None,
) -> Union[Expr, int]:
if isinstance(expr, (int, sympy.Integer)):
return int(expr)
@ -925,7 +929,9 @@ class SizeVarAllocator:
assert isinstance(expr, Expr), type(expr)
free_symbols = expr.free_symbols
size_dict = {
symbol: V.graph.sizevars.size_hint(symbol, fallback=fallback)
symbol: V.graph.sizevars.size_hint(
symbol, fallback=fallback, hint_override=hint_override
)
for symbol in free_symbols
}
return expr.subs(size_dict)