mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Ensures pyrefly ignores only ignore the intended error code. pyrefly check lintrunner Pull Request resolved: https://github.com/pytorch/pytorch/pull/166247 Approved by: https://github.com/oulgen
		
			
				
	
	
		
			1362 lines
		
	
	
		
			54 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1362 lines
		
	
	
		
			54 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # mypy: allow-untyped-defs
 | |
| import functools
 | |
| from typing import Optional, Union
 | |
| 
 | |
| import torch
 | |
| import torch.utils._pytree as pytree
 | |
| from torch._inductor.kernel.mm_common import mm_args
 | |
| 
 | |
| from . import config, ir
 | |
| from .codegen.cpp_gemm_template import CppGemmTemplate
 | |
| from .codegen.cpp_grouped_gemm_template import CppGroupedGemmTemplate
 | |
| from .codegen.cpp_utils import create_epilogue_with_attr
 | |
| from .ir import TensorBox
 | |
| from .lowering import (
 | |
|     add,
 | |
|     add_needs_realized_inputs,
 | |
|     aten,
 | |
|     permute,
 | |
|     register_lowering,
 | |
|     to_dtype,
 | |
|     view,
 | |
| )
 | |
| from .select_algorithm import (
 | |
|     autotune_select_algorithm,
 | |
|     ChoiceCaller,
 | |
|     ExternKernelChoice,
 | |
| )
 | |
| from .utils import use_aten_gemm_kernels, use_cpp_gemm_template
 | |
| from .virtualized import ops, OpsValue, V
 | |
| 
 | |
| 
 | |
| def create_int8_compensation(
 | |
|     W_tensor: torch.Tensor,
 | |
|     packed_weight: ir.TensorBox,
 | |
|     x_scale: ir.TensorBox,
 | |
|     x_zp: ir.TensorBox,
 | |
|     w_scale: ir.TensorBox,
 | |
| ) -> tuple[
 | |
|     bool,
 | |
|     Union[ir.TensorBox, ir.ShapeAsConstantBuffer],
 | |
|     Optional[Union[ir.TensorBox, ir.ShapeAsConstantBuffer]],
 | |
| ]:
 | |
|     x_w_scale: Optional[Union[ir.TensorBox, ir.ShapeAsConstantBuffer]] = None
 | |
|     use_int8_fast_compensation_path = all(
 | |
|         isinstance(item, ir.TensorBox)
 | |
|         and item.get_name() in V.graph.constants
 | |
|         and hasattr(item.data, "data")
 | |
|         and isinstance(item.data.data, ir.ConstantBuffer)
 | |
|         for item in [x_scale, x_zp, w_scale]
 | |
|     )
 | |
|     if use_int8_fast_compensation_path:
 | |
|         x_w_scale_tensor = (
 | |
|             V.graph.constants[x_scale.get_name()]
 | |
|             * V.graph.constants[w_scale.get_name()]
 | |
|         )
 | |
|         x_w_scale = V.graph.add_tensor_constant(
 | |
|             x_w_scale_tensor,
 | |
|             name=packed_weight.get_name() + "_x_w_compens",
 | |
|         )
 | |
|         weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0)
 | |
|         x_zp_tensor = V.graph.constants[x_zp.get_name()]
 | |
|         weight_compens_tensor = weight_compens_tensor * x_w_scale_tensor * x_zp_tensor
 | |
|         weight_compens = V.graph.add_tensor_constant(
 | |
|             weight_compens_tensor,
 | |
|             name=packed_weight.get_name() + "_BMatrixCompens",
 | |
|         )
 | |
|     else:
 | |
|         weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0)
 | |
|         weight_compens = V.graph.add_tensor_constant(
 | |
|             weight_compens_tensor,
 | |
|             name=packed_weight.get_name() + "_BMatrixCompens",
 | |
|         )
 | |
|     return (  # type: ignore[return-type]
 | |
|         use_int8_fast_compensation_path,
 | |
|         weight_compens,
 | |
|         x_w_scale,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def codegen_int8_gemm_template_compensation(
 | |
|     use_int8_fast_compensation_path: bool,
 | |
|     input: OpsValue,
 | |
|     _weight_compo: OpsValue,
 | |
|     _x_scale: Optional[OpsValue],
 | |
|     _x_zp: Optional[OpsValue],
 | |
|     _w_scale: Optional[OpsValue],
 | |
|     _x_w_scale: Optional[OpsValue],
 | |
| ) -> OpsValue:
 | |
|     if use_int8_fast_compensation_path:
 | |
|         temp = ops.sub(
 | |
|             ops.mul(
 | |
|                 input,
 | |
|                 _x_w_scale,
 | |
|             ),
 | |
|             _weight_compo,
 | |
|         )
 | |
|     else:
 | |
|         temp = ops.mul(
 | |
|             ops.mul(
 | |
|                 input,
 | |
|                 _x_scale,
 | |
|             ),
 | |
|             _w_scale,
 | |
|         )
 | |
|         # NOTE: We will apply compensation even if the x_zp is 0 for int8 quantization.
 | |
|         # That's because when torch.compile is invoked for dynamic quantization,
 | |
|         # x might coincidentally have such values that x_zp might be zero despite
 | |
|         # asymmetric quantization.
 | |
|         # Besides, if x_zp is dummy for int8 x, or if x is statically quantized,
 | |
|         # we'd still perform that redundant compute to avoid making the code messy
 | |
|         # because we discovered that redundant computation of compensation did not
 | |
|         # lead to performance degradation with the input shapes tested.
 | |
|         temp = ops.sub(
 | |
|             temp,
 | |
|             ops.mul(
 | |
|                 ops.mul(
 | |
|                     ops.mul(
 | |
|                         _x_scale,
 | |
|                         _w_scale,
 | |
|                     ),
 | |
|                     _x_zp,
 | |
|                 ),
 | |
|                 _weight_compo,
 | |
|             ),
 | |
|         )
 | |
|     return temp
 | |
| 
 | |
| 
 | |
| def grouped_gemm_lowering(
 | |
|     x: TensorBox,
 | |
|     w: list[TensorBox],
 | |
|     b: list[TensorBox],
 | |
|     attr=None,
 | |
|     scalars=None,
 | |
|     algorithm=None,
 | |
|     layout=None,
 | |
| ):
 | |
|     x_size = x.get_size()
 | |
|     if len(x_size) > 2:
 | |
|         # GEMM template needs 2D input, normalize input shape here
 | |
|         x = view(x, [-1, x_size[-1]])
 | |
|     num_gemm = len(w)
 | |
| 
 | |
|     assert config.max_autotune or config.max_autotune_gemm
 | |
|     # pyrefly: ignore [bad-assignment]
 | |
|     b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b]
 | |
| 
 | |
|     choices: list[ChoiceCaller] = []
 | |
|     *_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
 | |
| 
 | |
|     kwargs = {
 | |
|         "has_bias": [bias is not None for bias in b],
 | |
|         "trans_w": True,
 | |
|         "epilogue_creator": None,
 | |
|         "act_mapping": dict.fromkeys(range(num_gemm), x),
 | |
|     }
 | |
| 
 | |
|     input_nodes = [x, *w]
 | |
|     input_nodes.extend([bias for bias in b if bias is not None])
 | |
| 
 | |
|     CppGroupedGemmTemplate.add_choices(
 | |
|         choices,
 | |
|         layout,
 | |
|         input_nodes,
 | |
|         **kwargs,  # type: ignore[arg-type]
 | |
|     )
 | |
| 
 | |
|     assert len(choices) != 0
 | |
|     result = autotune_select_algorithm(
 | |
|         "grouped_gemm",
 | |
|         choices,
 | |
|         input_nodes,
 | |
|         layout,
 | |
|     )
 | |
