mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[cpu/sorting] Throw an error when trying to sort complex numbers. (#144113)
It doesn't really make sense to sort complex numbers as they are not comparable. Fixes #129296 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144113 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
78eded8e00
commit
aaf56152ea
@ -74,6 +74,12 @@ TORCH_META_FUNC2(sort, stable)
|
||||
(const Tensor& self, std::optional<bool> stable, int64_t dim, bool descending) {
|
||||
maybe_wrap_dim(dim, self.dim());
|
||||
|
||||
const auto self_dtype = self.dtype();
|
||||
TORCH_CHECK_VALUE(
|
||||
self_dtype != ScalarType::ComplexFloat &&
|
||||
self_dtype != ScalarType::ComplexDouble,
|
||||
"Sort currently does not support complex dtypes on CPU.");
|
||||
|
||||
// See issue: https://github.com/pytorch/pytorch/issues/65863
|
||||
// Strides should be dense, so as not to allocate too much memory.
|
||||
// We either use 'self' strides, or infer dense strides from them.
|
||||
|
@ -175,6 +175,14 @@ class TestSortAndSelect(TestCase):
|
||||
y = x.sort(stable=None).values
|
||||
self.assertTrue(torch.all(y == torch.ones(10)).item())
|
||||
|
||||
@onlyCPU
|
||||
def test_complex_unsupported_cpu(self):
|
||||
x = torch.tensor([3.0 + 2j, 4.0 + 3j])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Sort currently does not support complex dtypes on CPU."
|
||||
):
|
||||
torch.sort(input=x)
|
||||
|
||||
@onlyCUDA
|
||||
def test_sort_large_slice(self, device):
|
||||
# tests direct cub path
|
||||
|
Reference in New Issue
Block a user