[Quant] [PT2] Add SiLU into X86InductorQuantizer Conv2d Unary Annotation (#122267)

**Summary**
Add `SiLU` into X86InductorQuantizer Conv2d Unary Annotation

**TestPlan**
```
python -m pytest test_x86inductor_quantizer.py -k test_conv2d_unary
python -m pytest test_x86inductor_quantizer.py -k test_qat_conv2d_unary
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122267
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
ghstack dependencies: #122266
This commit is contained in:
haozhe.zhu
2024-03-19 19:53:37 -07:00
committed by PyTorch MergeBot
parent b7089937dc
commit e0329cba8a
2 changed files with 14 additions and 6 deletions

View File

@ -50,12 +50,14 @@ class TestHelperModules:
self.post_op = post_op
self.bn = torch.nn.BatchNorm2d(6)
self.with_bn = with_bn
self.maxpool = torch.nn.MaxPool2d((3, 3))
def forward(self, x):
x = self.conv(x)
if self.with_bn:
x = self.bn(x)
x = self.post_op(x)
x = self.maxpool(x)
return x
class Conv2dAddModule(torch.nn.Module):
@ -390,7 +392,9 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
"relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default],
"relu6_inplace": [torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default],
"hardswish": [torch.nn.Hardswish(inplace=False), torch.ops.aten.hardswish.default],
"hardswish_inplace": [torch.nn.Hardswish(inplace=True), torch.ops.aten.hardswish_.default]
"hardswish_inplace": [torch.nn.Hardswish(inplace=True), torch.ops.aten.hardswish_.default],
"swish": [torch.nn.SiLU(inplace=False), torch.ops.aten.silu.default],
"swish_inplace": [torch.nn.SiLU(inplace=True), torch.ops.aten.silu_.default],
}
use_bias_list = [True, False]
with override_quantized_engine("x86"), torch.no_grad():
@ -402,8 +406,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
)
node_occurrence = {
# one for input and weight of the conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
@ -1104,7 +1108,9 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
"relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default],
"relu6_inplace": [torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default],
"hardswish": [torch.nn.Hardswish(inplace=False), torch.ops.aten.hardswish.default],
"hardswish_inplace": [torch.nn.Hardswish(inplace=True), torch.ops.aten.hardswish_.default]
"hardswish_inplace": [torch.nn.Hardswish(inplace=True), torch.ops.aten.hardswish_.default],
"swish": [torch.nn.SiLU(inplace=False), torch.ops.aten.silu.default],
"swish_inplace": [torch.nn.SiLU(inplace=True), torch.ops.aten.silu_.default],
}
with override_quantized_engine("x86"):
@ -1116,8 +1122,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
)
node_occurrence = {
# one for input and weight of the conv, one for output for the relu
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,

View File

@ -575,6 +575,7 @@ class X86InductorQuantizer(Quantizer):
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh],
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardswish],
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6],
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.SiLU],
]
for unary_pattern in unary_patterns:
partitions = find_sequential_partitions(gm, unary_pattern)
@ -754,6 +755,7 @@ class X86InductorQuantizer(Quantizer):
[torch.nn.Conv2d, torch.nn.Hardtanh],
[torch.nn.Conv2d, torch.nn.Hardswish],
[torch.nn.Conv2d, torch.nn.ReLU6],
[torch.nn.Conv2d, torch.nn.SiLU],
]
for unary_pattern in unary_patterns:
partitions = find_sequential_partitions(gm, unary_pattern)