[Inductor] addmm + activation function fusion (#158137)

PR implements a pass in post_grad to fuse activation(add + mm)

This was previously done similarly here #106912 but was reverted for performance reasons. it was replaced with a pass that unfuses the activation and add from addmm/addmm_activation and let inductor handle the fusion.

however since then cuBLAS team has made a lot of perf improvements on this, will update this post with more benchmarks but preliminary benchmark show good results

perf dash board
<img width="3371" height="1240" alt="Screenshot from 2025-08-07 13-41-35" src="https://github.com/user-attachments/assets/d44d6205-b33a-4a20-9f0f-d9db176b3738" />

Relu works with both training and inference but gelu only works with inference mode due to some fundamental limitations since gelu's derivative depends on input and relu's doesnt. don't think this is fixable with the current addmm_activation API

Graph module before and after this pass

Relu(addmm)
```
graph():
    %primals_1 : [num_users=1] = placeholder[target=primals_1]
    %primals_2 : [num_users=2] = placeholder[target=primals_2]
    %primals_3 : [num_users=2] = placeholder[target=primals_3]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {})
    %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%addmm,), kwargs = {})
    %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {})
    %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {})
    return (relu, primals_2, le, permute_1)
graph():
    %primals_1 : [num_users=1] = placeholder[target=primals_1]
    %primals_2 : [num_users=2] = placeholder[target=primals_2]
    %primals_3 : [num_users=2] = placeholder[target=primals_3]
    %_addmm_activation_default : [num_users=2] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {})
    %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%_addmm_activation_default, 0), kwargs = {})
    %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {})
    return (_addmm_activation_default, primals_2, le, permute_1)
```
Gelu (addmm)
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %addmm : [num_users=4] = call_function[target=torch.ops.aten.addmm.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %addmm), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %addmm), kwargs = {})
    %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_1, 0.044715), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%addmm, %mul_2), kwargs = {})
    %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 0.7978845608028654), kwargs = {})
    %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, 0.5), kwargs = {})
    %tanh : [num_users=1] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_3,), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh, 1), kwargs = {})
    %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_4, %add_1), kwargs = {})
    return (mul_5,)
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %_addmm_activation_default : [num_users=1] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {use_gelu: True})
    return (_addmm_activation_default,)
```

Benchmark setup:
NGC pytorch 25.06 container
cublas version: 12.9.1.4
torch.compile ran with dynamic = False and max_autotune

H100
```
Testing with M=1024, N=1024, K=1024, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):	 0.0107 ms
Average Time per Iteration (torch compile):	 0.0296 ms

============================================================
Testing with M=2048, N=2048, K=2048, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):	 0.0262 ms
Average Time per Iteration (torch compile):	 0.0327 ms

============================================================
Testing with M=4096, N=4096, K=4096, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):	 0.1763 ms
Average Time per Iteration (torch compile):	 0.2457 ms

============================================================
Testing with M=8192, N=8192, K=8192, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):	 1.5280 ms
Average Time per Iteration (torch compile):	 1.9437 ms
```

A100
```
############################################################
Testing with dtype: float16
############################################################

============================================================
Testing with M=1024, N=1024, K=1024, dtype=float16
============================================================
Average Time per Iteration (cublas):	 0.0313 ms
Average Time per Iteration (torch compile):	 0.0643 ms

============================================================
Testing with M=2048, N=2048, K=2048, dtype=float16
============================================================
Average Time per Iteration (cublas):	 0.1149 ms
Average Time per Iteration (torch compile):	 0.1255 ms

============================================================
Testing with M=4096, N=4096, K=4096, dtype=float16
============================================================
Average Time per Iteration (cublas):	 0.6297 ms
Average Time per Iteration (torch compile):	 0.7547 ms

============================================================
Testing with M=8192, N=8192, K=8192, dtype=float16
============================================================
Average Time per Iteration (cublas):	 4.3821 ms
Average Time per Iteration (torch compile):	 5.0740 ms
```

Script
```py
import torch
torch.manual_seed(0)

warmup, numrun= 10, 100

sizes = [1024, 2048, 4096, 8192]
dtypes = [torch.float16, torch.bfloat16, torch.float32]

device = torch.device("cuda")

