Back out "[quant][observer] Add histogram observer" (#26236)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26236

Original diff broke oss CI. Reverting.

Original commit changeset: 0f047d3349cb
ghstack-source-id: 90125990

Test Plan: testinprod

Reviewed By: hx89

Differential Revision: D17385490

fbshipit-source-id: 4258502bbc0e3a6dd6852c8ce01ed05eee618b1a
This commit is contained in:
Sebastian Messmer
2019-09-14 12:47:18 -07:00
committed by Facebook Github Bot
parent 3051e36e05
commit 9f6b6b8101
2 changed files with 15 additions and 125 deletions

View File

@ -1,5 +1,3 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import unittest
import torch
import torch.nn as nn
@ -11,7 +9,7 @@ from torch.quantization import \
QConfig_dynamic, default_weight_observer, dump_tensor,\
quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \
quantize_dynamic, default_qconfig, default_debug_qconfig, default_qat_qconfig, \
default_dynamic_qconfig, QuantWrapper, TensorObserver, MinMaxObserver, HistogramObserver
default_dynamic_qconfig, MinMaxObserver, TensorObserver, QuantWrapper
from common_utils import run_tests
from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \
@ -777,8 +775,8 @@ class ObserverTest(QuantizationTestCase):
self.assertEqual(qparams[1].item(), ref_zero_point)
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
@given(obs=st.sampled_from((torch.quantization.default_observer()(), HistogramObserver(bins=10))))
def test_observer_scriptable(self, obs):
def test_observer_scriptable(self):
obs = torch.quantization.default_observer()()
scripted = torch.jit.script(obs)
x = torch.rand(3, 4)
@ -829,35 +827,5 @@ class QuantizationDebugTest(QuantizationTestCase):
loaded = torch.jit.load(buf)
self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0]))
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
reduce_range=st.booleans())
def test_histogram_observer(self, qdtype, qscheme, reduce_range):
myobs = HistogramObserver(bins=10, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0])
myobs(x)
myobs(y)
self.assertEqual(myobs.min_val, -1.5)
self.assertEqual(myobs.max_val, 8.5)
self.assertEqual(myobs.histogram, [0., 0., 1., 2., 1., 2., 3., 2., 1., 1.])
qparams = myobs.calculate_qparams()
if reduce_range:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.066666 * 255 / 127
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0333333 * 255 / 127
ref_zero_point = -64 if qdtype is torch.qint8 else 0
else:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.066666
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0333333
ref_zero_point = -128 if qdtype is torch.qint8 else 0
self.assertEqual(qparams[1].item(), ref_zero_point)
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
if __name__ == '__main__':
run_tests()