mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant][pt2e] Enable constant folding for quantize ops (#109343)
Summary: This PR added constant folding for quantize ops so that instead of storing fp32 weight in the quantized model, we'll get int8/int16 etc. weight Test Plan: python test/test_quantization.py TestQuantizePT2E.test_fold_quantize also will verify in executorch later Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D49399210](https://our.internmc.facebook.com/intern/diff/D49399210) Pull Request resolved: https://github.com/pytorch/pytorch/pull/109343 Approved by: https://github.com/kimishpatel, https://github.com/jgong5
This commit is contained in:
committed by
PyTorch MergeBot
parent
6138750ab1
commit
1b51d29b66
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._export as export
|
||||
@ -9,7 +10,6 @@ from torch.ao.quantization.observer import (
|
||||
MinMaxObserver,
|
||||
PlaceholderObserver,
|
||||
)
|
||||
from torch.ao.quantization.pt2e.utils import _find_q_dq_node_for_user
|
||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||
from torch.ao.quantization.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
@ -26,6 +26,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
|
||||
|
||||
class TestHelperModules:
|
||||
@ -81,6 +82,14 @@ class TestHelperModules:
|
||||
return w, add_output, extra_output
|
||||
|
||||
|
||||
_DEQUANTIZE_OPS = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
]
|
||||
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
||||
class TestDuplicateDQPass(QuantizationTestCase):
|
||||
def _test_duplicate_dq(
|
||||
self,
|
||||
@ -106,19 +115,9 @@ class TestDuplicateDQPass(QuantizationTestCase):
|
||||
for n in m.graph.nodes:
|
||||
annotation = n.meta.get("quantization_annotation", None)
|
||||
if annotation is not None:
|
||||
input_qspec_map = annotation.input_qspec_map
|
||||
for input_node, qspec in input_qspec_map.items():
|
||||
if (
|
||||
qspec is not None
|
||||
and hasattr(qspec, "dtype")
|
||||
and qspec.dtype != torch.float
|
||||
):
|
||||
q_node, dq_node = _find_q_dq_node_for_user(input_node, n)
|
||||
if dq_node is None:
|
||||
raise ValueError(
|
||||
f"No dq node found for {n}, even though {n} annotated for quantization."
|
||||
)
|
||||
self.assertEqual(len(dq_node.users.keys()), 1)
|
||||
for arg in n.args:
|
||||
if isinstance(arg, torch.fx.Node) and arg.target in _DEQUANTIZE_OPS:
|
||||
self.assertEqual(len(arg.users.keys()), 1)
|
||||
|
||||
def test_no_need_for_duplicate_dq(self):
|
||||
"""
|
||||
|
@ -16,6 +16,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTA
|
||||
from torch.fx import Node
|
||||
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
|
||||
|
||||
class TestHelperModules:
|
||||
@ -57,6 +58,8 @@ _QUANT_OPS = {
|
||||
}
|
||||
|
||||
|
||||
# TODO: rename to TestPortMetadataPass to align with the util name?
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
||||
class TestMetaDataPorting(QuantizationTestCase):
|
||||
def _test_metadata_porting(
|
||||
self,
|
||||
@ -77,23 +80,33 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
m = convert_pt2e(m, fold_quantize=True)
|
||||
|
||||
pt2_quant_output = m(*example_inputs)
|
||||
recorded_node_tags = {}
|
||||
for n in m.graph.nodes:
|
||||
if "quantization_tag" not in n.meta:
|
||||
continue
|
||||
if n.op == "call_function" and n.target in _QUANT_OPS:
|
||||
key = n.target
|
||||
elif n.op == "get_attr":
|
||||
key = "get_attr"
|
||||
else:
|
||||
continue
|
||||
|
||||
if key not in recorded_node_tags:
|
||||
recorded_node_tags[key] = set()
|
||||
|
||||
if (
|
||||
n.op == "call_function"
|
||||
and n.target in _QUANT_OPS
|
||||
and "quantization_tag" in n.meta
|
||||
and n.meta["quantization_tag"] in recorded_node_tags[key]
|
||||
):
|
||||
if n.target not in recorded_node_tags:
|
||||
recorded_node_tags[n.target] = set()
|
||||
if n.meta["quantization_tag"] in recorded_node_tags[n.target]:
|
||||
raise ValueError(
|
||||
f"{n} has tag {n.meta['quantization_tag']} that is associated with another node of the same type"
|
||||
)
|
||||
recorded_node_tags[n.target].add(n.meta["quantization_tag"])
|
||||
raise ValueError(
|
||||
f"{key} {n.format_node()} has tag {n.meta['quantization_tag']} that "
|
||||
"is associated with another node of the same type"
|
||||
)
|
||||
recorded_node_tags[key].add(n.meta["quantization_tag"])
|
||||
|
||||
self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys()))
|
||||
for k, v in recorded_node_tags.items():
|
||||
self.assertEqual(v, node_tags[k])
|
||||
@ -130,6 +143,10 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
get_attr_tags = {
|
||||
"BackendA_conv2d_0",
|
||||
"BackendA_linear_0",
|
||||
}
|
||||
quantize_per_tensor_tags = {
|
||||
"BackendA_conv2d_0",
|
||||
"BackendA_adaptive_avg_pool2d_0",
|
||||
@ -142,6 +159,7 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
}
|
||||
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
||||
node_tags = {
|
||||
"get_attr": get_attr_tags,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
||||
@ -180,10 +198,12 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
get_attr_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
||||
quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
||||
dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
||||
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
||||
node_tags = {
|
||||
"get_attr": get_attr_tags,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
||||
@ -236,6 +256,8 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
# TODO: add get_attr_tags when the test is re-enabled
|
||||
get_attr_tags = {}
|
||||
quantize_per_tensor_tags = {
|
||||
"BackendA_conv2d_0",
|
||||
"BackendA_adaptive_avg_pool2d_0",
|
||||
@ -252,6 +274,7 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
"BackendA_linear_dynamic_0",
|
||||
}
|
||||
node_tags = {
|
||||
"get_attr": get_attr_tags,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
|
||||
@ -295,6 +318,10 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
get_attr_tags = {
|
||||
"BackendA_conv2d_dynamic_0",
|
||||
"BackendA_linear_dynamic_0",
|
||||
}
|
||||
choose_qparams_tensor_tags = {
|
||||
"BackendA_conv2d_dynamic_0",
|
||||
"BackendA_linear_dynamic_0",
|
||||
@ -312,6 +339,7 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
"BackendA_linear_dynamic_0",
|
||||
}
|
||||
node_tags = {
|
||||
"get_attr": get_attr_tags,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
||||
@ -349,11 +377,13 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
get_attr_tags = {"BackendA_linear_dynamic_0"}
|
||||
choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"}
|
||||
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
||||
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
||||
dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"}
|
||||
node_tags = {
|
||||
"get_attr": get_attr_tags,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
||||
|
@ -91,6 +91,31 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
}
|
||||
|
||||
def _quantize(self, m, quantizer, example_inputs):
|
||||
m = capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m, fold_quantize=True)
|
||||
return m
|
||||
|
||||
def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=is_per_channel)
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(2, 2),)
|
||||
m = M().eval()
|
||||
return self._quantize(m, quantizer, example_inputs)
|
||||
|
||||
def _test_quantizer(
|
||||
self,
|
||||
@ -118,7 +143,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
m = convert_pt2e(m, fold_quantize=True)
|
||||
|
||||
pt2_quant_output = m(*example_inputs)
|
||||
node_occurrence = {
|
||||
@ -209,7 +234,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
node_occurrence = {
|
||||
# two for input of the first conv, one for output for the first conv
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
}
|
||||
node_list = [
|
||||
@ -277,18 +302,11 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m = torch.nn.Conv2d(2, 2, 1)
|
||||
x = torch.rand(1, 2, 14, 14)
|
||||
example_inputs = (x,)
|
||||
# program capture
|
||||
m = capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
m = prepare_pt2e(m, BackendAQuantizer())
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
m = self._quantize(m, BackendAQuantizer(), example_inputs)
|
||||
# Ensure the conv has no observer inserted at output
|
||||
node_occurrence = {
|
||||
# two for input of conv
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 1,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
|
||||
}
|
||||
node_list = [
|
||||
@ -369,19 +387,12 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m = TestHelperModules.ConvMaxPool2d()
|
||||
x = torch.rand(1, 2, 14, 14)
|
||||
example_inputs = (x,)
|
||||
# program capture
|
||||
m = capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
m = prepare_pt2e(m, BackendAQuantizer())
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
m = self._quantize(m, BackendAQuantizer(), example_inputs)
|
||||
node_occurrence = {
|
||||
# two for input of conv
|
||||
# one for input of maxpool
|
||||
# one for output of maxpool
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 4,
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 3,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 4,
|
||||
}
|
||||
node_list = [
|
||||
@ -469,19 +480,13 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval()
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
|
||||
# program capture
|
||||
m = capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
m = prepare_pt2e(m, BackendAQuantizer())
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
m = self._quantize(m, BackendAQuantizer(), example_inputs)
|
||||
node_occurrence = {
|
||||
# input, weight, bias, output for the conv
|
||||
# note: quantize op for weight and bias are const propagated
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
): 4,
|
||||
): 2,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
): 4,
|
||||
@ -575,15 +580,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval()
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
|
||||
# program capture
|
||||
m = capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
m = prepare_pt2e(m, BackendAQuantizer())
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
m(*example_inputs)
|
||||
m = self._quantize(m, BackendAQuantizer(), example_inputs)
|
||||
|
||||
node_occurrence = {
|
||||
# input, output for the conv
|
||||
@ -594,9 +591,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
): 2,
|
||||
# weight and bias for conv
|
||||
# note: quantize op for weight and bias are const propagated
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default
|
||||
): 2,
|
||||
): 0,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
||||
): 2,
|
||||
@ -653,14 +651,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m = M().eval()
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
|
||||
# program capture
|
||||
m = capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
m = prepare_pt2e(m, BackendAQuantizer())
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
m = self._quantize(m, BackendAQuantizer(), example_inputs)
|
||||
fixed_scale = 1.0 / 256.0
|
||||
fixed_zero_point = 0
|
||||
for n in m.graph.nodes:
|
||||
@ -805,13 +796,13 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
assert conv_output_obs[0] == conv_output_obs[1]
|
||||
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
m = convert_pt2e(m, fold_quantize=True)
|
||||
|
||||
node_occurrence = {
|
||||
# two for input of the first conv, one for output for the first conv
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
): 7,
|
||||
): 5,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
): 7,
|
||||
@ -873,8 +864,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
quantizer = Int16ActQuantizer()
|
||||
node_occurrence = {
|
||||
# two for input of the first conv, one for output for the first conv
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
# one for input of the first conv, one for output for the first conv
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
}
|
||||
node_list = [
|
||||
@ -892,6 +883,127 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
node_list,
|
||||
)
|
||||
|
||||
def test_fold_quantize(self):
|
||||
"""Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)
|
||||
"""
|
||||
m = self._get_pt2e_quantized_linear()
|
||||
node_occurrence = {
|
||||
# quantize op for weight node is folded
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 3,
|
||||
}
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
||||
|
||||
def test_fold_quantize_per_channel(self):
|
||||
"""Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded)
|
||||
"""
|
||||
m = self._get_pt2e_quantized_linear(is_per_channel=True)
|
||||
node_occurrence = {
|
||||
# quantize op for weight node is folded
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
|
||||
}
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
||||
|
||||
def test_dont_fold_other_constant(self):
|
||||
"""Make sure the constant propagation does not apply to things unrelated to
|
||||
quantization
|
||||
"""
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
self.dont_fold_me = torch.nn.Parameter(torch.randn(2, 2))
|
||||
|
||||
def forward(self, x):
|
||||
t = self.dont_fold_me.t()
|
||||
return self.linear(x) + t
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=False)
|
||||
# only quantize linear, so add is not quantized and the constant Tensor
|
||||
# should not be folded
|
||||
quantizer.set_module_type(torch.nn.Linear, operator_config)
|
||||
example_inputs = (torch.randn(2, 2),)
|
||||
m = M().eval()
|
||||
m = self._quantize(m, quantizer, example_inputs)
|
||||
node_occurrence = {
|
||||
# quantize op for weight node is folded
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 3,
|
||||
# transpose op not folded
|
||||
ns.call_function(torch.ops.aten.t.default): 1,
|
||||
}
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
||||
|
||||
def test_fold_all_ops_before_quantize(self):
|
||||
"""Test folding all ops that's before quantized operator:
|
||||
Before:
|
||||
get_attr(weight) -> transpose -> quantize -> dequantize
|
||||
After:
|
||||
get_attr(folded_weight) -> dequantize
|
||||
"""
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = torch.randn(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
t = self.weight.t()
|
||||
return torch.nn.functional.linear(x, t)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=False)
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(2, 2),)
|
||||
m = M().eval()
|
||||
m = self._quantize(m, quantizer, example_inputs)
|
||||
node_occurrence = {
|
||||
# quantize op for weight node is folded
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 3,
|
||||
}
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
||||
|
||||
def test_constant_prop_preserve_metadata(self):
|
||||
"""Test to make sure the get_attr node for const propagated weight Tensor gets the correct
|
||||
metadata (from original get_attr node from weight)
|
||||
"""
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config()
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(2, 2),)
|
||||
m = M().eval()
|
||||
m = capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
weight_meta = None
|
||||
for n in m.graph.nodes:
|
||||
if n.op == "get_attr" and list(n.users)[0].target == torch.ops.aten.linear.default:
|
||||
weight_meta = n.meta
|
||||
break
|
||||
assert weight_meta is not None, "Expect to find metadata for weight node"
|
||||
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m, fold_quantize=True)
|
||||
|
||||
for n in m.graph.nodes:
|
||||
if n.op == "get_attr" and "frozen_param" in n.target:
|
||||
self.assertIn("stack_trace", n.meta)
|
||||
for key in n.meta:
|
||||
self.assertEqual(n.meta[key], weight_meta[key])
|
||||
|
||||
def test_add_and_inplace_add(self):
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
@ -922,27 +1034,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
def test_save_load(self):
|
||||
"""Test save/load a quantized model
|
||||
"""
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=False)
|
||||
quantizer.set_global(operator_config)
|
||||
m = self._get_pt2e_quantized_linear()
|
||||
example_inputs = (torch.randn(2, 2),)
|
||||
m = M().eval()
|
||||
m = capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
ref_res = m(*example_inputs)
|
||||
|
||||
with TemporaryFileName() as fname:
|
||||
@ -991,7 +1084,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
@ -1021,7 +1115,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
qconfig = default_per_channel_symmetric_qnnpack_qconfig
|
||||
@ -1045,7 +1140,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
|
||||
}
|
||||
qconfig = default_per_channel_symmetric_qnnpack_qconfig
|
||||
@ -1072,7 +1168,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
|
||||
}
|
||||
qconfig = default_per_channel_symmetric_qnnpack_qconfig
|
||||
@ -1099,7 +1196,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
qconfig = default_per_channel_symmetric_qnnpack_qconfig
|
||||
@ -1125,7 +1223,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
@ -1252,7 +1351,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
self.assertEqual(
|
||||
id(m.activation_post_process_3), id(m.activation_post_process_4)
|
||||
)
|
||||
m = convert_pt2e(m)
|
||||
m = convert_pt2e(m, fold_quantize=True)
|
||||
node_occurrence = {
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
ns.call_function(
|
||||
@ -1261,9 +1360,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
): 5,
|
||||
# note: quantize op for weights are const propagated
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default
|
||||
): 2,
|
||||
): 0,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
||||
): 2,
|
||||
@ -1282,7 +1382,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
act_affine_quant_obs = observer.PlaceholderObserver.with_args(
|
||||
@ -1324,7 +1425,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||
}
|
||||
act_affine_quant_obs = observer.PlaceholderObserver.with_args(
|
||||
@ -1372,9 +1474,11 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
act_affine_quant_obs = observer.PlaceholderObserver.with_args(
|
||||
@ -1473,11 +1577,11 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
quantizer = EmbeddingQuantizer()
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
torch.ops.aten.embedding.default,
|
||||
]
|
||||
@ -1577,7 +1681,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
|
||||
}
|
||||
self._test_quantizer(
|
||||
@ -1644,7 +1749,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m.train()
|
||||
|
||||
# After convert: still not OK
|
||||
m = convert_pt2e(m)
|
||||
m = convert_pt2e(m, fold_quantize=True)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
m.eval()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
@ -1705,7 +1810,7 @@ class TestQuantizePT2EOps(QuantizationTestCase):
|
||||
quantizer.set_global(quantization_config)
|
||||
model_graph = prepare_pt2e(model_graph, quantizer)
|
||||
model_graph(*example_inputs)
|
||||
model_graph = convert_pt2e(model_graph)
|
||||
model_graph = convert_pt2e(model_graph, fold_quantize=True)
|
||||
self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs))
|
||||
|
||||
|
||||
@ -1767,7 +1872,7 @@ class TestQuantizePT2EOps(QuantizationTestCase):
|
||||
quantizer.set_global(quantization_config)
|
||||
model_graph = prepare_pt2e(model_graph, quantizer)
|
||||
model_graph(*example_inputs)
|
||||
model_graph = convert_pt2e(model_graph)
|
||||
model_graph = convert_pt2e(model_graph, fold_quantize=True)
|
||||
self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs))
|
||||
|
||||
|
||||
@ -1798,7 +1903,7 @@ class TestQuantizePT2EModels(PT2EQuantizationTestCase):
|
||||
id(m.activation_post_process_3), id(m.activation_post_process_2)
|
||||
)
|
||||
after_prepare_result = m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
m = convert_pt2e(m, fold_quantize=True)
|
||||
|
||||
after_quant_result = m(*example_inputs)
|
||||
|
||||
|
@ -29,7 +29,7 @@ class TestPT2ERepresentation(QuantizationTestCase):
|
||||
ref_node_occurrence: Dict[ns, int],
|
||||
non_ref_node_occurrence: Dict[ns, int],
|
||||
fixed_output_tol: float = None,
|
||||
output_scale_idx: int = 3,
|
||||
output_scale_idx: int = 2,
|
||||
) -> torch.nn.Module:
|
||||
# resetting dynamo cache
|
||||
torch._dynamo.reset()
|
||||
@ -42,7 +42,9 @@ class TestPT2ERepresentation(QuantizationTestCase):
|
||||
model = prepare_pt2e(model, quantizer)
|
||||
# Calibrate
|
||||
model(*example_inputs)
|
||||
model = convert_pt2e(model, use_reference_representation=True)
|
||||
model = convert_pt2e(
|
||||
model, use_reference_representation=True, fold_quantize=True
|
||||
)
|
||||
self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence)
|
||||
# make sure it runs
|
||||
pt2e_quant_output = model(*example_inputs)
|
||||
@ -52,7 +54,9 @@ class TestPT2ERepresentation(QuantizationTestCase):
|
||||
model_copy = prepare_pt2e(model_copy, quantizer)
|
||||
# Calibrate
|
||||
model_copy(*example_inputs)
|
||||
model_copy = convert_pt2e(model_copy, use_reference_representation=False)
|
||||
model_copy = convert_pt2e(
|
||||
model_copy, use_reference_representation=False, fold_quantize=True
|
||||
)
|
||||
self.checkGraphModuleNodes(
|
||||
model_copy, expected_node_occurrence=non_ref_node_occurrence
|
||||
)
|
||||
@ -253,9 +257,10 @@ class TestPT2ERepresentation(QuantizationTestCase):
|
||||
): 0,
|
||||
}
|
||||
non_ref_node_occurrence = {
|
||||
# quantize_per_channel is folded
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default
|
||||
): 1,
|
||||
): 0,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
||||
): 1,
|
||||
|
@ -253,7 +253,7 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
prepare_model = copy.deepcopy(m)
|
||||
m = convert_pt2e(m)
|
||||
m = convert_pt2e(m, fold_quantize=True)
|
||||
convert_model = copy.deepcopy(m)
|
||||
pt2_quant_output = m(*example_inputs)
|
||||
node_occurrence = {
|
||||
@ -284,7 +284,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
# one for input and weight of the conv, one for output for the conv
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
@ -321,7 +322,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
# one for input and weight of the conv, one for output for the relu
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
@ -361,19 +363,21 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
# one for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
else:
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv
|
||||
# one for input and weight of another conv
|
||||
# one for input of the conv
|
||||
# one for input of another conv
|
||||
# one for output for the add
|
||||
# 2 conv will share same input quant/dequant
|
||||
# one for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
@ -410,24 +414,26 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
).eval()
|
||||
if conv2d_type != Conv2DType.both:
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv
|
||||
# one for input for conv
|
||||
# one for output for the relu
|
||||
# one for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
else:
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv
|
||||
# one for input and weight of another conv
|
||||
# one for input of the conv
|
||||
# one for input of another conv
|
||||
# one for output for the relu
|
||||
# 2 conv will share same input quant/dequant
|
||||
# one for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
@ -458,7 +464,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 4,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
|
||||
}
|
||||
node_list = [
|
||||
@ -498,7 +505,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
# one for input and weight of the conv, two for input/output for the maxpool2d
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
@ -558,7 +566,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 7,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
|
||||
}
|
||||
node_list = [
|
||||
@ -632,7 +641,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
@ -688,7 +698,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
@ -739,7 +750,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
@ -799,7 +811,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
# one for input and weight, one for output
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
@ -841,7 +854,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
# one for input and weight of the conv, one for output for the relu
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
|
@ -83,6 +83,8 @@ except ImportError as e:
|
||||
try:
|
||||
# To be moved to compiler side later
|
||||
from quantization.pt2e.test_graph_utils import TestGraphUtils # noqa: F401
|
||||
from quantization.pt2e.test_duplicate_dq import TestDuplicateDQPass # noqa: F401
|
||||
from quantization.pt2e.test_metadata_porting import TestMetaDataPorting # noqa: F401
|
||||
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2E # noqa: F401
|
||||
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EOps # noqa: F401
|
||||
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EModels # noqa: F401
|
||||
|
@ -1,5 +1,5 @@
|
||||
import collections
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -51,11 +51,14 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||
self.user_to_last_uses = self.node_to_last_non_output_use()
|
||||
|
||||
def is_impure(self, node: torch.fx.node.Node):
|
||||
if node.target == torch.ops.quantized_decomposed.dequantize_per_channel.default:
|
||||
# For the pattern fp32_weight -> quantized_decomposed.quantize_per_channel.default
|
||||
# -> quantized_decomposed.dequantize_per_channel.default
|
||||
# We only folding fp32_weight -> quantized_decomposed.quantize_per_channel.default into
|
||||
# int8_weight and leave quantized_decomposed.dequantize_per_channel.default in graph to be fused
|
||||
if node.target in [
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
]:
|
||||
# For the pattern fp32_weight -> q -> dq
|
||||
# We only folding fp32_weight -> q
|
||||
# int8_weight and leave dq in graph to be fused
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -163,11 +166,13 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||
|
||||
|
||||
@torch.utils._python_dispatch._disable_current_modes()
|
||||
def constant_fold(gm):
|
||||
def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None):
|
||||
cf = ConstantFolder(gm, skip_constructors=True)
|
||||
cf.run()
|
||||
|
||||
for node, constant in cf.node_replacements.items():
|
||||
if constraint_fn is not None and not constraint_fn(node):
|
||||
continue
|
||||
replace_node_with_constant(gm, node, constant)
|
||||
|
||||
erased_params = []
|
||||
|
@ -98,6 +98,19 @@ def _port_metadata_for_input_quant_nodes(
|
||||
q_node, dq_node = _find_q_dq_node_for_user(input_node, node)
|
||||
if q_node is None or dq_node is None:
|
||||
return
|
||||
# add metadata for all the node between q_node and get_attr node
|
||||
# if the q_node can be traced back to get_attr node
|
||||
q_to_get_attr_nodes = [q_node]
|
||||
q_node_input = q_node.args[0]
|
||||
while isinstance(q_node_input, torch.fx.Node) and q_node_input.op not in [
|
||||
"placeholder",
|
||||
"get_attr",
|
||||
]:
|
||||
q_to_get_attr_nodes.append(q_node_input)
|
||||
q_node_input = q_node_input.args[0]
|
||||
if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr":
|
||||
for n in q_to_get_attr_nodes:
|
||||
_add_metadata(n, q_node_input)
|
||||
_add_metadata(dq_node, node)
|
||||
|
||||
|
||||
|
@ -56,7 +56,7 @@ def _find_q_dq_node_for_user(
|
||||
produer: torch.fx.Node, user: torch.fx.Node
|
||||
) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Find d, dq pair corresponding to [producer ... -> q -> dq -> user]
|
||||
Find q, dq pair corresponding to [producer -> q -> dq -> user]
|
||||
Utils works by finding dq arg of user and ensuring it is connected to
|
||||
producer
|
||||
"""
|
||||
|
@ -1,4 +1,6 @@
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx import Node
|
||||
|
||||
from .pt2e.prepare import prepare
|
||||
from .pt2e.qat_utils import (
|
||||
@ -24,6 +26,7 @@ from torch.ao.quantization.quantizer import ( # noqa: F401
|
||||
from torch.fx.passes.infra.pass_manager import PassManager
|
||||
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
|
||||
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
|
||||
from torch._inductor.constant_folding import constant_fold
|
||||
|
||||
__all__ = [
|
||||
"prepare_pt2e",
|
||||
@ -66,10 +69,40 @@ def prepare_qat_pt2e(
|
||||
model = _disallow_eval_train(model)
|
||||
return model
|
||||
|
||||
_QUANT_OPS = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
||||
]
|
||||
def _quant_node_constraint(n: Node) -> bool:
|
||||
"""If there is any pure ops between get_attr and quantize op they will be const propagated
|
||||
e.g. get_attr(weight) -> transpose -> quantize -> dequantize*
|
||||
(Note: dequantize op is not going to be constant propagated)
|
||||
|
||||
This filter is added because we don't want to constant fold the things that are not
|
||||
related to quantization
|
||||
"""
|
||||
return n.op == "call_function" and n.target in _QUANT_OPS
|
||||
|
||||
def convert_pt2e(
|
||||
model: GraphModule,
|
||||
use_reference_representation: bool = False,
|
||||
fold_quantize: bool = False,
|
||||
) -> GraphModule:
|
||||
"""Convert a calibrated/trained model to a quantized model
|
||||
|
||||
Args:
|
||||
model: calibrated/trained model
|
||||
use_reference_representation: boolean flag to indicate whether to produce referece representation or not
|
||||
fold_quantize: boolean flag to indicate whether fold the quantize op or not
|
||||
|
||||
Note: please set `fold_quantize` to True whenever you can, we'll deprecate this flag and
|
||||
make True the default option in the future, to make sure the change doesn't break BC for you, it's
|
||||
better to set the flag to True now.
|
||||
|
||||
Returns:
|
||||
quantized model, either in q/dq representation or reference representation
|
||||
"""
|
||||
original_graph_meta = model.meta
|
||||
model = _convert_to_reference_decomposed_fx(model)
|
||||
model = _fold_conv_bn_qat(model)
|
||||
@ -78,6 +111,10 @@ def convert_pt2e(
|
||||
|
||||
pm = PassManager([PortNodeMetaForQDQ()])
|
||||
model = pm(model).graph_module
|
||||
|
||||
if fold_quantize:
|
||||
constant_fold(model, _quant_node_constraint)
|
||||
|
||||
if use_reference_representation:
|
||||
model = reference_representation_rewrite(model)
|
||||
|
||||
|
@ -201,7 +201,8 @@ def _get_module_name_filter(module_name: str):
|
||||
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
|
||||
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
|
||||
# }
|
||||
nn_module_stack = n.meta["nn_module_stack"]
|
||||
# get_attr nodes doesn't have nn_module_stack?
|
||||
nn_module_stack = n.meta.get("nn_module_stack", {})
|
||||
names = [
|
||||
n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()
|
||||
]
|
||||
@ -228,7 +229,7 @@ def _get_module_type_filter(tp: Callable):
|
||||
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
|
||||
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
|
||||
# }
|
||||
nn_module_stack = n.meta["nn_module_stack"]
|
||||
nn_module_stack = n.meta.get("nn_module_stack", {})
|
||||
types = [t for _, t in nn_module_stack.values()]
|
||||
return tp in types
|
||||
|
||||
|
Reference in New Issue
Block a user