[quant] Implemented InputWeightObserver for Linear inputs

Summary: Implemented two observers (InputEqualObserver and WeightEqualObserver) which will be inserted into the graph during prepare_fx().

Test Plan: python test/test_quantization.py TestEqualizeFx

Reviewed By: supriyar

Differential Revision: D28836954

fbshipit-source-id: 25517dc82ae67698ed8b2dc334e3323286976104
This commit is contained in:
Angela Yi
2021-06-07 11:18:02 -07:00
committed by Facebook GitHub Bot
parent c51abf8fca
commit cc03ea2c47
3 changed files with 349 additions and 1 deletions

View File

@ -0,0 +1,112 @@
import torch
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.quantization.fx._equalize import (
_InputEqualizationObserver, _WeightEqualizationObserver, calculate_equalization_scale
)
# Standard Libraries
import numpy as np
# Testing utils
from hypothesis import given
from hypothesis import strategies as st
class TestEqualizeFx(QuantizationTestCase):
@given(input_qdtype=st.sampled_from((torch.qint8, torch.quint8)),
input_qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
weight_qdtype=st.sampled_from((torch.qint8, torch.quint8)),
weight_qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric,
torch.per_channel_affine_float_qparams)))
def test_input_weight_observer(self, input_qdtype, input_qscheme, weight_qdtype, weight_qscheme):
input_obs = _InputEqualizationObserver(dtype=input_qdtype, qscheme=input_qscheme)
weight_obs = _WeightEqualizationObserver(dtype=weight_qdtype, qscheme=weight_qscheme)
width = np.random.randint(1, 10)
x_height = np.random.randint(2, 10)
w_height = np.random.randint(2, 10)
x = (np.random.random(size=(x_height, width)) * 10).round(decimals=2).astype(np.float32)
w = (np.random.random(size=(w_height, width)) * 10).round(decimals=2).astype(np.float32)
ret_x = input_obs(torch.tensor(x))
ret_w = weight_obs(torch.tensor(w))
self.assertEqual((ret_x, ret_w), (x, w))
# Check the min/max input columns are correct
ref_min_inputs = x.min(axis=0)
ref_max_inputs = x.max(axis=0)
self.assertEqual(input_obs.get_input_minmax(), (ref_min_inputs, ref_max_inputs))
# Check the min/max weight columns are correct
ref_min_weights_col = w.min(axis=0)
ref_max_weights_col = w.max(axis=0)
self.assertEqual(weight_obs.get_weight_col_minmax(), (ref_min_weights_col, ref_max_weights_col))
# Check the min/max weight rows are correct
ref_min_weights_row = w.min(axis=1)
ref_max_weights_row = w.max(axis=1)
self.assertEqual(weight_obs.get_weight_row_minmax(), (ref_min_weights_row, ref_max_weights_row))
# Check the column indices of the min/max weight rows are correct
ref_min_weights_ind = w.argmin(axis=1)
ref_max_weights_ind = w.argmax(axis=1)
self.assertEqual((weight_obs.min_weights_ind, weight_obs.max_weights_ind),
(ref_min_weights_ind, ref_max_weights_ind))
# Check the equalization scale is correct
equalization_scale = calculate_equalization_scale(input_obs, weight_obs)
ref_equalization_scale = np.sqrt((ref_max_weights_col - ref_min_weights_col) /
(ref_max_inputs - ref_min_inputs))
self.assertEqual(equalization_scale, ref_equalization_scale)
input_obs.set_equalization_scale(equalization_scale)
weight_obs.set_equalization_scale(equalization_scale)
# check the input scale/zero-point values
input_qparams = input_obs.calculate_qparams()
min_input_scaled = np.min(ref_min_inputs * ref_equalization_scale)
min_input_scaled = min(0, min_input_scaled)
max_input_scaled = np.max(ref_max_inputs * ref_equalization_scale)
max_input_scaled = max(0, max_input_scaled)
if input_qscheme == torch.per_tensor_symmetric:
ref_scale = 2 * max(abs(min_input_scaled), max_input_scaled) / 255
ref_zero_point = 0 if input_qdtype is torch.qint8 else 128
else:
ref_scale = (max_input_scaled - min_input_scaled) / 255
ref_zero_point = -128 if input_qdtype is torch.qint8 else 0
self.assertEqual(input_qparams[0].item(), ref_scale, atol=1e-5, rtol=0)
self.assertEqual(input_qparams[1].item(), ref_zero_point)
# check the weight scale/zero-point values
weight_qparams = weight_obs.calculate_qparams()
min_weights_scaled = ref_min_weights_row * (1 / ref_equalization_scale[ref_min_weights_ind])
max_weights_scaled = ref_max_weights_row * (1 / ref_equalization_scale[ref_max_weights_ind])
if weight_qscheme == torch.per_channel_symmetric:
min_weights_scaled = np.minimum(np.zeros(min_weights_scaled.shape), min_weights_scaled)
max_weights_scaled = np.maximum(np.zeros(max_weights_scaled.shape), max_weights_scaled)
ref_scales = 2 * np.maximum(np.abs(min_weights_scaled), max_weights_scaled) / 255
ref_zero_points = np.zeros_like(
ref_scales) if weight_qdtype is torch.qint8 else np.ones_like(ref_scales) * 128
elif weight_qscheme == torch.per_channel_affine_float_qparams:
ref_scales = (max_weights_scaled - min_weights_scaled) / 255
ref_scales = np.where(ref_scales > 1e-7, ref_scales, np.ones_like(ref_scales))
ref_zero_points = -1 * min_weights_scaled / ref_scales
else:
min_weights_scaled = np.minimum(np.zeros_like(min_weights_scaled), min_weights_scaled)
max_weights_scaled = np.maximum(np.zeros_like(max_weights_scaled), max_weights_scaled)
ref_scales = (max_weights_scaled - min_weights_scaled) / 255
ref_zero_points = -128 if weight_qdtype is torch.qint8 else 0
ref_zero_points = ref_zero_points - np.round(min_weights_scaled / ref_scales)
self.assertTrue(torch.allclose(weight_qparams[0], torch.tensor(
ref_scales, dtype=weight_qparams[0].dtype), atol=0.0001))
self.assertTrue(torch.allclose(weight_qparams[1], torch.tensor(
ref_zero_points, dtype=weight_qparams[1].dtype), atol=1))