for dtype in dtypes:
    dtype_name = str(dtype).split('.')[-1]
    print(f"\n{'#'*60}")
    print(f"Testing with dtype: {dtype_name}")
    print(f"{'#'*60}")

    for size in sizes:
        M, N, K = size, size, size
        print(f"\n{'='*60}")
        print(f"Testing with M={M}, N={N}, K={K}, dtype={dtype_name}")
        print(f"{'='*60}")

        A = torch.randn(M, K, device=device, dtype=dtype)
        B = torch.randn(K, N, device=device, dtype=dtype)
        C = torch.randn(M, device=device, dtype=dtype)

        def func1():
            return torch._addmm_activation(C, A, B, use_gelu=True)

        def func2():
            return torch.nn.functional.gelu(torch.add(C, torch.mm(A, B)), approximate="tanh")

        func2_compiled = torch.compile(
            func2,
            dynamic=False,
            options={
                "force_disable_caches": True,
                "max_autotune": True,
                "max_autotune_gemm": True,
                "max_autotune_gemm_backends": "TRITON",
                "autotune_fallback_to_aten": False,
            }
        )

        for _ in range(warmup): func1()
        torch.cuda.synchronize(device=device)

        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        total_time_ms = 0.0
        start_event.record()
        for _ in range(numrun): func1()
        end_event.record()
        torch.cuda.synchronize(device=device)
        total_time_ms += start_event.elapsed_time(end_event)
        avg_time_ms = total_time_ms / numrun

        print(f"Average Time per Iteration (cublas):\t {avg_time_ms:.4f} ms")

        for _ in range(warmup): func2_compiled()
        torch.cuda.synchronize(device=device)

        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        total_time_ms = 0.0
        start_event.record()
        for _ in range(numrun): func2_compiled()
        end_event.record()
        torch.cuda.synchronize(device=device)
        total_time_ms += start_event.elapsed_time(end_event)
        avg_time_ms = total_time_ms / numrun

        print(f"Average Time per Iteration (torch compile):\t {avg_time_ms:.4f} ms")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158137
Approved by: https://github.com/eellison
This commit is contained in:
AaronWang04
2025-08-14 20:41:38 +00:00
committed by PyTorch MergeBot
parent 1028c5e2d5
commit b9d7de3a09
7 changed files with 343 additions and 2 deletions

View File

@ -635,6 +635,64 @@ class TestPatternMatcher(TestCase):
self.assertEqual(res1, res2)
@skipIfRocm
def test_addmm_activation(self):
def fn_addmm_relu(input, mat1, mat2):
return torch.nn.functional.relu(torch.addmm(input, mat1, mat2))
def fn_addmm_gelu(input, mat1, mat2):
return torch.nn.functional.gelu(
torch.addmm(input, mat1, mat2), approximate="tanh"
)
def fn_add_mm_relu(input, mat1, mat2):
return torch.nn.functional.relu(torch.add(input, torch.mm(mat1, mat2)))
def fn_add_mm_gelu(input, mat1, mat2):
return torch.nn.functional.gelu(
torch.add(input, torch.mm(mat1, mat2)), approximate="tanh"
)
args = [
torch.randn(20, device=GPU_TYPE),
torch.randn(10, 15, device=GPU_TYPE),
torch.randn(15, 20, device=GPU_TYPE),
]
for fn, atol in (
(fn_addmm_relu, 1e-8),
(fn_add_mm_relu, 1e-8),
(fn_addmm_gelu, 1e-3),
(fn_add_mm_gelu, 1e-3),
):
expected = fn(*args)
actual, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(actual, expected, atol=atol, rtol=0)
self.assertTrue("_addmm_activation" in code)
for fn in (fn_addmm_relu, fn_addmm_gelu):
actual, (code,) = run_and_get_code(
torch.compile(fn, options={"max_autotune_gemm": True}), *args
)
self.assertFalse("_addmm_activation" in code)
args_not_replaced = [
# addmm + activation with a rank-2 input
# is not fusable, hence not replaced
torch.randn(10, 20, device=GPU_TYPE), # input
torch.randn(10, 15, device=GPU_TYPE), # mat1
torch.randn(15, 20, device=GPU_TYPE), # mat2
]
for fn in (fn_addmm_relu, fn_addmm_gelu):
actual, (code,) = run_and_get_code(
torch.compile(
fn,
),
*args_not_replaced,
)
self.assertFalse("_addmm_activation" in code)
@inductor_config.patch(
{
"max_autotune_gemm_backends": "ATEN",

View File

@ -33,6 +33,7 @@ from ..pattern_matcher import (
CallFunctionVarArgs,
filter_nodes,
fwd_only,
gen_register_replacement,
get_arg_value,
get_mutation_region_id,
Ignored,
@ -660,6 +661,97 @@ def lazy_init():
extra_check=prepare_softmax_extra_check,
)
register_addmm_activation_fusion()
@functools.cache
def register_addmm_activation_fusion():
shapes = [(5,), (3, 4), (4, 5)]
args_fp32 = [torch.empty(shape) for shape in shapes]
args_bf16 = [torch.empty(shape, dtype=torch.bfloat16) for shape in shapes]
for pattern in [addmm_relu_pattern, addmm_relu_pattern_2]:
name = f"{pattern.__name__}_fp32"
gen_register_replacement(
name,
pattern,
addmm_relu_replacement,
args_fp32,
trace_fn=fwd_only,
pass_dicts=pass_patterns[2],
extra_check=is_valid_addmm_activation_fusion,
)
for args, dtype_suffix in [(args_fp32, "fp32"), (args_bf16, "bf16")]:
for pattern in [addmm_gelu_pattern, addmm_gelu_pattern_2]:
name = f"{pattern.__name__}_{dtype_suffix}"
gen_register_replacement(
name,
pattern,
addmm_gelu_replacement,
args,
trace_fn=fwd_only,
pass_dicts=pass_patterns[2],
extra_check=is_valid_addmm_activation_fusion,
)
def is_valid_addmm_activation_fusion(match):
if config.max_autotune_gemm:
return False
inp = match.kwargs["input"].meta["val"]
mat1 = match.kwargs["mat1"].meta["val"]
mat2 = match.kwargs["mat2"].meta["val"]
# match the dispatch logic for cuBLASLT at aten/src/ATen/native/cuda/Blas.cpp
if not (inp.is_cuda and inp.dim() == 1 and inp.is_contiguous()):
return False
if not (mat1.dim() == 2 and mat2.dim() == 2):
return False
if inp.size(0) != mat2.size(1):
return False
if inp.dtype != mat1.dtype or inp.dtype != mat2.dtype:
return False
output = match.output_node()
# do not fuse if there are pointwise ops after
return not all(is_pointwise_use(use) for use in output.users)
def addmm_gelu_pattern(input, mat1, mat2):
output = aten.mm(mat1, mat2)
output = aten.add(output, input)
return aten.gelu(output, approximate="tanh")
def addmm_gelu_pattern_2(input, mat1, mat2):
output = aten.mm(mat1, mat2)
output = aten.add(input, output)
return aten.gelu(output, approximate="tanh")
def addmm_gelu_replacement(input, mat1, mat2):
return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=True)
def addmm_relu_pattern(input, mat1, mat2):
output = aten.mm(mat1, mat2)
output = aten.add(input, output)
return aten.relu(output)
def addmm_relu_pattern_2(input, mat1, mat2):
output = aten.mm(mat1, mat2)
output = aten.add(output, input)
return aten.relu(output)
def addmm_relu_replacement(input, mat1, mat2):
return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=False)
def reorder_for_locality(graph: torch.fx.Graph):
if torch.distributed.is_available():
@ -1461,7 +1553,7 @@ def should_prefer_unfused_addmm(match):
@register_graph_pattern(
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
pass_dict=pass_patterns[2],
pass_dict=pass_patterns[1],
extra_check=should_prefer_unfused_addmm,
)
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):