|     template_buf = result.data.data
 | |
|     return_bufs = [
 | |
|         ir.MultiOutput(layout, template_buf, [(list, gemm_idx)])
 | |
|         for gemm_idx in range(num_gemm)
 | |
|     ]
 | |
|     # pyrefly: ignore [bad-argument-type]
 | |
|     template_buf.layout = ir.MultiOutputLayout(device=input_nodes[0].get_device())
 | |
|     template_buf.outputs = return_bufs
 | |
|     return_tensors = [
 | |
|         ir.TensorBox.create(return_bufs[gemm_idx]) for gemm_idx in range(num_gemm)
 | |
|     ]
 | |
|     if len(x_size) > 2:
 | |
|         for gemm_idx in range(num_gemm):
 | |
|             return_tensors[gemm_idx] = view(
 | |
|                 return_tensors[gemm_idx],  # type: ignore[arg-type]
 | |
|                 (*x_size[:-1], return_tensors[gemm_idx].get_size()[-1]),
 | |
|             )
 | |
|     return return_tensors
 | |
| 
 | |
| 
 | |
| grouped_gemm_lowering._inductor_lowering_function = True  # type: ignore[attr-defined]
 | |
| 
 | |
| 
 | |
| def register_onednn_fusion_ops():
 | |
|     if torch._C._has_mkldnn:
 | |
|         from . import mkldnn_ir
 | |
| 
 | |
|         aten_mkldnn_linear_unary = ExternKernelChoice(
 | |
|             torch.ops.mkldnn._linear_pointwise,
 | |
|             "mkldnn::_linear_pointwise",
 | |
|             has_out_variant=False,
 | |
|             kernel_creator=mkldnn_ir.LinearUnary.create,
 | |
|         )
 | |
|         aten_mkldnn_linear_binary = ExternKernelChoice(
 | |
|             torch.ops.mkldnn._linear_pointwise.binary,
 | |
|             "mkldnn::_linear_pointwise",
 | |
|             has_out_variant=False,
 | |
|             kernel_creator=mkldnn_ir.LinearBinary.create,
 | |
|         )
 | |
|         aten_mkldnn_qlinear_unary = ExternKernelChoice(
 | |
|             torch.ops.onednn.qlinear_pointwise,
 | |
|             "onednn::qlinear_pointwise",
 | |
|             has_out_variant=False,
 | |
|             kernel_creator=mkldnn_ir.QLinearPointwisePT2E.create,
 | |
|         )
 | |
|         aten_mkldnn_qlinear_binary = ExternKernelChoice(
 | |
|             torch.ops.onednn.qlinear_pointwise.binary,
 | |
|             "onednn::qlinear_pointwise",
 | |
|             has_out_variant=False,
 | |
|             kernel_creator=mkldnn_ir.QLinearPointwiseBinaryPT2E.create,
 | |
|         )
 | |
|         cpu_needs_realized_inputs = [
 | |
|             torch.ops.mkldnn._convolution_pointwise,
 | |
|             torch.ops.mkldnn._convolution_pointwise_,
 | |
|             torch.ops.mkldnn._convolution_transpose_pointwise,
 | |
|             torch.ops.mkldnn._linear_pointwise,
 | |
|             aten.mkldnn_rnn_layer.default,
 | |
|             torch.ops.onednn.qconv_pointwise,
 | |
|         ]
 | |
| 
 | |
|         @register_lowering(torch.ops.mkldnn._convolution_pointwise)
 | |
|         def convolution_unary(
 | |
|             x: TensorBox,
 | |
|             weight: TensorBox,
 | |
|             bias: TensorBox,
 | |
|             padding,
 | |
|             stride,
 | |
|             dilation,
 | |
|             groups,
 | |
|             attr,
 | |
|             scalars,
 | |
|             algorithm,
 | |
|         ):
 | |
