diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index 7d44463fa830..162468ecc711 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -34,6 +34,11 @@ namespace at::native { Tensor flip(const Tensor& self, IntArrayRef dims) { + TORCH_CHECK( + self.scalar_type() != at::kQUInt4x2 && + self.scalar_type() != at::kQUInt2x4, + "flip is not supported for tensor with data type ", + self.scalar_type()); const int64_t total_dims = self.dim(); // It wraps the dims and checks that there are no repeated dims auto flip_dims_b = at::dim_list_to_bitset(dims, total_dims); diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py index f89bb81745f8..a60b04c4f8b5 100644 --- a/test/test_shape_ops.py +++ b/test/test_shape_ops.py @@ -581,6 +581,16 @@ class TestShapeOps(TestCase): self.compare_with_numpy(torch_fn, np_fn, t_in) del t_in + @onlyCPU + @unittest.expectedFailure + @dtypes(torch.quint4x2, torch.quint2x4) + def test_flip_unsupported_dtype(self, dtype): + scale, zero_point = 0.1, 5 + qt = torch.quantize_per_tensor( + torch.randn(16, 16), scale=scale, zero_point=zero_point, dtype=dtype + ) + torch.flip(qt, dims=(0,)) + def _test_fliplr_flipud(self, torch_fn, np_fn, min_dim, max_dim, device, dtype): for dim in range(min_dim, max_dim + 1): shape = self._rand_shape(dim, 5, 10)