mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant][core][improvement][feature] Enabled support for quantized fill of nhwc tensors
Summary: Previously, filling a quantized tensor only worked for nchw tensors. This PR enables support for nhwc tensors. Test cases were added for per tensor and per channel quantized NHWC tensors. Test Plan: ``` python test/test_quantization.py -k test_qtensor_fill_per_channel_nhwc python test/test_quantization.py -k test_qtensor_fill_per_tensor_nhwc ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/79025 Approved by: https://github.com/jerryzh168
This commit is contained in:
@ -28,7 +28,7 @@ Tensor& fill_out(Tensor& self, const Scalar& value) {
|
||||
|
||||
Tensor& fill_out_quantized(Tensor& self, const Scalar& value) {
|
||||
at::Tensor out = at::ones(self.sizes()).to(kFloat) * value;
|
||||
out = out.to(self.device());
|
||||
out = out.to(self.device()).to(self.suggest_memory_format());
|
||||
// Trust the `copy_` to handle the quantization and the boundary chacks.
|
||||
self.copy_(out);
|
||||
return self;
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <ATen/native/quantized/Copy.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/quantized/AffineQuantizer.h>
|
||||
#include <ATen/native/quantized/Copy.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace at {
|
||||
@ -18,8 +18,9 @@ Tensor& quantized_copy_from_float_(Tensor& self, const Tensor& src) {
|
||||
src.scalar_type() == at::kFloat,
|
||||
"Quantized copy only works with kFloat as source Tensor");
|
||||
TORCH_CHECK(
|
||||
self.is_contiguous() && src.is_contiguous(),
|
||||
"Quantized copy only works with contiguous Tensors");
|
||||
(self.is_contiguous() && src.is_contiguous()) ||
|
||||
(self.is_contiguous(at::MemoryFormat::ChannelsLast) && src.is_contiguous(at::MemoryFormat::ChannelsLast)),
|
||||
"Quantized copy only works with contiguous and NHWC Tensors");
|
||||
TORCH_CHECK(
|
||||
self.sizes().equals(src.sizes()),
|
||||
"Quantized copy only works with Tensors with the same shape");
|
||||
|
@ -346,7 +346,7 @@ class TestQuantizedTensor(TestCase):
|
||||
qt1[:] = t2[:]
|
||||
self.assertEqual(qt1[:], qt2[:])
|
||||
# non-contiguous case **this should raise an exception**
|
||||
with self.assertRaisesRegex(RuntimeError, "Quantized copy only works with contiguous Tensors"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Quantized copy only works with contiguous and NHWC Tensors"):
|
||||
qt1[:, 0] = t2[:, 0]
|
||||
|
||||
def test_qtensor_float_assignment(self):
|
||||
@ -953,6 +953,32 @@ class TestQuantizedTensor(TestCase):
|
||||
self.assertEqual(q_filled.q_scale(), scale)
|
||||
self.assertEqual(q_filled.q_zero_point(), zero_point)
|
||||
|
||||
# Adapted from test_qtensor_fill_per_tensor but for a NHWC tensor (requires 4D)
|
||||
def test_qtensor_fill_per_tensor_nhwc(self):
|
||||
dims = torch.randint(low=1, high=10, size=(4, )).tolist()
|
||||
scale = 0.5
|
||||
zero_point = 10
|
||||
|
||||
ones = torch.ones(dims).to(torch.float)
|
||||
|
||||
qtypes = [torch.qint8, torch.quint8, torch.qint32]
|
||||
vals2fill = [-1, 1, 2**32] # positive, negative, overflow
|
||||
memory_formats = [torch.contiguous_format, torch.channels_last]
|
||||
devices = get_supported_device_types()
|
||||
for qtype, val2fill, memory_format, device in itertools.product(qtypes, vals2fill, memory_formats, devices):
|
||||
q_filled = torch._empty_affine_quantized(
|
||||
dims, scale=scale, zero_point=zero_point, device=device,
|
||||
dtype=qtype, memory_format=memory_format)
|
||||
q_filled.fill_(val2fill)
|
||||
# reference tensor for comparing q_filled
|
||||
q_ref = torch.quantize_per_tensor(ones * val2fill, scale,
|
||||
zero_point, qtype)
|
||||
self.assertEqual(q_filled.int_repr(), q_ref.int_repr())
|
||||
self.assertEqual(q_filled.dequantize(), q_ref.dequantize())
|
||||
# Make sure the scale and zero_point don't change
|
||||
self.assertEqual(q_filled.q_scale(), scale)
|
||||
self.assertEqual(q_filled.q_zero_point(), zero_point)
|
||||
|
||||
# adapted from test_qtensor_fill_per_tensor
|
||||
def test_qtensor_fill_per_channel(self):
|
||||
dims = [4, 5]
|
||||
@ -1062,6 +1088,37 @@ class TestQuantizedTensor(TestCase):
|
||||
|
||||
self.assertEqual(qx_ref, qx)
|
||||
|
||||
# adapted from test_qtensor_fill_per_channel and test_qtensor_fill_per_tensor_nhwc
|
||||
def test_qtensor_fill_per_channel_nhwc(self):
|
||||
dims = torch.randint(low=1, high=10, size=(4, )).tolist()
|
||||
axis = 0
|
||||
# adding a constant to avoid too small of a scale
|
||||
scales = torch.rand(dims[axis], dtype=torch.float64) + 0.1
|
||||
zero_points = torch.randint(low=0, high=10, size=(dims[axis], ))
|
||||
|
||||
ones = torch.ones(dims).to(torch.float)
|
||||
|
||||
qtypes = [torch.qint8, torch.quint8, torch.qint32]
|
||||
vals2fill = [-1, 1, 2**32] # positive, negative, overflow
|
||||
memory_formats = [torch.contiguous_format, torch.channels_last]
|
||||
devices = get_supported_device_types()
|
||||
for qtype, val2fill, memory_format, device in itertools.product(qtypes, vals2fill, memory_formats, devices):
|
||||
scales = scales.to(device)
|
||||
zero_points = zero_points.to(device)
|
||||
ones = ones.to(device)
|
||||
q_filled = torch._empty_per_channel_affine_quantized(
|
||||
dims, scales=scales, zero_points=zero_points, device=device,
|
||||
axis=axis, dtype=qtype, memory_format=memory_format)
|
||||
q_filled.fill_(val2fill)
|
||||
# reference tensor for comparing q_filled
|
||||
q_ref = torch.quantize_per_channel(ones * val2fill, scales=scales,
|
||||
zero_points=zero_points, axis=axis, dtype=qtype)
|
||||
self.assertEqual(q_filled.int_repr(), q_ref.int_repr())
|
||||
self.assertEqual(q_filled.dequantize(), q_ref.dequantize())
|
||||
# Make sure the scale and zero_point don't change
|
||||
self.assertEqual(q_filled.q_per_channel_scales(), scales)
|
||||
self.assertEqual(q_filled.q_per_channel_zero_points(), zero_points)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "No gpu is available.")
|
||||
def test_qtensor_index_select_cuda(self):
|
||||
self._test_qtensor_index_select('cuda')
|
||||
|
Reference in New Issue
Block a user