mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Quant][PT2E] Enable linear-binary(-unary) post-op recipe for X86Inductor quantizer (#122387)
As the title **Test plan** python test/test_quantization.py -k test_linear_binary Differential Revision: [D56288440](https://our.internmc.facebook.com/intern/diff/D56288440) Pull Request resolved: https://github.com/pytorch/pytorch/pull/122387 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5 ghstack dependencies: #123240
This commit is contained in:
committed by
PyTorch MergeBot
parent
dc4c75ba72
commit
35b332882b
@ -13,7 +13,10 @@ from torch.ao.quantization.quantize_pt2e import (
|
||||
prepare_pt2e,
|
||||
prepare_qat_pt2e,
|
||||
)
|
||||
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
|
||||
from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
|
||||
QUANT_ANNOTATION_KEY,
|
||||
X86InductorQuantizer,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
QuantizationTestCase,
|
||||
@ -24,7 +27,7 @@ from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
|
||||
|
||||
class Conv2DType(Enum):
|
||||
class NodePosType(Enum):
|
||||
left = 1
|
||||
right = 2
|
||||
both = 3
|
||||
@ -67,7 +70,7 @@ class TestHelperModules:
|
||||
def __init__(
|
||||
self,
|
||||
inplace_add: bool = False,
|
||||
conv2d_type: Conv2DType = Conv2DType.left,
|
||||
conv2d_type: NodePosType = NodePosType.left,
|
||||
use_bias: bool = False,
|
||||
with_bn: bool = False,
|
||||
) -> None:
|
||||
@ -95,7 +98,7 @@ class TestHelperModules:
|
||||
self.with_bn = with_bn
|
||||
|
||||
def forward(self, x):
|
||||
if self.conv2d_type == Conv2DType.left:
|
||||
if self.conv2d_type == NodePosType.left:
|
||||
if self.inplace_add:
|
||||
tmp = self.conv(x)
|
||||
if self.with_bn:
|
||||
@ -107,14 +110,14 @@ class TestHelperModules:
|
||||
if self.with_bn:
|
||||
tmp = self.bn(tmp)
|
||||
return tmp + self.relu(x)
|
||||
elif self.conv2d_type == Conv2DType.right:
|
||||
elif self.conv2d_type == NodePosType.right:
|
||||
if self.inplace_add:
|
||||
tmp = self.relu(x)
|
||||
tmp += self.conv(x)
|
||||
return tmp
|
||||
else:
|
||||
return self.relu(x) + self.conv(x)
|
||||
elif self.conv2d_type == Conv2DType.both:
|
||||
elif self.conv2d_type == NodePosType.both:
|
||||
if self.inplace_add:
|
||||
tmp = self.conv(x)
|
||||
tmp += self.conv2(x)
|
||||
@ -126,7 +129,7 @@ class TestHelperModules:
|
||||
def __init__(
|
||||
self,
|
||||
inplace_add: bool = False,
|
||||
conv2d_type: Conv2DType = Conv2DType.left,
|
||||
conv2d_type: NodePosType = NodePosType.left,
|
||||
inplace_relu: bool = False,
|
||||
use_bias: bool = False,
|
||||
with_bn: bool = False,
|
||||
@ -156,7 +159,7 @@ class TestHelperModules:
|
||||
self.with_bn = with_bn
|
||||
|
||||
def forward(self, x):
|
||||
if self.conv2d_type == Conv2DType.left:
|
||||
if self.conv2d_type == NodePosType.left:
|
||||
if self.inplace_add:
|
||||
tmp = self.conv(x)
|
||||
if self.with_bn:
|
||||
@ -168,14 +171,14 @@ class TestHelperModules:
|
||||
if self.with_bn:
|
||||
tmp = self.bn(tmp)
|
||||
return self.relu2(tmp + self.relu(x))
|
||||
elif self.conv2d_type == Conv2DType.right:
|
||||
elif self.conv2d_type == NodePosType.right:
|
||||
if self.inplace_add:
|
||||
tmp = self.relu(x)
|
||||
tmp += self.conv(x)
|
||||
return self.relu2(tmp)
|
||||
else:
|
||||
return self.relu2(self.relu(x) + self.conv(x))
|
||||
elif self.conv2d_type == Conv2DType.both:
|
||||
elif self.conv2d_type == NodePosType.both:
|
||||
if self.inplace_add:
|
||||
tmp = self.conv(x)
|
||||
tmp += self.conv2(x)
|
||||
@ -333,6 +336,131 @@ class TestHelperModules:
|
||||
def forward(self, x):
|
||||
return self.postop(self.linear(x))
|
||||
|
||||
class LinearAddModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
inplace_add: bool = False,
|
||||
linear_pos: NodePosType = NodePosType.left,
|
||||
use_bias: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=use_bias
|
||||
)
|
||||
self.linear2 = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=use_bias
|
||||
)
|
||||
self.relu = nn.ReLU()
|
||||
self.inplace_add = inplace_add
|
||||
self.linear_pos = linear_pos
|
||||
|
||||
def forward(self, x):
|
||||
if self.linear_pos == NodePosType.left:
|
||||
if self.inplace_add:
|
||||
tmp = self.linear(x)
|
||||
tmp += self.relu(x)
|
||||
return tmp
|
||||
else:
|
||||
tmp = self.linear(x)
|
||||
return tmp + self.relu(x)
|
||||
elif self.linear_pos == NodePosType.right:
|
||||
if self.inplace_add:
|
||||
tmp = self.relu(x)
|
||||
tmp += self.linear(x)
|
||||
return tmp
|
||||
else:
|
||||
return self.relu(x) + self.linear(x)
|
||||
elif self.linear_pos == NodePosType.both:
|
||||
if self.inplace_add:
|
||||
tmp = self.linear(x)
|
||||
tmp += self.linear2(x)
|
||||
return tmp
|
||||
else:
|
||||
return self.linear(x) + self.linear2(x)
|
||||
|
||||
class LinearAddReLUModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
inplace_add: bool = False,
|
||||
linear_pos: NodePosType = NodePosType.left,
|
||||
inplace_relu: bool = False,
|
||||
use_bias: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=use_bias
|
||||
)
|
||||
self.linear2 = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=use_bias
|
||||
)
|
||||
self.relu = nn.ReLU()
|
||||
self.inplace_add = inplace_add
|
||||
self.linear_pos = linear_pos
|
||||
self.relu2 = nn.ReLU(inplace=inplace_relu)
|
||||
|
||||
def forward(self, x):
|
||||
if self.linear_pos == NodePosType.left:
|
||||
if self.inplace_add:
|
||||
tmp = self.linear(x)
|
||||
tmp += self.relu(x)
|
||||
return self.relu2(tmp)
|
||||
else:
|
||||
tmp = self.linear(x)
|
||||
return self.relu2(tmp + self.relu(x))
|
||||
elif self.linear_pos == NodePosType.right:
|
||||
if self.inplace_add:
|
||||
tmp = self.relu(x)
|
||||
tmp += self.linear(x)
|
||||
return self.relu2(tmp)
|
||||
else:
|
||||
return self.relu2(self.relu(x) + self.linear(x))
|
||||
elif self.linear_pos == NodePosType.both:
|
||||
if self.inplace_add:
|
||||
tmp = self.linear(x)
|
||||
tmp += self.linear2(x)
|
||||
return self.relu2(tmp)
|
||||
else:
|
||||
return self.relu2(self.linear(x) + self.linear2(x))
|
||||
|
||||
class SerialsLinearAddReLUModule(torch.nn.Module):
|
||||
"""Serials of 2 Linear -> Add -> ReLU Pattern."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True)
|
||||
self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True)
|
||||
self.linear3 = torch.nn.Linear(in_features=16, out_features=16, bias=True)
|
||||
self.linear4 = torch.nn.Linear(in_features=16, out_features=16, bias=True)
|
||||
self.relu = nn.ReLU()
|
||||
self.relu2 = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.linear(x)
|
||||
res1 = self.relu(self.linear2(x1) + self.linear3(x1))
|
||||
res2 = self.relu2(self.linear4(res1) + res1)
|
||||
return res2
|
||||
|
||||
class LinearAddModule2(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
inplace_add: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True)
|
||||
self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True)
|
||||
self.inplace_add = inplace_add
|
||||
|
||||
def forward(self, x):
|
||||
if self.inplace_add:
|
||||
tmp = self.linear(x)
|
||||
tmp += self.linear2(tmp)
|
||||
return tmp
|
||||
else:
|
||||
tmp = self.linear(x)
|
||||
return tmp + self.linear2(tmp)
|
||||
|
||||
class Conv2dAddModule2(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -550,7 +678,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
Test pattern of conv2d with binary post ops (such as add) with X86InductorQuantizer.
|
||||
Currently, only add as binary post op is supported.
|
||||
"""
|
||||
conv2d_type_list = [Conv2DType.left, Conv2DType.both]
|
||||
conv2d_type_list = [NodePosType.left, NodePosType.both]
|
||||
example_inputs = (torch.randn(2, 3, 6, 6),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
@ -558,7 +686,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for conv2d_type in conv2d_type_list:
|
||||
m = TestHelperModules.Conv2dAddModule(conv2d_type=conv2d_type).eval()
|
||||
if conv2d_type != Conv2DType.both:
|
||||
if conv2d_type != NodePosType.both:
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv
|
||||
# one for extra input node of add
|
||||
@ -641,7 +769,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
Test pattern of conv2d with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
|
||||
Currently, only add as binary post op and relu as unary post op are supported.
|
||||
"""
|
||||
conv2d_type_list = [Conv2DType.left, Conv2DType.both]
|
||||
conv2d_type_list = [NodePosType.left, NodePosType.both]
|
||||
example_inputs = (torch.randn(2, 3, 6, 6),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
@ -651,7 +779,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
m = TestHelperModules.Conv2dAddReLUModule(
|
||||
conv2d_type=conv2d_type,
|
||||
).eval()
|
||||
if conv2d_type != Conv2DType.both:
|
||||
if conv2d_type != NodePosType.both:
|
||||
node_occurrence = {
|
||||
# one for input for conv
|
||||
# one for extra input node of add
|
||||
@ -1157,6 +1285,271 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_list,
|
||||
)
|
||||
|
||||
def _check_annotation_stat(self, gm, expected_stat_dict):
|
||||
# Check expected annotation statistics to ensure the annotation is correct
|
||||
|
||||
def _check_annotation(node):
|
||||
annot = node.meta.get(QUANT_ANNOTATION_KEY, None)
|
||||
if annot is None:
|
||||
return False, False
|
||||
return annot._annotated, annot._is_output_of_quantized_pattern
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.target in expected_stat_dict.keys():
|
||||
annotated, is_quant_out = _check_annotation(node)
|
||||
expected_stat_dict[node.target]["annotated"] -= annotated
|
||||
expected_stat_dict[node.target]["is_quant_out"] -= is_quant_out
|
||||
for op_stat in expected_stat_dict.values():
|
||||
assert all(v == 0 for v in op_stat.values())
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary(self):
|
||||
"""
|
||||
Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer.
|
||||
Currently, only add as binary post op is supported.
|
||||
"""
|
||||
linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
|
||||
# TODO test for inplace add after refactoring of capture_pre_autograd_graph
|
||||
inplace_add_list = [False]
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
cases = itertools.product(linear_pos_list, inplace_add_list)
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for linear_pos, inplace_add in cases:
|
||||
m = TestHelperModules.LinearAddModule(
|
||||
inplace_add=inplace_add, linear_pos=linear_pos
|
||||
).eval()
|
||||
if linear_pos != NodePosType.both:
|
||||
node_occurrence = {
|
||||
# Only one 1 q-dq for input of the linear
|
||||
# No q-dq for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
else:
|
||||
node_occurrence = {
|
||||
# One quantize_per_tensor for both linear nodes (shared)
|
||||
# Two dequantize_per_tensor for two linear nodes
|
||||
# No q-dq for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.add_.Tensor
|
||||
if inplace_add
|
||||
else torch.ops.aten.add.Tensor,
|
||||
]
|
||||
fq_m = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)[-1]
|
||||
# One linear and add are fused. The other linear is quantized alone if present
|
||||
aten = torch.ops.aten
|
||||
add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
|
||||
expected_annotation_stat = {
|
||||
aten.linear.default: {
|
||||
"annotated": 2 if linear_pos == NodePosType.both else 1,
|
||||
"is_quant_out": 1 if linear_pos == NodePosType.both else 0,
|
||||
},
|
||||
add_op: {"annotated": 1, "is_quant_out": 1},
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary2(self):
|
||||
"""
|
||||
Test Pattern:
|
||||
tmp = linear_1(x)
|
||||
tmp2 = linear_2(tmp)
|
||||
return tmp + tmp2
|
||||
Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1
|
||||
"""
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
# TODO test for inplace add after refactoring of capture_pre_autograd_graph
|
||||
inplace_add_list = [False]
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for inplace_add in inplace_add_list:
|
||||
m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval()
|
||||
# Two q-dq nodes for inputs of linear nodes
|
||||
# No q-dq for extra input node of add
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.aten.add_.Tensor
|
||||
if inplace_add
|
||||
else torch.ops.aten.add.Tensor,
|
||||
]
|
||||
fq_m = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)[-1]
|
||||
# One linear and add are fused. The other linear is quantized alone if present
|
||||
aten = torch.ops.aten
|
||||
add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
|
||||
expected_annotation_stat = {
|
||||
aten.linear.default: {
|
||||
"annotated": 2,
|
||||
"is_quant_out": 1,
|
||||
},
|
||||
add_op: {"annotated": 1, "is_quant_out": 1},
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_unary(self):
|
||||
"""
|
||||
Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
|
||||
Currently, only add as binary post op and relu as unary post op are supported.
|
||||
"""
|
||||
linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
|
||||
# TODO test for inplace add after refactoring of capture_pre_autograd_graph
|
||||
inplace_add_list = [False]
|
||||
# TODO test for inplace relu after refactoring of capture_pre_autograd_graph
|
||||
inplace_relu_list = [False]
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list)
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for linear_pos, inplace_add, inplace_relu in cases:
|
||||
m = TestHelperModules.LinearAddReLUModule(
|
||||
inplace_add=inplace_add,
|
||||
linear_pos=linear_pos,
|
||||
inplace_relu=inplace_relu,
|
||||
).eval()
|
||||
if linear_pos != NodePosType.both:
|
||||
node_occurrence = {
|
||||
# Only one q-dq node for input of the linear
|
||||
# No q-dq node for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
else:
|
||||
node_occurrence = {
|
||||
# One quantize_per_tensor for both linear nodes (shared)
|
||||
# Two dequantize_per_tensor for two linear nodes
|
||||
# No q-dq for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.add_.Tensor
|
||||
if inplace_add
|
||||
else torch.ops.aten.add.Tensor,
|
||||
]
|
||||
fq_m = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)[-1]
|
||||
# linear, add, relu are fused
|
||||
# The other linear is quantized alone if present
|
||||
aten = torch.ops.aten
|
||||
add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
|
||||
relu_op = aten.relu_.default if inplace_relu else aten.relu.default
|
||||
expected_annotation_stat = {
|
||||
aten.linear.default: {
|
||||
"annotated": 2 if linear_pos == NodePosType.both else 1,
|
||||
"is_quant_out": 1 if linear_pos == NodePosType.both else 0,
|
||||
},
|
||||
add_op: {"annotated": 1, "is_quant_out": 0},
|
||||
relu_op: {"annotated": 1, "is_quant_out": 1},
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_unary_serials(self):
|
||||
"""
|
||||
Test pattern of 2 following up linear add relu with X86InductorQuantizer.
|
||||
"""
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
m = TestHelperModules.SerialsLinearAddReLUModule().eval()
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
node_occurrence = {
|
||||
# quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4
|
||||
# dequantize_per_tensor: 1 for each linear
|
||||
# No q-dq for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.aten.relu.default,
|
||||
]
|
||||
fq_m = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)[-1]
|
||||
# Two linear nodes are quantized alone
|
||||
# The other two are fused with add and relu
|
||||
aten = torch.ops.aten
|
||||
expected_annotation_stat = {
|
||||
aten.linear.default: {
|
||||
"annotated": 4,
|
||||
"is_quant_out": 2,
|
||||
},
|
||||
aten.add.Tensor: {"annotated": 2, "is_quant_out": 0},
|
||||
aten.relu.default: {"annotated": 2, "is_quant_out": 2},
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
@skipIfTorchDynamo("very slow")
|
||||
@skipIfNoX86
|
||||
def test_qat_conv2d(self):
|
||||
|
@ -766,6 +766,7 @@ class X86InductorQuantizer(Quantizer):
|
||||
if config := self._get_aten_operator_qconfig(torch.ops.aten.linear.default):
|
||||
if config.input_activation and not config.input_activation.is_dynamic:
|
||||
# <TODO> Weiwen: Dynamic Quant of linear unary will be supported in next step
|
||||
self._annotate_linear_binary_unary(model, config)
|
||||
self._annotate_linear_unary(model, config)
|
||||
self._annotate_linear(model, config)
|
||||
|
||||
@ -1141,6 +1142,87 @@ class X86InductorQuantizer(Quantizer):
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
|
||||
def _annotate_linear_binary_unary(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
quantization_config: QuantizationConfig,
|
||||
) -> None:
|
||||
# linear + binary_op + (optional) unary op
|
||||
binary_op_list = [operator.add]
|
||||
unary_op_list = [torch.nn.ReLU, None]
|
||||
combinations = itertools.product(binary_op_list, unary_op_list)
|
||||
for binary_op, unary_op in combinations:
|
||||
has_unary = unary_op is not None
|
||||
seq_partition = [torch.nn.Linear, binary_op]
|
||||
if has_unary:
|
||||
seq_partition.append(unary_op)
|
||||
fused_partitions = find_sequential_partitions(gm, seq_partition)
|
||||
for fused_partition in fused_partitions:
|
||||
unary_partition, unary_node = None, None
|
||||
if has_unary:
|
||||
(
|
||||
linear_partition,
|
||||
binary_partition,
|
||||
unary_partition,
|
||||
) = fused_partition
|
||||
(
|
||||
linear_node,
|
||||
binary_node,
|
||||
unary_node,
|
||||
) = self._get_output_nodes_of_partitions(
|
||||
[linear_partition, binary_partition, unary_partition]
|
||||
)
|
||||
else:
|
||||
linear_partition, binary_partition = fused_partition
|
||||
linear_node, binary_node = self._get_output_nodes_of_partitions(
|
||||
[linear_partition, binary_partition]
|
||||
)
|
||||
if len(linear_node.users) != 1:
|
||||
# Linear Node should only has 1 user node
|
||||
continue
|
||||
(
|
||||
linear_node_idx,
|
||||
extra_input_node_idx,
|
||||
) = self._get_input_idx_for_binary_node(linear_node, binary_node)
|
||||
if (linear_node_idx is None) or (extra_input_node_idx is None):
|
||||
continue
|
||||
if linear_node != binary_node.args[linear_node_idx]:
|
||||
raise ValueError(
|
||||
f"{linear_node} doesn't match input of binary node"
|
||||
)
|
||||
assert isinstance(linear_node, Node)
|
||||
if (
|
||||
linear_node.op != "call_function"
|
||||
or linear_node.target != torch.ops.aten.linear.default
|
||||
):
|
||||
# No linear node found to be fused with add
|
||||
continue
|
||||
node_list = (
|
||||
[binary_node, linear_node]
|
||||
if unary_node is None
|
||||
else [unary_node, binary_node, linear_node]
|
||||
)
|
||||
if _is_annotated(node_list):
|
||||
continue
|
||||
self._annotate_linear_node_helper(
|
||||
linear_node, False, quantization_config
|
||||
)
|
||||
# We don't insert q-dq before the binary input node due to accuracy issues
|
||||
binary_node.meta[
|
||||
QUANT_ANNOTATION_KEY
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map={},
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=(not has_unary),
|
||||
)
|
||||
if unary_node is not None:
|
||||
unary_node.meta[
|
||||
QUANT_ANNOTATION_KEY
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
|
Reference in New Issue
Block a user