mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Fix out-of-bounds indices in repeat_interleave decomposition (#165368)
When `repeat_interleave` is decomposed into: ```bash cumsum = repeat.cumsum(0) pos = torch.arange(output_size, device=repeat.device) indices = torch.searchsorted(cumsum, pos, right=True) ``` `searchsorted` op with `right=True` returns the insertion point after matching elements. When query values `pos` are `>= cumsum[-1]`, searchsorted returns `len(cumsum)`, which is out of bounds for indexing (valid range: `[0, len(cumsum)-1]`). These invalid indices trigger CUDA device-side assert errors in downstream indexing operations. This fix adds clamping to ensure all indices stay within the valid range [0, repeat.size(0)-1]. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165368 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
102b7885ff
commit
a63ab0b8cd
@ -14268,6 +14268,38 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
self.assertTrue("'enable_fp_fusion': False" in code)
|
||||
torch.testing.assert_close(out, fn(a, b), atol=0, rtol=0)
|
||||
|
||||
@skip_if_cpp_wrapper("skip cpp wrapper")
|
||||
@requires_cuda_and_triton
|
||||
def test_repeat_interleave_decomposition_has_clamp(self):
|
||||
repeat = torch.ones(2560, dtype=torch.int64, device=GPU_TYPE)
|
||||
output_size = 505450
|
||||
data = torch.arange(2560, device=GPU_TYPE)
|
||||
|
||||
if is_dynamic_shape_enabled():
|
||||
raise unittest.SkipTest(
|
||||
"repeat_interleave decomp doesn't support dynamic output size"
|
||||
)
|
||||
|
||||
@torch.compile
|
||||
def fn(repeat, output_size, data):
|
||||
indices = torch.ops.aten.repeat_interleave.Tensor(
|
||||
repeat, output_size=output_size
|
||||
)
|
||||
return data[indices]
|
||||
|
||||
result, code = run_and_get_code(fn, repeat, output_size, data)
|
||||
|
||||
self.assertEqual(result.shape[0], output_size)
|
||||
self.assertTrue(torch.all(result >= 0).item())
|
||||
self.assertTrue(torch.all(result < 2560).item())
|
||||
|
||||
code_str = "\n".join(code)
|
||||
self.assertIn(
|
||||
"triton_helpers.minimum",
|
||||
code_str,
|
||||
"Generated Triton code should use triton_helpers.minimum for clamping",
|
||||
)
|
||||
|
||||
# end of class CommonTemplate - add new tests here
|
||||
|
||||
|
||||
|
@ -1188,9 +1188,10 @@ def repeat_interleave_Tensor(
|
||||
assert repeat.ndim == 1
|
||||
cumsum = repeat.cumsum(0)
|
||||
pos = torch.arange(output_size, device=repeat.device)
|
||||
return torch.searchsorted(
|
||||
indices = torch.searchsorted(
|
||||
cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True
|
||||
)
|
||||
return torch.clamp(indices, max=repeat.size(0) - 1)
|
||||
|
||||
|
||||
# intentionally not regiestered
|
||||
|
Reference in New Issue
Block a user