|             return TensorBox.create(
 | |
|                 mkldnn_ir.ConvolutionUnary.create(
 | |
|                     x,
 | |
|                     weight,
 | |
|                     bias,
 | |
|                     padding,
 | |
|                     stride,
 | |
|                     dilation,
 | |
|                     groups,
 | |
|                     attr,
 | |
|                     scalars,
 | |
|                     algorithm,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary)
 | |
|         def convolution_binary(
 | |
|             x: TensorBox,
 | |
|             other: TensorBox,
 | |
|             weight: TensorBox,
 | |
|             bias: TensorBox,
 | |
|             padding,
 | |
|             stride,
 | |
|             dilation,
 | |
|             groups,
 | |
|             binary_attr,
 | |
|             binary_alpha,
 | |
|             unary_attr,
 | |
|             unary_scalars,
 | |
|             unary_algorithm,
 | |
|         ):
 | |
|             return TensorBox.create(
 | |
|                 mkldnn_ir.ConvolutionBinary.create(
 | |
|                     x,
 | |
|                     other,
 | |
|                     weight,
 | |
|                     bias,
 | |
|                     padding,
 | |
|                     stride,
 | |
|                     dilation,
 | |
|                     groups,
 | |
|                     binary_attr,
 | |
|                     binary_alpha,
 | |
|                     unary_attr,
 | |
|                     unary_scalars,
 | |
|                     unary_algorithm,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary)
 | |
|         def convolution_binary_inplace(
 | |
|             x: TensorBox,
 | |
|             other: TensorBox,
 | |
|             weight: TensorBox,
 | |
|             bias: TensorBox,
 | |
|             padding,
 | |
|             stride,
 | |
|             dilation,
 | |
|             groups,
 | |
|             binary_attr,
 | |
|             binary_alpha,
 | |
|             unary_attr,
 | |
|             unary_scalars,
 | |
|             unary_algorithm,
 | |
|         ):
 | |
|             return TensorBox.create(
 | |
|                 mkldnn_ir.ConvolutionBinaryInplace.create(
 | |
|                     x,
 | |
|                     other,
 | |
|                     weight,
 | |
|                     bias,
 | |
|                     padding,
 | |
|                     stride,
 | |
|                     dilation,
 | |
|                     groups,
 | |
|                     binary_attr,
 | |
|                     binary_alpha,
 | |
|                     unary_attr,
 | |
|                     unary_scalars,
 | |
|                     unary_algorithm,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         @register_lowering(torch.ops.mkldnn._linear_pointwise)
 | |
|         def linear_unary(
 | |
|             x: TensorBox,
 | |
|             w: TensorBox,
 | |
|             b: TensorBox,
 | |
|             attr,
 | |
|             scalars,
 | |
|             algorithm,
 | |
|             layout=None,
 | |
|         ):
 | |
|             x_size = x.get_size()
 | |
|             if len(x_size) > 2:
 | |
|                 # GEMM template needs 2D input, normalize input shape here
 | |
|                 x = view(x, [-1, x_size[-1]])
 | |
|             if b is not None:
 | |
|                 b = ir.ExternKernel.realize_input(b)  # type: ignore[assignment]
 | |
|             choices: list[ChoiceCaller] = []
 | |
|             if config.max_autotune or config.max_autotune_gemm:
 | |
|                 transposed_w = permute(w, [1, 0])
 | |
|                 *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout)
 | |
|                 if use_cpp_gemm_template(layout, x, transposed_w):
 | |
| 
 | |
|                     def epilogue_creator(buf):
 | |
|                         return create_epilogue_with_attr(
 | |
|                             buf, attr, scalars=scalars, algorithm=algorithm
 | |
|                         )
 | |
| 
 | |
|                     kwargs = {
 | |
|                         "has_bias": b is not None,
 | |
|                         "trans_w": True,
 | |
|                         "epilogue_creator": (
 | |
|                             None if attr == "none" else epilogue_creator
 | |
|                         ),
 | |
|                     }
 | |
|                     if b is not None:
 | |
|                         kwargs["input_indices"] = [2, 0, 1]  # type: ignore[assignment]
 | |
|                     CppGemmTemplate.add_choices(
 | |
|                         choices,
 | |
|                         layout,
 | |
|                         [x, w] if b is None else [x, w, b],
 | |
|                         **kwargs,  # type: ignore[arg-type]
 | |
|                     )
 | |
|             if len(choices) == 0 or use_aten_gemm_kernels():
 | |
|                 kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm)
 | |
|                 if b is None:
 | |
|                     kwargs["B"] = None
 | |
|                 choices.append(
 | |
|                     aten_mkldnn_linear_unary.bind(
 | |
|                         [x, w] if b is None else [x, w, b],
 | |
|                         layout,
 | |
|                         **kwargs,
 | |
|                     )
 | |
|                 )
 | |
|             assert w.get_name() in V.graph.constants
 | |
|             input_gen_fns = {
 | |
|                 1: lambda x: V.graph.constants[x.get_name()],
 | |
|             }
 | |
|             result = autotune_select_algorithm(
 | |
|                 "linear_unary",
 | |
|                 choices,
 | |
|                 [x, w] if b is None else [x, w, b],
 | |
|                 layout,
 | |
|                 input_gen_fns=input_gen_fns,
 | |
|             )
 | |
|             if len(x_size) > 2:
 | |
|                 result = view(result, (*x_size[:-1], result.get_size()[-1]))
 | |
|             return result
 | |
| 
 | |
|         @register_lowering(torch.ops.mkldnn._linear_pointwise.binary)
 | |
|         def linear_binary(
 | |
|             x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None
 | |
|         ):
 | |
|             x_size = x.get_size()
 | |
|             if len(x_size) > 2:
 | |
|                 # GEMM template needs 2D input, normalize input shape here
 | |
|                 x = view(x, [-1, x_size[-1]])
 | |
|             y_size = y.get_size()
 | |
|             if len(y_size) > 2:
 | |
|                 y = view(y, [-1, y_size[-1]])
 | |
|             if b is not None:
 | |
|                 b = ir.ExternKernel.realize_input(b)  # type: ignore[assignment]
 | |
|             choices: list[ChoiceCaller] = []
 | |
|             if config.max_autotune or config.max_autotune_gemm:
 | |
|                 transposed_w = permute(w, [1, 0])
 | |
|                 *_, layout, x, transposed_w, y = mm_args(
 | |
|                     x, transposed_w, y, layout=layout
 | |
|                 )
 | |
|                 if use_cpp_gemm_template(layout, x, transposed_w):
 | |
| 
 | |
|                     def epilogue_creator(buf):
 | |
|                         return create_epilogue_with_attr(buf, attr, other=y)
 | |
| 
 | |
|                     kwargs = {
 | |
|                         "has_bias": b is not None,
 | |
|                         "trans_w": True,
 | |
|                         "epilogue_creator": epilogue_creator,
 | |
|                     }
 | |
| 
 | |
|                     # pyrefly: ignore [unsupported-operation]
 | |
|                     kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
 | |
|                     CppGemmTemplate.add_choices(
 | |
|                         choices,
 | |
|                         layout,
 | |
|                         [x, y, w] if b is None else [x, y, w, b],
 | |
|                         **kwargs,  # type: ignore[arg-type]
 | |
|                     )
 | |
|             if len(choices) == 0 or use_aten_gemm_kernels():
 | |
|                 kwargs = dict(attr=attr)
 | |
|                 if b is None:
 | |
|                     kwargs["B"] = None
 | |
|                 choices.append(
 | |
|                     aten_mkldnn_linear_binary.bind(
 | |
|                         [x, y, w] if b is None else [x, y, w, b],
 | |
|                         layout,
 | |
|                         **kwargs,
 | |
|                     )
 | |
|                 )
 | |
|             assert w.get_name() in V.graph.constants
 | |
|             input_gen_fns = {
 | |
|                 2: lambda x: V.graph.constants[x.get_name()],
 | |
|             }
 | |
|             result = autotune_select_algorithm(
 | |
|                 "linear_binary",
 | |
|                 choices,
 | |
|                 [x, y, w] if b is None else [x, y, w, b],
 | |
|                 layout,
 | |
|                 input_gen_fns=input_gen_fns,
 | |
|             )
 | |
|             if len(x_size) > 2:
 | |
|                 result = view(result, (*x_size[:-1], result.get_size()[-1]))
 | |
|             return result
 | |
| 
 | |
|         @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise)
 | |
|         def convolution_transpose_unary(
 | |
|             x: TensorBox,
 | |
|             weight: TensorBox,
 | |
|             bias: TensorBox,
 | |
|             padding,
 | |
|             output_padding,
 | |
|             stride,
 | |
|             dilation,
 | |
|             groups,
 | |
|             attr,
 | |
|             scalars,
 | |
|             algorithm,
 | |
|         ):
 | |
|             return TensorBox.create(
 | |
|                 mkldnn_ir.ConvolutionTransposeUnary.create(
 | |
|                     x,
 | |
|                     weight,
 | |
|                     bias,
 | |
|                     padding,
 | |
|                     output_padding,
 | |
|                     stride,
 | |
|                     dilation,
 | |
|                     groups,
 | |
|                     attr,
 | |
|                     scalars,
 | |
|                     algorithm,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         @register_lowering(aten.mkldnn_rnn_layer.default)
 | |
|         def mkldnn_rnn_layer(
 | |
|             x: TensorBox,
 | |
|             w0: TensorBox,
 | |
|             w1: TensorBox,
 | |
|             w2: TensorBox,
 | |
|             w3: TensorBox,
 | |
|             hx: TensorBox,
 | |
|             cx: TensorBox,
 | |
|             reverse: bool,
 | |
|             batch_sizes: list[int],
 | |
|             mode: int,
 | |
|             hidden_size: int,
 | |
|             num_layers: int,
 | |
|             has_biases: bool,
 | |
|             bidirectional: bool,
 | |
|             batch_first: bool,
 | |
|             train: bool,
 | |
|         ):
 | |
|             return pytree.tree_map(
 | |
|                 TensorBox.create,
 | |
|                 mkldnn_ir.MkldnnRnnLayer.create(
 | |
|                     x,
 | |
|                     w0,
 | |
|                     w1,
 | |
|                     w2,
 | |
|                     w3,
 | |
|                     hx,
 | |
|                     cx,
 | |
|                     reverse,
 | |
|                     batch_sizes,
 | |
|                     mode,
 | |
|                     hidden_size,
 | |
|                     num_layers,
 | |
|                     has_biases,
 | |
|                     bidirectional,
 | |
|                     batch_first,
 | |
|                     train,
 | |
|                 ),
 | |
|             )
 | |
| 
 | |
|         @register_lowering(torch.ops.onednn.qconv_pointwise, type_promotion_kind=None)
 | |
|         def qconvolution_unary(
 | |
|             x: TensorBox,
 | |
|             x_scale,
 | |
|             x_zp,
 | |
|             packed_weight: TensorBox,
 | |
|             w_scale: TensorBox,
 | |
|             w_zp: TensorBox,
 | |
|             bias: TensorBox,
 | |
|             stride,
 | |
|             padding,
 | |
|             dilation,
 | |
|             groups,
 | |
|             o_inv_scale,
 | |
|             o_zero_point,
 | |
|             output_dtype,
 | |
|             attr,
 | |
|             scalars,
 | |
|             algorithm,
 | |
|         ):
 | |
|             # To align with qlinear where x_scale and x_zp are converted to Tensor
 | |
|             assert type(x_scale) is float
 | |
|             x_scale = V.graph.add_tensor_constant(
 | |
|                 torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
 | |
|             )
 | |
|             assert type(x_zp) is int
 | |
|             x_zp = V.graph.add_tensor_constant(
 | |
|                 torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
 | |
|             )
 | |
| 
 | |
|             return TensorBox.create(
 | |
|                 mkldnn_ir.QConvPointWisePT2E.create(
 | |
|                     x,
 | |
|                     x_scale,
 | |
|                     x_zp,
 | |
|                     packed_weight,
 | |
|                     w_scale,
 | |
|                     w_zp,
 | |
|                     bias,
 | |
|                     stride,
 | |
|                     padding,
 | |
|                     dilation,
 | |
|                     groups,
 | |
|                     o_inv_scale,
 | |
|                     o_zero_point,
 | |
|                     output_dtype,
 | |
|                     attr,
 | |
|                     scalars,
 | |
|                     algorithm,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         @register_lowering(
 | |
|             torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None
 | |
|         )
 | |
|         @register_lowering(
 | |
|             torch.ops.onednn.qconv2d_pointwise.binary_tensor, type_promotion_kind=None
 | |
|         )
 | |
|         def qconvolution_binary(
 | |
|             x: TensorBox,
 | |
|             x_scale,
 | |
|             x_zp,
 | |
|             packed_weight: TensorBox,
 | |
|             w_scale: TensorBox,
 | |
|             w_zp: TensorBox,
 | |
|             accum: TensorBox,
 | |
|             bias: TensorBox,
 | |
|             stride,
 | |
|             padding,
 | |
|             dilation,
 | |
|             groups,
 | |
|             o_inv_scale,
 | |
|             o_zero_point,
 | |
|             output_dtype,
 | |
|             accum_scale,
 | |
|             accum_zp,
 | |
|             binary_attr,
 | |
|             alpha,
 | |
|             unary_attr,
 | |
|             unary_scalars,
 | |
|             unary_algorithmm,
 | |
|         ):
 | |
|             # To align with qlinear where x_scale and x_zp are converted to Tensor
 | |
|             assert type(x_scale) is float
 | |
|             x_scale = V.graph.add_tensor_constant(
 | |
|                 torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
 | |
|             )
 | |
|             assert type(x_zp) is int
 | |
|             x_zp = V.graph.add_tensor_constant(
 | |
|                 torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
 | |
|             )
 | |
| 
 | |
|             if (
 | |
|                 binary_attr == "sum"
 | |
|                 and output_dtype in [torch.float32, torch.bfloat16]
 | |
|                 and accum.get_dtype() in [torch.float32, torch.bfloat16]
 | |
|                 and accum.get_dtype() != output_dtype
 | |
|             ):
 | |
|                 # For int8-mixed-bf16 quantization and inplace add,
 | |
|                 # there is case when accum dtype is float32 but output dtype is bfloat16.
 | |
|                 # Since the accum will be inplaced changed with post op sum,
 | |
|                 # we will do accum dtype conversion here.
 | |
|                 accum = to_dtype(accum, output_dtype)
 | |
|             return TensorBox.create(
 | |
|                 mkldnn_ir.QConvPointWiseBinaryPT2E.create(
 | |
|                     x,
 | |
|                     x_scale,  # type: ignore[arg-type]
 | |
|                     x_zp,  # type: ignore[arg-type]
 | |
|                     packed_weight,
 | |
|                     w_scale,
 | |
|                     w_zp,
 | |
|                     accum,
 | |
|                     bias,
 | |
|                     stride,
 | |
|                     padding,
 | |
|                     dilation,
 | |
|                     groups,
 | |
|                     o_inv_scale,
 | |
|                     o_zero_point,
 | |
|                     output_dtype,
 | |
|                     accum_scale,
 | |
|                     accum_zp,
 | |
|                     binary_attr,
 | |
|                     alpha,
 | |
|                     unary_attr,
 | |
|                     unary_scalars,
 | |
|                     unary_algorithmm,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         @register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None)
 | |
|         def qlinear_unary(
 | |
|             x: TensorBox,
 | |
|             x_scale,
 | |
|             x_zp,
 | |
|             packed_weight: TensorBox,
 | |
|             w_scale: TensorBox,
 | |
|             w_zp: TensorBox,
 | |
|             bias: TensorBox,
 | |
|             o_scale,
 | |
|             o_zero_point,
 | |
|             output_dtype,
 | |
|             attr,
 | |
|             scalars,
 | |
|             algorithm,
 | |
|             layout=None,
 | |
|         ):
 | |
|             assert packed_weight.get_dtype() in [torch.int8, torch.float8_e4m3fn], (
 | |
|                 "Only int8 and e4m3fn weights are supported by oneDNN qlinear."
 | |
|             )
 | |
|             x_size = x.get_size()
 | |
|             if len(x_size) > 2:
 | |
|                 # GEMM template needs 2D input, normalize input shape here
 | |
|                 x = view(x, [-1, x_size[-1]])
 | |
|             if not isinstance(x_scale, ir.TensorBox):
 | |
|                 assert type(x_scale) is float
 | |
|                 x_scale = V.graph.add_tensor_constant(
 | |
|                     torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
 | |
|                 )
 | |
|             else:
 | |
|                 x_scale.realize()
 | |
|                 if all(dim == 1 for dim in x_scale.get_size()):
 | |
|                     # Corner-case discovered with LLaMA series.
 | |
|                     # If all outer dims of x_scale are 1, make it a 0D tensor.
 | |
|                     # Otherwise, epilogue creator will run into indexing issues.
 | |
|                     x_scale = view(x_scale, [])
 | |
|                 assert len(x_scale.get_size()) in [0, 1], "x_scale must be 0D or 1D"
 | |
| 
 | |
|             if x_zp is None:
 | |
|                 # If x_zp is None, x is int8 quantized per-tensor and its scale is not reshaped,
 | |
|                 # then the codegened code would segfault if we don't create a tensor for x_zp.
 | |
|                 # It's safe to do so since x is a symmetrically quantized int8 tensor.
 | |
|                 # Moreover, oneDNN qlinear API doesn't accept None value for zp
 | |
|                 x_zp = V.graph.add_tensor_constant(
 | |
|                     torch.tensor(0, dtype=torch.int32), name="x_zp"
 | |
|                 )
 | |
|             if not isinstance(x_zp, ir.TensorBox):
 | |
|                 assert type(x_zp) is int
 | |
|                 x_zp = V.graph.add_tensor_constant(
 | |
|                     torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
 | |
|                 )
 | |
|             else:
 | |
|                 x_zp.realize()
 | |
| 
 | |
|             assert x_zp.get_numel() == 1, "x_zp is incompatible with oneDNN qlinear"
 | |
| 
 | |
|             # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer
 | |
|             # Refer to
 | |
|             # https://github.com/pytorch/pytorch/blob/f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577  # noqa: B950
 | |
|             if w_zp is None:
 | |
|                 # If w_zp is None, then it's a dummy tensor created to denote the
 | |
|                 # absence of a zero point, and thus w is int8 symmetrically quantized.
 | |
|                 # Moreover, oneDNN qlinear API doesn't accept None value for zp
 | |
|                 # pyrefly: ignore [bad-assignment]
 | |
|                 w_zp = V.graph.add_tensor_constant(
 | |
|                     torch.tensor(0, dtype=torch.int32), name="w_zp"
 | |
|                 )
 | |
|             w_scale.realize()
 | |
|             w_zp.realize()
 | |
|             if w_zp.get_dtype() != torch.int32 and isinstance(
 | |
|                 ir.InputsKernel.unwrap_storage_for_input(w_zp),
 | |
|                 ir.ConstantBuffer,
 | |
|             ):
 | |
|                 # W_zp might be a ConstantBuffer with int64, convert it to int32
 | |
|                 w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32)
 | |
|                 w_zp = V.graph.add_tensor_constant(  # type: ignore[assignment]
 | |
|                     torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name()
 | |
|                 )
 | |
| 
 | |
|             bias_dtype = None if bias is None else bias.get_dtype()
 | |
|             choices: list[ChoiceCaller] = []
 | |
| 
 | |
|             if config.max_autotune or config.max_autotune_gemm:
 | |
|                 *_, layout, x, packed_weight = mm_args(
 | |
|                     x, packed_weight, layout=layout, out_dtype=output_dtype
 | |
|                 )
 | |
| 
 | |
|                 if (
 | |
|                     # GEMM template currently only supports symmetrically quantized weights
 | |
|                     isinstance(
 | |
|                         ir.InputsKernel.unwrap_storage_for_input(w_zp),
 | |
|                         ir.ConstantBuffer,
 | |
|                     )
 | |
|                     and torch.equal(
 | |
|                         torch.zeros_like(V.graph.constants[w_zp.get_name()]),
 | |
|                         V.graph.constants[w_zp.get_name()],
 | |
|                     )
 | |
|                 ) and use_cpp_gemm_template(layout, x, packed_weight):
 | |
|                     W_tensor = V.graph.constants[packed_weight.get_name()].to_dense()
 | |
| 
 | |
|                     (
 | |
|                         use_int8_fast_compensation_path,
 | |
|                         weight_compens,
 | |
|                         x_w_scale,
 | |
|                     ) = create_int8_compensation(
 | |
|                         W_tensor,
 | |
|                         packed_weight,
 | |
|                         # pyrefly: ignore [bad-argument-type]
 | |
|                         x_scale,
 | |
|                         # pyrefly: ignore [bad-argument-type]
 | |
|                         x_zp,
 | |
|                         w_scale,
 | |
|                     )
 | |
| 
 | |
|                     def epilogue_creator(input_buffer):
 | |
|                         # Epilogue to convert from s32 to f32 for u8s8f32
 | |
|                         assert output_dtype in [
 | |
|                             torch.float32,
 | |
|                             torch.bfloat16,
 | |
|                             torch.uint8,
 | |
|                             torch.int8,
 | |
|                         ]
 | |
|                         input_loader = input_buffer.make_loader()
 | |
|                         weight_compens_loader = weight_compens.make_loader()
 | |
|                         x_w_scale_loader = None
 | |
|                         if use_int8_fast_compensation_path:
 | |
|                             assert x_w_scale is not None
 | |
|                             x_w_scale_loader = x_w_scale.make_loader()
 | |
|                         x_scale_loader = x_scale.make_loader()
 | |
|                         w_scale_loader = w_scale.make_loader()
 | |
|                         x_zp_loader = x_zp.make_loader()
 | |
|                         nonlocal bias
 | |
|                         bias_loader = None
 | |
|                         if bias is not None:
 | |
|                             bias_loader = bias.make_loader()
 | |
| 
 | |
|                         def inner_fn(index):
 | |
|                             nonlocal bias
 | |
|                             input = input_loader(index)
 | |
|                             # MicroKernel Output is with int32
 | |
|                             # cvt to FP32 before doing compensation
 | |
|                             input = ops.to_dtype(input, torch.float32)
 | |
|                             weight_compens_index = (index[-1],)
 | |
| 
 | |
|                             _x_scale = None
 | |
|                             _x_zp = None
 | |
|                             _w_scale = None
 | |
|                             if not use_int8_fast_compensation_path:
 | |
|                                 _x_scale = x_scale_loader(())
 | |
|                                 _x_zp = x_zp_loader(())
 | |
|                                 _w_scale = w_scale_loader(weight_compens_index)
 | |
|                             _weight_compo = weight_compens_loader(weight_compens_index)
 | |
|                             _x_w_scale = None
 | |
|                             if use_int8_fast_compensation_path:
 | |
|                                 assert x_w_scale_loader is not None
 | |
|                                 _x_w_scale = x_w_scale_loader(weight_compens_index)
 | |
|                             # Step 1: Compute s8s8->s32 or u8s8->s32 GEMM & then apply compensation
 | |
|                             temp = codegen_int8_gemm_template_compensation(
 | |
|                                 use_int8_fast_compensation_path,
 | |
|                                 input,
 | |
|                                 _weight_compo,
 | |
|                                 _x_scale,
 | |
|                                 _x_zp,
 | |
|                                 _w_scale,
 | |
|                                 _x_w_scale,
 | |
|                             )
 | |
|                             # Step 2: add Bias if applicable
 | |
|                             if bias is not None:
 | |
|                                 # pyrefly: ignore [not-callable]
 | |
|                                 _bias = bias_loader(weight_compens_index)
 | |
|                                 nonlocal bias_dtype
 | |
|                                 assert bias_dtype in [torch.float32, torch.bfloat16]
 | |
|                                 if bias_dtype == torch.bfloat16:
 | |
|                                     _bias = ops.to_dtype(_bias, torch.float32)
 | |
|                                 temp = ops.add(temp, _bias)
 | |
| 
 | |
|                             return temp
 | |
| 
 | |
|                         output_buf = ir.Pointwise(
 | |
|                             device=input_buffer.get_device(),
 | |
|                             dtype=torch.float32,  # Hardcode to FP32 for u8s8f32 & s8s8f32
 | |
|                             inner_fn=inner_fn,
 | |
|                             ranges=input_buffer.get_size(),
 | |
|                         )
 | |
| 
 | |
|                         # Step 3: Doing the unary post op fusion
 | |
|                         if attr != "none":
 | |
|                             output_buf = create_epilogue_with_attr(
 | |
|                                 output_buf, attr, scalars=scalars, algorithm=algorithm
 | |
|                             )
 | |
| 
 | |
|                         # Step 4: Cast output to Target Dtype
 | |
|                         if output_dtype == torch.bfloat16:
 | |
|                             output_cast_loader = output_buf.make_loader()
 | |
| 
 | |
|                             def inner_fn_cast_output_to_bf16(index):
 | |
|                                 input = output_cast_loader(index)
 | |
|                                 return ops.to_dtype(input, output_dtype)
 | |
| 
 | |
|                             output_buf = ir.Pointwise(
 | |
|                                 device=output_buf.get_device_or_error(),
 | |
|                                 dtype=output_dtype,
 | |
|                                 inner_fn=inner_fn_cast_output_to_bf16,
 | |
|                                 ranges=output_buf.get_size(),
 | |
|                             )
 | |
|                         elif output_dtype in [torch.uint8, torch.int8]:
 | |
|                             from .lowering import _create_constants
 | |
| 
 | |
|                             requant_input_loader = output_buf.make_loader()
 | |
| 
 | |
|                             def inner_fn_requant(index, scale, zero_point):
 | |
|                                 input = requant_input_loader(index)
 | |
|                                 inv_scale, zero_point = _create_constants(
 | |
|                                     1.0 / scale, zero_point, dtype=torch.float32
 | |
|                                 )
 | |
|                                 val = ops.round(input * inv_scale) + zero_point
 | |
|                                 if output_dtype == torch.uint8:
 | |
|                                     qmin, qmax = _create_constants(
 | |
|                                         0, 255, dtype=torch.float32
 | |
|                                     )
 | |
|                                 else:
 | |
|                                     qmin, qmax = _create_constants(
 | |
|                                         -128, 127, dtype=torch.float32
 | |
|                                     )
 | |
|                                 clamped = ops.minimum(ops.maximum(val, qmin), qmax)
 | |
|                                 return ops.to_dtype(clamped, output_dtype)
 | |
| 
 | |
|                             output_buf = ir.Pointwise(
 | |
|                                 device=output_buf.get_device_or_error(),
 | |
|                                 dtype=output_dtype,
 | |
|                                 inner_fn=functools.partial(
 | |
|                                     inner_fn_requant,
 | |
|                                     scale=float(o_scale),
 | |
|                                     zero_point=int(o_zero_point),
 | |
|                                 ),
 | |
|                                 ranges=output_buf.get_size(),
 | |
|                             )
 | |
| 
 | |
|                         return output_buf
 | |
| 
 | |
|                     assert x.get_dtype() in [torch.uint8, torch.int8]
 | |
|                     CppGemmTemplate.add_choices(
 | |
|                         choices,
 | |
|                         layout,
 | |
|                         [x, x_scale, x_zp, packed_weight, w_scale, w_zp]
 | |
|                         if bias is None
 | |
|                         else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias],
 | |
|                         has_bias=bias is not None,
 | |
|                         epilogue_creator=epilogue_creator,
 | |
|                         input_indices=[0, 3, 1, 2, 4, 5]
 | |
|                         if bias is None
 | |
|                         else [6, 0, 3, 1, 2, 4, 5],
 | |
|                     )
 | |
|             if len(choices) == 0 or use_aten_gemm_kernels():
 | |
|                 kwargs = dict(
 | |
|                     output_scale=o_scale,
 | |
|                     output_zero_point=o_zero_point,
 | |
|                     output_dtype=output_dtype,
 | |
|                     post_op_name=attr,
 | |
|                     post_op_args=scalars,
 | |
|                     post_op_algorithm=algorithm,
 | |
|                 )
 | |
|                 if bias is None:
 | |
|                     kwargs["bias"] = None
 | |
|                 choices.append(
 | |
|                     aten_mkldnn_qlinear_unary.bind(
 | |
|                         (x, x_scale, x_zp, packed_weight, w_scale, w_zp)
 | |
|                         if bias is None
 | |
|                         else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias),
 | |
|                         layout,
 | |
|                         **kwargs,
 | |
|                     )
 | |
|                 )
 | |
|             assert packed_weight.get_name() in V.graph.constants
 | |
|             input_gen_fns = {
 | |
|                 3: lambda x: V.graph.constants[x.get_name()],  # packed weight
 | |
|                 4: lambda x: V.graph.constants[x.get_name()],  # weight scale
 | |
|                 5: lambda x: V.graph.constants[x.get_name()],  # weight zp
 | |
|                 6: lambda x: V.graph.constants[x.get_name()],  # bias
 | |
|             }
 | |
|             if isinstance(
 | |
|                 ir.InputsKernel.unwrap_storage_for_input(x_scale),
 | |
|                 ir.ConstantBuffer,
 | |
|             ):
 | |
|                 # x is statically quantized
 | |
|                 input_gen_fns[1] = lambda x: V.graph.constants[x.get_name()]
 | |
|             if isinstance(
 | |
|                 ir.InputsKernel.unwrap_storage_for_input(x_zp),
 | |
|                 ir.ConstantBuffer,
 | |
|             ):
 | |
|                 input_gen_fns[2] = lambda x: V.graph.constants[x.get_name()]
 | |
| 
 | |
|             result = autotune_select_algorithm(
 | |
|                 "qlinear_unary",
 | |
|                 choices,
 | |
|                 [x, x_scale, x_zp, packed_weight, w_scale, w_zp]
 | |
|                 if bias is None
 | |
|                 else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias],
 | |
|                 layout,
 | |
|                 input_gen_fns=input_gen_fns,
 | |
|             )
 | |
|             if len(x_size) > 2:
 | |
|                 result = view(result, (*x_size[:-1], result.get_size()[-1]))
 | |
|             return result
 | |
| 
 | |
|         @register_lowering(
 | |
|             torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None
 | |
|         )
 | |
|         @register_lowering(
 | |
|             torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None
 | |
|         )
 | |
|         def qlinear_binary(
 | |
|             x: TensorBox,
 | |
|             x_scale,
 | |
|             x_zp,
 | |
|             packed_weight: TensorBox,
 | |
|             w_scale: TensorBox,
 | |
|             w_zp: TensorBox,
 | |
|             x2: TensorBox,
 | |
|             bias: TensorBox,
 | |
|             o_scale,
 | |
|             o_zero_point,
 | |
|             output_dtype,
 | |
|             x2_scale,
 | |
|             x2_zp,
 | |
|             binary_attr,
 | |
|             alpha,
 | |
|             unary_attr,
 | |
|             unary_scalars,
 | |
|             unary_algorithmm,
 | |
|             layout=None,
 | |
|         ):
 | |
|             x_size = x.get_size()
 | |
|             x2_size = x2.get_size()
 | |
|             assert len(x_size) == len(x2_size)
 | |
|             if len(x_size) > 2 and binary_attr == "add":
 | |
|                 # GEMM template needs 2D input, normalize input shape here
 | |
|                 x = view(x, [-1, x_size[-1]])
 | |
|                 x2 = view(x2, [-1, x2_size[-1]])
 | |
|             if not isinstance(x_scale, ir.TensorBox):
 | |
|                 assert type(x_scale) is float
 | |
|                 x_scale = V.graph.add_tensor_constant(
 | |
|                     torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
 | |
|                 )
 | |
|             else:
 | |
|                 x_scale.realize()
 | |
|                 if all(dim == 1 for dim in x_scale.get_size()):
 | |
|                     # Corner-case discovered with LLaMA series.
 | |
|                     # If all outer dims of x_scale are 1, make it a 0D tensor.
 | |
|                     # Otherwise, epilogue creator will run into indexing issues.
 | |
|                     x_scale = view(x_scale, [])
 | |
|                 assert len(x_scale.get_size()) in [0, 1], "x_scale must be 0D or 1D"
 | |
| 
 | |
|             if x_zp is None:
 | |
|                 x_zp = V.graph.add_tensor_constant(
 | |
|                     torch.tensor(0, dtype=torch.int32), name="x_zp"
 | |
|                 )
 | |
| 
 | |
|             if w_zp is None:
 | |
|                 # pyrefly: ignore [bad-assignment]
 | |
|                 w_zp = V.graph.add_tensor_constant(
 | |
|                     torch.tensor(0, dtype=torch.int32), name="w_zp"
 | |
|                 )
 | |
| 
 | |
|             if not isinstance(x_zp, ir.TensorBox):
 | |
|                 assert type(x_zp) is int
 | |
|                 x_zp = V.graph.add_tensor_constant(
 | |
|                     torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
 | |
|                 )
 | |
|             else:
 | |
|                 x_zp.realize()
 | |
| 
 | |
|             # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer
 | |
|             # Refer to
 | |
|             # https://github.com/pytorch/pytorch/blob/f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577  # noqa: B950
 | |
|             w_scale.realize()
 | |
|             w_zp.realize()
 | |
|             if w_zp.get_dtype() != torch.int32 and isinstance(
 | |
|                 ir.InputsKernel.unwrap_storage_for_input(w_zp),
 | |
|                 ir.ConstantBuffer,
 | |
|             ):
 | |
|                 w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32)
 | |
|                 w_zp = V.graph.add_tensor_constant(  # type: ignore[assignment]
 | |
|                     torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name()
 | |
|                 )
 | |
|             if binary_attr == "sum":
 | |
|                 if output_dtype in [
 | |
|                     torch.float32,
 | |
|                     torch.bfloat16,
 | |
|                 ] and x2.get_dtype() in [torch.float32, torch.bfloat16]:
 | |
|                     if x2.get_dtype() != output_dtype:
 | |
|                         # For int8-mixed-bf16 quantization and inplace add,
 | |
|                         # there is case when accum dtype is float32 but output dtype is bfloat16.
 | |
|                         # Since the accum will be inplaced changed with post op sum,
 | |
|                         # we will do accum dtype conversion here.
 | |
|                         x2 = to_dtype(x2, output_dtype)
 | |
|                 else:
 | |
|                     assert x2.get_dtype() == output_dtype, (
 | |
|                         "dtype of accum for qlinear post op sum should be the same as output"
 | |
|                     )
 | |
|             x2_dtype = x2.get_dtype()
 | |
|             bias_dtype = bias.get_dtype() if bias is not None else None
 | |
|             choices: list[ChoiceCaller] = []
 | |
|             if (
 | |
|                 config.max_autotune or config.max_autotune_gemm
 | |
|             ) and binary_attr == "add":  # <TODO> Support inplace sum fusion
 | |
|                 *_, layout, x, packed_weight, x2 = mm_args(
 | |
|                     x, packed_weight, x2, layout=layout, out_dtype=output_dtype
 | |
|                 )
 | |
|                 if (
 | |
|                     isinstance(
 | |
|                         ir.InputsKernel.unwrap_storage_for_input(x_zp),
 | |
|                         ir.ConstantBuffer,
 | |
|                     )
 | |
|                     and len(x_zp.get_layout().size) == 0  # Per tensor quant of act
 | |
|                     and isinstance(
 | |
|                         ir.InputsKernel.unwrap_storage_for_input(w_zp),
 | |
|                         ir.ConstantBuffer,
 | |
|                     )
 | |
|                     and torch.equal(
 | |
|                         torch.zeros_like(V.graph.constants[w_zp.get_name()]),
 | |
|                         V.graph.constants[w_zp.get_name()],
 | |
|                     )  # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA
 | |
|                     and use_cpp_gemm_template(layout, x, packed_weight)
 | |
|                 ):
 | |
|                     W_tensor = V.graph.constants[packed_weight.get_name()]
 | |
|                     W_tensor = W_tensor.to_dense()
 | |
|                     (
 | |
|                         use_int8_fast_compensation_path,
 | |
|                         weight_compens,
 | |
|                         x_w_scale,
 | |
|                     ) = create_int8_compensation(
 | |
|                         W_tensor,
 | |
|                         packed_weight,
 | |
|                         # pyrefly: ignore [bad-argument-type]
 | |
|                         x_scale,
 | |
|                         # pyrefly: ignore [bad-argument-type]
 | |
|                         x_zp,
 | |
|                         w_scale,
 | |
|                     )
 | |
| 
 | |
|                     def epilogue_creator(input_buffer):
 | |
|                         # Epilogue to convert from s32 to f32 for u8s8f32
 | |
|                         assert output_dtype in [
 | |
|                             torch.float32,
 | |
|                             torch.bfloat16,
 | |
|                             torch.uint8,
 | |
|                             torch.int8,
 | |
|                         ]
 | |
| 
 | |
|                         input_loader = input_buffer.make_loader()
 | |
|                         x2_loader = x2.make_loader()
 | |
|                         weight_compens_loader = weight_compens.make_loader()
 | |
|                         x_w_scale_loader = None
 | |
|                         if use_int8_fast_compensation_path:
 | |
|                             assert x_w_scale is not None
 | |
|                             x_w_scale_loader = x_w_scale.make_loader()
 | |
|                         x_scale_loader = x_scale.make_loader()
 | |
|                         w_scale_loader = w_scale.make_loader()
 | |
|                         x_zp_loader = x_zp.make_loader()
 | |
|                         nonlocal bias
 | |
|                         bias_loader = None
 | |
|                         if bias is not None:
 | |
|                             bias_loader = bias.make_loader()
 | |
| 
 | |
|                         def inner_fn(index):
 | |
|                             nonlocal bias
 | |
|                             input = input_loader(index)
 | |
|                             _x2 = x2_loader(index)
 | |
|                             _x_scale = None
 | |
|                             _x_zp = None
 | |
|                             _w_scale = None
 | |
|                             weight_compens_index = (index[-1],)
 | |
|                             if not use_int8_fast_compensation_path:
 | |
|                                 _x_scale = x_scale_loader(())
 | |
|                                 _x_zp = x_zp_loader(())
 | |
|                                 _w_scale = w_scale_loader(weight_compens_index)
 | |
|                             # MicroKernel Output is with int32: cvt to FP32 before doing compensation
 | |
|                             input = ops.to_dtype(input, torch.float32)
 | |
|                             _weight_compo = weight_compens_loader(weight_compens_index)
 | |
|                             _x_w_scale = None
 | |
|                             if use_int8_fast_compensation_path:
 | |
|                                 assert x_w_scale_loader is not None
 | |
|                                 _x_w_scale = x_w_scale_loader(weight_compens_index)
 | |
|                             # Step 1: Doing compensation to cvt fp32
 | |
|                             temp = codegen_int8_gemm_template_compensation(
 | |
|                                 use_int8_fast_compensation_path,
 | |
|                                 input,
 | |
|                                 _weight_compo,
 | |
|                                 _x_scale,
 | |
|                                 _x_zp,
 | |
|                                 _w_scale,
 | |
|                                 _x_w_scale,
 | |
|                             )
 | |
|                             # Step 2: add Bias if applicable
 | |
|                             if bias is not None:
 | |
|                                 # pyrefly: ignore [not-callable]
 | |
|                                 _bias = bias_loader(weight_compens_index)
 | |
|                                 nonlocal bias_dtype
 | |
|                                 assert bias_dtype in [torch.float32, torch.bfloat16]
 | |
|                                 if bias_dtype == torch.bfloat16:
 | |
|                                     _bias = ops.to_dtype(_bias, torch.float32)
 | |
|                                 temp = ops.add(temp, _bias)
 | |
| 
 | |
|                             # Step 3: Binary add
 | |
|                             nonlocal x2_dtype
 | |
|                             assert x2_dtype in [torch.float32, torch.bfloat16]
 | |
|                             if x2_dtype == torch.bfloat16:
 | |
|                                 _x2 = ops.to_dtype(_x2, torch.float32)
 | |
|                             temp = ops.add(temp, _x2)
 | |
| 
 | |
|                             return temp
 | |
| 
 | |
|                         output_buf = ir.Pointwise(
 | |
|                             device=input_buffer.get_device(),
 | |
|                             dtype=torch.float32,  # Hardcode to FP32 for u8s8f32
 | |
|                             inner_fn=inner_fn,
 | |
|                             ranges=input_buffer.get_size(),
 | |
|                         )
 | |
| 
 | |
|                         # Step 4: Unary post op if has
 | |
|                         if unary_attr != "none":
 | |
|                             output_buf = create_epilogue_with_attr(
 | |
|                                 output_buf,
 | |
|                                 unary_attr,
 | |
|                                 scalars=unary_scalars,
 | |
|                                 algorithm=unary_algorithmm,
 | |
|                             )
 | |
| 
 | |
|                         # Step 5: Cast output to Target Dtype
 | |
|                         if output_dtype == torch.bfloat16:
 | |
|                             output_cast_loader = output_buf.make_loader()
 | |
| 
 | |
|                             def inner_fn_cast_output_to_bf16(index):
 | |
|                                 input = output_cast_loader(index)
 | |
|                                 return ops.to_dtype(input, output_dtype)
 | |
| 
 | |
|                             output_buf = ir.Pointwise(
 | |
|                                 device=output_buf.get_device_or_error(),
 | |
|                                 dtype=output_dtype,
 | |
|                                 inner_fn=inner_fn_cast_output_to_bf16,
 | |
|                                 ranges=output_buf.get_size(),
 | |
|                             )
 | |
|                         elif output_dtype in [torch.uint8, torch.int8]:
 | |
|                             from .lowering import _create_constants
 | |
| 
 | |
|                             requant_input_loader = output_buf.make_loader()
 | |
| 
 | |
|                             def inner_fn_requant(index, scale, zero_point):
 | |
|                                 input = requant_input_loader(index)
 | |
|                                 inv_scale, zero_point = _create_constants(
 | |
|                                     1.0 / scale, zero_point, dtype=torch.float32
 | |
|                                 )
 | |
|                                 val = ops.round(input * inv_scale) + zero_point
 | |
|                                 if output_dtype == torch.uint8:
 | |
|                                     qmin, qmax = _create_constants(
 | |
|                                         0, 255, dtype=torch.float32
 | |
|                                     )
 | |
|                                 else:
 | |
|                                     qmin, qmax = _create_constants(
 | |
|                                         -128, 127, dtype=torch.float32
 | |
|                                     )
 | |
|                                 clamped = ops.minimum(ops.maximum(val, qmin), qmax)
 | |
|                                 return ops.to_dtype(clamped, torch.uint8)
 | |
| 
 | |
|                             output_buf = ir.Pointwise(
 | |
|                                 device=output_buf.get_device_or_error(),
 | |
|                                 dtype=torch.uint8,
 | |
|                                 inner_fn=functools.partial(
 | |
|                                     inner_fn_requant,
 | |
|                                     scale=float(o_scale),
 | |
|                                     zero_point=int(o_zero_point),
 | |
|                                 ),
 | |
|                                 ranges=output_buf.get_size(),
 | |
|                             )
 | |
| 
 | |
|                         return output_buf
 | |
| 
 | |
|                     CppGemmTemplate.add_choices(
 | |
|                         choices,
 | |
|                         layout,
 | |
|                         [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2]
 | |
|                         if bias is None
 | |
|                         else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias],
 | |
|                         has_bias=bias is not None,
 | |
|                         epilogue_creator=epilogue_creator,
 | |
|                         # Reorder bias and x2
 | |
|                         input_indices=[0, 3, 1, 2, 4, 5, 6]
 | |
|                         if bias is None
 | |
|                         else [7, 0, 3, 1, 2, 4, 5, 6],
 | |
|                     )
 | |
