[Quant] flip: throw runtime error for QUInt4x2 and QUInt2x4 input (#147430)

Fixes #147208

**Summary**
The `flip` op causes memory corruption for `torch.quint4x2` and `torch.quint2x4` inputs. It is because the TensorIterator-based implementation does not support multiple elements per byte. And `torch.quint4x2` and `torch.quint2x4` are deprecated in PyTorch. So, we add a check here to throw a runtime error if input dtyps is `torch.quint4x2` or `torch.quint2x4`.

**Test plan**
```
pytest -s test/test_shape_ops.py -k test_flip_unsupported_dtype
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147430
Approved by: https://github.com/mingfeima, https://github.com/ngimel
This commit is contained in:
Xia, Weiwen
2025-02-25 03:47:39 +00:00
committed by PyTorch MergeBot
parent 20295c017e
commit 9478c90e2b
2 changed files with 15 additions and 0 deletions

View File

@ -34,6 +34,11 @@
namespace at::native {
Tensor flip(const Tensor& self, IntArrayRef dims) {
TORCH_CHECK(
self.scalar_type() != at::kQUInt4x2 &&
self.scalar_type() != at::kQUInt2x4,
"flip is not supported for tensor with data type ",
self.scalar_type());
const int64_t total_dims = self.dim();
// It wraps the dims and checks that there are no repeated dims
auto flip_dims_b = at::dim_list_to_bitset(dims, total_dims);

View File

@ -581,6 +581,16 @@ class TestShapeOps(TestCase):
self.compare_with_numpy(torch_fn, np_fn, t_in)
del t_in
@onlyCPU
@unittest.expectedFailure
@dtypes(torch.quint4x2, torch.quint2x4)
def test_flip_unsupported_dtype(self, dtype):
scale, zero_point = 0.1, 5
qt = torch.quantize_per_tensor(
torch.randn(16, 16), scale=scale, zero_point=zero_point, dtype=dtype
)
torch.flip(qt, dims=(0,))
def _test_fliplr_flipud(self, torch_fn, np_fn, min_dim, max_dim, device, dtype):
for dim in range(min_dim, max_dim + 1):
shape = self._rand_shape(dim, 5, 10)