diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index e6aa08601a0e..0ffe7cb37deb 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -635,64 +635,6 @@ class TestPatternMatcher(TestCase): self.assertEqual(res1, res2) - @skipIfRocm - def test_addmm_activation(self): - def fn_addmm_relu(input, mat1, mat2): - return torch.nn.functional.relu(torch.addmm(input, mat1, mat2)) - - def fn_addmm_gelu(input, mat1, mat2): - return torch.nn.functional.gelu( - torch.addmm(input, mat1, mat2), approximate="tanh" - ) - - def fn_add_mm_relu(input, mat1, mat2): - return torch.nn.functional.relu(torch.add(input, torch.mm(mat1, mat2))) - - def fn_add_mm_gelu(input, mat1, mat2): - return torch.nn.functional.gelu( - torch.add(input, torch.mm(mat1, mat2)), approximate="tanh" - ) - - args = [ - torch.randn(20, device=GPU_TYPE), - torch.randn(10, 15, device=GPU_TYPE), - torch.randn(15, 20, device=GPU_TYPE), - ] - - for fn, atol in ( - (fn_addmm_relu, 1e-8), - (fn_add_mm_relu, 1e-8), - (fn_addmm_gelu, 1e-3), - (fn_add_mm_gelu, 1e-3), - ): - expected = fn(*args) - actual, (code,) = run_and_get_code(torch.compile(fn), *args) - torch.testing.assert_close(actual, expected, atol=atol, rtol=0) - self.assertTrue("_addmm_activation" in code) - - for fn in (fn_addmm_relu, fn_addmm_gelu): - actual, (code,) = run_and_get_code( - torch.compile(fn, options={"max_autotune_gemm": True}), *args - ) - self.assertFalse("_addmm_activation" in code) - - args_not_replaced = [ - # addmm + activation with a rank-2 input - # is not fusable, hence not replaced - torch.randn(10, 20, device=GPU_TYPE), # input - torch.randn(10, 15, device=GPU_TYPE), # mat1 - torch.randn(15, 20, device=GPU_TYPE), # mat2 - ] - - for fn in (fn_addmm_relu, fn_addmm_gelu): - actual, (code,) = run_and_get_code( - torch.compile( - fn, - ), - *args_not_replaced, - ) - self.assertFalse("_addmm_activation" in code) - @inductor_config.patch( { "max_autotune_gemm_backends": "ATEN", diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index d72079b83a09..db273b06c8e6 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -33,7 +33,6 @@ from ..pattern_matcher import ( CallFunctionVarArgs, filter_nodes, fwd_only, - gen_register_replacement, get_arg_value, get_mutation_region_id, Ignored, @@ -661,97 +660,6 @@ def lazy_init(): extra_check=prepare_softmax_extra_check, ) - register_addmm_activation_fusion() - - -@functools.cache -def register_addmm_activation_fusion(): - shapes = [(5,), (3, 4), (4, 5)] - args_fp32 = [torch.empty(shape) for shape in shapes] - args_bf16 = [torch.empty(shape, dtype=torch.bfloat16) for shape in shapes] - - for pattern in [addmm_relu_pattern, addmm_relu_pattern_2]: - name = f"{pattern.__name__}_fp32" - gen_register_replacement( - name, - pattern, - addmm_relu_replacement, - args_fp32, - trace_fn=fwd_only, - pass_dicts=pass_patterns[2], - extra_check=is_valid_addmm_activation_fusion, - ) - - for args, dtype_suffix in [(args_fp32, "fp32"), (args_bf16, "bf16")]: - for pattern in [addmm_gelu_pattern, addmm_gelu_pattern_2]: - name = f"{pattern.__name__}_{dtype_suffix}" - gen_register_replacement( - name, - pattern, - addmm_gelu_replacement, - args, - trace_fn=fwd_only, - pass_dicts=pass_patterns[2], - extra_check=is_valid_addmm_activation_fusion, - ) - - -def is_valid_addmm_activation_fusion(match): - if config.max_autotune_gemm: - return False - inp = match.kwargs["input"].meta["val"] - mat1 = match.kwargs["mat1"].meta["val"] - mat2 = match.kwargs["mat2"].meta["val"] - - # match the dispatch logic for cuBLASLT at aten/src/ATen/native/cuda/Blas.cpp - if not (inp.is_cuda and inp.dim() == 1 and inp.is_contiguous()): - return False - - if not (mat1.dim() == 2 and mat2.dim() == 2): - return False - - if inp.size(0) != mat2.size(1): - return False - - if inp.dtype != mat1.dtype or inp.dtype != mat2.dtype: - return False - - output = match.output_node() - # do not fuse if there are pointwise ops after - return not all(is_pointwise_use(use) for use in output.users) - - -def addmm_gelu_pattern(input, mat1, mat2): - output = aten.mm(mat1, mat2) - output = aten.add(output, input) - return aten.gelu(output, approximate="tanh") - - -def addmm_gelu_pattern_2(input, mat1, mat2): - output = aten.mm(mat1, mat2) - output = aten.add(input, output) - return aten.gelu(output, approximate="tanh") - - -def addmm_gelu_replacement(input, mat1, mat2): - return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=True) - - -def addmm_relu_pattern(input, mat1, mat2): - output = aten.mm(mat1, mat2) - output = aten.add(input, output) - return aten.relu(output) - - -def addmm_relu_pattern_2(input, mat1, mat2): - output = aten.mm(mat1, mat2) - output = aten.add(output, input) - return aten.relu(output) - - -def addmm_relu_replacement(input, mat1, mat2): - return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=False) - def reorder_for_locality(graph: torch.fx.Graph): if torch.distributed.is_available(): @@ -1553,7 +1461,7 @@ def should_prefer_unfused_addmm(match): @register_graph_pattern( CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()), - pass_dict=pass_patterns[1], + pass_dict=pass_patterns[2], extra_check=should_prefer_unfused_addmm, ) def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp): diff --git a/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py b/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py deleted file mode 100644 index 99f691e6fdd4..000000000000 --- a/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py +++ /dev/null @@ -1,59 +0,0 @@ -# mypy: ignore-errors - -# noqa: F401, E501 -# This is an auto-generated file. Please do not modify it by hand. -# To re-generate, run: -# cd ~/pytorch && python torchgen/fuse/gen_patterns.py - -import torch -import torch._inductor -import operator - -aten = torch.ops.aten -prims = torch.ops.prims - -from torch._inductor.pattern_matcher import ( - Arg, - CallFunction, - CallFunctionVarArgs, - CallMethod, - CallMethodVarArgs, - CallModule, - CallModuleVarArgs, - ExclusiveKeywordArg, - Ignored, - KeywordArg, - ListOf, - MultiOutputPattern, - PatternExpr, - RepeatedExpr, - _TargetArgsExpr, - _TargetExpr, - _TargetExprVarArgs, -) -mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) -add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'), _users=4) -mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored()) -mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor) -mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor) -mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored()) -add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored()) -tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4) -add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored()) -addmm_gelu_pattern_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0) - - -mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) -add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input')) -convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4) -mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) -mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default) -mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default) -mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored()) -add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored()) -tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4) -add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored()) -mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2) -addmm_gelu_pattern_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern_2.py b/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern_2.py deleted file mode 100644 index 288177ed37ac..000000000000 --- a/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern_2.py +++ /dev/null @@ -1,59 +0,0 @@ -# mypy: ignore-errors - -# noqa: F401, E501 -# This is an auto-generated file. Please do not modify it by hand. -# To re-generate, run: -# cd ~/pytorch && python torchgen/fuse/gen_patterns.py - -import torch -import torch._inductor -import operator - -aten = torch.ops.aten -prims = torch.ops.prims - -from torch._inductor.pattern_matcher import ( - Arg, - CallFunction, - CallFunctionVarArgs, - CallMethod, - CallMethodVarArgs, - CallModule, - CallModuleVarArgs, - ExclusiveKeywordArg, - Ignored, - KeywordArg, - ListOf, - MultiOutputPattern, - PatternExpr, - RepeatedExpr, - _TargetArgsExpr, - _TargetExpr, - _TargetExprVarArgs, -) -mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) -add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default, _users=4) -mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored()) -mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor) -mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor) -mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored()) -add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored()) -tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4) -add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored()) -addmm_gelu_pattern_2_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0) - - -mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) -add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default) -convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4) -mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) -mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default) -mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default) -mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored()) -add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored()) -tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4) -add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored()) -mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2) -addmm_gelu_pattern_2_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py b/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py deleted file mode 100644 index 9deef11cf329..000000000000 --- a/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py +++ /dev/null @@ -1,36 +0,0 @@ -# mypy: ignore-errors - -# noqa: F401, E501 -# This is an auto-generated file. Please do not modify it by hand. -# To re-generate, run: -# cd ~/pytorch && python torchgen/fuse/gen_patterns.py - -import torch -import torch._inductor -import operator - -aten = torch.ops.aten -prims = torch.ops.prims - -from torch._inductor.pattern_matcher import ( - Arg, - CallFunction, - CallFunctionVarArgs, - CallMethod, - CallMethodVarArgs, - CallModule, - CallModuleVarArgs, - ExclusiveKeywordArg, - Ignored, - KeywordArg, - ListOf, - MultiOutputPattern, - PatternExpr, - RepeatedExpr, - _TargetArgsExpr, - _TargetExpr, - _TargetExprVarArgs, -) -mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) -add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default) -addmm_relu_pattern_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern_2.py b/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern_2.py deleted file mode 100644 index 4a3c47310511..000000000000 --- a/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern_2.py +++ /dev/null @@ -1,36 +0,0 @@ -# mypy: ignore-errors - -# noqa: F401, E501 -# This is an auto-generated file. Please do not modify it by hand. -# To re-generate, run: -# cd ~/pytorch && python torchgen/fuse/gen_patterns.py - -import torch -import torch._inductor -import operator - -aten = torch.ops.aten -prims = torch.ops.prims - -from torch._inductor.pattern_matcher import ( - Arg, - CallFunction, - CallFunctionVarArgs, - CallMethod, - CallMethodVarArgs, - CallModule, - CallModuleVarArgs, - ExclusiveKeywordArg, - Ignored, - KeywordArg, - ListOf, - MultiOutputPattern, - PatternExpr, - RepeatedExpr, - _TargetArgsExpr, - _TargetExpr, - _TargetExprVarArgs, -) -mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) -add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input')) -addmm_relu_pattern_2_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0) diff --git a/torchgen/fuse/gen_patterns.py b/torchgen/fuse/gen_patterns.py index b4bdf022202b..0861c882e3ff 100644 --- a/torchgen/fuse/gen_patterns.py +++ b/torchgen/fuse/gen_patterns.py @@ -2,7 +2,7 @@ import os from torch._inductor import pattern_matcher -from torch._inductor.fx_passes import joint_graph, post_grad +from torch._inductor.fx_passes import joint_graph if __name__ == "__main__": @@ -17,4 +17,3 @@ if __name__ == "__main__": # to serialize the patterns as it goes. os.environ["PYTORCH_GEN_PATTERNS"] = "1" joint_graph.lazy_init() - post_grad.lazy_init()