Compare commits

...

15 Commits

Author SHA1 Message Date
5559787d5e Update on "harden backed_size_oblivious and broadcast_shapes"
We probably need something similar for expand


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-11 21:49:51 -08:00
374dd7a1d7 Update base for Update on "harden backed_size_oblivious and broadcast_shapes"
We probably need something similar for expand


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-11 21:49:51 -08:00
31efc3b2cb Update on "harden backed_size_oblivious and broadcast_shapes"
We probably need something similar for expand


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-11 11:48:11 -08:00
ea7e19d6cf Update base for Update on "harden backed_size_oblivious and broadcast_shapes"
We probably need something similar for expand


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-11 11:48:11 -08:00
5fec0f9289 Update on "harden backed_size_oblivious and broadcast_shapes"
We probably need something similar for expand


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-06 15:47:36 -08:00
cc9f4fb4fc Update base for Update on "harden backed_size_oblivious and broadcast_shapes"
We probably need something similar for expand


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-06 15:47:35 -08:00
c75fb29df6 Update on "harden backed_size_oblivious and broadcast_shapes"
We probably need something similar for expand


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-06 13:24:04 -08:00
b78751c0ce Update base for Update on "harden backed_size_oblivious and broadcast_shapes"
We probably need something similar for expand


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-06 13:24:04 -08:00
eeff2c5f9b Update on "harden backed_size_oblivious and broadcast_shapes"
We probably need something similar for expand


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-06 13:03:47 -08:00
cc89226e97 Update base for Update on "harden backed_size_oblivious and broadcast_shapes"
We probably need something similar for expand


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-06 13:03:47 -08:00
0a9d56448e Update on "harden backed_size_oblivious and broadcast_shapes"
We probably need something similar for expand


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-06 09:17:48 -08:00
047c7713f7 Update base for Update on "harden backed_size_oblivious and broadcast_shapes"
We probably need something similar for expand


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-06 09:17:48 -08:00
ab2db578d3 harden backed_size_oblivious and broadcast_shapes
[ghstack-poisoned]
2025-11-06 09:16:11 -08:00
7104addf1e Update on "deprecate check_is_size and guard_size_oblivious"
cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-06 07:52:23 -08:00
0deaae4852 deprecate sizelike and guard_size_oblivious
[ghstack-poisoned]
2025-11-05 22:12:04 -08:00
3 changed files with 83 additions and 1 deletions

View File

@ -4465,6 +4465,54 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
res = f(x, start, 0)
self.assertEqual(res.shape, torch.Size([0]))
@skipIfTorchDynamo()
@torch.fx.experimental._config.patch("backed_size_oblivious", True)
def test_backed_size_oblivious_broadcast(self):
cnt = CompileCounterWithBackend("inductor")
torch._dynamo.reset()
def func(a, b):
torch.broadcast_shapes(a.size(), b.size())
return a + b
compiled = torch.compile(func, fullgraph=True, backend=cnt, dynamic=True)
def run(a, b):
self.assertEqual(compiled(a, b), func(a, b))
# No 0/1 specializations, no broadcasts.
# but a[0] == b[0] and a[1] == b[1] are asserted.
run(torch.rand(1, 10), torch.rand(1, 10))
run(torch.rand(1, 1), torch.rand(1, 1))
run(torch.rand(10, 10), torch.rand(10, 10))
self.assertEqual(cnt.frame_count, 1)
run(torch.rand(10, 10), torch.rand(1, 10))
self.assertEqual(cnt.frame_count, 2)
cnt.clear()
torch._dynamo.reset()
# specialize a[0] == 1. b[0] not specialized.
run(torch.rand(1, 10), torch.rand(9, 10))
run(torch.rand(1, 10), torch.rand(1, 10))
self.assertEqual(cnt.frame_count, 1)
# if we change a[0] we get recompilation.
run(torch.rand(10, 10), torch.rand(10, 10))
self.assertEqual(cnt.frame_count, 2)
cnt.clear()
torch._dynamo.reset()
# TODO duck sizing shall be disabled when backed_size_oblivious
# is on probably.
# specialize b[0] == 1. a[0] not specialized.
run(torch.rand(10, 11), torch.rand(1, 11))
run(torch.rand(1, 10), torch.rand(1, 10))
self.assertEqual(cnt.frame_count, 1)
run(torch.rand(2, 10), torch.rand(2, 10))
self.assertEqual(cnt.frame_count, 2)
instantiate_parametrized_tests(TestUnbacked)

View File

@ -385,7 +385,13 @@ def handle_noncontiguous_outputs(input_tlist, output):
def _broadcast_shapes(*_shapes):
from torch.fx.experimental.symbolic_shapes import guard_or_false, is_nested_int
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
is_nested_int,
size_hint,
)
backed_so = torch.fx.experimental._config.backed_size_oblivious
shapes = tuple(
(x,) if isinstance(x, IntLike) else x
@ -418,6 +424,22 @@ def _broadcast_shapes(*_shapes):
):
continue
else:
# When backed size oblivious is used, we specialize for broadcasting
# if its the only way to compile the example input.
# i.e: s0:1, s1:1 ==>
# assert s0==s1, no specialization on ==1 or !=1.
# The non-broadcast path is picked
# s0:1, s1:4 ==>
# specialize(s0) to be 1.
# s0:4, s1:1 ==>
# specialize(s1) to be 1.
if backed_so:
a = size_hint(shape[idx], allow_none=True)
b = size_hint(common_shape[idx], allow_none=True)
if a == 1 and b != 1:
torch._check(shape[idx] == 1)
if b == 1 and a != 1:
torch._check(common_shape[idx] == 1)
if guard_or_false(shape[idx] == common_shape[idx]):
continue

View File

@ -131,6 +131,7 @@ class PendingUnbackedSymbolNotFound(RuntimeError):
aten = torch._ops.ops.aten # type: ignore[has-type]
__all__ = [
"size_hint",
"guard_or_false",
"guard_or_true",
"has_symbolic_sizes_strides",
@ -255,6 +256,17 @@ def _nested_int_aware_sort(
)
def size_hint(x: int | torch.SymInt, *, allow_none: bool = False) -> int | None:
"""Gets a size hint for a given expression from the underlying shapes we had.
Does not introduce a guard, so only use this when you can guarantee that
your code is still valid for arbitrary shapes (such as optimization decisions)
"""
if isinstance(x, int):
return x
assert isinstance(x, torch.SymInt)
return x.node.shape_env.size_hint(x.node.expr, allow_none=allow_none)
# Wrapper on lru_cache that reports statistics at process end
def lru_cache(
maxsize: Optional[int],