mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[Quant][PT2E] Enable linear-binary(-unary) post-op recipe for X86Inductor quantizer (#122387)"
This reverts commit 82e0153487c2cd1abc92598963be5b57ab1948d4. Reverted https://github.com/pytorch/pytorch/pull/122387 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/122387#issuecomment-2048294643))
This commit is contained in:
@ -4,7 +4,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
|
||||
X86InductorQuantizer,
|
||||
QUANT_ANNOTATION_KEY,
|
||||
)
|
||||
from torch.ao.quantization.quantize_pt2e import (
|
||||
convert_pt2e,
|
||||
@ -25,7 +24,7 @@ import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
||||
from torch.ao.quantization import ObserverBase
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
|
||||
class NodePosType(Enum):
|
||||
class Conv2DType(Enum):
|
||||
left = 1
|
||||
right = 2
|
||||
both = 3
|
||||
@ -64,7 +63,7 @@ class TestHelperModules:
|
||||
class Conv2dAddModule(torch.nn.Module):
|
||||
def __init__(self,
|
||||
inplace_add: bool = False,
|
||||
conv2d_type: NodePosType = NodePosType.left,
|
||||
conv2d_type: Conv2DType = Conv2DType.left,
|
||||
use_bias: bool = False,
|
||||
with_bn: bool = False,
|
||||
) -> None:
|
||||
@ -82,7 +81,7 @@ class TestHelperModules:
|
||||
self.with_bn = with_bn
|
||||
|
||||
def forward(self, x):
|
||||
if self.conv2d_type == NodePosType.left:
|
||||
if self.conv2d_type == Conv2DType.left:
|
||||
if self.inplace_add:
|
||||
tmp = self.conv(x)
|
||||
if self.with_bn:
|
||||
@ -94,14 +93,14 @@ class TestHelperModules:
|
||||
if self.with_bn:
|
||||
tmp = self.bn(tmp)
|
||||
return tmp + self.relu(x)
|
||||
elif self.conv2d_type == NodePosType.right:
|
||||
elif self.conv2d_type == Conv2DType.right:
|
||||
if self.inplace_add:
|
||||
tmp = self.relu(x)
|
||||
tmp += self.conv(x)
|
||||
return tmp
|
||||
else:
|
||||
return self.relu(x) + self.conv(x)
|
||||
elif self.conv2d_type == NodePosType.both:
|
||||
elif self.conv2d_type == Conv2DType.both:
|
||||
if self.inplace_add:
|
||||
tmp = self.conv(x)
|
||||
tmp += self.conv2(x)
|
||||
@ -112,7 +111,7 @@ class TestHelperModules:
|
||||
class Conv2dAddReLUModule(torch.nn.Module):
|
||||
def __init__(self,
|
||||
inplace_add: bool = False,
|
||||
conv2d_type: NodePosType = NodePosType.left,
|
||||
conv2d_type: Conv2DType = Conv2DType.left,
|
||||
inplace_relu: bool = False,
|
||||
use_bias: bool = False,
|
||||
with_bn: bool = False,
|
||||
@ -132,7 +131,7 @@ class TestHelperModules:
|
||||
self.with_bn = with_bn
|
||||
|
||||
def forward(self, x):
|
||||
if self.conv2d_type == NodePosType.left:
|
||||
if self.conv2d_type == Conv2DType.left:
|
||||
if self.inplace_add:
|
||||
tmp = self.conv(x)
|
||||
if self.with_bn:
|
||||
@ -144,14 +143,14 @@ class TestHelperModules:
|
||||
if self.with_bn:
|
||||
tmp = self.bn(tmp)
|
||||
return self.relu2(tmp + self.relu(x))
|
||||
elif self.conv2d_type == NodePosType.right:
|
||||
elif self.conv2d_type == Conv2DType.right:
|
||||
if self.inplace_add:
|
||||
tmp = self.relu(x)
|
||||
tmp += self.conv(x)
|
||||
return self.relu2(tmp)
|
||||
else:
|
||||
return self.relu2(self.relu(x) + self.conv(x))
|
||||
elif self.conv2d_type == NodePosType.both:
|
||||
elif self.conv2d_type == Conv2DType.both:
|
||||
if self.inplace_add:
|
||||
tmp = self.conv(x)
|
||||
tmp += self.conv2(x)
|
||||
@ -265,138 +264,6 @@ class TestHelperModules:
|
||||
def forward(self, x):
|
||||
return self.postop(self.linear(x))
|
||||
|
||||
class LinearAddModule(torch.nn.Module):
|
||||
def __init__(self,
|
||||
inplace_add: bool = False,
|
||||
linear_pos: NodePosType = NodePosType.left,
|
||||
use_bias: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=use_bias
|
||||
)
|
||||
self.linear2 = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=use_bias
|
||||
)
|
||||
self.relu = nn.ReLU()
|
||||
self.inplace_add = inplace_add
|
||||
self.linear_pos = linear_pos
|
||||
|
||||
def forward(self, x):
|
||||
if self.linear_pos == NodePosType.left:
|
||||
if self.inplace_add:
|
||||
tmp = self.linear(x)
|
||||
tmp += self.relu(x)
|
||||
return tmp
|
||||
else:
|
||||
tmp = self.linear(x)
|
||||
return tmp + self.relu(x)
|
||||
elif self.linear_pos == NodePosType.right:
|
||||
if self.inplace_add:
|
||||
tmp = self.relu(x)
|
||||
tmp += self.linear(x)
|
||||
return tmp
|
||||
else:
|
||||
return self.relu(x) + self.linear(x)
|
||||
elif self.linear_pos == NodePosType.both:
|
||||
if self.inplace_add:
|
||||
tmp = self.linear(x)
|
||||
tmp += self.linear2(x)
|
||||
return tmp
|
||||
else:
|
||||
return self.linear(x) + self.linear2(x)
|
||||
|
||||
class LinearAddReLUModule(torch.nn.Module):
|
||||
def __init__(self,
|
||||
inplace_add: bool = False,
|
||||
linear_pos: NodePosType = NodePosType.left,
|
||||
inplace_relu: bool = False,
|
||||
use_bias: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=use_bias
|
||||
)
|
||||
self.linear2 = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=use_bias
|
||||
)
|
||||
self.relu = nn.ReLU()
|
||||
self.inplace_add = inplace_add
|
||||
self.linear_pos = linear_pos
|
||||
self.relu2 = nn.ReLU(inplace=inplace_relu)
|
||||
|
||||
def forward(self, x):
|
||||
if self.linear_pos == NodePosType.left:
|
||||
if self.inplace_add:
|
||||
tmp = self.linear(x)
|
||||
tmp += self.relu(x)
|
||||
return self.relu2(tmp)
|
||||
else:
|
||||
tmp = self.linear(x)
|
||||
return self.relu2(tmp + self.relu(x))
|
||||
elif self.linear_pos == NodePosType.right:
|
||||
if self.inplace_add:
|
||||
tmp = self.relu(x)
|
||||
tmp += self.linear(x)
|
||||
return self.relu2(tmp)
|
||||
else:
|
||||
return self.relu2(self.relu(x) + self.linear(x))
|
||||
elif self.linear_pos == NodePosType.both:
|
||||
if self.inplace_add:
|
||||
tmp = self.linear(x)
|
||||
tmp += self.linear2(x)
|
||||
return self.relu2(tmp)
|
||||
else:
|
||||
return self.relu2(self.linear(x) + self.linear2(x))
|
||||
|
||||
class SerialsLinearAddReLUModule(torch.nn.Module):
|
||||
""" Serials of 2 Linear -> Add -> ReLU Pattern.
|
||||
"""
|
||||
def __init__(self, ) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=True
|
||||
)
|
||||
self.linear2 = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=True
|
||||
)
|
||||
self.linear3 = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=True
|
||||
)
|
||||
self.linear4 = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=True
|
||||
)
|
||||
self.relu = nn.ReLU()
|
||||
self.relu2 = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.linear(x)
|
||||
res1 = self.relu(self.linear2(x1) + self.linear3(x1))
|
||||
res2 = self.relu2(self.linear4(res1) + res1)
|
||||
return res2
|
||||
|
||||
class LinearAddModule2(torch.nn.Module):
|
||||
def __init__(self,
|
||||
inplace_add: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=True
|
||||
)
|
||||
self.linear2 = torch.nn.Linear(
|
||||
in_features=16, out_features=16, bias=True
|
||||
)
|
||||
self.inplace_add = inplace_add
|
||||
|
||||
def forward(self, x):
|
||||
if self.inplace_add:
|
||||
tmp = self.linear(x)
|
||||
tmp += self.linear2(tmp)
|
||||
return tmp
|
||||
else:
|
||||
tmp = self.linear(x)
|
||||
return tmp + self.linear2(tmp)
|
||||
|
||||
class Conv2dAddModule2(torch.nn.Module):
|
||||
def __init__(self,
|
||||
inplace_add: bool = False,
|
||||
@ -565,7 +432,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
Test pattern of conv2d with binary post ops (such as add) with X86InductorQuantizer.
|
||||
Currently, only add as binary post op is supported.
|
||||
"""
|
||||
conv2d_type_list = [NodePosType.left, NodePosType.both]
|
||||
conv2d_type_list = [Conv2DType.left, Conv2DType.both]
|
||||
example_inputs = (torch.randn(2, 3, 6, 6),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
@ -573,7 +440,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for conv2d_type in conv2d_type_list:
|
||||
m = TestHelperModules.Conv2dAddModule(conv2d_type=conv2d_type).eval()
|
||||
if conv2d_type != NodePosType.both:
|
||||
if conv2d_type != Conv2DType.both:
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv
|
||||
# one for extra input node of add
|
||||
@ -655,7 +522,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
Test pattern of conv2d with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
|
||||
Currently, only add as binary post op and relu as unary post op are supported.
|
||||
"""
|
||||
conv2d_type_list = [NodePosType.left, NodePosType.both]
|
||||
conv2d_type_list = [Conv2DType.left, Conv2DType.both]
|
||||
example_inputs = (torch.randn(2, 3, 6, 6),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
@ -665,7 +532,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
m = TestHelperModules.Conv2dAddReLUModule(
|
||||
conv2d_type=conv2d_type,
|
||||
).eval()
|
||||
if conv2d_type != NodePosType.both:
|
||||
if conv2d_type != Conv2DType.both:
|
||||
node_occurrence = {
|
||||
# one for input for conv
|
||||
# one for extra input node of add
|
||||
@ -1188,276 +1055,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_list,
|
||||
)
|
||||
|
||||
def _check_annotation_stat(self, gm, expected_stat_dict):
|
||||
# Check expected annotation statistics to ensure the annotation is correct
|
||||
|
||||
def _check_annotation(node):
|
||||
annot = node.meta.get(QUANT_ANNOTATION_KEY, None)
|
||||
if annot is None:
|
||||
return False, False
|
||||
return annot._annotated, annot._is_output_of_quantized_pattern
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.target in expected_stat_dict.keys():
|
||||
annotated, is_quant_out = _check_annotation(node)
|
||||
expected_stat_dict[node.target]['annotated'] -= annotated
|
||||
expected_stat_dict[node.target]['is_quant_out'] -= is_quant_out
|
||||
for op_stat in expected_stat_dict.values():
|
||||
assert all(v == 0 for v in op_stat.values())
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary(self):
|
||||
"""
|
||||
Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer.
|
||||
Currently, only add as binary post op is supported.
|
||||
"""
|
||||
linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
|
||||
inplace_add_list = [False, True]
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
cases = itertools.product(linear_pos_list, inplace_add_list)
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for linear_pos, inplace_add in cases:
|
||||
m = TestHelperModules.LinearAddModule(inplace_add=inplace_add, linear_pos=linear_pos).eval()
|
||||
if linear_pos != NodePosType.both:
|
||||
node_occurrence = {
|
||||
# Only one 1 q-dq for input of the linear
|
||||
# No q-dq for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
else:
|
||||
node_occurrence = {
|
||||
# One quantize_per_tensor for both linear nodes (shared)
|
||||
# Two dequantize_per_tensor for two linear nodes
|
||||
# No q-dq for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
|
||||
]
|
||||
fq_m = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)[-1]
|
||||
# One linear and add are fused. The other linear is quantized alone if present
|
||||
aten = torch.ops.aten
|
||||
add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
|
||||
expected_annotation_stat = {
|
||||
aten.linear.default: {
|
||||
'annotated': 2 if linear_pos == NodePosType.both else 1,
|
||||
'is_quant_out': 1 if linear_pos == NodePosType.both else 0,
|
||||
},
|
||||
add_op: {
|
||||
'annotated': 1,
|
||||
'is_quant_out': 1
|
||||
},
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary2(self):
|
||||
"""
|
||||
Test Pattern:
|
||||
tmp = linear_1(x)
|
||||
tmp2 = linear_2(tmp)
|
||||
return tmp + tmp2
|
||||
Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1
|
||||
"""
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
inplace_add_list = [True, False]
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for inplace_add in inplace_add_list:
|
||||
m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval()
|
||||
# Two q-dq nodes for inputs of linear nodes
|
||||
# No q-dq for extra input node of add
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
|
||||
]
|
||||
fq_m = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)[-1]
|
||||
# One linear and add are fused. The other linear is quantized alone if present
|
||||
aten = torch.ops.aten
|
||||
add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
|
||||
expected_annotation_stat = {
|
||||
aten.linear.default: {
|
||||
'annotated': 2,
|
||||
'is_quant_out': 1,
|
||||
},
|
||||
add_op: {
|
||||
'annotated': 1,
|
||||
'is_quant_out': 1
|
||||
},
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_unary(self):
|
||||
"""
|
||||
Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
|
||||
Currently, only add as binary post op and relu as unary post op are supported.
|
||||
"""
|
||||
linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
|
||||
inplace_add_list = [False, True]
|
||||
inplace_relu_list = [False, True]
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list)
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for linear_pos, inplace_add, inplace_relu in cases:
|
||||
m = TestHelperModules.LinearAddReLUModule(
|
||||
inplace_add=inplace_add,
|
||||
linear_pos=linear_pos,
|
||||
inplace_relu=inplace_relu,
|
||||
).eval()
|
||||
if linear_pos != NodePosType.both:
|
||||
node_occurrence = {
|
||||
# Only one q-dq node for input of the linear
|
||||
# No q-dq node for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
else:
|
||||
node_occurrence = {
|
||||
# One quantize_per_tensor for both linear nodes (shared)
|
||||
# Two dequantize_per_tensor for two linear nodes
|
||||
# No q-dq for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
|
||||
]
|
||||
fq_m = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)[-1]
|
||||
# linear, add, relu are fused
|
||||
# The other linear is quantized alone if present
|
||||
aten = torch.ops.aten
|
||||
add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
|
||||
relu_op = aten.relu_.default if inplace_relu else aten.relu.default
|
||||
expected_annotation_stat = {
|
||||
aten.linear.default: {
|
||||
'annotated': 2 if linear_pos == NodePosType.both else 1,
|
||||
'is_quant_out': 1 if linear_pos == NodePosType.both else 0,
|
||||
},
|
||||
add_op: {
|
||||
'annotated': 1,
|
||||
'is_quant_out': 0
|
||||
},
|
||||
relu_op: {
|
||||
'annotated': 1,
|
||||
'is_quant_out': 1
|
||||
},
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_unary_serials(self):
|
||||
"""
|
||||
Test pattern of 2 following up linear add relu with X86InductorQuantizer.
|
||||
"""
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
m = TestHelperModules.SerialsLinearAddReLUModule().eval()
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(xiq.get_default_x86_inductor_quantization_config())
|
||||
node_occurrence = {
|
||||
# quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4
|
||||
# dequantize_per_tensor: 1 for each linear
|
||||
# No q-dq for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.aten.relu.default,
|
||||
]
|
||||
fq_m = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)[-1]
|
||||
# Two linear nodes are quantized alone
|
||||
# The other two are fused with add and relu
|
||||
aten = torch.ops.aten
|
||||
expected_annotation_stat = {
|
||||
aten.linear.default: {
|
||||
'annotated': 4,
|
||||
'is_quant_out': 2,
|
||||
},
|
||||
aten.add.Tensor: {
|
||||
'annotated': 2,
|
||||
'is_quant_out': 0
|
||||
},
|
||||
aten.relu.default: {
|
||||
'annotated': 2,
|
||||
'is_quant_out': 2
|
||||
},
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
@skipIfTorchDynamo("very slow")
|
||||
@skipIfNoX86
|
||||
def test_qat_conv2d(self):
|
||||
|
Reference in New Issue
Block a user