mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: The planned e2e for quantization in pytorch 2.0 export is the following: float_model -> prepare_pt2e -> calibration -> convert_pt2e -> ... inside convert_pt2e, we will first produce a q/dq representation of the quantized model, similar to the previous output of convert_to_reference_fx in fx grah mode quantization: ``` torch.ops.quantized_decomposed.dequantize_per_tensor -> torch.ops.aten.add -> torch.ops.quantized_decomopsed.quantize_per_tensor torch.ops.quantized_decomposed.dequantize_per_tensor / ``` Then we'll rewrite the above to a more precise representation that express the intention in a more precise manner, since here we actually want to do int8 addition, instead of simulating the int8 addition with fp32 operations, the representation for quantized add is: ``` def quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point): x = (x_scale / out_scale) * x_i8 y = (y_scale / out_scale) * y_i8 out = x + y out -= (x_zero_point * x_scale - y_zero_point * y_scale) / out_scale out += out_zero_point return out ``` Test Plan: ``` buck2 test caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_representation_add (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2E)' ``` Reviewed By: kimishpatel Differential Revision: D45628032 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104130 Approved by: https://github.com/kimishpatel
94 lines
3.1 KiB
Python
94 lines
3.1 KiB
Python
from torch.fx import GraphModule
|
|
|
|
from ._pt2e.prepare import prepare
|
|
from ._pt2e._propagate_annotation import propagate_annotation
|
|
from ._pt2e.qat_utils import (
|
|
_fuse_conv_bn_qat,
|
|
_fold_conv_bn_qat,
|
|
)
|
|
from ._pt2e.utils import (
|
|
_get_node_name_to_scope,
|
|
_fuse_conv_bn_,
|
|
_rearrange_weight_observer_for_decomposed_linear,
|
|
)
|
|
from ._pt2e.representation import reference_representation_rewrite
|
|
from .fx.prepare import prepare as fx_prepare
|
|
from .quantize_fx import _convert_to_reference_decomposed_fx
|
|
from torch.ao.quantization import QConfigMapping
|
|
from torch.ao.quantization._pt2e.quantizer import Quantizer
|
|
from torch.ao.quantization.backend_config import BackendConfig
|
|
|
|
from typing import Any, Tuple
|
|
|
|
def prepare_pt2e(
|
|
model: GraphModule,
|
|
qconfig_mapping: QConfigMapping,
|
|
example_inputs: Tuple[Any, ...],
|
|
backend_config: BackendConfig,
|
|
) -> GraphModule:
|
|
node_name_to_scope = _get_node_name_to_scope(model)
|
|
|
|
# TODO: check qconfig_mapping to make sure conv and bn are both configured
|
|
# to be quantized before fusion
|
|
# TODO: (maybe) rewrite this with subgraph_rewriter
|
|
_fuse_conv_bn_(model)
|
|
model = fx_prepare(
|
|
model,
|
|
qconfig_mapping,
|
|
False, # is_qat
|
|
node_name_to_scope,
|
|
example_inputs,
|
|
backend_config=backend_config
|
|
)
|
|
|
|
# TODO: remove hack when we have better support for pattern matching
|
|
# move around the observer for addmm
|
|
_rearrange_weight_observer_for_decomposed_linear(model)
|
|
return model
|
|
|
|
# TODO: update this to prepare_pt2e after we have a usable quantizer
|
|
# implemented
|
|
def prepare_pt2e_quantizer(
|
|
model: GraphModule,
|
|
quantizer: Quantizer,
|
|
) -> GraphModule:
|
|
node_name_to_scope = _get_node_name_to_scope(model)
|
|
# TODO: check qconfig_mapping to make sure conv and bn are both configured
|
|
# to be quantized before fusion
|
|
# TODO: (maybe) rewrite this with subgraph_rewriter
|
|
_fuse_conv_bn_(model)
|
|
quantizer.annotate(model)
|
|
quantizer.validate(model)
|
|
propagate_annotation(model)
|
|
model = prepare(model, node_name_to_scope, is_qat=False)
|
|
return model
|
|
|
|
# TODO: update this to prepare_qat_pt2e
|
|
def prepare_qat_pt2e_quantizer(
|
|
model: GraphModule,
|
|
quantizer: Quantizer,
|
|
) -> GraphModule:
|
|
node_name_to_scope = _get_node_name_to_scope(model)
|
|
quantizer.annotate(model)
|
|
quantizer.validate(model)
|
|
propagate_annotation(model)
|
|
# Perform fusion after annotate to avoid quantizing ops in the new
|
|
# subgraph that don't need to be quantized
|
|
# TODO: only fuse if conv and bn are both configured to be quantized
|
|
_fuse_conv_bn_qat(model)
|
|
model = prepare(model, node_name_to_scope, is_qat=True)
|
|
# TODO: remove hack when we have better support for pattern matching
|
|
# move around the observer for addmm
|
|
_rearrange_weight_observer_for_decomposed_linear(model)
|
|
return model
|
|
|
|
def convert_pt2e(
|
|
model: GraphModule,
|
|
use_reference_representation: bool = False,
|
|
) -> GraphModule:
|
|
model = _convert_to_reference_decomposed_fx(model)
|
|
model = _fold_conv_bn_qat(model)
|
|
if use_reference_representation:
|
|
model = reference_representation_rewrite(model)
|
|
return model
|