Files
pytorch/torch/_inductor/mkldnn_lowerings.py
Wu, Chunyuan d7411c0cc1 [AOTI] add C shim for QConvPointWise (#138540)
This PR adds C shim for `QConvPointWisePT2E` and `QConvPointWiseBinaryPT2E` similar to https://github.com/pytorch/pytorch/pull/138439. Besides that, we aligned the implementation of `qconv_pointwise` with `qlinear_pointwise` in the following aspects:
1. The parameter order of `qconv_pointwise` and `qlinear_pointwise` are quite different, we aligned the schema of `qconv_pointwise` to have similar parameter order as `qlinear_pointwise` to make it more consistent.
2. We always converted `x_scale` and `x_zero_point` to Tensors, just like in the lowering of `qlinear_pointwise`. This avoids the need to create two separate C APIs (one for `double x_scale` and `int64_t x_zero_point`, and another for `Tensor` versions). Instead, we only need one API for `Tensor`-based `x_scale` and `x_zero_point`. If we later add dynamic quantization for qconv (which will use `Tensor` for `x_scale` and `x_zero_point`), we can reuse the code from this PR and don't need to change the C shim layer API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138540
Approved by: https://github.com/jgong5, https://github.com/desertfire
ghstack dependencies: #138691, #138806
2024-10-31 02:03:01 +00:00

1111 lines
44 KiB
Python

# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
from typing import List, Optional
import torch
import torch.utils._pytree as pytree
from torch._inductor.kernel.mm_common import mm_args
from . import ir
from .codegen.cpp_gemm_template import CppPackedGemmTemplate
from .codegen.cpp_utils import create_epilogue_with_attr
from .ir import TensorBox
from .lowering import (
add,
add_needs_realized_inputs,
aten,
permute,
register_lowering,
to_dtype,
view,
)
from .select_algorithm import (
autotune_select_algorithm,
ChoiceCaller,
ExternKernelChoice,
)
from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune
from .virtualized import ops, V
def register_onednn_fusion_ops():
if torch._C._has_mkldnn:
from . import mkldnn_ir
aten_mkldnn_linear_unary = ExternKernelChoice(
torch.ops.mkldnn._linear_pointwise,
"mkldnn::_linear_pointwise",
has_out_variant=False,
kernel_creator=mkldnn_ir.LinearUnary.create,
)
aten_mkldnn_linear_binary = ExternKernelChoice(
torch.ops.mkldnn._linear_pointwise.binary,
"mkldnn::_linear_pointwise",
has_out_variant=False,
kernel_creator=mkldnn_ir.LinearBinary.create,
)
aten_mkldnn_qlinear_unary = ExternKernelChoice(
torch.ops.onednn.qlinear_pointwise,
"onednn::qlinear_pointwise",
has_out_variant=False,
kernel_creator=mkldnn_ir.QLinearPointwisePT2E.create,
)
aten_mkldnn_qlinear_binary = ExternKernelChoice(
torch.ops.onednn.qlinear_pointwise.binary,
"onednn::qlinear_pointwise",
has_out_variant=False,
kernel_creator=mkldnn_ir.QLinearPointwiseBinaryPT2E.create,
)
cpu_needs_realized_inputs = [
torch.ops.mkldnn._convolution_pointwise,
torch.ops.mkldnn._convolution_pointwise_,
torch.ops.mkldnn._convolution_transpose_pointwise,
torch.ops.mkldnn._linear_pointwise,
aten.mkldnn_rnn_layer.default,
torch.ops.onednn.qconv2d_pointwise,
]
@register_lowering(torch.ops.mkldnn._convolution_pointwise)
def convolution_unary(
x: TensorBox,
weight: TensorBox,
bias: TensorBox,
padding,
stride,
dilation,
groups,
attr,
scalars,
algorithm,
):
return TensorBox.create(
mkldnn_ir.ConvolutionUnary.create(
x,
weight,
bias,
padding,
stride,
dilation,
groups,
attr,
scalars,
algorithm,
)
)
@register_lowering(torch.ops.mkldnn._convolution_pointwise.binary)
def convolution_binary(
x: TensorBox,
other: TensorBox,
weight: TensorBox,
bias: TensorBox,
padding,
stride,
dilation,
groups,
binary_attr,
binary_alpha,
unary_attr,
unary_scalars,
unary_algorithm,
):
return TensorBox.create(
mkldnn_ir.ConvolutionBinary.create(
x,
other,
weight,
bias,
padding,
stride,
dilation,
groups,
binary_attr,
binary_alpha,
unary_attr,
unary_scalars,
unary_algorithm,
)
)
@register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary)
def convolution_binary_inplace(
x: TensorBox,
other: TensorBox,
weight: TensorBox,
bias: TensorBox,
padding,
stride,
dilation,
groups,
binary_attr,
binary_alpha,
unary_attr,
unary_scalars,
unary_algorithm,
):
return TensorBox.create(
mkldnn_ir.ConvolutionBinaryInplace.create(
x,
other,
weight,
bias,
padding,
stride,
dilation,
groups,
binary_attr,
binary_alpha,
unary_attr,
unary_scalars,
unary_algorithm,
)
)
@register_lowering(torch.ops.mkldnn._linear_pointwise)
def linear_unary(
x: TensorBox,
w: TensorBox,
b: TensorBox,
attr,
scalars,
algorithm,
layout=None,
):
x_size = x.get_size()
if len(x_size) > 2:
# GEMM template needs 2D input, normalize input shape here
x = view(x, [-1, x_size[-1]])
if b is not None:
b = ir.ExternKernel.realize_input(b)
choices: List[ChoiceCaller] = []
if use_max_autotune():
transposed_w = permute(w, [1, 0])
*_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout)
if use_cpp_packed_gemm_template(layout, x, transposed_w):
def epilogue_creator(buf):
return create_epilogue_with_attr(
buf, attr, scalars=scalars, algorithm=algorithm
)
kwargs = dict(
has_bias=b is not None,
trans_w=True,
epilogue_creator=None if attr == "none" else epilogue_creator,
)
if b is not None:
kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment]
CppPackedGemmTemplate.add_choices(
choices,
layout,
[x, w] if b is None else [x, w, b],
**kwargs, # type: ignore[arg-type]
)
if len(choices) == 0 or use_aten_gemm_kernels():
kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm)
if b is None:
kwargs["B"] = None
choices.append(
aten_mkldnn_linear_unary.bind(
[x, w] if b is None else [x, w, b],
layout,
**kwargs,
)
)
assert w.get_name() in V.graph.constants
input_gen_fns = {
1: lambda x: V.graph.constants[x.get_name()],
}
result = autotune_select_algorithm(
"linear_unary",
choices,
[x, w] if b is None else [x, w, b],
layout,
input_gen_fns=input_gen_fns,
)
if len(x_size) > 2:
result = view(result, (*x_size[:-1], result.get_size()[-1]))
return result
@register_lowering(torch.ops.mkldnn._linear_pointwise.binary)
def linear_binary(
x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None
):
x_size = x.get_size()
if len(x_size) > 2:
# GEMM template needs 2D input, normalize input shape here
x = view(x, [-1, x_size[-1]])
y_size = y.get_size()
if len(y_size) > 2:
y = view(y, [-1, y_size[-1]])
if b is not None:
b = ir.ExternKernel.realize_input(b)
choices: List[ChoiceCaller] = []
if use_max_autotune():
transposed_w = permute(w, [1, 0])
*_, layout, x, transposed_w, y = mm_args(
x, transposed_w, y, layout=layout
)
if use_cpp_packed_gemm_template(layout, x, transposed_w):
def epilogue_creator(buf):
return create_epilogue_with_attr(buf, attr, other=y)
kwargs = dict(
has_bias=b is not None,
trans_w=True,
epilogue_creator=epilogue_creator,
)
kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
CppPackedGemmTemplate.add_choices(
choices,
layout,
[x, y, w] if b is None else [x, y, w, b],
**kwargs, # type: ignore[arg-type]
)
if len(choices) == 0 or use_aten_gemm_kernels():
kwargs = dict(attr=attr)
if b is None:
kwargs["B"] = None
choices.append(
aten_mkldnn_linear_binary.bind(
[x, y, w] if b is None else [x, y, w, b],
layout,
**kwargs,
)
)
assert w.get_name() in V.graph.constants
input_gen_fns = {
2: lambda x: V.graph.constants[x.get_name()],
}
result = autotune_select_algorithm(
"linear_binary",
choices,
[x, y, w] if b is None else [x, y, w, b],
layout,
input_gen_fns=input_gen_fns,
)
if len(x_size) > 2:
result = view(result, (*x_size[:-1], result.get_size()[-1]))
return result
@register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise)
def convolution_transpose_unary(
x: TensorBox,
weight: TensorBox,
bias: TensorBox,
padding,
output_padding,
stride,
dilation,
groups,
attr,
scalars,
algorithm,
):
return TensorBox.create(
mkldnn_ir.ConvolutionTransposeUnary.create(
x,
weight,
bias,
padding,
output_padding,
stride,
dilation,
groups,
attr,
scalars,
algorithm,
)
)
@register_lowering(aten.mkldnn_rnn_layer.default)
def mkldnn_rnn_layer(
x: TensorBox,
w0: TensorBox,
w1: TensorBox,
w2: TensorBox,
w3: TensorBox,
hx: TensorBox,
cx: TensorBox,
reverse: bool,
batch_sizes: List[int],
mode: int,
hidden_size: int,
num_layers: int,
has_biases: bool,
bidirectional: bool,
batch_first: bool,
train: bool,
):
return pytree.tree_map(
TensorBox.create,
mkldnn_ir.MkldnnRnnLayer.create(
x,
w0,
w1,
w2,
w3,
hx,
cx,
reverse,
batch_sizes,
mode,
hidden_size,
num_layers,
has_biases,
bidirectional,
batch_first,
train,
),
)
@register_lowering(torch.ops.onednn.qconv2d_pointwise, type_promotion_kind=None)
def qconvolution_unary(
x: TensorBox,
x_scale,
x_zp,
packed_weight: TensorBox,
w_scale: TensorBox,
w_zp: TensorBox,
bias: TensorBox,
stride,
padding,
dilation,
groups,
o_inv_scale,
o_zero_point,
output_dtype,
attr,
scalars,
algorithm,
):
# To align with qlinear where x_scale and x_zp are converted to Tensor
assert type(x_scale) == float
x_scale = V.graph.add_tensor_constant(
torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
)
assert type(x_zp) == int
x_zp = V.graph.add_tensor_constant(
torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
)
return TensorBox.create(
mkldnn_ir.QConvPointWisePT2E.create(
x,
x_scale,
x_zp,
packed_weight,
w_scale,
w_zp,
bias,
stride,
padding,
dilation,
groups,
o_inv_scale,
o_zero_point,
output_dtype,
attr,
scalars,
algorithm,
)
)
@register_lowering(
torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None
)
@register_lowering(
torch.ops.onednn.qconv2d_pointwise.binary_tensor, type_promotion_kind=None
)
def qconvolution_binary(
x: TensorBox,
x_scale,
x_zp,
packed_weight: TensorBox,
w_scale: TensorBox,
w_zp: TensorBox,
accum: TensorBox,
bias: TensorBox,
stride,
padding,
dilation,
groups,
o_inv_scale,
o_zero_point,
output_dtype,
accum_scale,
accum_zp,
binary_attr,
alpha,
unary_attr,
unary_scalars,
unary_algorithmm,
):
# To align with qlinear where x_scale and x_zp are converted to Tensor
assert type(x_scale) == float
x_scale = V.graph.add_tensor_constant(
torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
)
assert type(x_zp) == int
x_zp = V.graph.add_tensor_constant(
torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
)
if (
binary_attr == "sum"
and output_dtype in [torch.float32, torch.bfloat16]
and accum.get_dtype() in [torch.float32, torch.bfloat16]
and accum.get_dtype() != output_dtype
):
# For int8-mixed-bf16 quantization and inplace add,
# there is case when accum dtype is float32 but output dtype is bfloat16.
# Since the accum will be inplaced changed with post op sum,
# we will do accum dtype convertion here.
accum = to_dtype(accum, output_dtype)
return TensorBox.create(
mkldnn_ir.QConvPointWiseBinaryPT2E.create(
x,
x_scale,
x_zp,
packed_weight,
w_scale,
w_zp,
accum,
bias,
stride,
padding,
dilation,
groups,
o_inv_scale,
o_zero_point,
output_dtype,
accum_scale,
accum_zp,
binary_attr,
alpha,
unary_attr,
unary_scalars,
unary_algorithmm,
)
)
@register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None)
def qlinear_unary(
x: TensorBox,
x_scale,
x_zp,
packed_weight: TensorBox,
w_scale: TensorBox,
w_zp: TensorBox,
bias: TensorBox,
o_scale,
o_zero_point,
output_dtype,
attr,
scalars,
algorithm,
layout=None,
):
x_size = x.get_size()
if len(x_size) > 2:
# GEMM template needs 2D input, normalize input shape here
x = view(x, [-1, x_size[-1]])
if not isinstance(x_scale, ir.TensorBox):
assert type(x_scale) == float
x_scale = V.graph.add_tensor_constant(
torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
)
else:
x_scale.realize()
if not isinstance(x_zp, ir.TensorBox):
assert type(x_zp) == int
x_zp = V.graph.add_tensor_constant(
torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
)
else:
x_zp.realize()
# When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer
# Refer to https://github.com/pytorch/pytorch/blob
# /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577
w_scale.realize()
w_zp.realize()
if w_zp.get_dtype() != torch.int32 and isinstance(
ir.InputsKernel.unwrap_storage_for_input(w_zp),
ir.ConstantBuffer,
):
# W_zp might be a ConstantBuffer with int64, convert it to int32
w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32)
w_zp = V.graph.add_tensor_constant(
torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name()
)
bias_dtype = None if bias is None else bias.get_dtype()
choices: List[ChoiceCaller] = []
if use_max_autotune():
*_, layout, x, packed_weight = mm_args(
x, packed_weight, layout=layout, out_dtype=output_dtype
)
if (
isinstance(
ir.InputsKernel.unwrap_storage_for_input(x_zp),
ir.ConstantBuffer,
)
and len(x_zp.get_layout().size) == 0 # Per tensor quant of act
and isinstance(
ir.InputsKernel.unwrap_storage_for_input(w_zp),
ir.ConstantBuffer,
)
and torch.equal(
torch.zeros_like(V.graph.constants[w_zp.get_name()]),
V.graph.constants[w_zp.get_name()],
) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA
and use_cpp_packed_gemm_template(layout, x, packed_weight)
):
W_tensor = V.graph.constants[packed_weight.get_name()].to_dense()
weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0)
weight_compens = V.graph.add_tensor_constant(
weight_compens_tensor,
name=packed_weight.get_name() + "_BMatrixCompens",
)
def epilogue_creator(input_buffer):
# Epilogue to convert from s32 to f32 for u8s8f32
assert output_dtype in [
torch.float32,
torch.bfloat16,
torch.uint8,
]
input_loader = input_buffer.make_loader()
weight_compens_loader = weight_compens.make_loader()
x_scale_loader = x_scale.make_loader()
w_scale_loader = w_scale.make_loader()
x_zp_loader = x_zp.make_loader()
nonlocal bias
bias_loader = None
if bias is not None:
bias_loader = bias.make_loader()
def inner_fn(index):
nonlocal bias
input = input_loader(index)
# MicroKernel Output is with int32
# cvt to FP32 before doing compensation
input = ops.to_dtype(input, torch.float32)
weight_compens_index = (index[-1],)
_x_scale = x_scale_loader(())
_x_zp = x_zp_loader(())
_w_scale = w_scale_loader(weight_compens_index)
_weight_compo = weight_compens_loader(weight_compens_index)
# Step 1: Doing compensation to cvt fp32
temp = ops.mul(
ops.mul(
input,
_x_scale,
),
_w_scale,
)
temp = ops.sub(
temp,
ops.mul(
ops.mul(
ops.mul(
_x_scale,
_w_scale,
),
_x_zp,
),
_weight_compo,
),
)
# Step 2: add Bias if applicable
if bias is not None:
_bias = bias_loader(weight_compens_index)
nonlocal bias_dtype
assert bias_dtype in [torch.float32, torch.bfloat16]
if bias_dtype == torch.bfloat16:
_bias = ops.to_dtype(_bias, torch.float32)
temp = ops.add(temp, _bias)
return temp
output_buf = ir.Pointwise(
device=input_buffer.get_device(),
dtype=torch.float32, # Hardcode to FP32 for u8s8f32
inner_fn=inner_fn,
ranges=input_buffer.get_size(),
)
# Step 3: Doing the unary post op fusion
if attr != "none":
output_buf = create_epilogue_with_attr(
output_buf, attr, scalars=scalars, algorithm=algorithm
)
# Step 4: Cast output to Target Dtype
if output_dtype == torch.bfloat16:
output_cast_loader = output_buf.make_loader()
def inner_fn_cast_output_to_bf16(index):
input = output_cast_loader(index)
return ops.to_dtype(input, output_dtype)
output_buf = ir.Pointwise(
device=output_buf.get_device(),
dtype=output_dtype,
inner_fn=inner_fn_cast_output_to_bf16,
ranges=output_buf.get_size(),
)
elif output_dtype == torch.uint8:
from .lowering import _create_constants
requant_input_loader = output_buf.make_loader()
def inner_fn_requant(index, scale, zero_point):
input = requant_input_loader(index)
inv_scale, zero_point = _create_constants(
1.0 / scale, zero_point, dtype=torch.float32
)
val = ops.round(input * inv_scale) + zero_point
qmin, qmax = _create_constants(
0, 255, dtype=torch.float32
)
clamped = ops.minimum(ops.maximum(val, qmin), qmax)
return ops.to_dtype(clamped, torch.uint8)
output_buf = ir.Pointwise(
device=output_buf.get_device(),
dtype=output_dtype,
inner_fn=functools.partial(
inner_fn_requant,
scale=float(o_scale),
zero_point=int(o_zero_point),
),
ranges=output_buf.get_size(),
)
return output_buf
assert x.get_dtype() == torch.uint8
CppPackedGemmTemplate.add_choices(
choices,
layout,
[x, x_scale, x_zp, packed_weight, w_scale, w_zp]
if bias is None
else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias],
has_bias=bias is not None,
epilogue_creator=epilogue_creator,
input_indices=[0, 3, 1, 2, 4, 5]
if bias is None
else [6, 0, 3, 1, 2, 4, 5],
)
if len(choices) == 0 or use_aten_gemm_kernels():
kwargs = dict(
output_scale=o_scale,
output_zero_point=o_zero_point,
output_dtype=output_dtype,
post_op_name=attr,
post_op_args=scalars,
post_op_algorithm=algorithm,
)
if bias is None:
kwargs["bias"] = None
choices.append(
aten_mkldnn_qlinear_unary.bind(
(x, x_scale, x_zp, packed_weight, w_scale, w_zp)
if bias is None
else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias),
layout,
**kwargs,
)
)
assert packed_weight.get_name() in V.graph.constants
input_gen_fns = {
3: lambda x: V.graph.constants[x.get_name()],
4: lambda x: V.graph.constants[x.get_name()],
5: lambda x: V.graph.constants[x.get_name()],
6: lambda x: V.graph.constants[x.get_name()], # For bias
}
result = autotune_select_algorithm(
"qlinear_unary",
choices,
[x, x_scale, x_zp, packed_weight, w_scale, w_zp]
if bias is None
else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias],
layout,
input_gen_fns=input_gen_fns,
)
if len(x_size) > 2:
result = view(result, (*x_size[:-1], result.get_size()[-1]))
return result
@register_lowering(
torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None
)
@register_lowering(
torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None
)
def qlinear_binary(
x: TensorBox,
x_scale,
x_zp,
packed_weight: TensorBox,
w_scale: TensorBox,
w_zp: TensorBox,
x2: TensorBox,
bias: TensorBox,
o_scale,
o_zero_point,
output_dtype,
x2_scale,
x2_zp,
binary_attr,
alpha,
unary_attr,
unary_scalars,
unary_algorithmm,
layout=None,
):
x_size = x.get_size()
x2_size = x2.get_size()
assert len(x_size) == len(x2_size)
if len(x_size) > 2 and binary_attr == "add":
# GEMM template needs 2D input, normalize input shape here
x = view(x, [-1, x_size[-1]])
x2 = view(x2, [-1, x2_size[-1]])
if not isinstance(x_scale, ir.TensorBox):
assert type(x_scale) == float
x_scale = V.graph.add_tensor_constant(
torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
)
else:
x_scale.realize()
if not isinstance(x_zp, ir.TensorBox):
assert type(x_zp) == int
x_zp = V.graph.add_tensor_constant(
torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
)
else:
x_zp.realize()
# When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer
# Refer to https://github.com/pytorch/pytorch/blob
# /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577
w_scale.realize()
w_zp.realize()
if w_zp.get_dtype() != torch.int32 and isinstance(
ir.InputsKernel.unwrap_storage_for_input(w_zp),
ir.ConstantBuffer,
):
w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32)
w_zp = V.graph.add_tensor_constant(
torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name()
)
if binary_attr == "sum":
if output_dtype in [
torch.float32,
torch.bfloat16,
] and x2.get_dtype() in [torch.float32, torch.bfloat16]:
if x2.get_dtype() != output_dtype:
# For int8-mixed-bf16 quantization and inplace add,
# there is case when accum dtype is float32 but output dtype is bfloat16.
# Since the accum will be inplaced changed with post op sum,
# we will do accum dtype convertion here.
x2 = to_dtype(x2, output_dtype)
else:
assert (
x2.get_dtype() == output_dtype
), "dtype of accum for qlinear post op sum should be the same as output"
x2_dtype = x2.get_dtype()
bias_dtype = bias.get_dtype() if bias is not None else None
choices: List[ChoiceCaller] = []
if (
use_max_autotune() and binary_attr == "add"
): # <TODO> Support inplace sum fusion
*_, layout, x, packed_weight, x2 = mm_args(
x, packed_weight, x2, layout=layout, out_dtype=output_dtype
)
if (
isinstance(
ir.InputsKernel.unwrap_storage_for_input(x_zp),
ir.ConstantBuffer,
)
and len(x_zp.get_layout().size) == 0 # Per tensor quant of act
and isinstance(
ir.InputsKernel.unwrap_storage_for_input(w_zp),
ir.ConstantBuffer,
)
and torch.equal(
torch.zeros_like(V.graph.constants[w_zp.get_name()]),
V.graph.constants[w_zp.get_name()],
) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA
and use_cpp_packed_gemm_template(layout, x, packed_weight)
):
W_tensor = V.graph.constants[packed_weight.get_name()]
W_tensor = W_tensor.to_dense()
weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0)
weight_compens = V.graph.add_tensor_constant(
weight_compens_tensor,
name=packed_weight.get_name() + "_BMatrixCompens",
)
def epilogue_creator(input_buffer):
# Epilogue to convert from s32 to f32 for u8s8f32
assert output_dtype in [
torch.float32,
torch.bfloat16,
torch.uint8,
]
input_loader = input_buffer.make_loader()
x2_loader = x2.make_loader()
weight_compens_loader = weight_compens.make_loader()
x_scale_loader = x_scale.make_loader()
w_scale_loader = w_scale.make_loader()
x_zp_loader = x_zp.make_loader()
nonlocal bias
bias_loader = None
if bias is not None:
bias_loader = bias.make_loader()
def inner_fn(index):
nonlocal bias
input = input_loader(index)
_x2 = x2_loader(index)
_x_scale = x_scale_loader(())
_x_zp = x_zp_loader(())
# MicroKernel Output is with int32
# cvt to FP32 before doing compensation
input = ops.to_dtype(input, torch.float32)
weight_compens_index = (index[-1],)
_w_scale = w_scale_loader(weight_compens_index)
_weight_compens = weight_compens_loader(
weight_compens_index
)
# Step 1: Doing compensation to cvt fp32
temp = ops.mul(
ops.mul(
input,
_x_scale,
),
_w_scale,
)
temp = ops.sub(
temp,
ops.mul(
ops.mul(
ops.mul(
_x_scale,
_w_scale,
),
_x_zp,
),
_weight_compens,
),
)
# Step 2: add Bias if applicable
if bias is not None:
_bias = bias_loader(weight_compens_index)
nonlocal bias_dtype
assert bias_dtype in [torch.float32, torch.bfloat16]
if bias_dtype == torch.bfloat16:
_bias = ops.to_dtype(_bias, torch.float32)
temp = ops.add(temp, _bias)
# Step 3: Binary add
nonlocal x2_dtype
assert x2_dtype in [torch.float32, torch.bfloat16]
if x2_dtype == torch.bfloat16:
_x2 = ops.to_dtype(_x2, torch.float32)
temp = ops.add(temp, _x2)
return temp
output_buf = ir.Pointwise(
device=input_buffer.get_device(),
dtype=torch.float32, # Hardcode to FP32 for u8s8f32
inner_fn=inner_fn,
ranges=input_buffer.get_size(),
)
# Step 4: Unary post op if has
if unary_attr != "none":
output_buf = create_epilogue_with_attr(
output_buf,
unary_attr,
scalars=unary_scalars,
algorithm=unary_algorithmm,
)
# Step 5: Cast output to Target Dtype
if output_dtype == torch.bfloat16:
output_cast_loader = output_buf.make_loader()
def inner_fn_cast_output_to_bf16(index):
input = output_cast_loader(index)
return ops.to_dtype(input, output_dtype)
output_buf = ir.Pointwise(
device=output_buf.get_device(),
dtype=output_dtype,
inner_fn=inner_fn_cast_output_to_bf16,
ranges=output_buf.get_size(),
)
elif output_dtype == torch.uint8:
from .lowering import _create_constants
requant_input_loader = output_buf.make_loader()
def inner_fn_requant(index, scale, zero_point):
input = requant_input_loader(index)
inv_scale, zero_point = _create_constants(
1.0 / scale, zero_point, dtype=torch.float32
)
val = ops.round(input * inv_scale) + zero_point
qmin, qmax = _create_constants(
0, 255, dtype=torch.float32
)
clamped = ops.minimum(ops.maximum(val, qmin), qmax)
return ops.to_dtype(clamped, torch.uint8)
output_buf = ir.Pointwise(
device=output_buf.get_device(),
dtype=torch.uint8,
inner_fn=functools.partial(
inner_fn_requant,
scale=float(o_scale),
zero_point=int(o_zero_point),
),
ranges=output_buf.get_size(),
)
return output_buf
CppPackedGemmTemplate.add_choices(
choices,
layout,
[x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2]
if bias is None
else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias],
has_bias=bias is not None,
epilogue_creator=epilogue_creator,
# Reorder bias and x2
input_indices=[0, 3, 1, 2, 4, 5, 6]
if bias is None
else [7, 0, 3, 1, 2, 4, 5, 6],
)
if len(choices) == 0 or use_aten_gemm_kernels():
kwargs = dict(
output_scale=o_scale,
output_zero_point=o_zero_point,
output_dtype=output_dtype,
other_scale=x2_scale,
other_zp=x2_zp,
binary_post_op=binary_attr,
binary_alpha=alpha,
unary_post_op=unary_attr,
unary_post_op_args=unary_scalars,
unary_post_op_algorithm=unary_algorithmm,
)
if bias is None:
kwargs["bias"] = None
choices.append(
aten_mkldnn_qlinear_binary.bind(
(x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2)
if bias is None
else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias),
layout,
**kwargs,
)
)
assert packed_weight.get_name() in V.graph.constants
input_gen_fns = {
3: lambda x: V.graph.constants[x.get_name()],
4: lambda x: V.graph.constants[x.get_name()],
5: lambda x: V.graph.constants[x.get_name()],
}
if bias is not None:
input_gen_fns[7] = lambda x: V.graph.constants[x.get_name()] # For bias
result = autotune_select_algorithm(
"qlinear_binary",
choices,
[x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2]
if bias is None
else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias],
layout,
input_gen_fns=input_gen_fns,
)
if len(x_size) > 2 and binary_attr == "add":
result = view(result, (*x_size[:-1], result.get_size()[-1]))
return result
if torch._C.has_mkl:
aten_mkl_linear = ExternKernelChoice(
torch.ops.mkl._mkl_linear,
"mkl::_mkl_linear",
has_out_variant=False,
kernel_creator=mkldnn_ir.MKLPackedLinear.create,
)
cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear)
@register_lowering(torch.ops.mkl._mkl_linear)
def mkl_packed_linear(
x: TensorBox,
packed_w: TensorBox,
orig_w: TensorBox,
b: Optional[TensorBox],
batch_size,
*,
layout=None,
):
choices: List[ChoiceCaller] = []
if use_max_autotune():
transposed_w = permute(orig_w, [1, 0])
*_, layout, x, transposed_w = mm_args(
x, transposed_w, layout=layout
)
if use_cpp_packed_gemm_template(layout, x, transposed_w):
CppPackedGemmTemplate.add_choices(
choices,
layout,
[x, packed_w, orig_w],
trans_w=True,
input_indices=[0, 2],
)
if len(choices) == 0 or use_aten_gemm_kernels():
choices.append(
aten_mkl_linear.bind(
(x, packed_w, orig_w), layout, B=None, batch_size=batch_size
)
)
assert packed_w.get_name() in V.graph.constants
assert orig_w.get_name() in V.graph.constants
# packed_w is a mkldnn tensor which we can't generate directly
# so we use the weights from the original tensor in autotune.
input_gen_fns = {
1: lambda x: V.graph.constants[x.get_name()],
2: lambda x: V.graph.constants[x.get_name()],
}
result: TensorBox = autotune_select_algorithm(
"packed_linear",
choices,
[x, packed_w, orig_w],
layout,
input_gen_fns=input_gen_fns,
)
if b is not None:
result = add(result, b)
return result
add_needs_realized_inputs(cpu_needs_realized_inputs)
else:
pass