View File

@ -0,0 +1,59 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
import torch
import torch._inductor
import operator
aten = torch.ops.aten
prims = torch.ops.prims
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
CallFunctionVarArgs,
CallMethod,
CallMethodVarArgs,
CallModule,
CallModuleVarArgs,
ExclusiveKeywordArg,
Ignored,
KeywordArg,
ListOf,
MultiOutputPattern,
PatternExpr,
RepeatedExpr,
_TargetArgsExpr,
_TargetExpr,
_TargetExprVarArgs,
)
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'), _users=4)
mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored())
mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor)
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor)
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
addmm_gelu_pattern_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0)
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'))
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4)
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default)
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default)
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2)
addmm_gelu_pattern_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0)

View File

@ -0,0 +1,59 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
import torch
import torch._inductor
import operator
aten = torch.ops.aten
prims = torch.ops.prims
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
CallFunctionVarArgs,
CallMethod,
CallMethodVarArgs,
CallModule,
CallModuleVarArgs,
ExclusiveKeywordArg,
Ignored,
KeywordArg,
ListOf,
MultiOutputPattern,
PatternExpr,
RepeatedExpr,
_TargetArgsExpr,
_TargetExpr,
_TargetExprVarArgs,
)
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default, _users=4)
mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored())
mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor)
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor)
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
addmm_gelu_pattern_2_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0)
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default)
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4)
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default)
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default)
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2)
addmm_gelu_pattern_2_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0)

View File

@ -0,0 +1,36 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
import torch
import torch._inductor
import operator
aten = torch.ops.aten
prims = torch.ops.prims
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
CallFunctionVarArgs,
CallMethod,
CallMethodVarArgs,
CallModule,
CallModuleVarArgs,
ExclusiveKeywordArg,
Ignored,
KeywordArg,
ListOf,
MultiOutputPattern,
PatternExpr,
RepeatedExpr,
_TargetArgsExpr,
_TargetExpr,
_TargetExprVarArgs,
)
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default)
addmm_relu_pattern_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0)

View File

@ -0,0 +1,36 @@
# mypy: ignore-errors
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
import torch
import torch._inductor
import operator
aten = torch.ops.aten
prims = torch.ops.prims
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
CallFunctionVarArgs,
CallMethod,
CallMethodVarArgs,
CallModule,
CallModuleVarArgs,
ExclusiveKeywordArg,
Ignored,
KeywordArg,
ListOf,
MultiOutputPattern,
PatternExpr,
RepeatedExpr,
_TargetArgsExpr,
_TargetExpr,
_TargetExprVarArgs,
)
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'))
addmm_relu_pattern_2_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0)

View File

@ -2,7 +2,7 @@
import os
from torch._inductor import pattern_matcher
from torch._inductor.fx_passes import joint_graph
from torch._inductor.fx_passes import joint_graph, post_grad
if __name__ == "__main__":
@ -17,3 +17,4 @@ if __name__ == "__main__":
# to serialize the patterns as it goes.
os.environ["PYTORCH_GEN_PATTERNS"] = "1"
joint_graph.lazy_init()
post_grad.lazy_init()