mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Fix] Inbound check of sorter indices in searchsorted (#95109)
Fixes https://github.com/pytorch/pytorch/issues/91606, but in C++14 style. Prior fix (https://github.com/pytorch/pytorch/pull/94863) was in C++17 which might violate some builds. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95109 Approved by: https://github.com/ngimel
This commit is contained in:
@ -134,6 +134,13 @@ inline void searchsorted_pre_check(
|
||||
|
||||
TORCH_CHECK(sorter.scalar_type() == ScalarType::Long, "torch.searchsorted(): sorter must be a tensor of long ",
|
||||
"dtype but got dtype ", sorter.scalar_type());
|
||||
|
||||
if (sorter.numel() > 0) {
|
||||
auto minmax = sorter.aminmax();
|
||||
int64_t vmin = std::get<0>(minmax).item().toLong();
|
||||
int64_t vmax = std::get<1>(minmax).item().toLong();
|
||||
TORCH_CHECK(vmin >= 0 && vmax < sorter.sizes().back(), "torch.searchsorted(): sorter index out of range");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1),
|
||||
|
@ -1563,6 +1563,14 @@ class TestReductions(TestCase):
|
||||
_, sorted_idx = torch.sort(sequence)
|
||||
torch.searchsorted(sequence, values_1d, sorter=sorted_idx.to(torch.float32))
|
||||
|
||||
# invalid sorter value, out of bound (>= innermost size)
|
||||
with self.assertRaisesRegex(RuntimeError, "sorter index out of range"):
|
||||
torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([0, 1, 3]))
|
||||
|
||||
# invalid sorter value, out of bound (< 0)
|
||||
with self.assertRaisesRegex(RuntimeError, "sorter index out of range"):
|
||||
torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([-1, 1, 2]))
|
||||
|
||||
# scalar type bfloat16
|
||||
if self.device_type == 'cpu':
|
||||
def test_dtype_bfloat16(values_bf16=False, boundaries_bf16=False):
|
||||
|
Reference in New Issue
Block a user