Revert "[Quant][PT2E] Enable linear-binary(-unary) post-op recipe for X86Inductor quantizer (#122387)"

This reverts commit 82e0153487c2cd1abc92598963be5b57ab1948d4.

Reverted https://github.com/pytorch/pytorch/pull/122387 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/122387#issuecomment-2048294643))
This commit is contained in:
PyTorch MergeBot
2024-04-10 19:34:26 +00:00
parent 30c4efe6d2
commit 8d9af8b91c
2 changed files with 15 additions and 505 deletions

View File

@ -4,7 +4,6 @@ import torch
import torch.nn as nn
from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
X86InductorQuantizer,
QUANT_ANNOTATION_KEY,
)
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
@ -25,7 +24,7 @@ import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.ao.quantization import ObserverBase
from torch._export import capture_pre_autograd_graph
class NodePosType(Enum):
class Conv2DType(Enum):
left = 1
right = 2
both = 3
@ -64,7 +63,7 @@ class TestHelperModules:
class Conv2dAddModule(torch.nn.Module):
def __init__(self,
inplace_add: bool = False,
conv2d_type: NodePosType = NodePosType.left,
conv2d_type: Conv2DType = Conv2DType.left,
use_bias: bool = False,
with_bn: bool = False,
) -> None:
@ -82,7 +81,7 @@ class TestHelperModules:
self.with_bn = with_bn
def forward(self, x):
if self.conv2d_type == NodePosType.left:
if self.conv2d_type == Conv2DType.left:
if self.inplace_add:
tmp = self.conv(x)
if self.with_bn:
@ -94,14 +93,14 @@ class TestHelperModules:
if self.with_bn:
tmp = self.bn(tmp)
return tmp + self.relu(x)
elif self.conv2d_type == NodePosType.right:
elif self.conv2d_type == Conv2DType.right:
if self.inplace_add:
tmp = self.relu(x)
tmp += self.conv(x)
return tmp
else:
return self.relu(x) + self.conv(x)
elif self.conv2d_type == NodePosType.both:
elif self.conv2d_type == Conv2DType.both:
if self.inplace_add:
tmp = self.conv(x)
tmp += self.conv2(x)
@ -112,7 +111,7 @@ class TestHelperModules:
class Conv2dAddReLUModule(torch.nn.Module):
def __init__(self,
inplace_add: bool = False,
conv2d_type: NodePosType = NodePosType.left,
conv2d_type: Conv2DType = Conv2DType.left,
inplace_relu: bool = False,
use_bias: bool = False,
with_bn: bool = False,
@ -132,7 +131,7 @@ class TestHelperModules:
self.with_bn = with_bn
def forward(self, x):
if self.conv2d_type == NodePosType.left:
if self.conv2d_type == Conv2DType.left:
if self.inplace_add:
tmp = self.conv(x)
if self.with_bn:
@ -144,14 +143,14 @@ class TestHelperModules:
if self.with_bn:
tmp = self.bn(tmp)
return self.relu2(tmp + self.relu(x))
elif self.conv2d_type == NodePosType.right:
elif self.conv2d_type == Conv2DType.right:
if self.inplace_add:
tmp = self.relu(x)
tmp += self.conv(x)
return self.relu2(tmp)
else:
return self.relu2(self.relu(x) + self.conv(x))
elif self.conv2d_type == NodePosType.both:
elif self.conv2d_type == Conv2DType.both:
if self.inplace_add:
tmp = self.conv(x)
tmp += self.conv2(x)
@ -265,138 +264,6 @@ class TestHelperModules:
def forward(self, x):
return self.postop(self.linear(x))
class LinearAddModule(torch.nn.Module):
def __init__(self,
inplace_add: bool = False,
linear_pos: NodePosType = NodePosType.left,
use_bias: bool = False,
) -> None:
super().__init__()
self.linear = torch.nn.Linear(
in_features=16, out_features=16, bias=use_bias
)
self.linear2 = torch.nn.Linear(
in_features=16, out_features=16, bias=use_bias
)
self.relu = nn.ReLU()
self.inplace_add = inplace_add
self.linear_pos = linear_pos
def forward(self, x):
if self.linear_pos == NodePosType.left:
if self.inplace_add:
tmp = self.linear(x)
tmp += self.relu(x)
return tmp
else:
tmp = self.linear(x)
return tmp + self.relu(x)
elif self.linear_pos == NodePosType.right:
if self.inplace_add:
tmp = self.relu(x)
tmp += self.linear(x)
return tmp
else:
return self.relu(x) + self.linear(x)
elif self.linear_pos == NodePosType.both:
if self.inplace_add:
tmp = self.linear(x)
tmp += self.linear2(x)
return tmp
else:
return self.linear(x) + self.linear2(x)
class LinearAddReLUModule(torch.nn.Module):
def __init__(self,
inplace_add: bool = False,
linear_pos: NodePosType = NodePosType.left,
inplace_relu: bool = False,
use_bias: bool = False,
) -> None:
super().__init__()
self.linear = torch.nn.Linear(
in_features=16, out_features=16, bias=use_bias
)
self.linear2 = torch.nn.Linear(
in_features=16, out_features=16, bias=use_bias
)
self.relu = nn.ReLU()
self.inplace_add = inplace_add
self.linear_pos = linear_pos
self.relu2 = nn.ReLU(inplace=inplace_relu)
def forward(self, x):
if self.linear_pos == NodePosType.left:
if self.inplace_add:
tmp = self.linear(x)
tmp += self.relu(x)
return self.relu2(tmp)
else:
tmp = self.linear(x)
return self.relu2(tmp + self.relu(x))
elif self.linear_pos == NodePosType.right:
if self.inplace_add:
tmp = self.relu(x)
tmp += self.linear(x)
return self.relu2(tmp)
else:
return self.relu2(self.relu(x) + self.linear(x))
elif self.linear_pos == NodePosType.both:
if self.inplace_add:
tmp = self.linear(x)
tmp += self.linear2(x)
return self.relu2(tmp)
else:
return self.relu2(self.linear(x) + self.linear2(x))
class SerialsLinearAddReLUModule(torch.nn.Module):
""" Serials of 2 Linear -> Add -> ReLU Pattern.
"""
def __init__(self, ) -> None:
super().__init__()
self.linear = torch.nn.Linear(
in_features=16, out_features=16, bias=True
)
self.linear2 = torch.nn.Linear(
in_features=16, out_features=16, bias=True
)
self.linear3 = torch.nn.Linear(
in_features=16, out_features=16, bias=True
)
self.linear4 = torch.nn.Linear(
in_features=16, out_features=16, bias=True
)
self.relu = nn.ReLU()
self.relu2 = nn.ReLU()
def forward(self, x):
x1 = self.linear(x)
res1 = self.relu(self.linear2(x1) + self.linear3(x1))
res2 = self.relu2(self.linear4(res1) + res1)
return res2
class LinearAddModule2(torch.nn.Module):
def __init__(self,
inplace_add: bool = False,
) -> None:
super().__init__()
self.linear = torch.nn.Linear(
in_features=16, out_features=16, bias=True
)
self.linear2 = torch.nn.Linear(
in_features=16, out_features=16, bias=True
)
self.inplace_add = inplace_add
def forward(self, x):
if self.inplace_add:
tmp = self.linear(x)
tmp += self.linear2(tmp)
return tmp
else:
tmp = self.linear(x)
return tmp + self.linear2(tmp)
class Conv2dAddModule2(torch.nn.Module):
def __init__(self,
inplace_add: bool = False,
@ -565,7 +432,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
Test pattern of conv2d with binary post ops (such as add) with X86InductorQuantizer.
Currently, only add as binary post op is supported.
"""
conv2d_type_list = [NodePosType.left, NodePosType.both]
conv2d_type_list = [Conv2DType.left, Conv2DType.both]
example_inputs = (torch.randn(2, 3, 6, 6),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
@ -573,7 +440,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
with override_quantized_engine("x86"), torch.no_grad():
for conv2d_type in conv2d_type_list:
m = TestHelperModules.Conv2dAddModule(conv2d_type=conv2d_type).eval()
if conv2d_type != NodePosType.both:
if conv2d_type != Conv2DType.both:
node_occurrence = {
# one for input and weight of the conv
# one for extra input node of add
@ -655,7 +522,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
Test pattern of conv2d with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
Currently, only add as binary post op and relu as unary post op are supported.
"""
conv2d_type_list = [NodePosType.left, NodePosType.both]
conv2d_type_list = [Conv2DType.left, Conv2DType.both]
example_inputs = (torch.randn(2, 3, 6, 6),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
@ -665,7 +532,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
m = TestHelperModules.Conv2dAddReLUModule(
conv2d_type=conv2d_type,
).eval()
if conv2d_type != NodePosType.both:
if conv2d_type != Conv2DType.both:
node_occurrence = {
# one for input for conv
# one for extra input node of add
@ -1188,276 +1055,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_list,
)
def _check_annotation_stat(self, gm, expected_stat_dict):
# Check expected annotation statistics to ensure the annotation is correct
def _check_annotation(node):
annot = node.meta.get(QUANT_ANNOTATION_KEY, None)
if annot is None:
return False, False
return annot._annotated, annot._is_output_of_quantized_pattern
for node in gm.graph.nodes:
if node.target in expected_stat_dict.keys():
annotated, is_quant_out = _check_annotation(node)
expected_stat_dict[node.target]['annotated'] -= annotated
expected_stat_dict[node.target]['is_quant_out'] -= is_quant_out
for op_stat in expected_stat_dict.values():
assert all(v == 0 for v in op_stat.values())
@skipIfNoX86
def test_linear_binary(self):
"""
Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer.
Currently, only add as binary post op is supported.
"""
linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
inplace_add_list = [False, True]
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
cases = itertools.product(linear_pos_list, inplace_add_list)
with override_quantized_engine("x86"), torch.no_grad():
for linear_pos, inplace_add in cases:
m = TestHelperModules.LinearAddModule(inplace_add=inplace_add, linear_pos=linear_pos).eval()
if linear_pos != NodePosType.both:
node_occurrence = {
# Only one 1 q-dq for input of the linear
# No q-dq for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
else:
node_occurrence = {
# One quantize_per_tensor for both linear nodes (shared)
# Two dequantize_per_tensor for two linear nodes
# No q-dq for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
]
fq_m = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)[-1]
# One linear and add are fused. The other linear is quantized alone if present
aten = torch.ops.aten
add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
expected_annotation_stat = {
aten.linear.default: {
'annotated': 2 if linear_pos == NodePosType.both else 1,
'is_quant_out': 1 if linear_pos == NodePosType.both else 0,
},
add_op: {
'annotated': 1,
'is_quant_out': 1
},
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfNoX86
def test_linear_binary2(self):
"""
Test Pattern:
tmp = linear_1(x)
tmp2 = linear_2(tmp)
return tmp + tmp2
Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1
"""
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
inplace_add_list = [True, False]
with override_quantized_engine("x86"), torch.no_grad():
for inplace_add in inplace_add_list:
m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval()
# Two q-dq nodes for inputs of linear nodes
# No q-dq for extra input node of add
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
]
fq_m = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)[-1]
# One linear and add are fused. The other linear is quantized alone if present
aten = torch.ops.aten
add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
expected_annotation_stat = {
aten.linear.default: {
'annotated': 2,
'is_quant_out': 1,
},
add_op: {
'annotated': 1,
'is_quant_out': 1
},
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfNoX86
def test_linear_binary_unary(self):
"""
Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
Currently, only add as binary post op and relu as unary post op are supported.
"""
linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
inplace_add_list = [False, True]
inplace_relu_list = [False, True]
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list)
with override_quantized_engine("x86"), torch.no_grad():
for linear_pos, inplace_add, inplace_relu in cases:
m = TestHelperModules.LinearAddReLUModule(
inplace_add=inplace_add,
linear_pos=linear_pos,
inplace_relu=inplace_relu,
).eval()
if linear_pos != NodePosType.both:
node_occurrence = {
# Only one q-dq node for input of the linear
# No q-dq node for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
else:
node_occurrence = {
# One quantize_per_tensor for both linear nodes (shared)
# Two dequantize_per_tensor for two linear nodes
# No q-dq for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
]
fq_m = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)[-1]
# linear, add, relu are fused
# The other linear is quantized alone if present
aten = torch.ops.aten
add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
relu_op = aten.relu_.default if inplace_relu else aten.relu.default
expected_annotation_stat = {
aten.linear.default: {
'annotated': 2 if linear_pos == NodePosType.both else 1,
'is_quant_out': 1 if linear_pos == NodePosType.both else 0,
},
add_op: {
'annotated': 1,
'is_quant_out': 0
},
relu_op: {
'annotated': 1,
'is_quant_out': 1
},
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfNoX86
def test_linear_binary_unary_serials(self):
"""
Test pattern of 2 following up linear add relu with X86InductorQuantizer.
"""
with override_quantized_engine("x86"), torch.no_grad():
m = TestHelperModules.SerialsLinearAddReLUModule().eval()
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(xiq.get_default_x86_inductor_quantization_config())
node_occurrence = {
# quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4
# dequantize_per_tensor: 1 for each linear
# No q-dq for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.aten.linear.default,
torch.ops.aten.add.Tensor,
torch.ops.aten.relu.default,
]
fq_m = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)[-1]
# Two linear nodes are quantized alone
# The other two are fused with add and relu
aten = torch.ops.aten
expected_annotation_stat = {
aten.linear.default: {
'annotated': 4,
'is_quant_out': 2,
},
aten.add.Tensor: {
'annotated': 2,
'is_quant_out': 0
},
aten.relu.default: {
'annotated': 2,
'is_quant_out': 2
},
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfTorchDynamo("very slow")
@skipIfNoX86
def test_qat_conv2d(self):