[export] support unbacked stack (#163867)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163867
Approved by: https://github.com/laithsakka
This commit is contained in:
Colin Peppler
2025-09-30 17:55:05 -07:00
committed by PyTorch MergeBot
parent f7ab8a2710
commit 80ed522910
2 changed files with 25 additions and 3 deletions

View File

@ -6240,6 +6240,26 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
}
self._test_export_same_as_eager(kw_func, args, kwargs)
def test_unbacked_stack(self):
class M(torch.nn.Module):
def forward(self, x):
nz = torch.nonzero(x)
nz_size = nz.size(0)
torch._check(nz_size % 4 == 0)
# Create two tensors whose leading dimensions are equivalent at
# runtime but expressed via different SymInt formulas.
first = torch.zeros((nz_size // 2, 4))
second = torch.zeros(((nz_size // 4) * 2, 4))
return torch.stack([first, second], dim=0)
inputs = (torch.ones((32,)),)
ep = export(M(), inputs)
orig_res = M()(*inputs)
ep_res = ep.module()(*inputs)
self.assertTrue(torch.allclose(orig_res, ep_res))
def test_unbacked_slice_simple(self):
class M(torch.nn.Module):
def forward(self, scores, score_thr, topk: torch.Tensor, results=None):

View File

@ -4031,11 +4031,13 @@ def rot90(
def _check_stack_inputs(tensors: TensorSequenceType) -> None:
from torch.fx.experimental.symbolic_shapes import sym_eq
entry_shape = tensors[0].shape
for i in range(1, len(tensors)):
assert tensors[i].shape == entry_shape, (
f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 "
f"and {tensors[i].shape} at entry {i}"
torch._check(
sym_eq(tensors[i].shape, entry_shape),
lambda: f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 ",
)