mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][FX passes] Remove config.split_cat_fx_passes & Add config.experimental_patterns (#104208)
Summary: TLDR: * Remove config.split_cat_fx_passes, and move split cat passes behind config.pattern_matcher (True by default) * Add config.experimental_patterns (False by default). * In the future, general/universal patterns should behind config.pattern_matcher; customized/unmatured patterns should behind config.experimental_patterns. More details at: https://docs.google.com/document/d/1P8uJTpOTdQpUbw56UxHol40tt-EPFTq1Qu38072E9aM/edit Test Plan: Existing unit tests Reviewed By: jansel, jackiexu1992 Differential Revision: D46752606 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104208 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
2da6cae43c
commit
7bf27cf163
@ -7,13 +7,7 @@ from torch.testing._internal.common_utils import IS_LINUX
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
|
||||
def patch(f):
|
||||
f = torch._inductor.config.patch(split_cat_fx_passes=True)(f)
|
||||
return f
|
||||
|
||||
|
||||
class TestSplitCatFxPasses(TestCase):
|
||||
@patch
|
||||
def test_split_normalization(self):
|
||||
def arg_only(x):
|
||||
return [torch.relu(s) for s in torch.split(x, 2, 1)]
|
||||
@ -91,7 +85,6 @@ class TestSplitCatFxPasses(TestCase):
|
||||
)
|
||||
counters.clear()
|
||||
|
||||
@patch
|
||||
def test_consecutive_split_merge(self):
|
||||
def multi_split(x):
|
||||
return [torch.split(s, 2, 1) for s in torch.split(x, 2, 1)]
|
||||
@ -252,7 +245,6 @@ class TestSplitCatFxPasses(TestCase):
|
||||
)
|
||||
counters.clear()
|
||||
|
||||
@patch
|
||||
def test_split_cat_merge(self):
|
||||
def simple_split_cat(x):
|
||||
return torch.cat(torch.split(x, 4, dim=1), dim=1)
|
||||
@ -582,7 +574,7 @@ class TestSplitCatFxPasses(TestCase):
|
||||
)
|
||||
counters.clear()
|
||||
|
||||
@torch._inductor.config.patch(split_cat_fx_passes=False)
|
||||
@torch._inductor.config.patch(pattern_matcher=False)
|
||||
def test_config_flag_is_respected(self):
|
||||
def split_with_cat(x):
|
||||
fs = torch.split(x, [4, 4, 24], dim=-1)
|
||||
@ -612,7 +604,6 @@ class TestSplitCatFxPasses(TestCase):
|
||||
0,
|
||||
)
|
||||
|
||||
@patch
|
||||
def test_split_cat_merge_mutation(self):
|
||||
args = [
|
||||
torch.randn(2, 32, 32, 16),
|
||||
@ -631,7 +622,6 @@ class TestSplitCatFxPasses(TestCase):
|
||||
self.assertEqual(counters["inductor"]["scmerge_split_removed"], 0)
|
||||
self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 0)
|
||||
|
||||
@patch
|
||||
def test_split_squeeze(self):
|
||||
def split_squeeze_stack(x):
|
||||
items = list(torch.split(x, 1, dim=1))
|
||||
@ -721,7 +711,6 @@ class TestSplitCatFxPasses(TestCase):
|
||||
)
|
||||
counters.clear()
|
||||
|
||||
@patch
|
||||
def test_unbind_stack(self):
|
||||
def unbind_stack(x):
|
||||
return torch.stack(torch.unbind(x, dim=1), 1)
|
||||
|
||||
@ -45,8 +45,8 @@ epilogue_fusion_first = False
|
||||
# enable pattern match+replace optimizations
|
||||
pattern_matcher = True
|
||||
|
||||
# Optimize away split cat patterns (Experimental)
|
||||
split_cat_fx_passes = True
|
||||
# enable experimental patterns for match+replace optimizations
|
||||
experimental_patterns = False
|
||||
|
||||
# enable reordering pass
|
||||
reordering = True
|
||||
|
||||
@ -12,7 +12,6 @@ from ..pattern_matcher import (
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethodVarArgs,
|
||||
config_flag,
|
||||
FailedMatch,
|
||||
get_arg_value,
|
||||
Ignored,
|
||||
@ -86,12 +85,10 @@ def normalize_split_base(match: Match, _get_split_args: Callable):
|
||||
@register_graph_pattern(
|
||||
CallFunctionVarArgs(torch.split, users=MULTIPLE),
|
||||
pass_dict=normalization_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
@register_graph_pattern(
|
||||
CallMethodVarArgs("split", users=MULTIPLE),
|
||||
pass_dict=normalization_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
def normalize_split_default(match: Match, *args, **kwargs):
|
||||
return normalize_split_base(match, _get_split_args_default)
|
||||
@ -100,7 +97,6 @@ def normalize_split_default(match: Match, *args, **kwargs):
|
||||
@register_graph_pattern(
|
||||
CallFunctionVarArgs(torch.cat, users=MULTIPLE),
|
||||
pass_dict=normalization_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
def normalize_cat_default(match: Match, *args, **kwargs):
|
||||
cat_node = match.nodes[0]
|
||||
@ -151,7 +147,6 @@ def find_next_users(split_node):
|
||||
@register_graph_pattern(
|
||||
CallMethodVarArgs("squeeze", users=MULTIPLE),
|
||||
pass_dict=normalization_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
def normalize_squeeze_default(match: Match, *args, **kwargs):
|
||||
squeeze_node = match.nodes[0]
|
||||
@ -219,7 +214,6 @@ class TorchSplit(CallFunction):
|
||||
KeywordArg("next_split_sections"),
|
||||
),
|
||||
pass_dict=merge_splits_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
def merge_splits(
|
||||
match: Match,
|
||||
@ -814,7 +808,6 @@ class GetItem(CallFunction):
|
||||
),
|
||||
),
|
||||
pass_dict=split_cat_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
@register_graph_pattern(
|
||||
RepeatedExpr(
|
||||
@ -832,7 +825,6 @@ class GetItem(CallFunction):
|
||||
)
|
||||
),
|
||||
pass_dict=split_cat_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
def merge_split_squeeze(
|
||||
match: Match, split_input: torch.fx.Node, split_sizes: List[int], dim: int
|
||||
@ -884,21 +876,18 @@ getitem_unbind = ListOf(
|
||||
@register_graph_pattern(
|
||||
CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE),
|
||||
pass_dict=unbind_stack_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
@register_graph_pattern(
|
||||
CallFunction(
|
||||
[torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE
|
||||
),
|
||||
pass_dict=unbind_stack_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
@register_graph_pattern(
|
||||
CallFunction(
|
||||
[torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE
|
||||
),
|
||||
pass_dict=unbind_stack_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int):
|
||||
unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
|
||||
@ -927,7 +916,6 @@ getitem_split = ListOf(
|
||||
_users=MULTIPLE,
|
||||
),
|
||||
pass_dict=split_cat_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
@register_graph_pattern(
|
||||
CallFunction(
|
||||
@ -937,7 +925,6 @@ getitem_split = ListOf(
|
||||
_users=MULTIPLE,
|
||||
),
|
||||
pass_dict=split_cat_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
@register_graph_pattern(
|
||||
CallFunction(
|
||||
@ -947,7 +934,6 @@ getitem_split = ListOf(
|
||||
_users=MULTIPLE,
|
||||
),
|
||||
pass_dict=split_cat_pass,
|
||||
extra_check=config_flag("split_cat_fx_passes"),
|
||||
)
|
||||
def simplify_split_cat(match: Match, split_sections: List[int], dim: int):
|
||||
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
||||
|
||||
Reference in New Issue
Block a user