mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33626 For DDP we require the attributes to be registered as buffers. By doing this the value is broadcast from one device to the rest. Test Plan: Tested on actual model on GPU Imported from OSS Differential Revision: D20038839 fbshipit-source-id: 82e829fc3baca0b3262c3894a283c375eb08a4a4
908 lines
36 KiB
Python
908 lines
36 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import math
|
|
import warnings
|
|
from abc import ABCMeta, abstractmethod
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch._jit_internal import List, Optional
|
|
|
|
def _with_args(cls_or_self, **kwargs):
|
|
r"""Wrapper that allows creation of class factories.
|
|
|
|
This can be useful when there is a need to create classes with the same
|
|
constructor arguments, but different instances.
|
|
|
|
.. Example::
|
|
|
|
>>> Foo.with_args = classmethod(_with_args)
|
|
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
|
|
>>> foo_instance1 = foo_builder()
|
|
>>> foo_instance2 = foo_builder()
|
|
>>> id(foo_instance1) == id(foo_instance2)
|
|
False
|
|
"""
|
|
class _PartialWrapper(object):
|
|
def __init__(self, p):
|
|
self.p = p
|
|
|
|
def __call__(self, *args, **keywords):
|
|
return self.p(*args, **keywords)
|
|
|
|
def __repr__(self):
|
|
return self.p.__repr__()
|
|
|
|
with_args = _with_args
|
|
r = _PartialWrapper(partial(cls_or_self, **kwargs))
|
|
return r
|
|
|
|
|
|
ABC = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
|
|
|
|
|
|
class ObserverBase(ABC, nn.Module):
|
|
r"""Base observer Module.
|
|
Any observer implementation should derive from this class.
|
|
|
|
Concrete observers should follow the same API. In forward, they will update
|
|
the statistics of the observed Tensor. And they should provide a
|
|
`calculate_qparams` function that computes the quantization parameters given
|
|
the collected statistics.
|
|
|
|
Args:
|
|
dtype: Quantized data type
|
|
"""
|
|
def __init__(self, dtype):
|
|
super(ObserverBase, self).__init__()
|
|
self.dtype = dtype
|
|
|
|
@abstractmethod
|
|
def forward(self, x):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def calculate_qparams(self, **kwargs):
|
|
pass
|
|
|
|
# Returns all quantization parameters that's needed
|
|
# for a quantize function call
|
|
# For instance, per channel obsserver will return
|
|
# scales, zero_points and axis
|
|
@abstractmethod
|
|
def get_qparams(self, **kwargs):
|
|
pass
|
|
|
|
with_args = classmethod(_with_args)
|
|
|
|
|
|
class _ObserverBase(ObserverBase):
|
|
r"""Internal common base for all qint/quint8 observers.
|
|
|
|
This base is for commonly used paramters used internally.
|
|
Users should use `~torch.quantization.observer.ObserverBase` as a base class
|
|
for custom observers.
|
|
|
|
Args:
|
|
dtype: Quantized data type.
|
|
qscheme: Quantization scheme to be used.
|
|
reduce_range: Reduces the range of the quantized data type by 1 bit.
|
|
This is sometimes required to avoid instruction overflow.
|
|
|
|
.. warning::
|
|
|
|
:attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
|
|
|
|
.. warning::
|
|
|
|
:attr:`qscheme` can only take one of the following options:
|
|
|
|
- ``torch.per_tensor_affine``
|
|
- ``torch.per_tensor_symmetric``
|
|
- ``torch.per_channel_affine``
|
|
- ``torch.per_channel_symmetric``
|
|
"""
|
|
|
|
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
|
|
reduce_range=False):
|
|
super(_ObserverBase, self).__init__(dtype=dtype)
|
|
self.qscheme = qscheme
|
|
self.reduce_range = reduce_range
|
|
|
|
self.eps = torch.finfo(torch.float32).eps
|
|
assert self.qscheme in (
|
|
torch.per_tensor_affine,
|
|
torch.per_tensor_symmetric,
|
|
torch.per_channel_affine,
|
|
torch.per_channel_symmetric,
|
|
), "Default Observer only works for per_tensor_affine, \
|
|
per_tensor_symmetric, per_channel_affine and \
|
|
per_channel_symmetric quantization scheme"
|
|
assert self.dtype in (
|
|
torch.qint8,
|
|
torch.quint8,
|
|
), "Default Observer only works for qint8 and quint8 data type"
|
|
|
|
def _calculate_per_channel_qparams(self, min_vals, max_vals):
|
|
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
|
|
r"""Calculates the per channel quantization parameters, given min and max
|
|
value tensors.
|
|
|
|
Args:
|
|
min_vals: Minimum values per channel
|
|
max_vals: Maximum values per channel
|
|
|
|
Returns:
|
|
scales: Per channel scales tensor of shape (#channels,)
|
|
zero_points: Per channel zero points tensor of shape (#channels,)
|
|
"""
|
|
if min_vals.numel() == 0 or max_vals.numel() == 0:
|
|
warnings.warn(
|
|
"must run observer before calling calculate_qparams.\
|
|
Returning default scale and zero point "
|
|
)
|
|
return torch.tensor([1.0]), torch.tensor([0])
|
|
|
|
for i in range(len(min_vals)):
|
|
assert (
|
|
min_vals[i] <= max_vals[i]
|
|
), "min {} should be less than max {}".format(min_vals[i], max_vals[i])
|
|
|
|
scales = torch.empty(min_vals.size(), dtype=torch.float32)
|
|
zero_points = torch.empty(min_vals.size(), dtype=torch.int64)
|
|
|
|
for i in range(len(scales)):
|
|
qparam = self._calculate_qparams(
|
|
min_vals[i], max_vals[i]
|
|
)
|
|
scales[i] = float(qparam[0])
|
|
zero_points[i] = int(qparam[1])
|
|
|
|
return scales, zero_points
|
|
|
|
@torch.jit.export
|
|
def _calculate_qparams(self, min_val, max_val):
|
|
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
|
|
r"""Calculates the per tensor quantization parameters, given the min/max.
|
|
|
|
Args:
|
|
min_val: Per tensor minimum value
|
|
max_val: Per tensor maximum value
|
|
|
|
Returns:
|
|
scale: Scale as a tensor of shape (1,)
|
|
zero_point: Zero point as a tensor of shape (1,)
|
|
"""
|
|
|
|
if max_val.numel() == 0 or min_val.numel() == 0:
|
|
warnings.warn("Must run observer before calling calculate_qparams.\
|
|
Returning default scale and zero point.")
|
|
return torch.tensor([1.0]), torch.tensor([0])
|
|
|
|
assert min_val <= max_val, "min {} should be less than max {}".format(
|
|
min_val, max_val
|
|
)
|
|
|
|
if self.dtype == torch.qint8:
|
|
if self.reduce_range:
|
|
qmin, qmax = -64, 63
|
|
else:
|
|
qmin, qmax = -128, 127
|
|
else:
|
|
if self.reduce_range:
|
|
qmin, qmax = 0, 127
|
|
else:
|
|
qmin, qmax = 0, 255
|
|
|
|
max_val, min_val = float(max_val), float(min_val)
|
|
min_val = min(0.0, min_val)
|
|
max_val = max(0.0, max_val)
|
|
if max_val == min_val:
|
|
scale = 1.0
|
|
zero_point = 0
|
|
else:
|
|
if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric:
|
|
max_val = max(-min_val, max_val)
|
|
scale = max_val / ((qmax - qmin) / 2)
|
|
scale = max(scale, self.eps)
|
|
zero_point = 0 if self.dtype == torch.qint8 else 128
|
|
else:
|
|
scale = (max_val - min_val) / float(qmax - qmin)
|
|
scale = max(scale, self.eps)
|
|
zero_point = qmin - round(min_val / scale)
|
|
zero_point = max(qmin, zero_point)
|
|
zero_point = min(qmax, zero_point)
|
|
zero_point = int(zero_point)
|
|
|
|
return torch.tensor([scale]), torch.tensor([zero_point])
|
|
|
|
@torch.jit.export
|
|
def get_qparams(self):
|
|
r"""Get all quantization parameters needed for quantize call"""
|
|
return self.calculate_qparams()
|
|
|
|
class MinMaxObserver(_ObserverBase):
|
|
r"""Observer module for computing the quantization parameters based on the
|
|
running min and max values.
|
|
|
|
This observer uses the tensor min/max statistics to compute the quantization
|
|
parameters. The module records the running minimum and maximum of incoming
|
|
tensors, and uses this statistic to compute the quantization parameters.
|
|
|
|
Args:
|
|
dtype: Quantized data type
|
|
qscheme: Quantization scheme to be used
|
|
reduce_range: Reduces the range of the quantized data type by 1 bit
|
|
|
|
Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`,
|
|
scale :math:`s` and zero point :math:`z` are computed as:
|
|
|
|
The running minimum/maximum :math:`x_\text{min/max}` is computed as:
|
|
|
|
.. math::
|
|
|
|
\begin{array}{ll}
|
|
x_\text{min} &= \begin{cases}
|
|
\min(X) & \text{if~}x_\text{min} = \text{None} \\
|
|
\min\left(x_\text{min}, \min(X)\right) & \text{otherwise}
|
|
\end{cases}\\
|
|
x_\text{max} &= \begin{cases}
|
|
\max(X) & \text{if~}x_\text{max} = \text{None} \\
|
|
\max\left(x_\text{max}, \max(X)\right) & \text{otherwise}
|
|
\end{cases}\\
|
|
\end{array}
|
|
|
|
where :math:`X` is the observed tensor.
|
|
|
|
The scale :math:`s` and zero point :math:`z` are then computed as:
|
|
|
|
.. math::
|
|
|
|
\begin{aligned}
|
|
\text{if Symmetric:}&\\
|
|
&s = 2 \max(|x_\text{min}|, x_\text{max}) /
|
|
\left( Q_\text{max} - Q_\text{min} \right) \\
|
|
&z = \begin{cases}
|
|
0 & \text{if dtype is qint8} \\
|
|
128 & \text{otherwise}
|
|
\end{cases}\\
|
|
\text{Otherwise:}&\\
|
|
&s = \left( x_\text{max} - x_\text{min} \right ) /
|
|
\left( Q_\text{max} - Q_\text{min} \right ) \\
|
|
&z = Q_\text{min} - \text{round}(x_\text{min} / s)
|
|
\end{aligned}
|
|
|
|
where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and
|
|
maximum of the quantized data type.
|
|
|
|
.. warning:: Only works with ``torch.per_tensor_symmetric`` quantization scheme
|
|
|
|
.. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
|
|
|
|
.. note:: If the running minimum equals to the running maximum, the scale
|
|
and zero_point are set to 1.0 and 0.
|
|
"""
|
|
|
|
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
|
|
reduce_range=False):
|
|
# For x86 quantized kernels, we need to ensure that the vpmaddubsw
|
|
# instruction does not overflow. We allow for a reduce_range argument to
|
|
# observers that reduces the quantized range to (0,127) or (-64, 63).
|
|
# For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp
|
|
# This is not an optimal choice for non x86 backends as it loses a bit
|
|
# of precision for activations.
|
|
|
|
super(MinMaxObserver, self).__init__(dtype=dtype,
|
|
qscheme=qscheme,
|
|
reduce_range=reduce_range)
|
|
self.register_buffer('min_val', torch.tensor([]))
|
|
self.register_buffer('max_val', torch.tensor([]))
|
|
if self.qscheme == torch.per_tensor_symmetric and \
|
|
self.reduce_range and \
|
|
self.dtype == torch.quint8:
|
|
raise NotImplementedError("Cannot reduce range for symmetric \
|
|
quantization for quint8")
|
|
|
|
def forward(self, x_orig):
|
|
r"""Records the running minimum and maximum of ``x``."""
|
|
x = x_orig.detach() # avoid keeping autograd tape
|
|
min_val = self.min_val
|
|
max_val = self.max_val
|
|
if min_val.numel() == 0 or max_val.numel() == 0:
|
|
min_val = torch.min(x)
|
|
max_val = torch.max(x)
|
|
else:
|
|
min_val = torch.min(torch.min(x), min_val)
|
|
max_val = torch.max(torch.max(x), max_val)
|
|
self.min_val = min_val
|
|
self.max_val = max_val
|
|
return x_orig
|
|
|
|
@torch.jit.export
|
|
def calculate_qparams(self):
|
|
r"""Calculates the quantization parameters."""
|
|
return self._calculate_qparams(self.min_val, self.max_val)
|
|
|
|
@torch.jit.export
|
|
def extra_repr(self):
|
|
return "min_val={}, max_val={}".format(self.min_val, self.max_val)
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
super(MinMaxObserver, self)._save_to_state_dict(destination, prefix, keep_vars)
|
|
destination[prefix + 'min_val'] = self.min_val
|
|
destination[prefix + 'max_val'] = self.max_val
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
|
|
local_state = ['min_val', 'max_val']
|
|
for name in local_state:
|
|
key = prefix + name
|
|
if key in state_dict:
|
|
val = state_dict[key]
|
|
setattr(self, name, val)
|
|
elif strict:
|
|
missing_keys.append(key)
|
|
super(MinMaxObserver, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
|
|
class MovingAverageMinMaxObserver(MinMaxObserver):
|
|
r"""Observer module for computing the quantization parameters based on the
|
|
moving average of the min and max values.
|
|
|
|
This observer computes the quantization parameters based on the moving
|
|
averages of minimums and maximums of the incoming tensors. The module
|
|
records the average minimum and maximum of incoming tensors, and uses this
|
|
statistic to compute the quantization parameters.
|
|
|
|
Args:
|
|
averaging_constant: Averaging constant for min/max.
|
|
dtype: Quantized data type
|
|
qscheme: Quantization scheme to be used
|
|
reduce_range: Reduces the range of the quantized data type by 1 bit
|
|
|
|
The moving average min/max is computed as follows
|
|
|
|
.. math::
|
|
|
|
\begin{array}{ll}
|
|
x_\text{min} = \begin{cases}
|
|
\min(X) & \text{if~}x_\text{min} = \text{None} \\
|
|
(1 - c) x_\text{min} + c \min(X) & \text{otherwise}
|
|
\end{cases}\\
|
|
x_\text{max} = \begin{cases}
|
|
\max(X) & \text{if~}x_\text{max} = \text{None} \\
|
|
(1 - c) x_\text{max} + c \max(X) & \text{otherwise}
|
|
\end{cases}\\
|
|
\end{array}
|
|
|
|
where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is
|
|
is the incoming tensor, and :math:`c` is the ``averaging_constant``.
|
|
|
|
The scale and zero point are then computed as in
|
|
:class:`~torch.quantization.observer.MinMaxObserver`.
|
|
|
|
.. note:: Only works with ``torch.per_tensor_affine`` quantization shceme.
|
|
|
|
.. note:: If the running minimum equals to the running maximum, the scale
|
|
and zero_point are set to 1.0 and 0.
|
|
"""
|
|
def __init__(self, averaging_constant=0.01, dtype=torch.quint8,
|
|
qscheme=torch.per_tensor_affine, reduce_range=False):
|
|
self.averaging_constant = averaging_constant
|
|
super(MovingAverageMinMaxObserver, self).__init__(dtype=dtype,
|
|
qscheme=qscheme,
|
|
reduce_range=reduce_range)
|
|
|
|
def forward(self, x_orig):
|
|
x = x_orig.detach() # avoid keeping autograd tape
|
|
min_val = self.min_val
|
|
max_val = self.max_val
|
|
if min_val.numel() == 0 or max_val.numel() == 0:
|
|
min_val = torch.min(x)
|
|
max_val = torch.max(x)
|
|
else:
|
|
min_val = min_val + self.averaging_constant * (torch.min(x) - min_val)
|
|
max_val = max_val + self.averaging_constant * (torch.max(x) - max_val)
|
|
self.min_val = min_val
|
|
self.max_val = max_val
|
|
return x_orig
|
|
|
|
|
|
class PerChannelMinMaxObserver(_ObserverBase):
|
|
r"""Observer module for computing the quantization parameters based on the
|
|
running per channel min and max values.
|
|
|
|
This observer uses the tensor min/max statistics to compute the per channel
|
|
quantization parameters. The module records the running minimum and maximum
|
|
of incoming tensors, and uses this statistic to compute the quantization
|
|
parameters.
|
|
|
|
Args:
|
|
ch_axis: Channel axis
|
|
dtype: Quantized data type
|
|
qscheme: Quantization scheme to be used
|
|
reduce_range: Reduces the range of the quantized data type by 1 bit
|
|
|
|
The quantization parameters are computed the same way as in
|
|
:class:`~torch.quantization.observer.MinMaxObserver`, with the difference
|
|
that the running min/max values are stored per channel.
|
|
Scales and zero points are thus computed per channel as well.
|
|
|
|
.. note:: If the running minimum equals to the running maximum, the scales
|
|
and zero_points are set to 1.0 and 0.
|
|
"""
|
|
|
|
def __init__(self, ch_axis=0, dtype=torch.quint8,
|
|
qscheme=torch.per_channel_affine, reduce_range=False):
|
|
super(PerChannelMinMaxObserver, self).__init__(dtype=dtype,
|
|
qscheme=qscheme,
|
|
reduce_range=reduce_range)
|
|
self.ch_axis = ch_axis
|
|
self.register_buffer('min_vals', torch.tensor([]))
|
|
self.register_buffer('max_vals', torch.tensor([]))
|
|
if (
|
|
self.qscheme == torch.per_channel_symmetric
|
|
and self.reduce_range
|
|
and self.dtype == torch.quint8
|
|
):
|
|
raise NotImplementedError(
|
|
"Cannot reduce range for symmetric quantization for quint8"
|
|
)
|
|
|
|
def forward(self, x_orig):
|
|
return self._forward(x_orig)
|
|
|
|
@torch.jit.ignore
|
|
def _forward(self, x_orig):
|
|
x = x_orig.detach() # avoid keeping autograd tape
|
|
min_vals = self.min_vals
|
|
max_vals = self.max_vals
|
|
x_dim = x.size()
|
|
|
|
new_axis_list = list(range(len(x_dim)))
|
|
new_axis_list[self.ch_axis] = 0
|
|
new_axis_list[0] = self.ch_axis
|
|
y = x.permute(tuple(new_axis_list))
|
|
y = torch.flatten(y, start_dim=1)
|
|
if min_vals.numel() == 0 or max_vals.numel() == 0:
|
|
min_vals = torch.min(y, 1)[0]
|
|
max_vals = torch.max(y, 1)[0]
|
|
else:
|
|
min_vals = torch.min(torch.min(y, 1)[0], min_vals)
|
|
max_vals = torch.max(torch.max(y, 1)[0], max_vals)
|
|
self.min_vals = min_vals
|
|
self.max_vals = max_vals
|
|
return x_orig
|
|
|
|
@torch.jit.export
|
|
def calculate_qparams(self):
|
|
return self._calculate_per_channel_qparams(self.min_vals, self.max_vals)
|
|
|
|
@torch.jit.export
|
|
def get_qparams(self):
|
|
scales, zero_points = self.calculate_qparams()
|
|
return scales, zero_points, self.ch_axis
|
|
|
|
def extra_repr(self):
|
|
return "min_val={}, max_val={}".format(self.min_vals, self.max_vals)
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
super(PerChannelMinMaxObserver, self)._save_to_state_dict(destination, prefix, keep_vars)
|
|
destination[prefix + 'min_vals'] = self.min_vals
|
|
destination[prefix + 'max_vals'] = self.max_vals
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
local_state = ['min_vals', 'max_vals']
|
|
for name in local_state:
|
|
key = prefix + name
|
|
if key in state_dict:
|
|
val = state_dict[key]
|
|
setattr(self, name, val)
|
|
elif strict:
|
|
missing_keys.append(key)
|
|
super(PerChannelMinMaxObserver, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver):
|
|
r"""Observer module for computing the quantization parameters based on the
|
|
running per channel min and max values.
|
|
|
|
This observer uses the tensor min/max statistics to compute the per channel
|
|
quantization parameters. The module records the running minimum and maximum
|
|
of incoming tensors, and uses this statistic to compute the quantization
|
|
parameters.
|
|
|
|
Args:
|
|
averaging_constant: Averaging constant for min/max.
|
|
ch_axis: Channel axis
|
|
dtype: Quantized data type
|
|
qscheme: Quantization scheme to be used
|
|
reduce_range: Reduces the range of the quantized data type by 1 bit
|
|
|
|
The quantization parameters are computed the same way as in
|
|
:class:`~torch.quantization.observer.MovingAverageMinMaxObserver`, with the
|
|
difference that the running min/max values are stored per channel.
|
|
Scales and zero points are thus computed per channel as well.
|
|
|
|
.. note:: If the running minimum equals to the running maximum, the scales
|
|
and zero_points are set to 1.0 and 0.
|
|
"""
|
|
|
|
def __init__(self, averaging_constant=0.01, ch_axis=0, dtype=torch.quint8,
|
|
qscheme=torch.per_channel_affine, reduce_range=False):
|
|
super(MovingAveragePerChannelMinMaxObserver, self).__init__(
|
|
ch_axis=ch_axis, dtype=dtype, qscheme=qscheme,
|
|
reduce_range=reduce_range)
|
|
self.averaging_constant = averaging_constant
|
|
|
|
def forward(self, x_orig):
|
|
x = x_orig.detach() # avoid keeping autograd tape
|
|
min_vals = self.min_vals
|
|
max_vals = self.max_vals
|
|
x_dim = x.size()
|
|
|
|
new_axis_list = list(range(len(x_dim)))
|
|
new_axis_list[self.ch_axis] = 0
|
|
new_axis_list[0] = self.ch_axis
|
|
y = x.permute(tuple(new_axis_list))
|
|
y = torch.flatten(y, start_dim=1)
|
|
if min_vals.numel() == 0 or max_vals.numel() == 0:
|
|
min_vals = torch.min(y, 1)[0]
|
|
max_vals = torch.max(y, 1)[0]
|
|
else:
|
|
min_vals = min_vals + self.averaging_constant * (torch.min(y, 1)[0] - min_vals)
|
|
max_vals = max_vals + self.averaging_constant * (torch.max(y, 1)[0] - max_vals)
|
|
self.min_vals = min_vals
|
|
self.max_vals = max_vals
|
|
return x_orig
|
|
|
|
class HistogramObserver(_ObserverBase):
|
|
r"""
|
|
The module records the running histogram of tensor values along with
|
|
min/max values. ``calculate_qparams`` will calculate scale and zero_point.
|
|
|
|
Args:
|
|
bins: Number of bins to use for the histogram
|
|
upsample_rate: Factor by which the histograms are upsampled, this is
|
|
used to interpolate histograms with varying ranges across observations
|
|
dtype: Quantized data type
|
|
qscheme: Quantization scheme to be used
|
|
reduce_range: Reduces the range of the quantized data type by 1 bit
|
|
|
|
The scale and zero point are computed as follows:
|
|
|
|
1. Create the histogram of the incoming inputs.
|
|
The histogram is computed continuously, and the ranges per bin change
|
|
with every new tensor observed.
|
|
2. Search the distribution in the histogram for optimal min/max values.
|
|
The search for the min/max values ensures the minimization of the
|
|
quantization error with respect to the floating point model.
|
|
3. Compute the scale and zero point the same way as in the
|
|
:class:`~torch.quantization.MinMaxObserver`
|
|
"""
|
|
|
|
def __init__(self, bins=2048, upsample_rate=128, dtype=torch.quint8,
|
|
qscheme=torch.per_tensor_affine, reduce_range=False):
|
|
# bins: The number of bins used for histogram calculation.
|
|
super(HistogramObserver, self).__init__(dtype=dtype,
|
|
qscheme=qscheme,
|
|
reduce_range=reduce_range)
|
|
self.bins = bins
|
|
self.register_buffer('histogram', torch.zeros(self.bins))
|
|
self.register_buffer('min_val', torch.tensor([]))
|
|
self.register_buffer('max_val', torch.tensor([]))
|
|
self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
|
|
self.upsample_rate = upsample_rate
|
|
|
|
@torch.jit.ignore
|
|
def _non_linear_param_search(self):
|
|
r"""Non-linear parameter search.
|
|
|
|
An approximation for L2 error minimization for selecting min/max.
|
|
By selecting new min/max, we filter out outliers in input distribution.
|
|
This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
|
|
caffe2/quantization/server/norm_minimization.cc
|
|
"""
|
|
def _get_norm(delta_begin, delta_end, density, norm_type):
|
|
r"""
|
|
Compute the norm of the values uniformaly distributed between
|
|
delta_begin and delta_end.
|
|
|
|
norm = density * (integral_{begin, end} x^2)
|
|
= density * (end^3 - begin^3) / 3
|
|
"""
|
|
assert norm_type == "L2", "Only L2 norms are currently supported"
|
|
norm = 0.0
|
|
if norm_type == "L2":
|
|
norm = (
|
|
delta_end * delta_end * delta_end
|
|
- delta_begin * delta_begin * delta_begin
|
|
) / 3
|
|
return density * norm
|
|
|
|
def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
|
|
r"""
|
|
Compute the quantization error if we use start_bin to end_bin as the
|
|
min and max to do the quantization.
|
|
"""
|
|
bin_width = (self.max_val.item() - self.min_val.item()) / self.bins
|
|
|
|
norm = 0.0
|
|
dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
|
|
if dst_bin_width == 0.0:
|
|
return 0.0
|
|
for src_bin in range(self.bins):
|
|
# distances from the beginning of first dst_bin to the beginning and
|
|
# end of src_bin
|
|
src_bin_begin = (src_bin - next_start_bin) * bin_width
|
|
src_bin_end = src_bin_begin + bin_width
|
|
|
|
# which dst_bins the beginning and end of src_bin belong to?
|
|
dst_bin_of_begin = min(
|
|
self.dst_nbins - 1, max(0.0, math.floor(src_bin_begin / dst_bin_width))
|
|
)
|
|
dst_bin_of_end = min(
|
|
self.dst_nbins - 1, max(0.0, math.floor(src_bin_end / dst_bin_width))
|
|
)
|
|
dst_bin_of_begin_center = (
|
|
dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
|
|
)
|
|
|
|
density = self.histogram[src_bin] / bin_width
|
|
if dst_bin_of_begin == dst_bin_of_end:
|
|
# if src_bin is entirely within 1 dst_bin
|
|
delta_begin = src_bin_begin - dst_bin_of_begin_center
|
|
delta_end = src_bin_end - dst_bin_of_begin_center
|
|
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
|
|
else:
|
|
delta_begin = src_bin_begin - dst_bin_of_begin_center
|
|
delta_end = dst_bin_width / 2
|
|
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
|
|
|
|
norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
|
|
-dst_bin_width / 2, dst_bin_width / 2, density, norm_type
|
|
)
|
|
|
|
dst_bin_of_end_center = (
|
|
dst_bin_of_end * dst_bin_width + dst_bin_width / 2
|
|
)
|
|
|
|
delta_begin = -dst_bin_width / 2
|
|
delta_end = src_bin_end - dst_bin_of_end_center
|
|
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
|
|
return norm
|
|
|
|
assert self.histogram.size()[0] == self.bins, "bins mistmatch"
|
|
bin_width = (self.max_val - self.min_val) / self.bins
|
|
|
|
# cumulative sum
|
|
total = sum(self.histogram)
|
|
cSum = torch.cumsum(self.histogram, dim=0)
|
|
|
|
stepsize = 1e-5 # granularity
|
|
alpha = 0.0 # lower bound
|
|
beta = 1.0 # upper bound
|
|
start_bin = 0
|
|
end_bin = self.bins - 1
|
|
norm_min = float("inf")
|
|
|
|
while alpha < beta:
|
|
# Find the next step
|
|
next_alpha = alpha + stepsize
|
|
next_beta = beta - stepsize
|
|
|
|
# find the left and right bins between the quantile bounds
|
|
l = start_bin
|
|
r = end_bin
|
|
while l < end_bin and cSum[l] < next_alpha * total:
|
|
l = l + 1
|
|
while r > start_bin and cSum[r] > next_beta * total:
|
|
r = r - 1
|
|
|
|
# decide the next move
|
|
next_start_bin = start_bin
|
|
next_end_bin = end_bin
|
|
if (l - start_bin) > (end_bin - r):
|
|
# move the start bin
|
|
next_start_bin = l
|
|
alpha = next_alpha
|
|
else:
|
|
# move the end bin
|
|
next_end_bin = r
|
|
beta = next_beta
|
|
|
|
if next_start_bin == start_bin and next_end_bin == end_bin:
|
|
continue
|
|
|
|
# calculate the quantization error using next_start_bin and next_end_bin
|
|
norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
|
|
|
|
if norm > norm_min:
|
|
break
|
|
norm_min = norm
|
|
start_bin = next_start_bin
|
|
end_bin = next_end_bin
|
|
|
|
new_min = self.min_val + bin_width * start_bin
|
|
new_max = self.min_val + bin_width * (end_bin + 1)
|
|
return new_min, new_max
|
|
|
|
@torch.jit.ignore
|
|
def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
|
|
# type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor, int, int]
|
|
# We ensure that:
|
|
# (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
|
|
# This allows us to have a common grid of resolution s, where we can align
|
|
# the input histogram
|
|
# start_idx maps min_val to the histogram bin index.
|
|
|
|
hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate)
|
|
downsample_rate = torch.ceil((combined_max - combined_min) / (self.bins * hist_bin_width)).to(torch.int).item()
|
|
e = downsample_rate * (self.bins * hist_bin_width) - (combined_max - combined_min)
|
|
combined_max = combined_max + e / 2
|
|
combined_min = combined_min - e / 2
|
|
start_idx = torch.round((self.min_val - combined_min) / hist_bin_width).to(torch.int).item()
|
|
return combined_min, combined_max, downsample_rate, start_idx
|
|
|
|
@torch.jit.ignore
|
|
def _combine_histograms(self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins):
|
|
# type: (Tensor, Tensor, int, int, int, int) -> Tensor
|
|
# First up-sample the histogram with new data by a factor of L
|
|
# This creates an approximate probability density thats piecwise constant
|
|
upsampled_histogram = new_hist.repeat_interleave(upsample_rate)
|
|
# Now insert the upsampled histogram into the output
|
|
# histogram, which is initialized with zeros.
|
|
# The offset at which the histogram is introduced is determined
|
|
# by the start index as the output histogram can cover a wider range
|
|
histogram_with_output_range = torch.zeros((Nbins * downsample_rate))
|
|
histogram_with_output_range[start_idx:Nbins * upsample_rate + start_idx] = upsampled_histogram
|
|
# Compute integral histogram, double precision is needed to ensure
|
|
# that there are no overflows
|
|
integral_histogram = torch.cumsum(histogram_with_output_range, 0,
|
|
dtype=torch.double)[downsample_rate - 1 :: downsample_rate]
|
|
# Finally perform interpolation
|
|
shifted_integral_histogram = torch.zeros((Nbins))
|
|
shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
|
|
interpolated_histogram = (integral_histogram - shifted_integral_histogram) / upsample_rate
|
|
orig_hist = orig_hist + interpolated_histogram.to(torch.float)
|
|
return orig_hist
|
|
|
|
def forward(self, x_orig):
|
|
# type: (Tensor) -> Tensor
|
|
x = x_orig.detach()
|
|
min_val = self.min_val
|
|
max_val = self.max_val
|
|
if min_val.numel() == 0 or max_val.numel() == 0:
|
|
min_val = torch.min(x)
|
|
max_val = torch.max(x)
|
|
self.min_val = min_val
|
|
self.max_val = max_val
|
|
self.histogram = torch.histc(x, self.bins, min=min_val, max=max_val)
|
|
else:
|
|
new_min = torch.min(x)
|
|
new_max = torch.max(x)
|
|
combined_min = torch.min(new_min, min_val)
|
|
combined_max = torch.max(new_max, max_val)
|
|
# combine the existing histogram and new histogram into 1 histogram
|
|
# We do this by first upsampling the histogram to a dense grid
|
|
# and then downsampling the histogram efficiently
|
|
combined_min, combined_max, downsample_rate, start_idx = \
|
|
self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
|
|
combined_histogram = torch.histc(x, self.bins, min=combined_min, max=combined_max)
|
|
if combined_min == min_val and combined_max == max_val:
|
|
combined_histogram += self.histogram
|
|
else:
|
|
combined_histogram = self._combine_histograms(
|
|
combined_histogram,
|
|
self.histogram,
|
|
self.upsample_rate,
|
|
downsample_rate,
|
|
start_idx,
|
|
self.bins)
|
|
|
|
self.histogram = combined_histogram
|
|
self.min_val = combined_min
|
|
self.max_val = combined_max
|
|
return x
|
|
|
|
@torch.jit.export
|
|
def calculate_qparams(self):
|
|
if self.min_val.numel() == 0 or self.max_val.numel() == 0:
|
|
warnings.warn(
|
|
"must run observer before calling calculate_qparams.\
|
|
Returning default scale and zero point "
|
|
)
|
|
return torch.tensor([1.0]), torch.tensor([0])
|
|
assert self.bins == len(self.histogram), (
|
|
"The number of bins in histogram should be equal to the number of bins "
|
|
"supplied while making this observer"
|
|
)
|
|
|
|
new_min, new_max = self._non_linear_param_search()
|
|
|
|
return self._calculate_qparams(new_min, new_max)
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
super(HistogramObserver, self)._save_to_state_dict(destination, prefix, keep_vars)
|
|
destination[prefix + 'min_val'] = self.min_val
|
|
destination[prefix + 'max_val'] = self.max_val
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
|
|
local_state = ['min_val', 'max_val']
|
|
for name in local_state:
|
|
key = prefix + name
|
|
if key in state_dict:
|
|
val = state_dict[key]
|
|
setattr(self, name, val)
|
|
elif strict:
|
|
missing_keys.append(key)
|
|
super(HistogramObserver, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
class RecordingObserver(_ObserverBase):
|
|
r"""
|
|
The module is mainly for debug and records the tensor values during runtime.
|
|
|
|
Args:
|
|
dtype: Quantized data type
|
|
qscheme: Quantization scheme to be used
|
|
reduce_range: Reduces the range of the quantized data type by 1 bit
|
|
"""
|
|
__annotations__ = {"tensor_val": List[Optional[torch.Tensor]]}
|
|
|
|
def __init__(self, **kwargs):
|
|
super(RecordingObserver, self).__init__(**kwargs)
|
|
self.tensor_val = []
|
|
|
|
def forward(self, x):
|
|
self.tensor_val.append(x.clone())
|
|
return x
|
|
|
|
@torch.jit.export
|
|
def calculate_qparams(self):
|
|
raise Exception("calculate_qparams should not be called for RecordingObserver")
|
|
|
|
@torch.jit.export
|
|
def get_tensor_value(self):
|
|
return self.tensor_val
|
|
|
|
|
|
class NoopObserver(ObserverBase):
|
|
r"""
|
|
Observer that doesn't do anything and just passes its configuration to the
|
|
quantized module's ``.from_float()``.
|
|
|
|
Primarily used for quantization to float16 which doesn't require determining
|
|
ranges.
|
|
|
|
Args:
|
|
dtype: Quantized data type
|
|
"""
|
|
def __init__(self, dtype=torch.float16):
|
|
if dtype != torch.float16:
|
|
raise ValueError("Only float16 quantization can be used without calibration process")
|
|
super(NoopObserver, self).__init__(dtype=dtype)
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
def calculate_qparams(self):
|
|
raise Exception("calculate_qparams should not be called for NoopObserver")
|
|
|
|
def get_qparams(self):
|
|
return self.calculate_qparams()
|
|
|
|
|
|
# Restrict activations to be in the range (0,127)
|
|
default_observer = MinMaxObserver.with_args(reduce_range=True)
|
|
default_debug_observer = RecordingObserver
|
|
default_weight_observer = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
|
|
default_histogram_observer = HistogramObserver.with_args(reduce_range=True)
|
|
default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
|