remove guard_size_oblivious from unbind. (#148815)

unbind will always specialize on dim, because it determine the number of output tensors.
guard_size_oblivious is not useful there and more confusing probably for code readers
added a comment and a test that verifies the specialization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148815
Approved by: https://github.com/pianpwk
This commit is contained in:
Laith Sakka
2025-06-20 17:50:06 -07:00
committed by PyTorch MergeBot
parent 61eaaa21a4
commit e15ea965a1
2 changed files with 23 additions and 4 deletions

View File

@ -15,7 +15,7 @@ import torch.fx
import torch.nn.functional as F
from torch import sym_int, SymBool, SymFloat, SymInt
from torch._C import _disabled_torch_function_impl
from torch._dynamo.testing import CompileCounterWithBackend
from torch._dynamo.testing import CompileCounter, CompileCounterWithBackend
from torch._inductor.utils import fresh_cache
from torch.fx.experimental import sym_node
from torch.fx.experimental.proxy_tensor import make_fx
@ -3417,6 +3417,24 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
# throws a data dependent error.
compiled_func(x, torch.tensor([5, 20]))
@skipIfTorchDynamo()
def test_unbind_not_dynamic(self):
cnt = CompileCounter()
@torch.compile(fullgraph=True, dynamic=True, backend=cnt)
def func(y):
return y.unbind(dim=2), y * 10
func(torch.ones(5, 6, 7, 8))
self.assertEqual(cnt.frame_count, 1)
# it can be dynamic in all dimentions except dim=2
func(torch.ones(4, 9, 7, 10))
self.assertEqual(cnt.frame_count, 1)
func(torch.ones(5, 6, 8, 8))
func(torch.ones(5, 6, 9, 8))
self.assertEqual(cnt.frame_count, 3)
instantiate_parametrized_tests(TestUnbacked)

View File

@ -4035,14 +4035,15 @@ def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
@register_decomposition(aten.unbind)
def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
dim = utils.canonicalize_dim(t.ndim, dim)
torch._check_index(
len(t.shape) > 0,
lambda: "Dimension specified as 0 but tensor has no dimensions",
)
if guard_size_oblivious(t.shape[dim] == 0):
# Note: t.shape[dim] can't be dynamic or unbacked, even if we use guard_or_false here we will fail
# later in the split since t.shape[dim] control the number of output tensors.
if t.shape[dim] == 0:
return ()
else:
return tuple(