mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant] Implemented InputWeightObserver for Linear inputs
Summary: Implemented two observers (InputEqualObserver and WeightEqualObserver) which will be inserted into the graph during prepare_fx(). Test Plan: python test/test_quantization.py TestEqualizeFx Reviewed By: supriyar Differential Revision: D28836954 fbshipit-source-id: 25517dc82ae67698ed8b2dc334e3323286976104
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c51abf8fca
commit
cc03ea2c47
231
torch/quantization/fx/_equalize.py
Normal file
231
torch/quantization/fx/_equalize.py
Normal file
@ -0,0 +1,231 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
class _InputEqualizationObserver(nn.Module):
|
||||
r"""Observer for tracking the running min/max values of input columns, and
|
||||
computing the quantization parameters for the overall min/max input values.
|
||||
|
||||
Args:
|
||||
dtype: Quantized data type
|
||||
qscheme: Quantization scheme
|
||||
quant_min: Minimum quantization value. If unspecified, it will
|
||||
follow the 8-bit setup.
|
||||
quant_max: Maximum quantization value. If unspecified, it will
|
||||
follow the 8-bit setup.
|
||||
output_obs: For the user to specify what kind of output observer they
|
||||
would like to use
|
||||
|
||||
The running minimum/maximum :math:`x_\text{min/max}` are computed in the
|
||||
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`,
|
||||
with the difference that the running min/max values are stored per column.
|
||||
|
||||
The qparams are calculated by multiplying the min/max input column values
|
||||
with the equalization scale, reducing to find the global min/max input
|
||||
values, and then calculating in the same way as in
|
||||
:class:`~torch.quantization.observer.MinMaxObserver`
|
||||
|
||||
.. note:: If the running minimum equals to the running maximum, the scales
|
||||
and zero_points are set to 1.0 and 0.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
|
||||
quant_min=None, quant_max=None, output_obs=None,
|
||||
factory_kwargs=None) -> None:
|
||||
super(_InputEqualizationObserver, self).__init__()
|
||||
|
||||
if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
|
||||
raise TypeError("Input qscheme must be per-tensor")
|
||||
|
||||
self.input_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
|
||||
qscheme=qscheme,
|
||||
quant_min=quant_min,
|
||||
quant_max=quant_max,
|
||||
factory_kwargs=factory_kwargs)
|
||||
|
||||
if output_obs is None:
|
||||
self.output_obs = MinMaxObserver(dtype=dtype,
|
||||
qscheme=qscheme,
|
||||
quant_min=quant_min,
|
||||
quant_max=quant_max,
|
||||
factory_kwargs=factory_kwargs)
|
||||
else:
|
||||
self.output_obs = output_obs
|
||||
|
||||
self.equalization_scale = torch.empty(0)
|
||||
|
||||
def forward(self, x_orig):
|
||||
# TODO: Allow for convoluational layers
|
||||
if not (x_orig.ndim == 2):
|
||||
raise ValueError("InputEqualizationObserver only supports Linear layers")
|
||||
|
||||
return self.input_obs(x_orig)
|
||||
|
||||
def get_input_minmax(self):
|
||||
return (self.input_obs.min_vals, self.input_obs.max_vals)
|
||||
|
||||
def set_equalization_scale(self, equalization_scale):
|
||||
self.equalization_scale = equalization_scale
|
||||
|
||||
def calculate_qparams(self):
|
||||
r"""
|
||||
Returns the scale/zero_point for the input and weight rows
|
||||
"""
|
||||
|
||||
if self.equalization_scale.nelement() == 0:
|
||||
warnings.warn(
|
||||
"Must call calculate_scale before calling calculate_qparams.\
|
||||
Returning default scale and zero point. "
|
||||
)
|
||||
return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0])
|
||||
|
||||
# Calculate qparams for the scaled min/max inputs
|
||||
# Scale the input by the equalization scale located at the same column
|
||||
# index
|
||||
(min_inputs, max_inputs) = self.get_input_minmax()
|
||||
min_input_scaled = torch.min(torch.mul(min_inputs, self.equalization_scale))
|
||||
max_input_scaled = torch.max(torch.mul(max_inputs, self.equalization_scale))
|
||||
(scale_input, zero_point_input) = self.input_obs._calculate_qparams(min_input_scaled, max_input_scaled)
|
||||
|
||||
return scale_input, zero_point_input
|
||||
|
||||
|
||||
class _WeightEqualizationObserver(nn.Module):
|
||||
r"""Observer for tracking the running min/max values of weight columns and
|
||||
rows, and computing the quantization parameters for the weight rows.
|
||||
|
||||
Args:
|
||||
dtype: Quantized data type
|
||||
qscheme: Quantization scheme
|
||||
quant_min: Minimum quantization value. If unspecified, it will
|
||||
follow the 8-bit setup.
|
||||
quant_max: Maximum quantization value. If unspecified, it will
|
||||
follow the 8-bit setup.
|
||||
|
||||
This observer is made up of 2 PerChannelMinMaxObservers
|
||||
- weight_col_obs: Used to record the running minimum and maximum of
|
||||
columns of incoming weight tensors
|
||||
- weight_row_obs: Used to record the running minimum and maximum of
|
||||
rows of incoming weight tensors
|
||||
|
||||
The running minimum/maximum :math:`w_\text{min/max}` are computed in the
|
||||
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`.
|
||||
|
||||
The qparams are calculated by multiplying the min/max weight row values
|
||||
with the inverse of the equalization scale, and then calculating in the same
|
||||
way as in :class:`~torch.quantization.observer.PerChannelMinMaxObserver`
|
||||
|
||||
.. note:: If the running minimum equals to the running maximum, the scales
|
||||
and zero_points are set to 1.0 and 0.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None,
|
||||
quant_max=None, factory_kwargs=None) -> None:
|
||||
super(_WeightEqualizationObserver, self).__init__()
|
||||
|
||||
self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
|
||||
qscheme=qscheme,
|
||||
quant_min=quant_min,
|
||||
quant_max=quant_max,
|
||||
factory_kwargs=factory_kwargs)
|
||||
|
||||
self.weight_row_obs = PerChannelMinMaxObserver(ch_axis=0, dtype=dtype,
|
||||
qscheme=qscheme,
|
||||
quant_min=quant_min,
|
||||
quant_max=quant_max,
|
||||
factory_kwargs=factory_kwargs)
|
||||
|
||||
self.equalization_scale = torch.empty(0)
|
||||
|
||||
def forward(self, w_orig):
|
||||
# TODO: Allow for convoluational layers
|
||||
if not (w_orig.ndim == 2):
|
||||
raise ValueError("WeightEqualizationObserver only supports Linear layers")
|
||||
|
||||
return self._forward(w_orig)
|
||||
|
||||
def _forward(self, w_orig):
|
||||
r"""
|
||||
Calculates the min/max values of each weight column and weight row.
|
||||
"""
|
||||
|
||||
w_orig = self.weight_col_obs(w_orig)
|
||||
w_orig = self.weight_row_obs(w_orig)
|
||||
|
||||
# Calculate the column indices of the min/max weight in each row
|
||||
num_row, _ = w_orig.shape
|
||||
min_weights_ind = []
|
||||
max_weights_ind = []
|
||||
for i in range(num_row):
|
||||
min_weights_ind.append(torch.nonzero(w_orig[i] == self.weight_row_obs.min_vals[i])[0][0])
|
||||
max_weights_ind.append(torch.nonzero(w_orig[i] == self.weight_row_obs.max_vals[i])[0][0])
|
||||
self.min_weights_ind = torch.tensor(min_weights_ind)
|
||||
self.max_weights_ind = torch.tensor(max_weights_ind)
|
||||
|
||||
return w_orig
|
||||
|
||||
def get_weight_col_minmax(self):
|
||||
return (self.weight_col_obs.min_vals, self.weight_col_obs.max_vals)
|
||||
|
||||
def get_weight_row_minmax(self):
|
||||
return (self.weight_row_obs.min_vals, self.weight_row_obs.max_vals)
|
||||
|
||||
def set_equalization_scale(self, equalization_scale):
|
||||
self.equalization_scale = equalization_scale
|
||||
|
||||
def calculate_qparams(self):
|
||||
r"""
|
||||
Returns the scale/zero_point for the input and weight rows
|
||||
"""
|
||||
|
||||
if self.equalization_scale.nelement() == 0:
|
||||
warnings.warn(
|
||||
"Must call calculate_scale before calling calculate_qparams.\
|
||||
Returning default scale and zero point. "
|
||||
)
|
||||
return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0])
|
||||
|
||||
if self.min_weights_ind is None or self.max_weights_ind is None:
|
||||
warnings.warn(
|
||||
"Must find the column indicies of the minimum of each row in the \
|
||||
weights in order to calculate the qparams calculate the \
|
||||
qparams. Returning default scale and zero point. "
|
||||
)
|
||||
return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0])
|
||||
|
||||
# Calculate the qparams for weights by using the rows
|
||||
# Scale the weight rows by the reciprocal of the equalization scale
|
||||
# located at the same column index
|
||||
(min_weights, max_weights) = self.get_weight_row_minmax()
|
||||
min_weights_scaled = torch.mul(min_weights, torch.reciprocal(self.equalization_scale[self.min_weights_ind]))
|
||||
max_weights_scaled = torch.mul(max_weights, torch.reciprocal(self.equalization_scale[self.max_weights_ind]))
|
||||
(scale_weight, zero_point_weight) = self.weight_row_obs._calculate_qparams(min_weights_scaled, max_weights_scaled)
|
||||
|
||||
return scale_weight, zero_point_weight
|
||||
|
||||
|
||||
def calculate_equalization_scale(input_obs: _InputEqualizationObserver,
|
||||
weight_obs: _WeightEqualizationObserver) -> torch.Tensor:
|
||||
r""" Calculates the equalization scale and sets the equalization_scale value
|
||||
in the observers.
|
||||
|
||||
Args:
|
||||
input_obs: Observer that tracks the ranges for the input columns
|
||||
weight_obs: Observer that tracks the ranges for the weight columns
|
||||
"""
|
||||
|
||||
(min_inputs, max_inputs) = input_obs.get_input_minmax()
|
||||
(min_weights, max_weights) = weight_obs.get_weight_col_minmax()
|
||||
|
||||
if not (min_inputs.shape == min_weights.shape):
|
||||
raise ValueError(
|
||||
"Input and Weight must have the same column dimension. " +
|
||||
f"Found {min_inputs.shape} and {max_inputs.shape} instead."
|
||||
)
|
||||
|
||||
equalization_scale = torch.sqrt((max_weights - min_weights) / (max_inputs - min_inputs))
|
||||
|
||||
return equalization_scale
|
Reference in New Issue
Block a user