mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] fix issue for example value with unbacked strides (#163660)
## Issue During autotune, we're not applying size hints atomically for the example inputs used for benchmarking. If there is unbacked symint showing up in inputs' strides, this might lead to CUDA IMA, and this could be reproduced by the added unittest, with stride being `[128 * u0, 128, 1]` and unbacked fallback being 8192, after calling `benchmark_example_value`, we get back a tensor with stride as `[8192, 128, 1]` as opposed to `[128 * 8192, 128, 1]` ## Fix Using the atomic API when trying to apply size hints to input tensor' strides. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163660 Approved by: https://github.com/ColinPeppler
This commit is contained in:
@ -653,6 +653,28 @@ class TestUnbackedSymints(InductorTestCase):
|
||||
expected = fn(*example_inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@inductor_config.patch({"max_autotune": True})
|
||||
@dynamo_config.patch({"capture_scalar_outputs": True})
|
||||
def test_autotune_with_unbacked_stride(self, device):
|
||||
def fn(x, y, a):
|
||||
u0 = a.item()
|
||||
torch._check(u0 != 1)
|
||||
unbacked = x.expand(8, u0, *x.shape).clone()
|
||||
unbacked = torch.permute(unbacked, [0, 2, 1])
|
||||
y = y.expand(8, *y.shape)
|
||||
bmm = torch.ops.aten.bmm(unbacked, y)
|
||||
return bmm
|
||||
|
||||
example_inputs = (
|
||||
torch.randn((32,), dtype=torch.bfloat16, device=device),
|
||||
torch.randn((128, 64), dtype=torch.bfloat16, device=device),
|
||||
torch.tensor(128, device=device),
|
||||
)
|
||||
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
|
||||
expected = fn(*example_inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
|
||||
|
||||
|
@ -3622,10 +3622,13 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
V.graph.sizevars.size_hints(
|
||||
node.get_stride(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
stride,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
for stride in node.get_stride()
|
||||
),
|
||||
node.get_device(),
|
||||
node.get_dtype(),
|
||||
@ -3677,9 +3680,12 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
node.get_size(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
),
|
||||
*sizevars.size_hints(
|
||||
node.get_stride(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
*tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
stride,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
)
|
||||
for stride in node.get_stride()
|
||||
),
|
||||
sizevars.size_hint(
|
||||
node.get_layout().offset,
|
||||
|
@ -908,7 +908,11 @@ class SizeVarAllocator:
|
||||
return expr
|
||||
|
||||
def atomically_apply_size_hint(
|
||||
self, expr: Union[Expr, int], *, fallback: Optional[int] = None
|
||||
self,
|
||||
expr: Union[Expr, int],
|
||||
*,
|
||||
fallback: Optional[int] = None,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> Union[Expr, int]:
|
||||
if isinstance(expr, (int, sympy.Integer)):
|
||||
return int(expr)
|
||||
@ -925,7 +929,9 @@ class SizeVarAllocator:
|
||||
assert isinstance(expr, Expr), type(expr)
|
||||
free_symbols = expr.free_symbols
|
||||
size_dict = {
|
||||
symbol: V.graph.sizevars.size_hint(symbol, fallback=fallback)
|
||||
symbol: V.graph.sizevars.size_hint(
|
||||
symbol, fallback=fallback, hint_override=hint_override
|
||||
)
|
||||
for symbol in free_symbols
|
||||
}
|
||||
return expr.subs(size_dict)
|
||||
|
Reference in New Issue
Block a user