mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Quant] [PT2] Add SiLU into X86InductorQuantizer Conv2d Unary Annotation (#122267)
**Summary** Add `SiLU` into X86InductorQuantizer Conv2d Unary Annotation **TestPlan** ``` python -m pytest test_x86inductor_quantizer.py -k test_conv2d_unary python -m pytest test_x86inductor_quantizer.py -k test_qat_conv2d_unary ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/122267 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5 ghstack dependencies: #122266
This commit is contained in:
committed by
PyTorch MergeBot
parent
b7089937dc
commit
e0329cba8a
@ -50,12 +50,14 @@ class TestHelperModules:
|
||||
self.post_op = post_op
|
||||
self.bn = torch.nn.BatchNorm2d(6)
|
||||
self.with_bn = with_bn
|
||||
self.maxpool = torch.nn.MaxPool2d((3, 3))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.with_bn:
|
||||
x = self.bn(x)
|
||||
x = self.post_op(x)
|
||||
x = self.maxpool(x)
|
||||
return x
|
||||
|
||||
class Conv2dAddModule(torch.nn.Module):
|
||||
@ -390,7 +392,9 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
"relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default],
|
||||
"relu6_inplace": [torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default],
|
||||
"hardswish": [torch.nn.Hardswish(inplace=False), torch.ops.aten.hardswish.default],
|
||||
"hardswish_inplace": [torch.nn.Hardswish(inplace=True), torch.ops.aten.hardswish_.default]
|
||||
"hardswish_inplace": [torch.nn.Hardswish(inplace=True), torch.ops.aten.hardswish_.default],
|
||||
"swish": [torch.nn.SiLU(inplace=False), torch.ops.aten.silu.default],
|
||||
"swish_inplace": [torch.nn.SiLU(inplace=True), torch.ops.aten.silu_.default],
|
||||
}
|
||||
use_bias_list = [True, False]
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
@ -402,8 +406,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
)
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
@ -1104,7 +1108,9 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
"relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default],
|
||||
"relu6_inplace": [torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default],
|
||||
"hardswish": [torch.nn.Hardswish(inplace=False), torch.ops.aten.hardswish.default],
|
||||
"hardswish_inplace": [torch.nn.Hardswish(inplace=True), torch.ops.aten.hardswish_.default]
|
||||
"hardswish_inplace": [torch.nn.Hardswish(inplace=True), torch.ops.aten.hardswish_.default],
|
||||
"swish": [torch.nn.SiLU(inplace=False), torch.ops.aten.silu.default],
|
||||
"swish_inplace": [torch.nn.SiLU(inplace=True), torch.ops.aten.silu_.default],
|
||||
}
|
||||
|
||||
with override_quantized_engine("x86"):
|
||||
@ -1116,8 +1122,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
)
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv, one for output for the relu
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
|
@ -575,6 +575,7 @@ class X86InductorQuantizer(Quantizer):
|
||||
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh],
|
||||
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardswish],
|
||||
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6],
|
||||
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.SiLU],
|
||||
]
|
||||
for unary_pattern in unary_patterns:
|
||||
partitions = find_sequential_partitions(gm, unary_pattern)
|
||||
@ -754,6 +755,7 @@ class X86InductorQuantizer(Quantizer):
|
||||
[torch.nn.Conv2d, torch.nn.Hardtanh],
|
||||
[torch.nn.Conv2d, torch.nn.Hardswish],
|
||||
[torch.nn.Conv2d, torch.nn.ReLU6],
|
||||
[torch.nn.Conv2d, torch.nn.SiLU],
|
||||
]
|
||||
for unary_pattern in unary_patterns:
|
||||
partitions = find_sequential_partitions(gm, unary_pattern)
|
||||
|
Reference in New Issue
Block a user