mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes a ton of false negatives throughout the codebase. RUFF also properly validates NOQA comments now and most of the changes are fixing typos there or removing filewide flake8 suppressions that were also silencing ruff issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153249 Approved by: https://github.com/cyyever, https://github.com/albanD, https://github.com/seemethere
1348 lines
53 KiB
Python
1348 lines
53 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._inductor.kernel.mm_common import mm_args
|
|
|
|
from . import ir
|
|
from .codegen.cpp_gemm_template import CppGemmTemplate
|
|
from .codegen.cpp_grouped_gemm_template import CppGroupedGemmTemplate
|
|
from .codegen.cpp_utils import create_epilogue_with_attr
|
|
from .ir import TensorBox
|
|
from .lowering import (
|
|
add,
|
|
add_needs_realized_inputs,
|
|
aten,
|
|
permute,
|
|
register_lowering,
|
|
to_dtype,
|
|
view,
|
|
)
|
|
from .select_algorithm import (
|
|
autotune_select_algorithm,
|
|
ChoiceCaller,
|
|
ExternKernelChoice,
|
|
)
|
|
from .utils import use_aten_gemm_kernels, use_cpp_gemm_template, use_max_autotune
|
|
from .virtualized import ops, OpsValue, V
|
|
|
|
|
|
def create_int8_compensation(
|
|
W_tensor: torch.Tensor,
|
|
packed_weight: ir.TensorBox,
|
|
x_scale: ir.TensorBox,
|
|
x_zp: ir.TensorBox,
|
|
w_scale: ir.TensorBox,
|
|
) -> tuple[bool, ir.TensorBox, Optional[ir.TensorBox]]:
|
|
use_int8_fast_compensation_path = False
|
|
weight_compens = None
|
|
x_w_scale = None
|
|
if all(
|
|
isinstance(item, ir.TensorBox)
|
|
and item.get_name() in V.graph.constants
|
|
and hasattr(item.data, "data")
|
|
and isinstance(item.data.data, ir.ConstantBuffer)
|
|
for item in [x_scale, x_zp, w_scale]
|
|
):
|
|
use_int8_fast_compensation_path = True
|
|
x_w_scale_tensor = (
|
|
V.graph.constants[x_scale.get_name()]
|
|
* V.graph.constants[w_scale.get_name()]
|
|
)
|
|
x_w_scale = V.graph.add_tensor_constant(
|
|
x_w_scale_tensor,
|
|
name=packed_weight.get_name() + "_x_w_compens",
|
|
)
|
|
weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0)
|
|
x_zp_tensor = V.graph.constants[x_zp.get_name()]
|
|
weight_compens_tensor = weight_compens_tensor * x_w_scale_tensor * x_zp_tensor
|
|
weight_compens = V.graph.add_tensor_constant(
|
|
weight_compens_tensor,
|
|
name=packed_weight.get_name() + "_BMatrixCompens",
|
|
)
|
|
else:
|
|
weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0)
|
|
weight_compens = V.graph.add_tensor_constant(
|
|
weight_compens_tensor,
|
|
name=packed_weight.get_name() + "_BMatrixCompens",
|
|
)
|
|
return (
|
|
use_int8_fast_compensation_path,
|
|
weight_compens,
|
|
x_w_scale,
|
|
)
|
|
|
|
|
|
def codegen_int8_gemm_template_compensation(
|
|
use_int8_fast_compensation_path: bool,
|
|
input: OpsValue,
|
|
_weight_compo: OpsValue,
|
|
_x_scale: Optional[OpsValue],
|
|
_x_zp: Optional[OpsValue],
|
|
_w_scale: Optional[OpsValue],
|
|
_x_w_scale: Optional[OpsValue],
|
|
) -> OpsValue:
|
|
if use_int8_fast_compensation_path:
|
|
temp = ops.sub(
|
|
ops.mul(
|
|
input,
|
|
_x_w_scale,
|
|
),
|
|
_weight_compo,
|
|
)
|
|
else:
|
|
temp = ops.mul(
|
|
ops.mul(
|
|
input,
|
|
_x_scale,
|
|
),
|
|
_w_scale,
|
|
)
|
|
# NOTE: We will apply compensation even if the x_zp is 0 for int8 quantization.
|
|
# That's because when torch.compile is invoked for dynamic quantization,
|
|
# x might coincidentally have such values that x_zp might be zero despite
|
|
# asymmetric quantization.
|
|
# Besides, if x_zp is dummy for int8 x, or if x is statically quantized,
|
|
# we'd still perform that redundant compute to avoid making the code messy
|
|
# because we discovered that redundant computation of compensation did not
|
|
# lead to performance degradation with the input shapes tested.
|
|
temp = ops.sub(
|
|
temp,
|
|
ops.mul(
|
|
ops.mul(
|
|
ops.mul(
|
|
_x_scale,
|
|
_w_scale,
|
|
),
|
|
_x_zp,
|
|
),
|
|
_weight_compo,
|
|
),
|
|
)
|
|
return temp
|
|
|
|
|
|
def grouped_gemm_lowering(
|
|
x: TensorBox,
|
|
w: list[TensorBox],
|
|
b: list[TensorBox],
|
|
attr=None,
|
|
scalars=None,
|
|
algorithm=None,
|
|
layout=None,
|
|
):
|
|
x_size = x.get_size()
|
|
if len(x_size) > 2:
|
|
# GEMM template needs 2D input, normalize input shape here
|
|
x = view(x, [-1, x_size[-1]])
|
|
num_gemm = len(w)
|
|
|
|
assert use_max_autotune()
|
|
b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b]
|
|
|
|
choices: list[ChoiceCaller] = []
|
|
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
|
|
|
|
kwargs = dict(
|
|
has_bias=[bias is not None for bias in b],
|
|
trans_w=True,
|
|
epilogue_creator=None,
|
|
act_mapping=dict.fromkeys(range(num_gemm), x),
|
|
)
|
|
|
|
input_nodes = [x, *w]
|
|
input_nodes.extend([bias for bias in b if bias is not None])
|
|
|
|
CppGroupedGemmTemplate.add_choices(
|
|
choices,
|
|
layout,
|
|
input_nodes,
|
|
**kwargs, # type: ignore[arg-type]
|
|
)
|
|
|
|
assert len(choices) != 0
|
|
result = autotune_select_algorithm(
|
|
"grouped_gemm",
|
|
choices,
|
|
input_nodes,
|
|
layout,
|
|
)
|
|
template_buf = result.data.data
|
|
return_bufs = [
|
|
ir.MultiOutput(layout, template_buf, [(list, gemm_idx)])
|
|
for gemm_idx in range(num_gemm)
|
|
]
|
|
template_buf.layout = ir.MultiOutputLayout(device=input_nodes[0].get_device())
|
|
template_buf.outputs = return_bufs
|
|
return_tensors = [
|
|
ir.TensorBox.create(return_bufs[gemm_idx]) for gemm_idx in range(num_gemm)
|
|
]
|
|
if len(x_size) > 2:
|
|
for gemm_idx in range(num_gemm):
|
|
return_tensors[gemm_idx] = view(
|
|
return_tensors[gemm_idx],
|
|
(*x_size[:-1], return_tensors[gemm_idx].get_size()[-1]),
|
|
)
|
|
return return_tensors
|
|
|
|
|
|
grouped_gemm_lowering._inductor_lowering_function = True # type: ignore[attr-defined]
|
|
|
|
|
|
def register_onednn_fusion_ops():
|
|
if torch._C._has_mkldnn:
|
|
from . import mkldnn_ir
|
|
|
|
aten_mkldnn_linear_unary = ExternKernelChoice(
|
|
torch.ops.mkldnn._linear_pointwise,
|
|
"mkldnn::_linear_pointwise",
|
|
has_out_variant=False,
|
|
kernel_creator=mkldnn_ir.LinearUnary.create,
|
|
)
|
|
aten_mkldnn_linear_binary = ExternKernelChoice(
|
|
torch.ops.mkldnn._linear_pointwise.binary,
|
|
"mkldnn::_linear_pointwise",
|
|
has_out_variant=False,
|
|
kernel_creator=mkldnn_ir.LinearBinary.create,
|
|
)
|
|
aten_mkldnn_qlinear_unary = ExternKernelChoice(
|
|
torch.ops.onednn.qlinear_pointwise,
|
|
"onednn::qlinear_pointwise",
|
|
has_out_variant=False,
|
|
kernel_creator=mkldnn_ir.QLinearPointwisePT2E.create,
|
|
)
|
|
aten_mkldnn_qlinear_binary = ExternKernelChoice(
|
|
torch.ops.onednn.qlinear_pointwise.binary,
|
|
"onednn::qlinear_pointwise",
|
|
has_out_variant=False,
|
|
kernel_creator=mkldnn_ir.QLinearPointwiseBinaryPT2E.create,
|
|
)
|
|
cpu_needs_realized_inputs = [
|
|
torch.ops.mkldnn._convolution_pointwise,
|
|
torch.ops.mkldnn._convolution_pointwise_,
|
|
torch.ops.mkldnn._convolution_transpose_pointwise,
|
|
torch.ops.mkldnn._linear_pointwise,
|
|
aten.mkldnn_rnn_layer.default,
|
|
torch.ops.onednn.qconv_pointwise,
|
|
]
|
|
|
|
@register_lowering(torch.ops.mkldnn._convolution_pointwise)
|
|
def convolution_unary(
|
|
x: TensorBox,
|
|
weight: TensorBox,
|
|
bias: TensorBox,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
):
|
|
return TensorBox.create(
|
|
mkldnn_ir.ConvolutionUnary.create(
|
|
x,
|
|
weight,
|
|
bias,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
)
|
|
)
|
|
|
|
@register_lowering(torch.ops.mkldnn._convolution_pointwise.binary)
|
|
def convolution_binary(
|
|
x: TensorBox,
|
|
other: TensorBox,
|
|
weight: TensorBox,
|
|
bias: TensorBox,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
binary_attr,
|
|
binary_alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithm,
|
|
):
|
|
return TensorBox.create(
|
|
mkldnn_ir.ConvolutionBinary.create(
|
|
x,
|
|
other,
|
|
weight,
|
|
bias,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
binary_attr,
|
|
binary_alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithm,
|
|
)
|
|
)
|
|
|
|
@register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary)
|
|
def convolution_binary_inplace(
|
|
x: TensorBox,
|
|
other: TensorBox,
|
|
weight: TensorBox,
|
|
bias: TensorBox,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
binary_attr,
|
|
binary_alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithm,
|
|
):
|
|
return TensorBox.create(
|
|
mkldnn_ir.ConvolutionBinaryInplace.create(
|
|
x,
|
|
other,
|
|
weight,
|
|
bias,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
binary_attr,
|
|
binary_alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithm,
|
|
)
|
|
)
|
|
|
|
@register_lowering(torch.ops.mkldnn._linear_pointwise)
|
|
def linear_unary(
|
|
x: TensorBox,
|
|
w: TensorBox,
|
|
b: TensorBox,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
layout=None,
|
|
):
|
|
x_size = x.get_size()
|
|
if len(x_size) > 2:
|
|
# GEMM template needs 2D input, normalize input shape here
|
|
x = view(x, [-1, x_size[-1]])
|
|
if b is not None:
|
|
b = ir.ExternKernel.realize_input(b)
|
|
choices: list[ChoiceCaller] = []
|
|
if use_max_autotune():
|
|
transposed_w = permute(w, [1, 0])
|
|
*_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout)
|
|
if use_cpp_gemm_template(layout, x, transposed_w):
|
|
|
|
def epilogue_creator(buf):
|
|
return create_epilogue_with_attr(
|
|
buf, attr, scalars=scalars, algorithm=algorithm
|
|
)
|
|
|
|
kwargs = dict(
|
|
has_bias=b is not None,
|
|
trans_w=True,
|
|
epilogue_creator=None if attr == "none" else epilogue_creator,
|
|
)
|
|
if b is not None:
|
|
kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment]
|
|
CppGemmTemplate.add_choices(
|
|
choices,
|
|
layout,
|
|
[x, w] if b is None else [x, w, b],
|
|
**kwargs, # type: ignore[arg-type]
|
|
)
|
|
if len(choices) == 0 or use_aten_gemm_kernels():
|
|
kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm)
|
|
if b is None:
|
|
kwargs["B"] = None
|
|
choices.append(
|
|
aten_mkldnn_linear_unary.bind(
|
|
[x, w] if b is None else [x, w, b],
|
|
layout,
|
|
**kwargs,
|
|
)
|
|
)
|
|
assert w.get_name() in V.graph.constants
|
|
input_gen_fns = {
|
|
1: lambda x: V.graph.constants[x.get_name()],
|
|
}
|
|
result = autotune_select_algorithm(
|
|
"linear_unary",
|
|
choices,
|
|
[x, w] if b is None else [x, w, b],
|
|
layout,
|
|
input_gen_fns=input_gen_fns,
|
|
)
|
|
if len(x_size) > 2:
|
|
result = view(result, (*x_size[:-1], result.get_size()[-1]))
|
|
return result
|
|
|
|
@register_lowering(torch.ops.mkldnn._linear_pointwise.binary)
|
|
def linear_binary(
|
|
x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None
|
|
):
|
|
x_size = x.get_size()
|
|
if len(x_size) > 2:
|
|
# GEMM template needs 2D input, normalize input shape here
|
|
x = view(x, [-1, x_size[-1]])
|
|
y_size = y.get_size()
|
|
if len(y_size) > 2:
|
|
y = view(y, [-1, y_size[-1]])
|
|
if b is not None:
|
|
b = ir.ExternKernel.realize_input(b)
|
|
choices: list[ChoiceCaller] = []
|
|
if use_max_autotune():
|
|
transposed_w = permute(w, [1, 0])
|
|
*_, layout, x, transposed_w, y = mm_args(
|
|
x, transposed_w, y, layout=layout
|
|
)
|
|
if use_cpp_gemm_template(layout, x, transposed_w):
|
|
|
|
def epilogue_creator(buf):
|
|
return create_epilogue_with_attr(buf, attr, other=y)
|
|
|
|
kwargs = dict(
|
|
has_bias=b is not None,
|
|
trans_w=True,
|
|
epilogue_creator=epilogue_creator,
|
|
)
|
|
kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
|
|
CppGemmTemplate.add_choices(
|
|
choices,
|
|
layout,
|
|
[x, y, w] if b is None else [x, y, w, b],
|
|
**kwargs, # type: ignore[arg-type]
|
|
)
|
|
if len(choices) == 0 or use_aten_gemm_kernels():
|
|
kwargs = dict(attr=attr)
|
|
if b is None:
|
|
kwargs["B"] = None
|
|
choices.append(
|
|
aten_mkldnn_linear_binary.bind(
|
|
[x, y, w] if b is None else [x, y, w, b],
|
|
layout,
|
|
**kwargs,
|
|
)
|
|
)
|
|
assert w.get_name() in V.graph.constants
|
|
input_gen_fns = {
|
|
2: lambda x: V.graph.constants[x.get_name()],
|
|
}
|
|
result = autotune_select_algorithm(
|
|
"linear_binary",
|
|
choices,
|
|
[x, y, w] if b is None else [x, y, w, b],
|
|
layout,
|
|
input_gen_fns=input_gen_fns,
|
|
)
|
|
if len(x_size) > 2:
|
|
result = view(result, (*x_size[:-1], result.get_size()[-1]))
|
|
return result
|
|
|
|
@register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise)
|
|
def convolution_transpose_unary(
|
|
x: TensorBox,
|
|
weight: TensorBox,
|
|
bias: TensorBox,
|
|
padding,
|
|
output_padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
):
|
|
return TensorBox.create(
|
|
mkldnn_ir.ConvolutionTransposeUnary.create(
|
|
x,
|
|
weight,
|
|
bias,
|
|
padding,
|
|
output_padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
)
|
|
)
|
|
|
|
@register_lowering(aten.mkldnn_rnn_layer.default)
|
|
def mkldnn_rnn_layer(
|
|
x: TensorBox,
|
|
w0: TensorBox,
|
|
w1: TensorBox,
|
|
w2: TensorBox,
|
|
w3: TensorBox,
|
|
hx: TensorBox,
|
|
cx: TensorBox,
|
|
reverse: bool,
|
|
batch_sizes: list[int],
|
|
mode: int,
|
|
hidden_size: int,
|
|
num_layers: int,
|
|
has_biases: bool,
|
|
bidirectional: bool,
|
|
batch_first: bool,
|
|
train: bool,
|
|
):
|
|
return pytree.tree_map(
|
|
TensorBox.create,
|
|
mkldnn_ir.MkldnnRnnLayer.create(
|
|
x,
|
|
w0,
|
|
w1,
|
|
w2,
|
|
w3,
|
|
hx,
|
|
cx,
|
|
reverse,
|
|
batch_sizes,
|
|
mode,
|
|
hidden_size,
|
|
num_layers,
|
|
has_biases,
|
|
bidirectional,
|
|
batch_first,
|
|
train,
|
|
),
|
|
)
|
|
|
|
@register_lowering(torch.ops.onednn.qconv_pointwise, type_promotion_kind=None)
|
|
def qconvolution_unary(
|
|
x: TensorBox,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight: TensorBox,
|
|
w_scale: TensorBox,
|
|
w_zp: TensorBox,
|
|
bias: TensorBox,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
o_inv_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
):
|
|
# To align with qlinear where x_scale and x_zp are converted to Tensor
|
|
assert type(x_scale) == float
|
|
x_scale = V.graph.add_tensor_constant(
|
|
torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
|
|
)
|
|
assert type(x_zp) == int
|
|
x_zp = V.graph.add_tensor_constant(
|
|
torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
|
|
)
|
|
|
|
return TensorBox.create(
|
|
mkldnn_ir.QConvPointWisePT2E.create(
|
|
x,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight,
|
|
w_scale,
|
|
w_zp,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
o_inv_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
)
|
|
)
|
|
|
|
@register_lowering(
|
|
torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None
|
|
)
|
|
@register_lowering(
|
|
torch.ops.onednn.qconv2d_pointwise.binary_tensor, type_promotion_kind=None
|
|
)
|
|
def qconvolution_binary(
|
|
x: TensorBox,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight: TensorBox,
|
|
w_scale: TensorBox,
|
|
w_zp: TensorBox,
|
|
accum: TensorBox,
|
|
bias: TensorBox,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
o_inv_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
accum_scale,
|
|
accum_zp,
|
|
binary_attr,
|
|
alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithmm,
|
|
):
|
|
# To align with qlinear where x_scale and x_zp are converted to Tensor
|
|
assert type(x_scale) == float
|
|
x_scale = V.graph.add_tensor_constant(
|
|
torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
|
|
)
|
|
assert type(x_zp) == int
|
|
x_zp = V.graph.add_tensor_constant(
|
|
torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
|
|
)
|
|
|
|
if (
|
|
binary_attr == "sum"
|
|
and output_dtype in [torch.float32, torch.bfloat16]
|
|
and accum.get_dtype() in [torch.float32, torch.bfloat16]
|
|
and accum.get_dtype() != output_dtype
|
|
):
|
|
# For int8-mixed-bf16 quantization and inplace add,
|
|
# there is case when accum dtype is float32 but output dtype is bfloat16.
|
|
# Since the accum will be inplaced changed with post op sum,
|
|
# we will do accum dtype convertion here.
|
|
accum = to_dtype(accum, output_dtype)
|
|
return TensorBox.create(
|
|
mkldnn_ir.QConvPointWiseBinaryPT2E.create(
|
|
x,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight,
|
|
w_scale,
|
|
w_zp,
|
|
accum,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
o_inv_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
accum_scale,
|
|
accum_zp,
|
|
binary_attr,
|
|
alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithmm,
|
|
)
|
|
)
|
|
|
|
@register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None)
|
|
def qlinear_unary(
|
|
x: TensorBox,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight: TensorBox,
|
|
w_scale: TensorBox,
|
|
w_zp: TensorBox,
|
|
bias: TensorBox,
|
|
o_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
layout=None,
|
|
):
|
|
assert packed_weight.get_dtype() is torch.int8, (
|
|
"Only int8 weights are supported by oneDNN qlinear."
|
|
)
|
|
x_size = x.get_size()
|
|
if len(x_size) > 2:
|
|
# GEMM template needs 2D input, normalize input shape here
|
|
x = view(x, [-1, x_size[-1]])
|
|
if not isinstance(x_scale, ir.TensorBox):
|
|
assert type(x_scale) == float
|
|
x_scale = V.graph.add_tensor_constant(
|
|
torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
|
|
)
|
|
else:
|
|
x_scale.realize()
|
|
if all(dim == 1 for dim in x_scale.get_size()):
|
|
# Corner-case discovered with LLaMA series.
|
|
# If all outer dims of x_scale are 1, make it a 0D tensor.
|
|
# Otherwise, epilogue creator will run into indexing issues.
|
|
x_scale = view(x_scale, [])
|
|
assert len(x_scale.get_size()) in [0, 1], "x_scale must be 0D or 1D"
|
|
|
|
if x_zp is None:
|
|
# If x_zp is None, x is int8 quantized per-tensor and its scale is not reshaped,
|
|
# then the codegened code would segfault if we don't create a tensor for x_zp.
|
|
# It's safe to do so since x is a symmetrically quantized int8 tensor.
|
|
# Moreover, oneDNN qlinear API doesn't accept None value for zp
|
|
x_zp = V.graph.add_tensor_constant(
|
|
torch.tensor(0, dtype=torch.int32), name="x_zp"
|
|
)
|
|
if not isinstance(x_zp, ir.TensorBox):
|
|
assert type(x_zp) == int
|
|
x_zp = V.graph.add_tensor_constant(
|
|
torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
|
|
)
|
|
else:
|
|
x_zp.realize()
|
|
|
|
assert x_zp.get_numel() == 1, "x_zp is incompatible with oneDNN qlinear"
|
|
|
|
# When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer
|
|
# Refer to
|
|
# https://github.com/pytorch/pytorch/blob/f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 # noqa: B950
|
|
if w_zp is None:
|
|
# If w_zp is None, then it's a dummy tensor created to denote the
|
|
# absence of a zero point, and thus w is int8 symmetrically quantized.
|
|
# Moreover, oneDNN qlinear API doesn't accept None value for zp
|
|
w_zp = V.graph.add_tensor_constant(
|
|
torch.tensor(0, dtype=torch.int32), name="w_zp"
|
|
)
|
|
w_scale.realize()
|
|
w_zp.realize()
|
|
if w_zp.get_dtype() != torch.int32 and isinstance(
|
|
ir.InputsKernel.unwrap_storage_for_input(w_zp),
|
|
ir.ConstantBuffer,
|
|
):
|
|
# W_zp might be a ConstantBuffer with int64, convert it to int32
|
|
w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32)
|
|
w_zp = V.graph.add_tensor_constant(
|
|
torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name()
|
|
)
|
|
|
|
bias_dtype = None if bias is None else bias.get_dtype()
|
|
choices: list[ChoiceCaller] = []
|
|
|
|
if use_max_autotune():
|
|
*_, layout, x, packed_weight = mm_args(
|
|
x, packed_weight, layout=layout, out_dtype=output_dtype
|
|
)
|
|
|
|
if (
|
|
# GEMM template currently only supports symmetrically quantized weights
|
|
isinstance(
|
|
ir.InputsKernel.unwrap_storage_for_input(w_zp),
|
|
ir.ConstantBuffer,
|
|
)
|
|
and torch.equal(
|
|
torch.zeros_like(V.graph.constants[w_zp.get_name()]),
|
|
V.graph.constants[w_zp.get_name()],
|
|
)
|
|
) and use_cpp_gemm_template(layout, x, packed_weight):
|
|
W_tensor = V.graph.constants[packed_weight.get_name()].to_dense()
|
|
|
|
(
|
|
use_int8_fast_compensation_path,
|
|
weight_compens,
|
|
x_w_scale,
|
|
) = create_int8_compensation(
|
|
W_tensor,
|
|
packed_weight,
|
|
x_scale,
|
|
x_zp,
|
|
w_scale,
|
|
)
|
|
|
|
def epilogue_creator(input_buffer):
|
|
# Epilogue to convert from s32 to f32 for u8s8f32
|
|
assert output_dtype in [
|
|
torch.float32,
|
|
torch.bfloat16,
|
|
torch.uint8,
|
|
torch.int8,
|
|
]
|
|
input_loader = input_buffer.make_loader()
|
|
weight_compens_loader = weight_compens.make_loader()
|
|
x_w_scale_loader = None
|
|
if use_int8_fast_compensation_path:
|
|
assert x_w_scale is not None
|
|
x_w_scale_loader = x_w_scale.make_loader()
|
|
x_scale_loader = x_scale.make_loader()
|
|
w_scale_loader = w_scale.make_loader()
|
|
x_zp_loader = x_zp.make_loader()
|
|
nonlocal bias
|
|
bias_loader = None
|
|
if bias is not None:
|
|
bias_loader = bias.make_loader()
|
|
|
|
def inner_fn(index):
|
|
nonlocal bias
|
|
input = input_loader(index)
|
|
# MicroKernel Output is with int32
|
|
# cvt to FP32 before doing compensation
|
|
input = ops.to_dtype(input, torch.float32)
|
|
weight_compens_index = (index[-1],)
|
|
|
|
_x_scale = None
|
|
_x_zp = None
|
|
_w_scale = None
|
|
if not use_int8_fast_compensation_path:
|
|
_x_scale = x_scale_loader(())
|
|
_x_zp = x_zp_loader(())
|
|
_w_scale = w_scale_loader(weight_compens_index)
|
|
_weight_compo = weight_compens_loader(weight_compens_index)
|
|
_x_w_scale = None
|
|
if use_int8_fast_compensation_path:
|
|
assert x_w_scale_loader is not None
|
|
_x_w_scale = x_w_scale_loader(weight_compens_index)
|
|
# Step 1: Compute s8s8->s32 or u8s8->s32 GEMM & then apply compensation
|
|
temp = codegen_int8_gemm_template_compensation(
|
|
use_int8_fast_compensation_path,
|
|
input,
|
|
_weight_compo,
|
|
_x_scale,
|
|
_x_zp,
|
|
_w_scale,
|
|
_x_w_scale,
|
|
)
|
|
# Step 2: add Bias if applicable
|
|
if bias is not None:
|
|
_bias = bias_loader(weight_compens_index)
|
|
nonlocal bias_dtype
|
|
assert bias_dtype in [torch.float32, torch.bfloat16]
|
|
if bias_dtype == torch.bfloat16:
|
|
_bias = ops.to_dtype(_bias, torch.float32)
|
|
temp = ops.add(temp, _bias)
|
|
|
|
return temp
|
|
|
|
output_buf = ir.Pointwise(
|
|
device=input_buffer.get_device(),
|
|
dtype=torch.float32, # Hardcode to FP32 for u8s8f32 & s8s8f32
|
|
inner_fn=inner_fn,
|
|
ranges=input_buffer.get_size(),
|
|
)
|
|
|
|
# Step 3: Doing the unary post op fusion
|
|
if attr != "none":
|
|
output_buf = create_epilogue_with_attr(
|
|
output_buf, attr, scalars=scalars, algorithm=algorithm
|
|
)
|
|
|
|
# Step 4: Cast output to Target Dtype
|
|
if output_dtype == torch.bfloat16:
|
|
output_cast_loader = output_buf.make_loader()
|
|
|
|
def inner_fn_cast_output_to_bf16(index):
|
|
input = output_cast_loader(index)
|
|
return ops.to_dtype(input, output_dtype)
|
|
|
|
output_buf = ir.Pointwise(
|
|
device=output_buf.get_device_or_error(),
|
|
dtype=output_dtype,
|
|
inner_fn=inner_fn_cast_output_to_bf16,
|
|
ranges=output_buf.get_size(),
|
|
)
|
|
elif output_dtype in [torch.uint8, torch.int8]:
|
|
from .lowering import _create_constants
|
|
|
|
requant_input_loader = output_buf.make_loader()
|
|
|
|
def inner_fn_requant(index, scale, zero_point):
|
|
input = requant_input_loader(index)
|
|
inv_scale, zero_point = _create_constants(
|
|
1.0 / scale, zero_point, dtype=torch.float32
|
|
)
|
|
val = ops.round(input * inv_scale) + zero_point
|
|
if output_dtype == torch.uint8:
|
|
qmin, qmax = _create_constants(
|
|
0, 255, dtype=torch.float32
|
|
)
|
|
else:
|
|
qmin, qmax = _create_constants(
|
|
-128, 127, dtype=torch.float32
|
|
)
|
|
clamped = ops.minimum(ops.maximum(val, qmin), qmax)
|
|
return ops.to_dtype(clamped, output_dtype)
|
|
|
|
output_buf = ir.Pointwise(
|
|
device=output_buf.get_device_or_error(),
|
|
dtype=output_dtype,
|
|
inner_fn=functools.partial(
|
|
inner_fn_requant,
|
|
scale=float(o_scale),
|
|
zero_point=int(o_zero_point),
|
|
),
|
|
ranges=output_buf.get_size(),
|
|
)
|
|
|
|
return output_buf
|
|
|
|
assert x.get_dtype() in [torch.uint8, torch.int8]
|
|
CppGemmTemplate.add_choices(
|
|
choices,
|
|
layout,
|
|
[x, x_scale, x_zp, packed_weight, w_scale, w_zp]
|
|
if bias is None
|
|
else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias],
|
|
has_bias=bias is not None,
|
|
epilogue_creator=epilogue_creator,
|
|
input_indices=[0, 3, 1, 2, 4, 5]
|
|
if bias is None
|
|
else [6, 0, 3, 1, 2, 4, 5],
|
|
)
|
|
if len(choices) == 0 or use_aten_gemm_kernels():
|
|
kwargs = dict(
|
|
output_scale=o_scale,
|
|
output_zero_point=o_zero_point,
|
|
output_dtype=output_dtype,
|
|
post_op_name=attr,
|
|
post_op_args=scalars,
|
|
post_op_algorithm=algorithm,
|
|
)
|
|
if bias is None:
|
|
kwargs["bias"] = None
|
|
choices.append(
|
|
aten_mkldnn_qlinear_unary.bind(
|
|
(x, x_scale, x_zp, packed_weight, w_scale, w_zp)
|
|
if bias is None
|
|
else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias),
|
|
layout,
|
|
**kwargs,
|
|
)
|
|
)
|
|
assert packed_weight.get_name() in V.graph.constants
|
|
input_gen_fns = {
|
|
3: lambda x: V.graph.constants[x.get_name()], # packed weight
|
|
4: lambda x: V.graph.constants[x.get_name()], # weight scale
|
|
5: lambda x: V.graph.constants[x.get_name()], # weight zp
|
|
6: lambda x: V.graph.constants[x.get_name()], # bias
|
|
}
|
|
if isinstance(
|
|
ir.InputsKernel.unwrap_storage_for_input(x_scale),
|
|
ir.ConstantBuffer,
|
|
):
|
|
# x is statically quantized
|
|
input_gen_fns[1] = lambda x: V.graph.constants[x.get_name()]
|
|
if isinstance(
|
|
ir.InputsKernel.unwrap_storage_for_input(x_zp),
|
|
ir.ConstantBuffer,
|
|
):
|
|
input_gen_fns[2] = lambda x: V.graph.constants[x.get_name()]
|
|
|
|
result = autotune_select_algorithm(
|
|
"qlinear_unary",
|
|
choices,
|
|
[x, x_scale, x_zp, packed_weight, w_scale, w_zp]
|
|
if bias is None
|
|
else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias],
|
|
layout,
|
|
input_gen_fns=input_gen_fns,
|
|
)
|
|
if len(x_size) > 2:
|
|
result = view(result, (*x_size[:-1], result.get_size()[-1]))
|
|
return result
|
|
|
|
@register_lowering(
|
|
torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None
|
|
)
|
|
@register_lowering(
|
|
torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None
|
|
)
|
|
def qlinear_binary(
|
|
x: TensorBox,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight: TensorBox,
|
|
w_scale: TensorBox,
|
|
w_zp: TensorBox,
|
|
x2: TensorBox,
|
|
bias: TensorBox,
|
|
o_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
x2_scale,
|
|
x2_zp,
|
|
binary_attr,
|
|
alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithmm,
|
|
layout=None,
|
|
):
|
|
x_size = x.get_size()
|
|
x2_size = x2.get_size()
|
|
assert len(x_size) == len(x2_size)
|
|
if len(x_size) > 2 and binary_attr == "add":
|
|
# GEMM template needs 2D input, normalize input shape here
|
|
x = view(x, [-1, x_size[-1]])
|
|
x2 = view(x2, [-1, x2_size[-1]])
|
|
if not isinstance(x_scale, ir.TensorBox):
|
|
assert type(x_scale) == float
|
|
x_scale = V.graph.add_tensor_constant(
|
|
torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
|
|
)
|
|
else:
|
|
x_scale.realize()
|
|
if all(dim == 1 for dim in x_scale.get_size()):
|
|
# Corner-case discovered with LLaMA series.
|
|
# If all outer dims of x_scale are 1, make it a 0D tensor.
|
|
# Otherwise, epilogue creator will run into indexing issues.
|
|
x_scale = view(x_scale, [])
|
|
assert len(x_scale.get_size()) in [0, 1], "x_scale must be 0D or 1D"
|
|
|
|
if x_zp is None:
|
|
x_zp = V.graph.add_tensor_constant(
|
|
torch.tensor(0, dtype=torch.int32), name="x_zp"
|
|
)
|
|
|
|
if w_zp is None:
|
|
w_zp = V.graph.add_tensor_constant(
|
|
torch.tensor(0, dtype=torch.int32), name="w_zp"
|
|
)
|
|
|
|
if not isinstance(x_zp, ir.TensorBox):
|
|
assert type(x_zp) == int
|
|
x_zp = V.graph.add_tensor_constant(
|
|
torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
|
|
)
|
|
else:
|
|
x_zp.realize()
|
|
|
|
# When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer
|
|
# Refer to
|
|
# https://github.com/pytorch/pytorch/blob/f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 # noqa: B950
|
|
w_scale.realize()
|
|
w_zp.realize()
|
|
if w_zp.get_dtype() != torch.int32 and isinstance(
|
|
ir.InputsKernel.unwrap_storage_for_input(w_zp),
|
|
ir.ConstantBuffer,
|
|
):
|
|
w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32)
|
|
w_zp = V.graph.add_tensor_constant(
|
|
torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name()
|
|
)
|
|
if binary_attr == "sum":
|
|
if output_dtype in [
|
|
torch.float32,
|
|
torch.bfloat16,
|
|
] and x2.get_dtype() in [torch.float32, torch.bfloat16]:
|
|
if x2.get_dtype() != output_dtype:
|
|
# For int8-mixed-bf16 quantization and inplace add,
|
|
# there is case when accum dtype is float32 but output dtype is bfloat16.
|
|
# Since the accum will be inplaced changed with post op sum,
|
|
# we will do accum dtype convertion here.
|
|
x2 = to_dtype(x2, output_dtype)
|
|
else:
|
|
assert x2.get_dtype() == output_dtype, (
|
|
"dtype of accum for qlinear post op sum should be the same as output"
|
|
)
|
|
x2_dtype = x2.get_dtype()
|
|
bias_dtype = bias.get_dtype() if bias is not None else None
|
|
choices: list[ChoiceCaller] = []
|
|
if (
|
|
use_max_autotune() and binary_attr == "add"
|
|
): # <TODO> Support inplace sum fusion
|
|
*_, layout, x, packed_weight, x2 = mm_args(
|
|
x, packed_weight, x2, layout=layout, out_dtype=output_dtype
|
|
)
|
|
if (
|
|
isinstance(
|
|
ir.InputsKernel.unwrap_storage_for_input(x_zp),
|
|
ir.ConstantBuffer,
|
|
)
|
|
and len(x_zp.get_layout().size) == 0 # Per tensor quant of act
|
|
and isinstance(
|
|
ir.InputsKernel.unwrap_storage_for_input(w_zp),
|
|
ir.ConstantBuffer,
|
|
)
|
|
and torch.equal(
|
|
torch.zeros_like(V.graph.constants[w_zp.get_name()]),
|
|
V.graph.constants[w_zp.get_name()],
|
|
) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA
|
|
and use_cpp_gemm_template(layout, x, packed_weight)
|
|
):
|
|
W_tensor = V.graph.constants[packed_weight.get_name()]
|
|
W_tensor = W_tensor.to_dense()
|
|
(
|
|
use_int8_fast_compensation_path,
|
|
weight_compens,
|
|
x_w_scale,
|
|
) = create_int8_compensation(
|
|
W_tensor,
|
|
packed_weight,
|
|
x_scale,
|
|
x_zp,
|
|
w_scale,
|
|
)
|
|
|
|
def epilogue_creator(input_buffer):
|
|
# Epilogue to convert from s32 to f32 for u8s8f32
|
|
assert output_dtype in [
|
|
torch.float32,
|
|
torch.bfloat16,
|
|
torch.uint8,
|
|
torch.int8,
|
|
]
|
|
|
|
input_loader = input_buffer.make_loader()
|
|
x2_loader = x2.make_loader()
|
|
weight_compens_loader = weight_compens.make_loader()
|
|
x_w_scale_loader = None
|
|
if use_int8_fast_compensation_path:
|
|
assert x_w_scale is not None
|
|
x_w_scale_loader = x_w_scale.make_loader()
|
|
x_scale_loader = x_scale.make_loader()
|
|
w_scale_loader = w_scale.make_loader()
|
|
x_zp_loader = x_zp.make_loader()
|
|
nonlocal bias
|
|
bias_loader = None
|
|
if bias is not None:
|
|
bias_loader = bias.make_loader()
|
|
|
|
def inner_fn(index):
|
|
nonlocal bias
|
|
input = input_loader(index)
|
|
_x2 = x2_loader(index)
|
|
_x_scale = None
|
|
_x_zp = None
|
|
_w_scale = None
|
|
weight_compens_index = (index[-1],)
|
|
if not use_int8_fast_compensation_path:
|
|
_x_scale = x_scale_loader(())
|
|
_x_zp = x_zp_loader(())
|
|
_w_scale = w_scale_loader(weight_compens_index)
|
|
# MicroKernel Output is with int32: cvt to FP32 before doing compensation
|
|
input = ops.to_dtype(input, torch.float32)
|
|
_weight_compo = weight_compens_loader(weight_compens_index)
|
|
_x_w_scale = None
|
|
if use_int8_fast_compensation_path:
|
|
assert x_w_scale_loader is not None
|
|
_x_w_scale = x_w_scale_loader(weight_compens_index)
|
|
# Step 1: Doing compensation to cvt fp32
|
|
temp = codegen_int8_gemm_template_compensation(
|
|
use_int8_fast_compensation_path,
|
|
input,
|
|
_weight_compo,
|
|
_x_scale,
|
|
_x_zp,
|
|
_w_scale,
|
|
_x_w_scale,
|
|
)
|
|
# Step 2: add Bias if applicable
|
|
if bias is not None:
|
|
_bias = bias_loader(weight_compens_index)
|
|
nonlocal bias_dtype
|
|
assert bias_dtype in [torch.float32, torch.bfloat16]
|
|
if bias_dtype == torch.bfloat16:
|
|
_bias = ops.to_dtype(_bias, torch.float32)
|
|
temp = ops.add(temp, _bias)
|
|
|
|
# Step 3: Binary add
|
|
nonlocal x2_dtype
|
|
assert x2_dtype in [torch.float32, torch.bfloat16]
|
|
if x2_dtype == torch.bfloat16:
|
|
_x2 = ops.to_dtype(_x2, torch.float32)
|
|
temp = ops.add(temp, _x2)
|
|
|
|
return temp
|
|
|
|
output_buf = ir.Pointwise(
|
|
device=input_buffer.get_device(),
|
|
dtype=torch.float32, # Hardcode to FP32 for u8s8f32
|
|
inner_fn=inner_fn,
|
|
ranges=input_buffer.get_size(),
|
|
)
|
|
|
|
# Step 4: Unary post op if has
|
|
if unary_attr != "none":
|
|
output_buf = create_epilogue_with_attr(
|
|
output_buf,
|
|
unary_attr,
|
|
scalars=unary_scalars,
|
|
algorithm=unary_algorithmm,
|
|
)
|
|
|
|
# Step 5: Cast output to Target Dtype
|
|
if output_dtype == torch.bfloat16:
|
|
output_cast_loader = output_buf.make_loader()
|
|
|
|
def inner_fn_cast_output_to_bf16(index):
|
|
input = output_cast_loader(index)
|
|
return ops.to_dtype(input, output_dtype)
|
|
|
|
output_buf = ir.Pointwise(
|
|
device=output_buf.get_device_or_error(),
|
|
dtype=output_dtype,
|
|
inner_fn=inner_fn_cast_output_to_bf16,
|
|
ranges=output_buf.get_size(),
|
|
)
|
|
elif output_dtype in [torch.uint8, torch.int8]:
|
|
from .lowering import _create_constants
|
|
|
|
requant_input_loader = output_buf.make_loader()
|
|
|
|
def inner_fn_requant(index, scale, zero_point):
|
|
input = requant_input_loader(index)
|
|
inv_scale, zero_point = _create_constants(
|
|
1.0 / scale, zero_point, dtype=torch.float32
|
|
)
|
|
val = ops.round(input * inv_scale) + zero_point
|
|
if output_dtype == torch.uint8:
|
|
qmin, qmax = _create_constants(
|
|
0, 255, dtype=torch.float32
|
|
)
|
|
else:
|
|
qmin, qmax = _create_constants(
|
|
-128, 127, dtype=torch.float32
|
|
)
|
|
clamped = ops.minimum(ops.maximum(val, qmin), qmax)
|
|
return ops.to_dtype(clamped, torch.uint8)
|
|
|
|
output_buf = ir.Pointwise(
|
|
device=output_buf.get_device_or_error(),
|
|
dtype=torch.uint8,
|
|
inner_fn=functools.partial(
|
|
inner_fn_requant,
|
|
scale=float(o_scale),
|
|
zero_point=int(o_zero_point),
|
|
),
|
|
ranges=output_buf.get_size(),
|
|
)
|
|
|
|
return output_buf
|
|
|
|
CppGemmTemplate.add_choices(
|
|
choices,
|
|
layout,
|
|
[x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2]
|
|
if bias is None
|
|
else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias],
|
|
has_bias=bias is not None,
|
|
epilogue_creator=epilogue_creator,
|
|
# Reorder bias and x2
|
|
input_indices=[0, 3, 1, 2, 4, 5, 6]
|
|
if bias is None
|
|
else [7, 0, 3, 1, 2, 4, 5, 6],
|
|
)
|
|
|
|
if len(choices) == 0 or use_aten_gemm_kernels():
|
|
kwargs = dict(
|
|
output_scale=o_scale,
|
|
output_zero_point=o_zero_point,
|
|
output_dtype=output_dtype,
|
|
other_scale=x2_scale,
|
|
other_zp=x2_zp,
|
|
binary_post_op=binary_attr,
|
|
binary_alpha=alpha,
|
|
unary_post_op=unary_attr,
|
|
unary_post_op_args=unary_scalars,
|
|
unary_post_op_algorithm=unary_algorithmm,
|
|
)
|
|
if bias is None:
|
|
kwargs["bias"] = None
|
|
choices.append(
|
|
aten_mkldnn_qlinear_binary.bind(
|
|
(x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2)
|
|
if bias is None
|
|
else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias),
|
|
layout,
|
|
**kwargs,
|
|
)
|
|
)
|
|
assert packed_weight.get_name() in V.graph.constants
|
|
input_gen_fns = {
|
|
3: lambda x: V.graph.constants[x.get_name()],
|
|
4: lambda x: V.graph.constants[x.get_name()],
|
|
5: lambda x: V.graph.constants[x.get_name()],
|
|
}
|
|
if bias is not None:
|
|
input_gen_fns[7] = lambda x: V.graph.constants[x.get_name()] # For bias
|
|
result = autotune_select_algorithm(
|
|
"qlinear_binary",
|
|
choices,
|
|
[x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2]
|
|
if bias is None
|
|
else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias],
|
|
layout,
|
|
input_gen_fns=input_gen_fns,
|
|
)
|
|
if len(x_size) > 2 and binary_attr == "add":
|
|
result = view(result, (*x_size[:-1], result.get_size()[-1]))
|
|
return result
|
|
|
|
if torch._C.has_mkl:
|
|
aten_mkl_linear = ExternKernelChoice(
|
|
torch.ops.mkl._mkl_linear,
|
|
"mkl::_mkl_linear",
|
|
has_out_variant=False,
|
|
kernel_creator=mkldnn_ir.MKLPackedLinear.create,
|
|
)
|
|
cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear)
|
|
|
|
@register_lowering(torch.ops.mkl._mkl_linear)
|
|
def mkl_packed_linear(
|
|
x: TensorBox,
|
|
packed_w: TensorBox,
|
|
orig_w: TensorBox,
|
|
b: Optional[TensorBox],
|
|
batch_size,
|
|
*,
|
|
layout=None,
|
|
):
|
|
choices: list[ChoiceCaller] = []
|
|
if use_max_autotune():
|
|
transposed_w = permute(orig_w, [1, 0])
|
|
*_, layout, x, transposed_w = mm_args(
|
|
x, transposed_w, layout=layout
|
|
)
|
|
if use_cpp_gemm_template(layout, x, transposed_w):
|
|
CppGemmTemplate.add_choices(
|
|
choices,
|
|
layout,
|
|
[x, packed_w, orig_w],
|
|
trans_w=True,
|
|
input_indices=[0, 2],
|
|
)
|
|
|
|
if len(choices) == 0 or use_aten_gemm_kernels():
|
|
choices.append(
|
|
aten_mkl_linear.bind(
|
|
(x, packed_w, orig_w), layout, B=None, batch_size=batch_size
|
|
)
|
|
)
|
|
|
|
assert packed_w.get_name() in V.graph.constants
|
|
assert orig_w.get_name() in V.graph.constants
|
|
# packed_w is a mkldnn tensor which we can't generate directly
|
|
# so we use the weights from the original tensor in autotune.
|
|
input_gen_fns = {
|
|
1: lambda x: V.graph.constants[x.get_name()],
|
|
2: lambda x: V.graph.constants[x.get_name()],
|
|
}
|
|
result: TensorBox = autotune_select_algorithm(
|
|
"packed_linear",
|
|
choices,
|
|
[x, packed_w, orig_w],
|
|
layout,
|
|
input_gen_fns=input_gen_fns,
|
|
)
|
|
if b is not None:
|
|
result = add(result, b)
|
|
return result
|
|
|
|
add_needs_realized_inputs(cpu_needs_realized_inputs)
|
|
else:
|
|
pass
|