mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[Quant][Inductor][X86] Separate unary post op fusion and lowering for qlinear (#143903)
**Summary** The current implementation fuses quantized ops and their post ops and lowers the fused the op to cpp backend in the same pass. It is better to separate post op fusion and lowering because - it looks better in terms of design - we need the post op fusion pass for PT2E quantization eager mode This PR is the first of a series of PRs which separate post op fusion and lowering for quantized linear and convolution. It moves unary post op fusion of qlinear out of the lowering pass. This PR moves the fusion pass from the lowering pass to after the weight-prepack pass. The workflow is 1. Weight prepack for qlinear so that `dq - linear` patterns are replaced by `onednn.qlinear_pointwise` 2. Fuse `onednn.qlinear_pointwise` and post ops 3. Lower to cpp backend This PR adds additional `PatternMatcherPass`'s to handle the post op fusion. Pattern matchers used for fusion are reused. **Test plan** It is covered by existing UTs in `test_mkldnn_pattern_matcher.py` for post op fusion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143903 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							094ca3154d
						
					
				
				
					commit
					f8fcb9e7d3
				
			| @ -2167,6 +2167,10 @@ class TestPatternMatcher(TestPatternMatcherBase): | ||||
|                 # 2. QLinear Unary fusion in post-grad fusion pass | ||||
|                 self.assertEqual( | ||||
|                     counters["inductor"]["qlinear_unary_matcher_count"], | ||||
|                     2, | ||||
|                 ) | ||||
|                 self.assertEqual( | ||||
|                     counters["inductor"]["qlinear_unary_lower_count"], | ||||
|                     0 if TEST_ACL else 2, | ||||
|                 ) | ||||
|  | ||||
| @ -2443,7 +2447,7 @@ class TestPatternMatcher(TestPatternMatcherBase): | ||||
|             # 3. QLinear Unary fusion in post-grad fusion pass * 1 | ||||
|             self.assertEqual( | ||||
|                 counters["inductor"]["qlinear_unary_matcher_count"], | ||||
|                 0 if TEST_ACL else 1, | ||||
|                 1, | ||||
|             ) | ||||
|  | ||||
|         self._test_common( | ||||
| @ -3706,7 +3710,7 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase): | ||||
|                 ) | ||||
|                 self.assertEqual( | ||||
|                     counters["inductor"]["qlinear_unary_matcher_count"], | ||||
|                     3 if annotate_matmul and not TEST_ACL else 0, | ||||
|                     3 if annotate_matmul else 0, | ||||
|                 ) | ||||
|  | ||||
|             quantizer = X86InductorQuantizer() | ||||
|  | ||||
| @ -102,6 +102,8 @@ def lazy_init(): | ||||
|  | ||||
|  | ||||
| def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0): | ||||
|     while pass_number > len(pass_patterns) - 1: | ||||
|         pass_patterns.append(PatternMatcherPass()) | ||||
|     return register_graph_pattern( | ||||
|         pattern, | ||||
|         extra_check=extra_check, | ||||
|  | ||||
| @ -401,7 +401,6 @@ def _register_quantized_linear_lowering( | ||||
|     pattern, | ||||
|     pass_number, | ||||
|     computation_op, | ||||
|     unary_attr, | ||||
| ): | ||||
|     @register_lowering_pattern( | ||||
|         pattern, | ||||
| @ -427,11 +426,13 @@ def _register_quantized_linear_lowering( | ||||
|         b = kwargs["b"] if "b" in kwargs else None | ||||
|  | ||||
|         # Output QParams | ||||
|         o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0 | ||||
|         o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0 | ||||
|         assert ( | ||||
|             kwargs["postop_name"] == "none" | ||||
|         )  # Expected no post op fused in weight prepack phase | ||||
|         o_inv_scale = kwargs["output_scale"] | ||||
|         o_zero_point = kwargs["output_zero_point"] | ||||
|  | ||||
|         # post op | ||||
|         postop_name = kwargs["postop_name"] | ||||
|         postop_args = kwargs["postop_args"] | ||||
|         postop_algorithm = kwargs["postop_algorithm"] | ||||
|  | ||||
|         computation_args = ( | ||||
|             x, | ||||
| @ -444,12 +445,12 @@ def _register_quantized_linear_lowering( | ||||
|             o_inv_scale, | ||||
|             o_zero_point, | ||||
|             output_dtype, | ||||
|             unary_attr.op_name, | ||||
|             unary_attr.scalars_attr, | ||||
|             unary_attr.algorithm_attr, | ||||
|             postop_name, | ||||
|             postop_args, | ||||
|             postop_algorithm, | ||||
|         ) | ||||
|         counters["inductor"]["qlinear_unary_matcher_count"] += 1 | ||||
|         counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes) | ||||
|         counters["inductor"]["qlinear_unary_lower_count"] += 1 | ||||
|         counters["inductor"]["qlinear_unary_lower_nodes"] += len(match.nodes) | ||||
|         return L[computation_op](*computation_args) | ||||
|  | ||||
|     return qlinear | ||||
| @ -704,13 +705,7 @@ def _register_quantized_conv_binary_lowering( | ||||
|  | ||||
|  | ||||
| def _register_quantization_unary_fusion(): | ||||
|     from .mkldnn_fusion import ( | ||||
|         _gelu_fusion_1 as _gelu_fusion_erf, | ||||
|         _gelu_fusion_2 as _gelu_fusion_tanh, | ||||
|         _hardswish_fusion, | ||||
|         _hardtanh_fusion, | ||||
|         _silu_fusion, | ||||
|     ) | ||||
|     from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion | ||||
|  | ||||
|     class UnaryAttr: | ||||
|         def __init__( | ||||
| @ -720,8 +715,8 @@ def _register_quantization_unary_fusion(): | ||||
|             self.scalars_attr = scalars_attr if scalars_attr else [] | ||||
|             self.algorithm_attr = algorithm_attr if algorithm_attr else "" | ||||
|  | ||||
|     # QConv2d | ||||
|     for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: | ||||
|         # QConv2d | ||||
|         # Priority 1 to match: QConv2d Unary pattern with int8 output | ||||
|         # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. | ||||
|         # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant | ||||
| @ -819,87 +814,19 @@ def _register_quantization_unary_fusion(): | ||||
|                 unary_attr,  # unary_attr | ||||
|             ) | ||||
|  | ||||
|         # QLinear | ||||
|         for x_scale_zp_are_tensors in (False, True): | ||||
|             qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) | ||||
|             # Priority 1 to match: QLinear Unary pattern with int8 output | ||||
|             linear_unary_replace_patterns = { | ||||
|                 UnaryAttr("none", [], ""): generate_pattern_with_output_quant( | ||||
|                     qlinear_pattern, | ||||
|                 ), | ||||
|                 UnaryAttr("relu", [], ""): generate_pattern_with_output_quant( | ||||
|                     generate_pattern_with_unary(qlinear_pattern, aten.relu.default), | ||||
|                 ), | ||||
|                 UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant( | ||||
|                     _unary_fusion_pattern( | ||||
|                         _gelu_fusion_erf, | ||||
|                         get_qlinear_pt2e_pattern( | ||||
|                             x_scale_zp_are_tensors, 1 if is_bf16 else 2 | ||||
|                         ), | ||||
|                         2, | ||||
|                         is_bf16, | ||||
|                     ), | ||||
|                     with_dtype_convert=is_bf16, | ||||
|                 ), | ||||
|                 UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant( | ||||
|                     _unary_fusion_pattern( | ||||
|                         _gelu_fusion_tanh, | ||||
|                         get_qlinear_pt2e_pattern( | ||||
|                             x_scale_zp_are_tensors, 1 if is_bf16 else 4 | ||||
|                         ), | ||||
|                         4, | ||||
|                         is_bf16, | ||||
|                     ), | ||||
|                     with_dtype_convert=is_bf16, | ||||
|                 ), | ||||
|             } | ||||
|  | ||||
|             for unary_attr, patterns in linear_unary_replace_patterns.items(): | ||||
|                 _register_quantized_linear_lowering( | ||||
|                     patterns, | ||||
|                     1,  # pass_number | ||||
|                     torch.ops.onednn.qlinear_pointwise,  # computation_op | ||||
|                     unary_attr,  # unary_attr | ||||
|                 ) | ||||
|  | ||||
|             # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output | ||||
|             linear_unary_replace_float_out_patterns = { | ||||
|                 UnaryAttr("relu", [], ""): generate_pattern_with_unary( | ||||
|                     qlinear_pattern, aten.relu.default | ||||
|                 ), | ||||
|                 UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert( | ||||
|                     _unary_fusion_pattern( | ||||
|                         _gelu_fusion_erf, | ||||
|                         get_qlinear_pt2e_pattern( | ||||
|                             x_scale_zp_are_tensors, 1 if is_bf16 else 2 | ||||
|                         ), | ||||
|                         2, | ||||
|                         is_bf16, | ||||
|                     ), | ||||
|                     Arg(), | ||||
|                     is_bf16, | ||||
|                 ), | ||||
|                 UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert( | ||||
|                     _unary_fusion_pattern( | ||||
|                         _gelu_fusion_tanh, | ||||
|                         get_qlinear_pt2e_pattern( | ||||
|                             x_scale_zp_are_tensors, 1 if is_bf16 else 4 | ||||
|                         ), | ||||
|                         4, | ||||
|                         is_bf16, | ||||
|                     ), | ||||
|                     Arg(), | ||||
|                     is_bf16, | ||||
|                 ), | ||||
|             } | ||||
|  | ||||
|             for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): | ||||
|                 _register_quantized_linear_lowering( | ||||
|                     patterns, | ||||
|                     2,  # pass_number | ||||
|                     torch.ops.onednn.qlinear_pointwise,  # computation_op | ||||
|                     unary_attr,  # unary_attr | ||||
|                 ) | ||||
|     # QLinear | ||||
|     for x_scale_zp_are_tensors in (False, True): | ||||
|         qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) | ||||
|         computation_op = ( | ||||
|             torch.ops.onednn.qlinear_pointwise.tensor | ||||
|             if x_scale_zp_are_tensors | ||||
|             else torch.ops.onednn.qlinear_pointwise.default | ||||
|         ) | ||||
|         _register_quantized_linear_lowering( | ||||
|             qlinear_pattern, | ||||
|             2,  # pass_number | ||||
|             computation_op, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def _register_quantization_binary_fusion(): | ||||
| @ -3059,6 +2986,177 @@ def _register_smooth_quant_int_mm_pattern(): | ||||
|                 ) | ||||
|  | ||||
|  | ||||
| def _register_qlinear_post_op_fusion_pass( | ||||
|     pattern, | ||||
|     pass_number, | ||||
|     computation_op, | ||||
|     unary_attr, | ||||
| ): | ||||
|     @register_freezing_graph_pattern( | ||||
|         pattern, | ||||
|         extra_check=_is_valid_quantized_linear_optimization_pattern(), | ||||
|         pass_number=pass_number, | ||||
|     ) | ||||
|     def qlinear_post_op_fusion(match: Match, *args, **kwargs): | ||||
|         """ | ||||
|         Match the pattern: | ||||
|         qlinear - post op | ||||
|         """ | ||||
|         output_dtype = _get_pattern_output_dtype(match) | ||||
|         # Activation QParams | ||||
|         x, x_scale, x_zp = ( | ||||
|             kwargs["x"], | ||||
|             kwargs["x_scale"], | ||||
|             kwargs["x_zp"], | ||||
|         ) | ||||
|         # Weight QParams | ||||
|         packed_weight, w_scale, w_zp = ( | ||||
|             kwargs["packed_weight"], | ||||
|             kwargs["w_scale"], | ||||
|             kwargs["w_zp"], | ||||
|         ) | ||||
|  | ||||
|         # bias | ||||
|         b = kwargs["b"] if "b" in kwargs else None | ||||
|  | ||||
|         # Output QParams | ||||
|         o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0 | ||||
|         o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0 | ||||
|         assert ( | ||||
|             kwargs["postop_name"] == "none" | ||||
|         )  # Expected no post op fused in weight prepack phase | ||||
|  | ||||
|         out_node = match.output_node() | ||||
|         with match.graph.inserting_before(out_node): | ||||
|             computation_args = ( | ||||
|                 x, | ||||
|                 x_scale, | ||||
|                 x_zp, | ||||
|                 packed_weight, | ||||
|                 w_scale, | ||||
|                 w_zp, | ||||
|                 b, | ||||
|                 o_inv_scale, | ||||
|                 o_zero_point, | ||||
|                 output_dtype, | ||||
|                 unary_attr.op_name, | ||||
|                 unary_attr.scalars_attr, | ||||
|                 unary_attr.algorithm_attr, | ||||
|             ) | ||||
|             new_linear_node = match.graph.call_function( | ||||
|                 computation_op, args=computation_args | ||||
|             ) | ||||
|             out_node.replace_all_uses_with(new_linear_node) | ||||
|             new_linear_node.meta.update(out_node.meta) | ||||
|             for node in reversed(match.nodes): | ||||
|                 match.graph.erase_node(node) | ||||
|         counters["inductor"]["qlinear_unary_matcher_count"] += 1 | ||||
|         counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes) | ||||
|  | ||||
|  | ||||
| def _register_qlinear_unary_fusion(): | ||||
|     from .mkldnn_fusion import ( | ||||
|         _gelu_fusion_1 as _gelu_fusion_erf, | ||||
|         _gelu_fusion_2 as _gelu_fusion_tanh, | ||||
|     ) | ||||
|  | ||||
|     class UnaryAttr: | ||||
|         def __init__( | ||||
|             self, op_name: str, scalars_attr=None, algorithm_attr=None | ||||
|         ) -> None: | ||||
|             self.op_name = op_name | ||||
|             self.scalars_attr = scalars_attr if scalars_attr else [] | ||||
|             self.algorithm_attr = algorithm_attr if algorithm_attr else "" | ||||
|  | ||||
|     for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: | ||||
|         is_bf16 = original_pattern_output_dtype == torch.bfloat16 | ||||
|         for x_scale_zp_are_tensors in (False, True): | ||||
|             qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) | ||||
|             computation_op = ( | ||||
|                 torch.ops.onednn.qlinear_pointwise.tensor | ||||
|                 if x_scale_zp_are_tensors | ||||
|                 else torch.ops.onednn.qlinear_pointwise.default | ||||
|             ) | ||||
|             # Priority 1 to match: QLinear Unary pattern with int8 output | ||||
|             linear_unary_replace_patterns = { | ||||
|                 UnaryAttr("none", [], ""): generate_pattern_with_output_quant( | ||||
|                     qlinear_pattern, | ||||
|                 ), | ||||
|                 UnaryAttr("relu", [], ""): generate_pattern_with_output_quant( | ||||
|                     generate_pattern_with_unary(qlinear_pattern, aten.relu.default), | ||||
|                 ), | ||||
|                 UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant( | ||||
|                     _unary_fusion_pattern( | ||||
|                         _gelu_fusion_erf, | ||||
|                         get_qlinear_pt2e_pattern( | ||||
|                             x_scale_zp_are_tensors, 1 if is_bf16 else 2 | ||||
|                         ), | ||||
|                         2, | ||||
|                         is_bf16, | ||||
|                     ), | ||||
|                     with_dtype_convert=is_bf16, | ||||
|                 ), | ||||
|                 UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant( | ||||
|                     _unary_fusion_pattern( | ||||
|                         _gelu_fusion_tanh, | ||||
|                         get_qlinear_pt2e_pattern( | ||||
|                             x_scale_zp_are_tensors, 1 if is_bf16 else 4 | ||||
|                         ), | ||||
|                         4, | ||||
|                         is_bf16, | ||||
|                     ), | ||||
|                     with_dtype_convert=is_bf16, | ||||
|                 ), | ||||
|             } | ||||
|  | ||||
|             for unary_attr, patterns in linear_unary_replace_patterns.items(): | ||||
|                 _register_qlinear_post_op_fusion_pass( | ||||
|                     patterns, | ||||
|                     3,  # pass_number | ||||
|                     computation_op, | ||||
|                     unary_attr,  # unary_attr | ||||
|                 ) | ||||
|  | ||||
|             # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output | ||||
|             linear_unary_replace_float_out_patterns = { | ||||
|                 UnaryAttr("relu", [], ""): generate_pattern_with_unary( | ||||
|                     qlinear_pattern, aten.relu.default | ||||
|                 ), | ||||
|                 UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert( | ||||
|                     _unary_fusion_pattern( | ||||
|                         _gelu_fusion_erf, | ||||
|                         get_qlinear_pt2e_pattern( | ||||
|                             x_scale_zp_are_tensors, 1 if is_bf16 else 2 | ||||
|                         ), | ||||
|                         2, | ||||
|                         is_bf16, | ||||
|                     ), | ||||
|                     Arg(), | ||||
|                     is_bf16, | ||||
|                 ), | ||||
|                 UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert( | ||||
|                     _unary_fusion_pattern( | ||||
|                         _gelu_fusion_tanh, | ||||
|                         get_qlinear_pt2e_pattern( | ||||
|                             x_scale_zp_are_tensors, 1 if is_bf16 else 4 | ||||
|                         ), | ||||
|                         4, | ||||
|                         is_bf16, | ||||
|                     ), | ||||
|                     Arg(), | ||||
|                     is_bf16, | ||||
|                 ), | ||||
|             } | ||||
|  | ||||
|             for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): | ||||
|                 _register_qlinear_post_op_fusion_pass( | ||||
|                     patterns, | ||||
|                     4,  # pass_number | ||||
|                     computation_op, | ||||
|                     unary_attr,  # unary_attr | ||||
|                 ) | ||||
|  | ||||
|  | ||||
| @functools.lru_cache(None) | ||||
| def _register_quantization_weight_pack_pass(): | ||||
|     # Step 1: Dequant promotion for int8-mixed-fp32/bf16 | ||||
| @ -3074,6 +3172,9 @@ def _register_quantization_weight_pack_pass(): | ||||
|     # Step 4: weight prepack for SmoothQuant from Torchao | ||||
|     _register_smooth_quant_int_mm_pattern() | ||||
|  | ||||
|     # Step 5: QLinear post op Fusion | ||||
|     _register_qlinear_unary_fusion() | ||||
|  | ||||
|  | ||||
| def quant_lift_up(graph_module: torch.fx.GraphModule): | ||||
|     """ | ||||
|  | ||||
| @ -2447,7 +2447,7 @@ if torch._C._has_mkldnn: | ||||
|         output_shape = list(x.shape) | ||||
|         # The weight has been transposed during the qlinear weight prepack process. | ||||
|         output_shape[-1] = w.shape[1] | ||||
|         assert output_dtype in [torch.float32, torch.bfloat16] | ||||
|         assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8] | ||||
|         out = x.new_empty(output_shape, dtype=output_dtype) | ||||
|         return out | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user