mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: The planned e2e for quantization in pytorch 2.0 export is the following: float_model -> prepare_pt2e -> calibration -> convert_pt2e -> ... inside convert_pt2e, we will first produce a q/dq representation of the quantized model, similar to the previous output of convert_to_reference_fx in fx grah mode quantization: ``` torch.ops.quantized_decomposed.dequantize_per_tensor -> torch.ops.aten.add -> torch.ops.quantized_decomopsed.quantize_per_tensor torch.ops.quantized_decomposed.dequantize_per_tensor / ``` Then we'll rewrite the above to a more precise representation that express the intention in a more precise manner, since here we actually want to do int8 addition, instead of simulating the int8 addition with fp32 operations, the representation for quantized add is: ``` def quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point): x = (x_scale / out_scale) * x_i8 y = (y_scale / out_scale) * y_i8 out = x + y out -= (x_zero_point * x_scale - y_zero_point * y_scale) / out_scale out += out_zero_point return out ``` Test Plan: ``` buck2 test caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_representation_add (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2E)' ``` Reviewed By: kimishpatel Differential Revision: D45628032 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104130 Approved by: https://github.com/kimishpatel
422 lines
18 KiB
Python
422 lines
18 KiB
Python
import torch
|
|
from torch.library import Library, impl
|
|
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
|
|
from typing import Tuple
|
|
|
|
|
|
# Note: decomposed means decomposed quantized tensor, using decomposed so that the
|
|
# name is not too long
|
|
quantized_decomposed_lib = Library("quantized_decomposed", "DEF")
|
|
|
|
_DTYPE_TO_QVALUE_BOUNDS = {
|
|
torch.uint8: (0, 255),
|
|
torch.int8: (-128, 127),
|
|
torch.int32: (-(2**31), 2**31 - 1)
|
|
}
|
|
|
|
# Helper to check the passed in quant min and max are valid for the dtype
|
|
def _quant_min_max_bounds_check(quant_min, quant_max, dtype):
|
|
if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
|
|
raise ValueError(f"Unsupported dtype: {dtype}")
|
|
quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
|
|
|
|
assert quant_min >= quant_min_lower_bound, \
|
|
"quant_min out of bound for dtype, " \
|
|
f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
|
|
|
|
assert quant_max <= quant_max_upper_bound, \
|
|
"quant_max out of bound for dtype, " \
|
|
f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
|
|
|
|
quantized_decomposed_lib.define(
|
|
"quantize_per_tensor(Tensor input, float scale, int zero_point, "
|
|
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
|
|
|
|
@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd")
|
|
def quantize_per_tensor(
|
|
input: torch.Tensor,
|
|
scale: float,
|
|
zero_point: int,
|
|
quant_min: int,
|
|
quant_max: int,
|
|
dtype: torch.dtype
|
|
) -> torch.Tensor:
|
|
""" Affine quantization for the Tensor using the same quantization parameters to map
|
|
from floating point to quantized values
|
|
|
|
Args:
|
|
input (torch.Tensor): original float32 Tensor
|
|
scale (float): quantization parameter for affine quantization
|
|
zero_point (int): quantization parameter for affine quantization
|
|
quant_min (int): minimum quantized value for output Tensor
|
|
quant_max (int): maximum quantized value for output Tensor
|
|
dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
|
|
|
|
Returns:
|
|
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
|
|
are not stored in the Tensor, we are storing them in function arguments instead
|
|
"""
|
|
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
|
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
|
|
|
inv_scale = 1.0 / scale
|
|
return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype)
|
|
|
|
quantized_decomposed_lib.define(
|
|
"quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
|
|
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
|
|
|
|
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd")
|
|
def quantize_per_tensor_tensor(
|
|
input: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
zero_point: torch.Tensor,
|
|
quant_min: int,
|
|
quant_max: int,
|
|
dtype: torch.dtype
|
|
) -> torch.Tensor:
|
|
""" Affine quantization for the Tensor using the same quantization parameters to map
|
|
from floating point to quantized values
|
|
Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
|
|
scalar values
|
|
"""
|
|
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
|
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
|
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
|
|
|
|
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
|
|
def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
|
|
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
|
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
|
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
|
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
|
return torch.empty_like(input, dtype=dtype)
|
|
|
|
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
|
|
# the signature as metadata for the input Tensor, this might be useful for pattern
|
|
# matching in the future
|
|
# We will revisit this later if we found there are no use cases for it
|
|
quantized_decomposed_lib.define(
|
|
"dequantize_per_tensor(Tensor input, float scale, int zero_point, "
|
|
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
|
|
|
|
@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
|
|
def dequantize_per_tensor(
|
|
input: torch.Tensor,
|
|
scale: float,
|
|
zero_point: int,
|
|
quant_min: int,
|
|
quant_max: int,
|
|
dtype: torch.dtype
|
|
) -> torch.Tensor:
|
|
""" Affine dequantization for the Tensor using the same quantization parameters to map
|
|
from quantized values to floating point values
|
|
|
|
Args:
|
|
input (torch.Tensor): Tensor with dtype matching `dtype` argument,
|
|
e.g. (`torch.uint8`), it is a per tensor quantized Tensor if combined with
|
|
quantization parameters in the argument of this function (scale/zero_point)
|
|
|
|
scale (float): quantization parameter for affine quantization
|
|
|
|
zero_point (int): quantization parameter for affine quantization
|
|
|
|
quant_min (int): minimum quantized value for input Tensor (not used in computation,
|
|
reserved for pattern matching)
|
|
|
|
quant_max (int): maximum quantized value for input Tensor (not used in computation,
|
|
reserved for pattern matching)
|
|
|
|
dtype (torch.dtype): dtype for input Tensor (not used in computation,
|
|
reserved for pattern matching)
|
|
|
|
Returns:
|
|
dequantized float32 Tensor
|
|
"""
|
|
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
|
|
if dtype in [torch.uint8, torch.int8, torch.int32]:
|
|
# TODO: investigate why
|
|
# (input - zero_point).to(torch.float32) * scale
|
|
# failed the test
|
|
return (input.to(torch.float32) - zero_point) * scale
|
|
else:
|
|
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
|
|
|
|
|
|
quantized_decomposed_lib.define(
|
|
"dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
|
|
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
|
|
|
|
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "CompositeExplicitAutograd")
|
|
def dequantize_per_tensor_tensor(
|
|
input: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
zero_point: torch.Tensor,
|
|
quant_min: int,
|
|
quant_max: int,
|
|
dtype: torch.dtype
|
|
) -> torch.Tensor:
|
|
""" Affine dequantization for the Tensor using the same quantization parameters to map
|
|
from quantized values to floating point values
|
|
Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
|
|
scalar values
|
|
"""
|
|
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
|
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
|
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
|
|
|
|
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
|
|
def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
|
|
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
|
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
|
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
|
|
if dtype in [torch.uint8, torch.int8, torch.int32]:
|
|
return torch.empty_like(input, dtype=torch.float32)
|
|
else:
|
|
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
|
|
|
|
|
|
quantized_decomposed_lib.define(
|
|
"choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
|
|
"float eps, ScalarType dtype) -> (Tensor, Tensor)")
|
|
|
|
@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd")
|
|
def choose_qparams_tensor(
|
|
input: torch.Tensor,
|
|
qmin: int,
|
|
qmax: int,
|
|
eps: float,
|
|
dtype: torch.dtype
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
""" Given an input Tensor, derive the per tensor affine quantization parameter
|
|
(scale and zero_point) for target quantized Tensor from the Tensor
|
|
|
|
Args:
|
|
input (torch.Tensor): floating point input Tensor
|
|
quant_min (int): minimum quantized value for target quantized Tensor
|
|
quant_max (int): maximum quantized value for target quantized Tensor
|
|
dtype (torch.dtype): dtype for target quantized Tensor
|
|
|
|
Returns:
|
|
scale (float): quantization parameter for the target quantized Tensor
|
|
zero_point (int): quantization parameter for the target quantized Tensor
|
|
"""
|
|
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
|
assert dtype == torch.int8 or dtype == torch.uint8 or dtype == torch.int32, \
|
|
f"Expecting target dtype to be int8 uint8 or int32, but got: {dtype}"
|
|
validate_qmin_qmax(qmin, qmax)
|
|
|
|
min_val, max_val = torch.aminmax(input)
|
|
|
|
return determine_qparams(
|
|
min_val, max_val, qmin, qmax, dtype, torch.Tensor([eps]), has_customized_qrange=False)
|
|
|
|
quantized_decomposed_lib.define(
|
|
"choose_qparams_symmetric.tensor(Tensor input, int quant_min, int quant_max, "
|
|
"float eps, ScalarType dtype) -> (Tensor, Tensor)")
|
|
|
|
@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "CompositeExplicitAutograd")
|
|
def choose_qparams_symmetric_tensor(
|
|
input: torch.Tensor,
|
|
qmin: int,
|
|
qmax: int,
|
|
eps: float,
|
|
dtype: torch.dtype
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
""" Given an input Tensor, derive the per tensor affine quantization parameter
|
|
(scale and zero_point) for target quantized Tensor from the Tensor
|
|
|
|
Args:
|
|
input (torch.Tensor): floating point input Tensor
|
|
quant_min (int): minimum quantized value for target quantized Tensor
|
|
quant_max (int): maximum quantized value for target quantized Tensor
|
|
dtype (torch.dtype): dtype for target quantized Tensor
|
|
|
|
Returns:
|
|
scale (float): quantization parameter for the target quantized Tensor
|
|
zero_point (int): quantization parameter for the target quantized Tensor
|
|
"""
|
|
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
|
assert dtype == torch.int8 or dtype == torch.uint8 or dtype == torch.int32, \
|
|
f"Expecting target dtype to be int8 uint8 or int32, but got: {dtype}"
|
|
validate_qmin_qmax(qmin, qmax)
|
|
|
|
min_val, max_val = torch.aminmax(input)
|
|
return determine_qparams(
|
|
min_val,
|
|
max_val,
|
|
qmin,
|
|
qmax,
|
|
dtype,
|
|
torch.Tensor([eps]),
|
|
has_customized_qrange=False,
|
|
qscheme=torch.per_tensor_symmetric
|
|
)
|
|
|
|
@impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta")
|
|
def choose_qparams_tensor_meta(
|
|
input: torch.Tensor,
|
|
quant_min: int,
|
|
quant_max: int,
|
|
eps: float,
|
|
dtype: torch.dtype
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
|
assert quant_min < quant_max, f"Expecting quant_min to be smaller than quant_max but received min: \
|
|
{quant_min} max: {quant_max}"
|
|
return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(1, dtype=torch.int64, device=input.device)
|
|
|
|
@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "Meta")
|
|
def choose_qparams_symmetric_tensor_meta(
|
|
input: torch.Tensor,
|
|
quant_min: int,
|
|
quant_max: int,
|
|
eps: float,
|
|
dtype: torch.dtype
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(1, dtype=torch.int64, device=input.device)
|
|
|
|
# Helper function used to implement per-channel quantization against any axis
|
|
def _permute_to_axis_zero(x, axis):
|
|
new_axis_list = list(range(x.dim()))
|
|
new_axis_list[axis] = 0
|
|
new_axis_list[0] = axis
|
|
y = x.permute(tuple(new_axis_list))
|
|
return y, new_axis_list
|
|
|
|
quantized_decomposed_lib.define(
|
|
"quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
|
|
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
|
|
|
|
@impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd")
|
|
def quantize_per_channel(
|
|
input: torch.Tensor,
|
|
scales: torch.Tensor,
|
|
zero_points: torch.Tensor,
|
|
axis: int,
|
|
quant_min: int,
|
|
quant_max: int,
|
|
dtype: torch.dtype
|
|
) -> torch.Tensor:
|
|
""" Affine per channel quantization for the Tensor using the same quantization
|
|
parameters for each channel/axis to map from floating point to quantized values
|
|
|
|
Args:
|
|
input (torch.Tensor): original float32 Tensor
|
|
scales (torch.Tensor): a list of scale quantization parameter for
|
|
affine quantization, one per channel
|
|
zero_point (torch.Tensor): a list of zero_point quantization parameter for
|
|
affine quantization, one per channel
|
|
quant_min (int): minimum quantized value for output Tensor
|
|
quant_max (int): maximum quantized value for output Tensor
|
|
dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
|
|
|
|
Returns:
|
|
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
|
|
are not stored in the Tensor, we are storing them in function arguments instead
|
|
"""
|
|
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
|
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
|
|
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
|
input, permute_axis_list = _permute_to_axis_zero(input, axis)
|
|
res = torch.zeros_like(input)
|
|
|
|
for i in range(input.size(0)):
|
|
res[i] = torch.clamp(
|
|
torch.round(input[i] * (1.0 / scales[i])) + zero_points[i],
|
|
quant_min,
|
|
quant_max
|
|
)
|
|
|
|
out = res.permute(tuple(permute_axis_list))
|
|
return out.to(dtype)
|
|
|
|
@impl(quantized_decomposed_lib, "quantize_per_channel", "Meta")
|
|
def quantize_per_channel_meta(
|
|
input: torch.Tensor,
|
|
scales: torch.Tensor,
|
|
zero_points: torch.Tensor,
|
|
axis: int,
|
|
quant_min: int,
|
|
quant_max: int,
|
|
dtype: torch.dtype
|
|
) -> torch.Tensor:
|
|
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
|
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
|
|
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
|
return torch.empty_like(input, dtype=dtype)
|
|
|
|
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
|
|
# the signature as metadata for the input Tensor, this might be useful for pattern
|
|
# matching in the future
|
|
# We will revisit this later if we found there are no use cases for it
|
|
quantized_decomposed_lib.define(
|
|
"dequantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
|
|
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
|
|
|
|
@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd")
|
|
def dequantize_per_channel(
|
|
input: torch.Tensor,
|
|
scales: torch.Tensor,
|
|
zero_points: torch.Tensor,
|
|
axis: int,
|
|
quant_min: int,
|
|
quant_max: int,
|
|
dtype: torch.dtype
|
|
) -> torch.Tensor:
|
|
""" Affine per channel dequantization for the Tensor using the same quantization
|
|
parameters for each channel/axis to map from quantized values to floating point values
|
|
|
|
Args:
|
|
input (torch.Tensor): Tensor with dtype matching `dtype` argument,
|
|
e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with
|
|
quantization parameter in the argument of this function (scales/zero_points/axis)
|
|
|
|
scales (torch.Tensor): a list of scale quantization parameter for
|
|
affine quantization, one per channel
|
|
|
|
zero_points (torch.Tensor): a list of zero_point quantization parameter for
|
|
affine quantization, one per channel
|
|
|
|
quant_min (int): minimum quantized value for output Tensor (not used in computation,
|
|
reserved for pattern matching)
|
|
|
|
quant_max (int): maximum quantized value for output Tensor (not used in computation,
|
|
reserved for pattern matching)
|
|
|
|
dtype (torch.dtype): requested dtype for output Tensor (not used in computation,
|
|
reserved for pattern matching)
|
|
|
|
Returns:
|
|
dequantized float32 Tensor
|
|
"""
|
|
assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
|
|
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
|
|
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
|
input, permute_axis_list = _permute_to_axis_zero(input, axis)
|
|
res = torch.zeros_like(input, dtype=torch.float32)
|
|
|
|
for i in range(input.size(0)):
|
|
# TODO: investigate why
|
|
# (input[i] - zero_points[i]).to(torch.float32) * scales[i]
|
|
# failed the test
|
|
res[i] = (input[i].to(torch.float32) - zero_points[i]) * scales[i]
|
|
|
|
out = res.permute(tuple(permute_axis_list))
|
|
return out
|
|
|
|
@impl(quantized_decomposed_lib, "dequantize_per_channel", "Meta")
|
|
def dequantize_per_channel_meta(
|
|
input: torch.Tensor,
|
|
scales: torch.Tensor,
|
|
zero_points: torch.Tensor,
|
|
axis: int,
|
|
quant_min: int,
|
|
quant_max: int,
|
|
dtype: torch.dtype
|
|
) -> torch.Tensor:
|
|
assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
|
|
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
|
|
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
|
return torch.empty_like(input, dtype=torch.float32)
|