View File

@ -79,10 +79,15 @@ try:
except ImportError:
pass
# Equalization for FX mode
try:
from quantization.fx.test_equalize_fx import TestEqualizeFx # noqa: F401
except ImportError:
pass
# Backward Compatibility. Tests serialization and BC for quantized modules.
from quantization.bc.test_backward_compatibility import TestSerialization # noqa: F401
# JIT Graph Mode Quantization
from quantization.jit.test_quantize_jit import TestQuantizeJit # noqa: F401
from quantization.jit.test_quantize_jit import TestQuantizeJitPasses # noqa: F401

View File

@ -0,0 +1,231 @@
import torch
import torch.nn as nn
from torch.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
import warnings
class _InputEqualizationObserver(nn.Module):
r"""Observer for tracking the running min/max values of input columns, and
computing the quantization parameters for the overall min/max input values.
Args:
dtype: Quantized data type
qscheme: Quantization scheme
quant_min: Minimum quantization value. If unspecified, it will
follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will
follow the 8-bit setup.
output_obs: For the user to specify what kind of output observer they
would like to use
The running minimum/maximum :math:`x_\text{min/max}` are computed in the
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`,
with the difference that the running min/max values are stored per column.
The qparams are calculated by multiplying the min/max input column values
with the equalization scale, reducing to find the global min/max input
values, and then calculating in the same way as in
:class:`~torch.quantization.observer.MinMaxObserver`
.. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0.
"""
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
quant_min=None, quant_max=None, output_obs=None,
factory_kwargs=None) -> None:
super(_InputEqualizationObserver, self).__init__()
if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
raise TypeError("Input qscheme must be per-tensor")
self.input_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
qscheme=qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs)
if output_obs is None:
self.output_obs = MinMaxObserver(dtype=dtype,
qscheme=qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs)
else:
self.output_obs = output_obs
self.equalization_scale = torch.empty(0)
def forward(self, x_orig):
# TODO: Allow for convoluational layers
if not (x_orig.ndim == 2):
raise ValueError("InputEqualizationObserver only supports Linear layers")
return self.input_obs(x_orig)
def get_input_minmax(self):
return (self.input_obs.min_vals, self.input_obs.max_vals)
def set_equalization_scale(self, equalization_scale):
self.equalization_scale = equalization_scale
def calculate_qparams(self):
r"""
Returns the scale/zero_point for the input and weight rows
"""
if self.equalization_scale.nelement() == 0:
warnings.warn(
"Must call calculate_scale before calling calculate_qparams.\
Returning default scale and zero point. "
)
return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0])
# Calculate qparams for the scaled min/max inputs
# Scale the input by the equalization scale located at the same column
# index
(min_inputs, max_inputs) = self.get_input_minmax()
min_input_scaled = torch.min(torch.mul(min_inputs, self.equalization_scale))
max_input_scaled = torch.max(torch.mul(max_inputs, self.equalization_scale))
(scale_input, zero_point_input) = self.input_obs._calculate_qparams(min_input_scaled, max_input_scaled)
return scale_input, zero_point_input
class _WeightEqualizationObserver(nn.Module):
r"""Observer for tracking the running min/max values of weight columns and
rows, and computing the quantization parameters for the weight rows.
Args:
dtype: Quantized data type
qscheme: Quantization scheme
quant_min: Minimum quantization value. If unspecified, it will
follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will
follow the 8-bit setup.
This observer is made up of 2 PerChannelMinMaxObservers
- weight_col_obs: Used to record the running minimum and maximum of
columns of incoming weight tensors
- weight_row_obs: Used to record the running minimum and maximum of
rows of incoming weight tensors
The running minimum/maximum :math:`w_\text{min/max}` are computed in the
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`.
The qparams are calculated by multiplying the min/max weight row values
with the inverse of the equalization scale, and then calculating in the same
way as in :class:`~torch.quantization.observer.PerChannelMinMaxObserver`
.. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0.
"""
def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None,
quant_max=None, factory_kwargs=None) -> None:
super(_WeightEqualizationObserver, self).__init__()
self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
qscheme=qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs)
self.weight_row_obs = PerChannelMinMaxObserver(ch_axis=0, dtype=dtype,
qscheme=qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs)
self.equalization_scale = torch.empty(0)
def forward(self, w_orig):
# TODO: Allow for convoluational layers
if not (w_orig.ndim == 2):
raise ValueError("WeightEqualizationObserver only supports Linear layers")
return self._forward(w_orig)
def _forward(self, w_orig):
r"""
Calculates the min/max values of each weight column and weight row.
"""
w_orig = self.weight_col_obs(w_orig)
w_orig = self.weight_row_obs(w_orig)
# Calculate the column indices of the min/max weight in each row
num_row, _ = w_orig.shape
min_weights_ind = []
max_weights_ind = []
for i in range(num_row):
min_weights_ind.append(torch.nonzero(w_orig[i] == self.weight_row_obs.min_vals[i])[0][0])
max_weights_ind.append(torch.nonzero(w_orig[i] == self.weight_row_obs.max_vals[i])[0][0])
self.min_weights_ind = torch.tensor(min_weights_ind)
self.max_weights_ind = torch.tensor(max_weights_ind)
return w_orig
def get_weight_col_minmax(self):
return (self.weight_col_obs.min_vals, self.weight_col_obs.max_vals)
def get_weight_row_minmax(self):
return (self.weight_row_obs.min_vals, self.weight_row_obs.max_vals)
def set_equalization_scale(self, equalization_scale):
self.equalization_scale = equalization_scale
def calculate_qparams(self):
r"""
Returns the scale/zero_point for the input and weight rows
"""
if self.equalization_scale.nelement() == 0:
warnings.warn(
"Must call calculate_scale before calling calculate_qparams.\
Returning default scale and zero point. "
)
return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0])
if self.min_weights_ind is None or self.max_weights_ind is None:
warnings.warn(
"Must find the column indicies of the minimum of each row in the \
weights in order to calculate the qparams calculate the \
qparams. Returning default scale and zero point. "
)
return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0])
# Calculate the qparams for weights by using the rows
# Scale the weight rows by the reciprocal of the equalization scale
# located at the same column index
(min_weights, max_weights) = self.get_weight_row_minmax()
min_weights_scaled = torch.mul(min_weights, torch.reciprocal(self.equalization_scale[self.min_weights_ind]))
max_weights_scaled = torch.mul(max_weights, torch.reciprocal(self.equalization_scale[self.max_weights_ind]))
(scale_weight, zero_point_weight) = self.weight_row_obs._calculate_qparams(min_weights_scaled, max_weights_scaled)
return scale_weight, zero_point_weight
def calculate_equalization_scale(input_obs: _InputEqualizationObserver,
weight_obs: _WeightEqualizationObserver) -> torch.Tensor:
r""" Calculates the equalization scale and sets the equalization_scale value
in the observers.
Args:
input_obs: Observer that tracks the ranges for the input columns
weight_obs: Observer that tracks the ranges for the weight columns
"""
(min_inputs, max_inputs) = input_obs.get_input_minmax()
(min_weights, max_weights) = weight_obs.get_weight_col_minmax()
if not (min_inputs.shape == min_weights.shape):
raise ValueError(
"Input and Weight must have the same column dimension. " +
f"Found {min_inputs.shape} and {max_inputs.shape} instead."
)
equalization_scale = torch.sqrt((max_weights - min_weights) / (max_inputs - min_inputs))
return equalization_scale