mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] support unbacked stack (#163867)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163867 Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
f7ab8a2710
commit
80ed522910
@ -6240,6 +6240,26 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
}
|
||||
self._test_export_same_as_eager(kw_func, args, kwargs)
|
||||
|
||||
def test_unbacked_stack(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
nz = torch.nonzero(x)
|
||||
nz_size = nz.size(0)
|
||||
torch._check(nz_size % 4 == 0)
|
||||
|
||||
# Create two tensors whose leading dimensions are equivalent at
|
||||
# runtime but expressed via different SymInt formulas.
|
||||
first = torch.zeros((nz_size // 2, 4))
|
||||
second = torch.zeros(((nz_size // 4) * 2, 4))
|
||||
return torch.stack([first, second], dim=0)
|
||||
|
||||
inputs = (torch.ones((32,)),)
|
||||
|
||||
ep = export(M(), inputs)
|
||||
orig_res = M()(*inputs)
|
||||
ep_res = ep.module()(*inputs)
|
||||
self.assertTrue(torch.allclose(orig_res, ep_res))
|
||||
|
||||
def test_unbacked_slice_simple(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, scores, score_thr, topk: torch.Tensor, results=None):
|
||||
|
@ -4031,11 +4031,13 @@ def rot90(
|
||||
|
||||
|
||||
def _check_stack_inputs(tensors: TensorSequenceType) -> None:
|
||||
from torch.fx.experimental.symbolic_shapes import sym_eq
|
||||
|
||||
entry_shape = tensors[0].shape
|
||||
for i in range(1, len(tensors)):
|
||||
assert tensors[i].shape == entry_shape, (
|
||||
f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 "
|
||||
f"and {tensors[i].shape} at entry {i}"
|
||||
torch._check(
|
||||
sym_eq(tensors[i].shape, entry_shape),
|
||||
lambda: f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 ",
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user