Compare commits

...

2 Commits

Author SHA1 Message Date
43e32bdcfc add test cases 2025-10-29 10:54:14 +00:00
0cc9ee9820 add sum support for qlinear_binary pattern 2025-10-21 12:09:38 +00:00
7 changed files with 257 additions and 53 deletions

View File

@ -1954,14 +1954,17 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
return
B = (2, batch_size) if input_3d else (batch_size,)
input = torch.randn(*B, in_features).to(dtype=torch.float32)
input2 = torch.randn(*B, in_features).to(dtype=torch.float32)
input3 = torch.randn(*B, out_features).to(dtype=torch.float32)
other = torch.randn(*B, out_features).to(dtype=dtype)
# Avoid hitting qlinear inplace sum fusion
if input_3d:
other2 = torch.randn(B[0] * B[1], out_features).to(dtype=dtype)
else:
other2 = torch.randn(1, *B, out_features).to(dtype=dtype)
other_clone = other.clone()
class M(torch.nn.Module):
def __init__(self, bias, input_3d):
super().__init__()
@ -1973,7 +1976,6 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
def forward(self, x, other, other2):
res = self.epilogue(self.linear(x) + other)
# Avoid hitting qlinear inplace sum fusion
if self.input_3d:
other2 = other2.view(2, other2.size(0) // 2, other2.size(1))
else:
@ -1981,11 +1983,29 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
res = self.epilogue2(self.linear2(res) + other2)
return res
class M2(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
self.epilogue = _get_epilogue(epilogue)
self.linear2 = torch.nn.Linear(out_features, out_features, bias)
self.epilogue2 = _get_epilogue(epilogue)
def forward(self, x0, x1, other):
# test qlinear sum -> qlinear sum
res = self.epilogue(self.linear(x0) + other)
res = self.epilogue2(self.linear2(x1) + res)
return res
counters.clear()
ref_quantized_mod = _generate_qdq_quantized_model(
M(bias=bias, input_3d=input_3d).eval(),
(input, other, other2),
)
ref_quantized_mod2 = _generate_qdq_quantized_model(
M2(bias=bias).eval(),
(input2, input3, other_clone),
)
atol, rtol = 5e-2, 5e-2
with (
patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)),
@ -1994,6 +2014,9 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
):
ref_res = ref_quantized_mod(input, other, other2)
cfn = torch.compile(ref_quantized_mod)
ref_res2 = ref_quantized_mod2(input2, input3, other_clone)
cfn2 = torch.compile(ref_quantized_mod2)
res = cfn(input, other, other2)
self.assertEqual(
res,
@ -2003,7 +2026,18 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
equal_nan=True,
exact_dtype=True,
)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
res2 = cfn2(input2, input3, other_clone)
self.assertEqual(
res2,
ref_res2,
atol=atol,
rtol=rtol,
equal_nan=True,
exact_dtype=True,
)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 4)
self.assertEqual(
counters["inductor"]["cpp_epilogue_fusion_counter"],
0,

View File

@ -2343,6 +2343,51 @@ class TestPatternMatcher(TestPatternMatcherBase):
matcher_check_fn=matcher_check_fn,
)
def _qlinear_sum_test_helper(
self,
inputs,
device="cpu",
int8_mixed_bf16=False,
matcher_check_fn=None,
bias=True,
):
class M(torch.nn.Module):
def __init__(self, use_bias):
super().__init__()
self.linear = torch.nn.Linear(4, 4, use_bias)
self.linear2 = torch.nn.Linear(4, 4, use_bias)
def forward(self, x, other):
# test qlinear sum -> qlinear sum
res = self.linear(x) + other
res = self.linear2(x) + res
return res
mod = M(bias).eval().to(device=device)
assert isinstance(inputs, tuple)
def __convert_tensor_to_device(input, device):
return input.to(device=device) if isinstance(input, torch.Tensor) else input
inputs = tuple(__convert_tensor_to_device(input, device) for input in inputs)
def _default_matcher_check_fn():
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2
)
self._test_common(
mod,
inputs,
matcher_check_fn=(
matcher_check_fn
if matcher_check_fn is not None
else _default_matcher_check_fn
),
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
check_quantization=True,
)
def _qlinear_test_helper(
self,
inputs,
@ -3058,6 +3103,18 @@ class TestPatternMatcher(TestPatternMatcherBase):
is_dynamic=is_dynamic,
)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qlinear_sum_cpu(self):
for bias in [True, False]:
use_bf16 = [True, False] if is_mkldnn_bf16_supported("cpu") else [False,]
for int8_mixed_bf16 in use_bf16:
self._qlinear_sum_test_helper(
(torch.randn((2, 2, 4)), torch.randn(2, 2, 4)),
bias=bias,
int8_mixed_bf16=int8_mixed_bf16
)
def _test_qlinear_fp8_inductor_cpu_helper(self, qlinear_op, post_op="none"):
dtype = torch.float8_e4m3fn
qlinear_prepack = torch.ops.onednn.qlinear_prepack

