mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Back out "[quant][observer] Add histogram observer" (#26236)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26236 Original diff broke oss CI. Reverting. Original commit changeset: 0f047d3349cb ghstack-source-id: 90125990 Test Plan: testinprod Reviewed By: hx89 Differential Revision: D17385490 fbshipit-source-id: 4258502bbc0e3a6dd6852c8ce01ed05eee618b1a
This commit is contained in:
committed by
Facebook Github Bot
parent
3051e36e05
commit
9f6b6b8101
@ -1,5 +1,3 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -11,7 +9,7 @@ from torch.quantization import \
|
||||
QConfig_dynamic, default_weight_observer, dump_tensor,\
|
||||
quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \
|
||||
quantize_dynamic, default_qconfig, default_debug_qconfig, default_qat_qconfig, \
|
||||
default_dynamic_qconfig, QuantWrapper, TensorObserver, MinMaxObserver, HistogramObserver
|
||||
default_dynamic_qconfig, MinMaxObserver, TensorObserver, QuantWrapper
|
||||
|
||||
from common_utils import run_tests
|
||||
from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \
|
||||
@ -777,8 +775,8 @@ class ObserverTest(QuantizationTestCase):
|
||||
self.assertEqual(qparams[1].item(), ref_zero_point)
|
||||
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
|
||||
|
||||
@given(obs=st.sampled_from((torch.quantization.default_observer()(), HistogramObserver(bins=10))))
|
||||
def test_observer_scriptable(self, obs):
|
||||
def test_observer_scriptable(self):
|
||||
obs = torch.quantization.default_observer()()
|
||||
scripted = torch.jit.script(obs)
|
||||
|
||||
x = torch.rand(3, 4)
|
||||
@ -829,35 +827,5 @@ class QuantizationDebugTest(QuantizationTestCase):
|
||||
loaded = torch.jit.load(buf)
|
||||
self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0]))
|
||||
|
||||
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
|
||||
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
|
||||
reduce_range=st.booleans())
|
||||
def test_histogram_observer(self, qdtype, qscheme, reduce_range):
|
||||
myobs = HistogramObserver(bins=10, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
|
||||
x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||
y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0])
|
||||
myobs(x)
|
||||
myobs(y)
|
||||
self.assertEqual(myobs.min_val, -1.5)
|
||||
self.assertEqual(myobs.max_val, 8.5)
|
||||
self.assertEqual(myobs.histogram, [0., 0., 1., 2., 1., 2., 3., 2., 1., 1.])
|
||||
qparams = myobs.calculate_qparams()
|
||||
if reduce_range:
|
||||
if qscheme == torch.per_tensor_symmetric:
|
||||
ref_scale = 0.066666 * 255 / 127
|
||||
ref_zero_point = 0 if qdtype is torch.qint8 else 128
|
||||
else:
|
||||
ref_scale = 0.0333333 * 255 / 127
|
||||
ref_zero_point = -64 if qdtype is torch.qint8 else 0
|
||||
else:
|
||||
if qscheme == torch.per_tensor_symmetric:
|
||||
ref_scale = 0.066666
|
||||
ref_zero_point = 0 if qdtype is torch.qint8 else 128
|
||||
else:
|
||||
ref_scale = 0.0333333
|
||||
ref_zero_point = -128 if qdtype is torch.qint8 else 0
|
||||
self.assertEqual(qparams[1].item(), ref_zero_point)
|
||||
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user