mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix assert issue. Add float8_e4m3fn into dtype list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157684 Approved by: https://github.com/Xia-Weiwen, https://github.com/leslie-fang-intel, https://github.com/jansel
3898 lines
139 KiB
Python
3898 lines
139 KiB
Python
# mypy: allow-untyped-decorators
|
|
# mypy: allow-untyped-defs
|
|
import copy
|
|
import functools
|
|
import itertools
|
|
import math
|
|
import operator
|
|
from typing import Any
|
|
|
|
import torch
|
|
from torch._dynamo.utils import counters
|
|
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
|
from torch.fx.node import map_arg
|
|
|
|
from .. import config
|
|
from ..lowering import lowerings as L, require_channels_last
|
|
from ..pattern_matcher import (
|
|
Arg,
|
|
CallFunction,
|
|
filter_nodes,
|
|
KeywordArg,
|
|
ListOf,
|
|
Match,
|
|
stable_topological_sort,
|
|
)
|
|
from ..utils import pad_listlike
|
|
from .freezing_patterns import register_freezing_graph_pattern
|
|
from .post_grad import register_lowering_pattern
|
|
|
|
|
|
aten = torch.ops.aten
|
|
prims = torch.ops.prims
|
|
quantized_decomposed = torch.ops.quantized_decomposed
|
|
quantized = torch.ops.quantized
|
|
|
|
# Only for per tensor quant since permute may changes the channel idx
|
|
_PER_TENSOR_QUANTIZE_OPS = [
|
|
quantized_decomposed.quantize_per_tensor.default,
|
|
quantized_decomposed.quantize_per_tensor.tensor,
|
|
]
|
|
|
|
_VIEW_OPS = [
|
|
aten.transpose.int,
|
|
aten.permute.default,
|
|
aten.view.default,
|
|
]
|
|
|
|
"""
|
|
The quantization.py file primarily incorporates passes related to quantization fusion
|
|
in inductor, includes:
|
|
1. Dequant Promotion;
|
|
2. Conv/GEMM weight prepack with oneDNN Library;
|
|
3. Conv/GEMM quantization fusion with output quant node (if have);
|
|
4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more;
|
|
|
|
It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference
|
|
of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is
|
|
1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM.
|
|
2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node.
|
|
Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16
|
|
quantization.
|
|
"""
|
|
|
|
|
|
def _get_pattern_output_dtype(match: Match):
|
|
"""
|
|
Get the pattern's output dtype from node's meta
|
|
Assume only 1 output node in this matched pattern.
|
|
"""
|
|
pattern_output_nodes = match.output_nodes()
|
|
assert len(pattern_output_nodes) == 1
|
|
output_node = pattern_output_nodes[0]
|
|
assert isinstance(output_node, torch.fx.Node)
|
|
output_dtype = output_node.meta["val"].dtype
|
|
assert output_dtype in [
|
|
torch.int8,
|
|
torch.uint8,
|
|
torch.float32,
|
|
torch.bfloat16,
|
|
torch.float8_e4m3fn,
|
|
]
|
|
return output_dtype
|
|
|
|
|
|
def _may_generate_pattern_with_dtype_convert(
|
|
pattern, dtype=Arg(), with_dtype_convert=True, users=1
|
|
):
|
|
if with_dtype_convert:
|
|
return CallFunction(
|
|
prims.convert_element_type.default,
|
|
pattern,
|
|
dtype,
|
|
_users=users,
|
|
)
|
|
else:
|
|
return pattern
|
|
|
|
|
|
def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True):
|
|
if with_reshape:
|
|
return CallFunction(
|
|
torch.ops.aten.reshape.default,
|
|
pattern,
|
|
reshape_size,
|
|
)
|
|
else:
|
|
return pattern
|
|
|
|
|
|
def _generate_linear_t_pattern(
|
|
_dequant_per_channel_pattern,
|
|
dtype,
|
|
):
|
|
assert dtype in [torch.float32, torch.bfloat16]
|
|
t_pattern = CallFunction(
|
|
aten.permute.default,
|
|
_may_generate_pattern_with_dtype_convert(
|
|
_dequant_per_channel_pattern,
|
|
KeywordArg("autocast_wgt_dtype"),
|
|
dtype == torch.bfloat16,
|
|
),
|
|
KeywordArg("permute_axes"),
|
|
)
|
|
return t_pattern
|
|
|
|
|
|
def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16):
|
|
# only insert to_dtype if is_bf16 is True
|
|
computation_call = _may_generate_pattern_with_dtype_convert(
|
|
call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users
|
|
)
|
|
return unary_fusion(computation_call)
|
|
|
|
|
|
def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False):
|
|
dequantize_per_tensor_activation_pattern = CallFunction(
|
|
quantized_decomposed.dequantize_per_tensor.tensor
|
|
if is_tensor_overload
|
|
else quantized_decomposed.dequantize_per_tensor.default,
|
|
KeywordArg("x"),
|
|
KeywordArg("x_scale"),
|
|
KeywordArg("x_zp"),
|
|
KeywordArg("x_quant_min"),
|
|
KeywordArg("x_quant_max"),
|
|
KeywordArg("x_dq_dtype"),
|
|
)
|
|
return dequantize_per_tensor_activation_pattern
|
|
|
|
|
|
dequantize_per_channel_weight_pattern = CallFunction(
|
|
quantized_decomposed.dequantize_per_channel.default,
|
|
KeywordArg("q_weight"),
|
|
KeywordArg("w_scale"),
|
|
KeywordArg("w_zp"),
|
|
KeywordArg("w_axis"),
|
|
KeywordArg("w_quant_min"),
|
|
KeywordArg("w_quant_max"),
|
|
KeywordArg("w_dtype"),
|
|
)
|
|
|
|
dequantize_per_channel_to_bf16_weight_pattern = (
|
|
_may_generate_pattern_with_dtype_convert(
|
|
dequantize_per_channel_weight_pattern,
|
|
KeywordArg("autocast_wgt_dtype"),
|
|
)
|
|
)
|
|
|
|
dequantize_per_channel_clone_weight_pattern = CallFunction(
|
|
aten.clone.default,
|
|
dequantize_per_channel_weight_pattern,
|
|
memory_format=KeywordArg("memory_format"),
|
|
)
|
|
|
|
dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
|
|
aten.clone.default,
|
|
dequantize_per_channel_to_bf16_weight_pattern,
|
|
memory_format=KeywordArg("memory_format"),
|
|
)
|
|
|
|
|
|
def get_qconv_pt2e_pattern(users=1):
|
|
return CallFunction(
|
|
torch.ops.onednn.qconv_pointwise.default,
|
|
KeywordArg("x"),
|
|
KeywordArg("x_scale"),
|
|
KeywordArg("x_zp"),
|
|
KeywordArg("packed_weight"),
|
|
KeywordArg("w_scale"),
|
|
KeywordArg("w_zp"),
|
|
KeywordArg("b"),
|
|
KeywordArg("stride"),
|
|
KeywordArg("padding"),
|
|
KeywordArg("dilation"),
|
|
KeywordArg("groups"),
|
|
KeywordArg("output_scale"),
|
|
KeywordArg("output_zero_point"),
|
|
KeywordArg("output_dtype"),
|
|
KeywordArg("postop_name"),
|
|
KeywordArg("postop_args"),
|
|
KeywordArg("postop_algorithm"),
|
|
_users=users,
|
|
)
|
|
|
|
|
|
def get_qconv2d_binary_pt2e_pattern(users=1):
|
|
return CallFunction(
|
|
torch.ops.onednn.qconv2d_pointwise.binary,
|
|
KeywordArg("x"),
|
|
KeywordArg("x_scale"),
|
|
KeywordArg("x_zp"),
|
|
KeywordArg("packed_weight"),
|
|
KeywordArg("w_scale"),
|
|
KeywordArg("w_zp"),
|
|
KeywordArg("accum"),
|
|
KeywordArg("b"),
|
|
KeywordArg("stride"),
|
|
KeywordArg("padding"),
|
|
KeywordArg("dilation"),
|
|
KeywordArg("groups"),
|
|
KeywordArg("output_scale"),
|
|
KeywordArg("output_zero_point"),
|
|
KeywordArg("output_dtype"),
|
|
KeywordArg("accum_scale"),
|
|
KeywordArg("accum_zero_point"),
|
|
KeywordArg("binary_op_name"),
|
|
KeywordArg("alpha"),
|
|
KeywordArg("unary_op_name"),
|
|
KeywordArg("unary_op_args"),
|
|
KeywordArg("unary_op_algorithm"),
|
|
_users=users,
|
|
)
|
|
|
|
|
|
def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1):
|
|
qlinear_op = (
|
|
torch.ops.onednn.qlinear_pointwise.tensor
|
|
if x_scale_zp_are_tensors
|
|
else torch.ops.onednn.qlinear_pointwise.default
|
|
)
|
|
return CallFunction(
|
|
qlinear_op,
|
|
KeywordArg("x"),
|
|
KeywordArg("x_scale"),
|
|
KeywordArg("x_zp"),
|
|
KeywordArg("packed_weight"),
|
|
KeywordArg("w_scale"),
|
|
KeywordArg("w_zp"),
|
|
KeywordArg("b"),
|
|
KeywordArg("output_scale"),
|
|
KeywordArg("output_zero_point"),
|
|
KeywordArg("output_dtype"),
|
|
KeywordArg("postop_name"),
|
|
KeywordArg("postop_args"),
|
|
KeywordArg("postop_algorithm"),
|
|
_users=users,
|
|
)
|
|
|
|
|
|
def get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors, users=1):
|
|
qlinear_op = (
|
|
torch.ops.onednn.qlinear_pointwise.binary_tensor
|
|
if x_scale_zp_are_tensors
|
|
else torch.ops.onednn.qlinear_pointwise.binary
|
|
)
|
|
return CallFunction(
|
|
qlinear_op,
|
|
KeywordArg("x"),
|
|
KeywordArg("x_scale"),
|
|
KeywordArg("x_zp"),
|
|
KeywordArg("packed_weight"),
|
|
KeywordArg("w_scale"),
|
|
KeywordArg("w_zp"),
|
|
KeywordArg("x_2"),
|
|
KeywordArg("b"),
|
|
KeywordArg("output_scale"),
|
|
KeywordArg("output_zero_point"),
|
|
KeywordArg("output_dtype"),
|
|
KeywordArg("x2_scale"),
|
|
KeywordArg("x2_zp"),
|
|
KeywordArg("binary_op_name"),
|
|
KeywordArg("alpha"),
|
|
KeywordArg("unary_op_name"),
|
|
KeywordArg("unary_op_args"),
|
|
KeywordArg("unary_op_algorithm"),
|
|
_users=users,
|
|
)
|
|
|
|
|
|
dequantize_accum_pattern = CallFunction(
|
|
quantized_decomposed.dequantize_per_tensor.default,
|
|
KeywordArg("accum"),
|
|
KeywordArg("accum_scale"),
|
|
KeywordArg("accum_zp"),
|
|
Arg(),
|
|
Arg(),
|
|
KeywordArg("accum_dq_dtype"),
|
|
)
|
|
|
|
|
|
def generate_pattern_with_binary(
|
|
binary_post_op,
|
|
computation_call,
|
|
extra_input_pattern,
|
|
dtype_convert=False,
|
|
swap_inputs=False,
|
|
):
|
|
binary_pattern = (
|
|
CallFunction(
|
|
binary_post_op,
|
|
extra_input_pattern,
|
|
computation_call,
|
|
)
|
|
if swap_inputs
|
|
else CallFunction(
|
|
binary_post_op,
|
|
computation_call,
|
|
extra_input_pattern,
|
|
)
|
|
)
|
|
return _may_generate_pattern_with_dtype_convert(
|
|
binary_pattern,
|
|
KeywordArg("convert_dtype_after_inplace_add"),
|
|
dtype_convert,
|
|
)
|
|
|
|
|
|
def generate_pattern_with_unary(computation_call, unary_post_op):
|
|
if unary_post_op is not None:
|
|
return CallFunction(
|
|
unary_post_op,
|
|
computation_call,
|
|
)
|
|
return computation_call
|
|
|
|
|
|
def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False):
|
|
quantized_op_output_pattern_pt2e = CallFunction(
|
|
quantized_decomposed.quantize_per_tensor.default,
|
|
_may_generate_pattern_with_dtype_convert(
|
|
computation_call,
|
|
Arg(),
|
|
with_dtype_convert,
|
|
),
|
|
KeywordArg("o_inv_scale"),
|
|
KeywordArg("o_zp"),
|
|
KeywordArg("o_qmin"),
|
|
KeywordArg("o_qmax"),
|
|
KeywordArg("o_dtype"),
|
|
)
|
|
return quantized_op_output_pattern_pt2e
|
|
|
|
|
|
def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value):
|
|
if kwarg_name in check_node.kwargs:
|
|
actual_value = check_node.kwargs[kwarg_name]
|
|
return actual_value == expected_value
|
|
else:
|
|
assert len(check_node.args) >= (args_index + 1)
|
|
actual_value = check_node.args[args_index]
|
|
return actual_value == expected_value
|
|
|
|
|
|
def _is_valid_quantized_conv_optimization_pattern():
|
|
def fn(match):
|
|
output_dtype = _get_pattern_output_dtype(match)
|
|
if output_dtype in [torch.float32, torch.bfloat16]:
|
|
# Only keep matched pattern with same output_dtype
|
|
qconv_node_after_weight_prepack = filter_nodes(
|
|
match.nodes, torch.ops.onednn.qconv_pointwise
|
|
)[0]
|
|
return _check_node_kwarg_arg_value(
|
|
qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype
|
|
)
|
|
return True
|
|
|
|
return fn
|
|
|
|
|
|
def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False):
|
|
return (
|
|
_is_valid_qconv_binary_optimization_pattern()
|
|
if has_binary_post_op
|
|
else _is_valid_quantized_conv_optimization_pattern()
|
|
)
|
|
|
|
|
|
def _is_valid_qconv_lowering_pattern():
|
|
def fn(match):
|
|
if len(match.nodes) != 1:
|
|
return False
|
|
return match.nodes[0].target in (
|
|
torch.ops.onednn.qconv_pointwise.default,
|
|
torch.ops.onednn.qconv_pointwise.tensor,
|
|
torch.ops.onednn.qconv2d_pointwise.binary,
|
|
torch.ops.onednn.qconv2d_pointwise.binary_tensor,
|
|
)
|
|
|
|
return fn
|
|
|
|
|
|
def _register_quantized_conv_lowering(
|
|
pattern,
|
|
pass_number,
|
|
computation_op,
|
|
):
|
|
@register_lowering_pattern(
|
|
pattern,
|
|
extra_check=_is_valid_qconv_lowering_pattern(),
|
|
pass_number=pass_number,
|
|
)
|
|
def qconv(match: Match, *args, **kwargs):
|
|
# Activation QParams
|
|
x, x_scale, x_zp = (
|
|
kwargs["x"],
|
|
kwargs["x_scale"],
|
|
kwargs["x_zp"],
|
|
)
|
|
# Weight QParams
|
|
packed_weight, w_scale, w_zp = (
|
|
kwargs["packed_weight"],
|
|
kwargs["w_scale"],
|
|
kwargs["w_zp"],
|
|
)
|
|
# Conv Params
|
|
b, stride, padding, dilation, groups = (
|
|
kwargs["b"],
|
|
kwargs["stride"],
|
|
kwargs["padding"],
|
|
kwargs["dilation"],
|
|
kwargs["groups"],
|
|
)
|
|
output_dtype = _get_pattern_output_dtype(match)
|
|
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
|
|
# Output QParams
|
|
o_inv_scale = kwargs["output_scale"]
|
|
o_zero_point = kwargs["output_zero_point"]
|
|
output_dtype = kwargs["output_dtype"]
|
|
# post op
|
|
postop_name = kwargs["postop_name"]
|
|
postop_args = kwargs["postop_args"]
|
|
postop_algorithm = kwargs["postop_algorithm"]
|
|
|
|
computation_args = (
|
|
x,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight,
|
|
w_scale,
|
|
w_zp,
|
|
b,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
o_inv_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
postop_name,
|
|
postop_args,
|
|
postop_algorithm,
|
|
)
|
|
counters["inductor"]["qconv_unary_lower_count"] += 1
|
|
counters["inductor"]["qconv_unary_lower_nodes"] += len(match.nodes)
|
|
return L[computation_op](*computation_args)
|
|
|
|
return qconv
|
|
|
|
|
|
def _is_valid_quantized_linear_optimization_pattern():
|
|
def fn(match):
|
|
output_dtype = _get_pattern_output_dtype(match)
|
|
if output_dtype in [torch.float32, torch.bfloat16]:
|
|
# Only keep matched pattern with same output_dtype
|
|
qlinear_node_after_weight_prepack = filter_nodes(
|
|
match.nodes, torch.ops.onednn.qlinear_pointwise
|
|
)[0]
|
|
return _check_node_kwarg_arg_value(
|
|
qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype
|
|
)
|
|
return True
|
|
|
|
return fn
|
|
|
|
|
|
def _is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op=False):
|
|
return (
|
|
_is_valid_qlinear_binary_optimization_pattern()
|
|
if has_binary_post_op
|
|
else _is_valid_quantized_linear_optimization_pattern()
|
|
)
|
|
|
|
|
|
def _is_valid_qlinear_lowering_pattern():
|
|
def fn(match):
|
|
if len(match.nodes) != 1:
|
|
return False
|
|
return match.nodes[0].target in (
|
|
torch.ops.onednn.qlinear_pointwise.default,
|
|
torch.ops.onednn.qlinear_pointwise.tensor,
|
|
torch.ops.onednn.qlinear_pointwise.binary,
|
|
torch.ops.onednn.qlinear_pointwise.binary_tensor,
|
|
)
|
|
|
|
return fn
|
|
|
|
|
|
def _register_quantized_linear_unary_lowering(
|
|
pattern,
|
|
pass_number,
|
|
computation_op,
|
|
):
|
|
@register_lowering_pattern(
|
|
pattern,
|
|
extra_check=_is_valid_qlinear_lowering_pattern(),
|
|
pass_number=pass_number,
|
|
)
|
|
def qlinear(match: Match, *args, **kwargs):
|
|
output_dtype = _get_pattern_output_dtype(match)
|
|
# Activation QParams
|
|
x, x_scale, x_zp = (
|
|
kwargs["x"],
|
|
kwargs["x_scale"],
|
|
kwargs["x_zp"],
|
|
)
|
|
# Weight QParams
|
|
packed_weight, w_scale, w_zp = (
|
|
kwargs["packed_weight"],
|
|
kwargs["w_scale"],
|
|
kwargs["w_zp"],
|
|
)
|
|
|
|
# bias
|
|
b = kwargs["b"] if "b" in kwargs else None
|
|
|
|
# Output QParams
|
|
o_inv_scale = kwargs["output_scale"]
|
|
o_zero_point = kwargs["output_zero_point"]
|
|
|
|
# post op
|
|
postop_name = kwargs["postop_name"]
|
|
postop_args = kwargs["postop_args"]
|
|
postop_algorithm = kwargs["postop_algorithm"]
|
|
|
|
computation_args = (
|
|
x,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight,
|
|
w_scale,
|
|
w_zp,
|
|
b,
|
|
o_inv_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
postop_name,
|
|
postop_args,
|
|
postop_algorithm,
|
|
)
|
|
counters["inductor"]["qlinear_unary_lower_count"] += 1
|
|
counters["inductor"]["qlinear_unary_lower_nodes"] += len(match.nodes)
|
|
return L[computation_op](*computation_args)
|
|
|
|
return qlinear
|
|
|
|
|
|
def _register_quantized_linear_binary_lowering(
|
|
pattern,
|
|
pass_number,
|
|
computation_op,
|
|
):
|
|
@register_lowering_pattern(
|
|
pattern,
|
|
extra_check=_is_valid_qlinear_lowering_pattern(),
|
|
pass_number=pass_number,
|
|
)
|
|
def qlinear_binary(match: Match, *args, **kwargs):
|
|
output_dtype = _get_pattern_output_dtype(match)
|
|
assert output_dtype is not None
|
|
# Activation QParams
|
|
x, x_scale, x_zp = (
|
|
kwargs["x"],
|
|
kwargs["x_scale"],
|
|
kwargs["x_zp"],
|
|
)
|
|
x2 = kwargs["x_2"]
|
|
x2_scale = kwargs["x2_scale"]
|
|
x2_zp = kwargs["x2_zp"]
|
|
# Weight QParams
|
|
packed_weight, w_scale, w_zp = (
|
|
kwargs["packed_weight"],
|
|
kwargs["w_scale"],
|
|
kwargs["w_zp"],
|
|
)
|
|
# bias
|
|
b = kwargs["b"] if "b" in kwargs else None
|
|
# Output QParams
|
|
o_inv_scale = kwargs["output_scale"]
|
|
o_zero_point = kwargs["output_zero_point"]
|
|
|
|
x2.realize()
|
|
from .mkldnn_fusion import _can_be_inplace
|
|
|
|
binary_op_name = kwargs["binary_op_name"]
|
|
alpha = kwargs["alpha"]
|
|
unary_op_name = kwargs["unary_op_name"]
|
|
unary_op_args = kwargs["unary_op_args"]
|
|
unary_op_algorithm = kwargs["unary_op_algorithm"]
|
|
|
|
if binary_op_name == "sum" and not _can_be_inplace(x2):
|
|
# When we enable the GEMM Template, the output of QLinear
|
|
# will be reshaped from 2D back to 3D if the input is 3D.
|
|
# This causes _can_be_inplace(x2) to return False if x2 happens
|
|
# to be the output of QLinear in this scenario.
|
|
# Change the post op from sum to binary add for this case.
|
|
# Refer to test case:
|
|
# test_mkldnn_pattern_matcher.py::test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2
|
|
binary_op_name = "add"
|
|
|
|
computation_args = (
|
|
x,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight,
|
|
w_scale,
|
|
w_zp,
|
|
x2,
|
|
b,
|
|
o_inv_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
x2_scale,
|
|
x2_zp,
|
|
binary_op_name,
|
|
alpha,
|
|
unary_op_name,
|
|
unary_op_args,
|
|
unary_op_algorithm,
|
|
)
|
|
counters["inductor"]["qlinear_binary_lower_count"] += 1
|
|
counters["inductor"]["qlinear_binary_lower_nodes"] += len(match.nodes)
|
|
return L[computation_op](*computation_args)
|
|
|
|
return qlinear_binary
|
|
|
|
|
|
def _is_valid_qconv_binary_optimization_pattern():
|
|
return _is_valid_quantized_op_binary_optimization_pattern(
|
|
torch.ops.onednn.qconv_pointwise
|
|
)
|
|
|
|
|
|
def _is_valid_qlinear_binary_optimization_pattern():
|
|
return _is_valid_quantized_op_binary_optimization_pattern(
|
|
torch.ops.onednn.qlinear_pointwise,
|
|
# we don't insert q-dq for extra input due to accuracy issues
|
|
extra_input_from_dequant=False,
|
|
)
|
|
|
|
|
|
def _is_valid_quantized_op_binary_optimization_pattern(
|
|
qop, extra_input_from_dequant=True
|
|
):
|
|
# Check if it's a valid Binary Pattern for qconv2d and qlinear:
|
|
# * qop_pointwise should only has one users
|
|
# * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern
|
|
# * the two inputs of binary node should have attribute "meta" and should be tensors
|
|
# * the two inputs of binary node should have the same shape
|
|
# * All users of the extra input in this pattern should be
|
|
# ancestor nodes of the compute node, except for the binary node
|
|
# connected to the compute node.
|
|
def fn(match):
|
|
output_dtype = _get_pattern_output_dtype(match)
|
|
compute_node = filter_nodes(match.nodes, qop)[0]
|
|
# qop_pointwise should only have one user
|
|
if len(compute_node.users) != 1:
|
|
return False
|
|
binary_node_inputs = next(iter(compute_node.users)).args
|
|
assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs"
|
|
if output_dtype in [torch.float32, torch.bfloat16]:
|
|
extra_input_of_binary_node = None
|
|
for arg in binary_node_inputs:
|
|
if arg != compute_node:
|
|
extra_input_of_binary_node = arg
|
|
break
|
|
assert extra_input_of_binary_node is not None
|
|
# Extra input of binary node comes from dequant pattern
|
|
if extra_input_from_dequant and (
|
|
(not isinstance(extra_input_of_binary_node, torch.fx.Node))
|
|
or (
|
|
extra_input_of_binary_node.target
|
|
!= quantized_decomposed.dequantize_per_tensor.default
|
|
)
|
|
):
|
|
return False
|
|
|
|
# the two inputs of binary node should have attribute "meta" and should be tensors
|
|
if not (
|
|
hasattr(binary_node_inputs[0], "meta")
|
|
and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
|
|
) or not (
|
|
hasattr(binary_node_inputs[1], "meta")
|
|
and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
|
|
):
|
|
return False
|
|
# the two inputs of binary node should have the same shape
|
|
if (
|
|
binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr]
|
|
!= binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr]
|
|
):
|
|
return False
|
|
|
|
# All users of the extra input in this pattern should be
|
|
# ancestor nodes of the compute node, except for the binary node
|
|
# connected to the compute node.
|
|
|
|
from .mkldnn_fusion import _get_remaining_users
|
|
|
|
extra_input_of_pattern = (
|
|
match.kwargs["other"]
|
|
if "other" in match.kwargs
|
|
else (
|
|
match.kwargs["accum"]
|
|
if (output_dtype in [torch.uint8, torch.int8])
|
|
or (not extra_input_from_dequant)
|
|
else match.kwargs["accum_after_dequant"]
|
|
)
|
|
)
|
|
if (
|
|
len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1
|
|
or extra_input_of_pattern == compute_node.args[0]
|
|
):
|
|
return False
|
|
return True
|
|
|
|
return fn
|
|
|
|
|
|
def _register_quantized_conv_binary_lowering(
|
|
pattern,
|
|
pass_number,
|
|
computation_op,
|
|
):
|
|
@register_lowering_pattern(
|
|
pattern,
|
|
extra_check=_is_valid_qconv_lowering_pattern(),
|
|
pass_number=pass_number,
|
|
)
|
|
def qconv_binary(match: Match, *args, **kwargs):
|
|
output_dtype = _get_pattern_output_dtype(match)
|
|
assert output_dtype is not None
|
|
x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"]
|
|
accum = kwargs["accum"]
|
|
accum_scale = kwargs["accum_scale"]
|
|
accum_zp = kwargs["accum_zero_point"]
|
|
packed_weight, w_scale, w_zp = (
|
|
kwargs["packed_weight"],
|
|
kwargs["w_scale"],
|
|
kwargs["w_zp"],
|
|
)
|
|
b, stride, padding, dilation, groups = (
|
|
kwargs["b"],
|
|
kwargs["stride"],
|
|
kwargs["padding"],
|
|
kwargs["dilation"],
|
|
kwargs["groups"],
|
|
)
|
|
# Output QParams
|
|
output_scale = kwargs["output_scale"]
|
|
output_zero_point = kwargs["output_zero_point"]
|
|
|
|
# post ops
|
|
binary_op_name = kwargs["binary_op_name"]
|
|
alpha = kwargs["alpha"]
|
|
unary_op_name = kwargs["unary_op_name"]
|
|
unary_op_args = kwargs["unary_op_args"]
|
|
unary_op_algorithm = kwargs["unary_op_algorithm"]
|
|
|
|
accum.realize()
|
|
from .mkldnn_fusion import _can_be_inplace
|
|
|
|
assert _can_be_inplace(accum), (
|
|
"QConv Binary Inplace Fusion requires accum is not an alias or mutation."
|
|
)
|
|
|
|
computation_args = (
|
|
x,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight,
|
|
w_scale,
|
|
w_zp,
|
|
accum,
|
|
b,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
output_scale,
|
|
output_zero_point,
|
|
output_dtype,
|
|
accum_scale,
|
|
accum_zp,
|
|
binary_op_name,
|
|
alpha,
|
|
unary_op_name,
|
|
unary_op_args,
|
|
unary_op_algorithm,
|
|
)
|
|
counters["inductor"]["qconv2d_binary_lower_count"] += 1
|
|
counters["inductor"]["qconv2d_binary_lower_nodes"] += len(match.nodes)
|
|
return L[computation_op](*computation_args)
|
|
|
|
return qconv_binary
|
|
|
|
|
|
def _register_quantization_unary_lowering():
|
|
# QConv2d
|
|
for users in [1, 2]:
|
|
qconv_pattern = get_qconv_pt2e_pattern(users)
|
|
_register_quantized_conv_lowering(
|
|
qconv_pattern,
|
|
2, # pass_number
|
|
torch.ops.onednn.qconv_pointwise.default, # computation_op
|
|
)
|
|
|
|
# QLinear
|
|
for x_scale_zp_are_tensors in (False, True):
|
|
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
|
|
computation_op = (
|
|
torch.ops.onednn.qlinear_pointwise.tensor
|
|
if x_scale_zp_are_tensors
|
|
else torch.ops.onednn.qlinear_pointwise.default
|
|
)
|
|
_register_quantized_linear_unary_lowering(
|
|
qlinear_pattern,
|
|
2, # pass_number
|
|
computation_op,
|
|
)
|
|
|
|
|
|
def _register_quantization_binary_lowering():
|
|
# QConv2d
|
|
for users in (1, 2):
|
|
qconv_pattern = get_qconv2d_binary_pt2e_pattern(users)
|
|
_register_quantized_conv_binary_lowering(
|
|
qconv_pattern,
|
|
2, # pass_number
|
|
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
|
)
|
|
|
|
# QLinear
|
|
for x_scale_zp_are_tensors in (False, True):
|
|
qlinear_pattern = get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors)
|
|
computation_op = (
|
|
torch.ops.onednn.qlinear_pointwise.binary_tensor
|
|
if x_scale_zp_are_tensors
|
|
else torch.ops.onednn.qlinear_pointwise.binary
|
|
)
|
|
_register_quantized_linear_binary_lowering(
|
|
qlinear_pattern,
|
|
2, # pass_number
|
|
computation_op,
|
|
)
|
|
|
|
|
|
def _is_valid_quantized_maxpool2d_optimization_pattern():
|
|
def fn(match):
|
|
# Only match the pattern which max_pool2d_with_indices returns value
|
|
# instead of indices.
|
|
get_item_node = filter_nodes(match.nodes, operator.getitem)[0]
|
|
return get_item_node.args[1] == 0
|
|
|
|
return fn
|
|
|
|
|
|
def _register_quantized_maxpool2d_lowering(
|
|
pattern,
|
|
computation_op,
|
|
):
|
|
@register_lowering_pattern(
|
|
pattern,
|
|
extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(),
|
|
)
|
|
def qmaxpool2d(match: Match, *args, **kwargs):
|
|
x = kwargs["x"]
|
|
kernel_size = kwargs["kernel_size"]
|
|
stride = kwargs["stride"] if ("stride" in kwargs) else None
|
|
padding = kwargs["padding"] if ("padding" in kwargs) else 0
|
|
dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1
|
|
ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False
|
|
|
|
if padding == 0:
|
|
padding = [0, 0]
|
|
if dilation == 1:
|
|
dilation = [1, 1]
|
|
if not stride:
|
|
stride = kernel_size
|
|
kernel_size = pad_listlike(kernel_size, 2)
|
|
stride = pad_listlike(stride, 2)
|
|
padding = pad_listlike(padding, 2)
|
|
dilation = pad_listlike(dilation, 2)
|
|
|
|
assert len(kernel_size) == 2
|
|
assert len(stride) == 2
|
|
assert len(padding) == 2
|
|
assert len(dilation) == 2
|
|
|
|
computation_args = (
|
|
x,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
ceil_mode,
|
|
)
|
|
computation_args, _ = require_channels_last(computation_op, *computation_args)
|
|
counters["inductor"]["qmaxpool2d_matcher_count"] += 1
|
|
counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes)
|
|
return L[computation_op](*computation_args)
|
|
|
|
return qmaxpool2d
|
|
|
|
|
|
def _register_quantization_maxpool2d():
|
|
# Currently, the default parameters are not in FX Graph generated by Dynamo export.
|
|
# So, if user defines nn.MaxPool2d with different assignment of default parameter,
|
|
# it will generate graph with different number of input nodes and hence
|
|
# different pattern to be matched.
|
|
# Refer to the issue: https://github.com/pytorch/pytorch/issues/105901
|
|
max_pool2d_args_list = [
|
|
[
|
|
KeywordArg("stride"),
|
|
],
|
|
[
|
|
KeywordArg("stride"),
|
|
KeywordArg("padding"),
|
|
],
|
|
[
|
|
KeywordArg("stride"),
|
|
KeywordArg("padding"),
|
|
KeywordArg("dilation"),
|
|
],
|
|
[
|
|
KeywordArg("stride"),
|
|
KeywordArg("padding"),
|
|
KeywordArg("dilation"),
|
|
KeywordArg("ceil_mode"),
|
|
],
|
|
]
|
|
for max_pool2d_args in max_pool2d_args_list:
|
|
dequantize_maxpool2d_pattern = CallFunction(
|
|
aten.max_pool2d_with_indices.default,
|
|
get_dequantize_per_tensor_activation_pattern(),
|
|
KeywordArg("kernel_size"),
|
|
*max_pool2d_args,
|
|
)
|
|
dequantize_lowmem_maxpool2d_pattern = CallFunction(
|
|
prims._low_memory_max_pool_with_offsets.default,
|
|
get_dequantize_per_tensor_activation_pattern(),
|
|
KeywordArg("kernel_size"),
|
|
*max_pool2d_args,
|
|
KeywordArg("offset_dtype"),
|
|
)
|
|
dequantize_maxpool2d_get_item_pattern = CallFunction(
|
|
operator.getitem,
|
|
dequantize_maxpool2d_pattern,
|
|
Arg(),
|
|
)
|
|
dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction(
|
|
operator.getitem,
|
|
dequantize_lowmem_maxpool2d_pattern,
|
|
Arg(),
|
|
)
|
|
_register_quantized_maxpool2d_lowering(
|
|
generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern),
|
|
quantized.max_pool2d.default,
|
|
)
|
|
_register_quantized_maxpool2d_lowering(
|
|
generate_pattern_with_output_quant(
|
|
dequantize_lowmem_maxpool2d_get_item_pattern
|
|
),
|
|
quantized.max_pool2d.default,
|
|
)
|
|
|
|
|
|
def _is_input_output_same_scale_zp(check_node):
|
|
def fn(match):
|
|
# Ensure all the inputs and output has same scale and zero point
|
|
# Step 1: Check inputs/output zero point
|
|
# Get dequant nodes at input
|
|
dequant_nodes = filter_nodes(
|
|
match.nodes, quantized_decomposed.dequantize_per_tensor.default
|
|
)
|
|
zero_points = [node.args[2] for node in dequant_nodes]
|
|
# Get quant nodes at output
|
|
quant_nodes = filter_nodes(
|
|
match.nodes, quantized_decomposed.quantize_per_tensor.default
|
|
)
|
|
assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern"
|
|
zero_points.append(quant_nodes[0].args[2])
|
|
if not all(zero_point == zero_points[0] for zero_point in zero_points):
|
|
return False
|
|
|
|
# Step 2: Check inputs/output scale
|
|
scales = [node.args[1] for node in dequant_nodes]
|
|
scales.append(quant_nodes[0].args[1])
|
|
if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type]
|
|
return False
|
|
|
|
return True
|
|
|
|
return fn
|
|
|
|
|
|
def _register_quantized_cat_lowering(
|
|
pattern,
|
|
computation_op,
|
|
):
|
|
@register_lowering_pattern(
|
|
pattern,
|
|
extra_check=_is_input_output_same_scale_zp(aten.cat.default),
|
|
)
|
|
def qcat(match: Match, inputs, dim, **kwargs):
|
|
# inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...]
|
|
uint8_inputs = [input[0] for input in inputs]
|
|
counters["inductor"]["qcat_matcher_count"] += 1
|
|
counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes)
|
|
return L[computation_op](uint8_inputs, dim)
|
|
|
|
return qcat
|
|
|
|
|
|
_raw_dequantize_per_tensor_activation_pattern = CallFunction(
|
|
quantized_decomposed.dequantize_per_tensor.default,
|
|
Arg(),
|
|
Arg(),
|
|
Arg(),
|
|
Arg(),
|
|
Arg(),
|
|
Arg(),
|
|
)
|
|
|
|
|
|
def _register_quantization_cat():
|
|
dequantize_cat_pattern = CallFunction(
|
|
aten.cat.default,
|
|
ListOf(_raw_dequantize_per_tensor_activation_pattern),
|
|
KeywordArg("dim"),
|
|
)
|
|
_register_quantized_cat_lowering(
|
|
generate_pattern_with_output_quant(dequantize_cat_pattern),
|
|
aten.cat,
|
|
)
|
|
|
|
|
|
def _register_quantized_reshape_lowering(
|
|
pattern,
|
|
computation_op,
|
|
):
|
|
@register_lowering_pattern(
|
|
pattern,
|
|
extra_check=_is_input_output_same_scale_zp(aten.reshape.default),
|
|
)
|
|
def qreshape(match: Match, *args, **kwargs):
|
|
qx = kwargs["x"]
|
|
shape = kwargs["shape"]
|
|
counters["inductor"]["qreshape_matcher_count"] += 1
|
|
counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes)
|
|
return L[computation_op](qx, shape)
|
|
|
|
return qreshape
|
|
|
|
|
|
def _register_quantization_reshape():
|
|
dequantize_reshape_pattern = CallFunction(
|
|
torch.ops.aten.reshape.default,
|
|
get_dequantize_per_tensor_activation_pattern(),
|
|
KeywordArg("shape"),
|
|
)
|
|
_register_quantized_reshape_lowering(
|
|
generate_pattern_with_output_quant(dequantize_reshape_pattern),
|
|
aten.reshape,
|
|
)
|
|
|
|
|
|
def _is_valid_concat_linear_int8_woq_optimization_pattern():
|
|
def fn(match):
|
|
if not config.cpp.enable_concat_linear:
|
|
return False
|
|
assert all(k in match.kwargs for k in ("x", "w1", "w2", "w3", "scales"))
|
|
if not all(
|
|
hasattr(match.kwargs[key], "meta")
|
|
for key in ["x", "w1", "w2", "w3", "scales"]
|
|
):
|
|
return False
|
|
x = match.kwargs["x"].meta["val"]
|
|
w1 = match.kwargs["w1"].meta["val"]
|
|
w2 = match.kwargs["w2"].meta["val"]
|
|
w3 = match.kwargs["w3"].meta["val"]
|
|
scales = match.kwargs["scales"].meta["val"]
|
|
if len(match.kwargs["scales"].meta["val"].size()) > 1:
|
|
return False
|
|
num_scales = match.kwargs["scales"].meta["val"].numel()
|
|
w1_cols = match.kwargs["w1"].meta["val"].size()[0]
|
|
w2_cols = match.kwargs["w2"].meta["val"].size()[0]
|
|
w3_cols = match.kwargs["w3"].meta["val"].size()[0]
|
|
# Technically, the shapes of the three weights need not be equal.
|
|
# But currently, we only enable replacement in this case.
|
|
if w1_cols != w2_cols or w2_cols != w3_cols:
|
|
return False
|
|
if 3 * w1_cols != num_scales:
|
|
return False
|
|
return (
|
|
# For now, we only support woq mm kernels
|
|
# with x.type=bfloat16 and w.type=int8
|
|
x.dtype == torch.bfloat16
|
|
and w1.dtype == torch.int8
|
|
and w2.dtype == torch.int8
|
|
and w3.dtype == torch.int8
|
|
and scales.dtype == torch.bfloat16
|
|
# _weight_int8pack_mm kernel only supports cpu now
|
|
# TODO: add cuda kernel support instead of calling mul+sum
|
|
and x.device.type == "cpu"
|
|
and x.device == w1.device
|
|
and w1.device == w2.device
|
|
and w2.device == w3.device
|
|
and x.device == scales.device
|
|
)
|
|
|
|
return fn
|
|
|
|
|
|
def _is_valid_woq_optimization_pattern():
|
|
def fn(match):
|
|
assert all(k in match.kwargs for k in ("x", "weight", "scales"))
|
|
if not all(
|
|
hasattr(match.kwargs[key], "meta") for key in ["x", "weight", "scales"]
|
|
):
|
|
return False
|
|
x = match.kwargs["x"].meta["val"]
|
|
weight = match.kwargs["weight"].meta["val"]
|
|
scales = match.kwargs["scales"].meta["val"]
|
|
return (
|
|
# For now, we only support woq mm kernels
|
|
# with x.type=bfloat16 and w.type=int8
|
|
x.dtype == torch.bfloat16
|
|
and weight.dtype == torch.int8
|
|
and scales.dtype == torch.bfloat16
|
|
# _weight_int8pack_mm kernel only supports cpu now
|
|
# TODO: add cuda kernel support instead of calling mul+sum
|
|
and x.device.type == "cpu"
|
|
and x.device == weight.device
|
|
and x.device == scales.device
|
|
)
|
|
|
|
return fn
|
|
|
|
|
|
def _register_concat_linear_int8_woq_lowering(
|
|
pattern, computation_woq, computation_reshape
|
|
):
|
|
@register_freezing_graph_pattern(
|
|
pattern,
|
|
extra_check=_is_valid_concat_linear_int8_woq_optimization_pattern(),
|
|
pass_number=4,
|
|
)
|
|
def woq(match: Match, *args, **kwargs):
|
|
x = kwargs["x"]
|
|
w1 = kwargs["w1"]
|
|
w2 = kwargs["w2"]
|
|
w3 = kwargs["w3"]
|
|
scales = kwargs["scales"]
|
|
counters["inductor"]["woq_matcher_count"] += 1
|
|
counters["inductor"]["woq_matcher_nodes"] += len(match.nodes)
|
|
out_features = (
|
|
w1.meta["val"].size()[0]
|
|
+ w2.meta["val"].size()[0]
|
|
+ w3.meta["val"].size()[0]
|
|
)
|
|
origin_x_size = tuple(x.meta["val"].size())
|
|
x_shape = [-1, origin_x_size[-1]]
|
|
out_shape = list(origin_x_size[:-1] + (out_features,))
|
|
mm_node_of_x = None
|
|
for candidate in iter(x.users.keys()):
|
|
if (
|
|
candidate.target == aten.mm.default
|
|
and list(candidate._input_nodes)[1].target == aten.cat.default
|
|
):
|
|
mm_node_of_x = candidate
|
|
break
|
|
assert mm_node_of_x is not None, "unable to find mm node"
|
|
_, cat_wgt_node = mm_node_of_x._input_nodes
|
|
scaling_node = next(iter(mm_node_of_x.users.keys()))
|
|
user_of_scaling_node = next(iter(scaling_node.users.keys()))
|
|
# Some other pass is making some changes that entails
|
|
# adding a node before it's used, but it can only be found when
|
|
# lint is run. stable_topological_sort() is being run before lint,
|
|
# so that error was not being being discovered.
|
|
# We call stable_topological_sort here as a workaround.
|
|
stable_topological_sort(match.graph)
|
|
with match.graph.inserting_before(user_of_scaling_node):
|
|
new_cat_node = match.graph.call_function(
|
|
aten.cat.default,
|
|
args=([w1, w2, w3], 0),
|
|
)
|
|
x_reshape_node = match.graph.call_function(
|
|
computation_reshape, args=(x, x_shape)
|
|
)
|
|
new_woq_node = match.graph.call_function(
|
|
computation_woq,
|
|
args=(x_reshape_node, new_cat_node, scales),
|
|
)
|
|
new_woq_node.meta = copy.copy(x.meta)
|
|
output_reshape_node = match.graph.call_function(
|
|
computation_reshape, args=(new_woq_node, out_shape)
|
|
)
|
|
scaling_node.replace_all_uses_with(output_reshape_node)
|
|
match.graph.erase_node(scaling_node)
|
|
match.graph.erase_node(mm_node_of_x)
|
|
match.graph.erase_node(cat_wgt_node)
|
|
match.graph.lint()
|
|
|
|
return woq
|
|
|
|
|
|
def _register_woq_lowering(pattern, computation_woq, computation_reshape):
|
|
@register_lowering_pattern(
|
|
pattern,
|
|
extra_check=_is_valid_woq_optimization_pattern(),
|
|
)
|
|
def woq(match: Match, *args, **kwargs):
|
|
x = kwargs["x"]
|
|
weight = kwargs["weight"]
|
|
scales = kwargs["scales"]
|
|
counters["inductor"]["woq_matcher_count"] += 1
|
|
counters["inductor"]["woq_matcher_nodes"] += len(match.nodes)
|
|
out_features = weight.get_size()[0]
|
|
origin_x_size = x.get_size()
|
|
x_shape = [-1, origin_x_size[-1]]
|
|
out_shape = origin_x_size[:-1] + [
|
|
out_features,
|
|
]
|
|
func1 = L[computation_reshape](x, x_shape)
|
|
func2 = L[computation_woq](func1, weight, scales)
|
|
return L[computation_reshape](func2, out_shape)
|
|
|
|
return woq
|
|
|
|
|
|
def _register_woq_mm_int8_pattern1():
|
|
# F.linear(x, weight.to(dtype=x.dtype)) * scales
|
|
# case of dispatching to mm, with x reshape
|
|
_woq_pattern = CallFunction(
|
|
aten.mul.Tensor,
|
|
CallFunction(
|
|
aten.reshape.default,
|
|
CallFunction(
|
|
aten.mm.default,
|
|
CallFunction(aten.reshape.default, KeywordArg("x"), Arg()),
|
|
CallFunction(
|
|
aten.permute.default,
|
|
CallFunction(
|
|
prims.convert_element_type.default, KeywordArg("weight"), Arg()
|
|
),
|
|
Arg(),
|
|
),
|
|
),
|
|
Arg(),
|
|
),
|
|
KeywordArg("scales"),
|
|
)
|
|
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
|
|
|
|
|
|
def _register_woq_mm_int8_pattern2():
|
|
# F.linear(x, weight.to(dtype=x.dtype)) * scales
|
|
# case of dispatching to mm, w/o x reshape
|
|
_woq_pattern = CallFunction(
|
|
aten.mul.Tensor,
|
|
CallFunction(
|
|
aten.reshape.default,
|
|
CallFunction(
|
|
aten.mm.default,
|
|
KeywordArg("x"),
|
|
CallFunction(
|
|
aten.permute.default,
|
|
CallFunction(
|
|
prims.convert_element_type.default, KeywordArg("weight"), Arg()
|
|
),
|
|
Arg(),
|
|
),
|
|
),
|
|
Arg(),
|
|
),
|
|
KeywordArg("scales"),
|
|
)
|
|
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
|
|
|
|
|
|
def _register_woq_mm_int8_pattern3():
|
|
# F.linear(x, weight.to(dtype=x.dtype)) * scales
|
|
# case of dispatching to bmm
|
|
_woq_pattern = CallFunction(
|
|
aten.mul.Tensor,
|
|
CallFunction(
|
|
aten.bmm.default,
|
|
CallFunction(aten.expand.default, KeywordArg("x"), Arg()),
|
|
CallFunction(
|
|
aten.expand.default,
|
|
CallFunction(
|
|
aten.permute.default,
|
|
CallFunction(
|
|
prims.convert_element_type.default, KeywordArg("weight"), Arg()
|
|
),
|
|
Arg(),
|
|
),
|
|
Arg(),
|
|
),
|
|
),
|
|
KeywordArg("scales"),
|
|
)
|
|
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
|
|
|
|
|
|
def _register_woq_mm_int8_pattern4():
|
|
_woq_pattern = CallFunction(
|
|
aten.mul.Tensor,
|
|
CallFunction(
|
|
aten.mm.default,
|
|
KeywordArg("x"),
|
|
CallFunction(
|
|
prims.convert_element_type.default,
|
|
CallFunction(
|
|
aten.permute.default,
|
|
KeywordArg("weight"),
|
|
Arg(),
|
|
),
|
|
Arg(),
|
|
),
|
|
),
|
|
KeywordArg("scales"),
|
|
)
|
|
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
|
|
|
|
|
|
def _register_int8_woq_concat_linear_pattern():
|
|
def _create_wgt_node(wgt_node_name: str):
|
|
return CallFunction(
|
|
prims.convert_element_type.default,
|
|
CallFunction(
|
|
aten.permute.default,
|
|
KeywordArg(wgt_node_name),
|
|
Arg(),
|
|
),
|
|
Arg(),
|
|
)
|
|
|
|
cat_wgt = CallFunction(
|
|
aten.cat.default, [_create_wgt_node(wgt) for wgt in ["w1", "w2", "w3"]], 1
|
|
)
|
|
|
|
_woq_pattern = CallFunction(
|
|
aten.mul.Tensor,
|
|
CallFunction(aten.mm.default, KeywordArg("x"), cat_wgt),
|
|
KeywordArg("scales"),
|
|
)
|
|
_register_concat_linear_int8_woq_lowering(
|
|
_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape
|
|
)
|
|
|
|
|
|
def _register_quantization_lowerings():
|
|
_register_quantization_unary_lowering()
|
|
_register_quantization_binary_lowering()
|
|
_register_quantization_maxpool2d()
|
|
_register_quantization_cat()
|
|
_register_quantization_reshape()
|
|
|
|
|
|
def _register_woq_lowerings():
|
|
_register_woq_mm_int8_pattern1()
|
|
_register_woq_mm_int8_pattern2()
|
|
_register_woq_mm_int8_pattern3()
|
|
_register_woq_mm_int8_pattern4()
|
|
|
|
|
|
def _is_valid_dequant_promotion_pattern(dtype=torch.float32):
|
|
def _inner(match):
|
|
assert dtype in [torch.float32, torch.bfloat16]
|
|
dequant_pattern_end_node = match.output_node()
|
|
if dequant_pattern_end_node.target not in [
|
|
quantized_decomposed.dequantize_per_tensor.default,
|
|
quantized_decomposed.dequantize_per_tensor.tensor,
|
|
prims.convert_element_type.default,
|
|
aten.reshape.default,
|
|
]:
|
|
return False
|
|
|
|
if dequant_pattern_end_node.target is aten.reshape.default:
|
|
dequant_node = (
|
|
dequant_pattern_end_node.args[
|
|
0
|
|
] # pattern: linear <- reshape <- dequant
|
|
if dtype == torch.float32
|
|
else dequant_pattern_end_node.args[0].args[
|
|
0
|
|
] # pattern: linear <- reshape <- to_bf16 <- dequant
|
|
)
|
|
else:
|
|
dequant_node = (
|
|
dequant_pattern_end_node # pattern: linear <- dequant
|
|
if dtype == torch.float32
|
|
else dequant_pattern_end_node.args[
|
|
0
|
|
] # pattern: linear <- to_bf16 <- dequant
|
|
)
|
|
|
|
if (
|
|
dequant_node.target
|
|
in [
|
|
quantized_decomposed.dequantize_per_tensor.default,
|
|
quantized_decomposed.dequantize_per_tensor.tensor,
|
|
]
|
|
and len(list(dequant_pattern_end_node.users)) > 1
|
|
):
|
|
# If dequant pattern has more than 1 users, then do dequant promoted
|
|
return True
|
|
return False
|
|
|
|
return _inner
|
|
|
|
|
|
def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
|
|
@register_freezing_graph_pattern(
|
|
pattern,
|
|
extra_check=_is_valid_dequant_promotion_pattern(dtype),
|
|
pass_number=pass_number,
|
|
)
|
|
def dequant_promotion(match: Match, *args, **kwargs):
|
|
# Dequant_promotion will transform
|
|
# graph 1:
|
|
# quant
|
|
# + - - - | - - - +
|
|
# | dequant |
|
|
# | / \ |
|
|
# | node1 node2 |
|
|
# + - | - - - | - +
|
|
# quant quant
|
|
# into:
|
|
# graph 2:
|
|
# quant
|
|
# + - - / - \ - - +
|
|
# |dequant dequant|
|
|
# | | | |
|
|
# | node1 node2 |
|
|
# + - | - - - | - +
|
|
# quant quant
|
|
# In graph 1, the dequant node is shared by node1 and node2,
|
|
# as a result, neither node1 nor node2 could form an int8
|
|
# fusion pattern.
|
|
# After this transformation, the graph 2 could hit the int8
|
|
# fusion pattern: dequant-node-quant, respectively for
|
|
# node1 and node2.
|
|
assert dtype in [torch.float32, torch.bfloat16]
|
|
|
|
def clone_to_new_node(graph, source_node, user_node):
|
|
# Clone the source_node to a new node
|
|
# Replace user_node's input from source_node to new_node
|
|
assert source_node.op == "call_function", (
|
|
"clone_to_new_node only support node.op call_function"
|
|
)
|
|
with graph.inserting_before(user_node):
|
|
new_node = graph.call_function(
|
|
source_node.target,
|
|
args=source_node.args,
|
|
kwargs=source_node.kwargs,
|
|
)
|
|
new_node.meta = copy.copy(source_node.meta)
|
|
user_node.replace_input_with(source_node, new_node)
|
|
return new_node
|
|
|
|
# Find the start node and end node of a dequant pattern
|
|
# * End node should be the match.output_node()
|
|
# * Start node should be the node of dequantize_per_tensor
|
|
dequant_pattern_end_node = match.output_node()
|
|
assert dequant_pattern_end_node.target in [
|
|
quantized_decomposed.dequantize_per_tensor.default,
|
|
quantized_decomposed.dequantize_per_tensor.tensor,
|
|
prims.convert_element_type.default,
|
|
aten.reshape.default,
|
|
]
|
|
|
|
# For a dequant pattern, we should expect see the node list as:
|
|
# * OPT(aten.reshape.default)
|
|
# * OPT(prims.convert_element_type.default) (to_bf16)
|
|
# * dequantize_per_tensor
|
|
def _find_first_node_in_dequant_pattern(_node):
|
|
if _node.target in [
|
|
quantized_decomposed.dequantize_per_tensor.default,
|
|
quantized_decomposed.dequantize_per_tensor.tensor,
|
|
]:
|
|
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node
|
|
return _node
|
|
else:
|
|
assert len(_node.args) >= 1, (
|
|
"In in dequant pattern, each node should have more than 1 arg."
|
|
)
|
|
return _find_first_node_in_dequant_pattern(_node.args[0])
|
|
|
|
dequant_pattern_start_node = _find_first_node_in_dequant_pattern(
|
|
dequant_pattern_end_node
|
|
)
|
|
|
|
assert dequant_pattern_start_node.target in [
|
|
quantized_decomposed.dequantize_per_tensor.default,
|
|
quantized_decomposed.dequantize_per_tensor.tensor,
|
|
]
|
|
|
|
# Clone the dequant pattern for each user node
|
|
graph = match.graph
|
|
user_node_list = list(dequant_pattern_end_node.users)
|
|
for user_node in user_node_list[1:]:
|
|
_source_node = dequant_pattern_end_node
|
|
_user_node = user_node
|
|
while _source_node != dequant_pattern_start_node.args[0]:
|
|
_user_node = clone_to_new_node(graph, _source_node, _user_node)
|
|
_source_node = _source_node.args[0] # type: ignore[assignment]
|
|
|
|
counters["inductor"]["dequant_promotion_matcher_count"] += 1
|
|
counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes)
|
|
|
|
|
|
def _is_valid_dequant_conv_pattern(dtype):
|
|
def _inner(match):
|
|
# Here we do some further check to ensure:
|
|
# 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now.
|
|
# 2. The dequant pattern has only 1 user of conv2d node.
|
|
# If these conditions don't meet, we will not
|
|
# insert weight prepack node into the matched pattern.
|
|
conv_node = match.output_node()
|
|
assert conv_node.target is aten.convolution.default
|
|
input_meta_value = conv_node.args[0].meta.get("val")
|
|
weight_meta_value = conv_node.args[1].meta.get("val")
|
|
for meta_value in [input_meta_value, weight_meta_value]:
|
|
if (
|
|
meta_value is None
|
|
or (meta_value.device.type != "cpu" and meta_value.device.type != "xpu")
|
|
or meta_value.dim() not in [3, 4]
|
|
):
|
|
# Only support conv1d/2d now
|
|
return False
|
|
|
|
assert dtype in [torch.float32, torch.bfloat16]
|
|
|
|
if dtype == torch.float32:
|
|
dequant_node = conv_node.args[0]
|
|
else:
|
|
convert_to_bf16 = conv_node.args[0]
|
|
dequant_node = convert_to_bf16.args[0]
|
|
|
|
if len(list(dequant_node.users)) != 1:
|
|
# Ensure the dequant pattern only has 1 user
|
|
# since we will delete the dequant pattern here
|
|
return False
|
|
return True
|
|
|
|
return _inner
|
|
|
|
|
|
def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32):
|
|
@register_freezing_graph_pattern(
|
|
pattern,
|
|
extra_check=_is_valid_dequant_conv_pattern(dtype),
|
|
pass_number=pass_number,
|
|
)
|
|
def qconv_weight_prepack(match: Match, *args, **kwargs):
|
|
"""
|
|
Match the pattern:
|
|
int8 activation
|
|
|
|
|
dequant_per_tensor
|
|
|
|
|
Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
|
|
|
|
Insert weight prepack node and change the pattern to:
|
|
int8 activation
|
|
|
|
|
onednn.qconv_pointwise <- onednn.qconv_prepack <- int8_weight
|
|
"""
|
|
assert dtype in [torch.float32, torch.bfloat16]
|
|
conv_node = match.output_node()
|
|
assert conv_node.target is aten.convolution.default
|
|
if dtype == torch.float32:
|
|
dequant_node = conv_node.args[0]
|
|
else:
|
|
convert_to_bf16 = conv_node.args[0]
|
|
dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr]
|
|
has_clone_to_channel_last_node_in_pattern = (
|
|
conv_node.args[1].target is aten.clone.default # type: ignore[union-attr]
|
|
)
|
|
clone_node = (
|
|
conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None
|
|
)
|
|
|
|
if dtype == torch.float32:
|
|
dequant_per_channel = (
|
|
clone_node.args[0] # type: ignore[union-attr]
|
|
if has_clone_to_channel_last_node_in_pattern
|
|
else conv_node.args[1]
|
|
)
|
|
else:
|
|
weight_to_bf16_node = (
|
|
clone_node.args[0] # type: ignore[union-attr]
|
|
if has_clone_to_channel_last_node_in_pattern
|
|
else conv_node.args[1]
|
|
)
|
|
dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr]
|
|
|
|
assert (
|
|
dequant_per_channel.target # type: ignore[union-attr]
|
|
is quantized_decomposed.dequantize_per_channel.default
|
|
)
|
|
|
|
# Activation QParams
|
|
qx, x_zp, x_scale = (
|
|
kwargs["x"],
|
|
kwargs["x_zp"],
|
|
kwargs["x_scale"],
|
|
)
|
|
|
|
# Weight QParams
|
|
qw, w_scale, w_zp = (
|
|
kwargs["q_weight"],
|
|
kwargs["w_scale"],
|
|
kwargs["w_zp"],
|
|
)
|
|
|
|
# Conv Params
|
|
bias, stride, padding, dilation, groups = (
|
|
kwargs["b"],
|
|
kwargs["stride"],
|
|
kwargs["padding"],
|
|
kwargs["dilation"],
|
|
kwargs["groups"],
|
|
)
|
|
|
|
x_shape = qx.meta.get("tensor_meta").shape
|
|
if has_free_symbols(x_shape):
|
|
# For dynamic shape case, we can't get activation shape ahead of runtime.
|
|
x_shape = None
|
|
graph = match.graph
|
|
with graph.inserting_before(conv_node):
|
|
# Insert weight prepack node and the QConv node
|
|
packed_weight_inputs = (
|
|
qw,
|
|
w_scale,
|
|
x_scale,
|
|
x_zp,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
x_shape,
|
|
)
|
|
packed_weight_op = torch.ops.onednn.qconv_prepack
|
|
prepack_weight_node = graph.call_function(
|
|
packed_weight_op, args=packed_weight_inputs
|
|
)
|
|
|
|
new_args: tuple[Any, ...] = (
|
|
qx,
|
|
x_scale,
|
|
x_zp,
|
|
prepack_weight_node,
|
|
w_scale,
|
|
w_zp,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
1.0, # output_scale
|
|
0, # output_zero_point
|
|
dtype, # output_dtype
|
|
"none", # attr
|
|
[], # scalars
|
|
"", # algorithm
|
|
)
|
|
new_conv_node = graph.call_function(
|
|
torch.ops.onednn.qconv_pointwise.default, args=new_args
|
|
)
|
|
conv_node.replace_all_uses_with(new_conv_node)
|
|
new_conv_node.meta.update(conv_node.meta)
|
|
|
|
# Erase the original conv node
|
|
graph.erase_node(conv_node)
|
|
# Erase the dequant pattern
|
|
if dtype == torch.bfloat16:
|
|
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type]
|
|
graph.erase_node(dequant_node) # type: ignore[arg-type]
|
|
# Erase the dequant per channel pattern
|
|
if clone_node is not None:
|
|
graph.erase_node(clone_node) # type: ignore[arg-type]
|
|
if dtype == torch.bfloat16:
|
|
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type]
|
|
graph.erase_node(dequant_per_channel) # type: ignore[arg-type]
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1
|
|
counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len(
|
|
match.nodes
|
|
)
|
|
|
|
|
|
def _generate_dequant_convolution_node_pattern(
|
|
_dequant_per_channel_pattern, dtype=torch.float32
|
|
):
|
|
assert dtype in [torch.float32, torch.bfloat16]
|
|
dequant_convolution_node_pattern = CallFunction(
|
|
aten.convolution.default,
|
|
_may_generate_pattern_with_dtype_convert(
|
|
get_dequantize_per_tensor_activation_pattern(),
|
|
KeywordArg("autocast_act_dtype"),
|
|
dtype == torch.bfloat16,
|
|
),
|
|
_dequant_per_channel_pattern,
|
|
KeywordArg("b"),
|
|
KeywordArg("stride"),
|
|
KeywordArg("padding"),
|
|
KeywordArg("dilation"),
|
|
KeywordArg("is_transposed"),
|
|
KeywordArg("out_padding"),
|
|
KeywordArg("groups"),
|
|
)
|
|
return dequant_convolution_node_pattern
|
|
|
|
|
|
def _generate_qconv_weight_prepack_patterns(dtype=torch.float32):
|
|
assert dtype in [torch.float32, torch.bfloat16]
|
|
return (
|
|
_generate_dequant_convolution_node_pattern(
|
|
dequantize_per_channel_weight_pattern
|
|
if dtype == torch.float32
|
|
else dequantize_per_channel_to_bf16_weight_pattern,
|
|
dtype,
|
|
),
|
|
# There is another pattern due to the pass of convert_conv_weights_to_channels_last
|
|
# https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
|
|
# Depend on some heuristics, it may or may not insert to(channel_last) node
|
|
# between convolution and dequant_per_channel node
|
|
_generate_dequant_convolution_node_pattern(
|
|
dequantize_per_channel_clone_weight_pattern
|
|
if dtype == torch.float32
|
|
else dequantize_per_channel_to_bf16_clone_weight_pattern,
|
|
dtype,
|
|
),
|
|
)
|
|
|
|
|
|
def _get_linear_node(match, input_dim_exceeds_two, input_contiguous):
|
|
output_reshape_node = None
|
|
if input_dim_exceeds_two:
|
|
if input_contiguous:
|
|
output_reshape_node = match.output_node()
|
|
assert output_reshape_node.target is aten.reshape.default
|
|
linear_node = output_reshape_node.args[0]
|
|
else:
|
|
linear_nodes = filter_nodes(match.nodes, aten.bmm.default)
|
|
assert len(linear_nodes) == 1
|
|
linear_node = linear_nodes[0]
|
|
else:
|
|
linear_node = match.output_node()
|
|
|
|
assert linear_node.target in (
|
|
aten.addmm.default,
|
|
aten.mm.default,
|
|
aten.bmm.default,
|
|
)
|
|
return linear_node, output_reshape_node
|
|
|
|
|
|
def _get_linear_dq_node(
|
|
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
|
|
):
|
|
act_reshape_node = None
|
|
activation_to_bf16_node = None
|
|
act_expand_node = None
|
|
if input_dim_exceeds_two:
|
|
if input_contiguous:
|
|
act_reshape_node = linear_node.args[input_index]
|
|
assert act_reshape_node.target is aten.reshape.default
|
|
if dtype == torch.float32:
|
|
# pattern: linear -> reshape -> dequant
|
|
dequant_node = act_reshape_node.args[0]
|
|
else:
|
|
# pattern: linear -> reshape -> to_bf16 -> dequant
|
|
activation_to_bf16_node = act_reshape_node.args[0]
|
|
dequant_node = activation_to_bf16_node.args[0]
|
|
else:
|
|
# bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous
|
|
act_expand_node = linear_node.args[input_index]
|
|
assert act_expand_node.target is aten.expand.default
|
|
if dtype == torch.float32:
|
|
dequant_node = act_expand_node.args[0]
|
|
else:
|
|
activation_to_bf16_node = act_expand_node.args[0]
|
|
dequant_node = activation_to_bf16_node.args[0]
|
|
else:
|
|
if dtype == torch.float32:
|
|
# pattern: linear -> dequant
|
|
dequant_node = linear_node.args[input_index]
|
|
else:
|
|
# pattern: linear -> to_bf16 -> dequant
|
|
activation_to_bf16_node = linear_node.args[input_index]
|
|
dequant_node = activation_to_bf16_node.args[0]
|
|
return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node
|
|
|
|
|
|
def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous):
|
|
def _inner(match):
|
|
# Check dequant pattern has only 1 user.
|
|
(
|
|
linear_node,
|
|
_,
|
|
) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
|
|
|
|
input_index = 1 if linear_node.target is aten.addmm.default else 0
|
|
assert dtype in [torch.float32, torch.bfloat16]
|
|
(
|
|
dequant_node,
|
|
_,
|
|
_,
|
|
_,
|
|
) = _get_linear_dq_node(
|
|
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
|
|
)
|
|
|
|
assert dequant_node.target in [
|
|
quantized_decomposed.dequantize_per_tensor.default,
|
|
quantized_decomposed.dequantize_per_tensor.tensor,
|
|
]
|
|
|
|
if len(list(dequant_node.users)) != 1:
|
|
# Ensure the dequant pattern only has 1 user
|
|
# since we will delete the dequant pattern here
|
|
return False
|
|
|
|
# Extra check for bmm pattern
|
|
if input_dim_exceeds_two and not input_contiguous:
|
|
# Check for act
|
|
# Act expand size should be exactly same as act size
|
|
act_expand_size = match.kwargs["act_expand_size"]
|
|
act_node = match.kwargs["x"]
|
|
if not (
|
|
hasattr(act_node, "meta")
|
|
and isinstance(act_node.meta.get("val", None), torch.Tensor)
|
|
and (act_node.meta["val"].size() == torch.Size(act_expand_size))
|
|
):
|
|
return False
|
|
|
|
# Check for wgt
|
|
# wgt permute dims should be [1, 0]
|
|
wgt_permute_dims = match.kwargs["permute_axes"]
|
|
if wgt_permute_dims != [1, 0]:
|
|
return False
|
|
|
|
# Check below wgt size items:
|
|
# wgt before expand should with dim 2
|
|
# Expand size should with dim 3
|
|
# Expand size[0] should same as act size[0]
|
|
# Expand size[1] should same as wgt size[1]
|
|
# Expand size[2] should same as wgt size[0]
|
|
qweight_node = match.kwargs["q_weight"]
|
|
wgt_expand_size = match.kwargs["wgt_expand_size"]
|
|
if not (
|
|
hasattr(qweight_node, "meta")
|
|
and isinstance(qweight_node.meta.get("val", None), torch.Tensor)
|
|
and len(qweight_node.meta["val"].size()) == 2
|
|
and len(wgt_expand_size) == 3
|
|
and wgt_expand_size[0] == act_node.meta["val"].size()[0]
|
|
and wgt_expand_size[1] == qweight_node.meta["val"].size()[1]
|
|
and wgt_expand_size[2] == qweight_node.meta["val"].size()[0]
|
|
):
|
|
return False
|
|
|
|
return True
|
|
|
|
return _inner
|
|
|
|
|
|
def _register_qlinear_weight_prepack_pass(
|
|
pattern,
|
|
pass_number,
|
|
dtype=torch.float32,
|
|
input_dim_exceeds_two=False,
|
|
input_contiguous=True,
|
|
):
|
|
@register_freezing_graph_pattern(
|
|
pattern,
|
|
extra_check=_is_valid_dequant_linear_pattern(
|
|
dtype, input_dim_exceeds_two, input_contiguous
|
|
),
|
|
pass_number=pass_number,
|
|
)
|
|
def qlinear_weight_prepack(match: Match, *args, **kwargs):
|
|
"""
|
|
Match the pattern:
|
|
int8 activation
|
|
|
|
|
dequant_per_tensor
|
|
|
|
|
mm/addmm <- t <- dequant_per_channel <- int8_weight
|
|
|
|
Insert weight prepack node and change the pattern to:
|
|
int8 activation
|
|
|
|
|
onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight
|
|
"""
|
|
assert dtype in [torch.float32, torch.bfloat16]
|
|
(
|
|
linear_node,
|
|
output_reshape_node,
|
|
) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
|
|
input_index = 1 if linear_node.target is aten.addmm.default else 0
|
|
weight_index = input_index + 1
|
|
|
|
(
|
|
dequant_node,
|
|
act_reshape_node,
|
|
activation_to_bf16_node,
|
|
act_expand_node,
|
|
) = _get_linear_dq_node(
|
|
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
|
|
)
|
|
|
|
if input_dim_exceeds_two and not input_contiguous:
|
|
wgt_expand_node = linear_node.args[weight_index]
|
|
assert wgt_expand_node.target is aten.expand.default
|
|
t_node = wgt_expand_node.args[0]
|
|
else:
|
|
t_node = linear_node.args[weight_index]
|
|
|
|
if dtype == torch.float32:
|
|
dequant_per_channel = t_node.args[0]
|
|
else:
|
|
weight_to_bf16_node = t_node.args[0]
|
|
dequant_per_channel = weight_to_bf16_node.args[0]
|
|
assert (
|
|
dequant_per_channel.target
|
|
is quantized_decomposed.dequantize_per_channel.default
|
|
)
|
|
|
|
# Activation QParams
|
|
qx, x_zp, x_scale = (
|
|
kwargs["x"],
|
|
kwargs["x_zp"],
|
|
kwargs["x_scale"],
|
|
)
|
|
|
|
# Weight QParams
|
|
qw, w_scale, w_zp = (
|
|
kwargs["q_weight"],
|
|
kwargs["w_scale"],
|
|
kwargs["w_zp"],
|
|
)
|
|
|
|
# Params
|
|
bias = kwargs["b"] if "b" in kwargs else None
|
|
|
|
x_shape = qx.meta.get("tensor_meta").shape
|
|
if has_free_symbols(x_shape):
|
|
# For dynamic shape case, we can't get activation shape ahead of runtime.
|
|
x_shape = None
|
|
graph = match.graph
|
|
with graph.inserting_before(linear_node):
|
|
# Insert weight prepack node and the qlinear node
|
|
packed_weight_inputs = (
|
|
qw,
|
|
x_shape,
|
|
)
|
|
packed_weight_op = torch.ops.onednn.qlinear_prepack
|
|
prepack_weight_node = graph.call_function(
|
|
packed_weight_op, args=packed_weight_inputs
|
|
)
|
|
|
|
new_args: tuple[Any, ...] = (
|
|
qx,
|
|
x_scale,
|
|
x_zp,
|
|
prepack_weight_node,
|
|
w_scale,
|
|
w_zp,
|
|
bias,
|
|
1.0, # output_scale
|
|
0, # output_zero_point
|
|
dtype, # output_dtype
|
|
"none", # post op name
|
|
[], # post op args
|
|
"", # post op algorithm
|
|
)
|
|
Node = torch.fx.node.Node
|
|
if isinstance(x_scale, Node) and isinstance(x_zp, Node):
|
|
new_linear_node = graph.call_function(
|
|
torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
|
|
)
|
|
else:
|
|
new_linear_node = graph.call_function(
|
|
torch.ops.onednn.qlinear_pointwise.default, args=new_args
|
|
)
|
|
if input_dim_exceeds_two:
|
|
if input_contiguous:
|
|
output_reshape_node.replace_all_uses_with(new_linear_node)
|
|
new_linear_node.meta.update(output_reshape_node.meta)
|
|
else:
|
|
if bias:
|
|
output_add_node_for_bias = match.output_node()
|
|
assert output_add_node_for_bias.target is aten.add.Tensor
|
|
output_add_node_for_bias.replace_all_uses_with(new_linear_node)
|
|
new_linear_node.meta.update(output_add_node_for_bias.meta)
|
|
else:
|
|
linear_node.replace_all_uses_with(new_linear_node)
|
|
new_linear_node.meta.update(linear_node.meta)
|
|
else:
|
|
linear_node.replace_all_uses_with(new_linear_node)
|
|
new_linear_node.meta.update(linear_node.meta)
|
|
|
|
# Erase the original linear node
|
|
if input_dim_exceeds_two:
|
|
if input_contiguous:
|
|
graph.erase_node(output_reshape_node)
|
|
elif not input_contiguous and bias:
|
|
graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined]
|
|
graph.erase_node(linear_node)
|
|
if input_dim_exceeds_two:
|
|
if input_contiguous:
|
|
graph.erase_node(act_reshape_node)
|
|
else:
|
|
graph.erase_node(act_expand_node)
|
|
graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined]
|
|
if dtype == torch.bfloat16:
|
|
graph.erase_node(activation_to_bf16_node)
|
|
# Erase the dequant pattern
|
|
graph.erase_node(dequant_node)
|
|
# Erase the dequant per channel pattern
|
|
graph.erase_node(t_node)
|
|
if dtype == torch.bfloat16:
|
|
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
|
|
graph.erase_node(dequant_per_channel)
|
|
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
|
|
match.nodes
|
|
)
|
|
|
|
|
|
def _generate_dequant_linear_node_pattern(
|
|
_dequant_per_channel_pattern,
|
|
dtype=torch.float32,
|
|
input_dim_exceeds_two=False,
|
|
is_tensor_overload=False,
|
|
):
|
|
assert dtype in [torch.float32, torch.bfloat16]
|
|
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
|
|
dequant_linear_bias_pattern = _may_generate_pattern_with_reshape(
|
|
CallFunction(
|
|
aten.addmm.default,
|
|
KeywordArg("b"),
|
|
_may_generate_pattern_with_reshape(
|
|
_may_generate_pattern_with_dtype_convert(
|
|
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
|
|
KeywordArg("autocast_act_dtype"),
|
|
dtype == torch.bfloat16,
|
|
),
|
|
KeywordArg("act_reshape_size"),
|
|
input_dim_exceeds_two,
|
|
),
|
|
t_pattern,
|
|
),
|
|
KeywordArg("output_reshape_size"),
|
|
input_dim_exceeds_two,
|
|
)
|
|
dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape(
|
|
CallFunction(
|
|
aten.mm.default,
|
|
_may_generate_pattern_with_reshape(
|
|
_may_generate_pattern_with_dtype_convert(
|
|
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
|
|
KeywordArg("autocast_act_dtype"),
|
|
dtype == torch.bfloat16,
|
|
),
|
|
KeywordArg("act_reshape_size"),
|
|
input_dim_exceeds_two,
|
|
),
|
|
t_pattern,
|
|
),
|
|
KeywordArg("output_reshape_size"),
|
|
input_dim_exceeds_two,
|
|
)
|
|
return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern
|
|
|
|
|
|
def _generate_dequant_bmm_node_pattern(
|
|
_dequant_per_channel_pattern,
|
|
dtype=torch.float32,
|
|
with_bias=False,
|
|
is_tensor_overload=False,
|
|
):
|
|
# When activation of linear dim exceed 2 and not contiguous
|
|
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
|
|
|
|
assert dtype in [torch.float32, torch.bfloat16]
|
|
dequant_bmm_pattern = CallFunction(
|
|
aten.bmm.default,
|
|
CallFunction(
|
|
aten.expand.default,
|
|
_may_generate_pattern_with_dtype_convert(
|
|
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
|
|
KeywordArg("autocast_act_dtype"),
|
|
dtype == torch.bfloat16,
|
|
),
|
|
KeywordArg("act_expand_size"),
|
|
),
|
|
CallFunction(
|
|
aten.expand.default,
|
|
t_pattern,
|
|
KeywordArg("wgt_expand_size"),
|
|
),
|
|
)
|
|
|
|
def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias):
|
|
if _with_bias:
|
|
return CallFunction(
|
|
aten.add.Tensor,
|
|
_dequant_bmm_pattern,
|
|
KeywordArg("b"),
|
|
)
|
|
else:
|
|
return _dequant_bmm_pattern
|
|
|
|
return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias)
|
|
|
|
|
|
def _generate_qlinear_weight_prepack_patterns(
|
|
dtype=torch.float32,
|
|
input_dim_exceeds_two=False,
|
|
input_contiguous=True,
|
|
with_bias=False,
|
|
is_tensor_overload=False,
|
|
):
|
|
if input_dim_exceeds_two and not input_contiguous:
|
|
return _generate_dequant_bmm_node_pattern(
|
|
dequantize_per_channel_weight_pattern,
|
|
dtype,
|
|
with_bias,
|
|
is_tensor_overload,
|
|
)
|
|
else:
|
|
return _generate_dequant_linear_node_pattern(
|
|
dequantize_per_channel_weight_pattern,
|
|
dtype,
|
|
input_dim_exceeds_two,
|
|
is_tensor_overload,
|
|
)
|
|
|
|
|
|
def _generate_linear_dynamic_fp16_pattern(
|
|
_dequant_weight_pattern,
|
|
input_dim_exceeds_two=False,
|
|
input_contiguous=True,
|
|
relu_fused=False,
|
|
):
|
|
dtype = torch.float32
|
|
t_pattern = _generate_linear_t_pattern(_dequant_weight_pattern, dtype)
|
|
|
|
if input_dim_exceeds_two and not input_contiguous:
|
|
# pattern is
|
|
# x -> expand -> bmm (-> add) (-> relu)
|
|
# w -> dequant -> permute -> expand /
|
|
pattern_no_bias = CallFunction(
|
|
aten.bmm.default,
|
|
CallFunction(
|
|
aten.expand.default,
|
|
KeywordArg("x"),
|
|
KeywordArg("act_expand_size"),
|
|
),
|
|
CallFunction(
|
|
aten.expand.default,
|
|
t_pattern,
|
|
KeywordArg("wgt_expand_size"),
|
|
),
|
|
)
|
|
pattern_with_bias = CallFunction(
|
|
aten.add.Tensor,
|
|
pattern_no_bias,
|
|
KeywordArg("b"),
|
|
)
|
|
if relu_fused:
|
|
pattern_with_bias = CallFunction(aten.relu.default, pattern_with_bias)
|
|
pattern_no_bias = CallFunction(aten.relu.default, pattern_no_bias)
|
|
return pattern_with_bias, pattern_no_bias
|
|
|
|
x_pattern_with_reshape = _may_generate_pattern_with_reshape(
|
|
KeywordArg("x"),
|
|
KeywordArg("act_reshape_size"),
|
|
input_dim_exceeds_two,
|
|
)
|
|
dequant_linear_bias_pattern = generate_pattern_with_unary(
|
|
_may_generate_pattern_with_reshape(
|
|
CallFunction(
|
|
aten.addmm.default,
|
|
KeywordArg("b"),
|
|
x_pattern_with_reshape,
|
|
t_pattern,
|
|
),
|
|
KeywordArg("output_reshape_size"),
|
|
input_dim_exceeds_two,
|
|
),
|
|
aten.relu.default if relu_fused else None,
|
|
)
|
|
dequant_linear_no_bias_pattern = generate_pattern_with_unary(
|
|
_may_generate_pattern_with_reshape(
|
|
CallFunction(
|
|
aten.mm.default,
|
|
x_pattern_with_reshape,
|
|
t_pattern,
|
|
),
|
|
KeywordArg("output_reshape_size"),
|
|
input_dim_exceeds_two,
|
|
),
|
|
aten.relu.default if relu_fused else None,
|
|
)
|
|
return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern
|
|
|
|
|
|
def _register_dequant_promotion():
|
|
dequant_pattern_cases = itertools.product(
|
|
[torch.float32, torch.bfloat16], [True, False], [True, False]
|
|
)
|
|
for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases:
|
|
# 4 dequantization patterns will be matched based on the dtype and input dimension size.
|
|
# Case 1: int8-mixed-fp32, input dim size is 2
|
|
# Case 2: int8-mixed-fp32, input dim size exceeds 2
|
|
# Case 3: int8-mixed-bf16, input dim size is 2
|
|
# Case 4: int8-mixed-bf16, input dim size exceeds 2
|
|
# quant
|
|
# + - - - - | - - - - +
|
|
# | dequant |
|
|
# | | |
|
|
# | OPT(to_bf16) |
|
|
# | | |
|
|
# | OPT(reshape) |
|
|
# | / \ |
|
|
# | node1 node2 |
|
|
# + - - | - - - | - - +
|
|
# OPT(reshape) OPT(reshape)
|
|
# + - - | - - - | - - +
|
|
# OPT(to_fp32) OPT(to_fp32)
|
|
# + - - | - - - | - - +
|
|
# quant quant
|
|
_register_dequant_promotion_pass(
|
|
_may_generate_pattern_with_reshape(
|
|
_may_generate_pattern_with_dtype_convert(
|
|
get_dequantize_per_tensor_activation_pattern(
|
|
is_tensor_overload=is_tensor_overload
|
|
),
|
|
KeywordArg("autocast_act_dtype"),
|
|
dtype == torch.bfloat16,
|
|
),
|
|
KeywordArg("act_reshape_size"),
|
|
with_reshape=input_dim_exceeds_two,
|
|
),
|
|
pass_number=0,
|
|
dtype=dtype,
|
|
) # pass_number=0 to run before weight prepack
|
|
|
|
|
|
def _register_qconv_weight_prepack():
|
|
for dtype in [torch.float32, torch.bfloat16]:
|
|
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype)
|
|
for weight_prepack_pattern in weight_prepack_patterns:
|
|
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
|
|
_register_qconv_weight_prepack_pass(
|
|
weight_prepack_pattern, pass_number=1, dtype=dtype
|
|
)
|
|
|
|
|
|
def _register_qlinear_weight_prepack():
|
|
# 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous.
|
|
# Then convert the pattern into a QLinear node with int8_fp32/bf16.
|
|
# Case 1: int8-mixed-fp32, input dim size is 2
|
|
# Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous
|
|
# Case 3: int8-mixed-bf16, input dim size is 2
|
|
# Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous
|
|
|
|
# + - - - - | - - - - - - | - - - - - +
|
|
# | dq_per_tensor dq_per_channel |
|
|
# | | | |
|
|
# | OPT(to_bf16) OPT(to_bf16) |
|
|
# | | | |
|
|
# | OPT(reshape) permute |
|
|
# | \ / |
|
|
# | addmm/mm |
|
|
# | | |
|
|
# | OPT(reshape) |
|
|
|
|
# Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous
|
|
# Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous
|
|
|
|
# + - - - - | - - - - - - | - - - - - +
|
|
# | dq_per_tensor dq_per_channel |
|
|
# | | | |
|
|
# | OPT(to_bf16) OPT(to_bf16) |
|
|
# | | | |
|
|
# | expand permute |
|
|
# | \ | |
|
|
# | expand |
|
|
# | / |
|
|
# | bmm |
|
|
# | | |
|
|
# | OPT(add) |
|
|
|
|
linear_weight_prepack_cases = itertools.product(
|
|
[torch.float32, torch.bfloat16], [True, False], [True, False]
|
|
)
|
|
|
|
# Step 1: register patterns from mm and addmm
|
|
for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases:
|
|
weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns(
|
|
dtype,
|
|
input_dim_exceeds_two,
|
|
is_tensor_overload=is_tensor_overload,
|
|
)
|
|
for weight_prepack_pattern in weight_prepack_patterns:
|
|
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
|
|
_register_qlinear_weight_prepack_pass(
|
|
weight_prepack_pattern,
|
|
pass_number=1,
|
|
dtype=dtype,
|
|
input_dim_exceeds_two=input_dim_exceeds_two,
|
|
)
|
|
|
|
# Step 2: register patterns from bmm
|
|
# Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous
|
|
# refer to:
|
|
# https://github.com/pytorch/pytorch/blob/80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968
|
|
# in this case, we can convert it back to qlinear
|
|
for dtype, with_bias, is_tensor_overload in itertools.product(
|
|
[torch.float32, torch.bfloat16], [True, False], [True, False]
|
|
):
|
|
bmm_pattern = _generate_qlinear_weight_prepack_patterns(
|
|
dtype=dtype,
|
|
input_dim_exceeds_two=True,
|
|
input_contiguous=False,
|
|
with_bias=with_bias,
|
|
is_tensor_overload=is_tensor_overload,
|
|
)
|
|
_register_qlinear_weight_prepack_pass(
|
|
bmm_pattern,
|
|
pass_number=1
|
|
if with_bias
|
|
else 2, # if with_bias, there is an output add, so we should try to match it firstly
|
|
dtype=dtype,
|
|
input_dim_exceeds_two=True,
|
|
input_contiguous=False,
|
|
)
|
|
|
|
|
|
def _register_linear_dynamic_fp16_weight_prepack_pass(
|
|
pattern,
|
|
pass_number,
|
|
input_dim_exceeds_two=False,
|
|
input_contiguous=True,
|
|
relu_fused=False,
|
|
):
|
|
def _extra_check_fn(match: Match):
|
|
return match.kwargs["dtype_fp16"] == torch.float16
|
|
|
|
@register_freezing_graph_pattern(
|
|
pattern,
|
|
extra_check=_extra_check_fn,
|
|
pass_number=pass_number,
|
|
)
|
|
def linear_dynamic_fp16_weight_prepack(match: Match, *args, **kwargs):
|
|
"""
|
|
Match the pattern:
|
|
fp32 activation
|
|
|
|
|
mm/addmm <- t <- to_fp32 <- to_fp16 <- weight
|
|
|
|
|
(reshape) <- (relu)
|
|
|
|
OR
|
|
|
|
fp32 activation
|
|
|
|
|
expand
|
|
|
|
|
bmm <- expand <- t <- to_fp32 <- to_fp16 <- weight
|
|
|
|
|
(add) <- (relu)
|
|
|
|
Insert weight prepack node and change the pattern to:
|
|
fp32 activation
|
|
|
|
|
onednn.linear_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight
|
|
(or onednn.linear_relu_dynamic_fp16)
|
|
"""
|
|
# find params
|
|
x = kwargs["x"]
|
|
w = kwargs["w"]
|
|
bias = kwargs["b"] if "b" in kwargs else None
|
|
|
|
# find linear node
|
|
nodes_to_find = [aten.addmm.default, aten.mm.default, aten.bmm.default]
|
|
linear_nodes = []
|
|
for node in nodes_to_find:
|
|
linear_nodes.extend(filter_nodes(match.nodes, node))
|
|
assert len(linear_nodes) == 1
|
|
linear_node = linear_nodes[0]
|
|
assert isinstance(linear_node, torch.fx.node.Node)
|
|
input_index = 1 if linear_node.target is aten.addmm.default else 0
|
|
weight_index = input_index + 1
|
|
|
|
# find relu node
|
|
relu_node = None
|
|
if relu_fused:
|
|
relu_node = match.output_node()
|
|
assert isinstance(relu_node, torch.fx.node.Node)
|
|
|
|
# find reshape node, expand node and add node
|
|
(
|
|
act_reshape_node,
|
|
output_reshape_node,
|
|
expand_x_node,
|
|
expand_w_node,
|
|
add_bias_node,
|
|
) = (None, None, None, None, None)
|
|
t_node = None
|
|
if input_dim_exceeds_two:
|
|
if input_contiguous:
|
|
act_reshape_node = linear_node.args[input_index]
|
|
t_node = linear_node.args[weight_index]
|
|
output_reshape_node = next(iter(linear_node.users))
|
|
assert output_reshape_node.target is aten.reshape.default
|
|
else:
|
|
expand_x_node = linear_node.args[input_index]
|
|
expand_w_node = linear_node.args[weight_index]
|
|
assert isinstance(expand_w_node, torch.fx.node.Node)
|
|
t_node = expand_w_node.args[0]
|
|
if bias:
|
|
add_bias_node = next(iter(linear_node.users))
|
|
assert add_bias_node.target is aten.add.Tensor
|
|
else:
|
|
t_node = linear_node.args[weight_index]
|
|
assert isinstance(t_node, torch.fx.node.Node)
|
|
|
|
w_to_fp32_node = t_node.args[0]
|
|
assert (
|
|
isinstance(w_to_fp32_node, torch.fx.node.Node)
|
|
and w_to_fp32_node.target
|
|
is quantized_decomposed.convert_element_type.no_fuse
|
|
)
|
|
w_to_fp16_node = w_to_fp32_node.args[0]
|
|
assert (
|
|
isinstance(w_to_fp16_node, torch.fx.node.Node)
|
|
and w_to_fp16_node.target
|
|
is quantized_decomposed.convert_element_type.no_fuse
|
|
)
|
|
|
|
x_shape = x.meta.get("tensor_meta").shape
|
|
if has_free_symbols(x_shape):
|
|
# For dynamic shape case, we can't get activation shape ahead of runtime.
|
|
x_shape = None
|
|
graph = match.graph
|
|
with graph.inserting_before(linear_node):
|
|
# Insert weight prepack node and the qlinear node
|
|
packed_weight_inputs = (
|
|
w,
|
|
x_shape,
|
|
)
|
|
packed_weight_op = torch.ops.onednn.linear_prepack_fp16
|
|
prepack_weight_node = graph.call_function(
|
|
packed_weight_op, args=packed_weight_inputs
|
|
)
|
|
|
|
# create new linear node and insert on graph
|
|
new_args: tuple[Any, ...] = (
|
|
x,
|
|
prepack_weight_node,
|
|
bias,
|
|
)
|
|
linear_op = (
|
|
torch.ops.onednn.linear_relu_dynamic_fp16.default
|
|
if relu_fused
|
|
else torch.ops.onednn.linear_dynamic_fp16.default
|
|
)
|
|
new_linear_node = graph.call_function(linear_op, args=new_args)
|
|
out_node = match.output_node()
|
|
out_node.replace_all_uses_with(new_linear_node)
|
|
|
|
# Erase the original nodes in the reverse order
|
|
new_linear_node.meta.update(out_node.meta)
|
|
if relu_node is not None:
|
|
graph.erase_node(relu_node)
|
|
if output_reshape_node is not None:
|
|
graph.erase_node(output_reshape_node)
|
|
if add_bias_node is not None:
|
|
graph.erase_node(add_bias_node)
|
|
graph.erase_node(linear_node)
|
|
if act_reshape_node is not None:
|
|
assert isinstance(act_reshape_node, torch.fx.node.Node)
|
|
graph.erase_node(act_reshape_node)
|
|
if expand_x_node is not None:
|
|
assert isinstance(expand_x_node, torch.fx.node.Node)
|
|
graph.erase_node(expand_x_node)
|
|
if expand_w_node is not None:
|
|
assert isinstance(expand_w_node, torch.fx.node.Node)
|
|
graph.erase_node(expand_w_node)
|
|
graph.erase_node(t_node)
|
|
graph.erase_node(w_to_fp32_node)
|
|
graph.erase_node(w_to_fp16_node)
|
|
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
|
|
match.nodes
|
|
)
|
|
|
|
|
|
def _register_linear_dynamic_fp16_weight_prepack():
|
|
to_dtype_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse
|
|
weight_pattern = CallFunction(
|
|
to_dtype_op,
|
|
CallFunction(
|
|
to_dtype_op,
|
|
KeywordArg("w"),
|
|
KeywordArg("dtype_fp16"),
|
|
),
|
|
KeywordArg("dtype_fp32"),
|
|
)
|
|
cases = itertools.product(
|
|
[False, True], # input_dim_exceeds_two
|
|
[True, False], # input_contiguous
|
|
[False, True], # relu fused
|
|
)
|
|
for input_dim_exceeds_two, input_contiguous, relu_fused in cases:
|
|
patterns = _generate_linear_dynamic_fp16_pattern(
|
|
weight_pattern,
|
|
input_dim_exceeds_two,
|
|
input_contiguous,
|
|
relu_fused,
|
|
)
|
|
for pattern in patterns:
|
|
_register_linear_dynamic_fp16_weight_prepack_pass(
|
|
pattern,
|
|
pass_number=0 if relu_fused else 1,
|
|
input_dim_exceeds_two=input_dim_exceeds_two,
|
|
input_contiguous=input_contiguous,
|
|
relu_fused=relu_fused,
|
|
)
|
|
|
|
|
|
def _register_smooth_quant_int_mm_pattern():
|
|
"""
|
|
The pattern is:
|
|
(no bias) reshape -> _int_mm -> convert_element_type -> (expand ->) mul -> mul -> reshape
|
|
or
|
|
(with bias) pattern_no_bias -> add (-> reshape -> reshape)
|
|
"""
|
|
|
|
# When torch.compile'ing with dynamic=True, the expand node and the two tailing reshape nodes exist
|
|
# When torch.compile'ing with dynamic=False, they don't exist
|
|
def get_pattern_no_bias(expand_a_scale: bool, reshape_a: bool = True):
|
|
return CallFunction(
|
|
aten.mul.Tensor,
|
|
CallFunction(
|
|
aten.mul.Tensor,
|
|
CallFunction(
|
|
prims.convert_element_type.default,
|
|
CallFunction(
|
|
aten._int_mm.default,
|
|
CallFunction(
|
|
aten.reshape.default,
|
|
KeywordArg("a"),
|
|
KeywordArg("in_shape"),
|
|
)
|
|
if reshape_a
|
|
else KeywordArg("a"),
|
|
KeywordArg("b"),
|
|
),
|
|
KeywordArg("dtype"),
|
|
),
|
|
(
|
|
CallFunction(
|
|
aten.expand.default,
|
|
KeywordArg("x_scale"),
|
|
Arg(),
|
|
)
|
|
if expand_a_scale
|
|
else KeywordArg("x_scale")
|
|
),
|
|
),
|
|
KeywordArg("w_scale"),
|
|
)
|
|
|
|
def _with_outer_reshape(pattern):
|
|
return CallFunction(
|
|
aten.reshape.default, pattern, KeywordArg("out_shape_no_bias")
|
|
)
|
|
|
|
# for torch.compile(dynamic=False)
|
|
pattern_no_bias_1 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=False))
|
|
pattern_with_bias_1 = CallFunction(
|
|
aten.add.Tensor,
|
|
pattern_no_bias_1,
|
|
KeywordArg("bias"),
|
|
)
|
|
# for torch.compile(dynamic=True)
|
|
pattern_no_bias_2 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=True))
|
|
pattern_with_bias_2 = CallFunction(
|
|
aten.reshape.default,
|
|
CallFunction(
|
|
aten.reshape.default,
|
|
CallFunction(
|
|
aten.add.Tensor,
|
|
pattern_no_bias_2,
|
|
KeywordArg("bias"),
|
|
),
|
|
Arg(),
|
|
),
|
|
KeywordArg("out_shape_with_bias"),
|
|
)
|
|
|
|
# The following patterns are for torchao int8_dynamic_activation_int8_weight linear,
|
|
# when both activation and weights are symmetrically quantized.
|
|
# In practice, though, they may also match smooth-quant pattern when a 2D input shape would be used.
|
|
# Since add is not currently being used as a oneDNN post-op, but is unfused, we don't need these patterns with bias.
|
|
# Ideally, we should add mul + add post-op support in ATen int8 oneDNN linear op.
|
|
pattern1_with_no_outer_or_act_reshape = get_pattern_no_bias(
|
|
expand_a_scale=False, reshape_a=False
|
|
)
|
|
pattern2_with_no_outer_or_act_reshape = get_pattern_no_bias(
|
|
expand_a_scale=True, reshape_a=False
|
|
)
|
|
|
|
def _validate_pattern(match: Match):
|
|
if len(match.nodes) not in [4, 5, 6, 7, 10]:
|
|
return False
|
|
# Make sure weight is a constant
|
|
aten_int_mm_node = filter_nodes(match.nodes, aten._int_mm.default)[0]
|
|
if not isinstance(aten_int_mm_node.args[1], torch.fx.node.Node):
|
|
return False
|
|
if aten_int_mm_node.args[1].op != "get_attr":
|
|
return False
|
|
|
|
if len(match.nodes) == 10:
|
|
# Check the two tailing reshape nodes can be fused
|
|
if match.nodes[9].args[1] != match.nodes[6].args[1]:
|
|
return False
|
|
if len(match.nodes) == 10 or (
|
|
len(match.nodes) == 7 and match.nodes[6].target is aten.add.Tensor
|
|
):
|
|
bias_idx = 7 if len(match.nodes) == 10 else 6
|
|
# Check bias shape
|
|
bias_node = match.nodes[bias_idx].args[1]
|
|
if not isinstance(bias_node, torch.fx.node.Node):
|
|
return False
|
|
if len(bias_node.meta.get("tensor_meta").shape) != 1: # type: ignore[union-attr]
|
|
return False
|
|
return True
|
|
|
|
pattern_to_pass_number = {
|
|
pattern_no_bias_2: 0,
|
|
pattern_with_bias_2: 0,
|
|
pattern_no_bias_1: 1,
|
|
pattern_with_bias_1: 1,
|
|
pattern1_with_no_outer_or_act_reshape: 2,
|
|
pattern2_with_no_outer_or_act_reshape: 2,
|
|
}
|
|
for pattern, pass_number in pattern_to_pass_number.items():
|
|
|
|
@register_freezing_graph_pattern(
|
|
pattern,
|
|
extra_check=_validate_pattern,
|
|
pass_number=pass_number,
|
|
)
|
|
def _int_mm_weight_prepack(match: Match, *args, **kwargs):
|
|
bias = kwargs.get("bias", None)
|
|
x = kwargs["a"]
|
|
weight = kwargs["b"]
|
|
dtype = kwargs["dtype"]
|
|
x_scale = kwargs["x_scale"]
|
|
w_scale = kwargs["w_scale"]
|
|
x_shape = x.meta.get("tensor_meta").shape
|
|
if has_free_symbols(x_shape):
|
|
# For dynamic shape case, we can't get activation shape ahead of runtime.
|
|
x_shape = None
|
|
|
|
out_node = match.output_node()
|
|
with match.graph.inserting_before(out_node):
|
|
transpose_node = match.graph.call_function(
|
|
aten.permute.default, args=(weight, [1, 0])
|
|
)
|
|
contig_node = match.graph.call_function(
|
|
aten.contiguous.default, args=(transpose_node,)
|
|
)
|
|
packed_weight_inputs = (
|
|
contig_node,
|
|
x_shape,
|
|
)
|
|
packed_weight_op = torch.ops.onednn.qlinear_prepack
|
|
prepack_weight_node = match.graph.call_function(
|
|
packed_weight_op, args=packed_weight_inputs
|
|
)
|
|
|
|
dummy_zp = None
|
|
w_scale = match.graph.call_function(
|
|
prims.convert_element_type.default, args=(w_scale, torch.float32)
|
|
)
|
|
|
|
x_scale_shape = x_scale.meta.get("tensor_meta").shape
|
|
x_scale_is_scalar = False
|
|
if not has_free_symbols(x_scale_shape):
|
|
prod = 1
|
|
for d in x_scale_shape:
|
|
prod *= d
|
|
x_scale_is_scalar = prod == 1
|
|
|
|
new_args: tuple[Any, ...]
|
|
if x_scale_is_scalar:
|
|
# in this case, we can call onednn.qlinear directly
|
|
new_args = (
|
|
x,
|
|
x_scale,
|
|
dummy_zp, # x_zp
|
|
prepack_weight_node,
|
|
w_scale,
|
|
dummy_zp, # w_zp
|
|
bias,
|
|
1.0, # output_scale
|
|
0, # output_zero_point
|
|
dtype, # output_dtype
|
|
"none", # post op name
|
|
[], # post op args
|
|
"", # post op algorithm
|
|
)
|
|
new_linear_node = match.graph.call_function(
|
|
torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
|
|
)
|
|
out_node.replace_all_uses_with(new_linear_node)
|
|
new_linear_node.meta.update(out_node.meta)
|
|
else:
|
|
# onednn.qlinear does not support per-channel quantization of x
|
|
# so in this case, we have to apply x scale and add bias ourselves after qlinear
|
|
in_shape = kwargs.get("in_shape", None)
|
|
if in_shape is None:
|
|
x_reshaped = x
|
|
else:
|
|
x_reshaped = match.graph.call_function(
|
|
aten.reshape.default, args=(x, in_shape)
|
|
)
|
|
new_args = (
|
|
x_reshaped,
|
|
1.0, # x_scale
|
|
0, # x_zp
|
|
prepack_weight_node,
|
|
w_scale,
|
|
dummy_zp, # w_zp
|
|
None, # bias
|
|
1.0, # output_scale
|
|
0, # output_zero_point
|
|
dtype, # output_dtype
|
|
"none", # post op name
|
|
[], # post op args
|
|
"", # post op algorithm
|
|
)
|
|
new_linear_node = match.graph.call_function(
|
|
torch.ops.onednn.qlinear_pointwise, args=new_args
|
|
)
|
|
# apply x scale
|
|
new_out_node = match.graph.call_function(
|
|
aten.mul.Tensor, args=(new_linear_node, x_scale)
|
|
)
|
|
|
|
# Add bias and reshape
|
|
has_outer_reshape = (
|
|
kwargs.get("out_shape_with_bias", None) is not None
|
|
or kwargs.get("out_shape_no_bias", None) is not None
|
|
)
|
|
|
|
if has_outer_reshape:
|
|
out_shape = kwargs.get(
|
|
"out_shape_with_bias", kwargs["out_shape_no_bias"]
|
|
)
|
|
if bias is not None:
|
|
new_out_node = match.graph.call_function(
|
|
aten.add.Tensor, args=(new_out_node, bias)
|
|
)
|
|
if has_outer_reshape:
|
|
new_out_node = match.graph.call_function(
|
|
aten.reshape.default,
|
|
args=(new_out_node, out_shape), # type: ignore[possibly-undefined]
|
|
)
|
|
else:
|
|
if has_outer_reshape:
|
|
new_out_node = match.graph.call_function(
|
|
aten.reshape.default,
|
|
args=(new_out_node, out_shape), # type: ignore[possibly-undefined]
|
|
)
|
|
out_node.replace_all_uses_with(new_out_node)
|
|
new_out_node.meta.update(out_node.meta)
|
|
for node in reversed(match.nodes):
|
|
match.graph.erase_node(node)
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
|
|
match.nodes
|
|
)
|
|
|
|
|
|
class PostOpAttr:
|
|
def __init__(
|
|
self,
|
|
binary_op_name: str = "none",
|
|
alpha=None,
|
|
unary_op_name: str = "none",
|
|
scalars_attr=None,
|
|
algorithm_attr=None,
|
|
) -> None:
|
|
self.binary_op_name = binary_op_name
|
|
self.alpha = alpha if alpha else 1.0
|
|
self.unary_op_name = unary_op_name
|
|
self.scalars_attr = scalars_attr if scalars_attr else []
|
|
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
|
|
|
|
|
|
def _register_qconv_post_op_fusion_pass(
|
|
pattern,
|
|
pass_number,
|
|
computation_op,
|
|
post_op_attr,
|
|
):
|
|
has_binary_post_op = post_op_attr.binary_op_name != "none"
|
|
|
|
@register_freezing_graph_pattern(
|
|
pattern,
|
|
extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op),
|
|
pass_number=pass_number,
|
|
)
|
|
def qconv(match: Match, *args, **kwargs):
|
|
# Activation QParams
|
|
x, x_scale, x_zp = (
|
|
kwargs["x"],
|
|
kwargs["x_scale"],
|
|
kwargs["x_zp"],
|
|
)
|
|
# Weight QParams
|
|
packed_weight, w_scale, w_zp = (
|
|
kwargs["packed_weight"],
|
|
kwargs["w_scale"],
|
|
kwargs["w_zp"],
|
|
)
|
|
# Conv Params
|
|
b, stride, padding, dilation, groups = (
|
|
kwargs["b"],
|
|
kwargs["stride"],
|
|
kwargs["padding"],
|
|
kwargs["dilation"],
|
|
kwargs["groups"],
|
|
)
|
|
output_dtype = _get_pattern_output_dtype(match)
|
|
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
|
|
# Output QParams
|
|
o_inv_scale = (
|
|
kwargs["o_inv_scale"]
|
|
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
|
|
else 1.0
|
|
)
|
|
o_zero_point = (
|
|
kwargs["o_zp"]
|
|
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
|
|
else 0
|
|
)
|
|
assert (
|
|
kwargs["postop_name"] == "none"
|
|
) # Expected no post op fused in weight prepack phase
|
|
if post_op_attr.unary_op_name == "hardtanh":
|
|
min_value = kwargs.get("min_value")
|
|
max_value = kwargs.get("max_value")
|
|
post_op_attr.scalars_attr = [min_value, max_value]
|
|
|
|
out_node = match.output_node()
|
|
with match.graph.inserting_before(out_node):
|
|
if not has_binary_post_op:
|
|
computation_args: tuple[Any, ...] = (
|
|
x,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight,
|
|
w_scale,
|
|
w_zp,
|
|
b,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
o_inv_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
post_op_attr.unary_op_name,
|
|
post_op_attr.scalars_attr,
|
|
post_op_attr.algorithm_attr,
|
|
)
|
|
else:
|
|
accum = (
|
|
kwargs["accum"]
|
|
if output_dtype in [torch.uint8, torch.int8]
|
|
else kwargs["accum_after_dequant"]
|
|
)
|
|
accum_scale = (
|
|
kwargs["accum_scale"]
|
|
if output_dtype in [torch.uint8, torch.int8]
|
|
else 1.0
|
|
)
|
|
accum_zp = (
|
|
kwargs["accum_zp"]
|
|
if output_dtype in [torch.uint8, torch.int8]
|
|
else 0
|
|
)
|
|
computation_args = (
|
|
x,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight,
|
|
w_scale,
|
|
w_zp,
|
|
accum,
|
|
b,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
o_inv_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
accum_scale,
|
|
accum_zp,
|
|
post_op_attr.binary_op_name,
|
|
post_op_attr.alpha,
|
|
post_op_attr.unary_op_name,
|
|
post_op_attr.scalars_attr,
|
|
post_op_attr.algorithm_attr,
|
|
)
|
|
new_conv_node = match.graph.call_function(
|
|
computation_op, args=computation_args
|
|
)
|
|
out_node.replace_all_uses_with(new_conv_node)
|
|
new_conv_node.meta.update(out_node.meta)
|
|
for node in reversed(match.nodes):
|
|
match.graph.erase_node(node)
|
|
count_key = (
|
|
"qconv2d_binary_matcher_count"
|
|
if has_binary_post_op
|
|
else "qconv_unary_matcher_count"
|
|
)
|
|
nodes_key = (
|
|
"qconv2d_binary_matcher_nodes"
|
|
if has_binary_post_op
|
|
else "qconv_unary_matcher_nodes"
|
|
)
|
|
counters["inductor"][count_key] += 1
|
|
counters["inductor"][nodes_key] += len(match.nodes)
|
|
|
|
return qconv
|
|
|
|
|
|
def _register_qconv_unary_fusion():
|
|
from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion
|
|
|
|
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
|
|
# Priority 1 to match: QConv2d Unary pattern with int8 output
|
|
# If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
|
|
# For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
|
|
is_bf16 = original_pattern_output_dtype == torch.bfloat16
|
|
conv_unary_replace_patterns = {
|
|
PostOpAttr(
|
|
"none", None, "none", [], ""
|
|
): generate_pattern_with_output_quant(
|
|
get_qconv_pt2e_pattern(1),
|
|
),
|
|
PostOpAttr(
|
|
"none", None, "relu", [], ""
|
|
): generate_pattern_with_output_quant(
|
|
generate_pattern_with_unary(
|
|
get_qconv_pt2e_pattern(1), aten.relu.default
|
|
),
|
|
),
|
|
PostOpAttr(
|
|
"none", None, "hardtanh", [], ""
|
|
): generate_pattern_with_output_quant(
|
|
_unary_fusion_pattern(
|
|
_hardtanh_fusion,
|
|
get_qconv_pt2e_pattern(1),
|
|
1,
|
|
is_bf16,
|
|
),
|
|
with_dtype_convert=is_bf16,
|
|
),
|
|
PostOpAttr(
|
|
"none", None, "hardswish", [], ""
|
|
): generate_pattern_with_output_quant(
|
|
_unary_fusion_pattern(
|
|
_hardswish_fusion,
|
|
get_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
|
2,
|
|
is_bf16,
|
|
),
|
|
with_dtype_convert=is_bf16,
|
|
),
|
|
PostOpAttr(
|
|
"none", None, "swish", [], ""
|
|
): generate_pattern_with_output_quant(
|
|
_unary_fusion_pattern(
|
|
_silu_fusion,
|
|
get_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
|
2,
|
|
is_bf16,
|
|
),
|
|
with_dtype_convert=is_bf16,
|
|
),
|
|
}
|
|
|
|
for unary_attr, patterns in conv_unary_replace_patterns.items():
|
|
# Register qconv2d pattern for ExternKernel Lowering
|
|
_register_qconv_post_op_fusion_pass(
|
|
patterns,
|
|
3, # pass_number
|
|
torch.ops.onednn.qconv_pointwise.default, # computation_op
|
|
unary_attr, # unary_attr
|
|
)
|
|
|
|
# Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
|
|
conv_unary_replace_float_out_patterns = {
|
|
PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary(
|
|
get_qconv_pt2e_pattern(1), aten.relu.default
|
|
),
|
|
PostOpAttr(
|
|
"none", None, "hardtanh", [], ""
|
|
): _may_generate_pattern_with_dtype_convert(
|
|
_unary_fusion_pattern(
|
|
_hardtanh_fusion,
|
|
get_qconv_pt2e_pattern(1),
|
|
1,
|
|
is_bf16,
|
|
),
|
|
Arg(),
|
|
is_bf16,
|
|
),
|
|
PostOpAttr(
|
|
"none", None, "hardswish", [], ""
|
|
): _may_generate_pattern_with_dtype_convert(
|
|
_unary_fusion_pattern(
|
|
_hardswish_fusion,
|
|
get_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
|
2,
|
|
is_bf16,
|
|
),
|
|
Arg(),
|
|
is_bf16,
|
|
),
|
|
PostOpAttr(
|
|
"none", None, "swish", [], ""
|
|
): _may_generate_pattern_with_dtype_convert(
|
|
_unary_fusion_pattern(
|
|
_silu_fusion,
|
|
get_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
|
2,
|
|
is_bf16,
|
|
),
|
|
Arg(),
|
|
is_bf16,
|
|
),
|
|
}
|
|
|
|
for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
|
|
# Register qconv2d pattern for ExternKernel Lowering
|
|
_register_qconv_post_op_fusion_pass(
|
|
patterns,
|
|
4, # pass_number
|
|
torch.ops.onednn.qconv_pointwise.default, # computation_op
|
|
unary_attr, # unary_attr
|
|
)
|
|
|
|
|
|
def _register_qconv_binary_fusion():
|
|
for int8_mixed_bf16_with_inplace_add in [False, True]:
|
|
# Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
|
|
swap_binary_inputs_list = [False, True]
|
|
binary_replace_patterns = {}
|
|
for swap_inputs in swap_binary_inputs_list:
|
|
binary_replace_patterns.update(
|
|
{
|
|
PostOpAttr(
|
|
"sum", 1.0, "none", [], ""
|
|
): generate_pattern_with_output_quant(
|
|
generate_pattern_with_binary(
|
|
aten.add.Tensor,
|
|
get_qconv_pt2e_pattern(1),
|
|
dequantize_accum_pattern,
|
|
int8_mixed_bf16_with_inplace_add,
|
|
swap_inputs=swap_inputs,
|
|
),
|
|
),
|
|
PostOpAttr(
|
|
"sum", 1.0, "relu", [], ""
|
|
): generate_pattern_with_output_quant(
|
|
generate_pattern_with_unary(
|
|
generate_pattern_with_binary(
|
|
aten.add.Tensor,
|
|
get_qconv_pt2e_pattern(1),
|
|
dequantize_accum_pattern,
|
|
int8_mixed_bf16_with_inplace_add,
|
|
swap_inputs=swap_inputs,
|
|
),
|
|
aten.relu.default,
|
|
),
|
|
),
|
|
}
|
|
)
|
|
|
|
for binary_unary_attr, patterns in binary_replace_patterns.items():
|
|
_register_qconv_post_op_fusion_pass(
|
|
patterns,
|
|
3, # pass_number
|
|
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
|
binary_unary_attr, # binary_unary_attr
|
|
)
|
|
|
|
# Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output
|
|
binary_replace_float_out_patterns = {}
|
|
for swap_inputs in swap_binary_inputs_list:
|
|
binary_replace_float_out_patterns.update(
|
|
{
|
|
PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
|
|
generate_pattern_with_binary(
|
|
aten.add.Tensor,
|
|
get_qconv_pt2e_pattern(1),
|
|
KeywordArg("accum_after_dequant"),
|
|
int8_mixed_bf16_with_inplace_add,
|
|
swap_inputs=swap_inputs,
|
|
),
|
|
aten.relu.default,
|
|
)
|
|
}
|
|
)
|
|
|
|
for (
|
|
binary_unary_attr,
|
|
patterns,
|
|
) in binary_replace_float_out_patterns.items():
|
|
if int8_mixed_bf16_with_inplace_add:
|
|
_register_qconv_post_op_fusion_pass(
|
|
patterns,
|
|
3, # pass_number
|
|
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
|
binary_unary_attr, # binary_unary_attr
|
|
)
|
|
else:
|
|
_register_qconv_post_op_fusion_pass(
|
|
patterns,
|
|
4, # pass_number
|
|
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
|
binary_unary_attr, # binary_unary_attr
|
|
)
|
|
|
|
# Priority 3: QConv2d Binary pattern with fp32/bfloat16 output
|
|
binary_replace_float_out_patterns = {}
|
|
for swap_inputs in swap_binary_inputs_list:
|
|
binary_replace_float_out_patterns.update(
|
|
{
|
|
PostOpAttr(
|
|
"sum", 1.0, "none", [], ""
|
|
): generate_pattern_with_binary(
|
|
aten.add.Tensor,
|
|
get_qconv_pt2e_pattern(1),
|
|
KeywordArg("accum_after_dequant"),
|
|
int8_mixed_bf16_with_inplace_add,
|
|
swap_inputs=swap_inputs,
|
|
),
|
|
}
|
|
)
|
|
|
|
for (
|
|
binary_unary_attr,
|
|
patterns,
|
|
) in binary_replace_float_out_patterns.items():
|
|
_register_qconv_post_op_fusion_pass(
|
|
patterns,
|
|
4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number
|
|
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
|
binary_unary_attr, # binary_unary_attr
|
|
)
|
|
|
|
|
|
def _register_qlinear_post_op_fusion_pass(
|
|
pattern,
|
|
pass_number,
|
|
computation_op,
|
|
post_op_attr,
|
|
):
|
|
has_binary_post_op = post_op_attr.binary_op_name != "none"
|
|
|
|
@register_freezing_graph_pattern(
|
|
pattern,
|
|
extra_check=_is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op),
|
|
pass_number=pass_number,
|
|
)
|
|
def qlinear_post_op_fusion(match: Match, *args, **kwargs):
|
|
"""
|
|
Match the pattern:
|
|
qlinear - post op
|
|
"""
|
|
output_dtype = _get_pattern_output_dtype(match)
|
|
# Activation QParams
|
|
x, x_scale, x_zp = (
|
|
kwargs["x"],
|
|
kwargs["x_scale"],
|
|
kwargs["x_zp"],
|
|
)
|
|
# Weight QParams
|
|
packed_weight, w_scale, w_zp = (
|
|
kwargs["packed_weight"],
|
|
kwargs["w_scale"],
|
|
kwargs["w_zp"],
|
|
)
|
|
|
|
# bias
|
|
b = kwargs["b"] if "b" in kwargs else None
|
|
|
|
# Output QParams
|
|
o_inv_scale = (
|
|
kwargs["o_inv_scale"]
|
|
if (output_dtype in [torch.uint8, torch.int8])
|
|
else 1.0
|
|
)
|
|
o_zero_point = (
|
|
kwargs["o_zp"] if (output_dtype in [torch.uint8, torch.int8]) else 0
|
|
)
|
|
assert (
|
|
kwargs["postop_name"] == "none"
|
|
) # Expected no post op fused in weight prepack phase
|
|
|
|
out_node = match.output_node()
|
|
with match.graph.inserting_before(out_node):
|
|
if not has_binary_post_op:
|
|
computation_args: tuple[Any, ...] = (
|
|
x,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight,
|
|
w_scale,
|
|
w_zp,
|
|
b,
|
|
o_inv_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
post_op_attr.unary_op_name,
|
|
post_op_attr.scalars_attr,
|
|
post_op_attr.algorithm_attr,
|
|
)
|
|
else:
|
|
other = kwargs["other"] if "other" in kwargs else kwargs["accum"]
|
|
x2_scale = 1.0
|
|
x2_zp = 0
|
|
computation_args = (
|
|
x,
|
|
x_scale,
|
|
x_zp,
|
|
packed_weight,
|
|
w_scale,
|
|
w_zp,
|
|
other,
|
|
b,
|
|
o_inv_scale,
|
|
o_zero_point,
|
|
output_dtype,
|
|
x2_scale,
|
|
x2_zp,
|
|
post_op_attr.binary_op_name,
|
|
post_op_attr.alpha,
|
|
post_op_attr.unary_op_name,
|
|
post_op_attr.scalars_attr,
|
|
post_op_attr.algorithm_attr,
|
|
)
|
|
new_linear_node = match.graph.call_function(
|
|
computation_op, args=computation_args
|
|
)
|
|
out_node.replace_all_uses_with(new_linear_node)
|
|
new_linear_node.meta.update(out_node.meta)
|
|
for node in reversed(match.nodes):
|
|
match.graph.erase_node(node)
|
|
count_key = (
|
|
"qlinear_binary_matcher_count"
|
|
if has_binary_post_op
|
|
else "qlinear_unary_matcher_count"
|
|
)
|
|
nodes_key = (
|
|
"qlinear_binary_matcher_nodes"
|
|
if has_binary_post_op
|
|
else "qlinear_unary_matcher_nodes"
|
|
)
|
|
counters["inductor"][count_key] += 1
|
|
counters["inductor"][nodes_key] += len(match.nodes)
|
|
|
|
|
|
def _register_qlinear_unary_fusion():
|
|
from .mkldnn_fusion import (
|
|
_gelu_fusion_1 as _gelu_fusion_erf,
|
|
_gelu_fusion_2 as _gelu_fusion_tanh,
|
|
)
|
|
|
|
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
|
|
is_bf16 = original_pattern_output_dtype == torch.bfloat16
|
|
for x_scale_zp_are_tensors in (False, True):
|
|
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
|
|
computation_op = (
|
|
torch.ops.onednn.qlinear_pointwise.tensor
|
|
if x_scale_zp_are_tensors
|
|
else torch.ops.onednn.qlinear_pointwise.default
|
|
)
|
|
# Priority 1 to match: QLinear Unary pattern with int8 output
|
|
linear_unary_replace_patterns = {
|
|
PostOpAttr(
|
|
"none", None, "none", [], ""
|
|
): generate_pattern_with_output_quant(
|
|
qlinear_pattern,
|
|
),
|
|
PostOpAttr(
|
|
"none", None, "relu", [], ""
|
|
): generate_pattern_with_output_quant(
|
|
generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
|
|
),
|
|
PostOpAttr(
|
|
"none", None, "gelu", [], "none"
|
|
): generate_pattern_with_output_quant(
|
|
_unary_fusion_pattern(
|
|
_gelu_fusion_erf,
|
|
get_qlinear_pt2e_pattern(
|
|
x_scale_zp_are_tensors, 1 if is_bf16 else 2
|
|
),
|
|
2,
|
|
is_bf16,
|
|
),
|
|
with_dtype_convert=is_bf16,
|
|
),
|
|
PostOpAttr(
|
|
"none", None, "gelu", [], "tanh"
|
|
): generate_pattern_with_output_quant(
|
|
_unary_fusion_pattern(
|
|
_gelu_fusion_tanh,
|
|
get_qlinear_pt2e_pattern(
|
|
x_scale_zp_are_tensors, 1 if is_bf16 else 4
|
|
),
|
|
4,
|
|
is_bf16,
|
|
),
|
|
with_dtype_convert=is_bf16,
|
|
),
|
|
}
|
|
|
|
for unary_attr, patterns in linear_unary_replace_patterns.items():
|
|
_register_qlinear_post_op_fusion_pass(
|
|
patterns,
|
|
3, # pass_number
|
|
computation_op,
|
|
unary_attr, # unary_attr
|
|
)
|
|
|
|
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
|
|
linear_unary_replace_float_out_patterns = {
|
|
PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary(
|
|
qlinear_pattern, aten.relu.default
|
|
),
|
|
PostOpAttr(
|
|
"none", None, "gelu", [], "none"
|
|
): _may_generate_pattern_with_dtype_convert(
|
|
_unary_fusion_pattern(
|
|
_gelu_fusion_erf,
|
|
get_qlinear_pt2e_pattern(
|
|
x_scale_zp_are_tensors, 1 if is_bf16 else 2
|
|
),
|
|
2,
|
|
is_bf16,
|
|
),
|
|
Arg(),
|
|
is_bf16,
|
|
),
|
|
PostOpAttr(
|
|
"none", None, "gelu", [], "tanh"
|
|
): _may_generate_pattern_with_dtype_convert(
|
|
_unary_fusion_pattern(
|
|
_gelu_fusion_tanh,
|
|
get_qlinear_pt2e_pattern(
|
|
x_scale_zp_are_tensors, 1 if is_bf16 else 4
|
|
),
|
|
4,
|
|
is_bf16,
|
|
),
|
|
Arg(),
|
|
is_bf16,
|
|
),
|
|
}
|
|
|
|
for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
|
|
_register_qlinear_post_op_fusion_pass(
|
|
patterns,
|
|
4, # pass_number
|
|
computation_op,
|
|
unary_attr, # unary_attr
|
|
)
|
|
|
|
|
|
def _register_qlinear_binary_fusion():
|
|
r"""
|
|
Supported linear-binary(-unary) patterns
|
|
|
|
linear(X) extra input
|
|
\ /
|
|
Add
|
|
|
|
|
Optional(relu)
|
|
|
|
|
Y
|
|
|
|
1. int8-mixed-fp32
|
|
+---+---------------+-----------+------------------------------+---------+
|
|
| # | Add type | Quant out | Pattern | Post op |
|
|
+---+---------------+-----------+------------------------------+---------+
|
|
| 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add |
|
|
+---+---------------+-----------+------------------------------+---------+
|
|
| 2 | In-/out-place | No | linear + fp32 -> (relu) | sum |
|
|
+---+---------------+-----------+------------------------------+---------+
|
|
|
|
2. int8-mixed-bf16
|
|
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
|
| # | X2 dtype | Add type | Quant out | Pattern | Post op |
|
|
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
|
| 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add |
|
|
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
|
| 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum |
|
|
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
|
| 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add |
|
|
| | | In-place right| | | |
|
|
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
|
| 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum |
|
|
| | | In-place right| | | |
|
|
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
|
| 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add |
|
|
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
|
| 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add |
|
|
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
|
|
|
Note
|
|
(1) The positions of linear and the extra input can be swapped.
|
|
(2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the
|
|
extra input, we don't match that pattern because we cannot match all these patterns in 3 passes.
|
|
"""
|
|
for x_scale_zp_are_tensors in (False, True):
|
|
qlinear_binary_op = (
|
|
torch.ops.onednn.qlinear_pointwise.binary_tensor
|
|
if x_scale_zp_are_tensors
|
|
else torch.ops.onednn.qlinear_pointwise.binary
|
|
)
|
|
unary_postop_list = ["none", "relu"]
|
|
unary_postop_dict = {
|
|
"none": None,
|
|
"relu": aten.relu.default,
|
|
}
|
|
convert_dtype_after_binary_list = [False, True]
|
|
|
|
# Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output
|
|
# Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16,
|
|
# totally 3 patterns (2 are identical)
|
|
swap_binary_inputs_list = [False, True]
|
|
int8_mixed_bf16_list = [False, True]
|
|
combinations = itertools.product(
|
|
unary_postop_list,
|
|
int8_mixed_bf16_list,
|
|
swap_binary_inputs_list,
|
|
convert_dtype_after_binary_list,
|
|
)
|
|
qlinear_binary_replace_patterns = {}
|
|
for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations:
|
|
if not int8_mixed_bf16 and cvt_dtype_binary:
|
|
# No convert node after binary node if dtypes are all fp32
|
|
continue
|
|
qlinear_binary_replace_patterns.update(
|
|
{
|
|
PostOpAttr(
|
|
"add", 1.0, unary_op, [], ""
|
|
): generate_pattern_with_output_quant(
|
|
generate_pattern_with_unary(
|
|
generate_pattern_with_binary(
|
|
aten.add.Tensor,
|
|
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
|
|
KeywordArg("other"),
|
|
# If fp32 extra input is inplace added to bf16 linear output,
|
|
# a to_bf16 node is inserted after binary
|
|
dtype_convert=cvt_dtype_binary,
|
|
swap_inputs=swap_inputs,
|
|
),
|
|
unary_postop_dict[unary_op],
|
|
),
|
|
)
|
|
}
|
|
)
|
|
for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items():
|
|
_register_qlinear_post_op_fusion_pass(
|
|
patterns,
|
|
3, # pass_number
|
|
qlinear_binary_op, # computation_op
|
|
binary_unary_attr,
|
|
)
|
|
|
|
# Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output
|
|
# Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16,
|
|
# totally 2 patterns (2 are identical)
|
|
binary_replace_float_out_patterns = {}
|
|
for swap_binary_inputs in swap_binary_inputs_list:
|
|
binary_replace_float_out_patterns.update(
|
|
{
|
|
PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
|
|
generate_pattern_with_binary(
|
|
aten.add.Tensor,
|
|
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
|
|
KeywordArg("accum"),
|
|
dtype_convert=False,
|
|
swap_inputs=swap_binary_inputs,
|
|
),
|
|
aten.relu.default,
|
|
),
|
|
}
|
|
)
|
|
for (
|
|
binary_unary_attr,
|
|
patterns,
|
|
) in binary_replace_float_out_patterns.items():
|
|
_register_qlinear_post_op_fusion_pass(
|
|
patterns,
|
|
4, # pass_number
|
|
qlinear_binary_op, # computation_op
|
|
binary_unary_attr,
|
|
)
|
|
# Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output
|
|
# Covers case (6) of int8-mixed-bf16
|
|
binary_replace_float_out_patterns = {}
|
|
for swap_binary_inputs in swap_binary_inputs_list:
|
|
binary_replace_float_out_patterns.update(
|
|
{
|
|
PostOpAttr("add", 1.0, "relu", [], ""): generate_pattern_with_unary(
|
|
generate_pattern_with_binary(
|
|
aten.add.Tensor,
|
|
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
|
|
KeywordArg("other"),
|
|
dtype_convert=True,
|
|
swap_inputs=swap_binary_inputs,
|
|
),
|
|
aten.relu.default,
|
|
),
|
|
}
|
|
)
|
|
for (
|
|
binary_unary_attr,
|
|
patterns,
|
|
) in binary_replace_float_out_patterns.items():
|
|
_register_qlinear_post_op_fusion_pass(
|
|
patterns,
|
|
4, # pass_number
|
|
qlinear_binary_op, # computation_op
|
|
binary_unary_attr,
|
|
)
|
|
|
|
# Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output
|
|
# Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16,
|
|
# totally 2 patterns (2 are identical)
|
|
binary_replace_float_out_patterns = {}
|
|
for swap_binary_inputs in swap_binary_inputs_list:
|
|
binary_replace_float_out_patterns.update(
|
|
{
|
|
PostOpAttr(
|
|
"sum", 1.0, "none", [], ""
|
|
): generate_pattern_with_binary(
|
|
aten.add.Tensor,
|
|
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
|
|
KeywordArg("accum"),
|
|
dtype_convert=False,
|
|
swap_inputs=swap_binary_inputs,
|
|
),
|
|
}
|
|
)
|
|
for (
|
|
binary_unary_attr,
|
|
patterns,
|
|
) in binary_replace_float_out_patterns.items():
|
|
_register_qlinear_post_op_fusion_pass(
|
|
patterns,
|
|
5, # pass_number
|
|
qlinear_binary_op, # computation_op
|
|
binary_unary_attr,
|
|
)
|
|
# Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output
|
|
# Covers (6) of int8-mixed-bf16
|
|
binary_replace_float_out_patterns = {}
|
|
for swap_binary_inputs in swap_binary_inputs_list:
|
|
binary_replace_float_out_patterns.update(
|
|
{
|
|
PostOpAttr(
|
|
"add", 1.0, "none", [], ""
|
|
): generate_pattern_with_binary(
|
|
aten.add.Tensor,
|
|
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
|
|
KeywordArg("other"),
|
|
dtype_convert=True,
|
|
swap_inputs=swap_binary_inputs,
|
|
),
|
|
}
|
|
)
|
|
for (
|
|
binary_unary_attr,
|
|
patterns,
|
|
) in binary_replace_float_out_patterns.items():
|
|
_register_qlinear_post_op_fusion_pass(
|
|
patterns,
|
|
5, # pass_number
|
|
qlinear_binary_op, # computation_op
|
|
binary_unary_attr,
|
|
)
|
|
|
|
|
|
@functools.cache
|
|
def _register_quantization_weight_pack_pass():
|
|
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
|
|
_register_dequant_promotion()
|
|
|
|
# Step 2: QConv weight prepack
|
|
_register_qconv_weight_prepack()
|
|
|
|
# Step 3: QLinear weight prepack
|
|
_register_qlinear_weight_prepack()
|
|
_register_linear_dynamic_fp16_weight_prepack()
|
|
|
|
# Step 4: weight prepack for SmoothQuant from Torchao
|
|
_register_smooth_quant_int_mm_pattern()
|
|
|
|
# Step 5: QLinear post op Fusion
|
|
if not torch.ops.mkldnn._is_mkldnn_acl_supported():
|
|
# skip fusion on ARM
|
|
_register_qconv_unary_fusion()
|
|
_register_qconv_binary_fusion()
|
|
_register_qlinear_unary_fusion()
|
|
_register_qlinear_binary_fusion()
|
|
|
|
|
|
def _is_valid_concat_linear_woq_int4_fusion(computation_nodes):
|
|
computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default
|
|
act = computation_nodes[0].args[0]
|
|
wgt = computation_nodes[0].args[1]
|
|
in_feature_size = wgt.meta.get("val").size(1) # type: ignore[union-attr]
|
|
group_size = computation_nodes[0].args[2]
|
|
return len(computation_nodes) >= 2 and all(
|
|
(
|
|
node.target == computation_op
|
|
and node.args[0] == act # share same activation
|
|
and (
|
|
node.args[1].meta.get("val").size(1) == in_feature_size
|
|
) # same in feature size
|
|
and (node.args[1] != wgt or gemm_idx == 0)
|
|
and node.args[1].op == "get_attr" # wgt are all constants
|
|
and node.args[2] == group_size # same group size
|
|
)
|
|
for gemm_idx, node in enumerate(computation_nodes)
|
|
)
|
|
|
|
|
|
def concat_linear_woq_int4(gm: torch.fx.GraphModule):
|
|
"""
|
|
Concat Linear optimization pass for WOQ int4
|
|
This pass fuses the original pattern:
|
|
def ...
|
|
return (woq_int4(x, w1, group_size, scale_zp1), woq_int4(x, w2, group_size, scale_zp1) ...)
|
|
into a single operation:
|
|
def ...
|
|
concat_res = woq_int4(x, concat_w, group_size, concat_scale_zp)
|
|
return split(concat_res, split_size_list)
|
|
"""
|
|
|
|
def concat_wgt(packed_wgts, scale_zps, group_size, act_dtype):
|
|
# Concat the wgts and scale_zps, and repack the wgt
|
|
unpacked_wgts = []
|
|
for packed_wgt in packed_wgts:
|
|
# Get the unpacked weight list
|
|
# Same as https://github.com/pytorch/pytorch/pull/156174
|
|
K = packed_wgt.size(1) * 2
|
|
N = packed_wgt.size(0)
|
|
x = torch.eye(K).to(dtype=act_dtype)
|
|
qscales_and_zeros = (
|
|
torch.tensor([1.0, 8.0])
|
|
.to(dtype=act_dtype)
|
|
.expand(K // group_size, N, 2)
|
|
.contiguous()
|
|
)
|
|
unpacked_wgts.append(
|
|
torch.ops.aten._weight_int4pack_mm_for_cpu(
|
|
x,
|
|
packed_wgt,
|
|
group_size,
|
|
qscales_and_zeros,
|
|
)
|
|
.t()
|
|
.contiguous()
|
|
.to(torch.int32) # N, K
|
|
)
|
|
concat_unpacked_wgt = torch.cat(unpacked_wgts, dim=0)
|
|
repack_w = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
|
|
concat_unpacked_wgt, 1
|
|
)
|
|
concat_scale_zp = torch.cat(scale_zps, dim=1).contiguous()
|
|
return repack_w, concat_scale_zp
|
|
|
|
graph = gm.graph
|
|
computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default
|
|
for node in graph.find_nodes(op="call_function", target=computation_op):
|
|
if (
|
|
not node._erased
|
|
and isinstance(node.meta.get("val"), torch.Tensor)
|
|
and node.meta["val"].device.type == "cpu"
|
|
):
|
|
act = node.args[0]
|
|
users = list(act.users)
|
|
if _is_valid_concat_linear_woq_int4_fusion(users):
|
|
with graph.inserting_before(node):
|
|
assert all(user.args[1].op == "get_attr" for user in users)
|
|
computation_node_0 = users[0]
|
|
packed_wgts = [getattr(gm, user.args[1].target) for user in users]
|
|
group_size = computation_node_0.args[2]
|
|
scale_zps = [getattr(gm, user.args[3].target) for user in users]
|
|
out_feature_size_list = [
|
|
packed_wgt.size(0) for packed_wgt in packed_wgts
|
|
]
|
|
repack_w, concat_scale_zp = concat_wgt(
|
|
packed_wgts, scale_zps, group_size, act.meta.get("val").dtype
|
|
)
|
|
repack_w_node_name = computation_node_0.args[1].target + "_concat"
|
|
concat_scale_zp_node_name = (
|
|
computation_node_0.args[3].target + "_concat"
|
|
)
|
|
gm.register_buffer(repack_w_node_name, repack_w)
|
|
setattr(gm, repack_w_node_name, repack_w)
|
|
gm.register_buffer(concat_scale_zp_node_name, concat_scale_zp)
|
|
setattr(gm, concat_scale_zp_node_name, concat_scale_zp)
|
|
|
|
repack_w_node = graph.create_node(
|
|
"get_attr", repack_w_node_name, (), {}
|
|
)
|
|
with graph.inserting_after(repack_w_node):
|
|
concat_scale_zp_node = graph.create_node(
|
|
"get_attr", concat_scale_zp_node_name, (), {}
|
|
)
|
|
|
|
with graph.inserting_after(concat_scale_zp_node):
|
|
concat_int4_gemm_node = graph.create_node(
|
|
"call_function",
|
|
computation_op,
|
|
(
|
|
act,
|
|
repack_w_node,
|
|
group_size,
|
|
concat_scale_zp_node,
|
|
),
|
|
)
|
|
with graph.inserting_after(concat_int4_gemm_node):
|
|
split_node = graph.create_node(
|
|
"call_function",
|
|
torch.ops.aten.split_with_sizes.default,
|
|
(
|
|
concat_int4_gemm_node,
|
|
out_feature_size_list,
|
|
1, # split dim
|
|
),
|
|
)
|
|
with graph.inserting_after(split_node):
|
|
for gemm_idx, user in enumerate(users):
|
|
assert user.target == computation_op
|
|
get_item = graph.create_node(
|
|
"call_function",
|
|
operator.getitem,
|
|
(
|
|
split_node,
|
|
gemm_idx,
|
|
),
|
|
)
|
|
with graph.inserting_after(get_item):
|
|
clone_node = graph.create_node(
|
|
"call_function",
|
|
torch.ops.aten.clone.default,
|
|
(get_item,),
|
|
{"memory_format": torch.contiguous_format},
|
|
)
|
|
user.replace_all_uses_with(clone_node)
|
|
graph.erase_node(user)
|
|
|
|
|
|
def quant_lift_up(graph_module: torch.fx.GraphModule):
|
|
"""
|
|
Lift up the quant node before view like nodes. It can benefit performance
|
|
of Attention like block. For example, we have the pattern as:
|
|
|
|
DQ
|
|
DQ LINEAR
|
|
LINEAR VIEW
|
|
VIEW PERMUTE
|
|
PERMUTE TRANSPOSE
|
|
Q Q
|
|
DQ DQ
|
|
Matmul
|
|
DIV
|
|
ADD
|
|
SOFTMAX
|
|
|
|
We want to lift up the the quant nodes from matmul before view like nodes
|
|
as the output of Linear node.
|
|
|
|
DQ
|
|
DQ LINEAR
|
|
LINEAR Q
|
|
Q VIEW
|
|
VIEW PERMUTE
|
|
PERMUTE TRANSPOSE
|
|
DQ DQ
|
|
Matmul
|
|
DIV
|
|
ADD
|
|
SOFTMAX
|
|
|
|
It produces a DQ->LINEAR->Q pattern which can be fused by backend.
|
|
"""
|
|
|
|
def is_view_op(node):
|
|
return node.op == "call_function" and node.target in _VIEW_OPS
|
|
|
|
for node in graph_module.graph.nodes:
|
|
# <TODO> Leslie: Here we verify that the quant node has exactly
|
|
# one input FX node, with constant scalar value for scale and zero point.
|
|
# For the case input of quant node has more than one input FX nodes,
|
|
# extend the implementation to lift up all the connected nodes
|
|
# before the view nodes to keep the topological order.
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target in _PER_TENSOR_QUANTIZE_OPS
|
|
and len(node.all_input_nodes) == 1
|
|
and is_view_op(node.all_input_nodes[0])
|
|
):
|
|
quant_node = node
|
|
input_node_of_quant = quant_node.args[0]
|
|
|
|
# Check the nodes along lift up path has only 1 user node
|
|
# Propagate view like node to find where to insert the new quant node
|
|
could_lift_up = True
|
|
current_node = quant_node
|
|
input_node = current_node.args[0]
|
|
while is_view_op(input_node):
|
|
if len(input_node.users) != 1:
|
|
could_lift_up = False
|
|
break
|
|
current_node = input_node
|
|
input_node = current_node.args[0]
|
|
|
|
# Further check the input node of the first view node has only 1 user node
|
|
if could_lift_up and len(input_node.users) == 1:
|
|
# Replace dequant's input from quant to quant's input
|
|
quant_node.replace_all_uses_with(input_node_of_quant)
|
|
# Insert the new quant node
|
|
with graph_module.graph.inserting_before(current_node):
|
|
new_quant_node = graph_module.graph.node_copy(quant_node)
|
|
input_node.replace_all_uses_with(new_quant_node)
|
|
|
|
# Update inputs of new_quant_node
|
|
def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
|
|
if n == input_node_of_quant:
|
|
return input_node
|
|
else:
|
|
return n
|
|
|
|
new_args = map_arg(new_quant_node.args, maybe_replace_node)
|
|
new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
|
|
new_quant_node.args = new_args # type: ignore[assignment]
|
|
new_quant_node.kwargs = new_kwargs # type: ignore[assignment]
|
|
graph_module.graph.erase_node(quant_node)
|
|
|
|
graph_module.graph.lint()
|
|
graph_module.recompile()
|