View File

@ -29,6 +29,7 @@ from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import get_device_op_overrides, IndentedBuffer, Kernel
from .cpp_utils import cexpr, DEVICE_TO_ATEN, DEVICE_TO_INT, DTYPE_TO_ATEN, DTYPE_TO_CPP
from .wrapper import (
codegen_reinterpret_view_helper,
EnterSubgraphLine,
ExitSubgraphLine,
PythonWrapperCodegen,
@ -1752,6 +1753,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
"""Returns a newly-created, temporary RAII tensor handle containing the
reinterpreted tensor data. Callers of this function are responsible for saving
the handle if persistent access is needed."""
d_size, d_stride, d_offset, d_dtype, collapsible = codegen_reinterpret_view_helper(data)
dim = str(len(size))
original_offset = offset
offset = self.codegen_sizevar(offset)
@ -1797,26 +1801,31 @@ class CppWrapperCpu(PythonWrapperCodegen):
]
return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs
if (
size == data.layout.size
and stride == data.layout.stride
and original_offset == data.layout.offset
):
# pure dtypeview
if dtype is not None and dtype != data.dtype:
# New: use collapsed info (matches python logic)
collapsed = collapsible and original_offset == d_offset
if collapsed:
same_layout = (size == d_size and stride == d_stride)
base_dtype = d_dtype
else:
same_layout = (
size == data.layout.size
and stride == data.layout.stride
and original_offset == data.layout.offset
)
base_dtype = data.dtype
if same_layout:
# pure dtype view or simple handle dup
if dtype is not None and dtype != base_dtype:
final_tensor_str, tmp_call_strs = create_dtypeview_call(data.get_name())
else:
final_tensor_str, tmp_call_strs = create_new_tensor_handle()
call_strs.extend(tmp_call_strs)
else:
# firstly create reinterpretview
# need reinterpret
final_tensor_str = create_reinterpret_call()
if dtype is not None and dtype != data.dtype:
# wrap it with dtypeview
final_tensor_str, tmp_call_strs = create_dtypeview_call(
final_tensor_str
)
if dtype is not None and dtype != base_dtype:
final_tensor_str, tmp_call_strs = create_dtypeview_call(final_tensor_str)
call_strs.extend(tmp_call_strs)
for line in call_strs:

View File

