mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant][pt2e] Propagate get_attr meta through known ops only (#124415)
Summary: Avoid situation where the graph traversal finds a matmul node with a `get_attr` as its `args[0]`, and incorrectly propagate the `get_attr`'s meta to everything downstream. Test Plan: CI Differential Revision: D56219120 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124415 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
355dc34f86
commit
8885638f95
@ -7,7 +7,7 @@ from typing import List
|
||||
import torch
|
||||
import torch._export
|
||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||
from torch.ao.quantization.quantizer import Quantizer
|
||||
from torch.ao.quantization.quantizer import QuantizationAnnotation, Quantizer
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
get_symmetric_quantization_config,
|
||||
)
|
||||
@ -456,3 +456,67 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
self._test_quant_tag_preservation_through_decomp(
|
||||
m, example_inputs, from_node_to_tags
|
||||
)
|
||||
|
||||
def test_no_metadata_porting_through_unknown_ops(self):
|
||||
"""
|
||||
Model under test
|
||||
matmul -> add -> relu
|
||||
matmul has get_attr as first input, but the quantization_tag should not be
|
||||
propagated to add even if it's part of a chain that ends at get_attr
|
||||
"""
|
||||
|
||||
class MatmulWithConstInput(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_parameter("w", torch.nn.Parameter(torch.rand(8, 16)))
|
||||
|
||||
def forward(self, x, y):
|
||||
x = torch.matmul(self.w, x)
|
||||
z = x + y
|
||||
return torch.nn.functional.relu(z)
|
||||
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
backend_string = "BackendA"
|
||||
qconfig = get_symmetric_quantization_config()
|
||||
for n in gm.graph.nodes:
|
||||
if n.op != "call_function":
|
||||
continue
|
||||
|
||||
n.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map={n.args[0]: qconfig.input_activation},
|
||||
output_qspec=qconfig.output_activation,
|
||||
)
|
||||
|
||||
tag = str(n.target)
|
||||
n.meta["quantization_tag"] = tag
|
||||
for arg in n.args:
|
||||
if arg.op == "get_attr":
|
||||
arg.meta["quantization_tag"] = tag
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(16, 24), torch.randn(8, 24))
|
||||
get_attr_tags = {"aten.matmul.default"}
|
||||
quantize_per_tensor_tensor_tags = {
|
||||
"aten.matmul.default",
|
||||
"aten.add.Tensor",
|
||||
"aten.relu.default",
|
||||
}
|
||||
dequantize_per_tensor_tensor_tags = {
|
||||
"aten.matmul.default",
|
||||
"aten.add.Tensor",
|
||||
"aten.relu.default",
|
||||
}
|
||||
node_tags = {
|
||||
"get_attr": get_attr_tags,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tensor_tags,
|
||||
}
|
||||
m = self._test_metadata_porting(
|
||||
MatmulWithConstInput(),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
node_tags,
|
||||
)
|
||||
|
Reference in New Issue
Block a user