mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Quant][PT2E] enable qlinear post op fusion for dynamic quant & qat (#122667)
**Description** Add fusion path for dynamic quant and for QAT. The following patterns can be matched for static quant with QAT cases: `qx -> qlinear -> add -> optional relu -> optional type convert -> optional quant` The following patterns can be matched for dynamic quant cases: `qx -> qlinear -> add -> optional relu` **Test plan** python test/inductor/test_mkldnn_pattern_matcher.py -k test_qlinear python test/inductor/test_cpu_cpp_wrapper.py -k test_qlinear python test/test_quantization.py -k test_linear_unary python test/test_quantization.py -k test_linear_binary Differential Revision: [D57655830](https://our.internmc.facebook.com/intern/diff/D57655830) Pull Request resolved: https://github.com/pytorch/pytorch/pull/122667 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
This commit is contained in:
committed by
PyTorch MergeBot
parent
a8985a97f9
commit
36e2608783
@ -1778,10 +1778,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
to_bf16_after_binary = 2 * (add_fn == add_fn_list[2] and fq_x2)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qlinear_binary_matcher_nodes"],
|
||||
5 + 2 * use_relu + to_bf16_after_binary,
|
||||
(4 if is_dynamic else 5) + 2 * use_relu + to_bf16_after_binary,
|
||||
)
|
||||
|
||||
for is_qat in [False, True]:
|
||||
is_qat_list = [False, True]
|
||||
is_dynamic_list = [False, True]
|
||||
cases = itertools.product(is_qat_list, is_dynamic_list)
|
||||
for is_qat, is_dynamic in cases:
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
@ -1789,6 +1792,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
is_qat=is_qat,
|
||||
is_dynamic=is_dynamic,
|
||||
)
|
||||
if torch._inductor.config.cpp_wrapper:
|
||||
# For CPP wrapper
|
||||
|
||||
@ -1203,44 +1203,64 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_list,
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_unary(self):
|
||||
def _test_linear_unary_helper(
|
||||
self,
|
||||
post_op_module,
|
||||
post_op_aten,
|
||||
post_op_aten_inplace,
|
||||
post_op_algo_list=None,
|
||||
is_qat=False,
|
||||
is_dynamic=False,
|
||||
):
|
||||
"""
|
||||
Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer.
|
||||
"""
|
||||
use_bias_list = [True, False]
|
||||
inplace_list = [True, False]
|
||||
postop_list = [nn.ReLU, nn.LeakyReLU] # only test two to save time
|
||||
cases = itertools.product(use_bias_list, inplace_list, postop_list)
|
||||
post_op_map = {
|
||||
nn.ReLU: [torch.ops.aten.relu_.default, torch.ops.aten.relu.default],
|
||||
nn.LeakyReLU: [
|
||||
torch.ops.aten.leaky_relu_.default,
|
||||
torch.ops.aten.leaky_relu.default,
|
||||
],
|
||||
}
|
||||
# TODO test for inplace add after refactoring of capture_pre_autograd_graph
|
||||
inplace_list = [False]
|
||||
if post_op_algo_list is None:
|
||||
post_op_algo_list = [None]
|
||||
cases = itertools.product(use_bias_list, inplace_list, post_op_algo_list)
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for use_bias, inplace, postop in cases:
|
||||
for use_bias, inplace, post_op_algo in cases:
|
||||
if inplace and post_op_aten_inplace is None:
|
||||
continue
|
||||
m = TestHelperModules.LinearUnaryModule(
|
||||
use_bias=use_bias, postop=postop, inplace_postop=inplace
|
||||
use_bias=use_bias,
|
||||
postop=post_op_module,
|
||||
inplace_postop=inplace,
|
||||
post_op_algo=post_op_algo,
|
||||
).eval()
|
||||
example_inputs = (torch.randn(2, 4),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
xiq.get_default_x86_inductor_quantization_config(
|
||||
is_qat=is_qat,
|
||||
is_dynamic=is_dynamic,
|
||||
)
|
||||
)
|
||||
quantize_per_tensor_op = (
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
if is_dynamic
|
||||
else torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
)
|
||||
dequantize_per_tensor_op = (
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
if is_dynamic
|
||||
else torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
)
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv, one for output for the relu
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||
# one for input of the linear
|
||||
quantize_per_tensor_op: 1,
|
||||
dequantize_per_tensor_op: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
quantize_per_tensor_op,
|
||||
dequantize_per_tensor_op,
|
||||
torch.ops.aten.linear.default,
|
||||
post_op_map[postop][0 if inplace else 1],
|
||||
post_op_aten_inplace if inplace else post_op_aten,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
@ -1248,47 +1268,70 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
is_qat=is_qat,
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_unary_gelu(self):
|
||||
"""
|
||||
Test pattern of linear with unary post ops (e.g. gelu) with X86InductorQuantizer.
|
||||
"""
|
||||
use_bias_list = [True, False]
|
||||
postop = nn.GELU
|
||||
post_op_algorithm = ["none", "tanh"]
|
||||
cases = itertools.product(use_bias_list, post_op_algorithm)
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for use_bias, post_op_algo in cases:
|
||||
m = TestHelperModules.LinearUnaryModule(
|
||||
use_bias=use_bias, postop=postop, post_op_algo=post_op_algo
|
||||
).eval()
|
||||
example_inputs = (torch.randn(2, 4),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv, one for output for the gelu
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.gelu.default,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
def test_linear_unary(self):
|
||||
aten = torch.ops.aten
|
||||
self._test_linear_unary_helper(nn.ReLU, aten.relu.default, aten.relu_.default)
|
||||
self._test_linear_unary_helper(
|
||||
nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default
|
||||
)
|
||||
self._test_linear_unary_helper(
|
||||
nn.GELU, aten.gelu.default, None, ["none", "tanh"]
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_unary_qat(self):
|
||||
aten = torch.ops.aten
|
||||
self._test_linear_unary_helper(
|
||||
nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True
|
||||
)
|
||||
self._test_linear_unary_helper(
|
||||
nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default, is_qat=True
|
||||
)
|
||||
self._test_linear_unary_helper(
|
||||
nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_qat=True
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_unary_dynamic(self):
|
||||
aten = torch.ops.aten
|
||||
self._test_linear_unary_helper(
|
||||
nn.ReLU, aten.relu.default, aten.relu_.default, is_dynamic=True
|
||||
)
|
||||
self._test_linear_unary_helper(
|
||||
nn.LeakyReLU,
|
||||
aten.leaky_relu.default,
|
||||
aten.leaky_relu_.default,
|
||||
is_dynamic=True,
|
||||
)
|
||||
self._test_linear_unary_helper(
|
||||
nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_dynamic=True
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_unary_dynamic_qat(self):
|
||||
aten = torch.ops.aten
|
||||
self._test_linear_unary_helper(
|
||||
nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True, is_dynamic=True
|
||||
)
|
||||
self._test_linear_unary_helper(
|
||||
nn.LeakyReLU,
|
||||
aten.leaky_relu.default,
|
||||
aten.leaky_relu_.default,
|
||||
is_qat=True,
|
||||
is_dynamic=True,
|
||||
)
|
||||
self._test_linear_unary_helper(
|
||||
nn.GELU,
|
||||
aten.gelu.default,
|
||||
None,
|
||||
["none", "tanh"],
|
||||
is_qat=True,
|
||||
is_dynamic=True,
|
||||
)
|
||||
|
||||
def _check_annotation_stat(self, gm, expected_stat_dict):
|
||||
# Check expected annotation statistics to ensure the annotation is correct
|
||||
@ -1307,8 +1350,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
for op_stat in expected_stat_dict.values():
|
||||
assert all(v == 0 for v in op_stat.values())
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary(self):
|
||||
def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False):
|
||||
"""
|
||||
Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer.
|
||||
Currently, only add as binary post op is supported.
|
||||
@ -1318,7 +1360,20 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
inplace_add_list = [False]
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
xiq.get_default_x86_inductor_quantization_config(
|
||||
is_qat=is_qat,
|
||||
is_dynamic=is_dynamic,
|
||||
)
|
||||
)
|
||||
quantize_per_tensor_op = (
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
if is_dynamic
|
||||
else torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
)
|
||||
dequantize_per_tensor_op = (
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
if is_dynamic
|
||||
else torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
)
|
||||
cases = itertools.product(linear_pos_list, inplace_add_list)
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
@ -1330,26 +1385,28 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_occurrence = {
|
||||
# Only one 1 q-dq for input of the linear
|
||||
# No q-dq for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||
quantize_per_tensor_op: 1,
|
||||
dequantize_per_tensor_op: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
else:
|
||||
# convert_pt2e disables duplicate dequant for dynamic quant
|
||||
num_dequant = 1 if is_dynamic else 2
|
||||
node_occurrence = {
|
||||
# One quantize_per_tensor for both linear nodes (shared)
|
||||
# Two dequantize_per_tensor for two linear nodes
|
||||
# No q-dq for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
quantize_per_tensor_op: 1,
|
||||
dequantize_per_tensor_op: num_dequant,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
quantize_per_tensor_op,
|
||||
dequantize_per_tensor_op,
|
||||
torch.ops.aten.linear.default,
|
||||
(
|
||||
torch.ops.aten.add_.Tensor
|
||||
@ -1363,6 +1420,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
is_qat=is_qat,
|
||||
)[-1]
|
||||
# One linear and add are fused. The other linear is quantized alone if present
|
||||
aten = torch.ops.aten
|
||||
@ -1376,6 +1434,22 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary(self):
|
||||
self._test_linear_binary_helper()
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_qat(self):
|
||||
self._test_linear_binary_helper(is_qat=True)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_dynamic(self):
|
||||
self._test_linear_binary_helper(is_dynamic=True)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_dynamic_qat(self):
|
||||
self._test_linear_binary_helper(is_qat=True, is_dynamic=True)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary2(self):
|
||||
"""
|
||||
@ -1386,28 +1460,43 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1
|
||||
"""
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
# TODO test for inplace add after refactoring of capture_pre_autograd_graph
|
||||
inplace_add_list = [False]
|
||||
is_qat_list = [False, True]
|
||||
is_dynamic_list = [False, True]
|
||||
cases = itertools.product(inplace_add_list, is_qat_list, is_dynamic_list)
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for inplace_add in inplace_add_list:
|
||||
for inplace_add, is_qat, is_dynamic in cases:
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config(
|
||||
is_qat=is_qat, is_dynamic=is_dynamic
|
||||
)
|
||||
)
|
||||
m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval()
|
||||
quantize_per_tensor_op = (
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
if is_dynamic
|
||||
else torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
)
|
||||
dequantize_per_tensor_op = (
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
if is_dynamic
|
||||
else torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
)
|
||||
# Two q-dq nodes for inputs of linear nodes
|
||||
# No q-dq for extra input node of add
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
quantize_per_tensor_op: 2,
|
||||
dequantize_per_tensor_op: 2,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
quantize_per_tensor_op,
|
||||
dequantize_per_tensor_op,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
(
|
||||
torch.ops.aten.add_.Tensor
|
||||
if inplace_add
|
||||
@ -1434,7 +1523,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_unary(self):
|
||||
def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False):
|
||||
"""
|
||||
Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
|
||||
Currently, only add as binary post op and relu as unary post op are supported.
|
||||
@ -1446,7 +1535,20 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
inplace_relu_list = [False]
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
xiq.get_default_x86_inductor_quantization_config(
|
||||
is_qat=is_qat,
|
||||
is_dynamic=is_dynamic,
|
||||
)
|
||||
)
|
||||
quantize_per_tensor_op = (
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
if is_dynamic
|
||||
else torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
)
|
||||
dequantize_per_tensor_op = (
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
if is_dynamic
|
||||
else torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
)
|
||||
cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list)
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
@ -1460,26 +1562,28 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_occurrence = {
|
||||
# Only one q-dq node for input of the linear
|
||||
# No q-dq node for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||
quantize_per_tensor_op: 1,
|
||||
dequantize_per_tensor_op: 1,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
else:
|
||||
# convert_pt2e disables duplicate dequant for dynamic quant
|
||||
num_dequant = 1 if is_dynamic else 2
|
||||
node_occurrence = {
|
||||
# One quantize_per_tensor for both linear nodes (shared)
|
||||
# Two dequantize_per_tensor for two linear nodes
|
||||
# No q-dq for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
quantize_per_tensor_op: 1,
|
||||
dequantize_per_tensor_op: num_dequant,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
quantize_per_tensor_op,
|
||||
dequantize_per_tensor_op,
|
||||
torch.ops.aten.linear.default,
|
||||
(
|
||||
torch.ops.aten.add_.Tensor
|
||||
@ -1509,57 +1613,91 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_unary(self):
|
||||
self._test_linear_binary_unary_helper()
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_unary_qat(self):
|
||||
self._test_linear_binary_unary_helper(is_qat=True)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_unary_dynamic(self):
|
||||
self._test_linear_binary_unary_helper(is_dynamic=True)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_unary_dynamic_qat(self):
|
||||
self._test_linear_binary_unary_helper(is_qat=True, is_dynamic=True)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_binary_unary_serials(self):
|
||||
"""
|
||||
Test pattern of 2 following up linear add relu with X86InductorQuantizer.
|
||||
"""
|
||||
is_qat_list = [False, True]
|
||||
is_dynamic_list = [False, True]
|
||||
cases = itertools.product(is_qat_list, is_dynamic_list)
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
m = TestHelperModules.SerialsLinearAddReLUModule().eval()
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
node_occurrence = {
|
||||
# quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4
|
||||
# dequantize_per_tensor: 1 for each linear
|
||||
# No q-dq for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.aten.relu.default,
|
||||
]
|
||||
fq_m = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)[-1]
|
||||
# Two linear nodes are quantized alone
|
||||
# The other two are fused with add and relu
|
||||
aten = torch.ops.aten
|
||||
expected_annotation_stat = {
|
||||
aten.linear.default: {
|
||||
"annotated": 4,
|
||||
"is_quant_out": 2,
|
||||
},
|
||||
aten.add.Tensor: {"annotated": 2, "is_quant_out": 0},
|
||||
aten.relu.default: {"annotated": 2, "is_quant_out": 2},
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
for is_qat, is_dynamic in cases:
|
||||
m = TestHelperModules.SerialsLinearAddReLUModule().eval()
|
||||
example_inputs = (torch.randn(2, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config(
|
||||
is_qat=is_qat,
|
||||
is_dynamic=is_dynamic,
|
||||
)
|
||||
)
|
||||
quantize_per_tensor_op = (
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
if is_dynamic
|
||||
else torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
)
|
||||
dequantize_per_tensor_op = (
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
if is_dynamic
|
||||
else torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
)
|
||||
# convert_pt2e disables duplicate dequant for dynamic quant
|
||||
num_dequant = 3 if is_dynamic else 4
|
||||
node_occurrence = {
|
||||
# quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4
|
||||
# dequantize_per_tensor: 1 for each linear
|
||||
# No q-dq for extra input node of add
|
||||
quantize_per_tensor_op: 3,
|
||||
dequantize_per_tensor_op: num_dequant,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
quantize_per_tensor_op,
|
||||
dequantize_per_tensor_op,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.aten.relu.default,
|
||||
]
|
||||
fq_m = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)[-1]
|
||||
# Two linear nodes are quantized alone
|
||||
# The other two are fused with add and relu
|
||||
aten = torch.ops.aten
|
||||
expected_annotation_stat = {
|
||||
aten.linear.default: {
|
||||
"annotated": 4,
|
||||
"is_quant_out": 2,
|
||||
},
|
||||
aten.add.Tensor: {"annotated": 2, "is_quant_out": 0},
|
||||
aten.relu.default: {"annotated": 2, "is_quant_out": 2},
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
@skipIfTorchDynamo("very slow")
|
||||
@skipIfNoX86
|
||||
|
||||
@ -1068,13 +1068,8 @@ class X86InductorQuantizer(Quantizer):
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
filter_fn: Optional[FilterFn] = None,
|
||||
):
|
||||
if (quantization_config is None) or (
|
||||
quantization_config.input_activation
|
||||
and not quantization_config.input_activation.is_dynamic
|
||||
):
|
||||
# <TODO> Weiwen: Dynamic Quant of linear unary will be supported in next step
|
||||
self._annotate_linear_binary_unary(model, quantization_config, filter_fn)
|
||||
self._annotate_linear_unary(model, quantization_config, filter_fn)
|
||||
self._annotate_linear_binary_unary(model, quantization_config, filter_fn)
|
||||
self._annotate_linear_unary(model, quantization_config, filter_fn)
|
||||
self._annotate_linear(model, quantization_config, filter_fn)
|
||||
|
||||
def _annotate_matmul(
|
||||
|
||||
Reference in New Issue
Block a user