[Quant][PT2E] enable qlinear post op fusion for dynamic quant & qat (#122667)

**Description**
Add fusion path for dynamic quant and for QAT.
The following patterns can be matched for static quant with QAT cases:
`qx -> qlinear -> add -> optional relu -> optional type convert -> optional quant`

The following patterns can be matched for dynamic quant cases:
`qx -> qlinear -> add -> optional relu`

**Test plan**
python test/inductor/test_mkldnn_pattern_matcher.py -k test_qlinear
python test/inductor/test_cpu_cpp_wrapper.py -k test_qlinear
python test/test_quantization.py -k test_linear_unary
python test/test_quantization.py -k test_linear_binary

Differential Revision: [D57655830](https://our.internmc.facebook.com/intern/diff/D57655830)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122667
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
This commit is contained in:
Xia, Weiwen
2024-06-28 09:18:42 +08:00
committed by PyTorch MergeBot
parent a8985a97f9
commit 36e2608783
3 changed files with 276 additions and 139 deletions

View File

@ -1778,10 +1778,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
to_bf16_after_binary = 2 * (add_fn == add_fn_list[2] and fq_x2)
self.assertEqual(
counters["inductor"]["qlinear_binary_matcher_nodes"],
5 + 2 * use_relu + to_bf16_after_binary,
(4 if is_dynamic else 5) + 2 * use_relu + to_bf16_after_binary,
)
for is_qat in [False, True]:
is_qat_list = [False, True]
is_dynamic_list = [False, True]
cases = itertools.product(is_qat_list, is_dynamic_list)
for is_qat, is_dynamic in cases:
self._test_common(
mod,
(v,),
@ -1789,6 +1792,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
matcher_check_fn=matcher_check_fn,
is_qat=is_qat,
is_dynamic=is_dynamic,
)
if torch._inductor.config.cpp_wrapper:
# For CPP wrapper

View File

@ -1203,44 +1203,64 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_list,
)
@skipIfNoX86
def test_linear_unary(self):
def _test_linear_unary_helper(
self,
post_op_module,
post_op_aten,
post_op_aten_inplace,
post_op_algo_list=None,
is_qat=False,
is_dynamic=False,
):
"""
Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer.
"""
use_bias_list = [True, False]
inplace_list = [True, False]
postop_list = [nn.ReLU, nn.LeakyReLU] # only test two to save time
cases = itertools.product(use_bias_list, inplace_list, postop_list)
post_op_map = {
nn.ReLU: [torch.ops.aten.relu_.default, torch.ops.aten.relu.default],
nn.LeakyReLU: [
torch.ops.aten.leaky_relu_.default,
torch.ops.aten.leaky_relu.default,
],
}
# TODO test for inplace add after refactoring of capture_pre_autograd_graph
inplace_list = [False]
if post_op_algo_list is None:
post_op_algo_list = [None]
cases = itertools.product(use_bias_list, inplace_list, post_op_algo_list)
with override_quantized_engine("x86"), torch.no_grad():
for use_bias, inplace, postop in cases:
for use_bias, inplace, post_op_algo in cases:
if inplace and post_op_aten_inplace is None:
continue
m = TestHelperModules.LinearUnaryModule(
use_bias=use_bias, postop=postop, inplace_postop=inplace
use_bias=use_bias,
postop=post_op_module,
inplace_postop=inplace,
post_op_algo=post_op_algo,
).eval()
example_inputs = (torch.randn(2, 4),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
xiq.get_default_x86_inductor_quantization_config(
is_qat=is_qat,
is_dynamic=is_dynamic,
)
)
quantize_per_tensor_op = (
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
if is_dynamic
else torch.ops.quantized_decomposed.quantize_per_tensor.default
)
dequantize_per_tensor_op = (
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
if is_dynamic
else torch.ops.quantized_decomposed.dequantize_per_tensor.default
)
node_occurrence = {
# one for input and weight of the conv, one for output for the relu
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
# one for input of the linear
quantize_per_tensor_op: 1,
dequantize_per_tensor_op: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
quantize_per_tensor_op,
dequantize_per_tensor_op,
torch.ops.aten.linear.default,
post_op_map[postop][0 if inplace else 1],
post_op_aten_inplace if inplace else post_op_aten,
]
self._test_quantizer(
m,
@ -1248,47 +1268,70 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
quantizer,
node_occurrence,
node_list,
is_qat=is_qat,
)
@skipIfNoX86
def test_linear_unary_gelu(self):
"""
Test pattern of linear with unary post ops (e.g. gelu) with X86InductorQuantizer.
"""
use_bias_list = [True, False]
postop = nn.GELU
post_op_algorithm = ["none", "tanh"]
cases = itertools.product(use_bias_list, post_op_algorithm)
with override_quantized_engine("x86"), torch.no_grad():
for use_bias, post_op_algo in cases:
m = TestHelperModules.LinearUnaryModule(
use_bias=use_bias, postop=postop, post_op_algo=post_op_algo
).eval()
example_inputs = (torch.randn(2, 4),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
node_occurrence = {
# one for input and weight of the conv, one for output for the gelu
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.aten.gelu.default,
]
self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)
def test_linear_unary(self):
aten = torch.ops.aten
self._test_linear_unary_helper(nn.ReLU, aten.relu.default, aten.relu_.default)
self._test_linear_unary_helper(
nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default
)
self._test_linear_unary_helper(
nn.GELU, aten.gelu.default, None, ["none", "tanh"]
)
@skipIfNoX86
def test_linear_unary_qat(self):
aten = torch.ops.aten
self._test_linear_unary_helper(
nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True
)
self._test_linear_unary_helper(
nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default, is_qat=True
)
self._test_linear_unary_helper(
nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_qat=True
)
@skipIfNoX86
def test_linear_unary_dynamic(self):
aten = torch.ops.aten
self._test_linear_unary_helper(
nn.ReLU, aten.relu.default, aten.relu_.default, is_dynamic=True
)
self._test_linear_unary_helper(
nn.LeakyReLU,
aten.leaky_relu.default,
aten.leaky_relu_.default,
is_dynamic=True,
)
self._test_linear_unary_helper(
nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_dynamic=True
)
@skipIfNoX86
def test_linear_unary_dynamic_qat(self):
aten = torch.ops.aten
self._test_linear_unary_helper(
nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True, is_dynamic=True
)
self._test_linear_unary_helper(
nn.LeakyReLU,
aten.leaky_relu.default,
aten.leaky_relu_.default,
is_qat=True,
is_dynamic=True,
)
self._test_linear_unary_helper(
nn.GELU,
aten.gelu.default,
None,
["none", "tanh"],
is_qat=True,
is_dynamic=True,
)
def _check_annotation_stat(self, gm, expected_stat_dict):
# Check expected annotation statistics to ensure the annotation is correct
@ -1307,8 +1350,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
for op_stat in expected_stat_dict.values():
assert all(v == 0 for v in op_stat.values())
@skipIfNoX86
def test_linear_binary(self):
def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False):
"""
Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer.
Currently, only add as binary post op is supported.
@ -1318,7 +1360,20 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
inplace_add_list = [False]
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
xiq.get_default_x86_inductor_quantization_config(
is_qat=is_qat,
is_dynamic=is_dynamic,
)
)
quantize_per_tensor_op = (
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
if is_dynamic
else torch.ops.quantized_decomposed.quantize_per_tensor.default
)
dequantize_per_tensor_op = (
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
if is_dynamic
else torch.ops.quantized_decomposed.dequantize_per_tensor.default
)
cases = itertools.product(linear_pos_list, inplace_add_list)
with override_quantized_engine("x86"), torch.no_grad():
@ -1330,26 +1385,28 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_occurrence = {
# Only one 1 q-dq for input of the linear
# No q-dq for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
quantize_per_tensor_op: 1,
dequantize_per_tensor_op: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
else:
# convert_pt2e disables duplicate dequant for dynamic quant
num_dequant = 1 if is_dynamic else 2
node_occurrence = {
# One quantize_per_tensor for both linear nodes (shared)
# Two dequantize_per_tensor for two linear nodes
# No q-dq for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
quantize_per_tensor_op: 1,
dequantize_per_tensor_op: num_dequant,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
quantize_per_tensor_op,
dequantize_per_tensor_op,
torch.ops.aten.linear.default,
(
torch.ops.aten.add_.Tensor
@ -1363,6 +1420,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
quantizer,
node_occurrence,
node_list,
is_qat=is_qat,
)[-1]
# One linear and add are fused. The other linear is quantized alone if present
aten = torch.ops.aten
@ -1376,6 +1434,22 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfNoX86
def test_linear_binary(self):
self._test_linear_binary_helper()
@skipIfNoX86
def test_linear_binary_qat(self):
self._test_linear_binary_helper(is_qat=True)
@skipIfNoX86
def test_linear_binary_dynamic(self):
self._test_linear_binary_helper(is_dynamic=True)
@skipIfNoX86
def test_linear_binary_dynamic_qat(self):
self._test_linear_binary_helper(is_qat=True, is_dynamic=True)
@skipIfNoX86
def test_linear_binary2(self):
"""
@ -1386,28 +1460,43 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1
"""
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
# TODO test for inplace add after refactoring of capture_pre_autograd_graph
inplace_add_list = [False]
is_qat_list = [False, True]
is_dynamic_list = [False, True]
cases = itertools.product(inplace_add_list, is_qat_list, is_dynamic_list)
with override_quantized_engine("x86"), torch.no_grad():
for inplace_add in inplace_add_list:
for inplace_add, is_qat, is_dynamic in cases:
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config(
is_qat=is_qat, is_dynamic=is_dynamic
)
)
m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval()
quantize_per_tensor_op = (
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
if is_dynamic
else torch.ops.quantized_decomposed.quantize_per_tensor.default
)
dequantize_per_tensor_op = (
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
if is_dynamic
else torch.ops.quantized_decomposed.dequantize_per_tensor.default
)
# Two q-dq nodes for inputs of linear nodes
# No q-dq for extra input node of add
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
quantize_per_tensor_op: 2,
dequantize_per_tensor_op: 2,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
quantize_per_tensor_op,
dequantize_per_tensor_op,
torch.ops.aten.linear.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
(
torch.ops.aten.add_.Tensor
if inplace_add
@ -1434,7 +1523,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfNoX86
def test_linear_binary_unary(self):
def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False):
"""
Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
Currently, only add as binary post op and relu as unary post op are supported.
@ -1446,7 +1535,20 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
inplace_relu_list = [False]
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
xiq.get_default_x86_inductor_quantization_config(
is_qat=is_qat,
is_dynamic=is_dynamic,
)
)
quantize_per_tensor_op = (
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
if is_dynamic
else torch.ops.quantized_decomposed.quantize_per_tensor.default
)
dequantize_per_tensor_op = (
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
if is_dynamic
else torch.ops.quantized_decomposed.dequantize_per_tensor.default
)
cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list)
with override_quantized_engine("x86"), torch.no_grad():
@ -1460,26 +1562,28 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_occurrence = {
# Only one q-dq node for input of the linear
# No q-dq node for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
quantize_per_tensor_op: 1,
dequantize_per_tensor_op: 1,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
else:
# convert_pt2e disables duplicate dequant for dynamic quant
num_dequant = 1 if is_dynamic else 2
node_occurrence = {
# One quantize_per_tensor for both linear nodes (shared)
# Two dequantize_per_tensor for two linear nodes
# No q-dq for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
quantize_per_tensor_op: 1,
dequantize_per_tensor_op: num_dequant,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
quantize_per_tensor_op,
dequantize_per_tensor_op,
torch.ops.aten.linear.default,
(
torch.ops.aten.add_.Tensor
@ -1509,57 +1613,91 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfNoX86
def test_linear_binary_unary(self):
self._test_linear_binary_unary_helper()
@skipIfNoX86
def test_linear_binary_unary_qat(self):
self._test_linear_binary_unary_helper(is_qat=True)
@skipIfNoX86
def test_linear_binary_unary_dynamic(self):
self._test_linear_binary_unary_helper(is_dynamic=True)
@skipIfNoX86
def test_linear_binary_unary_dynamic_qat(self):
self._test_linear_binary_unary_helper(is_qat=True, is_dynamic=True)
@skipIfNoX86
def test_linear_binary_unary_serials(self):
"""
Test pattern of 2 following up linear add relu with X86InductorQuantizer.
"""
is_qat_list = [False, True]
is_dynamic_list = [False, True]
cases = itertools.product(is_qat_list, is_dynamic_list)
with override_quantized_engine("x86"), torch.no_grad():
m = TestHelperModules.SerialsLinearAddReLUModule().eval()
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
node_occurrence = {
# quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4
# dequantize_per_tensor: 1 for each linear
# No q-dq for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.aten.linear.default,
torch.ops.aten.add.Tensor,
torch.ops.aten.relu.default,
]
fq_m = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)[-1]
# Two linear nodes are quantized alone
# The other two are fused with add and relu
aten = torch.ops.aten
expected_annotation_stat = {
aten.linear.default: {
"annotated": 4,
"is_quant_out": 2,
},
aten.add.Tensor: {"annotated": 2, "is_quant_out": 0},
aten.relu.default: {"annotated": 2, "is_quant_out": 2},
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
for is_qat, is_dynamic in cases:
m = TestHelperModules.SerialsLinearAddReLUModule().eval()
example_inputs = (torch.randn(2, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config(
is_qat=is_qat,
is_dynamic=is_dynamic,
)
)
quantize_per_tensor_op = (
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
if is_dynamic
else torch.ops.quantized_decomposed.quantize_per_tensor.default
)
dequantize_per_tensor_op = (
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
if is_dynamic
else torch.ops.quantized_decomposed.dequantize_per_tensor.default
)
# convert_pt2e disables duplicate dequant for dynamic quant
num_dequant = 3 if is_dynamic else 4
node_occurrence = {
# quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4
# dequantize_per_tensor: 1 for each linear
# No q-dq for extra input node of add
quantize_per_tensor_op: 3,
dequantize_per_tensor_op: num_dequant,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_channel.default,
quantize_per_tensor_op,
dequantize_per_tensor_op,
torch.ops.aten.linear.default,
torch.ops.aten.linear.default,
torch.ops.aten.linear.default,
torch.ops.aten.add.Tensor,
torch.ops.aten.relu.default,
]
fq_m = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)[-1]
# Two linear nodes are quantized alone
# The other two are fused with add and relu
aten = torch.ops.aten
expected_annotation_stat = {
aten.linear.default: {
"annotated": 4,
"is_quant_out": 2,
},
aten.add.Tensor: {"annotated": 2, "is_quant_out": 0},
aten.relu.default: {"annotated": 2, "is_quant_out": 2},
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfTorchDynamo("very slow")
@skipIfNoX86

View File

@ -1068,13 +1068,8 @@ class X86InductorQuantizer(Quantizer):
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[FilterFn] = None,
):
if (quantization_config is None) or (
quantization_config.input_activation
and not quantization_config.input_activation.is_dynamic
):
# <TODO> Weiwen: Dynamic Quant of linear unary will be supported in next step
self._annotate_linear_binary_unary(model, quantization_config, filter_fn)
self._annotate_linear_unary(model, quantization_config, filter_fn)
self._annotate_linear_binary_unary(model, quantization_config, filter_fn)
self._annotate_linear_unary(model, quantization_config, filter_fn)
self._annotate_linear(model, quantization_config, filter_fn)
def _annotate_matmul(