[ao] fix incorrect integer cast on histogram observer bounds (#90355)

Summary: A cast to int was added in
https://github.com/pytorch/pytorch/pull/45630 to make mypy not complain.
However this leads to unexpected behavior where the histogram doesn't
actually capture the full range of activation values.

note1: the test_histogram_observer_against_reference test was secretly
broken, on master. The random parameters that normally get run apparently don't cause a test failure but if you make a loop repeatedly run the test, it would
eventually fail. This was due to in some cases
sum(<tensor>)!=torch.sum(<tensor>).item(). I was not able to reproduce
this with a toy example but running this test in a loop and editing
either observer to print the calculation for 'total' would break the
test and show different behaviors. Fixing this test was necessary to
land this PR since the changing histogram bounds changed things enough
that this test would error.

note2: updating histogram observer breaks some BC tests unless I regenerate the
model using the HistogramObserver from this PR

Test Plan: python test/test_quantization.py TestHistogramObserver.test_histogram_observer_correct_numel

python test/test_quantization -k histogram

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90355
Approved by: https://github.com/vkuzo
This commit is contained in:
HDCharles
2022-12-12 08:54:54 -08:00
committed by PyTorch MergeBot
parent 60e196c241
commit e11650887e
6 changed files with 18 additions and 3 deletions

View File

@ -503,7 +503,7 @@ class _ReferenceHistogramObserver(HistogramObserver):
bin_width = (self.max_val - self.min_val) / self.bins
# cumulative sum
total = sum(self.histogram)
total = torch.sum(self.histogram).item()
cSum = torch.cumsum(self.histogram, dim=0)
stepsize = 1e-5 # granularity
@ -706,10 +706,19 @@ class TestHistogramObserver(QuantizationTestCase):
X = torch.randn(N)
my_obs(X)
ref_obs(X)
self.assertEqual(my_obs.histogram, ref_obs.histogram)
self.assertEqual(my_obs.min_val, ref_obs.min_val)
self.assertEqual(my_obs.max_val, ref_obs.max_val)
ref_qparams = ref_obs.calculate_qparams()
my_qparams = my_obs.calculate_qparams()
for i in range(0, bins, 200):
for j in range(i + 5, bins, 200):
ref_qe = ref_obs._compute_quantization_error(i, j)
qe = my_obs._compute_quantization_error(i, j)
self.assertEqual(ref_qe, qe)
self.assertEqual(ref_qparams, my_qparams)
def test_histogram_observer_extreme_inputs(self):
@ -726,6 +735,12 @@ class TestHistogramObserver(QuantizationTestCase):
obs(test_input)
obs(test_input)
def test_histogram_observer_correct_numel(self):
for i in range(1, 10):
obs = HistogramObserver()
obs(torch.randn(i, i))
self.assertEqual(obs.histogram.sum().item(), i**2)
class TestFakeQuantize(TestCase):
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),

View File

@ -1154,7 +1154,7 @@ class HistogramObserver(UniformQuantizationObserverBase):
min_val.numel() == 1 and max_val.numel() == 1
), "histogram min/max values must be scalar."
torch.histc(
x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram
x, self.bins, min=min_val, max=max_val, out=self.histogram # type: ignore[arg-type]
)
else:
new_min, new_max = torch.aminmax(x)
@ -1173,7 +1173,7 @@ class HistogramObserver(UniformQuantizationObserverBase):
combined_min.numel() == 1 and combined_max.numel() == 1
), "histogram min/max values must be scalar."
combined_histogram = torch.histc(
x, self.bins, min=int(combined_min), max=int(combined_max)
x, self.bins, min=combined_min, max=combined_max # type: ignore[arg-type]
)
if combined_min == min_val and combined_max == max_val:
combined_histogram += self.histogram