mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] addmm + activation function fusion (#158137)
PR implements a pass in post_grad to fuse activation(add + mm) This was previously done similarly here #106912 but was reverted for performance reasons. it was replaced with a pass that unfuses the activation and add from addmm/addmm_activation and let inductor handle the fusion. however since then cuBLAS team has made a lot of perf improvements on this, will update this post with more benchmarks but preliminary benchmark show good results perf dash board <img width="3371" height="1240" alt="Screenshot from 2025-08-07 13-41-35" src="https://github.com/user-attachments/assets/d44d6205-b33a-4a20-9f0f-d9db176b3738" /> Relu works with both training and inference but gelu only works with inference mode due to some fundamental limitations since gelu's derivative depends on input and relu's doesnt. don't think this is fixable with the current addmm_activation API Graph module before and after this pass Relu(addmm) ``` graph(): %primals_1 : [num_users=1] = placeholder[target=primals_1] %primals_2 : [num_users=2] = placeholder[target=primals_2] %primals_3 : [num_users=2] = placeholder[target=primals_3] %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {}) %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%addmm,), kwargs = {}) %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {}) %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {}) return (relu, primals_2, le, permute_1) graph(): %primals_1 : [num_users=1] = placeholder[target=primals_1] %primals_2 : [num_users=2] = placeholder[target=primals_2] %primals_3 : [num_users=2] = placeholder[target=primals_3] %_addmm_activation_default : [num_users=2] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {}) %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%_addmm_activation_default, 0), kwargs = {}) %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {}) return (_addmm_activation_default, primals_2, le, permute_1) ``` Gelu (addmm) ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %addmm : [num_users=4] = call_function[target=torch.ops.aten.addmm.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %addmm), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %addmm), kwargs = {}) %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_1, 0.044715), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%addmm, %mul_2), kwargs = {}) %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 0.7978845608028654), kwargs = {}) %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, 0.5), kwargs = {}) %tanh : [num_users=1] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_3,), kwargs = {}) %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh, 1), kwargs = {}) %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_4, %add_1), kwargs = {}) return (mul_5,) graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %_addmm_activation_default : [num_users=1] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {use_gelu: True}) return (_addmm_activation_default,) ``` Benchmark setup: NGC pytorch 25.06 container cublas version: 12.9.1.4 torch.compile ran with dynamic = False and max_autotune H100 ``` Testing with M=1024, N=1024, K=1024, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.0107 ms Average Time per Iteration (torch compile): 0.0296 ms ============================================================ Testing with M=2048, N=2048, K=2048, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.0262 ms Average Time per Iteration (torch compile): 0.0327 ms ============================================================ Testing with M=4096, N=4096, K=4096, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.1763 ms Average Time per Iteration (torch compile): 0.2457 ms ============================================================ Testing with M=8192, N=8192, K=8192, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 1.5280 ms Average Time per Iteration (torch compile): 1.9437 ms ``` A100 ``` ############################################################ Testing with dtype: float16 ############################################################ ============================================================ Testing with M=1024, N=1024, K=1024, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.0313 ms Average Time per Iteration (torch compile): 0.0643 ms ============================================================ Testing with M=2048, N=2048, K=2048, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.1149 ms Average Time per Iteration (torch compile): 0.1255 ms ============================================================ Testing with M=4096, N=4096, K=4096, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.6297 ms Average Time per Iteration (torch compile): 0.7547 ms ============================================================ Testing with M=8192, N=8192, K=8192, dtype=float16 ============================================================ Average Time per Iteration (cublas): 4.3821 ms Average Time per Iteration (torch compile): 5.0740 ms ``` Script ```py import torch torch.manual_seed(0) warmup, numrun= 10, 100 sizes = [1024, 2048, 4096, 8192] dtypes = [torch.float16, torch.bfloat16, torch.float32] device = torch.device("cuda") for dtype in dtypes: dtype_name = str(dtype).split('.')[-1] print(f"\n{'#'*60}") print(f"Testing with dtype: {dtype_name}") print(f"{'#'*60}") for size in sizes: M, N, K = size, size, size print(f"\n{'='*60}") print(f"Testing with M={M}, N={N}, K={K}, dtype={dtype_name}") print(f"{'='*60}") A = torch.randn(M, K, device=device, dtype=dtype) B = torch.randn(K, N, device=device, dtype=dtype) C = torch.randn(M, device=device, dtype=dtype) def func1(): return torch._addmm_activation(C, A, B, use_gelu=True) def func2(): return torch.nn.functional.gelu(torch.add(C, torch.mm(A, B)), approximate="tanh") func2_compiled = torch.compile( func2, dynamic=False, options={ "force_disable_caches": True, "max_autotune": True, "max_autotune_gemm": True, "max_autotune_gemm_backends": "TRITON", "autotune_fallback_to_aten": False, } ) for _ in range(warmup): func1() torch.cuda.synchronize(device=device) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) total_time_ms = 0.0 start_event.record() for _ in range(numrun): func1() end_event.record() torch.cuda.synchronize(device=device) total_time_ms += start_event.elapsed_time(end_event) avg_time_ms = total_time_ms / numrun print(f"Average Time per Iteration (cublas):\t {avg_time_ms:.4f} ms") for _ in range(warmup): func2_compiled() torch.cuda.synchronize(device=device) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) total_time_ms = 0.0 start_event.record() for _ in range(numrun): func2_compiled() end_event.record() torch.cuda.synchronize(device=device) total_time_ms += start_event.elapsed_time(end_event) avg_time_ms = total_time_ms / numrun print(f"Average Time per Iteration (torch compile):\t {avg_time_ms:.4f} ms") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158137 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
1028c5e2d5
commit
b9d7de3a09
@ -2,7 +2,7 @@
|
||||
import os
|
||||
|
||||
from torch._inductor import pattern_matcher
|
||||
from torch._inductor.fx_passes import joint_graph
|
||||
from torch._inductor.fx_passes import joint_graph, post_grad
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -17,3 +17,4 @@ if __name__ == "__main__":
|
||||
# to serialize the patterns as it goes.
|
||||
os.environ["PYTORCH_GEN_PATTERNS"] = "1"
|
||||
joint_graph.lazy_init()
|
||||
post_grad.lazy_init()
|
||||
|
Reference in New Issue
Block a user