mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torch quantization]raise exception when OOM during combine histogram in observer (#123309)
Summary: Even with changes in D55347133, it is still possible to OOM in histogram observer, because the size of allocated tensor also depends on *downsample_rate*. For example, I still see OOM due to the attempt of allocating a 10GB+ histogram tensor in multi-task model. To fix OOM issue better, we use *try-catch* clause to avoid OOM. Empirically, we set the max size of a single histogram tensor size to 1 GB. Test Plan: Test the change for Multi-Task model (depth + segmentation) Differential Revision: D55567292 Pull Request resolved: https://github.com/pytorch/pytorch/pull/123309 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
d3596cf004
commit
77643ed2eb
@ -357,6 +357,31 @@ class TestObserver(QuantizationTestCase):
|
||||
result = torch.softmax(input + dequant_mask, dim=1)
|
||||
self.assertEqual(result, ref_result)
|
||||
|
||||
def test_histogram_observer_handle_OOM_due_to_close_min_max_value(self):
|
||||
obser = HistogramObserver.with_args(reduce_range=False)()
|
||||
# close min and max value in the 1st forward() pass of observer tends
|
||||
# to cause OOM in the following pass.
|
||||
# This is due to the allocation of histogram tensor during _combine_histograms().
|
||||
# With sanity check on the size of histogram tensor, we expect the histogram observer
|
||||
# can still work by resetting the histogram
|
||||
x1 = torch.tensor([0, 1e-9])
|
||||
obser(x1)
|
||||
|
||||
x2 = torch.tensor([2.0, 3.0])
|
||||
obser(x2)
|
||||
|
||||
def test_histogram_observer_handle_OOM_due_to_large_upsample_rate(self):
|
||||
# a large upsample rate leads to OOM due to the allocation of histogram tensor
|
||||
# during _combine_histograms(). With sanity check on the size of histogram tensor,
|
||||
# we expect the histogram observer can still work by resetting the histogram
|
||||
obser = HistogramObserver.with_args(upsample_rate=(8000**2), reduce_range=False)()
|
||||
|
||||
x1 = torch.tensor([0, 1.0])
|
||||
obser(x1)
|
||||
|
||||
x2 = torch.tensor([2, 2 + 1e-9])
|
||||
obser(x2)
|
||||
|
||||
def test_histogram_observer_save_load_state_dict(self):
|
||||
"""
|
||||
Smoke test on saving/loading state_dict
|
||||
|
@ -1189,6 +1189,18 @@ class HistogramObserver(UniformQuantizationObserverBase):
|
||||
orig_hist = orig_hist + interpolated_histogram.to(torch.float)
|
||||
return orig_hist
|
||||
|
||||
def reset_histogram(self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor) -> None:
|
||||
self.min_val.resize_(min_val.shape)
|
||||
self.min_val.copy_(min_val)
|
||||
self.max_val.resize_(max_val.shape)
|
||||
self.max_val.copy_(max_val)
|
||||
assert (
|
||||
min_val.numel() == 1 and max_val.numel() == 1
|
||||
), "histogram min/max values must be scalar."
|
||||
torch.histc(
|
||||
x, self.bins, min=min_val, max=max_val, out=self.histogram # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
|
||||
if x_orig.numel() == 0:
|
||||
return x_orig
|
||||
@ -1212,16 +1224,7 @@ class HistogramObserver(UniformQuantizationObserverBase):
|
||||
is_uninitialized = min_val == float("inf") and max_val == float("-inf")
|
||||
if is_uninitialized or same_values or close_values:
|
||||
min_val, max_val = x_min, x_max
|
||||
self.min_val.resize_(min_val.shape)
|
||||
self.min_val.copy_(min_val)
|
||||
self.max_val.resize_(max_val.shape)
|
||||
self.max_val.copy_(max_val)
|
||||
assert (
|
||||
min_val.numel() == 1 and max_val.numel() == 1
|
||||
), "histogram min/max values must be scalar."
|
||||
torch.histc(
|
||||
x, self.bins, min=min_val, max=max_val, out=self.histogram # type: ignore[arg-type]
|
||||
)
|
||||
self.reset_histogram(x, min_val, max_val)
|
||||
else:
|
||||
new_min, new_max = x_min, x_max
|
||||
combined_min = torch.min(new_min, min_val)
|
||||
@ -1249,21 +1252,29 @@ class HistogramObserver(UniformQuantizationObserverBase):
|
||||
if combined_min == min_val and combined_max == max_val:
|
||||
combined_histogram += self.histogram
|
||||
else:
|
||||
combined_histogram = self._combine_histograms(
|
||||
combined_histogram,
|
||||
self.histogram,
|
||||
self.upsample_rate,
|
||||
downsample_rate,
|
||||
start_idx,
|
||||
self.bins,
|
||||
)
|
||||
MAX_HISTOGRAM_SIZE = 1e9 # 1 GB
|
||||
histogram_size = self.bins * downsample_rate * 4
|
||||
if histogram_size > MAX_HISTOGRAM_SIZE:
|
||||
warnings.warn(
|
||||
"Fail to combine histograms. Fall back to reset histogram."
|
||||
)
|
||||
self.reset_histogram(x, x_min, x_max)
|
||||
else:
|
||||
combined_histogram = self._combine_histograms(
|
||||
combined_histogram,
|
||||
self.histogram,
|
||||
self.upsample_rate,
|
||||
downsample_rate,
|
||||
start_idx,
|
||||
self.bins,
|
||||
)
|
||||
self.histogram.detach_().resize_(combined_histogram.shape)
|
||||
self.histogram.copy_(combined_histogram)
|
||||
self.min_val.detach_().resize_(combined_min.shape)
|
||||
self.min_val.copy_(combined_min)
|
||||
self.max_val.detach_().resize_(combined_max.shape)
|
||||
self.max_val.copy_(combined_max)
|
||||
|
||||
self.histogram.detach_().resize_(combined_histogram.shape)
|
||||
self.histogram.copy_(combined_histogram)
|
||||
self.min_val.detach_().resize_(combined_min.shape)
|
||||
self.min_val.copy_(combined_min)
|
||||
self.max_val.detach_().resize_(combined_max.shape)
|
||||
self.max_val.copy_(combined_max)
|
||||
return x_orig
|
||||
|
||||
@torch.jit.export
|
||||
|
Reference in New Issue
Block a user