@ -135,6 +135,39 @@ def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike):
return False
def codegen_reinterpret_view_helper(data):
"""
Collapse a chain of (ReinterpretView <- TensorBox| StorageBox)... <- buffer wrappers if every layer
has the same offset as the innermost (base) buffer.
Returns:
(size, stride, offset, dtype, collapsible: bool)
"""
if isinstance(data, ir.Buffer):
lay = data.get_layout()
return lay.size, lay.stride, lay.offset, lay.dtype, True
layouts: list[Any] = []
cur = data
while isinstance(cur, (ir.TensorBox, ir.StorageBox, ir.ReinterpretView)):
lay = cur.get_layout()
if lay is None:
return None, None, None, None, False
layouts.append(lay)
cur = cur.data # unwrap
if not isinstance(cur, ir.Buffer):
return None, None, None, None, False
# All wrapper offsets must match base offset to be collapsible
for lay in layouts:
if lay.offset != cur.get_layout().offset:
return None, None, None, None, False
base_lay = cur.get_layout()
return base_lay.size, base_lay.stride, base_lay.offset, base_lay.dtype, True
# TODO: Move to a well known place
TritonMetaParams = dict[str, int]
TritonGrid = Union[
@ -1964,25 +1997,36 @@ class PythonWrapperCodegen(CodeGen):
writeline: Callable[..., None],
dtype=None,
) -> str:
if (
size == data.layout.size
and stride == data.layout.stride
and offset == data.layout.offset
):
if dtype is not None and dtype != data.dtype:
return f"aten.view.dtype({data.get_name()}, {dtype})"
else:
return f"{data.get_name()}"
d_size, d_stride, d_offset, d_dtype, collapsible = codegen_reinterpret_view_helper(data)
def apply_reinterpret(name, tgt_size, tgt_stride, tgt_offset, cast_dtype, base_dtype):
s = self.codegen_python_shape_tuple(tgt_size)
st = self.codegen_python_shape_tuple(tgt_stride)
off = self.codegen_sizevar(tgt_offset)
expr = f"reinterpret_tensor({name}, {s}, {st}, {off})"
if cast_dtype is not None and cast_dtype != base_dtype:
return f"aten.view.dtype({expr}, {cast_dtype})"
return expr
name = data.get_name()
collapsed = collapsible and offset == d_offset
if collapsed:
same_layout = (size == d_size and stride == d_stride)
base_dtype = d_dtype
else:
size = self.codegen_python_shape_tuple(size)
stride = self.codegen_python_shape_tuple(stride)
offset = self.codegen_sizevar(offset)
if dtype is not None and dtype != data.dtype:
return f"aten.view.dtype(reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset}), {dtype})"
else:
return (
f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})"
)
same_layout = (
size == data.layout.size
and stride == data.layout.stride
and offset == data.layout.offset
)
base_dtype = data.dtype
if same_layout:
if dtype is not None and dtype != base_dtype:
return f"aten.view.dtype({name}, {dtype})"
return f"{name}"
return apply_reinterpret(name, size, stride, offset, dtype, base_dtype)
def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]):
self.writeline(f"{dst}.copy_({src}, {non_blocking})")
@ -3080,7 +3124,7 @@ class PythonWrapperCodegen(CodeGen):
if (
name in V.graph.removed_buffers
or name in self.allocated
or isinstance(buffer, (ir.DonatedBuffer, ir.SubgraphBuffer))
or isinstance(buffer, (ir.DonatedBuffer, ir.SubgraphBuffer, ir.InputBuffer))
):
return
self.allocated.add(name)
@ -3105,7 +3149,16 @@ class PythonWrapperCodegen(CodeGen):
box = layout.view.data
assert isinstance(box, ir.StorageBox), type(box)
input_buffer = box.data
assert isinstance(input_buffer, ir.Buffer), type(box)
assert isinstance(input_buffer, (ir.Buffer, ir.ReinterpretView)), type(input_buffer)
if isinstance(input_buffer, ir.ReinterpretView):
def unwrap_views(target) -> ir.Buffer:
if isinstance(target, ir.BaseView):
return unwrap_views(target.unwrap_view())
if isinstance(target, ir.MutableBox):
return unwrap_views(target.data)
assert isinstance(target, ir.Buffer), type(target)
return target
input_buffer = unwrap_views(input_buffer)
self.codegen_allocation(input_buffer)
self.writeline(ReinterpretLine(self, input_buffer, buffer, layout))
return

View File

