mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 14:15:07 +08:00
[Inductor] Fix out-of-bounds indices in repeat_interleave decomposition
[ghstack-poisoned]
This commit is contained in:
@ -12279,7 +12279,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
|
||||
@requires_gpu()
|
||||
@skip_if_not_triton
|
||||
@skip_if_cpp_wrapper("skip cpp_wrapper tests")
|
||||
@skip_if_cpp_wrapper("run_and_get_kernels issue")
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
def test_generated_code_has_size_stride_assert(self):
|
||||
def foo(x):
|
||||
@ -14245,6 +14245,33 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
self.assertTrue("'enable_fp_fusion': False" in code)
|
||||
torch.testing.assert_close(out, fn(a, b), atol=0, rtol=0)
|
||||
|
||||
@skip_if_cpp_wrapper("skip cpp wrapper")
|
||||
@requires_cuda_and_triton
|
||||
def test_repeat_interleave_decomposition_has_clamp(self):
|
||||
repeat = torch.ones(2560, dtype=torch.int64, device=GPU_TYPE)
|
||||
output_size = 505450
|
||||
data = torch.arange(2560, device=GPU_TYPE)
|
||||
|
||||
@torch.compile
|
||||
def fn(repeat, output_size, data):
|
||||
indices = torch.ops.aten.repeat_interleave.Tensor(
|
||||
repeat, output_size=output_size
|
||||
)
|
||||
return data[indices]
|
||||
|
||||
result, code = run_and_get_code(fn, repeat, output_size, data)
|
||||
|
||||
self.assertEqual(result.shape[0], output_size)
|
||||
self.assertTrue(torch.all(result >= 0).item())
|
||||
self.assertTrue(torch.all(result < 2560).item())
|
||||
|
||||
code_str = "\n".join(code)
|
||||
self.assertIn(
|
||||
"triton_helpers.minimum",
|
||||
code_str,
|
||||
"Generated Triton code should use triton_helpers.minimum for clamping",
|
||||
)
|
||||
|
||||
# end of class CommonTemplate - add new tests here
|
||||
|
||||
|
||||
|
||||
@ -1188,9 +1188,10 @@ def repeat_interleave_Tensor(
|
||||
assert repeat.ndim == 1
|
||||
cumsum = repeat.cumsum(0)
|
||||
pos = torch.arange(output_size, device=repeat.device)
|
||||
return torch.searchsorted(
|
||||
indices = torch.searchsorted(
|
||||
cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True
|
||||
)
|
||||
return torch.clamp(indices, max=repeat.size(0) - 1)
|
||||
|
||||
|
||||
# intentionally not regiestered
|
||||
|
||||
Reference in New Issue
Block a user