[Jit] Fix schema of aten::split int[] version (#69745)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69745

Missed in D31935573 (6b44e75f6b).

Reviewed By: d1jang

Differential Revision: D31889867

fbshipit-source-id: 417bd0b15db4891dbd641b35a803553f11d0d756
This commit is contained in:
Hao Lu
2021-12-10 02:32:26 -08:00
committed by Facebook GitHub Bot
parent 9962bfb3c9
commit 91d16cb633
3 changed files with 57 additions and 2 deletions

View File

@ -1532,6 +1532,61 @@ TEST(AliasRegistrationTest, WildcardAliasForTupleConstructWithUses) {
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["b"], vmap["z"]));
}
TEST(AliasRegistrationTest, ATenSplitIntListAliasCheck) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
auto graph_string = R"IR(
graph():
%x : Tensor = prim::MakeTestTensor()
%0 : int = prim::Constant[value=0]()
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%y : Tensor = aten::add(%x, %x, %0)
%lengths_list : int[] = prim::tolist(%1, %2)
%a : Tensor[] = aten::split(%y, %lengths_list, %0)
%b : Tensor, %c : Tensor = prim::ListUnpack(%a)
%b1 : Tensor = aten::flatten(%b, %0, %1)
%c1 : Tensor = aten::flatten(%c, %0, %1)
%d : Tensor = aten::add(%b1, %c1, %0)
return (%d))IR";
torch::jit::parseIR(graph_string, graph.get(), vmap);
AliasDb aliasDb(
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b1"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c1"]));
}
TEST(AliasRegistrationTest, ATenSplitIntAliasCheck) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
auto graph_string = R"IR(
graph():
%x : Tensor = prim::MakeTestTensor()
%0 : int = prim::Constant[value=0]()
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%y : Tensor = aten::add(%x, %x, %0)
%a : Tensor[] = aten::split(%y, %2, %0)
%b : Tensor, %c : Tensor = prim::ListUnpack(%a)
%b1 : Tensor = aten::flatten(%b, %0, %1)
%c1 : Tensor = aten::flatten(%c, %0, %1)
%d : Tensor = aten::add(%b1, %c1, %0)
return (%d))IR";
torch::jit::parseIR(graph_string, graph.get(), vmap);
AliasDb aliasDb(
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b1"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c1"]));
}
TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) {
auto registry = torch::RegisterOperators().op(
"foo::rand12(Tensor(a) arg1) -> Tensor(b)",