Revert "[Inductor] addmm + activation function fusion (#158137)"

This reverts commit b9d7de3a094598c3dc0dd52e57bce30eb684c9d8.

Reverted https://github.com/pytorch/pytorch/pull/158137 on behalf of https://github.com/malfet due to Broke inductor torchbench, see 663da17b62/1 ([comment](https://github.com/pytorch/pytorch/pull/158137#issuecomment-3191841298))
This commit is contained in:
PyTorch MergeBot
2025-08-15 15:34:09 +00:00
parent 663da17b62
commit 846963fa9b
7 changed files with 2 additions and 343 deletions

View File

@ -635,64 +635,6 @@ class TestPatternMatcher(TestCase):
self.assertEqual(res1, res2)
@skipIfRocm
def test_addmm_activation(self):
def fn_addmm_relu(input, mat1, mat2):
return torch.nn.functional.relu(torch.addmm(input, mat1, mat2))
def fn_addmm_gelu(input, mat1, mat2):
return torch.nn.functional.gelu(
torch.addmm(input, mat1, mat2), approximate="tanh"
)
def fn_add_mm_relu(input, mat1, mat2):
return torch.nn.functional.relu(torch.add(input, torch.mm(mat1, mat2)))
def fn_add_mm_gelu(input, mat1, mat2):
return torch.nn.functional.gelu(
torch.add(input, torch.mm(mat1, mat2)), approximate="tanh"
)
args = [
torch.randn(20, device=GPU_TYPE),
torch.randn(10, 15, device=GPU_TYPE),
torch.randn(15, 20, device=GPU_TYPE),
]
for fn, atol in (
(fn_addmm_relu, 1e-8),
(fn_add_mm_relu, 1e-8),
(fn_addmm_gelu, 1e-3),
(fn_add_mm_gelu, 1e-3),
):
expected = fn(*args)
actual, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(actual, expected, atol=atol, rtol=0)
self.assertTrue("_addmm_activation" in code)
for fn in (fn_addmm_relu, fn_addmm_gelu):
actual, (code,) = run_and_get_code(
torch.compile(fn, options={"max_autotune_gemm": True}), *args
)
self.assertFalse("_addmm_activation" in code)
args_not_replaced = [
# addmm + activation with a rank-2 input
# is not fusable, hence not replaced
torch.randn(10, 20, device=GPU_TYPE), # input
torch.randn(10, 15, device=GPU_TYPE), # mat1
torch.randn(15, 20, device=GPU_TYPE), # mat2
]
for fn in (fn_addmm_relu, fn_addmm_gelu):
actual, (code,) = run_and_get_code(
torch.compile(
fn,
),
*args_not_replaced,
)
self.assertFalse("_addmm_activation" in code)
@inductor_config.patch(
{
"max_autotune_gemm_backends": "ATEN",

View File

@ -33,7 +33,6 @@ from ..pattern_matcher import (
CallFunctionVarArgs,
filter_nodes,
fwd_only,
gen_register_replacement,
get_arg_value,
get_mutation_region_id,
Ignored,
@ -661,97 +660,6 @@ def lazy_init():
extra_check=prepare_softmax_extra_check,
)
register_addmm_activation_fusion()
@functools.cache
def register_addmm_activation_fusion():
shapes = [(5,), (3, 4), (4, 5)]
args_fp32 = [torch.empty(shape) for shape in shapes]
args_bf16 = [torch.empty(shape, dtype=torch.bfloat16) for shape in shapes]
for pattern in [addmm_relu_pattern, addmm_relu_pattern_2]:
name = f"{pattern.__name__}_fp32"
gen_register_replacement(
name,
pattern,
addmm_relu_replacement,
args_fp32,
trace_fn=fwd_only,
pass_dicts=pass_patterns[2],
extra_check=is_valid_addmm_activation_fusion,
)
for args, dtype_suffix in [(args_fp32, "fp32"), (args_bf16, "bf16")]:
for pattern in [addmm_gelu_pattern, addmm_gelu_pattern_2]:
name = f"{pattern.__name__}_{dtype_suffix}"
gen_register_replacement(
name,
pattern,
addmm_gelu_replacement,
args,
trace_fn=fwd_only,
pass_dicts=pass_patterns[2],
extra_check=is_valid_addmm_activation_fusion,
)
def is_valid_addmm_activation_fusion(match):
if config.max_autotune_gemm:
return False
inp = match.kwargs["input"].meta["val"]
mat1 = match.kwargs["mat1"].meta["val"]
mat2 = match.kwargs["mat2"].meta["val"]
# match the dispatch logic for cuBLASLT at aten/src/ATen/native/cuda/Blas.cpp
if not (inp.is_cuda and inp.dim() == 1 and inp.is_contiguous()):
return False
if not (mat1.dim() == 2 and mat2.dim() == 2):
return False
if inp.size(0) != mat2.size(1):
return False
if inp.dtype != mat1.dtype or inp.dtype != mat2.dtype:
return False
output = match.output_node()
# do not fuse if there are pointwise ops after
return not all(is_pointwise_use(use) for use in output.users)
def addmm_gelu_pattern(input, mat1, mat2):
output = aten.mm(mat1, mat2)
output = aten.add(output, input)
return aten.gelu(output, approximate="tanh")
def addmm_gelu_pattern_2(input, mat1, mat2):
output = aten.mm(mat1, mat2)
output = aten.add(input, output)
return aten.gelu(output, approximate="tanh")
def addmm_gelu_replacement(input, mat1, mat2):
return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=True)
def addmm_relu_pattern(input, mat1, mat2):
output = aten.mm(mat1, mat2)
output = aten.add(input, output)
return aten.relu(output)
def addmm_relu_pattern_2(input, mat1, mat2):
output = aten.mm(mat1, mat2)
output = aten.add(output, input)
return aten.relu(output)
def addmm_relu_replacement(input, mat1, mat2):
return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=False)
def reorder_for_locality(graph: torch.fx.Graph):
if torch.distributed.is_available():
@ -1553,7 +1461,7 @@ def should_prefer_unfused_addmm(match):
@register_graph_pattern(
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
pass_dict=pass_patterns[1],
pass_dict=pass_patterns[2],
extra_check=should_prefer_unfused_addmm,
)
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):

