mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add dtype checks in meta dispatch for various ordering ops (#159556)
This adds data type checks for the unsupported bool and complex types for argmax/min topk, sort, minimum, maximum. As listed here:0a99b026d6/torch/testing/_internal/common_methods_invocations.py (L21076)
Currently the ops will fail on CPU or CUDA calculation, rather than at meta dispatch stage as with for example max:0a99b026d6/aten/src/ATen/native/TensorCompare.cpp (L285)
. This will catch it early. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159556 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
cd8d8c18f5
commit
077cb38974
@ -220,6 +220,8 @@ static void check_argmax_argmin(
|
||||
const char* name,
|
||||
const Tensor& self,
|
||||
const std::optional<int64_t>& dim) {
|
||||
TORCH_CHECK(!self.is_complex(), name, ": does not support complex input");
|
||||
TORCH_CHECK(!(self.scalar_type() == kBool), name, ": does not support bool input");
|
||||
if (dim.has_value()) {
|
||||
auto dim_ = maybe_wrap_dim(dim.value(), self.dim());
|
||||
native::zero_numel_check_dims(self, dim_, name);
|
||||
|
@ -59,6 +59,8 @@ TORCH_META_FUNC(topk)
|
||||
"selected index k out of range");
|
||||
int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim);
|
||||
TORCH_CHECK(k >= 0 && k <= sliceSize, "k not in range for dimension");
|
||||
TORCH_CHECK(!self.is_complex(), " topk does not support complex dtypes on CPU");
|
||||
TORCH_CHECK(!(self.scalar_type() == kBool), "topk does not support bool dtypes on CPU");
|
||||
|
||||
// Build the output size, which is the dim being selected set to
|
||||
// size k
|
||||
@ -74,11 +76,7 @@ TORCH_META_FUNC2(sort, stable)
|
||||
(const Tensor& self, std::optional<bool> stable, int64_t dim, bool descending) {
|
||||
maybe_wrap_dim(dim, self.dim());
|
||||
|
||||
const auto self_dtype = self.dtype();
|
||||
TORCH_CHECK_VALUE(
|
||||
self_dtype != ScalarType::ComplexFloat &&
|
||||
self_dtype != ScalarType::ComplexDouble,
|
||||
"Sort currently does not support complex dtypes on CPU.");
|
||||
TORCH_CHECK(!self.is_complex(), " Sort does not support complex dtypes on CPU");
|
||||
|
||||
// See issue: https://github.com/pytorch/pytorch/issues/65863
|
||||
// Strides should be dense, so as not to allocate too much memory.
|
||||
|
@ -399,6 +399,38 @@ skip_noncontig = {
|
||||
"as_strided_copy",
|
||||
}
|
||||
|
||||
bool_unsupported_ordered_ops = {
|
||||
"topk",
|
||||
"argmin",
|
||||
"ceil",
|
||||
"argmax",
|
||||
"floor",
|
||||
}
|
||||
bool_ordered_op_db = tuple(
|
||||
filter(lambda op: op.name in bool_unsupported_ordered_ops, op_db)
|
||||
)
|
||||
|
||||
complex_unsupported_ordered_ops = {
|
||||
"sort",
|
||||
"topk",
|
||||
"lt",
|
||||
"argmin",
|
||||
"le",
|
||||
"ge",
|
||||
"amax",
|
||||
"maximum",
|
||||
"minimum",
|
||||
"clamp",
|
||||
"amin",
|
||||
"gt",
|
||||
"ceil",
|
||||
"argmax",
|
||||
"floor",
|
||||
}
|
||||
complex_ordered_op_db = tuple(
|
||||
filter(lambda op: op.name in complex_unsupported_ordered_ops, op_db)
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant")
|
||||
@unMarkDynamoStrictTest
|
||||
@ -2968,6 +3000,39 @@ class TestOperators(TestCase):
|
||||
actual_fn(torch.ones_like(actual_o)),
|
||||
)
|
||||
|
||||
@ops(bool_ordered_op_db, dtypes=[torch.bool])
|
||||
def test_ordered_bool_raises(self, device, dtype, op):
|
||||
# Generate sample inputs for the op
|
||||
sample_inputs = op.sample_inputs(device, dtype)
|
||||
|
||||
for sample_input in sample_inputs:
|
||||
# Check that the op raises NotImplementedError or appropriate failure
|
||||
self.assertRaises(
|
||||
RuntimeError,
|
||||
op,
|
||||
sample_input.input,
|
||||
*sample_input.args,
|
||||
**sample_input.kwargs,
|
||||
)
|
||||
|
||||
@ops(
|
||||
complex_ordered_op_db,
|
||||
dtypes=[torch.complex32, torch.complex64, torch.complex128],
|
||||
)
|
||||
def test_ordered_complex_raises(self, device, dtype, op):
|
||||
# Generate sample inputs for the op
|
||||
sample_inputs = op.sample_inputs(device, dtype)
|
||||
|
||||
for sample_input in sample_inputs:
|
||||
# Check that the op raises NotImplementedError or appropriate failure
|
||||
self.assertRaises(
|
||||
RuntimeError,
|
||||
op,
|
||||
sample_input.input,
|
||||
*sample_input.args,
|
||||
**sample_input.kwargs,
|
||||
)
|
||||
|
||||
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
|
||||
|
@ -179,7 +179,7 @@ class TestSortAndSelect(TestCase):
|
||||
def test_complex_unsupported_cpu(self):
|
||||
x = torch.tensor([3.0 + 2j, 4.0 + 3j])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Sort currently does not support complex dtypes on CPU."
|
||||
RuntimeError, " Sort does not support complex dtypes on CPU"
|
||||
):
|
||||
torch.sort(input=x)
|
||||
|
||||
|
Reference in New Issue
Block a user