mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In quantization tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154728 Approved by: https://github.com/ezyang
524 lines
21 KiB
Python
524 lines
21 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
import copy
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._export
|
|
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
|
from torch.ao.quantization.quantizer import QuantizationAnnotation, Quantizer
|
|
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
|
get_symmetric_quantization_config,
|
|
)
|
|
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
|
|
from torch.fx import Node
|
|
from torch.testing._internal.common_quantization import QuantizationTestCase
|
|
from torch.testing._internal.common_utils import (
|
|
IS_WINDOWS,
|
|
raise_on_run_directly,
|
|
skipIfCrossRef,
|
|
)
|
|
|
|
|
|
class TestHelperModules:
|
|
class Conv2dWithObsSharingOps(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
self.hardtanh = torch.nn.Hardtanh()
|
|
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.adaptive_avg_pool2d(x)
|
|
x = self.hardtanh(x)
|
|
x = x.view(-1, 3)
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
|
|
def _tag_partitions(
|
|
backend_name: str, op_name: str, annotated_partitions: list[list[Node]]
|
|
):
|
|
for index, partition_nodes in enumerate(annotated_partitions):
|
|
tag_name = backend_name + "_" + op_name + "_" + str(index)
|
|
for node in partition_nodes:
|
|
assert "quantization_tag" not in node.meta, f"{node} is already tagged"
|
|
node.meta["quantization_tag"] = tag_name
|
|
|
|
|
|
_QUANT_OPS = {
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
|
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
|
torch.ops.quantized_decomposed.choose_qparams.tensor,
|
|
}
|
|
|
|
|
|
# TODO: rename to TestPortMetadataPass to align with the util name?
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
|
class TestMetaDataPorting(QuantizationTestCase):
|
|
def _test_quant_tag_preservation_through_decomp(
|
|
self, model, example_inputs, from_node_to_tags
|
|
):
|
|
ep = torch.export.export(model, example_inputs, strict=True)
|
|
found_tags = True
|
|
not_found_nodes = ""
|
|
for from_node, tag in from_node_to_tags.items():
|
|
for n in ep.graph_module.graph.nodes:
|
|
from_node_meta = n.meta.get("from_node", None)
|
|
if from_node_meta is None:
|
|
continue
|
|
if not isinstance(from_node_meta, list):
|
|
raise ValueError(
|
|
f"from_node metadata is of type {type(from_node_meta)}, but expected list"
|
|
)
|
|
for meta in from_node_meta:
|
|
node_target = meta.target
|
|
if node_target == str(from_node):
|
|
node_tag = n.meta.get("quantization_tag", None)
|
|
if node_tag is None or tag != node_tag:
|
|
not_found_nodes += str(n.target) + ", "
|
|
found_tags = False
|
|
break
|
|
if not found_tags:
|
|
break
|
|
self.assertTrue(
|
|
found_tags,
|
|
f"Decomposition did not preserve quantization tag for {not_found_nodes}",
|
|
)
|
|
|
|
def _test_metadata_porting(
|
|
self,
|
|
model,
|
|
example_inputs,
|
|
quantizer,
|
|
node_tags=None,
|
|
) -> torch.fx.GraphModule:
|
|
m_eager = model.eval()
|
|
|
|
# program capture
|
|
m = copy.deepcopy(m_eager)
|
|
m = torch.export.export_for_training(m, example_inputs, strict=True).module()
|
|
|
|
m = prepare_pt2e(m, quantizer)
|
|
# Calibrate
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m)
|
|
|
|
m(*example_inputs)
|
|
recorded_node_tags = {}
|
|
for n in m.graph.nodes:
|
|
if "quantization_tag" not in n.meta:
|
|
continue
|
|
if n.op == "call_function" and n.target in _QUANT_OPS:
|
|
key = n.target
|
|
elif n.op == "get_attr":
|
|
key = "get_attr"
|
|
else:
|
|
continue
|
|
|
|
if key not in recorded_node_tags:
|
|
recorded_node_tags[key] = set()
|
|
|
|
if (
|
|
n.op == "call_function"
|
|
and n.meta["quantization_tag"] in recorded_node_tags[key]
|
|
):
|
|
raise ValueError(
|
|
f"{key} {n.format_node()} has tag {n.meta['quantization_tag']} that "
|
|
"is associated with another node of the same type"
|
|
)
|
|
recorded_node_tags[key].add(n.meta["quantization_tag"])
|
|
|
|
self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys()))
|
|
for k, v in recorded_node_tags.items():
|
|
self.assertEqual(v, node_tags[k])
|
|
return m
|
|
|
|
@skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack
|
|
# trace of the mode torch function impl doesn't match the traced graph stored lineno.
|
|
def test_simple_metadata_porting(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> avgpool -> hardtanh -> linear
|
|
Check quantization tags on conv2d, avgpool and linear are correctly set
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True
|
|
)
|
|
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
|
gm, quantization_config
|
|
)
|
|
_tag_partitions(backend_string, "linear", annotated_partitions)
|
|
annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
|
|
_tag_partitions(backend_string, "conv2d", annotated_partitions)
|
|
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
|
|
gm, quantization_config
|
|
)
|
|
_tag_partitions(
|
|
backend_string, "adaptive_avg_pool2d", annotated_partitions
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
get_attr_tags = {
|
|
"BackendA_conv2d_0",
|
|
"BackendA_linear_0",
|
|
}
|
|
quantize_per_tensor_tags = {
|
|
"BackendA_conv2d_0",
|
|
"BackendA_adaptive_avg_pool2d_0",
|
|
"BackendA_linear_0",
|
|
}
|
|
dequantize_per_tensor_tags = {
|
|
"BackendA_adaptive_avg_pool2d_0",
|
|
"BackendA_conv2d_0",
|
|
"BackendA_linear_0",
|
|
}
|
|
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
|
node_tags = {
|
|
"get_attr": get_attr_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
|
}
|
|
m = self._test_metadata_porting(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
from_node_to_tags = {
|
|
torch.ops.aten.adaptive_avg_pool2d.default: "BackendA_adaptive_avg_pool2d_0",
|
|
torch.ops.aten.linear.default: "BackendA_linear_0",
|
|
}
|
|
self._test_quant_tag_preservation_through_decomp(
|
|
m, example_inputs, from_node_to_tags
|
|
)
|
|
|
|
def test_metadata_porting_with_no_quant_inbetween(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> avgpool -> hardtanh -> linear
|
|
Dont quantize avgpool
|
|
Check quantization tags on conv2d and linear are correctly set
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True
|
|
)
|
|
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
|
gm, quantization_config
|
|
)
|
|
_tag_partitions(backend_string, "linear", annotated_partitions)
|
|
annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
|
|
_tag_partitions(backend_string, "conv2d", annotated_partitions)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
get_attr_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
|
quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
|
dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
|
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
|
node_tags = {
|
|
"get_attr": get_attr_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
|
}
|
|
self._test_metadata_porting(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
@unittest.skip("Temporarily disabled")
|
|
def test_metadata_porting_for_dq(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> avgpool -> hardtanh -> linear
|
|
Quantize all except linear.
|
|
Quantize linear with dynamic quantization
|
|
Check quantization tags on conv2d, avgpool and linear are correctly set
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
# static quantiazation
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True
|
|
)
|
|
annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
|
|
_tag_partitions(backend_string, "conv2d", annotated_partitions)
|
|
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
|
|
gm, quantization_config
|
|
)
|
|
_tag_partitions(
|
|
backend_string, "adaptive_avg_pool2d", annotated_partitions
|
|
)
|
|
|
|
# dynamic quantization
|
|
quantization_config_dynamic = get_symmetric_quantization_config(
|
|
is_per_channel=True, is_dynamic=True
|
|
)
|
|
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
|
gm, quantization_config_dynamic
|
|
)
|
|
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
# TODO: add get_attr_tags when the test is re-enabled
|
|
get_attr_tags = {}
|
|
quantize_per_tensor_tags = {
|
|
"BackendA_conv2d_0",
|
|
"BackendA_adaptive_avg_pool2d_0",
|
|
}
|
|
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
|
choose_qparams_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
|
dequantize_per_tensor_tags = {
|
|
"BackendA_adaptive_avg_pool2d_0",
|
|
"BackendA_conv2d_0",
|
|
}
|
|
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
|
dequantize_per_channel_tags = {
|
|
"BackendA_conv2d_0",
|
|
"BackendA_linear_dynamic_0",
|
|
}
|
|
node_tags = {
|
|
"get_attr": get_attr_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
|
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tensor_tags,
|
|
}
|
|
self._test_metadata_porting(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
def test_metadata_porting_for_two_dq(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> avgpool -> hardtanh -> linear
|
|
Quantize linear and conv with dynamic quantization
|
|
Check quantization tags on conv2d, avgpool and linear are correctly set
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
|
|
# dynamic quantization
|
|
quantization_config_dynamic = get_symmetric_quantization_config(
|
|
is_per_channel=True, is_dynamic=True
|
|
)
|
|
annotated_partitions = OP_TO_ANNOTATOR["conv"](
|
|
gm, quantization_config_dynamic
|
|
)
|
|
_tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions)
|
|
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
|
gm, quantization_config_dynamic
|
|
)
|
|
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
get_attr_tags = {
|
|
"BackendA_conv2d_dynamic_0",
|
|
"BackendA_linear_dynamic_0",
|
|
}
|
|
choose_qparams_tensor_tags = {
|
|
"BackendA_conv2d_dynamic_0",
|
|
"BackendA_linear_dynamic_0",
|
|
}
|
|
quantize_per_tensor_tensor_tags = {
|
|
"BackendA_conv2d_dynamic_0",
|
|
"BackendA_linear_dynamic_0",
|
|
}
|
|
dequantize_per_tensor_tensor_tags = {
|
|
"BackendA_conv2d_dynamic_0",
|
|
"BackendA_linear_dynamic_0",
|
|
}
|
|
dequantize_per_channel_tags = {
|
|
"BackendA_conv2d_dynamic_0",
|
|
"BackendA_linear_dynamic_0",
|
|
}
|
|
node_tags = {
|
|
"get_attr": get_attr_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
|
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
|
|
}
|
|
self._test_metadata_porting(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
def test_metadata_porting_for_dq_no_static_q(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> avgpool -> hardtanh -> linear
|
|
Dont quantize anything except linear.
|
|
Quantize linear with dynamic quantization
|
|
Check quantization tags on conv2d, avgpool and linear are correctly set
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
# dynamic quantization
|
|
quantization_config_dynamic = get_symmetric_quantization_config(
|
|
is_per_channel=True, is_dynamic=True
|
|
)
|
|
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
|
gm, quantization_config_dynamic
|
|
)
|
|
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
get_attr_tags = {"BackendA_linear_dynamic_0"}
|
|
choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"}
|
|
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
|
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
|
dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"}
|
|
node_tags = {
|
|
"get_attr": get_attr_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
|
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
|
|
}
|
|
self._test_metadata_porting(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
def test_no_metadata_porting(self):
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True
|
|
)
|
|
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
|
|
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
|
|
OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
node_tags = {}
|
|
m = self._test_metadata_porting(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
from_node_to_tags = {}
|
|
self._test_quant_tag_preservation_through_decomp(
|
|
m, example_inputs, from_node_to_tags
|
|
)
|
|
|
|
def test_no_metadata_porting_through_unknown_ops(self):
|
|
"""
|
|
Model under test
|
|
matmul -> add -> relu
|
|
matmul has get_attr as first input, but the quantization_tag should not be
|
|
propagated to add even if it's part of a chain that ends at get_attr
|
|
"""
|
|
|
|
class MatmulWithConstInput(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_parameter("w", torch.nn.Parameter(torch.rand(8, 16)))
|
|
|
|
def forward(self, x, y):
|
|
x = torch.matmul(self.w, x)
|
|
z = x + y
|
|
return torch.nn.functional.relu(z)
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
qconfig = get_symmetric_quantization_config()
|
|
for n in gm.graph.nodes:
|
|
if n.op != "call_function":
|
|
continue
|
|
|
|
n.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={n.args[0]: qconfig.input_activation},
|
|
output_qspec=qconfig.output_activation,
|
|
)
|
|
|
|
tag = str(n.target)
|
|
n.meta["quantization_tag"] = tag
|
|
for arg in n.args:
|
|
if arg.op == "get_attr":
|
|
arg.meta["quantization_tag"] = tag
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(16, 24), torch.randn(8, 24))
|
|
get_attr_tags = {"aten.matmul.default"}
|
|
quantize_per_tensor_tensor_tags = {
|
|
"aten.matmul.default",
|
|
"aten.add.Tensor",
|
|
"aten.relu.default",
|
|
}
|
|
dequantize_per_tensor_tensor_tags = {
|
|
"aten.matmul.default",
|
|
"aten.add.Tensor",
|
|
"aten.relu.default",
|
|
}
|
|
node_tags = {
|
|
"get_attr": get_attr_tags,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tensor_tags,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tensor_tags,
|
|
}
|
|
self._test_metadata_porting(
|
|
MatmulWithConstInput(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_tags,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_quantization.py")
|