mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156572 Approved by: https://github.com/EikanWang, https://github.com/angelayi ghstack dependencies: #157430
1353 lines
42 KiB
Python
1353 lines
42 KiB
Python
# mypy: allow-untyped-defs
|
|
from collections.abc import Sequence
|
|
from typing import Any, Optional, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch._prims_common import make_channels_last_strides_for, StrideType
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from .ir import (
|
|
ExternKernelAlloc,
|
|
FixedLayout,
|
|
FlexibleLayout,
|
|
get_device_type,
|
|
ir_node_to_tensor,
|
|
IRNode,
|
|
is_contiguous_storage_and_layout,
|
|
Layout,
|
|
may_convert_to_optional,
|
|
MultiOutput,
|
|
MultiOutputLayout,
|
|
MutationOutput,
|
|
NoneLayout,
|
|
ShapeAsConstantBuffer,
|
|
TensorBox,
|
|
)
|
|
from .utils import convert_shape_to_inductor, pad_listlike, SUPPORTED_MKLDNN_DEVICES
|
|
from .virtualized import V
|
|
|
|
|
|
def _prepare_convolution_fusion_create(
|
|
cls,
|
|
x: "TensorBox",
|
|
weight: "TensorBox",
|
|
bias: "TensorBox",
|
|
padding: Sequence[int],
|
|
stride: Sequence[int],
|
|
dilation: Sequence[int],
|
|
groups: int,
|
|
transposed: bool = False,
|
|
output_padding: Optional[Sequence[int]] = None,
|
|
quantize_args: Optional[list["TensorBox"]] = None,
|
|
other: Optional["TensorBox"] = None,
|
|
):
|
|
"""
|
|
This function is a helper function to prepare inputs, layout and constant args
|
|
for convolution post-op fusion's create function, including deciding the output
|
|
layout (channels first or channels last), realizing inputs and make them etc. The
|
|
function only supports the CPU/XPU device since conv post-op fusion kernel is only
|
|
supported on CPU/XPU right now.
|
|
"""
|
|
|
|
# Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size
|
|
def _conv_input_size(
|
|
output_size, weight_size, padding, output_padding, stride, dilation, groups
|
|
):
|
|
assert len(output_size) == len(weight_size), "Expect input dim == weight dim"
|
|
dim = len(output_size)
|
|
assert dim > 2, "Expect input dim > 2"
|
|
|
|
BATCH_DIM = 0
|
|
WEIGHT_INPUT_CHANNELS_DIM = 1
|
|
input_size = []
|
|
input_size.append(output_size[BATCH_DIM])
|
|
input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups)
|
|
for d in range(2, dim):
|
|
kernel = (weight_size[d] - 1) * dilation[d - 2] + 1
|
|
input_size_d = (
|
|
(output_size[d] - 1) * stride[d - 2]
|
|
- (padding[d - 2] * 2)
|
|
+ kernel
|
|
+ output_padding[d - 2]
|
|
)
|
|
input_size.append(input_size_d)
|
|
return list(map(int, input_size))
|
|
|
|
# Port from aten/src/ATen/native/ConvUtils.h: _conv_output_size
|
|
def _conv_output_size(input_size, weight_size, padding, stride, dilation=None):
|
|
has_dilation = dilation is not None
|
|
dim = len(input_size)
|
|
output_size = []
|
|
output_size.append(input_size[0])
|
|
output_size.append(weight_size[0])
|
|
for d in range(2, dim):
|
|
dilation_ = dilation[d - 2] if has_dilation else 1
|
|
kernel = dilation_ * (weight_size[d] - 1) + 1
|
|
output_size_d = (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[
|
|
d - 2
|
|
] + 1
|
|
output_size.append(output_size_d)
|
|
return output_size
|
|
|
|
# The size of prepacked_weight is the prepacked weight size of deconv:
|
|
# Groups > 1: [g*o, i/g, ...]
|
|
# Groups == 1: [o, i, ...]
|
|
# Returns original weight size in [i, o, ...]
|
|
def _original_deconv_weight_size(
|
|
prepacked_weight,
|
|
groups,
|
|
):
|
|
prepacked_weight_size = prepacked_weight.size()
|
|
dim = len(prepacked_weight_size)
|
|
assert dim > 2, "Expect weight dim > 2"
|
|
if groups > 1:
|
|
weight_size = []
|
|
weight_size.append(prepacked_weight_size[1] * groups)
|
|
weight_size.append(prepacked_weight_size[0] / groups)
|
|
weight_size.extend(prepacked_weight_size[d] for d in range(2, dim))
|
|
else:
|
|
weight_size = prepacked_weight.transpose(0, 1).size()
|
|
return weight_size
|
|
|
|
x.realize()
|
|
weight.realize()
|
|
if bias is not None:
|
|
bias.realize()
|
|
with V.graph.fake_mode:
|
|
# TODO <Leslie> cleaned up the fake_tensor trace as Linear implementation
|
|
x_fake = ir_node_to_tensor(x, guard_shape=True)
|
|
weight_fake = ir_node_to_tensor(weight, guard_shape=True)
|
|
dims = len(x_fake.size()) - 2
|
|
assert 0 < len(padding) <= dims
|
|
assert 0 < len(dilation) <= dims
|
|
assert 0 < len(stride) <= dims
|
|
padding = pad_listlike(padding, dims)
|
|
dilation = pad_listlike(dilation, dims)
|
|
stride = pad_listlike(stride, dims)
|
|
if output_padding is None:
|
|
output_padding = pad_listlike([0], dims)
|
|
else:
|
|
assert 0 < len(output_padding) <= dims
|
|
output_padding = pad_listlike(output_padding, dims)
|
|
assert isinstance(groups, (int, sympy.core.numbers.Integer))
|
|
if transposed:
|
|
# When transposed, the size of the prepacked oneDNN weight is different
|
|
# from the PyTorch weight. We're not able to run aten conv with such
|
|
# size. We infer the output size from the input params here:
|
|
weight_size = _original_deconv_weight_size(weight_fake, groups)
|
|
input_size = x_fake.size()
|
|
output_size = _conv_input_size(
|
|
input_size,
|
|
weight_size,
|
|
padding,
|
|
output_padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
)
|
|
else:
|
|
x_shape = list(x_fake.shape)
|
|
weight_shape = list(weight_fake.shape)
|
|
if len(x_shape) != len(weight_shape):
|
|
assert len(x_shape) == 3 and len(weight_shape) == 4
|
|
weight_shape.pop(2)
|
|
output_size = _conv_output_size(
|
|
x_shape,
|
|
weight_shape,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
)
|
|
|
|
req_stride_order = [0] + list(reversed(range(1, len(stride) + 1)))
|
|
req_stride_order = [len(req_stride_order)] + req_stride_order
|
|
|
|
x = cls.require_stride_order(x, req_stride_order)
|
|
|
|
# We won't do weight prepack for Conv if dynamic_shapes or if is xpu.
|
|
# In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel.
|
|
# In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1),
|
|
# x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order
|
|
# won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel,
|
|
# this tensor is considered as channels first and the output will be in contiguous format.
|
|
# To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last.
|
|
dynamic_shapes = not all(isinstance(i, int) for i in (output_size))
|
|
if (
|
|
dynamic_shapes or get_device_type(x) == "xpu"
|
|
) and is_contiguous_storage_and_layout(x):
|
|
output_stride: StrideType = FlexibleLayout.contiguous_strides(output_size)
|
|
# Currently we don't support channel last for the situation that stride of input's batch dim is 0,
|
|
# eg. input_size = (1, 1280, 64, 64), but input_stride=(0, 1, 81920, 1280).
|
|
# So we use NCHW hear instead.
|
|
# Different with cpu, cpu conv always use channels_last for convolution when weight is prepacked,
|
|
# but xpu does not do the prepack, so the problem exposed here is only for xpu.
|
|
# TODO support channels_last for such zero stride input.
|
|
elif get_device_type(x) == "xpu" and x.get_stride()[0] == 0:
|
|
output_stride = FlexibleLayout.contiguous_strides(output_size)
|
|
else:
|
|
output_stride = make_channels_last_strides_for(output_size)
|
|
|
|
assert get_device_type(x) == get_device_type(weight)
|
|
assert get_device_type(x) in SUPPORTED_MKLDNN_DEVICES
|
|
inputs = [x]
|
|
|
|
if quantize_args is not None:
|
|
x_scale, x_zero_point, w_scale, w_zero_point = quantize_args
|
|
x_scale.realize()
|
|
x_zero_point.realize()
|
|
w_scale.realize()
|
|
w_zero_point.realize()
|
|
inputs = inputs + [x_scale, x_zero_point] + [weight] + [w_scale, w_zero_point]
|
|
else:
|
|
inputs += [weight]
|
|
|
|
if other is not None:
|
|
other = cls.require_stride_order(other, req_stride_order)
|
|
assert isinstance(other, TensorBox)
|
|
inputs += [other]
|
|
|
|
kernel_layout = FixedLayout(
|
|
x.get_device_or_error(),
|
|
x.get_dtype(),
|
|
convert_shape_to_inductor(output_size),
|
|
convert_shape_to_inductor(output_stride),
|
|
)
|
|
constant_args = [padding, stride, dilation, groups]
|
|
if transposed:
|
|
constant_args.insert(1, output_padding)
|
|
|
|
if bias is not None:
|
|
inputs.append(bias)
|
|
else:
|
|
constant_args.insert(0, bias)
|
|
return inputs, constant_args, kernel_layout, req_stride_order, other
|
|
|
|
|
|
def _prepare_linear_fusion_create(
|
|
cls,
|
|
x: "TensorBox",
|
|
weight: "TensorBox",
|
|
bias: "TensorBox",
|
|
quantize_args: Optional[list["TensorBox"]] = None,
|
|
other: Optional["TensorBox"] = None,
|
|
binary_sum: bool = False,
|
|
):
|
|
"""
|
|
This function is a helper function to prepare inputs, layout and constant args
|
|
for linear post-op fusion's create function. The function only supports the CPU device
|
|
since linear post-op fusion kernel is only supported on CPU right now.
|
|
"""
|
|
x.realize()
|
|
weight.realize()
|
|
if bias is not None:
|
|
bias.realize()
|
|
|
|
*m, _ = x.get_size()
|
|
# The weight has been transposed during the qlinear weight prepack process.
|
|
# https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/
|
|
# aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291
|
|
_, oc = weight.get_size()
|
|
output_size = list(m) + [oc]
|
|
req_stride_order = list(reversed(range(len(x.get_size()))))
|
|
|
|
x = cls.require_stride_order(x, req_stride_order)
|
|
assert get_device_type(x) == get_device_type(weight)
|
|
assert get_device_type(x) in SUPPORTED_MKLDNN_DEVICES
|
|
inputs = [x]
|
|
|
|
if quantize_args is not None:
|
|
x_scale, x_zero_point, w_scale, w_zero_point = quantize_args
|
|
x_scale.realize()
|
|
x_zero_point.realize()
|
|
w_scale.realize()
|
|
w_zero_point.realize()
|
|
inputs = inputs + [x_scale, x_zero_point] + [weight] + [w_scale, w_zero_point]
|
|
else:
|
|
inputs += [weight]
|
|
|
|
if other is not None:
|
|
if binary_sum:
|
|
other = cls.require_stride_order(other, req_stride_order)
|
|
inputs = inputs + [other]
|
|
|
|
output_stride = FlexibleLayout.contiguous_strides(output_size)
|
|
kernel_layout = FixedLayout(
|
|
x.get_device(),
|
|
x.get_dtype(),
|
|
output_size,
|
|
output_stride,
|
|
)
|
|
constant_args: list[Any] = []
|
|
|
|
if bias is not None:
|
|
inputs.append(bias)
|
|
else:
|
|
constant_args.insert(0, bias)
|
|
return inputs, constant_args, kernel_layout, req_stride_order, other
|
|
|
|
|
|
def _create_output_node(packed):
|
|
output_ir = MultiOutput(
|
|
packed.get_layout(),
|
|
packed,
|
|
[],
|
|
)
|
|
packed.layout = MultiOutputLayout(device=packed.get_device())
|
|
packed.outputs = [output_ir]
|
|
return output_ir
|
|
|
|
|
|
class ConvolutionUnary(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
) -> None:
|
|
self.device_type = get_device_type(inputs[0])
|
|
super().__init__(
|
|
layout,
|
|
inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=torch.ops.mkldnn._convolution_pointwise.default,
|
|
cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_pointwise",
|
|
)
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header(
|
|
f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h"
|
|
)
|
|
super().codegen(wrapper)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
x: "TensorBox",
|
|
weight: "TensorBox",
|
|
bias: "TensorBox",
|
|
padding_: list[int],
|
|
stride_: list[int],
|
|
dilation_: list[int],
|
|
groups: int,
|
|
attr,
|
|
scalars: Optional[list[Any]],
|
|
algorithm,
|
|
):
|
|
(
|
|
inputs,
|
|
constant_args,
|
|
kernel_layout,
|
|
_,
|
|
_,
|
|
) = _prepare_convolution_fusion_create(
|
|
cls, x, weight, bias, padding_, stride_, dilation_, groups
|
|
)
|
|
constant_args = constant_args + [
|
|
attr,
|
|
may_convert_to_optional(scalars),
|
|
algorithm,
|
|
]
|
|
packed = ConvolutionUnary(
|
|
layout=kernel_layout,
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
)
|
|
return _create_output_node(packed)
|
|
|
|
|
|
class ConvolutionBinary(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
cpp_constant_args=(),
|
|
) -> None:
|
|
self.device_type = get_device_type(inputs[0])
|
|
super().__init__(
|
|
layout,
|
|
inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=torch.ops.mkldnn._convolution_pointwise.binary,
|
|
cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_pointwise_binary",
|
|
)
|
|
self.cpp_constant_args = cpp_constant_args
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header(
|
|
f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h"
|
|
)
|
|
super().codegen(wrapper)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
x: "TensorBox",
|
|
other: "TensorBox",
|
|
weight: "TensorBox",
|
|
bias: "TensorBox",
|
|
padding_: list[int],
|
|
stride_: list[int],
|
|
dilation_: list[int],
|
|
groups: int,
|
|
binary_attr: str,
|
|
binary_alpha: Optional[float],
|
|
unary_attr: Optional[str],
|
|
unary_scalars: Optional[list[Any]],
|
|
unary_algorithm: Optional[str],
|
|
):
|
|
(
|
|
inputs,
|
|
constant_args,
|
|
kernel_layout,
|
|
req_stride_order,
|
|
_,
|
|
) = _prepare_convolution_fusion_create(
|
|
cls, x, weight, bias, padding_, stride_, dilation_, groups
|
|
)
|
|
other = cls.require_stride_order(other, req_stride_order)
|
|
inputs.insert(1, other)
|
|
constant_args = constant_args + [
|
|
binary_attr,
|
|
binary_alpha,
|
|
unary_attr,
|
|
may_convert_to_optional(unary_scalars),
|
|
unary_algorithm,
|
|
]
|
|
packed = ConvolutionBinary(
|
|
layout=kernel_layout,
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
)
|
|
return _create_output_node(packed)
|
|
|
|
|
|
class ConvolutionBinaryInplace(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
kernel_layout,
|
|
inputs,
|
|
constant_args=(),
|
|
) -> None:
|
|
# Due to constrain of op.call, other (Tensor&) should be at input[0]
|
|
self.device_type = get_device_type(inputs[0])
|
|
reordered_inputs = [inputs[1], inputs[0]] + inputs[2:]
|
|
|
|
super().__init__(
|
|
kernel_layout,
|
|
reordered_inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=torch.ops.mkldnn._convolution_pointwise_.binary,
|
|
cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_pointwise_binary_",
|
|
)
|
|
|
|
self.mutation_outputs = [
|
|
MutationOutput(NoneLayout(device=inputs[0].get_device()), inputs[0], self),
|
|
MutationOutput(NoneLayout(device=inputs[1].get_device()), inputs[1], self),
|
|
]
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header(
|
|
f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h"
|
|
)
|
|
super().codegen(wrapper)
|
|
|
|
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
|
return OrderedSet()
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
x: "TensorBox",
|
|
other: "TensorBox",
|
|
weight: "TensorBox",
|
|
bias: "TensorBox",
|
|
padding_: list[int],
|
|
stride_: list[int],
|
|
dilation_: list[int],
|
|
groups: int,
|
|
binary_attr: str,
|
|
binary_alpha: Optional[float],
|
|
unary_attr: Optional[str],
|
|
unary_scalars: Optional[list[Any]],
|
|
unary_algorithm: Optional[str],
|
|
):
|
|
(
|
|
inputs,
|
|
constant_args,
|
|
_,
|
|
req_stride_order,
|
|
_,
|
|
) = _prepare_convolution_fusion_create(
|
|
cls, x, weight, bias, padding_, stride_, dilation_, groups
|
|
)
|
|
other = cls.require_stride_order(other, req_stride_order)
|
|
inputs.insert(1, other)
|
|
constant_args = constant_args + [
|
|
binary_attr,
|
|
binary_alpha,
|
|
unary_attr,
|
|
may_convert_to_optional(unary_scalars),
|
|
unary_algorithm,
|
|
]
|
|
packed = ConvolutionBinaryInplace(
|
|
kernel_layout=NoneLayout(device=inputs[1].get_device()), # type: ignore[arg-type]
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
)
|
|
# This op mutates in place which means that the result is not the
|
|
# target but rather the input that is being mutated
|
|
# init reorders the inputs, so inputs[1] becomes packed.inputs[0]
|
|
return packed.inputs[0]
|
|
|
|
|
|
class ConvolutionTransposeUnary(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
) -> None:
|
|
self.device_type = get_device_type(inputs[0])
|
|
super().__init__(
|
|
layout,
|
|
inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=torch.ops.mkldnn._convolution_transpose_pointwise.default,
|
|
cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_transpose_pointwise",
|
|
)
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header(
|
|
f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h"
|
|
)
|
|
super().codegen(wrapper)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
x: "TensorBox",
|
|
weight: "TensorBox",
|
|
bias: "TensorBox",
|
|
padding_: list[int],
|
|
output_padding_: list[int],
|
|
stride_: list[int],
|
|
dilation_: list[int],
|
|
groups_: int,
|
|
attr,
|
|
scalars: Optional[list[Any]],
|
|
algorithm,
|
|
):
|
|
transposed = True
|
|
(
|
|
inputs,
|
|
constant_args,
|
|
kernel_layout,
|
|
_,
|
|
_,
|
|
) = _prepare_convolution_fusion_create(
|
|
cls,
|
|
x,
|
|
weight,
|
|
bias,
|
|
padding_,
|
|
stride_,
|
|
dilation_,
|
|
groups_,
|
|
transposed,
|
|
output_padding_,
|
|
)
|
|
constant_args = constant_args + [
|
|
attr,
|
|
may_convert_to_optional(scalars),
|
|
algorithm,
|
|
]
|
|
packed = ConvolutionTransposeUnary(
|
|
layout=kernel_layout,
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
)
|
|
return _create_output_node(packed)
|
|
|
|
|
|
class QConvPointWisePT2E(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
) -> None:
|
|
"""
|
|
if bias is not None
|
|
- inputs = [x, w, b, weight_scale, weight_zp]
|
|
- const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp,
|
|
fp32_output, unary_attr, unary_scalars, unary_algorithm]
|
|
else
|
|
- inputs = [x, w, weight_scale, weight_zp]
|
|
- const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp,
|
|
fp32_output, unary_attr, unary_scalars, unary_algorithm]
|
|
"""
|
|
self.device_type = get_device_type(inputs[0])
|
|
self.has_bias = len(inputs) == 5
|
|
super().__init__(
|
|
layout,
|
|
inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=torch.ops.onednn.qconv_pointwise.default,
|
|
cpp_kernel_name=f"aoti_torch_{self.device_type}__qconv_pointwise_tensor",
|
|
)
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header(
|
|
f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h"
|
|
)
|
|
super().codegen(wrapper)
|
|
if isinstance(self.layout, Layout):
|
|
self.codegen_size_asserts(wrapper)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
qx: "TensorBox",
|
|
x_scale: Union["ShapeAsConstantBuffer", "TensorBox"],
|
|
x_zero_point: Union["ShapeAsConstantBuffer", "TensorBox"],
|
|
qw: "TensorBox", # qw
|
|
w_scale: "TensorBox",
|
|
w_zero_point: "TensorBox",
|
|
bias: "TensorBox",
|
|
stride: list[int],
|
|
padding: list[int],
|
|
dilation: list[int],
|
|
groups: int,
|
|
output_scale: float,
|
|
output_zero_point: int,
|
|
output_dtype,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
):
|
|
transposed = False
|
|
output_padding = None
|
|
(
|
|
inputs,
|
|
constant_args,
|
|
kernel_layout,
|
|
_,
|
|
_,
|
|
) = _prepare_convolution_fusion_create(
|
|
cls,
|
|
qx,
|
|
qw,
|
|
bias,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
transposed,
|
|
output_padding,
|
|
[x_scale, x_zero_point, w_scale, w_zero_point], # type: ignore[list-item]
|
|
)
|
|
# swap padding and stride to align with functional conv arg order
|
|
if bias is None:
|
|
constant_args[1], constant_args[2] = constant_args[2], constant_args[1]
|
|
else:
|
|
constant_args[0], constant_args[1] = constant_args[1], constant_args[0]
|
|
|
|
constant_args = constant_args + [
|
|
output_scale,
|
|
output_zero_point,
|
|
output_dtype,
|
|
attr,
|
|
may_convert_to_optional(scalars),
|
|
algorithm,
|
|
]
|
|
|
|
assert output_dtype is not None
|
|
if output_dtype in [torch.float32, torch.bfloat16]:
|
|
# in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout
|
|
# if we set output_dtype is not None, the output buf should be output_dtype instead of uint8.
|
|
kernel_layout.dtype = output_dtype
|
|
|
|
return QConvPointWisePT2E(
|
|
layout=kernel_layout,
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
)
|
|
|
|
|
|
class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
) -> None:
|
|
"""
|
|
Needs input/weight/output qparams
|
|
if bias is not None
|
|
- inputs = [x, x_scale, x_zp, w, w_scale, w_zp, accum, b]
|
|
- const_args = [stride, padding, dilation, groups, o_scale, o_zp,
|
|
output_dtype, accum_scale, accum_zp, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm]
|
|
else
|
|
- inputs = [x, x_scale, x_zp, w, w_scale, w_zp, accum]
|
|
- const_args [b, stride, padding, dilation, groups, o_scale, o_zp,
|
|
output_dtype, accum_scale, accum_zp, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm]
|
|
"""
|
|
self.device_type = get_device_type(inputs[0])
|
|
self.has_bias = len(inputs) == 8
|
|
self.idx_for_inplace_sum = 6
|
|
super().__init__(
|
|
layout,
|
|
inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=torch.ops.onednn.qconv2d_pointwise.binary,
|
|
cpp_kernel_name=(
|
|
f"aoti_torch_{self.device_type}__qconv2d_pointwise_binary_tensor"
|
|
),
|
|
)
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header(
|
|
f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h"
|
|
)
|
|
super().codegen(wrapper)
|
|
if isinstance(self.layout, Layout):
|
|
self.codegen_size_asserts(wrapper)
|
|
|
|
def get_mutation_names(self) -> Sequence[str]:
|
|
return [self.input_name(self.idx_for_inplace_sum)]
|
|
|
|
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
|
return OrderedSet()
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
qx: "TensorBox",
|
|
x_scale: "TensorBox",
|
|
x_zero_point: "TensorBox",
|
|
qw: "TensorBox", # packed_weight
|
|
w_scale,
|
|
w_zero_point,
|
|
qaccum: "TensorBox",
|
|
bias: "TensorBox",
|
|
stride: list[int],
|
|
padding: list[int],
|
|
dilation: list[int],
|
|
groups: int,
|
|
output_scale: "TensorBox",
|
|
output_zero_point: "TensorBox",
|
|
output_dtype,
|
|
accum_scale,
|
|
accum_zero_point,
|
|
binary_attr,
|
|
alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithm,
|
|
):
|
|
transposed = False
|
|
output_padding = None
|
|
(
|
|
inputs,
|
|
constant_args,
|
|
_kernel_layout,
|
|
req_stride_order,
|
|
qaccum,
|
|
) = _prepare_convolution_fusion_create(
|
|
cls,
|
|
qx,
|
|
qw,
|
|
bias,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
transposed,
|
|
output_padding,
|
|
[x_scale, x_zero_point, w_scale, w_zero_point],
|
|
qaccum,
|
|
)
|
|
|
|
# swap padding and stride to align with functional conv arg order
|
|
if bias is None:
|
|
constant_args[1], constant_args[2] = constant_args[2], constant_args[1]
|
|
else:
|
|
constant_args[0], constant_args[1] = constant_args[1], constant_args[0]
|
|
|
|
constant_args = constant_args + [
|
|
output_scale,
|
|
output_zero_point,
|
|
output_dtype,
|
|
accum_scale,
|
|
accum_zero_point,
|
|
binary_attr,
|
|
alpha,
|
|
unary_attr,
|
|
may_convert_to_optional(unary_scalars),
|
|
unary_algorithm,
|
|
]
|
|
|
|
assert binary_attr == "sum", (
|
|
"For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
|
|
)
|
|
|
|
V.graph.mark_buffer_mutated(qaccum.get_name())
|
|
packed = QConvPointWiseBinaryPT2E(
|
|
layout=NoneLayout(device=qaccum.get_device()),
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
)
|
|
|
|
# Return accum since it has been inplace changed.
|
|
return packed.inputs[packed.idx_for_inplace_sum]
|
|
|
|
|
|
class MKLPackedLinear(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
) -> None:
|
|
super().__init__(
|
|
layout,
|
|
inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=torch.ops.mkl._mkl_linear.default,
|
|
)
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
|
|
super().codegen(wrapper)
|
|
|
|
@classmethod
|
|
def create(cls, x, packed_w, orig_w, B, batch_size):
|
|
x = cls.require_stride1(cls.realize_input(x))
|
|
orig_w = cls.require_stride1(cls.realize_input(orig_w))
|
|
*m, _ = x.get_size()
|
|
oc, _ = orig_w.get_size()
|
|
output_size = list(m) + [oc]
|
|
output_stride = FlexibleLayout.contiguous_strides(output_size)
|
|
inputs = [x, packed_w, orig_w]
|
|
constant_args = [batch_size]
|
|
if B is not None:
|
|
inputs += [B]
|
|
else:
|
|
constant_args.insert(0, None)
|
|
|
|
device = x.get_device()
|
|
assert device is not None
|
|
return MKLPackedLinear(
|
|
layout=FixedLayout(device, x.get_dtype(), output_size, output_stride),
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
)
|
|
|
|
|
|
class LinearUnary(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
) -> None:
|
|
self.device_type = get_device_type(inputs[0])
|
|
super().__init__(
|
|
layout,
|
|
inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=torch.ops.mkldnn._linear_pointwise.default,
|
|
cpp_kernel_name=f"aoti_torch_{self.device_type}__linear_pointwise",
|
|
)
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header(
|
|
f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h"
|
|
)
|
|
super().codegen(wrapper)
|
|
|
|
@classmethod
|
|
def create(cls, x, w, B, attr, scalars, algorithm):
|
|
x = cls.require_contiguous(cls.realize_input(x))
|
|
w = cls.require_contiguous(cls.realize_input(w))
|
|
|
|
*m, _ic = x.get_size()
|
|
oc, _ic = w.get_size()
|
|
output_size = list(m) + [oc]
|
|
inputs = [x, w]
|
|
constant_args = [attr, scalars if scalars else [-1], algorithm]
|
|
if B is not None:
|
|
B = cls.require_contiguous(cls.realize_input(B))
|
|
inputs.append(B)
|
|
else:
|
|
constant_args.insert(0, None)
|
|
|
|
device = x.get_device()
|
|
assert device is not None
|
|
|
|
packed = LinearUnary(
|
|
layout=FixedLayout(
|
|
device=device,
|
|
dtype=x.get_dtype(),
|
|
size=output_size,
|
|
),
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
)
|
|
return _create_output_node(packed)
|
|
|
|
def apply_constraint(self):
|
|
pass
|
|
|
|
|
|
class LinearBinary(ExternKernelAlloc):
|
|
kernel = "torch.ops.mkldnn._linear_pointwise.binary"
|
|
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
) -> None:
|
|
self.device_type = get_device_type(inputs[0])
|
|
super().__init__(
|
|
layout,
|
|
inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=torch.ops.mkldnn._linear_pointwise.binary,
|
|
cpp_kernel_name=f"aoti_torch_{self.device_type}__linear_pointwise_binary",
|
|
)
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header(
|
|
f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h"
|
|
)
|
|
super().codegen(wrapper)
|
|
|
|
@classmethod
|
|
def create(cls, x, y, w, B, attr):
|
|
x = cls.require_contiguous(cls.realize_input(x))
|
|
y = cls.require_contiguous(cls.realize_input(y))
|
|
w = cls.require_contiguous(cls.realize_input(w))
|
|
|
|
*m, _ic = x.get_size()
|
|
oc, _ic = w.get_size()
|
|
output_size = list(m) + [oc]
|
|
inputs = [x, y, w]
|
|
constant_args = [attr]
|
|
if B is not None:
|
|
B = cls.require_contiguous(cls.realize_input(B))
|
|
inputs.append(B)
|
|
else:
|
|
constant_args.insert(0, B)
|
|
|
|
device = x.get_device()
|
|
assert device is not None
|
|
packed = LinearBinary(
|
|
layout=FixedLayout(
|
|
device=device,
|
|
dtype=x.get_dtype(),
|
|
size=output_size,
|
|
),
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
)
|
|
return _create_output_node(packed)
|
|
|
|
def apply_constraint(self):
|
|
pass
|
|
|
|
|
|
class QLinearPointwisePT2E(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
has_bias=True,
|
|
) -> None:
|
|
"""
|
|
if bias is not None
|
|
- inputs = [x, w, b, weight_scale, weight_zp]
|
|
- const_args is: [x_scale, x_zp, o_scale, o_zp,
|
|
fp32_output, unary_attr, unary_scalars, unary_algorithm]
|
|
else
|
|
- inputs = [x, w, weight_scale, weight_zp]
|
|
- const_args is: [bias, x_scale, x_zp, o_scale, o_zp,
|
|
fp32_output, unary_attr, unary_scalars, unary_algorithm]
|
|
"""
|
|
self.device_type = get_device_type(inputs[0])
|
|
self.has_bias = has_bias
|
|
super().__init__(
|
|
layout,
|
|
inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=(torch.ops.onednn.qlinear_pointwise.tensor),
|
|
cpp_kernel_name=(
|
|
f"aoti_torch_{self.device_type}__qlinear_pointwise_tensor"
|
|
),
|
|
)
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header(
|
|
f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h"
|
|
)
|
|
super().codegen(wrapper)
|
|
|
|
if isinstance(self.layout, Layout):
|
|
self.codegen_size_asserts(wrapper)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
qx: "TensorBox",
|
|
x_scale: "TensorBox",
|
|
x_zero_point: "TensorBox",
|
|
qw: "TensorBox", # packed_weight
|
|
w_scale: "TensorBox",
|
|
w_zero_point: "TensorBox",
|
|
bias: "TensorBox",
|
|
output_scale: float,
|
|
output_zero_point: int,
|
|
output_dtype,
|
|
post_op_name,
|
|
post_op_args,
|
|
post_op_algorithm,
|
|
):
|
|
(inputs, constant_args, kernel_layout, _, _) = _prepare_linear_fusion_create(
|
|
cls,
|
|
qx,
|
|
qw,
|
|
bias,
|
|
[x_scale, x_zero_point, w_scale, w_zero_point],
|
|
)
|
|
|
|
constant_args = constant_args + [
|
|
output_scale,
|
|
output_zero_point,
|
|
output_dtype,
|
|
post_op_name,
|
|
may_convert_to_optional(post_op_args),
|
|
post_op_algorithm,
|
|
]
|
|
|
|
assert output_dtype is not None
|
|
if output_dtype in [torch.float32, torch.bfloat16]:
|
|
# in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout
|
|
# if we set fp32_output, the output buf should be dtype float32 instead of uint8.
|
|
kernel_layout.dtype = output_dtype
|
|
|
|
return QLinearPointwisePT2E(
|
|
layout=kernel_layout,
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
has_bias=(bias is not None),
|
|
)
|
|
|
|
|
|
class QLinearPointwiseBinaryPT2E(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
has_bias=True,
|
|
) -> None:
|
|
"""
|
|
if bias is not None
|
|
- inputs = [x, w, x_scale, x_zp, weight_scale, weight_zp, x2, bias]
|
|
- const_args is: [o_scale, o_zp,
|
|
fp32_output, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm]
|
|
else
|
|
- inputs = [x, w, x_scale, x_zp, weight_scale, weight_zp, x2]
|
|
- const_args is: [bias, o_scale, o_zp,
|
|
fp32_output, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm]
|
|
"""
|
|
self.device_type = get_device_type(inputs[0])
|
|
self.has_bias = has_bias
|
|
self.idx_for_inplace_sum = 6
|
|
super().__init__(
|
|
layout,
|
|
inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=(torch.ops.onednn.qlinear_pointwise.binary_tensor),
|
|
cpp_kernel_name=f"aoti_torch_{self.device_type}__qlinear_pointwise_binary_tensor",
|
|
)
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header(
|
|
f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h"
|
|
)
|
|
super().codegen(wrapper)
|
|
if isinstance(self.layout, Layout):
|
|
self.codegen_size_asserts(wrapper)
|
|
|
|
def get_mutation_names(self) -> Sequence[str]:
|
|
binary_post_op = self.constant_args[-5]
|
|
if binary_post_op == "sum":
|
|
input = self.inputs[self.idx_for_inplace_sum]
|
|
assert isinstance(input, IRNode)
|
|
return [input.get_name()]
|
|
else:
|
|
return []
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
qx: "TensorBox",
|
|
x_scale: "TensorBox",
|
|
x_zero_point: "TensorBox",
|
|
qw: "TensorBox", # packed_weight
|
|
w_scale: "TensorBox",
|
|
w_zero_point: "TensorBox",
|
|
other: "TensorBox",
|
|
bias: "TensorBox",
|
|
output_scale: float,
|
|
output_zero_point: int,
|
|
output_dtype,
|
|
other_scale,
|
|
other_zp,
|
|
binary_post_op,
|
|
binary_alpha,
|
|
unary_post_op,
|
|
unary_post_op_args,
|
|
unary_post_op_algorithm,
|
|
):
|
|
(
|
|
inputs,
|
|
constant_args,
|
|
kernel_layout,
|
|
req_stride_order,
|
|
other,
|
|
) = _prepare_linear_fusion_create(
|
|
cls,
|
|
qx,
|
|
qw,
|
|
bias,
|
|
[x_scale, x_zero_point, w_scale, w_zero_point],
|
|
other,
|
|
binary_post_op == "sum",
|
|
)
|
|
|
|
constant_args = constant_args + [
|
|
output_scale,
|
|
output_zero_point,
|
|
output_dtype,
|
|
other_scale,
|
|
other_zp,
|
|
binary_post_op,
|
|
binary_alpha,
|
|
unary_post_op,
|
|
may_convert_to_optional(unary_post_op_args),
|
|
unary_post_op_algorithm,
|
|
]
|
|
|
|
if binary_post_op == "sum":
|
|
V.graph.mark_buffer_mutated(other.get_name())
|
|
packed = QLinearPointwiseBinaryPT2E(
|
|
layout=NoneLayout(device=other.get_device()),
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
has_bias=(bias is not None),
|
|
)
|
|
# Return other since it has been inplace changed.
|
|
return packed.inputs[packed.idx_for_inplace_sum]
|
|
|
|
assert output_dtype is not None
|
|
if output_dtype in [torch.float32, torch.bfloat16]:
|
|
# in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout
|
|
# if we set fp32_output, the output buf should be dtype float32 instead of uint8.
|
|
kernel_layout.dtype = output_dtype
|
|
|
|
return QLinearPointwiseBinaryPT2E(
|
|
layout=kernel_layout,
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
has_bias=(bias is not None),
|
|
)
|
|
|
|
|
|
class MkldnnRnnLayer(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
) -> None:
|
|
super().__init__(
|
|
layout,
|
|
inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=torch.ops.aten.mkldnn_rnn_layer.default,
|
|
)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
x: "TensorBox",
|
|
w0: "TensorBox",
|
|
w1: "TensorBox",
|
|
w2: "TensorBox",
|
|
w3: "TensorBox",
|
|
hx: "TensorBox",
|
|
cx: "TensorBox",
|
|
reverse: bool,
|
|
batch_sizes: list[int],
|
|
mode: int,
|
|
hidden_size: int,
|
|
num_layers: int,
|
|
has_biases: bool,
|
|
bidirectional: bool,
|
|
batch_first: bool,
|
|
train: bool,
|
|
):
|
|
x = cls.require_stride1(cls.realize_input(x))
|
|
# If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer.
|
|
# Make sure x is contiguous in batch_first case.
|
|
x.freeze_layout()
|
|
w0 = cls.require_stride1(cls.realize_input(w0))
|
|
w1 = cls.require_stride1(cls.realize_input(w1))
|
|
w2 = cls.require_stride1(cls.realize_input(w2))
|
|
w3 = cls.require_stride1(cls.realize_input(w3))
|
|
hx = cls.require_stride1(cls.realize_input(hx))
|
|
hx.freeze_layout()
|
|
cx = cls.require_stride1(cls.realize_input(cx))
|
|
cx.freeze_layout()
|
|
|
|
input_size = x.get_size()
|
|
assert len(input_size) == 3, "Expect lstm input to be 3D"
|
|
# batch_first is handled in the lstm OP. When entering
|
|
# rnn_layer here, we'll always have batch_first = False
|
|
seq_length, mini_batch, input_size = input_size
|
|
output_shape = [seq_length, mini_batch, hidden_size]
|
|
|
|
hy_shape = hx.get_size()
|
|
cy_shape = cx.get_size()
|
|
|
|
inputs = [x, w0, w1, w2, w3, hx, cx]
|
|
constant_args = [
|
|
reverse,
|
|
batch_sizes,
|
|
mode,
|
|
hidden_size,
|
|
num_layers,
|
|
has_biases,
|
|
bidirectional,
|
|
batch_first,
|
|
train,
|
|
]
|
|
|
|
device = x.get_device()
|
|
assert device is not None
|
|
packed = MkldnnRnnLayer(
|
|
MultiOutputLayout(device=device),
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
)
|
|
|
|
def get_strides_of_lstm_output(output_shape, batch_first):
|
|
assert len(output_shape) == 3, "Expect output_shape to be 3D"
|
|
return FlexibleLayout.contiguous_strides(output_shape)
|
|
|
|
# C shim call requires all the outputs to be passed in, and thus the last
|
|
# dummy return value is added.
|
|
output_sizes = [output_shape, hy_shape, cy_shape, [1]]
|
|
output_strides = [
|
|
get_strides_of_lstm_output(output_shape, batch_first),
|
|
FlexibleLayout.contiguous_strides(hy_shape),
|
|
FlexibleLayout.contiguous_strides(cy_shape),
|
|
[1],
|
|
]
|
|
output_ir = [
|
|
MultiOutput(
|
|
FixedLayout(
|
|
x.get_device(), # type: ignore[arg-type]
|
|
x.get_dtype(),
|
|
output_size,
|
|
output_stride,
|
|
),
|
|
packed,
|
|
[(tuple, i)],
|
|
)
|
|
for i, (output_size, output_stride) in enumerate(
|
|
zip(output_sizes, output_strides)
|
|
)
|
|
]
|
|
packed.outputs = output_ir
|
|
|
|
return output_ir
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
|
|
return super().codegen(wrapper)
|
|
|
|
|
|
# Add this IR so that we can include shim_cpu.h for cpp_wrapper
|
|
class WeightInt4PackMatmul(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
) -> None:
|
|
"""
|
|
inputs = [x, w, qGroupSize, qScalesAndZeros]
|
|
constant_args = ()
|
|
"""
|
|
assert len(inputs) == 4
|
|
assert len(constant_args) == 0
|
|
super().__init__(
|
|
layout,
|
|
inputs,
|
|
constant_args,
|
|
None,
|
|
op_overload=(torch.ops.quantized.int4mm_packed_weight_cpu.default),
|
|
cpp_kernel_name=("aoti_torch_cpu__weight_int4pack_mm_cpu_tensor"),
|
|
)
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
|
|
super().codegen(wrapper)
|
|
|
|
if isinstance(self.layout, Layout):
|
|
self.codegen_size_asserts(wrapper)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
x: "TensorBox",
|
|
w: "TensorBox",
|
|
qGroupSize: "TensorBox",
|
|
qScalesAndZeros: "TensorBox",
|
|
):
|
|
inputs = [x, w, qGroupSize, qScalesAndZeros]
|
|
*m, _ = x.get_size()
|
|
n, _ = w.get_size()
|
|
output_size = list(m) + [n]
|
|
output_stride = FlexibleLayout.contiguous_strides(output_size)
|
|
kernel_layout = FixedLayout(
|
|
x.get_device(), # type: ignore[arg-type]
|
|
x.get_dtype(),
|
|
output_size,
|
|
output_stride,
|
|
)
|
|
return WeightInt4PackMatmul(
|
|
layout=kernel_layout,
|
|
inputs=inputs,
|
|
)
|