[Fix] Inbound check of sorter indices in searchsorted (#95109)

Fixes https://github.com/pytorch/pytorch/issues/91606, but in C++14 style.

Prior fix (https://github.com/pytorch/pytorch/pull/94863) was in C++17 which might violate some builds.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95109
Approved by: https://github.com/ngimel
This commit is contained in:
ganler
2023-02-20 04:59:08 +00:00
committed by PyTorch MergeBot
parent 286d821e61
commit 3dcf8b6140
2 changed files with 15 additions and 0 deletions

View File

@ -134,6 +134,13 @@ inline void searchsorted_pre_check(
TORCH_CHECK(sorter.scalar_type() == ScalarType::Long, "torch.searchsorted(): sorter must be a tensor of long ",
"dtype but got dtype ", sorter.scalar_type());
if (sorter.numel() > 0) {
auto minmax = sorter.aminmax();
int64_t vmin = std::get<0>(minmax).item().toLong();
int64_t vmax = std::get<1>(minmax).item().toLong();
TORCH_CHECK(vmin >= 0 && vmax < sorter.sizes().back(), "torch.searchsorted(): sorter index out of range");
}
}
TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1),

View File

@ -1563,6 +1563,14 @@ class TestReductions(TestCase):
_, sorted_idx = torch.sort(sequence)
torch.searchsorted(sequence, values_1d, sorter=sorted_idx.to(torch.float32))
# invalid sorter value, out of bound (>= innermost size)
with self.assertRaisesRegex(RuntimeError, "sorter index out of range"):
torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([0, 1, 3]))
# invalid sorter value, out of bound (< 0)
with self.assertRaisesRegex(RuntimeError, "sorter index out of range"):
torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([-1, 1, 2]))
# scalar type bfloat16
if self.device_type == 'cpu':
def test_dtype_bfloat16(values_bf16=False, boundaries_bf16=False):