mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
add torch.concat to normalization pass (#156574)
Summary: In the normalization pass, we also add torch.concat to it to normalize it as torch.cat Test Plan: ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:split_cat_fx_passes -- test_cat_normalization ``` Buck UI: https://www.internalfb.com/buck2/597fd4f1-0aa7-4372-8a66-5a690d9b63a4 Test UI: https://www.internalfb.com/intern/testinfra/testrun/1688850152284203 Network: Up: 84KiB Down: 34KiB (reSessionID-3916e009-7117-41ce-b6f9-089873aa50dd) Executing actions. Remaining 0/3 1.1s exec time total Command: test. Finished 2 local Time elapsed: 3:47.1s Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 Rollback Plan: Differential Revision: D77125331 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156574 Approved by: https://github.com/Mingming-Ding
This commit is contained in:
committed by
PyTorch MergeBot
parent
1155c53e7d
commit
640703d95f
@ -115,6 +115,33 @@ class TestSplitCatFxPasses(TestCase):
|
||||
)
|
||||
counters.clear()
|
||||
|
||||
@torch._inductor.config.patch(
|
||||
pre_grad_fusion_options={
|
||||
"normalization_pass": {},
|
||||
},
|
||||
post_grad_fusion_options={},
|
||||
)
|
||||
def test_cat_normalization(self):
|
||||
def caoncat_only(x):
|
||||
return torch.concat(list(torch.split(x, 2, 1)), dim=1)
|
||||
|
||||
args = [
|
||||
torch.randn(2, 32),
|
||||
]
|
||||
for fn, dynamic, expected_cat_norm_count in [
|
||||
(caoncat_only, False, 2),
|
||||
]:
|
||||
expected = fn(*args)
|
||||
actual = torch.compile(fn, dynamic=dynamic)(*args)
|
||||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["normalization_pass"],
|
||||
expected_cat_norm_count,
|
||||
msg=f"for {fn}",
|
||||
)
|
||||
counters.clear()
|
||||
|
||||
@patch
|
||||
def test_consecutive_split_merge(self):
|
||||
def multi_split(x):
|
||||
|
@ -302,7 +302,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs):
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunctionVarArgs(torch.cat, users=MULTIPLE),
|
||||
CallFunctionVarArgs([torch.cat, torch.concat], users=MULTIPLE),
|
||||
pass_dict=construct_pattern_matcher_pass("normalization_pass"),
|
||||
)
|
||||
def normalize_cat_default(match: Match, *args, **kwargs):
|
||||
@ -347,6 +347,7 @@ def normalize_cat_default(match: Match, *args, **kwargs):
|
||||
cat_node.args == new_args
|
||||
and cat_node.kwargs == new_kwargs
|
||||
and cat_node.op == "call_function"
|
||||
and cat_node.target == torch.cat
|
||||
):
|
||||
return
|
||||
|
||||
|
Reference in New Issue
Block a user