View File

@ -1,59 +0,0 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
import torch
import torch._inductor
import operator
aten = torch.ops.aten
prims = torch.ops.prims
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
CallFunctionVarArgs,
CallMethod,
CallMethodVarArgs,
CallModule,
CallModuleVarArgs,
ExclusiveKeywordArg,
Ignored,
KeywordArg,
ListOf,
MultiOutputPattern,
PatternExpr,
RepeatedExpr,
_TargetArgsExpr,
_TargetExpr,
_TargetExprVarArgs,
)
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'), _users=4)
mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored())
mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor)
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor)
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
addmm_gelu_pattern_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0)
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'))
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4)
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default)
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default)
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2)
addmm_gelu_pattern_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0)

View File

@ -1,59 +0,0 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
import torch
import torch._inductor
import operator
aten = torch.ops.aten
prims = torch.ops.prims
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
CallFunctionVarArgs,
CallMethod,
CallMethodVarArgs,
CallModule,
CallModuleVarArgs,
ExclusiveKeywordArg,
Ignored,
KeywordArg,
ListOf,
MultiOutputPattern,
PatternExpr,
RepeatedExpr,
_TargetArgsExpr,
_TargetExpr,
_TargetExprVarArgs,
)
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default, _users=4)
mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored())
mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor)
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor)
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
addmm_gelu_pattern_2_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0)
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default)
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4)
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default)
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default)
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2)
addmm_gelu_pattern_2_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0)

View File

@ -1,36 +0,0 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
import torch
import torch._inductor
import operator
aten = torch.ops.aten
prims = torch.ops.prims
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
CallFunctionVarArgs,
CallMethod,
CallMethodVarArgs,
CallModule,
CallModuleVarArgs,
ExclusiveKeywordArg,
Ignored,
KeywordArg,
ListOf,
MultiOutputPattern,
PatternExpr,
RepeatedExpr,
_TargetArgsExpr,
_TargetExpr,
_TargetExprVarArgs,
)
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default)
addmm_relu_pattern_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0)

View File

@ -1,36 +0,0 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
import torch
import torch._inductor
import operator
aten = torch.ops.aten
prims = torch.ops.prims
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
CallFunctionVarArgs,
CallMethod,
CallMethodVarArgs,
CallModule,
CallModuleVarArgs,
ExclusiveKeywordArg,
Ignored,
KeywordArg,
ListOf,
MultiOutputPattern,
PatternExpr,
RepeatedExpr,
_TargetArgsExpr,
_TargetExpr,
_TargetExprVarArgs,
)
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'))
addmm_relu_pattern_2_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0)

View File

@ -2,7 +2,7 @@
import os
from torch._inductor import pattern_matcher
from torch._inductor.fx_passes import joint_graph, post_grad
from torch._inductor.fx_passes import joint_graph
if __name__ == "__main__":
@ -17,4 +17,3 @@ if __name__ == "__main__":
# to serialize the patterns as it goes.
os.environ["PYTORCH_GEN_PATTERNS"] = "1"
joint_graph.lazy_init()
post_grad.lazy_init()