[quant][pt2e] Propagate get_attr meta through known ops only (#124415)

Summary: Avoid situation where the graph traversal finds a matmul node with a `get_attr` as its `args[0]`, and incorrectly propagate the `get_attr`'s meta to everything downstream.

Test Plan: CI

Differential Revision: D56219120

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124415
Approved by: https://github.com/jerryzh168
This commit is contained in:
Shen Xu
2024-04-24 20:55:56 +00:00
committed by PyTorch MergeBot
parent 355dc34f86
commit 8885638f95
2 changed files with 84 additions and 5 deletions

View File

@ -7,7 +7,7 @@ from typing import List
import torch
import torch._export
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer import QuantizationAnnotation, Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
@ -456,3 +456,67 @@ class TestMetaDataPorting(QuantizationTestCase):
self._test_quant_tag_preservation_through_decomp(
m, example_inputs, from_node_to_tags
)
def test_no_metadata_porting_through_unknown_ops(self):
"""
Model under test
matmul -> add -> relu
matmul has get_attr as first input, but the quantization_tag should not be
propagated to add even if it's part of a chain that ends at get_attr
"""
class MatmulWithConstInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_parameter("w", torch.nn.Parameter(torch.rand(8, 16)))
def forward(self, x, y):
x = torch.matmul(self.w, x)
z = x + y
return torch.nn.functional.relu(z)
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
qconfig = get_symmetric_quantization_config()
for n in gm.graph.nodes:
if n.op != "call_function":
continue
n.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={n.args[0]: qconfig.input_activation},
output_qspec=qconfig.output_activation,
)
tag = str(n.target)
n.meta["quantization_tag"] = tag
for arg in n.args:
if arg.op == "get_attr":
arg.meta["quantization_tag"] = tag
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(16, 24), torch.randn(8, 24))
get_attr_tags = {"aten.matmul.default"}
quantize_per_tensor_tensor_tags = {
"aten.matmul.default",
"aten.add.Tensor",
"aten.relu.default",
}
dequantize_per_tensor_tensor_tags = {
"aten.matmul.default",
"aten.add.Tensor",
"aten.relu.default",
}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tensor_tags,
}
m = self._test_metadata_porting(
MatmulWithConstInput(),
example_inputs,
BackendAQuantizer(),
node_tags,
)