Add uint16 support for observer (#136238)

Summary:
att

Test Plan:
python test/test_quantization.py -k TestObserver

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D62909821](https://our.internmc.facebook.com/intern/diff/D62909821)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136238
Approved by: https://github.com/tarun292
This commit is contained in:
Jerry Zhang
2024-09-18 14:12:45 -07:00
committed by PyTorch MergeBot
parent 068c80e6b6
commit f2b0fc89f2
3 changed files with 36 additions and 5 deletions

View File

@ -67,14 +67,30 @@ from torch.testing._internal.common_quantization import (
NP_RANDOM_SEED = 19
tolerance = 1e-6
# copy and modified from torch/ao/quantization/observer.py
_INT_DTYPES = (
torch.qint8,
torch.quint8,
torch.quint4x2,
torch.qint32,
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.uint16,
)
class TestObserver(QuantizationTestCase):
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8, torch.qint32)),
@given(qdtype=st.sampled_from(_INT_DTYPES),
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
reduce_range=st.booleans())
def test_per_tensor_observers(self, qdtype, qscheme, reduce_range):
# reduce_range cannot be true for symmetric quantization with uint8
if (qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric) or qdtype == torch.qint32:
reduce_range = False
if qdtype == torch.quint4x2:
return
ObserverList = [MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range),
MovingAverageMinMaxObserver(averaging_constant=0.5,
dtype=qdtype,
@ -82,18 +98,23 @@ class TestObserver(QuantizationTestCase):
reduce_range=reduce_range)]
def _get_ref_params(reduce_range, qscheme, dtype, input_scale, min_val, max_val):
assert dtype in _INT_DTYPES, "Not supported dtype: {dtype}, supported dtypes are {_INT_DTYPES}"
eps = torch.tensor([tolerance])
if dtype == torch.qint8:
if dtype in [torch.qint8, torch.int8]:
if reduce_range:
quant_min, quant_max = -64, 63
else:
quant_min, quant_max = -128, 127
elif dtype == torch.quint8:
elif dtype in [torch.quint8, torch.uint8]:
if reduce_range:
quant_min, quant_max = 0, 127
else:
quant_min, quant_max = 0, 255
elif dtype == torch.qint32:
elif dtype == torch.int16:
quant_min, quant_max = -1 * (2 ** 15), (2 ** 15) - 1
elif dtype == torch.uint16:
quant_min, quant_max = 0, (2 ** 16) - 1
elif dtype in [torch.qint32, torch.int32]:
quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1
min_val_neg = torch.tensor([0.])
@ -103,12 +124,15 @@ class TestObserver(QuantizationTestCase):
if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric:
scale = torch.max(-min_val_neg, max_val_pos) / (float(quant_max - quant_min) / 2)
scale = torch.max(scale, eps)
if dtype == torch.quint8:
if dtype in [torch.quint8, torch.uint8]:
zero_point = 128
if dtype in [torch.uint16]:
zero_point = 2 ** 15
else:
scale = torch.max((max_val_pos - min_val_neg) / float(quant_max - quant_min), eps)
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
return scale, zero_point
for myobs in ObserverList:

View File

@ -253,6 +253,7 @@ class UniformQuantizationObserverBase(ObserverBase):
torch.int32,
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.uint16,
)
assert (
@ -368,6 +369,8 @@ class UniformQuantizationObserverBase(ObserverBase):
)
else:
zero_point = zero_point.new_full(zero_point.size(), 128)
elif self.dtype in [torch.uint16]:
zero_point = zero_point.new_full(zero_point.size(), 2**15)
elif self.qscheme == torch.per_channel_affine_float_qparams:
scale = (max_val - min_val) / float(quant_max - quant_min)
scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))

View File

@ -473,6 +473,10 @@ def calculate_qmin_qmax(
quant_min, quant_max = 0, 255
elif dtype in [torch.qint32, torch.int32]:
quant_min, quant_max = -1 * (2**31), (2**31) - 1
elif dtype in [torch.uint16]:
quant_min, quant_max = 0, 2**16 - 1
elif dtype in [torch.int16]:
quant_min, quant_max = -(2**15), 2**15 - 1
else:
quant_min, quant_max = 0, 15
return quant_min, quant_max