diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 5f9d5c85750b..db046428bb68 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -220,6 +220,8 @@ static void check_argmax_argmin( const char* name, const Tensor& self, const std::optional& dim) { + TORCH_CHECK(!self.is_complex(), name, ": does not support complex input"); + TORCH_CHECK(!(self.scalar_type() == kBool), name, ": does not support bool input"); if (dim.has_value()) { auto dim_ = maybe_wrap_dim(dim.value(), self.dim()); native::zero_numel_check_dims(self, dim_, name); diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index 1bdc806a3b4e..44215a26018f 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -59,6 +59,8 @@ TORCH_META_FUNC(topk) "selected index k out of range"); int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim); TORCH_CHECK(k >= 0 && k <= sliceSize, "k not in range for dimension"); + TORCH_CHECK(!self.is_complex(), " topk does not support complex dtypes on CPU"); + TORCH_CHECK(!(self.scalar_type() == kBool), "topk does not support bool dtypes on CPU"); // Build the output size, which is the dim being selected set to // size k @@ -74,11 +76,7 @@ TORCH_META_FUNC2(sort, stable) (const Tensor& self, std::optional stable, int64_t dim, bool descending) { maybe_wrap_dim(dim, self.dim()); - const auto self_dtype = self.dtype(); - TORCH_CHECK_VALUE( - self_dtype != ScalarType::ComplexFloat && - self_dtype != ScalarType::ComplexDouble, - "Sort currently does not support complex dtypes on CPU."); + TORCH_CHECK(!self.is_complex(), " Sort does not support complex dtypes on CPU"); // See issue: https://github.com/pytorch/pytorch/issues/65863 // Strides should be dense, so as not to allocate too much memory. diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index cef00f83eb72..78e64278cb1e 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -399,6 +399,38 @@ skip_noncontig = { "as_strided_copy", } +bool_unsupported_ordered_ops = { + "topk", + "argmin", + "ceil", + "argmax", + "floor", +} +bool_ordered_op_db = tuple( + filter(lambda op: op.name in bool_unsupported_ordered_ops, op_db) +) + +complex_unsupported_ordered_ops = { + "sort", + "topk", + "lt", + "argmin", + "le", + "ge", + "amax", + "maximum", + "minimum", + "clamp", + "amin", + "gt", + "ceil", + "argmax", + "floor", +} +complex_ordered_op_db = tuple( + filter(lambda op: op.name in complex_unsupported_ordered_ops, op_db) +) + @unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant") @unMarkDynamoStrictTest @@ -2968,6 +3000,39 @@ class TestOperators(TestCase): actual_fn(torch.ones_like(actual_o)), ) + @ops(bool_ordered_op_db, dtypes=[torch.bool]) + def test_ordered_bool_raises(self, device, dtype, op): + # Generate sample inputs for the op + sample_inputs = op.sample_inputs(device, dtype) + + for sample_input in sample_inputs: + # Check that the op raises NotImplementedError or appropriate failure + self.assertRaises( + RuntimeError, + op, + sample_input.input, + *sample_input.args, + **sample_input.kwargs, + ) + + @ops( + complex_ordered_op_db, + dtypes=[torch.complex32, torch.complex64, torch.complex128], + ) + def test_ordered_complex_raises(self, device, dtype, op): + # Generate sample inputs for the op + sample_inputs = op.sample_inputs(device, dtype) + + for sample_input in sample_inputs: + # Check that the op raises NotImplementedError or appropriate failure + self.assertRaises( + RuntimeError, + op, + sample_input.input, + *sample_input.args, + **sample_input.kwargs, + ) + only_for = ("cpu", "cuda") instantiate_device_type_tests(TestOperators, globals(), only_for=only_for) diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 669f165529e7..5be175818646 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -179,7 +179,7 @@ class TestSortAndSelect(TestCase): def test_complex_unsupported_cpu(self): x = torch.tensor([3.0 + 2j, 4.0 + 3j]) with self.assertRaisesRegex( - ValueError, "Sort currently does not support complex dtypes on CPU." + RuntimeError, " Sort does not support complex dtypes on CPU" ): torch.sort(input=x)