[Inductor] Fix out-of-bounds indices in repeat_interleave decomposition (#165368)

When `repeat_interleave` is decomposed into:
```bash
  cumsum = repeat.cumsum(0)
  pos = torch.arange(output_size, device=repeat.device)
  indices = torch.searchsorted(cumsum, pos, right=True)
```
`searchsorted` op with `right=True` returns the insertion point after matching elements. When query values `pos` are `>= cumsum[-1]`, searchsorted returns `len(cumsum)`, which is out of bounds for indexing (valid range: `[0, len(cumsum)-1]`). These invalid indices trigger CUDA device-side assert errors in downstream indexing operations.

This fix adds clamping to ensure all indices stay within the valid range [0, repeat.size(0)-1].

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165368
Approved by: https://github.com/mlazos
This commit is contained in:
karthickai
2025-10-14 11:43:58 -07:00
committed by PyTorch MergeBot
parent 102b7885ff
commit a63ab0b8cd
2 changed files with 34 additions and 1 deletions

View File

@ -14268,6 +14268,38 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertTrue("'enable_fp_fusion': False" in code)
torch.testing.assert_close(out, fn(a, b), atol=0, rtol=0)
@skip_if_cpp_wrapper("skip cpp wrapper")
@requires_cuda_and_triton
def test_repeat_interleave_decomposition_has_clamp(self):
repeat = torch.ones(2560, dtype=torch.int64, device=GPU_TYPE)
output_size = 505450
data = torch.arange(2560, device=GPU_TYPE)
if is_dynamic_shape_enabled():
raise unittest.SkipTest(
"repeat_interleave decomp doesn't support dynamic output size"
)
@torch.compile
def fn(repeat, output_size, data):
indices = torch.ops.aten.repeat_interleave.Tensor(
repeat, output_size=output_size
)
return data[indices]
result, code = run_and_get_code(fn, repeat, output_size, data)
self.assertEqual(result.shape[0], output_size)
self.assertTrue(torch.all(result >= 0).item())
self.assertTrue(torch.all(result < 2560).item())
code_str = "\n".join(code)
self.assertIn(
"triton_helpers.minimum",
code_str,
"Generated Triton code should use triton_helpers.minimum for clamping",
)
# end of class CommonTemplate - add new tests here

View File

@ -1188,9 +1188,10 @@ def repeat_interleave_Tensor(
assert repeat.ndim == 1
cumsum = repeat.cumsum(0)
pos = torch.arange(output_size, device=repeat.device)
return torch.searchsorted(
indices = torch.searchsorted(
cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True
)
return torch.clamp(indices, max=repeat.size(0) - 1)
# intentionally not regiestered