mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	[Jit] Fix schema of aten::split int[] version (#69745)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69745
Missed in D31935573 (6b44e75f6b).
Reviewed By: d1jang
Differential Revision: D31889867
fbshipit-source-id: 417bd0b15db4891dbd641b35a803553f11d0d756
			
			
This commit is contained in:
		
				
					committed by
					
						
						Facebook GitHub Bot
					
				
			
			
				
	
			
			
			
						parent
						
							9962bfb3c9
						
					
				
				
					commit
					91d16cb633
				
			@ -1532,6 +1532,61 @@ TEST(AliasRegistrationTest, WildcardAliasForTupleConstructWithUses) {
 | 
			
		||||
  EXPECT_TRUE(aliasDb.mayContainAlias(vmap["b"], vmap["z"]));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(AliasRegistrationTest, ATenSplitIntListAliasCheck) {
 | 
			
		||||
  auto graph = std::make_shared<Graph>();
 | 
			
		||||
  std::unordered_map<std::string, Value*> vmap;
 | 
			
		||||
  auto graph_string = R"IR(
 | 
			
		||||
  graph():
 | 
			
		||||
    %x : Tensor = prim::MakeTestTensor()
 | 
			
		||||
    %0 : int = prim::Constant[value=0]()
 | 
			
		||||
    %1 : int = prim::Constant[value=1]()
 | 
			
		||||
    %2 : int = prim::Constant[value=2]()
 | 
			
		||||
    %y : Tensor = aten::add(%x, %x, %0)
 | 
			
		||||
    %lengths_list : int[] = prim::tolist(%1, %2)
 | 
			
		||||
    %a : Tensor[] = aten::split(%y, %lengths_list, %0)
 | 
			
		||||
    %b : Tensor, %c : Tensor = prim::ListUnpack(%a)
 | 
			
		||||
    %b1 : Tensor = aten::flatten(%b, %0, %1)
 | 
			
		||||
    %c1 : Tensor = aten::flatten(%c, %0, %1)
 | 
			
		||||
    %d : Tensor = aten::add(%b1, %c1, %0)
 | 
			
		||||
    return (%d))IR";
 | 
			
		||||
 | 
			
		||||
  torch::jit::parseIR(graph_string, graph.get(), vmap);
 | 
			
		||||
  AliasDb aliasDb(
 | 
			
		||||
      graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
 | 
			
		||||
  EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
 | 
			
		||||
  EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b1"]));
 | 
			
		||||
  EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c1"]));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(AliasRegistrationTest, ATenSplitIntAliasCheck) {
 | 
			
		||||
  auto graph = std::make_shared<Graph>();
 | 
			
		||||
  std::unordered_map<std::string, Value*> vmap;
 | 
			
		||||
  auto graph_string = R"IR(
 | 
			
		||||
  graph():
 | 
			
		||||
    %x : Tensor = prim::MakeTestTensor()
 | 
			
		||||
    %0 : int = prim::Constant[value=0]()
 | 
			
		||||
    %1 : int = prim::Constant[value=1]()
 | 
			
		||||
    %2 : int = prim::Constant[value=2]()
 | 
			
		||||
    %y : Tensor = aten::add(%x, %x, %0)
 | 
			
		||||
    %a : Tensor[] = aten::split(%y, %2, %0)
 | 
			
		||||
    %b : Tensor, %c : Tensor = prim::ListUnpack(%a)
 | 
			
		||||
    %b1 : Tensor = aten::flatten(%b, %0, %1)
 | 
			
		||||
    %c1 : Tensor = aten::flatten(%c, %0, %1)
 | 
			
		||||
    %d : Tensor = aten::add(%b1, %c1, %0)
 | 
			
		||||
    return (%d))IR";
 | 
			
		||||
 | 
			
		||||
  torch::jit::parseIR(graph_string, graph.get(), vmap);
 | 
			
		||||
  AliasDb aliasDb(
 | 
			
		||||
      graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
 | 
			
		||||
  EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
 | 
			
		||||
  EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b1"]));
 | 
			
		||||
  EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c1"]));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) {
 | 
			
		||||
  auto registry = torch::RegisterOperators().op(
 | 
			
		||||
      "foo::rand12(Tensor(a) arg1) -> Tensor(b)",
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user