Add dtype checks in meta dispatch for various ordering ops (#159556)

This adds data type checks for the unsupported bool and complex types for argmax/min topk, sort, minimum, maximum. As listed here:

0a99b026d6/torch/testing/_internal/common_methods_invocations.py (L21076)

Currently the ops will fail on CPU or CUDA calculation, rather than at meta dispatch stage as with for example max: 0a99b026d6/aten/src/ATen/native/TensorCompare.cpp (L285) . This will catch it early.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159556
Approved by: https://github.com/janeyx99
This commit is contained in:
Matthew Haddock
2025-08-14 17:06:23 +00:00
committed by PyTorch MergeBot
parent cd8d8c18f5
commit 077cb38974
4 changed files with 71 additions and 6 deletions

View File

@ -220,6 +220,8 @@ static void check_argmax_argmin(
const char* name,
const Tensor& self,
const std::optional<int64_t>& dim) {
TORCH_CHECK(!self.is_complex(), name, ": does not support complex input");
TORCH_CHECK(!(self.scalar_type() == kBool), name, ": does not support bool input");
if (dim.has_value()) {
auto dim_ = maybe_wrap_dim(dim.value(), self.dim());
native::zero_numel_check_dims(self, dim_, name);

View File

@ -59,6 +59,8 @@ TORCH_META_FUNC(topk)
"selected index k out of range");
int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim);
TORCH_CHECK(k >= 0 && k <= sliceSize, "k not in range for dimension");
TORCH_CHECK(!self.is_complex(), " topk does not support complex dtypes on CPU");
TORCH_CHECK(!(self.scalar_type() == kBool), "topk does not support bool dtypes on CPU");
// Build the output size, which is the dim being selected set to
// size k
@ -74,11 +76,7 @@ TORCH_META_FUNC2(sort, stable)
(const Tensor& self, std::optional<bool> stable, int64_t dim, bool descending) {
maybe_wrap_dim(dim, self.dim());
const auto self_dtype = self.dtype();
TORCH_CHECK_VALUE(
self_dtype != ScalarType::ComplexFloat &&
self_dtype != ScalarType::ComplexDouble,
"Sort currently does not support complex dtypes on CPU.");
TORCH_CHECK(!self.is_complex(), " Sort does not support complex dtypes on CPU");
// See issue: https://github.com/pytorch/pytorch/issues/65863
// Strides should be dense, so as not to allocate too much memory.

View File

@ -399,6 +399,38 @@ skip_noncontig = {
"as_strided_copy",
}
bool_unsupported_ordered_ops = {
"topk",
"argmin",
"ceil",
"argmax",
"floor",
}
bool_ordered_op_db = tuple(
filter(lambda op: op.name in bool_unsupported_ordered_ops, op_db)
)
complex_unsupported_ordered_ops = {
"sort",
"topk",
"lt",
"argmin",
"le",
"ge",
"amax",
"maximum",
"minimum",
"clamp",
"amin",
"gt",
"ceil",
"argmax",
"floor",
}
complex_ordered_op_db = tuple(
filter(lambda op: op.name in complex_unsupported_ordered_ops, op_db)
)
@unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant")
@unMarkDynamoStrictTest
@ -2968,6 +3000,39 @@ class TestOperators(TestCase):
actual_fn(torch.ones_like(actual_o)),
)
@ops(bool_ordered_op_db, dtypes=[torch.bool])
def test_ordered_bool_raises(self, device, dtype, op):
# Generate sample inputs for the op
sample_inputs = op.sample_inputs(device, dtype)
for sample_input in sample_inputs:
# Check that the op raises NotImplementedError or appropriate failure
self.assertRaises(
RuntimeError,
op,
sample_input.input,
*sample_input.args,
**sample_input.kwargs,
)
@ops(
complex_ordered_op_db,
dtypes=[torch.complex32, torch.complex64, torch.complex128],
)
def test_ordered_complex_raises(self, device, dtype, op):
# Generate sample inputs for the op
sample_inputs = op.sample_inputs(device, dtype)
for sample_input in sample_inputs:
# Check that the op raises NotImplementedError or appropriate failure
self.assertRaises(
RuntimeError,
op,
sample_input.input,
*sample_input.args,
**sample_input.kwargs,
)
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)

View File

@ -179,7 +179,7 @@ class TestSortAndSelect(TestCase):
def test_complex_unsupported_cpu(self):
x = torch.tensor([3.0 + 2j, 4.0 + 3j])
with self.assertRaisesRegex(
ValueError, "Sort currently does not support complex dtypes on CPU."
RuntimeError, " Sort does not support complex dtypes on CPU"
):
torch.sort(input=x)