mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Symintify repeat_interleave (#154660)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/154660 Approved by: https://github.com/pianpwk
This commit is contained in:
committed by
PyTorch MergeBot
parent
f6275bf0fe
commit
e22be781b7
@ -74,7 +74,7 @@ Tensor repeat_interleave_symint(
|
||||
}
|
||||
|
||||
Tensor repeats_ = repeats;
|
||||
if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.sym_size(0) == 1)) {
|
||||
if (repeats.dim() == 0 || (repeats.dim() == 1 && TORCH_GUARD_OR_FALSE(repeats.sym_size(0).sym_eq(1)))) {
|
||||
repeats_ = repeats.reshape({1}).expand_symint({input.sym_size(dim.value())});
|
||||
} else if (repeats.dim() == 1) {
|
||||
TORCH_CHECK(
|
||||
|
@ -13373,6 +13373,25 @@ def forward(self, x):
|
||||
self.assertTrue(torch.allclose(comp_mod(inp1), mod(inp1)))
|
||||
self.assertTrue(torch.allclose(comp_mod(inp2), mod(inp2)))
|
||||
|
||||
def test_repeat_interleave(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, values, batch_sizes):
|
||||
return torch.repeat_interleave(
|
||||
torch.arange(
|
||||
values.shape[0],
|
||||
),
|
||||
batch_sizes,
|
||||
)
|
||||
|
||||
inp = (torch.randint(0, 10, (1, 3)), torch.randint(0, 10, (1,)))
|
||||
torch.fx.experimental._config.backed_size_oblivious = True
|
||||
ep = torch.export.export(
|
||||
M(), inp, dynamic_shapes=({0: Dim("dim")}, {0: Dim("dim")})
|
||||
)
|
||||
self.assertTrue(torch.allclose(M()(*inp), ep.module()(*inp)))
|
||||
inp = (torch.randint(0, 10, (2, 3)), torch.randint(0, 10, (2,)))
|
||||
self.assertTrue(torch.allclose(M()(*inp), ep.module()(*inp)))
|
||||
|
||||
def test_automatic_dynamic_shapes_simple_equality(self):
|
||||
# The next 3 test cases tests for automatic dynamic shapes specs, verifying that automatic dynamism
|
||||
# leads to replacement symbols being set for equalities, and inferred relationships being checked
|
||||
|
Reference in New Issue
Block a user