mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[Inductor] addmm + activation function fusion (#158137)"
This reverts commit b9d7de3a094598c3dc0dd52e57bce30eb684c9d8.
Reverted https://github.com/pytorch/pytorch/pull/158137 on behalf of https://github.com/malfet due to Broke inductor torchbench, see 663da17b62/1
([comment](https://github.com/pytorch/pytorch/pull/158137#issuecomment-3191841298))
This commit is contained in:
@ -635,64 +635,6 @@ class TestPatternMatcher(TestCase):
|
||||
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
@skipIfRocm
|
||||
def test_addmm_activation(self):
|
||||
def fn_addmm_relu(input, mat1, mat2):
|
||||
return torch.nn.functional.relu(torch.addmm(input, mat1, mat2))
|
||||
|
||||
def fn_addmm_gelu(input, mat1, mat2):
|
||||
return torch.nn.functional.gelu(
|
||||
torch.addmm(input, mat1, mat2), approximate="tanh"
|
||||
)
|
||||
|
||||
def fn_add_mm_relu(input, mat1, mat2):
|
||||
return torch.nn.functional.relu(torch.add(input, torch.mm(mat1, mat2)))
|
||||
|
||||
def fn_add_mm_gelu(input, mat1, mat2):
|
||||
return torch.nn.functional.gelu(
|
||||
torch.add(input, torch.mm(mat1, mat2)), approximate="tanh"
|
||||
)
|
||||
|
||||
args = [
|
||||
torch.randn(20, device=GPU_TYPE),
|
||||
torch.randn(10, 15, device=GPU_TYPE),
|
||||
torch.randn(15, 20, device=GPU_TYPE),
|
||||
]
|
||||
|
||||
for fn, atol in (
|
||||
(fn_addmm_relu, 1e-8),
|
||||
(fn_add_mm_relu, 1e-8),
|
||||
(fn_addmm_gelu, 1e-3),
|
||||
(fn_add_mm_gelu, 1e-3),
|
||||
):
|
||||
expected = fn(*args)
|
||||
actual, (code,) = run_and_get_code(torch.compile(fn), *args)
|
||||
torch.testing.assert_close(actual, expected, atol=atol, rtol=0)
|
||||
self.assertTrue("_addmm_activation" in code)
|
||||
|
||||
for fn in (fn_addmm_relu, fn_addmm_gelu):
|
||||
actual, (code,) = run_and_get_code(
|
||||
torch.compile(fn, options={"max_autotune_gemm": True}), *args
|
||||
)
|
||||
self.assertFalse("_addmm_activation" in code)
|
||||
|
||||
args_not_replaced = [
|
||||
# addmm + activation with a rank-2 input
|
||||
# is not fusable, hence not replaced
|
||||
torch.randn(10, 20, device=GPU_TYPE), # input
|
||||
torch.randn(10, 15, device=GPU_TYPE), # mat1
|
||||
torch.randn(15, 20, device=GPU_TYPE), # mat2
|
||||
]
|
||||
|
||||
for fn in (fn_addmm_relu, fn_addmm_gelu):
|
||||
actual, (code,) = run_and_get_code(
|
||||
torch.compile(
|
||||
fn,
|
||||
),
|
||||
*args_not_replaced,
|
||||
)
|
||||
self.assertFalse("_addmm_activation" in code)
|
||||
|
||||
@inductor_config.patch(
|
||||
{
|
||||
"max_autotune_gemm_backends": "ATEN",
|
||||
|
@ -33,7 +33,6 @@ from ..pattern_matcher import (
|
||||
CallFunctionVarArgs,
|
||||
filter_nodes,
|
||||
fwd_only,
|
||||
gen_register_replacement,
|
||||
get_arg_value,
|
||||
get_mutation_region_id,
|
||||
Ignored,
|
||||
@ -661,97 +660,6 @@ def lazy_init():
|
||||
extra_check=prepare_softmax_extra_check,
|
||||
)
|
||||
|
||||
register_addmm_activation_fusion()
|
||||
|
||||
|
||||
@functools.cache
|
||||
def register_addmm_activation_fusion():
|
||||
shapes = [(5,), (3, 4), (4, 5)]
|
||||
args_fp32 = [torch.empty(shape) for shape in shapes]
|
||||
args_bf16 = [torch.empty(shape, dtype=torch.bfloat16) for shape in shapes]
|
||||
|
||||
for pattern in [addmm_relu_pattern, addmm_relu_pattern_2]:
|
||||
name = f"{pattern.__name__}_fp32"
|
||||
gen_register_replacement(
|
||||
name,
|
||||
pattern,
|
||||
addmm_relu_replacement,
|
||||
args_fp32,
|
||||
trace_fn=fwd_only,
|
||||
pass_dicts=pass_patterns[2],
|
||||
extra_check=is_valid_addmm_activation_fusion,
|
||||
)
|
||||
|
||||
for args, dtype_suffix in [(args_fp32, "fp32"), (args_bf16, "bf16")]:
|
||||
for pattern in [addmm_gelu_pattern, addmm_gelu_pattern_2]:
|
||||
name = f"{pattern.__name__}_{dtype_suffix}"
|
||||
gen_register_replacement(
|
||||
name,
|
||||
pattern,
|
||||
addmm_gelu_replacement,
|
||||
args,
|
||||
trace_fn=fwd_only,
|
||||
pass_dicts=pass_patterns[2],
|
||||
extra_check=is_valid_addmm_activation_fusion,
|
||||
)
|
||||
|
||||
|
||||
def is_valid_addmm_activation_fusion(match):
|
||||
if config.max_autotune_gemm:
|
||||
return False
|
||||
inp = match.kwargs["input"].meta["val"]
|
||||
mat1 = match.kwargs["mat1"].meta["val"]
|
||||
mat2 = match.kwargs["mat2"].meta["val"]
|
||||
|
||||
# match the dispatch logic for cuBLASLT at aten/src/ATen/native/cuda/Blas.cpp
|
||||
if not (inp.is_cuda and inp.dim() == 1 and inp.is_contiguous()):
|
||||
return False
|
||||
|
||||
if not (mat1.dim() == 2 and mat2.dim() == 2):
|
||||
return False
|
||||
|
||||
if inp.size(0) != mat2.size(1):
|
||||
return False
|
||||
|
||||
if inp.dtype != mat1.dtype or inp.dtype != mat2.dtype:
|
||||
return False
|
||||
|
||||
output = match.output_node()
|
||||
# do not fuse if there are pointwise ops after
|
||||
return not all(is_pointwise_use(use) for use in output.users)
|
||||
|
||||
|
||||
def addmm_gelu_pattern(input, mat1, mat2):
|
||||
output = aten.mm(mat1, mat2)
|
||||
output = aten.add(output, input)
|
||||
return aten.gelu(output, approximate="tanh")
|
||||
|
||||
|
||||
def addmm_gelu_pattern_2(input, mat1, mat2):
|
||||
output = aten.mm(mat1, mat2)
|
||||
output = aten.add(input, output)
|
||||
return aten.gelu(output, approximate="tanh")
|
||||
|
||||
|
||||
def addmm_gelu_replacement(input, mat1, mat2):
|
||||
return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=True)
|
||||
|
||||
|
||||
def addmm_relu_pattern(input, mat1, mat2):
|
||||
output = aten.mm(mat1, mat2)
|
||||
output = aten.add(input, output)
|
||||
return aten.relu(output)
|
||||
|
||||
|
||||
def addmm_relu_pattern_2(input, mat1, mat2):
|
||||
output = aten.mm(mat1, mat2)
|
||||
output = aten.add(output, input)
|
||||
return aten.relu(output)
|
||||
|
||||
|
||||
def addmm_relu_replacement(input, mat1, mat2):
|
||||
return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=False)
|
||||
|
||||
|
||||
def reorder_for_locality(graph: torch.fx.Graph):
|
||||
if torch.distributed.is_available():
|
||||
@ -1553,7 +1461,7 @@ def should_prefer_unfused_addmm(match):
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
|
||||
pass_dict=pass_patterns[1],
|
||||
pass_dict=pass_patterns[2],
|
||||
extra_check=should_prefer_unfused_addmm,
|
||||
)
|
||||
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
|
||||
|
@ -1,59 +0,0 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
import operator
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
CallModule,
|
||||
CallModuleVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'), _users=4)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored())
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor)
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor)
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
|
||||
add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
|
||||
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
|
||||
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
|
||||
addmm_gelu_pattern_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0)
|
||||
|
||||
|
||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'))
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default)
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default)
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
|
||||
add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
|
||||
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
|
||||
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
|
||||
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2)
|
||||
addmm_gelu_pattern_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0)
|
@ -1,59 +0,0 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
import operator
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
CallModule,
|
||||
CallModuleVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default, _users=4)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored())
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor)
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor)
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
|
||||
add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
|
||||
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
|
||||
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
|
||||
addmm_gelu_pattern_2_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0)
|
||||
|
||||
|
||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default)
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default)
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default)
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
|
||||
add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
|
||||
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
|
||||
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
|
||||
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2)
|
||||
addmm_gelu_pattern_2_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0)
|
@ -1,36 +0,0 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
import operator
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
CallModule,
|
||||
CallModuleVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default)
|
||||
addmm_relu_pattern_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0)
|
@ -1,36 +0,0 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
import operator
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
CallModule,
|
||||
CallModuleVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'))
|
||||
addmm_relu_pattern_2_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0)
|
@ -2,7 +2,7 @@
|
||||
import os
|
||||
|
||||
from torch._inductor import pattern_matcher
|
||||
from torch._inductor.fx_passes import joint_graph, post_grad
|
||||
from torch._inductor.fx_passes import joint_graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -17,4 +17,3 @@ if __name__ == "__main__":
|
||||
# to serialize the patterns as it goes.
|
||||
os.environ["PYTORCH_GEN_PATTERNS"] = "1"
|
||||
joint_graph.lazy_init()
|
||||
post_grad.lazy_init()
|
||||
|
Reference in New Issue
Block a user