Files
pytorch/torch/ao/quantization/_quantize_pt2e.py
Jerry Zhang c98896b76f [quant][pt2e] Add more precise representation for quantized add (#104130)
Summary:
The planned e2e for quantization in pytorch 2.0 export is the following:

float_model -> prepare_pt2e -> calibration -> convert_pt2e -> ...

inside convert_pt2e, we will first produce a q/dq representation of the quantized model, similar to the previous output of
convert_to_reference_fx in fx grah mode quantization:

```
torch.ops.quantized_decomposed.dequantize_per_tensor -> torch.ops.aten.add -> torch.ops.quantized_decomopsed.quantize_per_tensor
torch.ops.quantized_decomposed.dequantize_per_tensor   /
```

Then we'll rewrite the above to a more precise representation that express the intention in a more precise manner, since
here we actually want to do int8 addition, instead of simulating the int8 addition with fp32 operations, the representation for
quantized add is:

```
def quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point):
    x = (x_scale / out_scale) * x_i8
    y = (y_scale / out_scale) * y_i8
    out = x + y
    out -= (x_zero_point * x_scale - y_zero_point * y_scale) / out_scale
    out += out_zero_point
    return out
```

Test Plan:
```
buck2 test caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_representation_add (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2E)'
```

Reviewed By: kimishpatel

Differential Revision: D45628032

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104130
Approved by: https://github.com/kimishpatel
2023-06-27 20:11:30 +00:00

94 lines
3.1 KiB
Python

from torch.fx import GraphModule
from ._pt2e.prepare import prepare
from ._pt2e._propagate_annotation import propagate_annotation
from ._pt2e.qat_utils import (
_fuse_conv_bn_qat,
_fold_conv_bn_qat,
)
from ._pt2e.utils import (
_get_node_name_to_scope,
_fuse_conv_bn_,
_rearrange_weight_observer_for_decomposed_linear,
)
from ._pt2e.representation import reference_representation_rewrite
from .fx.prepare import prepare as fx_prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization._pt2e.quantizer import Quantizer
from torch.ao.quantization.backend_config import BackendConfig
from typing import Any, Tuple
def prepare_pt2e(
model: GraphModule,
qconfig_mapping: QConfigMapping,
example_inputs: Tuple[Any, ...],
backend_config: BackendConfig,
) -> GraphModule:
node_name_to_scope = _get_node_name_to_scope(model)
# TODO: check qconfig_mapping to make sure conv and bn are both configured
# to be quantized before fusion
# TODO: (maybe) rewrite this with subgraph_rewriter
_fuse_conv_bn_(model)
model = fx_prepare(
model,
qconfig_mapping,
False, # is_qat
node_name_to_scope,
example_inputs,
backend_config=backend_config
)
# TODO: remove hack when we have better support for pattern matching
# move around the observer for addmm
_rearrange_weight_observer_for_decomposed_linear(model)
return model
# TODO: update this to prepare_pt2e after we have a usable quantizer
# implemented
def prepare_pt2e_quantizer(
model: GraphModule,
quantizer: Quantizer,
) -> GraphModule:
node_name_to_scope = _get_node_name_to_scope(model)
# TODO: check qconfig_mapping to make sure conv and bn are both configured
# to be quantized before fusion
# TODO: (maybe) rewrite this with subgraph_rewriter
_fuse_conv_bn_(model)
quantizer.annotate(model)
quantizer.validate(model)
propagate_annotation(model)
model = prepare(model, node_name_to_scope, is_qat=False)
return model
# TODO: update this to prepare_qat_pt2e
def prepare_qat_pt2e_quantizer(
model: GraphModule,
quantizer: Quantizer,
) -> GraphModule:
node_name_to_scope = _get_node_name_to_scope(model)
quantizer.annotate(model)
quantizer.validate(model)
propagate_annotation(model)
# Perform fusion after annotate to avoid quantizing ops in the new
# subgraph that don't need to be quantized
# TODO: only fuse if conv and bn are both configured to be quantized
_fuse_conv_bn_qat(model)
model = prepare(model, node_name_to_scope, is_qat=True)
# TODO: remove hack when we have better support for pattern matching
# move around the observer for addmm
_rearrange_weight_observer_for_decomposed_linear(model)
return model
def convert_pt2e(
model: GraphModule,
use_reference_representation: bool = False,
) -> GraphModule:
model = _convert_to_reference_decomposed_fx(model)
model = _fold_conv_bn_qat(model)
if use_reference_representation:
model = reference_representation_rewrite(model)
return model