mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add uint16 support for observer (#136238)
Summary: att Test Plan: python test/test_quantization.py -k TestObserver Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D62909821](https://our.internmc.facebook.com/intern/diff/D62909821) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136238 Approved by: https://github.com/tarun292
This commit is contained in:
committed by
PyTorch MergeBot
parent
068c80e6b6
commit
f2b0fc89f2
@ -67,14 +67,30 @@ from torch.testing._internal.common_quantization import (
|
||||
NP_RANDOM_SEED = 19
|
||||
tolerance = 1e-6
|
||||
|
||||
# copy and modified from torch/ao/quantization/observer.py
|
||||
_INT_DTYPES = (
|
||||
torch.qint8,
|
||||
torch.quint8,
|
||||
torch.quint4x2,
|
||||
torch.qint32,
|
||||
torch.int8,
|
||||
torch.uint8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.uint16,
|
||||
)
|
||||
|
||||
class TestObserver(QuantizationTestCase):
|
||||
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8, torch.qint32)),
|
||||
@given(qdtype=st.sampled_from(_INT_DTYPES),
|
||||
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
|
||||
reduce_range=st.booleans())
|
||||
def test_per_tensor_observers(self, qdtype, qscheme, reduce_range):
|
||||
# reduce_range cannot be true for symmetric quantization with uint8
|
||||
if (qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric) or qdtype == torch.qint32:
|
||||
reduce_range = False
|
||||
if qdtype == torch.quint4x2:
|
||||
return
|
||||
|
||||
ObserverList = [MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range),
|
||||
MovingAverageMinMaxObserver(averaging_constant=0.5,
|
||||
dtype=qdtype,
|
||||
@ -82,18 +98,23 @@ class TestObserver(QuantizationTestCase):
|
||||
reduce_range=reduce_range)]
|
||||
|
||||
def _get_ref_params(reduce_range, qscheme, dtype, input_scale, min_val, max_val):
|
||||
assert dtype in _INT_DTYPES, "Not supported dtype: {dtype}, supported dtypes are {_INT_DTYPES}"
|
||||
eps = torch.tensor([tolerance])
|
||||
if dtype == torch.qint8:
|
||||
if dtype in [torch.qint8, torch.int8]:
|
||||
if reduce_range:
|
||||
quant_min, quant_max = -64, 63
|
||||
else:
|
||||
quant_min, quant_max = -128, 127
|
||||
elif dtype == torch.quint8:
|
||||
elif dtype in [torch.quint8, torch.uint8]:
|
||||
if reduce_range:
|
||||
quant_min, quant_max = 0, 127
|
||||
else:
|
||||
quant_min, quant_max = 0, 255
|
||||
elif dtype == torch.qint32:
|
||||
elif dtype == torch.int16:
|
||||
quant_min, quant_max = -1 * (2 ** 15), (2 ** 15) - 1
|
||||
elif dtype == torch.uint16:
|
||||
quant_min, quant_max = 0, (2 ** 16) - 1
|
||||
elif dtype in [torch.qint32, torch.int32]:
|
||||
quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1
|
||||
|
||||
min_val_neg = torch.tensor([0.])
|
||||
@ -103,12 +124,15 @@ class TestObserver(QuantizationTestCase):
|
||||
if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric:
|
||||
scale = torch.max(-min_val_neg, max_val_pos) / (float(quant_max - quant_min) / 2)
|
||||
scale = torch.max(scale, eps)
|
||||
if dtype == torch.quint8:
|
||||
if dtype in [torch.quint8, torch.uint8]:
|
||||
zero_point = 128
|
||||
if dtype in [torch.uint16]:
|
||||
zero_point = 2 ** 15
|
||||
else:
|
||||
scale = torch.max((max_val_pos - min_val_neg) / float(quant_max - quant_min), eps)
|
||||
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
|
||||
zero_point = torch.clamp(zero_point, quant_min, quant_max)
|
||||
|
||||
return scale, zero_point
|
||||
|
||||
for myobs in ObserverList:
|
||||
|
@ -253,6 +253,7 @@ class UniformQuantizationObserverBase(ObserverBase):
|
||||
torch.int32,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.uint16,
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -368,6 +369,8 @@ class UniformQuantizationObserverBase(ObserverBase):
|
||||
)
|
||||
else:
|
||||
zero_point = zero_point.new_full(zero_point.size(), 128)
|
||||
elif self.dtype in [torch.uint16]:
|
||||
zero_point = zero_point.new_full(zero_point.size(), 2**15)
|
||||
elif self.qscheme == torch.per_channel_affine_float_qparams:
|
||||
scale = (max_val - min_val) / float(quant_max - quant_min)
|
||||
scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
|
||||
|
@ -473,6 +473,10 @@ def calculate_qmin_qmax(
|
||||
quant_min, quant_max = 0, 255
|
||||
elif dtype in [torch.qint32, torch.int32]:
|
||||
quant_min, quant_max = -1 * (2**31), (2**31) - 1
|
||||
elif dtype in [torch.uint16]:
|
||||
quant_min, quant_max = 0, 2**16 - 1
|
||||
elif dtype in [torch.int16]:
|
||||
quant_min, quant_max = -(2**15), 2**15 - 1
|
||||
else:
|
||||
quant_min, quant_max = 0, 15
|
||||
return quant_min, quant_max
|
||||
|
Reference in New Issue
Block a user