[Inductor][FX passes] Remove config.split_cat_fx_passes & Add config.experimental_patterns (#104208)

Summary:
TLDR:
* Remove config.split_cat_fx_passes, and move split cat passes behind config.pattern_matcher (True by default)
* Add config.experimental_patterns (False by default).
* In the future, general/universal patterns should behind config.pattern_matcher; customized/unmatured patterns should behind config.experimental_patterns.

More details at:
https://docs.google.com/document/d/1P8uJTpOTdQpUbw56UxHol40tt-EPFTq1Qu38072E9aM/edit

Test Plan: Existing unit tests

Reviewed By: jansel, jackiexu1992

Differential Revision: D46752606

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104208
Approved by: https://github.com/williamwen42
This commit is contained in:
Yanbo Liang
2023-06-27 20:08:40 +00:00
committed by PyTorch MergeBot
parent 2da6cae43c
commit 7bf27cf163
3 changed files with 3 additions and 28 deletions

View File

@ -7,13 +7,7 @@ from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
def patch(f):
f = torch._inductor.config.patch(split_cat_fx_passes=True)(f)
return f
class TestSplitCatFxPasses(TestCase):
@patch
def test_split_normalization(self):
def arg_only(x):
return [torch.relu(s) for s in torch.split(x, 2, 1)]
@ -91,7 +85,6 @@ class TestSplitCatFxPasses(TestCase):
)
counters.clear()
@patch
def test_consecutive_split_merge(self):
def multi_split(x):
return [torch.split(s, 2, 1) for s in torch.split(x, 2, 1)]
@ -252,7 +245,6 @@ class TestSplitCatFxPasses(TestCase):
)
counters.clear()
@patch
def test_split_cat_merge(self):
def simple_split_cat(x):
return torch.cat(torch.split(x, 4, dim=1), dim=1)
@ -582,7 +574,7 @@ class TestSplitCatFxPasses(TestCase):
)
counters.clear()
@torch._inductor.config.patch(split_cat_fx_passes=False)
@torch._inductor.config.patch(pattern_matcher=False)
def test_config_flag_is_respected(self):
def split_with_cat(x):
fs = torch.split(x, [4, 4, 24], dim=-1)
@ -612,7 +604,6 @@ class TestSplitCatFxPasses(TestCase):
0,
)
@patch
def test_split_cat_merge_mutation(self):
args = [
torch.randn(2, 32, 32, 16),
@ -631,7 +622,6 @@ class TestSplitCatFxPasses(TestCase):
self.assertEqual(counters["inductor"]["scmerge_split_removed"], 0)
self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 0)
@patch
def test_split_squeeze(self):
def split_squeeze_stack(x):
items = list(torch.split(x, 1, dim=1))
@ -721,7 +711,6 @@ class TestSplitCatFxPasses(TestCase):
)
counters.clear()
@patch
def test_unbind_stack(self):
def unbind_stack(x):
return torch.stack(torch.unbind(x, dim=1), 1)

View File

@ -45,8 +45,8 @@ epilogue_fusion_first = False
# enable pattern match+replace optimizations
pattern_matcher = True
# Optimize away split cat patterns (Experimental)
split_cat_fx_passes = True
# enable experimental patterns for match+replace optimizations
experimental_patterns = False
# enable reordering pass
reordering = True

View File

@ -12,7 +12,6 @@ from ..pattern_matcher import (
CallFunction,
CallFunctionVarArgs,
CallMethodVarArgs,
config_flag,
FailedMatch,
get_arg_value,
Ignored,
@ -86,12 +85,10 @@ def normalize_split_base(match: Match, _get_split_args: Callable):
@register_graph_pattern(
CallFunctionVarArgs(torch.split, users=MULTIPLE),
pass_dict=normalization_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
@register_graph_pattern(
CallMethodVarArgs("split", users=MULTIPLE),
pass_dict=normalization_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
def normalize_split_default(match: Match, *args, **kwargs):
return normalize_split_base(match, _get_split_args_default)
@ -100,7 +97,6 @@ def normalize_split_default(match: Match, *args, **kwargs):
@register_graph_pattern(
CallFunctionVarArgs(torch.cat, users=MULTIPLE),
pass_dict=normalization_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
def normalize_cat_default(match: Match, *args, **kwargs):
cat_node = match.nodes[0]
@ -151,7 +147,6 @@ def find_next_users(split_node):
@register_graph_pattern(
CallMethodVarArgs("squeeze", users=MULTIPLE),
pass_dict=normalization_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
def normalize_squeeze_default(match: Match, *args, **kwargs):
squeeze_node = match.nodes[0]
@ -219,7 +214,6 @@ class TorchSplit(CallFunction):
KeywordArg("next_split_sections"),
),
pass_dict=merge_splits_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
def merge_splits(
match: Match,
@ -814,7 +808,6 @@ class GetItem(CallFunction):
),
),
pass_dict=split_cat_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
@register_graph_pattern(
RepeatedExpr(
@ -832,7 +825,6 @@ class GetItem(CallFunction):
)
),
pass_dict=split_cat_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
def merge_split_squeeze(
match: Match, split_input: torch.fx.Node, split_sizes: List[int], dim: int
@ -884,21 +876,18 @@ getitem_unbind = ListOf(
@register_graph_pattern(
CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE),
pass_dict=unbind_stack_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
@register_graph_pattern(
CallFunction(
[torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE
),
pass_dict=unbind_stack_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
@register_graph_pattern(
CallFunction(
[torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE
),
pass_dict=unbind_stack_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int):
unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
@ -927,7 +916,6 @@ getitem_split = ListOf(
_users=MULTIPLE,
),
pass_dict=split_cat_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
@register_graph_pattern(
CallFunction(
@ -937,7 +925,6 @@ getitem_split = ListOf(
_users=MULTIPLE,
),
pass_dict=split_cat_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
@register_graph_pattern(
CallFunction(
@ -947,7 +934,6 @@ getitem_split = ListOf(
_users=MULTIPLE,
),
pass_dict=split_cat_pass,
extra_check=config_flag("split_cat_fx_passes"),
)
def simplify_split_cat(match: Match, split_sections: List[int], dim: int):
if not isinstance(split_sections, (list, tuple)): # Unnormalized split