mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ao] fix incorrect integer cast on histogram observer bounds (#90355)
Summary: A cast to int was added in https://github.com/pytorch/pytorch/pull/45630 to make mypy not complain. However this leads to unexpected behavior where the histogram doesn't actually capture the full range of activation values. note1: the test_histogram_observer_against_reference test was secretly broken, on master. The random parameters that normally get run apparently don't cause a test failure but if you make a loop repeatedly run the test, it would eventually fail. This was due to in some cases sum(<tensor>)!=torch.sum(<tensor>).item(). I was not able to reproduce this with a toy example but running this test in a loop and editing either observer to print the calculation for 'total' would break the test and show different behaviors. Fixing this test was necessary to land this PR since the changing histogram bounds changed things enough that this test would error. note2: updating histogram observer breaks some BC tests unless I regenerate the model using the HistogramObserver from this PR Test Plan: python test/test_quantization.py TestHistogramObserver.test_histogram_observer_correct_numel python test/test_quantization -k histogram Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/90355 Approved by: https://github.com/vkuzo
This commit is contained in:
committed by
PyTorch MergeBot
parent
60e196c241
commit
e11650887e
@ -503,7 +503,7 @@ class _ReferenceHistogramObserver(HistogramObserver):
|
|||||||
bin_width = (self.max_val - self.min_val) / self.bins
|
bin_width = (self.max_val - self.min_val) / self.bins
|
||||||
|
|
||||||
# cumulative sum
|
# cumulative sum
|
||||||
total = sum(self.histogram)
|
total = torch.sum(self.histogram).item()
|
||||||
cSum = torch.cumsum(self.histogram, dim=0)
|
cSum = torch.cumsum(self.histogram, dim=0)
|
||||||
|
|
||||||
stepsize = 1e-5 # granularity
|
stepsize = 1e-5 # granularity
|
||||||
@ -706,10 +706,19 @@ class TestHistogramObserver(QuantizationTestCase):
|
|||||||
X = torch.randn(N)
|
X = torch.randn(N)
|
||||||
my_obs(X)
|
my_obs(X)
|
||||||
ref_obs(X)
|
ref_obs(X)
|
||||||
|
self.assertEqual(my_obs.histogram, ref_obs.histogram)
|
||||||
|
self.assertEqual(my_obs.min_val, ref_obs.min_val)
|
||||||
|
self.assertEqual(my_obs.max_val, ref_obs.max_val)
|
||||||
|
|
||||||
ref_qparams = ref_obs.calculate_qparams()
|
ref_qparams = ref_obs.calculate_qparams()
|
||||||
my_qparams = my_obs.calculate_qparams()
|
my_qparams = my_obs.calculate_qparams()
|
||||||
|
|
||||||
|
for i in range(0, bins, 200):
|
||||||
|
for j in range(i + 5, bins, 200):
|
||||||
|
ref_qe = ref_obs._compute_quantization_error(i, j)
|
||||||
|
qe = my_obs._compute_quantization_error(i, j)
|
||||||
|
self.assertEqual(ref_qe, qe)
|
||||||
|
|
||||||
self.assertEqual(ref_qparams, my_qparams)
|
self.assertEqual(ref_qparams, my_qparams)
|
||||||
|
|
||||||
def test_histogram_observer_extreme_inputs(self):
|
def test_histogram_observer_extreme_inputs(self):
|
||||||
@ -726,6 +735,12 @@ class TestHistogramObserver(QuantizationTestCase):
|
|||||||
obs(test_input)
|
obs(test_input)
|
||||||
obs(test_input)
|
obs(test_input)
|
||||||
|
|
||||||
|
def test_histogram_observer_correct_numel(self):
|
||||||
|
for i in range(1, 10):
|
||||||
|
obs = HistogramObserver()
|
||||||
|
obs(torch.randn(i, i))
|
||||||
|
self.assertEqual(obs.histogram.sum().item(), i**2)
|
||||||
|
|
||||||
|
|
||||||
class TestFakeQuantize(TestCase):
|
class TestFakeQuantize(TestCase):
|
||||||
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
||||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1154,7 +1154,7 @@ class HistogramObserver(UniformQuantizationObserverBase):
|
|||||||
min_val.numel() == 1 and max_val.numel() == 1
|
min_val.numel() == 1 and max_val.numel() == 1
|
||||||
), "histogram min/max values must be scalar."
|
), "histogram min/max values must be scalar."
|
||||||
torch.histc(
|
torch.histc(
|
||||||
x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram
|
x, self.bins, min=min_val, max=max_val, out=self.histogram # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
new_min, new_max = torch.aminmax(x)
|
new_min, new_max = torch.aminmax(x)
|
||||||
@ -1173,7 +1173,7 @@ class HistogramObserver(UniformQuantizationObserverBase):
|
|||||||
combined_min.numel() == 1 and combined_max.numel() == 1
|
combined_min.numel() == 1 and combined_max.numel() == 1
|
||||||
), "histogram min/max values must be scalar."
|
), "histogram min/max values must be scalar."
|
||||||
combined_histogram = torch.histc(
|
combined_histogram = torch.histc(
|
||||||
x, self.bins, min=int(combined_min), max=int(combined_max)
|
x, self.bins, min=combined_min, max=combined_max # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
if combined_min == min_val and combined_max == max_val:
|
if combined_min == min_val and combined_max == max_val:
|
||||||
combined_histogram += self.histogram
|
combined_histogram += self.histogram
|
||||||
|
Reference in New Issue
Block a user