add torch.concat to normalization pass (#156574)

Summary: In the normalization pass, we also add torch.concat to it to normalize it as torch.cat

Test Plan:
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:split_cat_fx_passes -- test_cat_normalization
```

Buck UI: https://www.internalfb.com/buck2/597fd4f1-0aa7-4372-8a66-5a690d9b63a4
Test UI: https://www.internalfb.com/intern/testinfra/testrun/1688850152284203
Network: Up: 84KiB  Down: 34KiB  (reSessionID-3916e009-7117-41ce-b6f9-089873aa50dd)
Executing actions. Remaining     0/3                                                                                              1.1s exec time total
Command: test.     Finished 2 local
Time elapsed: 3:47.1s
Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0

Rollback Plan:

Differential Revision: D77125331

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156574
Approved by: https://github.com/Mingming-Ding
This commit is contained in:
Menglu Yu
2025-06-27 06:07:26 +00:00
committed by PyTorch MergeBot
parent 1155c53e7d
commit 640703d95f
2 changed files with 29 additions and 1 deletions

View File

@ -115,6 +115,33 @@ class TestSplitCatFxPasses(TestCase):
)
counters.clear()
@torch._inductor.config.patch(
pre_grad_fusion_options={
"normalization_pass": {},
},
post_grad_fusion_options={},
)
def test_cat_normalization(self):
def caoncat_only(x):
return torch.concat(list(torch.split(x, 2, 1)), dim=1)
args = [
torch.randn(2, 32),
]
for fn, dynamic, expected_cat_norm_count in [
(caoncat_only, False, 2),
]:
expected = fn(*args)
actual = torch.compile(fn, dynamic=dynamic)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(
counters["inductor"]["normalization_pass"],
expected_cat_norm_count,
msg=f"for {fn}",
)
counters.clear()
@patch
def test_consecutive_split_merge(self):
def multi_split(x):

View File

@ -302,7 +302,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs):
@register_graph_pattern(
CallFunctionVarArgs(torch.cat, users=MULTIPLE),
CallFunctionVarArgs([torch.cat, torch.concat], users=MULTIPLE),
pass_dict=construct_pattern_matcher_pass("normalization_pass"),
)
def normalize_cat_default(match: Match, *args, **kwargs):
@ -347,6 +347,7 @@ def normalize_cat_default(match: Match, *args, **kwargs):
cat_node.args == new_args
and cat_node.kwargs == new_kwargs
and cat_node.op == "call_function"
and cat_node.target == torch.cat
):
return