diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 83071b195483..10a6521e17f3 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -1686,6 +1686,59 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): M(), example_inputs, is_per_channel=True, verify_convert=True, ) + def test_representation_add(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x + y + + import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq + + quantizer = QNNPackQuantizer() + operator_config = qq.get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + m_eager = M().eval() + + example_inputs = (torch.randn(1, 3, 3, 3), torch.randn(1, 3, 3, 3),) + # program capture + m = m_eager + # m_copy = copy.deepcopy(m) + m, guards = torchdynamo.export( + m, + *copy.deepcopy(example_inputs), + aten_graph=True, + ) + + m = prepare_pt2e_quantizer(m, quantizer) + # Calibrate + m(*example_inputs) + m = convert_pt2e(m, use_reference_representation=True) + # make sure it runs + pt2_quant_output = m(*example_inputs) + + # TODO: torchdynamo timesout when we do this, we can enable numerical checking + # after that is fixed + # m_copy = prepare_pt2e_quantizer(m_copy, quantizer) + # # Calibrate + # m_copy(*example_inputs) + # m_copy = convert_pt2e(m_copy, use_reference_representation=False) + # pt2_quant_output_copy = m_copy(*example_inputs) + + # output_scale = None + # idx = 0 + # for n in m_copy.graph.nodes: + # if n.target == torch.ops.quantized_decomposed.quantize_per_tensor.default: + # idx += 1 + # if idx == 3: + # output_scale = n.args[1] + # assert output_scale is not None + + # # make sure the result is off by one at most in the quantized integer representation + # self.assertTrue( + # torch.max(torch.abs(pt2_quant_output_copy - pt2_quant_output)) <= (2 * output_scale + 1e-5) + # ) @skipIfNoQNNPACK class TestQuantizePT2EOps(QuantizationTestCase): diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py index 357288981629..1fd3e94f6832 100644 --- a/torch/_dynamo/skipfiles.py +++ b/torch/_dynamo/skipfiles.py @@ -141,6 +141,7 @@ FILENAME_ALLOWLIST |= {torch.utils._foreach_utils.__file__} FILENAME_ALLOWLIST |= { _module_dir(torch) + "ao/quantization/_pt2e/qat_utils.py", _module_dir(torch) + "ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py", + _module_dir(torch) + "ao/quantization/_pt2e/representation/rewrite.py", } # TODO (zhxchen17) Make exportdb importable here. diff --git a/torch/ao/quantization/_pt2e/qat_utils.py b/torch/ao/quantization/_pt2e/qat_utils.py index 0edc91ec7789..25a42659a9e3 100644 --- a/torch/ao/quantization/_pt2e/qat_utils.py +++ b/torch/ao/quantization/_pt2e/qat_utils.py @@ -1,4 +1,3 @@ -import copy import dataclasses import itertools import operator @@ -16,6 +15,7 @@ from .quantizer import ( QuantizationSpecBase, ) from .utils import _fold_bn_weights_into_conv_node +from .utils import _get_aten_graph_module # Example inputs for `_conv2d_bn_pattern`, `_qat_conv2d_bn_pattern`, and `_qat_conv2d_bn_pattern_no_bias` _conv2d_bn_pattern_example_inputs = ( @@ -299,27 +299,6 @@ def _get_folded_quantized_qat_conv2d_bn_pattern( return x return _folded_quantized_qat_conv2d_bn_pattern -def _get_aten_graph_module( - pattern: Callable, - example_inputs: Tuple[Any, ...], - **kwargs, -) -> GraphModule: - """ - Convert the pattern to an FX graph with decomposed aten ops. - """ - # Avoid circular imports - import torch._dynamo - aten_pattern, _ = torch._dynamo.export( - pattern, - *copy.deepcopy(example_inputs), - aten_graph=True, - tracing_mode="real", - **kwargs, - ) - aten_pattern.graph.eliminate_dead_code() - aten_pattern.recompile() - return aten_pattern - def _has_conv_bias_filter( match: "InternalMatch", # type: ignore[name-defined] original_graph: Graph, diff --git a/torch/ao/quantization/_pt2e/representation/__init__.py b/torch/ao/quantization/_pt2e/representation/__init__.py new file mode 100644 index 000000000000..9ddac64c04fa --- /dev/null +++ b/torch/ao/quantization/_pt2e/representation/__init__.py @@ -0,0 +1,5 @@ +from .rewrite import reference_representation_rewrite + +__all__ = [ + "reference_representation_rewrite", +] diff --git a/torch/ao/quantization/_pt2e/representation/rewrite.py b/torch/ao/quantization/_pt2e/representation/rewrite.py new file mode 100644 index 000000000000..9561f13e0daf --- /dev/null +++ b/torch/ao/quantization/_pt2e/representation/rewrite.py @@ -0,0 +1,74 @@ +import torch +from torch.fx import GraphModule +from ..utils import _get_aten_graph_module +from ..utils import _remove_tensor_overload_for_qdq_ops +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.fx.subgraph_rewriter import replace_pattern + +__all__ = [ + "reference_representation_rewrite", +] + +_QUANTIZED_ADD_EXAMPLE_INPUTS = ( + torch.randn(1, 3, 3, 3).to(torch.int8), + torch.randn(1).to(torch.float), + torch.zeros(1).to(torch.int), + torch.randn(1, 3, 3, 3).to(torch.int8), + torch.randn(1).to(torch.float), + torch.zeros(1).to(torch.int), + torch.randn(1).to(torch.float), + torch.zeros(1).to(torch.int), +) + +def _qdq_quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point): + quant_min = -128 + quant_max = 127 + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8) + y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8) + out_fp32 = x_fp32 + y_fp32 + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8 + ) + return out_i8 + +def _reference_quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point): + """ + # How to Derive the formula for out_i8 based on x_i8 and y_i8 + # (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8) + + # out_i8 is quantized output, we can write down the formula for it first: +out_i8 = out_f32 / out_scale + out_zero_point (1) + + # then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8 + out_f32 = x_f32 + y_f32 (2) + x_fp32 = (x_i8 - x_zero_point) * x_scale (3) + y_fp32 = (y_i8 - y_zero_point) * y_scale (4) + + # applying the above fomula to the out_i8 equation we can get the following: + out_i8 = out_fp32 / out_scale + out_zero_point # (1) + = (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32 + = ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4) + """ + x_i32 = x_i8.to(torch.int32) + y_i32 = y_i8.to(torch.int32) + # TODO: use out_dtype op + x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32) + y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32) + out_i32 = x_i32 + y_i32 + out_zero_point + quant_min = -128 + quant_max = 127 + out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8) + return out_i8 + +_EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [ + (_QUANTIZED_ADD_EXAMPLE_INPUTS, _qdq_quantized_add, _reference_quantized_add) +] + +def reference_representation_rewrite(model: GraphModule) -> GraphModule: + _remove_tensor_overload_for_qdq_ops(model) + for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS: + pattern_graph = _get_aten_graph_module(pattern, example_inputs) + _remove_tensor_overload_for_qdq_ops(pattern_graph) + replacement_graph = _get_aten_graph_module(replacement, example_inputs) + matches = replace_pattern(model, pattern_graph, replacement_graph) + return model diff --git a/torch/ao/quantization/_pt2e/utils.py b/torch/ao/quantization/_pt2e/utils.py index 1d30dbb852ef..7103e2c7b9f0 100644 --- a/torch/ao/quantization/_pt2e/utils.py +++ b/torch/ao/quantization/_pt2e/utils.py @@ -9,7 +9,8 @@ from torch.ao.quantization.fx.prepare import ( _is_activation_post_process_node, ) import operator -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Callable, Any +import copy def _get_tensor_constant_from_node(node, m): @@ -172,3 +173,42 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]: current_scope = (bt[0].split(".")[-1], bt[1]) node_name_to_scope[n.name] = current_scope return node_name_to_scope + +def _get_aten_graph_module( + pattern: Callable, + example_inputs: Tuple[Any, ...], + **kwargs, +) -> GraphModule: + """ + Convert the pattern to an FX graph with decomposed aten ops. + """ + # Avoid circular imports + import torch._dynamo + aten_pattern, _ = torch._dynamo.export( + pattern, + *copy.deepcopy(example_inputs), + aten_graph=True, + tracing_mode="real", + **kwargs, + ) + aten_pattern.graph.eliminate_dead_code() + aten_pattern.recompile() + return aten_pattern + +def _remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None: + """ Remove .tensor overload for quantize/dequantize ops so that we can + use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e + """ + _MAP = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel, + torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel + } + for n in match_pattern.graph.nodes: + if n.op != "call_function": + continue + if n.target in _MAP: + n.target = _MAP[n.target] diff --git a/torch/ao/quantization/_quantize_pt2e.py b/torch/ao/quantization/_quantize_pt2e.py index c6097f2df4fd..f865099cc2ff 100644 --- a/torch/ao/quantization/_quantize_pt2e.py +++ b/torch/ao/quantization/_quantize_pt2e.py @@ -11,6 +11,7 @@ from ._pt2e.utils import ( _fuse_conv_bn_, _rearrange_weight_observer_for_decomposed_linear, ) +from ._pt2e.representation import reference_representation_rewrite from .fx.prepare import prepare as fx_prepare from .quantize_fx import _convert_to_reference_decomposed_fx from torch.ao.quantization import QConfigMapping @@ -82,8 +83,11 @@ def prepare_qat_pt2e_quantizer( return model def convert_pt2e( - model: GraphModule + model: GraphModule, + use_reference_representation: bool = False, ) -> GraphModule: model = _convert_to_reference_decomposed_fx(model) model = _fold_conv_bn_qat(model) + if use_reference_representation: + model = reference_representation_rewrite(model) return model diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index 80f83b74e70c..35e234b1512e 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -133,7 +133,7 @@ def dequantize_per_tensor( Returns: dequantized float32 Tensor """ - assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" + assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}, but got {input.dtype}" if dtype in [torch.uint8, torch.int8, torch.int32]: # TODO: investigate why # (input - zero_point).to(torch.float32) * scale