mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[SR] Eliminate extra permutes around softmax calls (#76391)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76391 I've seen this pattern in many important internal models: ``` x = torch.permute(a, [0, 2, 1]) y = torch.softmax(x, 2) z = torch.permute(y, [0, 2, 1]) ``` This is equivalent to ``` z = torch.softmax(x, 1) ``` The `permute` ops can degrade performance, especially if copy variants are on. Add another pattern to our `EliminateExtraPermuteOpsPass` to handle this. ghstack-source-id: 155466506 Test Plan: New unit tests Reviewed By: navahgar, huiguoo Differential Revision: D35938289 fbshipit-source-id: 398b5528077b0b3f1c6fc5544e483803e96d68e9 (cherry picked from commit d742abd094d1fef23ca6a34703d97a6da2d14bd1)
This commit is contained in:
committed by
PyTorch MergeBot
parent
cac2733af1
commit
1fed6b7559
@ -1530,7 +1530,7 @@ TEST(ForceNonEmptyOutputs, TwoSubBlocks) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(EliminateExtraPermuteOps, FusesCorrectly) {
|
||||
TEST(EliminateExtraPermuteOps, FusesSumCorrectly) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, x):
|
||||
y = torch.permute(x, (0, 2, 1))
|
||||
@ -1553,7 +1553,7 @@ TEST(EliminateExtraPermuteOps, FusesCorrectly) {
|
||||
EXPECT_EQ(dim->toIntList(), c10::List<int64_t>{1});
|
||||
}
|
||||
|
||||
TEST(EliminateExtraPermuteOps, DoesNotFuseWrongDim) {
|
||||
TEST(EliminateExtraPermuteOps, DoesNotFuseSumWrongDim) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, x):
|
||||
y = torch.permute(x, (0, 2, 1))
|
||||
@ -1571,7 +1571,7 @@ TEST(EliminateExtraPermuteOps, DoesNotFuseWrongDim) {
|
||||
EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
|
||||
}
|
||||
|
||||
TEST(EliminateExtraPermuteOps, DoesNotFuseNonConstantDim) {
|
||||
TEST(EliminateExtraPermuteOps, DoesNotFuseSumNonConstantDim) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, x, dim: int):
|
||||
y = torch.permute(x, (0, 2, 1))
|
||||
@ -1589,6 +1589,64 @@ TEST(EliminateExtraPermuteOps, DoesNotFuseNonConstantDim) {
|
||||
EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
|
||||
}
|
||||
|
||||
TEST(EliminateExtraPermuteOps, FusesSoftmaxCorrectly) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, x):
|
||||
a = torch.permute(x, [0, 2, 1])
|
||||
b = torch.softmax(a, 2)
|
||||
c = torch.permute(b, [0, 2, 1])
|
||||
return c.clone()
|
||||
)JIT";
|
||||
torch::jit::Module mod("m");
|
||||
mod.define(src);
|
||||
auto graph = mod.get_method("forward").graph();
|
||||
ConstantPropagation(graph);
|
||||
EliminateExtraPermuteOps(graph);
|
||||
graph->dump();
|
||||
|
||||
EXPECT_FALSE(hasNodeWithKind(graph, "aten::permute"));
|
||||
auto* softmax = getNodeWithKind(graph, "aten::softmax");
|
||||
ASSERT_NE(softmax, nullptr);
|
||||
auto dim = toIValue(softmax->input(1));
|
||||
ASSERT_TRUE(dim.has_value() && dim->isInt());
|
||||
EXPECT_EQ(dim->toInt(), 1);
|
||||
|
||||
std::vector<IValue> args{at::randn({3, 4, 5})};
|
||||
testStaticRuntime(src, args, /*args2=*/{}, /*use_allclose=*/true);
|
||||
}
|
||||
|
||||
TEST(EliminateExtraPermuteOps, DoesNotFuseSoftmaxWrongPermuteDim) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, x):
|
||||
a = torch.permute(x, [0, 1, 2])
|
||||
b = torch.softmax(a, 2)
|
||||
c = torch.permute(b, [0, 1, 2])
|
||||
return c.clone()
|
||||
)JIT";
|
||||
torch::jit::Module mod("m");
|
||||
mod.define(src);
|
||||
auto graph = mod.get_method("forward").graph();
|
||||
ConstantPropagation(graph);
|
||||
EliminateExtraPermuteOps(graph);
|
||||
EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
|
||||
}
|
||||
|
||||
TEST(EliminateExtraPermuteOps, DoesNotFuseSoftmaxWrongSoftmaxDim) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, x):
|
||||
a = torch.permute(x, [0, 2, 1])
|
||||
b = torch.softmax(a, 0)
|
||||
c = torch.permute(b, [0, 2, 1])
|
||||
return c.clone()
|
||||
)JIT";
|
||||
torch::jit::Module mod("m");
|
||||
mod.define(src);
|
||||
auto graph = mod.get_method("forward").graph();
|
||||
ConstantPropagation(graph);
|
||||
EliminateExtraPermuteOps(graph);
|
||||
EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
|
||||
}
|
||||
|
||||
TEST(UseSplitAndSqueeze, Fusion) {
|
||||
const auto src = R"IR(
|
||||
graph(%x: Tensor):
|
||||
|
Reference in New Issue
Block a user