mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76391 I've seen this pattern in many important internal models: ``` x = torch.permute(a, [0, 2, 1]) y = torch.softmax(x, 2) z = torch.permute(y, [0, 2, 1]) ``` This is equivalent to ``` z = torch.softmax(x, 1) ``` The `permute` ops can degrade performance, especially if copy variants are on. Add another pattern to our `EliminateExtraPermuteOpsPass` to handle this. ghstack-source-id: 155466506 Test Plan: New unit tests Reviewed By: navahgar, huiguoo Differential Revision: D35938289 fbshipit-source-id: 398b5528077b0b3f1c6fc5544e483803e96d68e9 (cherry picked from commit d742abd094d1fef23ca6a34703d97a6da2d14bd1)