[Optimus] Fix normalization pass in the aten IR (#157857)

Summary: We found there's a special case in recent APS model where the input tensor has smaller size compared to the split size. It will be automatically truncated in split.Tensor thus we add extra condition check for split_with_sizes when do the normalization.

Test Plan:
### unit
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:split_cat_fx_aten_passes -- test_split_aten_normalization
```

Buck UI: https://www.internalfb.com/buck2/2ecd1ef8-8efe-4245-b4c8-282c23645b3c
Test UI: https://www.internalfb.com/intern/testinfra/testrun/7599824648585787
Network: Up: 3.9GiB  Down: 9.2GiB  (reSessionID-1396c91e-0dd2-457b-a49b-a6ab1f2a7d8f)
Loading targets.   Remaining      0/5344                                                                                                              99617 dirs read, 1074949 targets declared
Analyzing targets. Remaining      0/123279                                                                                                            4988547 actions, 5966764 artifacts declared
Executing actions. Remaining      0/728058                                                                                                            209:52:59.9s exec time total
Command: test.     Finished 12466 local, 209448 remote, 1226 cache (1% hit)                                                                           42:10.5s exec time cached (0%)
Time elapsed: 26:07.6s
Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0

### E2E

before fix:
aps-afoc_apop_pt2_v0-db2fe0449a

after fix:
aps-afoc_apop_pt2_v0-755ad0cdc6

Rollback Plan:

Differential Revision: D77961394

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157857
Approved by: https://github.com/anijain2305
This commit is contained in:
Menglu Yu
2025-07-09 05:38:15 +00:00
committed by PyTorch MergeBot
parent effe376db0
commit e3f2597b45
2 changed files with 39 additions and 0 deletions

View File

@ -325,5 +325,38 @@ class TestSplitCatAten(TestCase):
counters.clear()
class TestSplitCatAtenNormalizationPasses(TestCase):
@torch._inductor.config.patch(
pre_grad_fusion_options={},
post_grad_fusion_options={
"normalization_aten_pass": {},
},
)
def test_split_aten_normalization(self):
def arg_only_size_same(x):
return torch.ops.aten.split.Tensor(x, 300, 1)
def arg_only_size_different(x):
return torch.ops.aten.split.Tensor(x, 320, 1)
args = [
torch.randn(4096, 300),
]
for fn, expected_split_norm_count in [
(arg_only_size_same, 1),
(arg_only_size_different, 1),
]:
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(
counters["inductor"]["normalization_aten_pass"],
expected_split_norm_count,
msg=f"for {fn}",
)
counters.clear()
if __name__ == "__main__":
run_tests()

View File

@ -1701,6 +1701,12 @@ def normalize_split_default_aten(match: Match, *args, **kwargs):
return
if split_dim < 0: # Normalize split dim
split_dim += split_input.meta["val"].dim()
# we also need to check the input of the split_node
# primals =torch.randn(4096, 300)
# split = torch.ops.aten.split.Tensor(primals, 320, 1) -> truncate to 300 automatically
# split_2 = torch.ops.aten.split_with_sizes.default(primals, [320], dim = 1) -> runtime error
split_input_size = split_input.meta["val"].shape[split_dim]
split_size = min(split_size, split_input_size)
split_section_list = [split_size] * (len(split_node.meta["val"]))
new_args = (split_input, split_section_list)
new_kwargs = {"dim": split_dim}