[torch quantization]raise exception when OOM during combine histogram in observer (#123309)

Summary:
Even with changes in D55347133, it is still possible to OOM in histogram observer, because the size of allocated tensor also depends on *downsample_rate*.

For example, I still see OOM due to the attempt of allocating a 10GB+ histogram tensor in multi-task model.

To fix OOM issue better, we use *try-catch* clause to avoid OOM.
Empirically, we set the max size of a single histogram tensor size to 1 GB.

Test Plan: Test the change for Multi-Task model (depth + segmentation)

Differential Revision: D55567292

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123309
Approved by: https://github.com/jerryzh168
This commit is contained in:
Zhicheng Yan
2024-04-06 03:15:02 +00:00
committed by PyTorch MergeBot
parent d3596cf004
commit 77643ed2eb
2 changed files with 60 additions and 24 deletions

View File

@ -357,6 +357,31 @@ class TestObserver(QuantizationTestCase):
result = torch.softmax(input + dequant_mask, dim=1)
self.assertEqual(result, ref_result)
def test_histogram_observer_handle_OOM_due_to_close_min_max_value(self):
obser = HistogramObserver.with_args(reduce_range=False)()
# close min and max value in the 1st forward() pass of observer tends
# to cause OOM in the following pass.
# This is due to the allocation of histogram tensor during _combine_histograms().
# With sanity check on the size of histogram tensor, we expect the histogram observer
# can still work by resetting the histogram
x1 = torch.tensor([0, 1e-9])
obser(x1)
x2 = torch.tensor([2.0, 3.0])
obser(x2)
def test_histogram_observer_handle_OOM_due_to_large_upsample_rate(self):
# a large upsample rate leads to OOM due to the allocation of histogram tensor
# during _combine_histograms(). With sanity check on the size of histogram tensor,
# we expect the histogram observer can still work by resetting the histogram
obser = HistogramObserver.with_args(upsample_rate=(8000**2), reduce_range=False)()
x1 = torch.tensor([0, 1.0])
obser(x1)
x2 = torch.tensor([2, 2 + 1e-9])
obser(x2)
def test_histogram_observer_save_load_state_dict(self):
"""
Smoke test on saving/loading state_dict

View File

@ -1189,6 +1189,18 @@ class HistogramObserver(UniformQuantizationObserverBase):
orig_hist = orig_hist + interpolated_histogram.to(torch.float)
return orig_hist
def reset_histogram(self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor) -> None:
self.min_val.resize_(min_val.shape)
self.min_val.copy_(min_val)
self.max_val.resize_(max_val.shape)
self.max_val.copy_(max_val)
assert (
min_val.numel() == 1 and max_val.numel() == 1
), "histogram min/max values must be scalar."
torch.histc(
x, self.bins, min=min_val, max=max_val, out=self.histogram # type: ignore[arg-type]
)
def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
if x_orig.numel() == 0:
return x_orig
@ -1212,16 +1224,7 @@ class HistogramObserver(UniformQuantizationObserverBase):
is_uninitialized = min_val == float("inf") and max_val == float("-inf")
if is_uninitialized or same_values or close_values:
min_val, max_val = x_min, x_max
self.min_val.resize_(min_val.shape)
self.min_val.copy_(min_val)
self.max_val.resize_(max_val.shape)
self.max_val.copy_(max_val)
assert (
min_val.numel() == 1 and max_val.numel() == 1
), "histogram min/max values must be scalar."
torch.histc(
x, self.bins, min=min_val, max=max_val, out=self.histogram # type: ignore[arg-type]
)
self.reset_histogram(x, min_val, max_val)
else:
new_min, new_max = x_min, x_max
combined_min = torch.min(new_min, min_val)
@ -1249,21 +1252,29 @@ class HistogramObserver(UniformQuantizationObserverBase):
if combined_min == min_val and combined_max == max_val:
combined_histogram += self.histogram
else:
combined_histogram = self._combine_histograms(
combined_histogram,
self.histogram,
self.upsample_rate,
downsample_rate,
start_idx,
self.bins,
)
MAX_HISTOGRAM_SIZE = 1e9 # 1 GB
histogram_size = self.bins * downsample_rate * 4
if histogram_size > MAX_HISTOGRAM_SIZE:
warnings.warn(
"Fail to combine histograms. Fall back to reset histogram."
)
self.reset_histogram(x, x_min, x_max)
else:
combined_histogram = self._combine_histograms(
combined_histogram,
self.histogram,
self.upsample_rate,
downsample_rate,
start_idx,
self.bins,
)
self.histogram.detach_().resize_(combined_histogram.shape)
self.histogram.copy_(combined_histogram)
self.min_val.detach_().resize_(combined_min.shape)
self.min_val.copy_(combined_min)
self.max_val.detach_().resize_(combined_max.shape)
self.max_val.copy_(combined_max)
self.histogram.detach_().resize_(combined_histogram.shape)
self.histogram.copy_(combined_histogram)
self.min_val.detach_().resize_(combined_min.shape)
self.min_val.copy_(combined_min)
self.max_val.detach_().resize_(combined_max.shape)
self.max_val.copy_(combined_max)
return x_orig
@torch.jit.export