mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[Inductor] addmm + activation function fusion (#158137)"
This reverts commit b9d7de3a094598c3dc0dd52e57bce30eb684c9d8.
Reverted https://github.com/pytorch/pytorch/pull/158137 on behalf of https://github.com/malfet due to Broke inductor torchbench, see 663da17b62/1 ([comment](https://github.com/pytorch/pytorch/pull/158137#issuecomment-3191841298))
This commit is contained in:
@ -635,64 +635,6 @@ class TestPatternMatcher(TestCase):
|
|||||||
|
|
||||||
self.assertEqual(res1, res2)
|
self.assertEqual(res1, res2)
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
def test_addmm_activation(self):
|
|
||||||
def fn_addmm_relu(input, mat1, mat2):
|
|
||||||
return torch.nn.functional.relu(torch.addmm(input, mat1, mat2))
|
|
||||||
|
|
||||||
def fn_addmm_gelu(input, mat1, mat2):
|
|
||||||
return torch.nn.functional.gelu(
|
|
||||||
torch.addmm(input, mat1, mat2), approximate="tanh"
|
|
||||||
)
|
|
||||||
|
|
||||||
def fn_add_mm_relu(input, mat1, mat2):
|
|
||||||
return torch.nn.functional.relu(torch.add(input, torch.mm(mat1, mat2)))
|
|
||||||
|
|
||||||
def fn_add_mm_gelu(input, mat1, mat2):
|
|
||||||
return torch.nn.functional.gelu(
|
|
||||||
torch.add(input, torch.mm(mat1, mat2)), approximate="tanh"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = [
|
|
||||||
torch.randn(20, device=GPU_TYPE),
|
|
||||||
torch.randn(10, 15, device=GPU_TYPE),
|
|
||||||
torch.randn(15, 20, device=GPU_TYPE),
|
|
||||||
]
|
|
||||||
|
|
||||||
for fn, atol in (
|
|
||||||
(fn_addmm_relu, 1e-8),
|
|
||||||
(fn_add_mm_relu, 1e-8),
|
|
||||||
(fn_addmm_gelu, 1e-3),
|
|
||||||
(fn_add_mm_gelu, 1e-3),
|
|
||||||
):
|
|
||||||
expected = fn(*args)
|
|
||||||
actual, (code,) = run_and_get_code(torch.compile(fn), *args)
|
|
||||||
torch.testing.assert_close(actual, expected, atol=atol, rtol=0)
|
|
||||||
self.assertTrue("_addmm_activation" in code)
|
|
||||||
|
|
||||||
for fn in (fn_addmm_relu, fn_addmm_gelu):
|
|
||||||
actual, (code,) = run_and_get_code(
|
|
||||||
torch.compile(fn, options={"max_autotune_gemm": True}), *args
|
|
||||||
)
|
|
||||||
self.assertFalse("_addmm_activation" in code)
|
|
||||||
|
|
||||||
args_not_replaced = [
|
|
||||||
# addmm + activation with a rank-2 input
|
|
||||||
# is not fusable, hence not replaced
|
|
||||||
torch.randn(10, 20, device=GPU_TYPE), # input
|
|
||||||
torch.randn(10, 15, device=GPU_TYPE), # mat1
|
|
||||||
torch.randn(15, 20, device=GPU_TYPE), # mat2
|
|
||||||
]
|
|
||||||
|
|
||||||
for fn in (fn_addmm_relu, fn_addmm_gelu):
|
|
||||||
actual, (code,) = run_and_get_code(
|
|
||||||
torch.compile(
|
|
||||||
fn,
|
|
||||||
),
|
|
||||||
*args_not_replaced,
|
|
||||||
)
|
|
||||||
self.assertFalse("_addmm_activation" in code)
|
|
||||||
|
|
||||||
@inductor_config.patch(
|
@inductor_config.patch(
|
||||||
{
|
{
|
||||||
"max_autotune_gemm_backends": "ATEN",
|
"max_autotune_gemm_backends": "ATEN",
|
||||||
|
|||||||
@ -33,7 +33,6 @@ from ..pattern_matcher import (
|
|||||||
CallFunctionVarArgs,
|
CallFunctionVarArgs,
|
||||||
filter_nodes,
|
filter_nodes,
|
||||||
fwd_only,
|
fwd_only,
|
||||||
gen_register_replacement,
|
|
||||||
get_arg_value,
|
get_arg_value,
|
||||||
get_mutation_region_id,
|
get_mutation_region_id,
|
||||||
Ignored,
|
Ignored,
|
||||||
@ -661,97 +660,6 @@ def lazy_init():
|
|||||||
extra_check=prepare_softmax_extra_check,
|
extra_check=prepare_softmax_extra_check,
|
||||||
)
|
)
|
||||||
|
|
||||||
register_addmm_activation_fusion()
|
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def register_addmm_activation_fusion():
|
|
||||||
shapes = [(5,), (3, 4), (4, 5)]
|
|
||||||
args_fp32 = [torch.empty(shape) for shape in shapes]
|
|
||||||
args_bf16 = [torch.empty(shape, dtype=torch.bfloat16) for shape in shapes]
|
|
||||||
|
|
||||||
for pattern in [addmm_relu_pattern, addmm_relu_pattern_2]:
|
|
||||||
name = f"{pattern.__name__}_fp32"
|
|
||||||
gen_register_replacement(
|
|
||||||
name,
|
|
||||||
pattern,
|
|
||||||
addmm_relu_replacement,
|
|
||||||
args_fp32,
|
|
||||||
trace_fn=fwd_only,
|
|
||||||
pass_dicts=pass_patterns[2],
|
|
||||||
extra_check=is_valid_addmm_activation_fusion,
|
|
||||||
)
|
|
||||||
|
|
||||||
for args, dtype_suffix in [(args_fp32, "fp32"), (args_bf16, "bf16")]:
|
|
||||||
for pattern in [addmm_gelu_pattern, addmm_gelu_pattern_2]:
|
|
||||||
name = f"{pattern.__name__}_{dtype_suffix}"
|
|
||||||
gen_register_replacement(
|
|
||||||
name,
|
|
||||||
pattern,
|
|
||||||
addmm_gelu_replacement,
|
|
||||||
args,
|
|
||||||
trace_fn=fwd_only,
|
|
||||||
pass_dicts=pass_patterns[2],
|
|
||||||
extra_check=is_valid_addmm_activation_fusion,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_addmm_activation_fusion(match):
|
|
||||||
if config.max_autotune_gemm:
|
|
||||||
return False
|
|
||||||
inp = match.kwargs["input"].meta["val"]
|
|
||||||
mat1 = match.kwargs["mat1"].meta["val"]
|
|
||||||
mat2 = match.kwargs["mat2"].meta["val"]
|
|
||||||
|
|
||||||
# match the dispatch logic for cuBLASLT at aten/src/ATen/native/cuda/Blas.cpp
|
|
||||||
if not (inp.is_cuda and inp.dim() == 1 and inp.is_contiguous()):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not (mat1.dim() == 2 and mat2.dim() == 2):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if inp.size(0) != mat2.size(1):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if inp.dtype != mat1.dtype or inp.dtype != mat2.dtype:
|
|
||||||
return False
|
|
||||||
|
|
||||||
output = match.output_node()
|
|
||||||
# do not fuse if there are pointwise ops after
|
|
||||||
return not all(is_pointwise_use(use) for use in output.users)
|
|
||||||
|
|
||||||
|
|
||||||
def addmm_gelu_pattern(input, mat1, mat2):
|
|
||||||
output = aten.mm(mat1, mat2)
|
|
||||||
output = aten.add(output, input)
|
|
||||||
return aten.gelu(output, approximate="tanh")
|
|
||||||
|
|
||||||
|
|
||||||
def addmm_gelu_pattern_2(input, mat1, mat2):
|
|
||||||
output = aten.mm(mat1, mat2)
|
|
||||||
output = aten.add(input, output)
|
|
||||||
return aten.gelu(output, approximate="tanh")
|
|
||||||
|
|
||||||
|
|
||||||
def addmm_gelu_replacement(input, mat1, mat2):
|
|
||||||
return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=True)
|
|
||||||
|
|
||||||
|
|
||||||
def addmm_relu_pattern(input, mat1, mat2):
|
|
||||||
output = aten.mm(mat1, mat2)
|
|
||||||
output = aten.add(input, output)
|
|
||||||
return aten.relu(output)
|
|
||||||
|
|
||||||
|
|
||||||
def addmm_relu_pattern_2(input, mat1, mat2):
|
|
||||||
output = aten.mm(mat1, mat2)
|
|
||||||
output = aten.add(output, input)
|
|
||||||
return aten.relu(output)
|
|
||||||
|
|
||||||
|
|
||||||
def addmm_relu_replacement(input, mat1, mat2):
|
|
||||||
return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=False)
|
|
||||||
|
|
||||||
|
|
||||||
def reorder_for_locality(graph: torch.fx.Graph):
|
def reorder_for_locality(graph: torch.fx.Graph):
|
||||||
if torch.distributed.is_available():
|
if torch.distributed.is_available():
|
||||||
@ -1553,7 +1461,7 @@ def should_prefer_unfused_addmm(match):
|
|||||||
|
|
||||||
@register_graph_pattern(
|
@register_graph_pattern(
|
||||||
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
|
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
|
||||||
pass_dict=pass_patterns[1],
|
pass_dict=pass_patterns[2],
|
||||||
extra_check=should_prefer_unfused_addmm,
|
extra_check=should_prefer_unfused_addmm,
|
||||||
)
|
)
|
||||||
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
|
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
|
||||||
|
|||||||
@ -1,59 +0,0 @@
|
|||||||
# mypy: ignore-errors
|
|
||||||
|
|
||||||
# noqa: F401, E501
|
|
||||||
# This is an auto-generated file. Please do not modify it by hand.
|
|
||||||
# To re-generate, run:
|
|
||||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch._inductor
|
|
||||||
import operator
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
|
||||||
prims = torch.ops.prims
|
|
||||||
|
|
||||||
from torch._inductor.pattern_matcher import (
|
|
||||||
Arg,
|
|
||||||
CallFunction,
|
|
||||||
CallFunctionVarArgs,
|
|
||||||
CallMethod,
|
|
||||||
CallMethodVarArgs,
|
|
||||||
CallModule,
|
|
||||||
CallModuleVarArgs,
|
|
||||||
ExclusiveKeywordArg,
|
|
||||||
Ignored,
|
|
||||||
KeywordArg,
|
|
||||||
ListOf,
|
|
||||||
MultiOutputPattern,
|
|
||||||
PatternExpr,
|
|
||||||
RepeatedExpr,
|
|
||||||
_TargetArgsExpr,
|
|
||||||
_TargetExpr,
|
|
||||||
_TargetExprVarArgs,
|
|
||||||
)
|
|
||||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
|
||||||
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'), _users=4)
|
|
||||||
mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored())
|
|
||||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor)
|
|
||||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor)
|
|
||||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
|
|
||||||
add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3)
|
|
||||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
|
|
||||||
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
|
|
||||||
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
|
|
||||||
addmm_gelu_pattern_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0)
|
|
||||||
|
|
||||||
|
|
||||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
|
||||||
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'))
|
|
||||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4)
|
|
||||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
|
||||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default)
|
|
||||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default)
|
|
||||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
|
|
||||||
add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3)
|
|
||||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
|
|
||||||
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
|
|
||||||
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
|
|
||||||
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2)
|
|
||||||
addmm_gelu_pattern_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0)
|
|
||||||
@ -1,59 +0,0 @@
|
|||||||
# mypy: ignore-errors
|
|
||||||
|
|
||||||
# noqa: F401, E501
|
|
||||||
# This is an auto-generated file. Please do not modify it by hand.
|
|
||||||
# To re-generate, run:
|
|
||||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch._inductor
|
|
||||||
import operator
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
|
||||||
prims = torch.ops.prims
|
|
||||||
|
|
||||||
from torch._inductor.pattern_matcher import (
|
|
||||||
Arg,
|
|
||||||
CallFunction,
|
|
||||||
CallFunctionVarArgs,
|
|
||||||
CallMethod,
|
|
||||||
CallMethodVarArgs,
|
|
||||||
CallModule,
|
|
||||||
CallModuleVarArgs,
|
|
||||||
ExclusiveKeywordArg,
|
|
||||||
Ignored,
|
|
||||||
KeywordArg,
|
|
||||||
ListOf,
|
|
||||||
MultiOutputPattern,
|
|
||||||
PatternExpr,
|
|
||||||
RepeatedExpr,
|
|
||||||
_TargetArgsExpr,
|
|
||||||
_TargetExpr,
|
|
||||||
_TargetExprVarArgs,
|
|
||||||
)
|
|
||||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
|
||||||
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default, _users=4)
|
|
||||||
mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored())
|
|
||||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor)
|
|
||||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor)
|
|
||||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
|
|
||||||
add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3)
|
|
||||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
|
|
||||||
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
|
|
||||||
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
|
|
||||||
addmm_gelu_pattern_2_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0)
|
|
||||||
|
|
||||||
|
|
||||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
|
||||||
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default)
|
|
||||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4)
|
|
||||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
|
||||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default)
|
|
||||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default)
|
|
||||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored())
|
|
||||||
add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3)
|
|
||||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored())
|
|
||||||
tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4)
|
|
||||||
add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored())
|
|
||||||
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2)
|
|
||||||
addmm_gelu_pattern_2_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0)
|
|
||||||
@ -1,36 +0,0 @@
|
|||||||
# mypy: ignore-errors
|
|
||||||
|
|
||||||
# noqa: F401, E501
|
|
||||||
# This is an auto-generated file. Please do not modify it by hand.
|
|
||||||
# To re-generate, run:
|
|
||||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch._inductor
|
|
||||||
import operator
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
|
||||||
prims = torch.ops.prims
|
|
||||||
|
|
||||||
from torch._inductor.pattern_matcher import (
|
|
||||||
Arg,
|
|
||||||
CallFunction,
|
|
||||||
CallFunctionVarArgs,
|
|
||||||
CallMethod,
|
|
||||||
CallMethodVarArgs,
|
|
||||||
CallModule,
|
|
||||||
CallModuleVarArgs,
|
|
||||||
ExclusiveKeywordArg,
|
|
||||||
Ignored,
|
|
||||||
KeywordArg,
|
|
||||||
ListOf,
|
|
||||||
MultiOutputPattern,
|
|
||||||
PatternExpr,
|
|
||||||
RepeatedExpr,
|
|
||||||
_TargetArgsExpr,
|
|
||||||
_TargetExpr,
|
|
||||||
_TargetExprVarArgs,
|
|
||||||
)
|
|
||||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
|
||||||
add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default)
|
|
||||||
addmm_relu_pattern_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0)
|
|
||||||
@ -1,36 +0,0 @@
|
|||||||
# mypy: ignore-errors
|
|
||||||
|
|
||||||
# noqa: F401, E501
|
|
||||||
# This is an auto-generated file. Please do not modify it by hand.
|
|
||||||
# To re-generate, run:
|
|
||||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch._inductor
|
|
||||||
import operator
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
|
||||||
prims = torch.ops.prims
|
|
||||||
|
|
||||||
from torch._inductor.pattern_matcher import (
|
|
||||||
Arg,
|
|
||||||
CallFunction,
|
|
||||||
CallFunctionVarArgs,
|
|
||||||
CallMethod,
|
|
||||||
CallMethodVarArgs,
|
|
||||||
CallModule,
|
|
||||||
CallModuleVarArgs,
|
|
||||||
ExclusiveKeywordArg,
|
|
||||||
Ignored,
|
|
||||||
KeywordArg,
|
|
||||||
ListOf,
|
|
||||||
MultiOutputPattern,
|
|
||||||
PatternExpr,
|
|
||||||
RepeatedExpr,
|
|
||||||
_TargetArgsExpr,
|
|
||||||
_TargetExpr,
|
|
||||||
_TargetExprVarArgs,
|
|
||||||
)
|
|
||||||
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
|
||||||
add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'))
|
|
||||||
addmm_relu_pattern_2_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0)
|
|
||||||
@ -2,7 +2,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from torch._inductor import pattern_matcher
|
from torch._inductor import pattern_matcher
|
||||||
from torch._inductor.fx_passes import joint_graph, post_grad
|
from torch._inductor.fx_passes import joint_graph
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -17,4 +17,3 @@ if __name__ == "__main__":
|
|||||||
# to serialize the patterns as it goes.
|
# to serialize the patterns as it goes.
|
||||||
os.environ["PYTORCH_GEN_PATTERNS"] = "1"
|
os.environ["PYTORCH_GEN_PATTERNS"] = "1"
|
||||||
joint_graph.lazy_init()
|
joint_graph.lazy_init()
|
||||||
post_grad.lazy_init()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user