[SR] Eliminate extra permutes around softmax calls (#76391)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76391

I've seen this pattern in many important internal models:

```
x = torch.permute(a, [0, 2, 1])
y = torch.softmax(x, 2)
z = torch.permute(y, [0, 2, 1])
```

This is equivalent to
```
z = torch.softmax(x, 1)
```

The `permute` ops can degrade performance, especially if copy variants are on. Add another pattern to our `EliminateExtraPermuteOpsPass` to handle this.
ghstack-source-id: 155466506

Test Plan: New unit tests

Reviewed By: navahgar, huiguoo

Differential Revision: D35938289

fbshipit-source-id: 398b5528077b0b3f1c6fc5544e483803e96d68e9
(cherry picked from commit d742abd094d1fef23ca6a34703d97a6da2d14bd1)
This commit is contained in:
Mike Iovine
2022-05-04 15:56:47 -07:00
committed by PyTorch MergeBot
parent cac2733af1
commit 1fed6b7559
2 changed files with 128 additions and 18 deletions

View File

@ -1530,7 +1530,7 @@ TEST(ForceNonEmptyOutputs, TwoSubBlocks) {
}
}
TEST(EliminateExtraPermuteOps, FusesCorrectly) {
TEST(EliminateExtraPermuteOps, FusesSumCorrectly) {
const auto src = R"JIT(
def forward(self, x):
y = torch.permute(x, (0, 2, 1))
@ -1553,7 +1553,7 @@ TEST(EliminateExtraPermuteOps, FusesCorrectly) {
EXPECT_EQ(dim->toIntList(), c10::List<int64_t>{1});
}
TEST(EliminateExtraPermuteOps, DoesNotFuseWrongDim) {
TEST(EliminateExtraPermuteOps, DoesNotFuseSumWrongDim) {
const auto src = R"JIT(
def forward(self, x):
y = torch.permute(x, (0, 2, 1))
@ -1571,7 +1571,7 @@ TEST(EliminateExtraPermuteOps, DoesNotFuseWrongDim) {
EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
}
TEST(EliminateExtraPermuteOps, DoesNotFuseNonConstantDim) {
TEST(EliminateExtraPermuteOps, DoesNotFuseSumNonConstantDim) {
const auto src = R"JIT(
def forward(self, x, dim: int):
y = torch.permute(x, (0, 2, 1))
@ -1589,6 +1589,64 @@ TEST(EliminateExtraPermuteOps, DoesNotFuseNonConstantDim) {
EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
}
TEST(EliminateExtraPermuteOps, FusesSoftmaxCorrectly) {
const auto src = R"JIT(
def forward(self, x):
a = torch.permute(x, [0, 2, 1])
b = torch.softmax(a, 2)
c = torch.permute(b, [0, 2, 1])
return c.clone()
)JIT";
torch::jit::Module mod("m");
mod.define(src);
auto graph = mod.get_method("forward").graph();
ConstantPropagation(graph);
EliminateExtraPermuteOps(graph);
graph->dump();
EXPECT_FALSE(hasNodeWithKind(graph, "aten::permute"));
auto* softmax = getNodeWithKind(graph, "aten::softmax");
ASSERT_NE(softmax, nullptr);
auto dim = toIValue(softmax->input(1));
ASSERT_TRUE(dim.has_value() && dim->isInt());
EXPECT_EQ(dim->toInt(), 1);
std::vector<IValue> args{at::randn({3, 4, 5})};
testStaticRuntime(src, args, /*args2=*/{}, /*use_allclose=*/true);
}
TEST(EliminateExtraPermuteOps, DoesNotFuseSoftmaxWrongPermuteDim) {
const auto src = R"JIT(
def forward(self, x):
a = torch.permute(x, [0, 1, 2])
b = torch.softmax(a, 2)
c = torch.permute(b, [0, 1, 2])
return c.clone()
)JIT";
torch::jit::Module mod("m");
mod.define(src);
auto graph = mod.get_method("forward").graph();
ConstantPropagation(graph);
EliminateExtraPermuteOps(graph);
EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
}
TEST(EliminateExtraPermuteOps, DoesNotFuseSoftmaxWrongSoftmaxDim) {
const auto src = R"JIT(
def forward(self, x):
a = torch.permute(x, [0, 2, 1])
b = torch.softmax(a, 0)
c = torch.permute(b, [0, 2, 1])
return c.clone()
)JIT";
torch::jit::Module mod("m");
mod.define(src);
auto graph = mod.get_method("forward").graph();
ConstantPropagation(graph);
EliminateExtraPermuteOps(graph);
EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
}
TEST(UseSplitAndSqueeze, Fusion) {
const auto src = R"IR(
graph(%x: Tensor):