from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor from torch.ao.quantization import ObserverOrFakeQuantize from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor from torch.fx import Node __all__ = [ "Quantizer", "QuantizationSpecBase", "QuantizationSpec", "FixedQParamsQuantizationSpec", "EdgeOrNode", "SharedQuantizationSpec", "DerivedQuantizationSpec", "QuantizationAnnotation", ] # TODO: maybe remove torch.float32 SUPPORTED_DTYPES = [ torch.uint8, torch.int8, torch.int16, torch.int32, torch.float16, torch.float32, ] SUPPORTED_QSCHEMES = [ torch.per_tensor_affine, torch.per_tensor_symmetric, torch.per_channel_affine, torch.per_channel_symmetric, torch.per_channel_affine_float_qparams, ] class QuantizationSpecBase(ABC): # noqa: B024 """Base class for different types of quantization specs that allows users to specify how to quantize a Tensor (input/output of a Node) in the model """ pass @dataclass(eq=True, frozen=True) class QuantizationSpec(QuantizationSpecBase): """Quantization spec for common operators that allows user to specify how to quantize a Tensor, this includes dtype, quant_min, quant_max etc. """ dtype: torch.dtype # observer or fake_quantize constructor such as # MinMaxObserver, PerChannelHistogramObserver etc. # or we can attach some custom args to them # e.g. MinMaxObserver.with_args(eps=eps) observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor quant_min: Optional[int] = None quant_max: Optional[int] = None qscheme: Optional[torch.qscheme] = None ch_axis: Optional[int] = None is_dynamic: bool = False def __post_init__(self): # check dtype is one of the supported types if self.dtype not in SUPPORTED_DTYPES: raise TypeError(f"Unsupported dtype {self.dtype}.") # quant_min must be less than quant_max if ( self.quant_min is not None and self.quant_max is not None and self.quant_min > self.quant_max ): raise ValueError( f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}." ) # check qscheme is on of the supported ones if self.qscheme is not None and self.qscheme not in SUPPORTED_QSCHEMES: raise ValueError(f"Unsupported qscheme {self.qscheme}.") # ch_axis must be less than the number of channels # but no way to check here. Just check that it is not < 0. if self.ch_axis is not None and self.ch_axis < 0: raise ValueError("Ch_axis is < 0.") @dataclass(eq=True, frozen=True) class FixedQParamsQuantizationSpec(QuantizationSpecBase): dtype: torch.dtype scale: float zero_point: int quant_min: Optional[int] = None quant_max: Optional[int] = None qscheme: Optional[torch.qscheme] = None """ The way we refer to other points of quantization in the graph will be either an input edge or an output value input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node] output value is an fx Node """ EdgeOrNode = Union[Tuple[Node, Node], Node] EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer" @dataclass(eq=True, frozen=True) class SharedQuantizationSpec(QuantizationSpecBase): """ Quantization spec for the Tensors whose quantization parameters are shared with other Tensors """ edge_or_node: EdgeOrNode @dataclass(eq=True, frozen=True) class DerivedQuantizationSpec(QuantizationSpecBase): """Quantization spec for the Tensors whose quantization parameters are derived from other Tensors""" derived_from: List[EdgeOrNode] derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]] dtype: torch.dtype quant_min: Optional[int] = None quant_max: Optional[int] = None qscheme: Optional[torch.qscheme] = None ch_axis: Optional[int] = None @dataclass class QuantizationAnnotation: """How are input arguemnt or output should be quantized, expressed as QuantizationSpec, this corresponds to how a Tensor in the operator Graph is observed (PTQ) or fake quantized (QAT) """ # a map from torch.fx.Node to a type of QuantizationSpecBase input_qspec_map: Dict[Node, QuantizationSpecBase] = field(default_factory=dict) # How the output of this node is quantized, expressed as QuantizationSpec # TODO: change the value to QuantizationSpec in a separate PR output_qspec: Optional[QuantizationSpecBase] = None # whether the node is annotated or not _annotated: bool = False class Quantizer(ABC): # annotate nodes in the graph with observer or fake quant constructors # to convey the desired way of quantization @abstractmethod def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: pass # validate the annotated graph is supported by the backend @abstractmethod def validate(self, model: torch.fx.GraphModule) -> None: pass