[quant][pt2e] Enable constant folding for quantize ops (#109343)

Summary:
This PR added constant folding for quantize ops so that instead of storing fp32 weight in the
quantized model, we'll get int8/int16 etc. weight

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_fold_quantize

also will verify in executorch later

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D49399210](https://our.internmc.facebook.com/intern/diff/D49399210)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109343
Approved by: https://github.com/kimishpatel, https://github.com/jgong5
This commit is contained in:
Jerry Zhang
2023-09-26 19:09:30 -07:00
committed by PyTorch MergeBot
parent 6138750ab1
commit 1b51d29b66
11 changed files with 359 additions and 148 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: quantization"]
import copy
import unittest
import torch
import torch._export as export
@ -9,7 +10,6 @@ from torch.ao.quantization.observer import (
MinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.pt2e.utils import _find_q_dq_node_for_user
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
@ -26,6 +26,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
)
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS
class TestHelperModules:
@ -81,6 +82,14 @@ class TestHelperModules:
return w, add_output, extra_output
_DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
class TestDuplicateDQPass(QuantizationTestCase):
def _test_duplicate_dq(
self,
@ -106,19 +115,9 @@ class TestDuplicateDQPass(QuantizationTestCase):
for n in m.graph.nodes:
annotation = n.meta.get("quantization_annotation", None)
if annotation is not None:
input_qspec_map = annotation.input_qspec_map
for input_node, qspec in input_qspec_map.items():
if (
qspec is not None
and hasattr(qspec, "dtype")
and qspec.dtype != torch.float
):
q_node, dq_node = _find_q_dq_node_for_user(input_node, n)
if dq_node is None:
raise ValueError(
f"No dq node found for {n}, even though {n} annotated for quantization."
)
self.assertEqual(len(dq_node.users.keys()), 1)
for arg in n.args:
if isinstance(arg, torch.fx.Node) and arg.target in _DEQUANTIZE_OPS:
self.assertEqual(len(arg.users.keys()), 1)
def test_no_need_for_duplicate_dq(self):
"""

View File

@ -16,6 +16,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTA
from torch.fx import Node
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS
class TestHelperModules:
@ -57,6 +58,8 @@ _QUANT_OPS = {
}
# TODO: rename to TestPortMetadataPass to align with the util name?
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
class TestMetaDataPorting(QuantizationTestCase):
def _test_metadata_porting(
self,
@ -77,23 +80,33 @@ class TestMetaDataPorting(QuantizationTestCase):
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
pt2_quant_output = m(*example_inputs)
recorded_node_tags = {}
for n in m.graph.nodes:
if "quantization_tag" not in n.meta:
continue
if n.op == "call_function" and n.target in _QUANT_OPS:
key = n.target
elif n.op == "get_attr":
key = "get_attr"
else:
continue
if key not in recorded_node_tags:
recorded_node_tags[key] = set()
if (
n.op == "call_function"
and n.target in _QUANT_OPS
and "quantization_tag" in n.meta
and n.meta["quantization_tag"] in recorded_node_tags[key]
):
if n.target not in recorded_node_tags:
recorded_node_tags[n.target] = set()
if n.meta["quantization_tag"] in recorded_node_tags[n.target]:
raise ValueError(
f"{n} has tag {n.meta['quantization_tag']} that is associated with another node of the same type"
)
recorded_node_tags[n.target].add(n.meta["quantization_tag"])
raise ValueError(
f"{key} {n.format_node()} has tag {n.meta['quantization_tag']} that "
"is associated with another node of the same type"
)
recorded_node_tags[key].add(n.meta["quantization_tag"])
self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys()))
for k, v in recorded_node_tags.items():
self.assertEqual(v, node_tags[k])
@ -130,6 +143,10 @@ class TestMetaDataPorting(QuantizationTestCase):
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
get_attr_tags = {
"BackendA_conv2d_0",
"BackendA_linear_0",
}
quantize_per_tensor_tags = {
"BackendA_conv2d_0",
"BackendA_adaptive_avg_pool2d_0",
@ -142,6 +159,7 @@ class TestMetaDataPorting(QuantizationTestCase):
}
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
@ -180,10 +198,12 @@ class TestMetaDataPorting(QuantizationTestCase):
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
get_attr_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
@ -236,6 +256,8 @@ class TestMetaDataPorting(QuantizationTestCase):
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
# TODO: add get_attr_tags when the test is re-enabled
get_attr_tags = {}
quantize_per_tensor_tags = {
"BackendA_conv2d_0",
"BackendA_adaptive_avg_pool2d_0",
@ -252,6 +274,7 @@ class TestMetaDataPorting(QuantizationTestCase):
"BackendA_linear_dynamic_0",
}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
@ -295,6 +318,10 @@ class TestMetaDataPorting(QuantizationTestCase):
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
get_attr_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
choose_qparams_tensor_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
@ -312,6 +339,7 @@ class TestMetaDataPorting(QuantizationTestCase):
"BackendA_linear_dynamic_0",
}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
@ -349,11 +377,13 @@ class TestMetaDataPorting(QuantizationTestCase):
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
get_attr_tags = {"BackendA_linear_dynamic_0"}
choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"}
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,

View File

@ -91,6 +91,31 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
}
def _quantize(self, m, quantizer, example_inputs):
m = capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m, fold_quantize=True)
return m
def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=is_per_channel)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
return self._quantize(m, quantizer, example_inputs)
def _test_quantizer(
self,
@ -118,7 +143,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
pt2_quant_output = m(*example_inputs)
node_occurrence = {
@ -209,7 +234,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1, 3, 5, 5),)
node_occurrence = {
# two for input of the first conv, one for output for the first conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
}
node_list = [
@ -277,18 +302,11 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = torch.nn.Conv2d(2, 2, 1)
x = torch.rand(1, 2, 14, 14)
example_inputs = (x,)
# program capture
m = capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_pt2e(m, BackendAQuantizer())
m(*example_inputs)
m = convert_pt2e(m)
m = self._quantize(m, BackendAQuantizer(), example_inputs)
# Ensure the conv has no observer inserted at output
node_occurrence = {
# two for input of conv
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
}
node_list = [
@ -369,19 +387,12 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = TestHelperModules.ConvMaxPool2d()
x = torch.rand(1, 2, 14, 14)
example_inputs = (x,)
# program capture
m = capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_pt2e(m, BackendAQuantizer())
m(*example_inputs)
m = convert_pt2e(m)
m = self._quantize(m, BackendAQuantizer(), example_inputs)
node_occurrence = {
# two for input of conv
# one for input of maxpool
# one for output of maxpool
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 4,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 3,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 4,
}
node_list = [
@ -469,19 +480,13 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m = capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_pt2e(m, BackendAQuantizer())
m(*example_inputs)
m = convert_pt2e(m)
m = self._quantize(m, BackendAQuantizer(), example_inputs)
node_occurrence = {
# input, weight, bias, output for the conv
# note: quantize op for weight and bias are const propagated
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 4,
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 4,
@ -575,15 +580,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m = capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_pt2e(m, BackendAQuantizer())
m(*example_inputs)
m = convert_pt2e(m)
m(*example_inputs)
m = self._quantize(m, BackendAQuantizer(), example_inputs)
node_occurrence = {
# input, output for the conv
@ -594,9 +591,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 2,
# weight and bias for conv
# note: quantize op for weight and bias are const propagated
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_channel.default
): 2,
): 0,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 2,
@ -653,14 +651,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m = capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_pt2e(m, BackendAQuantizer())
m(*example_inputs)
m = convert_pt2e(m)
m = self._quantize(m, BackendAQuantizer(), example_inputs)
fixed_scale = 1.0 / 256.0
fixed_zero_point = 0
for n in m.graph.nodes:
@ -805,13 +796,13 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
assert conv_output_obs[0] == conv_output_obs[1]
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
node_occurrence = {
# two for input of the first conv, one for output for the first conv
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 7,
): 5,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 7,
@ -873,8 +864,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
quantizer = Int16ActQuantizer()
node_occurrence = {
# two for input of the first conv, one for output for the first conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
# one for input of the first conv, one for output for the first conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
}
node_list = [
@ -892,6 +883,127 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
node_list,
)
def test_fold_quantize(self):
"""Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)
"""
m = self._get_pt2e_quantized_linear()
node_occurrence = {
# quantize op for weight node is folded
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 3,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
def test_fold_quantize_per_channel(self):
"""Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded)
"""
m = self._get_pt2e_quantized_linear(is_per_channel=True)
node_occurrence = {
# quantize op for weight node is folded
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
def test_dont_fold_other_constant(self):
"""Make sure the constant propagation does not apply to things unrelated to
quantization
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
self.dont_fold_me = torch.nn.Parameter(torch.randn(2, 2))
def forward(self, x):
t = self.dont_fold_me.t()
return self.linear(x) + t
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
# only quantize linear, so add is not quantized and the constant Tensor
# should not be folded
quantizer.set_module_type(torch.nn.Linear, operator_config)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = self._quantize(m, quantizer, example_inputs)
node_occurrence = {
# quantize op for weight node is folded
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 3,
# transpose op not folded
ns.call_function(torch.ops.aten.t.default): 1,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
def test_fold_all_ops_before_quantize(self):
"""Test folding all ops that's before quantized operator:
Before:
get_attr(weight) -> transpose -> quantize -> dequantize
After:
get_attr(folded_weight) -> dequantize
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(2, 2)
def forward(self, x):
t = self.weight.t()
return torch.nn.functional.linear(x, t)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = self._quantize(m, quantizer, example_inputs)
node_occurrence = {
# quantize op for weight node is folded
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 3,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
def test_constant_prop_preserve_metadata(self):
"""Test to make sure the get_attr node for const propagated weight Tensor gets the correct
metadata (from original get_attr node from weight)
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config()
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = capture_pre_autograd_graph(
m,
example_inputs,
)
weight_meta = None
for n in m.graph.nodes:
if n.op == "get_attr" and list(n.users)[0].target == torch.ops.aten.linear.default:
weight_meta = n.meta
break
assert weight_meta is not None, "Expect to find metadata for weight node"
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m, fold_quantize=True)
for n in m.graph.nodes:
if n.op == "get_attr" and "frozen_param" in n.target:
self.assertIn("stack_trace", n.meta)
for key in n.meta:
self.assertEqual(n.meta[key], weight_meta[key])
def test_add_and_inplace_add(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
@ -922,27 +1034,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def test_save_load(self):
"""Test save/load a quantized model
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
m = self._get_pt2e_quantized_linear()
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
ref_res = m(*example_inputs)
with TemporaryFileName() as fname:
@ -991,7 +1084,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
@ -1021,7 +1115,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
qconfig = default_per_channel_symmetric_qnnpack_qconfig
@ -1045,7 +1140,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
}
qconfig = default_per_channel_symmetric_qnnpack_qconfig
@ -1072,7 +1168,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
}
qconfig = default_per_channel_symmetric_qnnpack_qconfig
@ -1099,7 +1196,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
qconfig = default_per_channel_symmetric_qnnpack_qconfig
@ -1125,7 +1223,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
@ -1252,7 +1351,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
self.assertEqual(
id(m.activation_post_process_3), id(m.activation_post_process_4)
)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
ns.call_function(
@ -1261,9 +1360,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 5,
# note: quantize op for weights are const propagated
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_channel.default
): 2,
): 0,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 2,
@ -1282,7 +1382,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
act_affine_quant_obs = observer.PlaceholderObserver.with_args(
@ -1324,7 +1425,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
}
act_affine_quant_obs = observer.PlaceholderObserver.with_args(
@ -1372,9 +1474,11 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
act_affine_quant_obs = observer.PlaceholderObserver.with_args(
@ -1473,11 +1577,11 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
quantizer = EmbeddingQuantizer()
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.aten.embedding.default,
]
@ -1577,7 +1681,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
}
self._test_quantizer(
@ -1644,7 +1749,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m.train()
# After convert: still not OK
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
@ -1705,7 +1810,7 @@ class TestQuantizePT2EOps(QuantizationTestCase):
quantizer.set_global(quantization_config)
model_graph = prepare_pt2e(model_graph, quantizer)
model_graph(*example_inputs)
model_graph = convert_pt2e(model_graph)
model_graph = convert_pt2e(model_graph, fold_quantize=True)
self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs))
@ -1767,7 +1872,7 @@ class TestQuantizePT2EOps(QuantizationTestCase):
quantizer.set_global(quantization_config)
model_graph = prepare_pt2e(model_graph, quantizer)
model_graph(*example_inputs)
model_graph = convert_pt2e(model_graph)
model_graph = convert_pt2e(model_graph, fold_quantize=True)
self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs))
@ -1798,7 +1903,7 @@ class TestQuantizePT2EModels(PT2EQuantizationTestCase):
id(m.activation_post_process_3), id(m.activation_post_process_2)
)
after_prepare_result = m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
after_quant_result = m(*example_inputs)

View File

@ -29,7 +29,7 @@ class TestPT2ERepresentation(QuantizationTestCase):
ref_node_occurrence: Dict[ns, int],
non_ref_node_occurrence: Dict[ns, int],
fixed_output_tol: float = None,
output_scale_idx: int = 3,
output_scale_idx: int = 2,
) -> torch.nn.Module:
# resetting dynamo cache
torch._dynamo.reset()
@ -42,7 +42,9 @@ class TestPT2ERepresentation(QuantizationTestCase):
model = prepare_pt2e(model, quantizer)
# Calibrate
model(*example_inputs)
model = convert_pt2e(model, use_reference_representation=True)
model = convert_pt2e(
model, use_reference_representation=True, fold_quantize=True
)
self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence)
# make sure it runs
pt2e_quant_output = model(*example_inputs)
@ -52,7 +54,9 @@ class TestPT2ERepresentation(QuantizationTestCase):
model_copy = prepare_pt2e(model_copy, quantizer)
# Calibrate
model_copy(*example_inputs)
model_copy = convert_pt2e(model_copy, use_reference_representation=False)
model_copy = convert_pt2e(
model_copy, use_reference_representation=False, fold_quantize=True
)
self.checkGraphModuleNodes(
model_copy, expected_node_occurrence=non_ref_node_occurrence
)
@ -253,9 +257,10 @@ class TestPT2ERepresentation(QuantizationTestCase):
): 0,
}
non_ref_node_occurrence = {
# quantize_per_channel is folded
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_channel.default
): 1,
): 0,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 1,

View File

@ -253,7 +253,7 @@ class X86InductorQuantTestCase(QuantizationTestCase):
# Calibrate
m(*example_inputs)
prepare_model = copy.deepcopy(m)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
convert_model = copy.deepcopy(m)
pt2_quant_output = m(*example_inputs)
node_occurrence = {
@ -284,7 +284,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
# one for input and weight of the conv, one for output for the conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
@ -321,7 +322,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
# one for input and weight of the conv, one for output for the relu
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
@ -361,19 +363,21 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
else:
node_occurrence = {
# one for input and weight of the conv
# one for input and weight of another conv
# one for input of the conv
# one for input of another conv
# one for output for the add
# 2 conv will share same input quant/dequant
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
@ -410,24 +414,26 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
).eval()
if conv2d_type != Conv2DType.both:
node_occurrence = {
# one for input and weight of the conv
# one for input for conv
# one for output for the relu
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
else:
node_occurrence = {
# one for input and weight of the conv
# one for input and weight of another conv
# one for input of the conv
# one for input of another conv
# one for output for the relu
# 2 conv will share same input quant/dequant
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
@ -458,7 +464,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7,
torch.ops.quantized_decomposed.quantize_per_channel.default: 4,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
}
node_list = [
@ -498,7 +505,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
# one for input and weight of the conv, two for input/output for the maxpool2d
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
@ -558,7 +566,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 7,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7,
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
}
node_list = [
@ -632,7 +641,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
@ -688,7 +698,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
@ -739,7 +750,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
@ -799,7 +811,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
# one for input and weight, one for output
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
@ -841,7 +854,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
# one for input and weight of the conv, one for output for the relu
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [

View File

@ -83,6 +83,8 @@ except ImportError as e:
try:
# To be moved to compiler side later
from quantization.pt2e.test_graph_utils import TestGraphUtils # noqa: F401
from quantization.pt2e.test_duplicate_dq import TestDuplicateDQPass # noqa: F401
from quantization.pt2e.test_metadata_porting import TestMetaDataPorting # noqa: F401
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2E # noqa: F401
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EOps # noqa: F401
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EModels # noqa: F401

View File

@ -1,5 +1,5 @@
import collections
from typing import Any, Dict
from typing import Any, Callable, Dict, Optional
import torch
import torch.utils._pytree as pytree
@ -51,11 +51,14 @@ class ConstantFolder(torch.fx.Interpreter):
self.user_to_last_uses = self.node_to_last_non_output_use()
def is_impure(self, node: torch.fx.node.Node):
if node.target == torch.ops.quantized_decomposed.dequantize_per_channel.default:
# For the pattern fp32_weight -> quantized_decomposed.quantize_per_channel.default
# -> quantized_decomposed.dequantize_per_channel.default
# We only folding fp32_weight -> quantized_decomposed.quantize_per_channel.default into
# int8_weight and leave quantized_decomposed.dequantize_per_channel.default in graph to be fused
if node.target in [
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]:
# For the pattern fp32_weight -> q -> dq
# We only folding fp32_weight -> q
# int8_weight and leave dq in graph to be fused
return True
return False
@ -163,11 +166,13 @@ class ConstantFolder(torch.fx.Interpreter):
@torch.utils._python_dispatch._disable_current_modes()
def constant_fold(gm):
def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None):
cf = ConstantFolder(gm, skip_constructors=True)
cf.run()
for node, constant in cf.node_replacements.items():
if constraint_fn is not None and not constraint_fn(node):
continue
replace_node_with_constant(gm, node, constant)
erased_params = []

View File

@ -98,6 +98,19 @@ def _port_metadata_for_input_quant_nodes(
q_node, dq_node = _find_q_dq_node_for_user(input_node, node)
if q_node is None or dq_node is None:
return
# add metadata for all the node between q_node and get_attr node
# if the q_node can be traced back to get_attr node
q_to_get_attr_nodes = [q_node]
q_node_input = q_node.args[0]
while isinstance(q_node_input, torch.fx.Node) and q_node_input.op not in [
"placeholder",
"get_attr",
]:
q_to_get_attr_nodes.append(q_node_input)
q_node_input = q_node_input.args[0]
if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr":
for n in q_to_get_attr_nodes:
_add_metadata(n, q_node_input)
_add_metadata(dq_node, node)

View File

@ -56,7 +56,7 @@ def _find_q_dq_node_for_user(
produer: torch.fx.Node, user: torch.fx.Node
) -> Tuple[Any, Any]:
"""
Find d, dq pair corresponding to [producer ... -> q -> dq -> user]
Find q, dq pair corresponding to [producer -> q -> dq -> user]
Utils works by finding dq arg of user and ensuring it is connected to
producer
"""

View File

@ -1,4 +1,6 @@
import torch
from torch.fx import GraphModule
from torch.fx import Node
from .pt2e.prepare import prepare
from .pt2e.qat_utils import (
@ -24,6 +26,7 @@ from torch.ao.quantization.quantizer import ( # noqa: F401
from torch.fx.passes.infra.pass_manager import PassManager
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
from torch._inductor.constant_folding import constant_fold
__all__ = [
"prepare_pt2e",
@ -66,10 +69,40 @@ def prepare_qat_pt2e(
model = _disallow_eval_train(model)
return model
_QUANT_OPS = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
]
def _quant_node_constraint(n: Node) -> bool:
"""If there is any pure ops between get_attr and quantize op they will be const propagated
e.g. get_attr(weight) -> transpose -> quantize -> dequantize*
(Note: dequantize op is not going to be constant propagated)
This filter is added because we don't want to constant fold the things that are not
related to quantization
"""
return n.op == "call_function" and n.target in _QUANT_OPS
def convert_pt2e(
model: GraphModule,
use_reference_representation: bool = False,
fold_quantize: bool = False,
) -> GraphModule:
"""Convert a calibrated/trained model to a quantized model
Args:
model: calibrated/trained model
use_reference_representation: boolean flag to indicate whether to produce referece representation or not
fold_quantize: boolean flag to indicate whether fold the quantize op or not
Note: please set `fold_quantize` to True whenever you can, we'll deprecate this flag and
make True the default option in the future, to make sure the change doesn't break BC for you, it's
better to set the flag to True now.
Returns:
quantized model, either in q/dq representation or reference representation
"""
original_graph_meta = model.meta
model = _convert_to_reference_decomposed_fx(model)
model = _fold_conv_bn_qat(model)
@ -78,6 +111,10 @@ def convert_pt2e(
pm = PassManager([PortNodeMetaForQDQ()])
model = pm(model).graph_module
if fold_quantize:
constant_fold(model, _quant_node_constraint)
if use_reference_representation:
model = reference_representation_rewrite(model)

View File

@ -201,7 +201,8 @@ def _get_module_name_filter(module_name: str):
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
# }
nn_module_stack = n.meta["nn_module_stack"]
# get_attr nodes doesn't have nn_module_stack?
nn_module_stack = n.meta.get("nn_module_stack", {})
names = [
n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()
]
@ -228,7 +229,7 @@ def _get_module_type_filter(tp: Callable):
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
# }
nn_module_stack = n.meta["nn_module_stack"]
nn_module_stack = n.meta.get("nn_module_stack", {})
types = [t for _, t in nn_module_stack.values()]
return tp in types