mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for prototype affine quantization in pt2e flow (#141421)
Summary:
duplicated affine quantization functionality including
observer (https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py)
and some quant_primitive ops (7c3c51fd0d/torchao/quantization/quant_primitives.py (L26-L30)
)
to allow for per group quantization min max observer in pt2e flow
Next: We can follow up to add moving average min max observer
Test Plan:
python test/test_quantization.py -k test_channel_group_quantization
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141421
Approved by: https://github.com/cccclai
This commit is contained in:
committed by
PyTorch MergeBot
parent
60a0d53c13
commit
ace645a017
@ -250,6 +250,18 @@ the values observed during calibration (PTQ) or training (QAT).
|
||||
default_per_channel_weight_observer
|
||||
default_dynamic_quant_observer
|
||||
default_float_qparams_observer
|
||||
AffineQuantizedObserverBase
|
||||
Granularity
|
||||
MappingType
|
||||
PerAxis
|
||||
PerBlock
|
||||
PerGroup
|
||||
PerRow
|
||||
PerTensor
|
||||
PerToken
|
||||
TorchAODType
|
||||
ZeroPointDomain
|
||||
get_block_size
|
||||
|
||||
torch.ao.quantization.fake_quantize
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
3
mypy.ini
3
mypy.ini
@ -79,6 +79,9 @@ ignore_missing_imports = True
|
||||
[mypy-torch.ao.quantization.experimental.fake_quantize]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch.ao.quantization.pt2e._affine_quantization]
|
||||
ignore_errors = True
|
||||
|
||||
#
|
||||
# Files with various errors. Mostly real errors, possibly some false
|
||||
# positives as well.
|
||||
|
@ -42,7 +42,6 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
||||
)
|
||||
from torch.export import export_for_training
|
||||
from torch.fx import Node
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
PT2EQuantizationTestCase,
|
||||
@ -1865,6 +1864,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
torch.ops.aten.batch_norm.default,
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"device",
|
||||
["cpu"] + (["cuda"] if TEST_CUDA else []) + (["hpu"] if TEST_HPU else []),
|
||||
)
|
||||
def test_move_exported_model_bn(self, device):
|
||||
"""
|
||||
Test switching batch_norm behavior between train and eval modes using
|
||||
@ -2477,9 +2480,90 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
check_nn_module(node)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestQuantizePT2E)
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase):
|
||||
def test_channel_group_quantization(self):
|
||||
from torch.ao.quantization.observer import MappingType, PerGroup, PerToken
|
||||
from torch.ao.quantization.pt2e._affine_quantization import (
|
||||
AffineQuantizedMinMaxObserver,
|
||||
)
|
||||
|
||||
devices = ["cpu", "cuda"]
|
||||
if TEST_HPU:
|
||||
devices.append("hpu")
|
||||
instantiate_device_type_tests(TestQuantizePT2E, globals(), only_for=devices)
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
for node in model.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.linear.default
|
||||
):
|
||||
input_act = node.args[0]
|
||||
assert isinstance(input_act, Node)
|
||||
weight = node.args[1]
|
||||
assert isinstance(weight, Node)
|
||||
|
||||
act_qspec = QuantizationSpec(
|
||||
dtype=torch.uint8,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
qscheme=None,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args(
|
||||
# TODO: maybe align the arg name here
|
||||
target_dtype=torch.uint8,
|
||||
mapping_type=MappingType.SYMMETRIC,
|
||||
granularity=PerToken(),
|
||||
),
|
||||
)
|
||||
|
||||
weight_qspec = QuantizationSpec(
|
||||
dtype=torch.uint8,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
qscheme=None,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args(
|
||||
target_dtype=torch.uint8,
|
||||
mapping_type=MappingType.SYMMETRIC,
|
||||
granularity=PerGroup(group_size=128),
|
||||
),
|
||||
)
|
||||
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map={
|
||||
input_act: act_qspec,
|
||||
weight: weight_qspec,
|
||||
},
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(128, 20)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
node_occurrence = {
|
||||
torch.ops.quant.quantize_affine: 2,
|
||||
torch.ops.quant.dequantize_affine: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quant.quantize_affine,
|
||||
torch.ops.quant.dequantize_affine,
|
||||
torch.ops.quant.quantize_affine,
|
||||
torch.ops.quant.dequantize_affine,
|
||||
]
|
||||
example_inputs = (torch.randn(5, 128),)
|
||||
self._test_quantizer(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
node_occurrence,
|
||||
node_list,
|
||||
is_debug_mode=True,
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestQuantizePT2E)
|
||||
|
@ -87,6 +87,7 @@ try:
|
||||
from quantization.pt2e.test_metadata_porting import TestMetaDataPorting # noqa: F401
|
||||
from quantization.pt2e.test_numeric_debugger import TestNumericDebugger # noqa: F401
|
||||
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2E # noqa: F401
|
||||
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EAffineQuantization # noqa: F401
|
||||
from quantization.pt2e.test_representation import TestPT2ERepresentation # noqa: F401
|
||||
from quantization.pt2e.test_xnnpack_quantizer import TestXNNPACKQuantizer # noqa: F401
|
||||
from quantization.pt2e.test_xnnpack_quantizer import TestXNNPACKQuantizerModels # noqa: F401
|
||||
|
@ -168,6 +168,20 @@ __all__ = [
|
||||
"prepare_for_propagation_comparison",
|
||||
"extract_results_from_loggers",
|
||||
"compare_results",
|
||||
# from torchao, should be merged with torchao
|
||||
# in the future
|
||||
"AffineQuantizedObserverBase",
|
||||
"Granularity",
|
||||
"MappingType",
|
||||
"PerAxis",
|
||||
"PerBlock",
|
||||
"PerGroup",
|
||||
"PerRow",
|
||||
"PerTensor",
|
||||
"PerToken",
|
||||
"TorchAODType",
|
||||
"ZeroPointDomain",
|
||||
"get_block_size",
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,5 +1,8 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
# temporarily skip RUF for this file for now, we can re-enable
|
||||
# after move the affine quantization related things to torchao
|
||||
# noqa: RUF
|
||||
"""
|
||||
This module implements observers which are used to collect statistics about
|
||||
the values observed during calibration (PTQ) or training (QAT).
|
||||
@ -54,6 +57,18 @@ __all__ = [
|
||||
"RecordingObserver",
|
||||
"ReuseInputObserver",
|
||||
"UniformQuantizationObserverBase",
|
||||
"AffineQuantizedObserverBase",
|
||||
"Granularity",
|
||||
"MappingType",
|
||||
"PerAxis",
|
||||
"PerBlock",
|
||||
"PerGroup",
|
||||
"PerRow",
|
||||
"PerTensor",
|
||||
"PerToken",
|
||||
"TorchAODType",
|
||||
"ZeroPointDomain",
|
||||
"get_block_size",
|
||||
]
|
||||
|
||||
|
||||
@ -1584,6 +1599,258 @@ class ReuseInputObserver(ObserverBase):
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
# Experimental Affine Quantization Feature START
|
||||
We plan to merge the following with torchao repo after we move pt2e flow to torchao
|
||||
copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
|
||||
|
||||
class MappingType(Enum):
|
||||
"""How floating point number is mapped to integer number
|
||||
|
||||
symmetric mapping means floating point range is symmetrically mapped to integer range
|
||||
let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4)
|
||||
we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7)
|
||||
e.g. scale = (10.2 - (-10.2)) / (7 - (-8))
|
||||
|
||||
SYMMETRIC_NO_CLIPPING_ERR is a variant of symmetric mapping, where the scale is the max of smin
|
||||
and smax, where smin = min_val_neg / quant_min, and smax = max_val_pos / quant_max. By calculating
|
||||
smin and smax individually, there can be less round error on negative values, and no out-of-range
|
||||
of all floating point values.
|
||||
|
||||
asymmetric mapping means we just directly map the floating point range to integer range,
|
||||
for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter
|
||||
based on this mapping
|
||||
e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
|
||||
"""
|
||||
|
||||
SYMMETRIC = auto()
|
||||
SYMMETRIC_NO_CLIPPING_ERR = auto()
|
||||
ASYMMETRIC = auto()
|
||||
|
||||
|
||||
class ZeroPointDomain(Enum):
|
||||
"""Enum that indicate whether zero_point is in integer domain or floating point domain
|
||||
|
||||
integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
|
||||
float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
|
||||
none domain: quantized_val = (float_val / scale)
|
||||
"""
|
||||
|
||||
INT = auto()
|
||||
FLOAT = auto()
|
||||
NONE = auto()
|
||||
|
||||
|
||||
class TorchAODType(Enum):
|
||||
"""
|
||||
Placeholder for dtypes that do not exist in PyTorch core yet.
|
||||
"""
|
||||
|
||||
# torch.int1 to torch.int7 will be added to PyTorch 2.6
|
||||
# These will remain here for BC with older PyTorch versions
|
||||
INT1 = auto()
|
||||
INT2 = auto()
|
||||
INT3 = auto()
|
||||
INT4 = auto()
|
||||
INT5 = auto()
|
||||
INT6 = auto()
|
||||
INT7 = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Granularity:
|
||||
"""
|
||||
Base class for representing the granularity of quantization.
|
||||
|
||||
This class serves as a parent for specific granularity types used in
|
||||
quantization operations, such as per-tensor or per-axis quantization.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PerBlock(Granularity):
|
||||
"""
|
||||
Represents per-block granularity in quantization. See
|
||||
:func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for
|
||||
`block_size`
|
||||
|
||||
Attributes:
|
||||
block_size (Tuple[int, ...]): The size of each quantization group
|
||||
"""
|
||||
|
||||
block_size: Tuple[int, ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PerTensor(Granularity):
|
||||
"""
|
||||
Represents per-tensor granularity in quantization.
|
||||
|
||||
This granularity type calculates the quantization parameters
|
||||
based off the entire tensor.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PerAxis(Granularity):
|
||||
"""
|
||||
Represents per-axis granularity in quantization.
|
||||
|
||||
This granularity type calculates different quantization parameters
|
||||
along a specified axis of the tensor.
|
||||
|
||||
For example if the input tensor is shape [8, 16] and axis=0, then
|
||||
the quantization parameters are calculated for each row of the tensor.
|
||||
Giving a total of 8 quantization parameters.
|
||||
|
||||
Attributes:
|
||||
axis (int): The axis along which reduction is performed.
|
||||
"""
|
||||
|
||||
axis: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PerGroup(Granularity):
|
||||
"""
|
||||
Represents per-channel group granularity in quantization.
|
||||
|
||||
This granularity type calculates different quantization parameters
|
||||
for each group of <group_size> elements.
|
||||
|
||||
For example if the input tensor is shape [8, 16], and the group size is 4, then
|
||||
the input tensor is reshaped to [64, 4]
|
||||
quantization parameters are calculated for each group of 4 elements,
|
||||
giving a total of 64 quantization parameters.
|
||||
|
||||
Attributes:
|
||||
group_size (int): The size of each quantization group
|
||||
|
||||
"""
|
||||
|
||||
group_size: int
|
||||
|
||||
|
||||
class PerRow(Granularity):
|
||||
"""
|
||||
Represents row-wise granularity in quantization.
|
||||
|
||||
This is a special case of per-axis quantization and is unique to Float8 matmuls
|
||||
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
|
||||
is quantized with a block_size of (1, weight.shape[1]).
|
||||
"""
|
||||
|
||||
|
||||
class PerToken(Granularity):
|
||||
"""
|
||||
Represents per-token granularity in quantization.
|
||||
|
||||
This granularity type calculates a different set of quantization parameters
|
||||
for each token, which is represented as the last dimension of the tensor.
|
||||
|
||||
For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens
|
||||
with 4 elements each, and we will calculate 6 sets of quantization parameters,
|
||||
one for each token.
|
||||
|
||||
If the input tensor has only two dimensions, e.g. [8, 16], then this is
|
||||
equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters.
|
||||
"""
|
||||
|
||||
|
||||
def get_block_size(
|
||||
input_shape: Tuple[int, ...], granularity: Granularity
|
||||
) -> Tuple[int, ...]:
|
||||
"""Get the block size based on the input shape and granularity type.
|
||||
|
||||
Args:
|
||||
input_shape: The input tensor shape possibly more than 2 dimensions
|
||||
granularity: The granularity type of the quantization
|
||||
"""
|
||||
assert isinstance(
|
||||
granularity, Granularity
|
||||
), "Please provide an instance of Granularity, not subclass of it"
|
||||
if isinstance(granularity, PerTensor):
|
||||
return input_shape
|
||||
elif isinstance(granularity, PerAxis):
|
||||
block_size = list(input_shape)
|
||||
block_size[granularity.axis] = 1
|
||||
return tuple(block_size)
|
||||
elif isinstance(granularity, PerRow):
|
||||
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
|
||||
elif isinstance(granularity, PerGroup):
|
||||
assert (
|
||||
len(input_shape) == 2
|
||||
), f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
|
||||
return (1, granularity.group_size)
|
||||
elif isinstance(granularity, PerToken):
|
||||
block_size = list(input_shape)
|
||||
block_size[-1] = input_shape[-1]
|
||||
return tuple(block_size)
|
||||
raise ValueError(f"Unsupported Granularity: {granularity}")
|
||||
|
||||
|
||||
class AffineQuantizedObserverBase(ABC, torch.nn.Module):
|
||||
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
|
||||
|
||||
Args:
|
||||
`granularity` and `block_size`: The granularity of the quantization,
|
||||
must specify at least one, if both are specified `block_size` takes precedence
|
||||
Current supported granularity type are `PerTensor` and `PerAxis`
|
||||
other args: please see `:class:torchao.dtypes.AffineQuantizedTensor`
|
||||
"""
|
||||
|
||||
with_args = classmethod(_with_args)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mapping_type: MappingType,
|
||||
target_dtype: torch.dtype,
|
||||
granularity: Granularity,
|
||||
quant_min: Optional[int] = None,
|
||||
quant_max: Optional[int] = None,
|
||||
eps: Optional[float] = None,
|
||||
scale_dtype: Optional[torch.dtype] = None,
|
||||
zero_point_dtype: Optional[torch.dtype] = None,
|
||||
preserve_zero: bool = True,
|
||||
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
|
||||
# there could be some extra args that's ignored
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert granularity is not None, "granularity is None"
|
||||
|
||||
self.mapping_type = mapping_type
|
||||
self.target_dtype = target_dtype
|
||||
self.granularity = granularity
|
||||
self.quant_min = quant_min
|
||||
self.quant_max = quant_max
|
||||
self.eps = eps
|
||||
self.scale_dtype = scale_dtype
|
||||
self.zero_point_dtype = zero_point_dtype
|
||||
self.preserve_zero = preserve_zero
|
||||
self.zero_point_domain = zero_point_domain
|
||||
# populatd during forward
|
||||
self.block_size = None
|
||||
self.original_dtype = None
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
"""forward function should take the input tensor
|
||||
and updates internal stats and return the original input Tensor
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calculate quantization parameter based on the stats attached to the observer module
|
||||
and returns a tuple of scale and zero_point Tensor
|
||||
"""
|
||||
|
||||
|
||||
def _is_observer_script_module(mod, obs_type_name):
|
||||
"""Returns true if given mod is an instance of Observer script module."""
|
||||
if isinstance(mod, torch.jit.RecursiveScriptModule):
|
||||
@ -1594,10 +1861,17 @@ def _is_observer_script_module(mod, obs_type_name):
|
||||
return False
|
||||
|
||||
|
||||
# Experimental Affine Quantization Feature END
|
||||
|
||||
|
||||
def _is_activation_post_process(module):
|
||||
return isinstance(
|
||||
module,
|
||||
(torch.ao.quantization.ObserverBase, torch.ao.quantization.FakeQuantizeBase),
|
||||
(
|
||||
torch.ao.quantization.ObserverBase,
|
||||
torch.ao.quantization.FakeQuantizeBase,
|
||||
AffineQuantizedObserverBase,
|
||||
),
|
||||
) or _is_observer_script_module(module, "quantization.observer")
|
||||
|
||||
|
||||
|
775
torch/ao/quantization/pt2e/_affine_quantization.py
Normal file
775
torch/ao/quantization/pt2e/_affine_quantization.py
Normal file
@ -0,0 +1,775 @@
|
||||
# copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
|
||||
# and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
|
||||
# PLESE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC
|
||||
import logging
|
||||
from abc import ABCMeta
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.observer import (
|
||||
AffineQuantizedObserverBase,
|
||||
get_block_size,
|
||||
MappingType,
|
||||
TorchAODType,
|
||||
ZeroPointDomain,
|
||||
)
|
||||
from torch.fx import Node
|
||||
|
||||
|
||||
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FP8_TYPES = {
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
}
|
||||
_SUB_BYTE_UINT_BOUNDS = {
|
||||
torch.uint1: (0, 2**1 - 1),
|
||||
torch.uint2: (0, 2**2 - 1),
|
||||
torch.uint3: (0, 2**3 - 1),
|
||||
torch.uint4: (0, 2**4 - 1),
|
||||
torch.uint5: (0, 2**5 - 1),
|
||||
torch.uint6: (0, 2**6 - 1),
|
||||
torch.uint7: (0, 2**7 - 1),
|
||||
}
|
||||
|
||||
"""
|
||||
Map from dtype to the bound value of integers
|
||||
TODO: maybe can replace this with call to torch.iinfo
|
||||
"""
|
||||
_DTYPE_TO_QVALUE_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {
|
||||
torch.uint8: (0, 255),
|
||||
torch.int8: (-128, 127),
|
||||
torch.int16: (-(2**15), 2**15 - 1),
|
||||
torch.int32: (-(2**31), 2**31 - 1),
|
||||
}
|
||||
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS)
|
||||
|
||||
|
||||
def _is_float8_type(dtype: torch.dtype) -> bool:
|
||||
fp8_types = {
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e5m2fnuz,
|
||||
}
|
||||
return dtype in fp8_types
|
||||
|
||||
|
||||
# TODO: decide on if we want to allow custom quant_min/quant_max here
|
||||
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
|
||||
"""Get quant_min and quant_max args based on dtype and also
|
||||
verify that they are within the range of possible quant_min/quant_max
|
||||
for dtype
|
||||
"""
|
||||
if dtype in FP8_TYPES:
|
||||
quant_min_lower_bound, quant_max_upper_bound = (
|
||||
torch.finfo(dtype).min,
|
||||
torch.finfo(dtype).max,
|
||||
)
|
||||
elif dtype not in _DTYPE_TO_QVALUE_BOUNDS:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
else:
|
||||
quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
|
||||
if quant_min is None:
|
||||
quant_min = quant_min_lower_bound
|
||||
if quant_max is None:
|
||||
quant_max = quant_max_upper_bound
|
||||
|
||||
assert quant_min >= quant_min_lower_bound, (
|
||||
"quant_min out of bound for dtype, "
|
||||
f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
|
||||
)
|
||||
|
||||
assert quant_max <= quant_max_upper_bound, (
|
||||
"quant_max out of bound for dtype, "
|
||||
f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
|
||||
)
|
||||
return quant_min, quant_max
|
||||
|
||||
|
||||
def _get_reduction_params(block_size, input_size):
|
||||
"""Given block_size and input size find the parameters for reduction:
|
||||
|
||||
Output:
|
||||
shape_for_reduction: the shape we use to `view` input to prepare it for reduction
|
||||
reduction_dims: the dims we'll do reduction over
|
||||
|
||||
Example::
|
||||
Input:
|
||||
block_size: (3, 3, 2, 10)
|
||||
input_size: (3, 3, 10, 10)
|
||||
|
||||
Output:
|
||||
shape_for_reduction: (3, 3, 5, 2, 10)
|
||||
reduction_dim: [0, 1, 3, 4]
|
||||
"""
|
||||
assert len(block_size) == len(input_size)
|
||||
shape_for_reduction = []
|
||||
reduction_dims = []
|
||||
cur_dim = 0
|
||||
for i in range(len(block_size)):
|
||||
if block_size[i] != input_size[i] and block_size[i] > 1:
|
||||
assert input_size[i] % block_size[i] == 0, (
|
||||
f"Expecting input size at {i} dimension: "
|
||||
f"{input_size[i]} to be divisible by block_size at {i} dimension: {block_size[i]}"
|
||||
)
|
||||
shape_for_reduction.append(input_size[i] // block_size[i])
|
||||
shape_for_reduction.append(block_size[i])
|
||||
# reduce over the block_size[i] dim
|
||||
reduction_dims.append(cur_dim + 1)
|
||||
cur_dim += 2
|
||||
else:
|
||||
# block_size[i] == input_size[i] or block_size[i] == 1
|
||||
shape_for_reduction.append(input_size[i])
|
||||
# we only need to reduce over the dimension if block_size is greater than 1
|
||||
# otherwise it's already the same as reduced dimension
|
||||
if block_size[i] != 1:
|
||||
reduction_dims.append(cur_dim)
|
||||
cur_dim += 1
|
||||
return shape_for_reduction, reduction_dims
|
||||
|
||||
|
||||
def _register_custom_op(lib):
|
||||
"""This decorator is used to preserve some high level operators for torch.export.export
|
||||
while still allow them to be decomposed for inductor path
|
||||
|
||||
requirement: make sure `fn.__name__[1:]` is the operator name you want to register
|
||||
|
||||
NOTE: This should be applied at the top, after all other decorators have been applied
|
||||
NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
|
||||
e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
|
||||
sense for downstream system (like executorch) to accept as well
|
||||
|
||||
Example:
|
||||
lib = torch.library.Library("my_namespace', "FRAGMENT")
|
||||
|
||||
register_custom_op = _register_custom_op(lib)
|
||||
|
||||
@register_custom_op
|
||||
def _the_op_that_needs_to_be_preserved(...)
|
||||
...
|
||||
|
||||
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
|
||||
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
|
||||
# torch.export.export / torch._export.export_for_training
|
||||
|
||||
"""
|
||||
from torch._inductor.decomposition import register_decomposition
|
||||
|
||||
def decorator(fn):
|
||||
from torch._library.infer_schema import infer_schema
|
||||
|
||||
# expecting fn.__name__ starts with `_` and we want to take the rest
|
||||
# to be the name of the custom op
|
||||
assert (
|
||||
fn.__name__[0] == "_"
|
||||
), f"Expecting function name starts with `_`, got {fn.__name__}"
|
||||
assert not any(
|
||||
c in fn.__name__ for c in ".<>"
|
||||
), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}"
|
||||
op_name = fn.__name__[1:]
|
||||
schema = op_name + infer_schema(fn, mutates_args={})
|
||||
lib.define(schema)
|
||||
lib.impl(op_name, fn, "CompositeImplicitAutograd")
|
||||
|
||||
lib_namespace = lib.ns
|
||||
op = getattr(getattr(torch.ops, lib_namespace), op_name)
|
||||
register_decomposition([op])(fn)
|
||||
return op
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
quant_lib = torch.library.Library("quant", "FRAGMENT") # noqa: TOR901
|
||||
|
||||
register_custom_op = _register_custom_op(quant_lib)
|
||||
|
||||
|
||||
def choose_qparams_affine_with_min_max(
|
||||
min_val: torch.Tensor,
|
||||
max_val: torch.Tensor,
|
||||
mapping_type: MappingType,
|
||||
block_size: Tuple[int, ...],
|
||||
target_dtype: torch.dtype,
|
||||
quant_min: Optional[int] = None,
|
||||
quant_max: Optional[int] = None,
|
||||
eps: Optional[float] = None,
|
||||
scale_dtype: Optional[torch.dtype] = None,
|
||||
zero_point_dtype: Optional[torch.dtype] = None,
|
||||
preserve_zero: bool = True,
|
||||
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`
|
||||
operator that pass in min_val and max_val directly instead of deriving these from a single input.
|
||||
This is used for observers in static quantization where min_val and max_val may be obtained through
|
||||
tracking all the data in calibration data set.
|
||||
|
||||
Args:
|
||||
Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one
|
||||
difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val
|
||||
and then scale/zero_point, we pass in min_val/max_val directly
|
||||
"""
|
||||
return _choose_qparams_affine(
|
||||
None,
|
||||
mapping_type.name,
|
||||
block_size,
|
||||
target_dtype,
|
||||
quant_min,
|
||||
quant_max,
|
||||
eps,
|
||||
scale_dtype,
|
||||
zero_point_dtype,
|
||||
preserve_zero,
|
||||
zero_point_domain.name if zero_point_domain is not None else None,
|
||||
min_val,
|
||||
max_val,
|
||||
)
|
||||
|
||||
|
||||
@register_custom_op
|
||||
def _choose_qparams_affine(
|
||||
input: Optional[torch.Tensor],
|
||||
mapping_type: str,
|
||||
block_size: List[int],
|
||||
target_dtype: torch.dtype,
|
||||
quant_min: Optional[Union[int, float, bool]] = None,
|
||||
quant_max: Optional[Union[int, float, bool]] = None,
|
||||
eps: Optional[float] = None,
|
||||
scale_dtype: Optional[torch.dtype] = None,
|
||||
zero_point_dtype: Optional[torch.dtype] = None,
|
||||
preserve_zero: bool = True,
|
||||
zero_point_domain: Optional[str] = "INT",
|
||||
min_val: Optional[torch.Tensor] = None,
|
||||
max_val: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""op definition that has compatible signatures with custom op library
|
||||
|
||||
The op does the following:
|
||||
1. figure out the dimension for reduction based on block_size
|
||||
2. find min_val/max_val based on the dimension for reduction
|
||||
3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero`
|
||||
and `zero_point_domain`
|
||||
"""
|
||||
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
|
||||
assert mapping_type in [
|
||||
MappingType.SYMMETRIC.name,
|
||||
MappingType.SYMMETRIC_NO_CLIPPING_ERR.name,
|
||||
MappingType.ASYMMETRIC.name,
|
||||
], f"Unsupported mapping type: {mapping_type}"
|
||||
if target_dtype in FP8_TYPES:
|
||||
assert (
|
||||
mapping_type == MappingType.SYMMETRIC.name
|
||||
), f"Only symmetric quantization is supported for FP8 types, got {mapping_type}"
|
||||
|
||||
if input is not None:
|
||||
if scale_dtype is None:
|
||||
scale_dtype = input.dtype
|
||||
if zero_point_dtype is None:
|
||||
zero_point_dtype = input.dtype
|
||||
if eps is None:
|
||||
eps = torch.finfo(input.dtype).eps
|
||||
|
||||
assert (
|
||||
len(block_size) == input.dim()
|
||||
), f"Got input dim:{input.dim()}, block_size: {block_size}"
|
||||
shape_for_reduction, reduction_dims = _get_reduction_params(
|
||||
block_size, input.size()
|
||||
)
|
||||
input = input.view(shape_for_reduction)
|
||||
|
||||
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
|
||||
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
|
||||
else:
|
||||
assert (
|
||||
min_val is not None and max_val is not None
|
||||
), "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}"
|
||||
assert (
|
||||
min_val.dtype == max_val.dtype
|
||||
), "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}"
|
||||
|
||||
if scale_dtype is None:
|
||||
scale_dtype = min_val.dtype
|
||||
if zero_point_dtype is None:
|
||||
zero_point_dtype = min_val.dtype
|
||||
if eps is None:
|
||||
eps = torch.finfo(min_val.dtype).eps
|
||||
|
||||
if preserve_zero:
|
||||
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
|
||||
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
|
||||
else:
|
||||
min_val_neg = min_val
|
||||
max_val_pos = max_val
|
||||
|
||||
if (
|
||||
mapping_type == MappingType.SYMMETRIC.name
|
||||
or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
|
||||
):
|
||||
# scales
|
||||
if mapping_type == MappingType.SYMMETRIC.name:
|
||||
max_val_pos = torch.max(-min_val_neg, max_val_pos)
|
||||
scale = max_val_pos / (float(quant_max - quant_min) / 2)
|
||||
else:
|
||||
assert mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
|
||||
# calculate smin and smax individually and choose the larger one. For example, if quant_min = -8 and
|
||||
# quant_max = 7.
|
||||
# - If smin is bigger: There would be coverage on negative values down to -8, and less rounding
|
||||
# error than the existing SYMMETRIC case.
|
||||
# - If smax is bigger: it covers the positive values up to 7. The round
|
||||
# error may be bigger than the existing SYMMETRIC case. Either way, there's no out-of-range fp values after
|
||||
# quantization.
|
||||
smin = min_val_neg / float(quant_min)
|
||||
smax = max_val_pos / float(quant_max)
|
||||
mask = smin > smax
|
||||
scale = torch.where(mask, smin, smax)
|
||||
# zeros
|
||||
if not preserve_zero:
|
||||
raise ValueError(
|
||||
"preserve_zero == False is not supported for symmetric quantization"
|
||||
)
|
||||
if (
|
||||
zero_point_domain is not None
|
||||
and zero_point_domain != ZeroPointDomain.INT.name
|
||||
):
|
||||
raise ValueError(
|
||||
"zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization"
|
||||
)
|
||||
scale = torch.clamp(scale, min=eps)
|
||||
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
|
||||
else:
|
||||
assert mapping_type == MappingType.ASYMMETRIC.name
|
||||
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
|
||||
scale = torch.clamp(scale, min=eps)
|
||||
if zero_point_domain == ZeroPointDomain.NONE.name:
|
||||
zero_point = None
|
||||
else:
|
||||
if preserve_zero:
|
||||
zero_point = quant_min - torch.round(min_val_neg / scale)
|
||||
zero_point = torch.clamp(zero_point, quant_min, quant_max)
|
||||
else:
|
||||
assert (
|
||||
zero_point_domain == ZeroPointDomain.FLOAT.name
|
||||
), "if not preserve_zero, zero_point must be in FLOAT domain"
|
||||
mid_point = (quant_max + quant_min + 1) / 2
|
||||
zero_point = min_val_neg + scale * mid_point
|
||||
|
||||
if zero_point is not None:
|
||||
zero_point = zero_point.to(dtype=zero_point_dtype)
|
||||
return scale.to(dtype=scale_dtype), zero_point
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def quantize_affine(
|
||||
input: torch.Tensor,
|
||||
block_size: Tuple[int, ...],
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor],
|
||||
output_dtype: torch.dtype,
|
||||
quant_min: Optional[Union[int, float]] = None,
|
||||
quant_max: Optional[Union[int, float]] = None,
|
||||
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
|
||||
block_size: (Tuple[int, ...]): granularity of quantization,
|
||||
this means the size of the tensor elements that's sharing the same qparam
|
||||
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
|
||||
scale (float): quantization parameter for affine quantization
|
||||
zero_point (int): quantization parameter for affine quantization
|
||||
output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
|
||||
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
|
||||
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
|
||||
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
|
||||
if zero_point is in integer domain, zero point is added to the quantized integer value during
|
||||
quantization
|
||||
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
|
||||
value during quantization
|
||||
default is ZeroPointDomain.INT
|
||||
|
||||
Note:
|
||||
How can block_size represent different granularities?
|
||||
let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different
|
||||
granularities:
|
||||
|
||||
granularity type | block_size
|
||||
per_tensor | (3, 3, 10, 10)
|
||||
per_axis (axis=0) | (1, 3, 10, 10)
|
||||
per_axis (axis=1) | (3, 1, 10, 10)
|
||||
per_group (groupsize=2) | (3, 3, 10, 2)
|
||||
per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10)
|
||||
|
||||
|
||||
Output:
|
||||
quantized tensor with requested dtype
|
||||
"""
|
||||
return _quantize_affine(
|
||||
input,
|
||||
block_size,
|
||||
scale,
|
||||
zero_point,
|
||||
output_dtype,
|
||||
quant_min,
|
||||
quant_max,
|
||||
zero_point_domain.name if zero_point_domain is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@register_custom_op
|
||||
def _quantize_affine(
|
||||
input: torch.Tensor,
|
||||
block_size: List[int],
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor],
|
||||
output_dtype: torch.dtype,
|
||||
quant_min: Optional[Union[int, float, bool]] = None,
|
||||
quant_max: Optional[Union[int, float, bool]] = None,
|
||||
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
|
||||
) -> torch.Tensor:
|
||||
"""op definition that has compatible signatures with custom op library
|
||||
|
||||
Note:
|
||||
zero_point_domain is optional specifies how we quantize the floating point to quantized data:
|
||||
INT: quantized_val = (float_val / scale) (integer) + zero_point (integer)
|
||||
FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
|
||||
None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization
|
||||
Where we do not want to round values to nearest integer and instead scale and cast.
|
||||
"""
|
||||
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
|
||||
# workaround for uintx dtypes, since we don't have native Uintx dtype connected with
|
||||
# torch.uintx dtypes yet
|
||||
if output_dtype in _SUB_BYTE_UINT_BOUNDS:
|
||||
output_dtype = torch.uint8
|
||||
return _quantize_affine_no_dtype_cast(
|
||||
input,
|
||||
block_size,
|
||||
scale,
|
||||
zero_point,
|
||||
quant_min,
|
||||
quant_max,
|
||||
zero_point_domain,
|
||||
).to(output_dtype)
|
||||
|
||||
|
||||
def _quantize_affine_no_dtype_cast(
|
||||
input: torch.Tensor,
|
||||
block_size: List[int],
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor],
|
||||
quant_min: Union[int, float],
|
||||
quant_max: Union[int, float],
|
||||
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The op does the following:
|
||||
1. figure out the dimension for reduction based on block_size, also reshape the input to align with
|
||||
the shape after reduction
|
||||
2. quantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain
|
||||
3. reshape the quantized result to origianl shape
|
||||
"""
|
||||
# TODO: validations
|
||||
# TODO: validate scale/zero_point dimensions are compatible with block_size
|
||||
assert input.dtype in [
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], f"Unsupported input dtype: {input.dtype}"
|
||||
assert (
|
||||
len(block_size) == input.dim()
|
||||
), f"Got input dim:{input.dim()}, block_size: {block_size}"
|
||||
shape_for_reduction, reduction_dims = _get_reduction_params(
|
||||
block_size, input.size()
|
||||
)
|
||||
original_shape = input.shape
|
||||
input = input.view(shape_for_reduction)
|
||||
shape_after_reduction = shape_for_reduction
|
||||
for i in reduction_dims:
|
||||
shape_after_reduction[i] = 1
|
||||
scale = scale.view(shape_after_reduction)
|
||||
if zero_point is not None:
|
||||
zero_point = zero_point.view(shape_after_reduction)
|
||||
|
||||
if zero_point_domain == ZeroPointDomain.INT.name:
|
||||
quant = torch.clamp(
|
||||
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
|
||||
)
|
||||
elif zero_point_domain == ZeroPointDomain.NONE.name:
|
||||
assert (
|
||||
zero_point is None
|
||||
), "zero_point should be None when zero_point_domain is NONE"
|
||||
quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max)
|
||||
elif zero_point_domain is None:
|
||||
# This case handles quantization for float8 we expect no zero point and no zero point domain
|
||||
assert (
|
||||
zero_point is None
|
||||
), "zero_point should be None when zero_point_domain is None"
|
||||
quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max)
|
||||
else:
|
||||
assert zero_point_domain == ZeroPointDomain.FLOAT.name
|
||||
mid_point = (quant_max + quant_min + 1) / 2
|
||||
min_val = zero_point - scale * mid_point
|
||||
quant = torch.clamp(
|
||||
torch.round((input - min_val) / scale), quant_min, quant_max
|
||||
)
|
||||
quant = quant.view(original_shape)
|
||||
|
||||
return quant
|
||||
|
||||
|
||||
def dequantize_affine(
|
||||
input: torch.Tensor,
|
||||
block_size: Tuple[int, ...],
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor],
|
||||
input_dtype: torch.dtype,
|
||||
quant_min: Optional[Union[int, float]] = None,
|
||||
quant_max: Optional[Union[int, float]] = None,
|
||||
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
|
||||
*,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
|
||||
block_size: (List[int]): granularity of quantization,
|
||||
this means the size of the tensor elements that's sharing the same qparam
|
||||
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
|
||||
scale (Tensor): quantization parameter for affine quantization
|
||||
zero_point (Tensor): quantization parameter for affine quantization
|
||||
input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
|
||||
quant_min (Optional[int]): minimum quantized value for input Tensor
|
||||
quant_max (Optional[int]): maximum quantized value for input Tensor
|
||||
output_dtype (torch.dtype): dtype for output Tensor, default is fp32
|
||||
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
|
||||
if zero_point is in integer domain, zero point is added to the quantized integer value during
|
||||
quantization
|
||||
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
|
||||
value during quantization
|
||||
default is ZeroPointDomain.INT
|
||||
|
||||
Output:
|
||||
dequantized Tensor, with requested dtype or fp32
|
||||
"""
|
||||
return _dequantize_affine(
|
||||
input,
|
||||
block_size,
|
||||
scale,
|
||||
zero_point,
|
||||
input_dtype,
|
||||
quant_min,
|
||||
quant_max,
|
||||
zero_point_domain.name if zero_point_domain is not None else None,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
|
||||
|
||||
@register_custom_op
|
||||
def _dequantize_affine(
|
||||
input: torch.Tensor,
|
||||
block_size: List[int],
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor],
|
||||
input_dtype: torch.dtype,
|
||||
quant_min: Optional[Union[int, float, bool]] = None,
|
||||
quant_max: Optional[Union[int, float, bool]] = None,
|
||||
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""op definition that has compatible signatures with custom op library"""
|
||||
# TODO: validate scale/zero_point dimensions are compatible with block_size
|
||||
if input_dtype not in _SUB_BYTE_UINT_BOUNDS:
|
||||
assert (
|
||||
input.dtype == input_dtype
|
||||
), f"Expected: {input_dtype}, got: {input.dtype}"
|
||||
assert output_dtype in [
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], f"Unsupported output dtype: {output_dtype}"
|
||||
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
|
||||
return _dequantize_affine_no_dtype_check(
|
||||
input,
|
||||
block_size,
|
||||
scale,
|
||||
zero_point,
|
||||
quant_min,
|
||||
quant_max,
|
||||
zero_point_domain,
|
||||
output_dtype,
|
||||
)
|
||||
|
||||
|
||||
def _dequantize_affine_no_dtype_check(
|
||||
input: torch.Tensor,
|
||||
block_size: List[int],
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor],
|
||||
quant_min: Union[int, float],
|
||||
quant_max: Union[int, float],
|
||||
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""This function converts AQT tensors to their high precision floating point representation
|
||||
|
||||
The op does the following:
|
||||
1. figure out the dimension for reduction based on block_size, also reshape the input to align with
|
||||
the shape after reduction
|
||||
2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain
|
||||
3. reshape the quantized result to origianl shape and change dtype to the output_dtype
|
||||
"""
|
||||
assert (
|
||||
len(block_size) == input.dim()
|
||||
), f"Got input dim:{input.dim()}, block_size: {block_size}"
|
||||
shape_for_reduction, reduction_dims = _get_reduction_params(
|
||||
block_size, input.size()
|
||||
)
|
||||
original_shape = input.shape
|
||||
input = input.view(shape_for_reduction)
|
||||
shape_after_reduction = shape_for_reduction
|
||||
for i in reduction_dims:
|
||||
shape_after_reduction[i] = 1
|
||||
scale = scale.view(shape_after_reduction)
|
||||
|
||||
if zero_point is not None:
|
||||
zero_point = zero_point.view(shape_after_reduction)
|
||||
|
||||
if zero_point_domain == ZeroPointDomain.INT.name:
|
||||
# Force a copy to avoid input modification due
|
||||
# to upcoming in-place operations.
|
||||
dequant = input.to(torch.int32, copy=True)
|
||||
if zero_point is not None:
|
||||
dequant = dequant - zero_point.to(torch.int32)
|
||||
dequant = dequant.to(output_dtype)
|
||||
dequant = dequant * scale
|
||||
elif zero_point_domain == ZeroPointDomain.NONE.name:
|
||||
assert (
|
||||
zero_point is None
|
||||
), "zero_point should be None when zero_point_domain is NONE"
|
||||
dequant = input.to(output_dtype)
|
||||
dequant = dequant * scale
|
||||
elif zero_point_domain is None:
|
||||
# This case handles dequantization for float8 we expect no zero point and no zero point domain
|
||||
assert (
|
||||
zero_point is None
|
||||
), "zero_point should be None when zero_point_domain is None"
|
||||
assert _is_float8_type(
|
||||
input.dtype
|
||||
), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}"
|
||||
dequant = input.to(output_dtype)
|
||||
dequant = dequant * scale
|
||||
else:
|
||||
assert (
|
||||
zero_point_domain == ZeroPointDomain.FLOAT.name
|
||||
), f"Unexpected zero point domain: {zero_point_domain}"
|
||||
# TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this)
|
||||
mid_point = (quant_max + quant_min + 1) / 2
|
||||
# This should allocate new memory and avoid input modification
|
||||
dequant = input - mid_point
|
||||
dequant = dequant.to(output_dtype)
|
||||
dequant *= scale
|
||||
if zero_point is not None:
|
||||
dequant += zero_point
|
||||
|
||||
return dequant.view(original_shape).to(output_dtype)
|
||||
|
||||
|
||||
class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase):
|
||||
def forward(self, input: torch.Tensor):
|
||||
if input.numel() == 0:
|
||||
return input
|
||||
|
||||
input_detached = input.detach()
|
||||
self.original_dtype = input_detached.dtype
|
||||
assert self.granularity is not None, "granularity is None"
|
||||
self.block_size = get_block_size(input_detached.shape, self.granularity)
|
||||
|
||||
shape_for_reduction, reduction_dims = _get_reduction_params(
|
||||
self.block_size, input_detached.size()
|
||||
)
|
||||
input_detached = input_detached.view(shape_for_reduction)
|
||||
min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False)
|
||||
max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False)
|
||||
if not hasattr(self, "min_val") or not hasattr(self, "max_val"):
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
else:
|
||||
assert (
|
||||
self.min_val.shape == min_val.shape
|
||||
), f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}"
|
||||
assert (
|
||||
self.max_val.shape == max_val.shape
|
||||
), f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}"
|
||||
min_val = torch.min(self.min_val, min_val)
|
||||
max_val = torch.max(self.max_val, max_val)
|
||||
self.min_val.copy_(min_val)
|
||||
self.max_val.copy_(max_val)
|
||||
# returning original input
|
||||
return input
|
||||
|
||||
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hasattr(self, "min_val") and hasattr(
|
||||
self, "max_val"
|
||||
), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
|
||||
return choose_qparams_affine_with_min_max(
|
||||
self.min_val,
|
||||
self.max_val,
|
||||
self.mapping_type,
|
||||
[], # BlockSize is not needed because the min/max are already reduced
|
||||
self.target_dtype,
|
||||
self.quant_min,
|
||||
self.quant_max,
|
||||
self.eps,
|
||||
self.scale_dtype,
|
||||
self.zero_point_dtype,
|
||||
self.preserve_zero,
|
||||
self.zero_point_domain,
|
||||
)
|
||||
|
||||
def convert(self, model: torch.fx.GraphModule, observer_node: Node):
|
||||
print("calling convert")
|
||||
from torch.ao.quantization.fx.utils import create_getattr_from_value
|
||||
|
||||
scale, zero_point = self.calculate_qparams()
|
||||
with model.graph.inserting_before(observer_node):
|
||||
assert self.block_size is not None, "Expecting block_size to be populated"
|
||||
assert (
|
||||
self.original_dtype is not None
|
||||
), "Expecting original_dtype to be populated"
|
||||
scale_node = create_getattr_from_value(model, model.graph, "_scale", scale)
|
||||
zero_point_node = create_getattr_from_value(
|
||||
model, model.graph, "_zero_point", zero_point
|
||||
)
|
||||
q_node = model.graph.call_function(
|
||||
torch.ops.quant.quantize_affine,
|
||||
(
|
||||
observer_node.args[0],
|
||||
self.block_size,
|
||||
scale_node,
|
||||
zero_point_node,
|
||||
self.target_dtype,
|
||||
self.quant_min,
|
||||
self.quant_max,
|
||||
self.zero_point_domain.name,
|
||||
),
|
||||
{},
|
||||
)
|
||||
dq_node = model.graph.call_function(
|
||||
torch.ops.quant.dequantize_affine,
|
||||
(
|
||||
q_node,
|
||||
self.block_size,
|
||||
scale_node,
|
||||
zero_point_node,
|
||||
self.target_dtype,
|
||||
self.quant_min,
|
||||
self.quant_max,
|
||||
self.zero_point_domain.name,
|
||||
),
|
||||
{"output_dtype": self.original_dtype},
|
||||
)
|
||||
observer_node.replace_all_uses_with(dq_node)
|
||||
model.graph.erase_node(observer_node)
|
@ -1305,6 +1305,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
else:
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
if is_debug_mode:
|
||||
print("prepared model:", m)
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
|
Reference in New Issue
Block a user