@ -9,7 +9,7 @@ from torch._dynamo.utils import counters
from torch.fx.experimental.symbolic_shapes import has_free_symbols
from torch.utils._ordered_set import OrderedSet
from .. import ir
from .. import ir, mkldnn_ir
from ..lowering import lowerings as L
from ..pattern_matcher import (
Arg,
@ -760,6 +760,37 @@ if torch._C._has_mkldnn:
or len(_other.get_inputs_that_alias_output()) > 0
)
def _qlinear_binary_can_be_inplace(_other):
if isinstance(_other.data, ir.BaseView):
try:
# It can be inplaced when _other is the 2D to 3D view of a CppTemplateBuffer/QLinearPointwiseBinaryPT2E
# because if there is a view of CppTemplateBuffer/QLinearPointwiseBinaryPT2E,
# CppTemplateBuffer/QLinearPointwiseBinaryPT2E will not be used directly but the view.
if isinstance(
_other.data.data.data, # type: ignore[attr-defined]
(ir.CppTemplateBuffer, mkldnn_ir.QLinearPointwiseBinaryPT2E)
):
return True
else:
# This is a special case on VIT model:
# QLinearPointwiseBinaryPT2E(sum) -> QLinearPointwiseBinaryPT2E(sum) -> ...
# That means the output of previous QLinearPointwiseBinaryPT2E is the input x2 of current QLinearPointwiseBinaryPT2E.
# Use V.graph.operations to check if _other is a view of the output
# of previous QLinearPointwiseBinaryPT2E (the inputs[6]).
for op in V.graph.operations:
if (
isinstance(op, mkldnn_ir.QLinearPointwiseBinaryPT2E)
and _other.data.data.data == op.inputs[6] # type: ignore[attr-defined]
):
return True
return False
except AttributeError:
return False
elif len(_other.get_inputs_that_alias_output()) > 0:
return False
else:
return True
def _register_binary_unary_maybe_inplace_fusion_lowering(
pattern,
computation_op,

View File

@ -26,6 +26,7 @@ from ..pattern_matcher import (
from ..utils import pad_listlike
from .freezing_patterns import register_freezing_graph_pattern
from .post_grad import register_lowering_pattern
from ..virtualized import V
aten = torch.ops.aten
@ -494,6 +495,9 @@ def _is_valid_qlinear_lowering_pattern():
def fn(match):
if len(match.nodes) != 1:
return False
# if match['binary_op_name'] == 'sum':
# if match.kwargs['x_2'].users > 1:
return match.nodes[0].target in (
torch.ops.onednn.qlinear_pointwise.default,
torch.ops.onednn.qlinear_pointwise.tensor,
@ -598,22 +602,19 @@ def _register_quantized_linear_binary_lowering(
o_zero_point = kwargs["output_zero_point"]
x2.realize()
from .mkldnn_fusion import _can_be_inplace
from .mkldnn_fusion import _qlinear_binary_can_be_inplace
binary_op_name = kwargs["binary_op_name"]
alpha = kwargs["alpha"]
unary_op_name = kwargs["unary_op_name"]
unary_op_args = kwargs["unary_op_args"]
unary_op_algorithm = kwargs["unary_op_algorithm"]
if binary_op_name == "sum" and not _can_be_inplace(x2):
# When we enable the GEMM Template, the output of QLinear
# will be reshaped from 2D back to 3D if the input is 3D.
# This causes _can_be_inplace(x2) to return False if x2 happens
# to be the output of QLinear in this scenario.
# Change the post op from sum to binary add for this case.
# Refer to test case:
# test_mkldnn_pattern_matcher.py::test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2
if (
binary_op_name == "sum"
# Support sum for a special case on VIT model when the output of previous QLinearPointwiseBinaryPT2E/CPPTemplateBuffer
# is the x2 of current QLinearPointwiseBinaryPT2E even if x2 is a view of the output of previous QLinearPointwiseBinaryPT2E/CPPTemplateBuffer
and (not _qlinear_binary_can_be_inplace(x2))
):
binary_op_name = "add"
computation_args = (

View File

@ -989,7 +989,7 @@ def register_onednn_fusion_ops():
x_size = x.get_size()
x2_size = x2.get_size()
assert len(x_size) == len(x2_size)
if len(x_size) > 2 and binary_attr == "add":
if len(x_size) > 2 and binary_attr in ["add", "sum"]:
# GEMM template needs 2D input, normalize input shape here
x = view(x, [-1, x_size[-1]])
x2 = view(x2, [-1, x2_size[-1]])
@ -1056,9 +1056,10 @@ def register_onednn_fusion_ops():
x2_dtype = x2.get_dtype()
bias_dtype = bias.get_dtype() if bias is not None else None
choices: list[ChoiceCaller] = []
if (
config.max_autotune or config.max_autotune_gemm
) and binary_attr == "add": # <TODO> Support inplace sum fusion
if (config.max_autotune or config.max_autotune_gemm) and binary_attr in [
"add",
"sum",
]:
*_, layout, x, packed_weight, x2 = mm_args(
x, packed_weight, x2, layout=layout, out_dtype=output_dtype
)
@ -1283,7 +1284,25 @@ def register_onednn_fusion_ops():
layout,
input_gen_fns=input_gen_fns,
)
if len(x_size) > 2 and binary_attr == "add":
if (
isinstance(result.data.data, ir.CppTemplateBuffer)
and binary_attr == "sum"
):
# In this case, x2 is inplace updated when binary_attr is "sum"
# So we update the layout of result to view of x2
assert result.data.data.layout == x2.get_layout()
result = ir.TensorBox.create(
ir.CppTemplateBuffer(
layout=ir.NonOwningLayout(
ir.ReinterpretView(data=x2, layout=x2.get_layout())
),
inputs=result.data.data.inputs, # type: ignore[arg-type]
make_kernel_render=result.data.data.make_kernel_render, # type: ignore[arg-type]
template=result.data.data.template,
choice=result.data.data.choice,
)
)
if len(x_size) > 2 and binary_attr in ["add", "sum"]:
result = view(result, (*x_size[:-1], result.get_size()[-1]))
return result