| 
 | |
|             if len(choices) == 0 or use_aten_gemm_kernels():
 | |
|                 kwargs = dict(
 | |
|                     output_scale=o_scale,
 | |
|                     output_zero_point=o_zero_point,
 | |
|                     output_dtype=output_dtype,
 | |
|                     other_scale=x2_scale,
 | |
|                     other_zp=x2_zp,
 | |
|                     binary_post_op=binary_attr,
 | |
|                     binary_alpha=alpha,
 | |
|                     unary_post_op=unary_attr,
 | |
|                     unary_post_op_args=unary_scalars,
 | |
|                     unary_post_op_algorithm=unary_algorithmm,
 | |
|                 )
 | |
|                 if bias is None:
 | |
|                     kwargs["bias"] = None
 | |
|                 choices.append(
 | |
|                     aten_mkldnn_qlinear_binary.bind(
 | |
|                         (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2)
 | |
|                         if bias is None
 | |
|                         else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias),
 | |
|                         layout,
 | |
|                         **kwargs,
 | |
|                     )
 | |
|                 )
 | |
|             assert packed_weight.get_name() in V.graph.constants
 | |
|             input_gen_fns = {
 | |
|                 3: lambda x: V.graph.constants[x.get_name()],
 | |
|                 4: lambda x: V.graph.constants[x.get_name()],
 | |
|                 5: lambda x: V.graph.constants[x.get_name()],
 | |
|             }
 | |
