[Quant][Inductor][X86] Separate unary post op fusion and lowering for qlinear (#143903)

**Summary**
The current implementation fuses quantized ops and their post ops and lowers the fused the op to cpp backend in the same pass. It is better to separate post op fusion and lowering because
- it looks better in terms of design
- we need the post op fusion pass for PT2E quantization eager mode

This PR is the first of a series of PRs which separate post op fusion and lowering for quantized linear and convolution. It moves unary post op fusion of qlinear out of the lowering pass.
This PR moves the fusion pass from the lowering pass to after the weight-prepack pass. The workflow is
1. Weight prepack for qlinear so that `dq - linear` patterns are replaced by `onednn.qlinear_pointwise`
2. Fuse `onednn.qlinear_pointwise` and post ops
3. Lower to cpp backend

This PR adds additional `PatternMatcherPass`'s to handle the post op fusion. Pattern matchers used for fusion are reused.

**Test plan**
It is covered by existing UTs in `test_mkldnn_pattern_matcher.py` for post op fusion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143903
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
Xia, Weiwen
2025-01-07 01:03:40 -08:00
committed by PyTorch MergeBot
parent 094ca3154d
commit f8fcb9e7d3
4 changed files with 210 additions and 103 deletions

View File

@ -2167,6 +2167,10 @@ class TestPatternMatcher(TestPatternMatcherBase):
# 2. QLinear Unary fusion in post-grad fusion pass
self.assertEqual(
counters["inductor"]["qlinear_unary_matcher_count"],
2,
)
self.assertEqual(
counters["inductor"]["qlinear_unary_lower_count"],
0 if TEST_ACL else 2,
)
@ -2443,7 +2447,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
# 3. QLinear Unary fusion in post-grad fusion pass * 1
self.assertEqual(
counters["inductor"]["qlinear_unary_matcher_count"],
0 if TEST_ACL else 1,
1,
)
self._test_common(
@ -3706,7 +3710,7 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
)
self.assertEqual(
counters["inductor"]["qlinear_unary_matcher_count"],
3 if annotate_matmul and not TEST_ACL else 0,
3 if annotate_matmul else 0,
)
quantizer = X86InductorQuantizer()

View File

@ -102,6 +102,8 @@ def lazy_init():
def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0):
while pass_number > len(pass_patterns) - 1:
pass_patterns.append(PatternMatcherPass())
return register_graph_pattern(
pattern,
extra_check=extra_check,

View File

@ -401,7 +401,6 @@ def _register_quantized_linear_lowering(
pattern,
pass_number,
computation_op,
unary_attr,
):
@register_lowering_pattern(
pattern,
@ -427,11 +426,13 @@ def _register_quantized_linear_lowering(
b = kwargs["b"] if "b" in kwargs else None
# Output QParams
o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
assert (
kwargs["postop_name"] == "none"
) # Expected no post op fused in weight prepack phase
o_inv_scale = kwargs["output_scale"]
o_zero_point = kwargs["output_zero_point"]
# post op
postop_name = kwargs["postop_name"]
postop_args = kwargs["postop_args"]
postop_algorithm = kwargs["postop_algorithm"]
computation_args = (
x,
@ -444,12 +445,12 @@ def _register_quantized_linear_lowering(
o_inv_scale,
o_zero_point,
output_dtype,
unary_attr.op_name,
unary_attr.scalars_attr,
unary_attr.algorithm_attr,
postop_name,
postop_args,
postop_algorithm,
)
counters["inductor"]["qlinear_unary_matcher_count"] += 1
counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes)
counters["inductor"]["qlinear_unary_lower_count"] += 1
counters["inductor"]["qlinear_unary_lower_nodes"] += len(match.nodes)
return L[computation_op](*computation_args)
return qlinear
@ -704,13 +705,7 @@ def _register_quantized_conv_binary_lowering(
def _register_quantization_unary_fusion():
from .mkldnn_fusion import (
_gelu_fusion_1 as _gelu_fusion_erf,
_gelu_fusion_2 as _gelu_fusion_tanh,
_hardswish_fusion,
_hardtanh_fusion,
_silu_fusion,
)
from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion
class UnaryAttr:
def __init__(
@ -720,8 +715,8 @@ def _register_quantization_unary_fusion():
self.scalars_attr = scalars_attr if scalars_attr else []
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
# QConv2d
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
# QConv2d
# Priority 1 to match: QConv2d Unary pattern with int8 output
# If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
# For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
@ -819,87 +814,19 @@ def _register_quantization_unary_fusion():
unary_attr, # unary_attr
)
# QLinear
for x_scale_zp_are_tensors in (False, True):
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
# Priority 1 to match: QLinear Unary pattern with int8 output
linear_unary_replace_patterns = {
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
qlinear_pattern,
),
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
),
UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_gelu_fusion_erf,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 2
),
2,
is_bf16,
),
with_dtype_convert=is_bf16,
),
UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_gelu_fusion_tanh,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 4
),
4,
is_bf16,
),
with_dtype_convert=is_bf16,
),
}
for unary_attr, patterns in linear_unary_replace_patterns.items():
_register_quantized_linear_lowering(
patterns,
1, # pass_number
torch.ops.onednn.qlinear_pointwise, # computation_op
unary_attr, # unary_attr
)
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
linear_unary_replace_float_out_patterns = {
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
qlinear_pattern, aten.relu.default
),
UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_gelu_fusion_erf,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 2
),
2,
is_bf16,
),
Arg(),
is_bf16,
),
UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_gelu_fusion_tanh,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 4
),
4,
is_bf16,
),
Arg(),
is_bf16,
),
}
for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
_register_quantized_linear_lowering(
patterns,
2, # pass_number
torch.ops.onednn.qlinear_pointwise, # computation_op
unary_attr, # unary_attr
)
# QLinear
for x_scale_zp_are_tensors in (False, True):
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
computation_op = (
torch.ops.onednn.qlinear_pointwise.tensor
if x_scale_zp_are_tensors
else torch.ops.onednn.qlinear_pointwise.default
)
_register_quantized_linear_lowering(
qlinear_pattern,
2, # pass_number
computation_op,
)
def _register_quantization_binary_fusion():
@ -3059,6 +2986,177 @@ def _register_smooth_quant_int_mm_pattern():
)
def _register_qlinear_post_op_fusion_pass(
pattern,
pass_number,
computation_op,
unary_attr,
):
@register_freezing_graph_pattern(
pattern,
extra_check=_is_valid_quantized_linear_optimization_pattern(),
pass_number=pass_number,
)
def qlinear_post_op_fusion(match: Match, *args, **kwargs):
"""
Match the pattern:
qlinear - post op
"""
output_dtype = _get_pattern_output_dtype(match)
# Activation QParams
x, x_scale, x_zp = (
kwargs["x"],
kwargs["x_scale"],
kwargs["x_zp"],
)
# Weight QParams
packed_weight, w_scale, w_zp = (
kwargs["packed_weight"],
kwargs["w_scale"],
kwargs["w_zp"],
)
# bias
b = kwargs["b"] if "b" in kwargs else None
# Output QParams
o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
assert (
kwargs["postop_name"] == "none"
) # Expected no post op fused in weight prepack phase
out_node = match.output_node()
with match.graph.inserting_before(out_node):
computation_args = (
x,
x_scale,
x_zp,
packed_weight,
w_scale,
w_zp,
b,
o_inv_scale,
o_zero_point,
output_dtype,
unary_attr.op_name,
unary_attr.scalars_attr,
unary_attr.algorithm_attr,
)
new_linear_node = match.graph.call_function(
computation_op, args=computation_args
)
out_node.replace_all_uses_with(new_linear_node)
new_linear_node.meta.update(out_node.meta)
for node in reversed(match.nodes):
match.graph.erase_node(node)
counters["inductor"]["qlinear_unary_matcher_count"] += 1
counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes)
def _register_qlinear_unary_fusion():
from .mkldnn_fusion import (
_gelu_fusion_1 as _gelu_fusion_erf,
_gelu_fusion_2 as _gelu_fusion_tanh,
)
class UnaryAttr:
def __init__(
self, op_name: str, scalars_attr=None, algorithm_attr=None
) -> None:
self.op_name = op_name
self.scalars_attr = scalars_attr if scalars_attr else []
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
is_bf16 = original_pattern_output_dtype == torch.bfloat16
for x_scale_zp_are_tensors in (False, True):
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
computation_op = (
torch.ops.onednn.qlinear_pointwise.tensor
if x_scale_zp_are_tensors
else torch.ops.onednn.qlinear_pointwise.default
)
# Priority 1 to match: QLinear Unary pattern with int8 output
linear_unary_replace_patterns = {
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
qlinear_pattern,
),
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
),
UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_gelu_fusion_erf,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 2
),
2,
is_bf16,
),
with_dtype_convert=is_bf16,
),
UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_gelu_fusion_tanh,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 4
),
4,
is_bf16,
),
with_dtype_convert=is_bf16,
),
}
for unary_attr, patterns in linear_unary_replace_patterns.items():
_register_qlinear_post_op_fusion_pass(
patterns,
3, # pass_number
computation_op,
unary_attr, # unary_attr
)
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
linear_unary_replace_float_out_patterns = {
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
qlinear_pattern, aten.relu.default
),
UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_gelu_fusion_erf,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 2
),
2,
is_bf16,
),
Arg(),
is_bf16,
),
UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_gelu_fusion_tanh,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 4
),
4,
is_bf16,
),
Arg(),
is_bf16,
),
}
for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
_register_qlinear_post_op_fusion_pass(
patterns,
4, # pass_number
computation_op,
unary_attr, # unary_attr
)
@functools.lru_cache(None)
def _register_quantization_weight_pack_pass():
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
@ -3074,6 +3172,9 @@ def _register_quantization_weight_pack_pass():
# Step 4: weight prepack for SmoothQuant from Torchao
_register_smooth_quant_int_mm_pattern()
# Step 5: QLinear post op Fusion
_register_qlinear_unary_fusion()
def quant_lift_up(graph_module: torch.fx.GraphModule):
"""

View File

@ -2447,7 +2447,7 @@ if torch._C._has_mkldnn:
output_shape = list(x.shape)
# The weight has been transposed during the qlinear weight prepack process.
output_shape[-1] = w.shape[1]
assert output_dtype in [torch.float32, torch.bfloat16]
assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8]
out = x.new_empty(output_shape, dtype=output_dtype)
return out