Symintify repeat_interleave (#154660)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154660
Approved by: https://github.com/pianpwk
This commit is contained in:
angelayi
2025-06-02 15:19:39 +00:00
committed by PyTorch MergeBot
parent f6275bf0fe
commit e22be781b7
2 changed files with 20 additions and 1 deletions

View File

@ -74,7 +74,7 @@ Tensor repeat_interleave_symint(
}
Tensor repeats_ = repeats;
if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.sym_size(0) == 1)) {
if (repeats.dim() == 0 || (repeats.dim() == 1 && TORCH_GUARD_OR_FALSE(repeats.sym_size(0).sym_eq(1)))) {
repeats_ = repeats.reshape({1}).expand_symint({input.sym_size(dim.value())});
} else if (repeats.dim() == 1) {
TORCH_CHECK(

View File

@ -13373,6 +13373,25 @@ def forward(self, x):
self.assertTrue(torch.allclose(comp_mod(inp1), mod(inp1)))
self.assertTrue(torch.allclose(comp_mod(inp2), mod(inp2)))
def test_repeat_interleave(self):
class M(torch.nn.Module):
def forward(self, values, batch_sizes):
return torch.repeat_interleave(
torch.arange(
values.shape[0],
),
batch_sizes,
)
inp = (torch.randint(0, 10, (1, 3)), torch.randint(0, 10, (1,)))
torch.fx.experimental._config.backed_size_oblivious = True
ep = torch.export.export(
M(), inp, dynamic_shapes=({0: Dim("dim")}, {0: Dim("dim")})
)
self.assertTrue(torch.allclose(M()(*inp), ep.module()(*inp)))
inp = (torch.randint(0, 10, (2, 3)), torch.randint(0, 10, (2,)))
self.assertTrue(torch.allclose(M()(*inp), ep.module()(*inp)))
def test_automatic_dynamic_shapes_simple_equality(self):
# The next 3 test cases tests for automatic dynamic shapes specs, verifying that automatic dynamism
# leads to replacement symbols being set for equalities, and inferred relationships being checked