|             if bias is not None:
 | |
|                 input_gen_fns[7] = lambda x: V.graph.constants[x.get_name()]  # For bias
 | |
|             result = autotune_select_algorithm(
 | |
|                 "qlinear_binary",
 | |
|                 choices,
 | |
|                 [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2]
 | |
|                 if bias is None
 | |
|                 else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias],
 | |
|                 layout,
 | |
|                 input_gen_fns=input_gen_fns,
 | |
|             )
 | |
|             if len(x_size) > 2 and binary_attr == "add":
 | |
|                 result = view(result, (*x_size[:-1], result.get_size()[-1]))
 | |
|             return result
 | |
| 
 | |
|         if torch._C.has_mkl:
 | |
|             aten_mkl_linear = ExternKernelChoice(
 | |
|                 torch.ops.mkl._mkl_linear,
 | |
|                 "mkl::_mkl_linear",
 | |
|                 has_out_variant=False,
 | |
|                 kernel_creator=mkldnn_ir.MKLPackedLinear.create,
 | |
|             )
 | |
|             cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear)
 | |
| 
 | |
|             @register_lowering(torch.ops.mkl._mkl_linear)
 | |
|             def mkl_packed_linear(
 | |
|                 x: TensorBox,
 | |
|                 packed_w: TensorBox,
 | |
|                 orig_w: TensorBox,
 | |
|                 b: Optional[TensorBox],
 | |
|                 batch_size,
 | |
|                 *,
 | |
|                 layout=None,
 | |
|             ):
 | |
