[Quant][PT2E] Enable linear-binary(-unary) post-op recipe for X86Inductor quantizer (#122387)

As the title
**Test plan**
python test/test_quantization.py -k test_linear_binary

Differential Revision: [D56288440](https://our.internmc.facebook.com/intern/diff/D56288440)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122387
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
ghstack dependencies: #123240
This commit is contained in:
Xia, Weiwen
2024-04-26 09:02:56 +08:00
committed by PyTorch MergeBot
parent dc4c75ba72
commit 35b332882b
2 changed files with 489 additions and 14 deletions

View File

@ -13,7 +13,10 @@ from torch.ao.quantization.quantize_pt2e import (
prepare_pt2e,
prepare_qat_pt2e,
)
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
QUANT_ANNOTATION_KEY,
X86InductorQuantizer,
)
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
@ -24,7 +27,7 @@ from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import skipIfTorchDynamo
class Conv2DType(Enum):
class NodePosType(Enum):
left = 1
right = 2
both = 3
@ -67,7 +70,7 @@ class TestHelperModules:
def __init__(
self,
inplace_add: bool = False,
conv2d_type: Conv2DType = Conv2DType.left,
conv2d_type: NodePosType = NodePosType.left,
use_bias: bool = False,
with_bn: bool = False,
) -> None:
@ -95,7 +98,7 @@ class TestHelperModules:
self.with_bn = with_bn
def forward(self, x):
if self.conv2d_type == Conv2DType.left:
if self.conv2d_type == NodePosType.left:
if self.inplace_add:
tmp = self.conv(x)
if self.with_bn:
@ -107,14 +110,14 @@ class TestHelperModules:
if self.with_bn:
tmp = self.bn(tmp)
return tmp + self.relu(x)
elif self.conv2d_type == Conv2DType.right:
elif self.conv2d_type == NodePosType.right:
if self.inplace_add:
tmp = self.relu(x)
tmp += self.conv(x)
return tmp
else:
return self.relu(x) + self.conv(x)
elif self.conv2d_type == Conv2DType.both:
elif self.conv2d_type == NodePosType.both:
if self.inplace_add:
tmp = self.conv(x)
tmp += self.conv2(x)
@ -126,7 +129,7 @@ class TestHelperModules:
def __init__(
self,
inplace_add: bool = False,
conv2d_type: Conv2DType = Conv2DType.left,
conv2d_type: NodePosType = NodePosType.left,
inplace_relu: bool = False,
use_bias: bool = False,
with_bn: bool = False,
@ -156,7 +159,7 @@ class TestHelperModules:
self.with_bn = with_bn
def forward(self, x):
if self.conv2d_type == Conv2DType.left:
if self.conv2d_type == NodePosType.left:
if self.inplace_add:
tmp = self.conv(x)
if self.with_bn:
@ -168,14 +171,14 @@ class TestHelperModules:
if self.with_bn:
tmp = self.bn(tmp)
return self.relu2(tmp + self.relu(x))
elif self.conv2d_type == Conv2DType.right:
elif self.conv2d_type == NodePosType.right:
if self.inplace_add:
tmp = self.relu(x)
tmp += self.conv(x)
return self.relu2(tmp)
else:
return self.relu2(self.relu(x) + self.conv(x))
elif self.conv2d_type == Conv2DType.both:
elif self.conv2d_type == NodePosType.both:
if self.inplace_add:
tmp = self.conv(x)
tmp += self.conv2(x)
@ -333,6 +336,131 @@ class TestHelperModules:
def forward(self, x):
return self.postop(self.linear(x))
class LinearAddModule(torch.nn.Module):
def __init__(
self,
inplace_add: bool = False,
linear_pos: NodePosType = NodePosType.left,
use_bias: bool = False,
) -> None:
super().__init__()
self.linear = torch.nn.Linear(
in_features=16, out_features=16, bias=use_bias
)
self.linear2 = torch.nn.Linear(
in_features=16, out_features=16, bias=use_bias
)
self.relu = nn.ReLU()
self.inplace_add = inplace_add
self.linear_pos = linear_pos
def forward(self, x):
if self.linear_pos == NodePosType.left:
if self.inplace_add:
tmp = self.linear(x)
tmp += self.relu(x)
return tmp
else:
tmp = self.linear(x)
return tmp + self.relu(x)
elif self.linear_pos == NodePosType.right:
if self.inplace_add:
tmp = self.relu(x)
tmp += self.linear(x)
return tmp
else:
return self.relu(x) + self.linear(x)
elif self.linear_pos == NodePosType.both:
if self.inplace_add:
tmp = self.linear(x)
tmp += self.linear2(x)
return tmp
else:
return self.linear(x) + self.linear2(x)
class LinearAddReLUModule(torch.nn.Module):
def __init__(
self,
inplace_add: bool = False,
linear_pos: NodePosType = NodePosType.left,
inplace_relu: bool = False,
use_bias: bool = False,
) -> None:
super().__init__()
self.linear = torch.nn.Linear(
in_features=16, out_features=16, bias=use_bias
)
self.linear2 = torch.nn.Linear(
in_features=16, out_features=16, bias=use_bias
)
self.relu = nn.ReLU()
self.inplace_add = inplace_add
self.linear_pos = linear_pos
self.relu2 = nn.ReLU(inplace=inplace_relu)
def forward(self, x):
if self.linear_pos == NodePosType.left:
if self.inplace_add:
tmp = self.linear(x)
tmp += self.relu(x)
return self.relu2(tmp)
else:
tmp = self.linear(x)
return self.relu2(tmp + self.relu(x))
elif self.linear_pos == NodePosType.right:
if self.inplace_add:
tmp = self.relu(x)
tmp += self.linear(x)
return self.relu2(tmp)
else:
return self.relu2(self.relu(x) + self.linear(x))
elif self.linear_pos == NodePosType.both:
if self.inplace_add:
tmp = self.linear(x)
tmp += self.linear2(x)
return self.relu2(tmp)
else:
return self.relu2(self.linear(x) + self.linear2(x))
class SerialsLinearAddReLUModule(torch.nn.Module):
"""Serials of 2 Linear -> Add -> ReLU Pattern."""
def __init__(
self,
) -> None:
super().__init__()
self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True)
self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True)
self.linear3 = torch.nn.Linear(in_features=16, out_features=16, bias=True)
self.linear4 = torch.nn.Linear(in_features=16, out_features=16, bias=True)
self.relu = nn.ReLU()
self.relu2 = nn.ReLU()
def forward(self, x):
x1 = self.linear(x)
res1 = self.relu(self.linear2(x1) + self.linear3(x1))
res2 = self.relu2(self.linear4(res1) + res1)
return res2
class LinearAddModule2(torch.nn.Module):
def __init__(
self,
inplace_add: bool = False,
) -> None:
super().__init__()
self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True)
self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True)
self.inplace_add = inplace_add
def forward(self, x):
if self.inplace_add:
tmp = self.linear(x)
tmp += self.linear2(tmp)
return tmp
else:
tmp = self.linear(x)
return tmp + self.linear2(tmp)
class Conv2dAddModule2(torch.nn.Module):
def __init__(
self,
@ -550,7 +678,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
Test pattern of conv2d with binary post ops (such as add) with X86InductorQuantizer.
Currently, only add as binary post op is supported.
"""
conv2d_type_list = [Conv2DType.left, Conv2DType.both]
conv2d_type_list = [NodePosType.left, NodePosType.both]
example_inputs = (torch.randn(2, 3, 6, 6),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
@ -558,7 +686,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
with override_quantized_engine("x86"), torch.no_grad():
for conv2d_type in conv2d_type_list:
m = TestHelperModules.Conv2dAddModule(conv2d_type=conv2d_type).eval()
if conv2d_type != Conv2DType.both:
if conv2d_type != NodePosType.both:
node_occurrence = {
# one for input and weight of the conv
# one for extra input node of add
@ -641,7 +769,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
Test pattern of conv2d with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
Currently, only add as binary post op and relu as unary post op are supported.
"""
conv2d_type_list = [Conv2DType.left, Conv2DType.both]
conv2d_type_list = [NodePosType.left, NodePosType.both]
example_inputs = (torch.randn(2, 3, 6, 6),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
@ -651,7 +779,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
m = TestHelperModules.Conv2dAddReLUModule(
conv2d_type=conv2d_type,
).eval()
if conv2d_type != Conv2DType.both:
if conv2d_type != NodePosType.both:
node_occurrence = {
# one for input for conv
# one for extra input node of add
@ -1157,6 +1285,271 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_list,
)
def _check_annotation_stat(self, gm, expected_stat_dict):
# Check expected annotation statistics to ensure the annotation is correct
def _check_annotation(node):
annot = node.meta.get(QUANT_ANNOTATION_KEY, None)
if annot is None:
return False, False
return annot._annotated, annot._is_output_of_quantized_pattern
for node in gm.graph.nodes:
if node.target in expected_stat_dict.keys():
annotated, is_quant_out = _check_annotation(node)
expected_stat_dict[node.target]["annotated"] -= annotated
expected_stat_dict[node.target]["is_quant_out"] -= is_quant_out
for op_stat in expected_stat_dict.values():
assert all(v == 0 for v in op_stat.values())
@skipIfNoX86
def test_linear_binary(self):
"""
Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer.
Currently, only add as binary post op is supported.
"""
linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
# TODO test for inplace add after refactoring of capture_pre_autograd_graph
inplace_add_list = [False]
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
cases = itertools.product(linear_pos_list, inplace_add_list)
with override_quantized_engine("x86"), torch.no_grad():
for linear_pos, inplace_add in cases:
m = TestHelperModules.LinearAddModule(
inplace_add=inplace_add, linear_pos=linear_pos
).eval()
if linear_pos != NodePosType.both:
node_occurrence = {
# Only one 1 q-dq for input of the linear
# No q-dq for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
else:
node_occurrence = {
# One quantize_per_tensor for both linear nodes (shared)
# Two dequantize_per_tensor for two linear nodes
# No q-dq for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.aten.add_.Tensor
if inplace_add
else torch.ops.aten.add.Tensor,
]
fq_m = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)[-1]
# One linear and add are fused. The other linear is quantized alone if present
aten = torch.ops.aten
add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
expected_annotation_stat = {
aten.linear.default: {
"annotated": 2 if linear_pos == NodePosType.both else 1,
"is_quant_out": 1 if linear_pos == NodePosType.both else 0,
},
add_op: {"annotated": 1, "is_quant_out": 1},
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfNoX86
def test_linear_binary2(self):
"""
Test Pattern:
tmp = linear_1(x)
tmp2 = linear_2(tmp)
return tmp + tmp2
Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1
"""
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
# TODO test for inplace add after refactoring of capture_pre_autograd_graph
inplace_add_list = [False]
with override_quantized_engine("x86"), torch.no_grad():
for inplace_add in inplace_add_list:
m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval()
# Two q-dq nodes for inputs of linear nodes
# No q-dq for extra input node of add
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.aten.add_.Tensor
if inplace_add
else torch.ops.aten.add.Tensor,
]
fq_m = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)[-1]
# One linear and add are fused. The other linear is quantized alone if present
aten = torch.ops.aten
add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
expected_annotation_stat = {
aten.linear.default: {
"annotated": 2,
"is_quant_out": 1,
},
add_op: {"annotated": 1, "is_quant_out": 1},
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfNoX86
def test_linear_binary_unary(self):
"""
Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
Currently, only add as binary post op and relu as unary post op are supported.
"""
linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
# TODO test for inplace add after refactoring of capture_pre_autograd_graph
inplace_add_list = [False]
# TODO test for inplace relu after refactoring of capture_pre_autograd_graph
inplace_relu_list = [False]
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list)
with override_quantized_engine("x86"), torch.no_grad():
for linear_pos, inplace_add, inplace_relu in cases:
m = TestHelperModules.LinearAddReLUModule(
inplace_add=inplace_add,
linear_pos=linear_pos,
inplace_relu=inplace_relu,
).eval()
if linear_pos != NodePosType.both:
node_occurrence = {
# Only one q-dq node for input of the linear
# No q-dq node for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
else:
node_occurrence = {
# One quantize_per_tensor for both linear nodes (shared)
# Two dequantize_per_tensor for two linear nodes
# No q-dq for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.aten.add_.Tensor
if inplace_add
else torch.ops.aten.add.Tensor,
]
fq_m = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)[-1]
# linear, add, relu are fused
# The other linear is quantized alone if present
aten = torch.ops.aten
add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
relu_op = aten.relu_.default if inplace_relu else aten.relu.default
expected_annotation_stat = {
aten.linear.default: {
"annotated": 2 if linear_pos == NodePosType.both else 1,
"is_quant_out": 1 if linear_pos == NodePosType.both else 0,
},
add_op: {"annotated": 1, "is_quant_out": 0},
relu_op: {"annotated": 1, "is_quant_out": 1},
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfNoX86
def test_linear_binary_unary_serials(self):
"""
Test pattern of 2 following up linear add relu with X86InductorQuantizer.
"""
with override_quantized_engine("x86"), torch.no_grad():
m = TestHelperModules.SerialsLinearAddReLUModule().eval()
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
node_occurrence = {
# quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4
# dequantize_per_tensor: 1 for each linear
# No q-dq for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.aten.linear.default,
torch.ops.aten.add.Tensor,
torch.ops.aten.relu.default,
]
fq_m = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)[-1]
# Two linear nodes are quantized alone
# The other two are fused with add and relu
aten = torch.ops.aten
expected_annotation_stat = {
aten.linear.default: {
"annotated": 4,
"is_quant_out": 2,
},
aten.add.Tensor: {"annotated": 2, "is_quant_out": 0},
aten.relu.default: {"annotated": 2, "is_quant_out": 2},
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfTorchDynamo("very slow")
@skipIfNoX86
def test_qat_conv2d(self):

View File

@ -766,6 +766,7 @@ class X86InductorQuantizer(Quantizer):
if config := self._get_aten_operator_qconfig(torch.ops.aten.linear.default):
if config.input_activation and not config.input_activation.is_dynamic:
# <TODO> Weiwen: Dynamic Quant of linear unary will be supported in next step
self._annotate_linear_binary_unary(model, config)
self._annotate_linear_unary(model, config)
self._annotate_linear(model, config)
@ -1141,6 +1142,87 @@ class X86InductorQuantizer(Quantizer):
_is_output_of_quantized_pattern=True,
)
def _annotate_linear_binary_unary(
self,
gm: torch.fx.GraphModule,
quantization_config: QuantizationConfig,
) -> None:
# linear + binary_op + (optional) unary op
binary_op_list = [operator.add]
unary_op_list = [torch.nn.ReLU, None]
combinations = itertools.product(binary_op_list, unary_op_list)
for binary_op, unary_op in combinations:
has_unary = unary_op is not None
seq_partition = [torch.nn.Linear, binary_op]
if has_unary:
seq_partition.append(unary_op)
fused_partitions = find_sequential_partitions(gm, seq_partition)
for fused_partition in fused_partitions:
unary_partition, unary_node = None, None
if has_unary:
(
linear_partition,
binary_partition,
unary_partition,
) = fused_partition
(
linear_node,
binary_node,
unary_node,
) = self._get_output_nodes_of_partitions(
[linear_partition, binary_partition, unary_partition]
)
else:
linear_partition, binary_partition = fused_partition
linear_node, binary_node = self._get_output_nodes_of_partitions(
[linear_partition, binary_partition]
)
if len(linear_node.users) != 1:
# Linear Node should only has 1 user node
continue
(
linear_node_idx,
extra_input_node_idx,
) = self._get_input_idx_for_binary_node(linear_node, binary_node)
if (linear_node_idx is None) or (extra_input_node_idx is None):
continue
if linear_node != binary_node.args[linear_node_idx]:
raise ValueError(
f"{linear_node} doesn't match input of binary node"
)
assert isinstance(linear_node, Node)
if (
linear_node.op != "call_function"
or linear_node.target != torch.ops.aten.linear.default
):
# No linear node found to be fused with add
continue
node_list = (
[binary_node, linear_node]
if unary_node is None
else [unary_node, binary_node, linear_node]
)
if _is_annotated(node_list):
continue
self._annotate_linear_node_helper(
linear_node, False, quantization_config
)
# We don't insert q-dq before the binary input node due to accuracy issues
binary_node.meta[
QUANT_ANNOTATION_KEY
] = _X86InductorQuantizationAnnotation(
input_qspec_map={},
_annotated=True,
_is_output_of_quantized_pattern=(not has_unary),
)
if unary_node is not None:
unary_node.meta[
QUANT_ANNOTATION_KEY
] = _X86InductorQuantizationAnnotation(
_annotated=True,
_is_output_of_quantized_pattern=True,
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass