mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] addmm + activation function fusion (#158137)
PR implements a pass in post_grad to fuse activation(add + mm) This was previously done similarly here #106912 but was reverted for performance reasons. it was replaced with a pass that unfuses the activation and add from addmm/addmm_activation and let inductor handle the fusion. however since then cuBLAS team has made a lot of perf improvements on this, will update this post with more benchmarks but preliminary benchmark show good results perf dash board <img width="3371" height="1240" alt="Screenshot from 2025-08-07 13-41-35" src="https://github.com/user-attachments/assets/d44d6205-b33a-4a20-9f0f-d9db176b3738" /> Relu works with both training and inference but gelu only works with inference mode due to some fundamental limitations since gelu's derivative depends on input and relu's doesnt. don't think this is fixable with the current addmm_activation API Graph module before and after this pass Relu(addmm) ``` graph(): %primals_1 : [num_users=1] = placeholder[target=primals_1] %primals_2 : [num_users=2] = placeholder[target=primals_2] %primals_3 : [num_users=2] = placeholder[target=primals_3] %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {}) %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%addmm,), kwargs = {}) %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {}) %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {}) return (relu, primals_2, le, permute_1) graph(): %primals_1 : [num_users=1] = placeholder[target=primals_1] %primals_2 : [num_users=2] = placeholder[target=primals_2] %primals_3 : [num_users=2] = placeholder[target=primals_3] %_addmm_activation_default : [num_users=2] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {}) %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%_addmm_activation_default, 0), kwargs = {}) %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {}) return (_addmm_activation_default, primals_2, le, permute_1) ``` Gelu (addmm) ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %addmm : [num_users=4] = call_function[target=torch.ops.aten.addmm.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %addmm), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %addmm), kwargs = {}) %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_1, 0.044715), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%addmm, %mul_2), kwargs = {}) %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 0.7978845608028654), kwargs = {}) %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, 0.5), kwargs = {}) %tanh : [num_users=1] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_3,), kwargs = {}) %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh, 1), kwargs = {}) %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_4, %add_1), kwargs = {}) return (mul_5,) graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %_addmm_activation_default : [num_users=1] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {use_gelu: True}) return (_addmm_activation_default,) ``` Benchmark setup: NGC pytorch 25.06 container cublas version: 12.9.1.4 torch.compile ran with dynamic = False and max_autotune H100 ``` Testing with M=1024, N=1024, K=1024, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.0107 ms Average Time per Iteration (torch compile): 0.0296 ms ============================================================ Testing with M=2048, N=2048, K=2048, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.0262 ms Average Time per Iteration (torch compile): 0.0327 ms ============================================================ Testing with M=4096, N=4096, K=4096, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.1763 ms Average Time per Iteration (torch compile): 0.2457 ms ============================================================ Testing with M=8192, N=8192, K=8192, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 1.5280 ms Average Time per Iteration (torch compile): 1.9437 ms ``` A100 ``` ############################################################ Testing with dtype: float16 ############################################################ ============================================================ Testing with M=1024, N=1024, K=1024, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.0313 ms Average Time per Iteration (torch compile): 0.0643 ms ============================================================ Testing with M=2048, N=2048, K=2048, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.1149 ms Average Time per Iteration (torch compile): 0.1255 ms ============================================================ Testing with M=4096, N=4096, K=4096, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.6297 ms Average Time per Iteration (torch compile): 0.7547 ms ============================================================ Testing with M=8192, N=8192, K=8192, dtype=float16 ============================================================ Average Time per Iteration (cublas): 4.3821 ms Average Time per Iteration (torch compile): 5.0740 ms ``` Script ```py import torch torch.manual_seed(0) warmup, numrun= 10, 100 sizes = [1024, 2048, 4096, 8192] dtypes = [torch.float16, torch.bfloat16, torch.float32] device = torch.device("cuda") for dtype in dtypes: dtype_name = str(dtype).split('.')[-1] print(f"\n{'#'*60}") print(f"Testing with dtype: {dtype_name}") print(f"{'#'*60}") for size in sizes: M, N, K = size, size, size print(f"\n{'='*60}") print(f"Testing with M={M}, N={N}, K={K}, dtype={dtype_name}") print(f"{'='*60}") A = torch.randn(M, K, device=device, dtype=dtype) B = torch.randn(K, N, device=device, dtype=dtype) C = torch.randn(M, device=device, dtype=dtype) def func1(): return torch._addmm_activation(C, A, B, use_gelu=True) def func2(): return torch.nn.functional.gelu(torch.add(C, torch.mm(A, B)), approximate="tanh") func2_compiled = torch.compile( func2, dynamic=False, options={ "force_disable_caches": True, "max_autotune": True, "max_autotune_gemm": True, "max_autotune_gemm_backends": "TRITON", "autotune_fallback_to_aten": False, } ) for _ in range(warmup): func1() torch.cuda.synchronize(device=device) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) total_time_ms = 0.0 start_event.record() for _ in range(numrun): func1() end_event.record() torch.cuda.synchronize(device=device) total_time_ms += start_event.elapsed_time(end_event) avg_time_ms = total_time_ms / numrun print(f"Average Time per Iteration (cublas):\t {avg_time_ms:.4f} ms") for _ in range(warmup): func2_compiled() torch.cuda.synchronize(device=device) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) total_time_ms = 0.0 start_event.record() for _ in range(numrun): func2_compiled() end_event.record() torch.cuda.synchronize(device=device) total_time_ms += start_event.elapsed_time(end_event) avg_time_ms = total_time_ms / numrun print(f"Average Time per Iteration (torch compile):\t {avg_time_ms:.4f} ms") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158137 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
1028c5e2d5
commit
b9d7de3a09
@ -635,6 +635,64 @@ class TestPatternMatcher(TestCase):
|
||||
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
@skipIfRocm
|
||||
def test_addmm_activation(self):
|
||||
def fn_addmm_relu(input, mat1, mat2):
|
||||
return torch.nn.functional.relu(torch.addmm(input, mat1, mat2))
|
||||
|
||||
def fn_addmm_gelu(input, mat1, mat2):
|
||||
return torch.nn.functional.gelu(
|
||||
torch.addmm(input, mat1, mat2), approximate="tanh"
|
||||
)
|
||||
|
||||
def fn_add_mm_relu(input, mat1, mat2):
|
||||
return torch.nn.functional.relu(torch.add(input, torch.mm(mat1, mat2)))
|
||||
|
||||
def fn_add_mm_gelu(input, mat1, mat2):
|
||||
return torch.nn.functional.gelu(
|
||||
torch.add(input, torch.mm(mat1, mat2)), approximate="tanh"
|
||||
)
|
||||
|
||||
args = [
|
||||
torch.randn(20, device=GPU_TYPE),
|
||||
torch.randn(10, 15, device=GPU_TYPE),
|
||||
torch.randn(15, 20, device=GPU_TYPE),
|
||||
]
|
||||
|
||||
for fn, atol in (
|
||||
(fn_addmm_relu, 1e-8),
|
||||
(fn_add_mm_relu, 1e-8),
|
||||
(fn_addmm_gelu, 1e-3),
|
||||
(fn_add_mm_gelu, 1e-3),
|
||||
):
|
||||
expected = fn(*args)
|
||||
actual, (code,) = run_and_get_code(torch.compile(fn), *args)
|
||||
torch.testing.assert_close(actual, expected, atol=atol, rtol=0)
|
||||
self.assertTrue("_addmm_activation" in code)
|
||||
|
||||
for fn in (fn_addmm_relu, fn_addmm_gelu):
|
||||
actual, (code,) = run_and_get_code(
|
||||
torch.compile(fn, options={"max_autotune_gemm": True}), *args
|
||||
)
|
||||
self.assertFalse("_addmm_activation" in code)
|
||||
|
||||
args_not_replaced = [
|
||||
# addmm + activation with a rank-2 input
|
||||
# is not fusable, hence not replaced
|
||||
torch.randn(10, 20, device=GPU_TYPE), # input
|
||||
torch.randn(10, 15, device=GPU_TYPE), # mat1
|
||||
torch.randn(15, 20, device=GPU_TYPE), # mat2
|
||||
]
|
||||
|
||||
for fn in (fn_addmm_relu, fn_addmm_gelu):
|
||||
actual, (code,) = run_and_get_code(
|
||||
torch.compile(
|
||||
fn,
|
||||
),
|
||||
*args_not_replaced,
|
||||
)
|
||||
self.assertFalse("_addmm_activation" in code)
|
||||
|
||||
@inductor_config.patch(
|
||||
{
|
||||
"max_autotune_gemm_backends": "ATEN",
|
||||
|
@ -33,6 +33,7 @@ from ..pattern_matcher import (
|
||||
CallFunctionVarArgs,
|
||||
filter_nodes,
|
||||
fwd_only,
|
||||
gen_register_replacement,
|
||||
get_arg_value,
|
||||
get_mutation_region_id,
|
||||
Ignored,
|
||||
@ -660,6 +661,97 @@ def lazy_init():
|
||||
extra_check=prepare_softmax_extra_check,
|
||||
)
|
||||
|
||||
register_addmm_activation_fusion()
|
||||
|
||||
|
||||
@functools.cache
|
||||
def register_addmm_activation_fusion():
|
||||
shapes = [(5,), (3, 4), (4, 5)]
|
||||
args_fp32 = [torch.empty(shape) for shape in shapes]
|
||||
args_bf16 = [torch.empty(shape, dtype=torch.bfloat16) for shape in shapes]
|
||||
|
||||
for pattern in [addmm_relu_pattern, addmm_relu_pattern_2]:
|
||||
name = f"{pattern.__name__}_fp32"
|
||||
gen_register_replacement(
|
||||
name,
|
||||
pattern,
|
||||
addmm_relu_replacement,
|
||||
args_fp32,
|
||||
trace_fn=fwd_only,
|
||||
pass_dicts=pass_patterns[2],
|
||||
extra_check=is_valid_addmm_activation_fusion,
|
||||
)
|
||||
|
||||
for args, dtype_suffix in [(args_fp32, "fp32"), (args_bf16, "bf16")]:
|
||||
for pattern in [addmm_gelu_pattern, addmm_gelu_pattern_2]:
|
||||
name = f"{pattern.__name__}_{dtype_suffix}"
|
||||
gen_register_replacement(
|
||||
name,
|
||||
pattern,
|
||||
addmm_gelu_replacement,
|
||||
args,
|
||||
trace_fn=fwd_only,
|
||||
pass_dicts=pass_patterns[2],
|
||||
extra_check=is_valid_addmm_activation_fusion,
|
||||
)
|
||||
|
||||
|
||||
def is_valid_addmm_activation_fusion(match):
|
||||
if config.max_autotune_gemm:
|
||||
return False
|
||||
inp = match.kwargs["input"].meta["val"]
|
||||
mat1 = match.kwargs["mat1"].meta["val"]
|
||||
mat2 = match.kwargs["mat2"].meta["val"]
|
||||
|
||||
# match the dispatch logic for cuBLASLT at aten/src/ATen/native/cuda/Blas.cpp
|
||||
if not (inp.is_cuda and inp.dim() == 1 and inp.is_contiguous()):
|
||||
return False
|
||||
|
||||
if not (mat1.dim() == 2 and mat2.dim() == 2):
|
||||
return False
|
||||
|
||||
if inp.size(0) != mat2.size(1):
|
||||
return False
|
||||
|
||||
if inp.dtype != mat1.dtype or inp.dtype != mat2.dtype:
|
||||
return False
|
||||
|
||||
output = match.output_node()
|
||||
# do not fuse if there are pointwise ops after
|
||||
return not all(is_pointwise_use(use) for use in output.users)
|
||||
|
||||
|
||||
def addmm_gelu_pattern(input, mat1, mat2):
|
||||
output = aten.mm(mat1, mat2)
|
||||
output = aten.add(output, input)
|
||||
return aten.gelu(output, approximate="tanh")
|
||||
|
||||
|
||||
def addmm_gelu_pattern_2(input, mat1, mat2):
|
||||
output = aten.mm(mat1, mat2)
|
||||
output = aten.add(input, output)
|
||||
return aten.gelu(output, approximate="tanh")
|
||||
|
||||
|
||||
def addmm_gelu_replacement(input, mat1, mat2):
|
||||
return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=True)
|
||||
|
||||
|
||||
def addmm_relu_pattern(input, mat1, mat2):
|
||||
output = aten.mm(mat1, mat2)
|
||||
output = aten.add(input, output)
|
||||
return aten.relu(output)
|
||||
|
||||
|
||||
def addmm_relu_pattern_2(input, mat1, mat2):
|
||||
output = aten.mm(mat1, mat2)
|
||||
output = aten.add(output, input)
|
||||
return aten.relu(output)
|
||||
|
||||
|
||||
def addmm_relu_replacement(input, mat1, mat2):
|
||||
return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=False)
|
||||
|
||||
|
||||
def reorder_for_locality(graph: torch.fx.Graph):
|
||||
if torch.distributed.is_available():
|
||||
@ -1461,7 +1553,7 @@ def should_prefer_unfused_addmm(match):
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
|
||||
pass_dict=pass_patterns[2],
|
||||
pass_dict=pass_patterns[1],
|
||||
extra_check=should_prefer_unfused_addmm,
|
||||
)
|
||||
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
|
||||
|
@ -0,0 +1,59 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
import operator
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
CallModule,
|
||||
CallModuleVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'), _users=4)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored())
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor)
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor)
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
|
||||
add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
|
||||
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
|
||||
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
|
||||
addmm_gelu_pattern_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0)
|
||||
|
||||
|
||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'))
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default)
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default)
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
|
||||
add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
|
||||
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
|
||||
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
|
||||
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2)
|
||||
addmm_gelu_pattern_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0)
|
@ -0,0 +1,59 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
import operator
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
CallModule,
|
||||
CallModuleVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default, _users=4)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored())
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor)
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor)
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
|
||||
add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
|
||||
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
|
||||
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
|
||||
addmm_gelu_pattern_2_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0)
|
||||
|
||||
|
||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default)
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default)
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default)
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
|
||||
add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
|
||||
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
|
||||
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
|
||||
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2)
|
||||
addmm_gelu_pattern_2_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0)
|
@ -0,0 +1,36 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
import operator
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
CallModule,
|
||||
CallModuleVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default)
|
||||
addmm_relu_pattern_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0)
|
@ -0,0 +1,36 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
import operator
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
CallModule,
|
||||
CallModuleVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'))
|
||||
addmm_relu_pattern_2_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0)
|
@ -2,7 +2,7 @@
|
||||
import os
|
||||
|
||||
from torch._inductor import pattern_matcher
|
||||
from torch._inductor.fx_passes import joint_graph
|
||||
from torch._inductor.fx_passes import joint_graph, post_grad
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -17,3 +17,4 @@ if __name__ == "__main__":
|
||||
# to serialize the patterns as it goes.
|
||||
os.environ["PYTORCH_GEN_PATTERNS"] = "1"
|
||||
joint_graph.lazy_init()
|
||||
post_grad.lazy_init()
|
||||
|
Reference in New Issue
Block a user