|                 choices: list[ChoiceCaller] = []
 | |
|                 if config.max_autotune or config.max_autotune_gemm:
 | |
|                     transposed_w = permute(orig_w, [1, 0])
 | |
|                     *_, layout, x, transposed_w = mm_args(
 | |
|                         x, transposed_w, layout=layout
 | |
|                     )
 | |
|                     if use_cpp_gemm_template(layout, x, transposed_w):
 | |
|                         CppGemmTemplate.add_choices(
 | |
|                             choices,
 | |
|                             layout,
 | |
|                             [x, packed_w, orig_w],
 | |
|                             trans_w=True,
 | |
|                             input_indices=[0, 2],
 | |
|                         )
 | |
| 
 | |
|                 if len(choices) == 0 or use_aten_gemm_kernels():
 | |
|                     choices.append(
 | |
|                         aten_mkl_linear.bind(
 | |
|                             (x, packed_w, orig_w), layout, B=None, batch_size=batch_size
 | |
|                         )
 | |
|                     )
 | |
| 
 | |
|                 assert packed_w.get_name() in V.graph.constants
 | |
|                 assert orig_w.get_name() in V.graph.constants
 | |
|                 # packed_w is a mkldnn tensor which we can't generate directly
 | |
|                 # so we use the weights from the original tensor in autotune.
 | |
|                 input_gen_fns = {
 | |
|                     1: lambda x: V.graph.constants[x.get_name()],
 | |
|                     2: lambda x: V.graph.constants[x.get_name()],
 | |
|                 }
 | |
|                 result: TensorBox = autotune_select_algorithm(
 | |
|                     "packed_linear",
 | |
|                     choices,
 | |
|                     [x, packed_w, orig_w],
 | |
|                     layout,
 | |
|                     input_gen_fns=input_gen_fns,
 | |
|                 )
 | |
|                 if b is not None:
 | |
|                     result = add(result, b)
 | |
|                 return result
 | |
| 
 | |
|         add_needs_realized_inputs(cpu_needs_realized_inputs)
 |