mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Optimus] Fix normalization pass in the aten IR (#157857)
Summary: We found there's a special case in recent APS model where the input tensor has smaller size compared to the split size. It will be automatically truncated in split.Tensor thus we add extra condition check for split_with_sizes when do the normalization. Test Plan: ### unit ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:split_cat_fx_aten_passes -- test_split_aten_normalization ``` Buck UI: https://www.internalfb.com/buck2/2ecd1ef8-8efe-4245-b4c8-282c23645b3c Test UI: https://www.internalfb.com/intern/testinfra/testrun/7599824648585787 Network: Up: 3.9GiB Down: 9.2GiB (reSessionID-1396c91e-0dd2-457b-a49b-a6ab1f2a7d8f) Loading targets. Remaining 0/5344 99617 dirs read, 1074949 targets declared Analyzing targets. Remaining 0/123279 4988547 actions, 5966764 artifacts declared Executing actions. Remaining 0/728058 209:52:59.9s exec time total Command: test. Finished 12466 local, 209448 remote, 1226 cache (1% hit) 42:10.5s exec time cached (0%) Time elapsed: 26:07.6s Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 ### E2E before fix: aps-afoc_apop_pt2_v0-db2fe0449a after fix: aps-afoc_apop_pt2_v0-755ad0cdc6 Rollback Plan: Differential Revision: D77961394 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157857 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
effe376db0
commit
e3f2597b45
@ -325,5 +325,38 @@ class TestSplitCatAten(TestCase):
|
||||
counters.clear()
|
||||
|
||||
|
||||
class TestSplitCatAtenNormalizationPasses(TestCase):
|
||||
@torch._inductor.config.patch(
|
||||
pre_grad_fusion_options={},
|
||||
post_grad_fusion_options={
|
||||
"normalization_aten_pass": {},
|
||||
},
|
||||
)
|
||||
def test_split_aten_normalization(self):
|
||||
def arg_only_size_same(x):
|
||||
return torch.ops.aten.split.Tensor(x, 300, 1)
|
||||
|
||||
def arg_only_size_different(x):
|
||||
return torch.ops.aten.split.Tensor(x, 320, 1)
|
||||
|
||||
args = [
|
||||
torch.randn(4096, 300),
|
||||
]
|
||||
for fn, expected_split_norm_count in [
|
||||
(arg_only_size_same, 1),
|
||||
(arg_only_size_different, 1),
|
||||
]:
|
||||
expected = fn(*args)
|
||||
actual = torch.compile(fn)(*args)
|
||||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["normalization_aten_pass"],
|
||||
expected_split_norm_count,
|
||||
msg=f"for {fn}",
|
||||
)
|
||||
counters.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1701,6 +1701,12 @@ def normalize_split_default_aten(match: Match, *args, **kwargs):
|
||||
return
|
||||
if split_dim < 0: # Normalize split dim
|
||||
split_dim += split_input.meta["val"].dim()
|
||||
# we also need to check the input of the split_node
|
||||
# primals =torch.randn(4096, 300)
|
||||
# split = torch.ops.aten.split.Tensor(primals, 320, 1) -> truncate to 300 automatically
|
||||
# split_2 = torch.ops.aten.split_with_sizes.default(primals, [320], dim = 1) -> runtime error
|
||||
split_input_size = split_input.meta["val"].shape[split_dim]
|
||||
split_size = min(split_size, split_input_size)
|
||||
split_section_list = [split_size] * (len(split_node.meta["val"]))
|
||||
new_args = (split_input, split_section_list)
|
||||
new_kwargs = {"dim": split_dim}
|
||||
|
Reference in New Issue
Block a user