mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Implement torch.nn.Embedding / EmbeddingBag in PyTorch C++ API (#26358)
Summary: added more variables to EmbeddingOptions and updated EmbeddingImpl reset, forward functions. Also added EmbeddingBag. ----- This PR is BC-breaking in the following way: Previously, `EmbeddingOptions` supports `count` and `dimension` as options arguments. After this PR, they are renamed to `num_embeddings` and `embedding_dim` respectively. Pull Request resolved: https://github.com/pytorch/pytorch/pull/26358 Differential Revision: D17714337 Pulled By: yf225 fbshipit-source-id: f9f969c68e4bece106b92f8e2e02ac39c8455fb7
This commit is contained in:
		
				
					committed by
					
						
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							b96f49885f
						
					
				
				
					commit
					a37be201c1
				
			@ -800,6 +800,21 @@ TEST_F(ModulesTest, EmbeddingList) {
 | 
			
		||||
  ASSERT_EQ(y.size(2), 4);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ModulesTest, EmbeddingFromPretrained) {
 | 
			
		||||
  auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
 | 
			
		||||
  Embedding embedding = torch::nn::Embedding::from_pretrained(weight);
 | 
			
		||||
  auto input = torch::tensor({1}, torch::kLong);
 | 
			
		||||
  ASSERT_TRUE(torch::allclose(embedding(input), torch::tensor({4.0000, 5.1000, 6.3000})));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ModulesTest, EmbeddingBagFromPretrained) {
 | 
			
		||||
  auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
 | 
			
		||||
  EmbeddingBag embeddingbag = torch::nn::EmbeddingBag::from_pretrained(weight);
 | 
			
		||||
  auto input = torch::zeros({{1, 2}}, torch::kLong);
 | 
			
		||||
  input[0] = torch::tensor({1, 0});
 | 
			
		||||
  ASSERT_TRUE(torch::allclose(embeddingbag(input), torch::tensor({2.5000, 3.7000, 4.6500})));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ModulesTest, Dropout) {
 | 
			
		||||
  Dropout dropout(0.5);
 | 
			
		||||
  torch::Tensor x = torch::ones(100, torch::requires_grad());
 | 
			
		||||
@ -1290,8 +1305,29 @@ TEST_F(ModulesTest, PrettyPrintBatchNorm) {
 | 
			
		||||
 | 
			
		||||
TEST_F(ModulesTest, PrettyPrintEmbedding) {
 | 
			
		||||
  ASSERT_EQ(
 | 
			
		||||
      c10::str(Embedding(10, 2)),
 | 
			
		||||
      "torch::nn::Embedding(count=10, dimension=2)");
 | 
			
		||||
      c10::str(Embedding(EmbeddingOptions(10, 2))),
 | 
			
		||||
      "torch::nn::Embedding(num_embeddings=10, embedding_dim=2)");
 | 
			
		||||
  ASSERT_EQ(
 | 
			
		||||
      c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2))),
 | 
			
		||||
      "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)");
 | 
			
		||||
  ASSERT_EQ(
 | 
			
		||||
      c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true))),
 | 
			
		||||
      "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ModulesTest, PrettyPrintEmbeddingBag) {
 | 
			
		||||
  ASSERT_EQ(
 | 
			
		||||
      c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2))),
 | 
			
		||||
      "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2)");
 | 
			
		||||
  ASSERT_EQ(
 | 
			
		||||
      c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2))),
 | 
			
		||||
      "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2)");
 | 
			
		||||
  ASSERT_EQ(
 | 
			
		||||
      c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true))),
 | 
			
		||||
      "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)");
 | 
			
		||||
  ASSERT_EQ(
 | 
			
		||||
      c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode("sum"))),
 | 
			
		||||
      "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=sum)");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ModulesTest, PrettyPrintHingeEmbeddingLoss) {
 | 
			
		||||
@ -1339,7 +1375,7 @@ TEST_F(ModulesTest, PrettyPrintNestedModel) {
 | 
			
		||||
    TestModule()
 | 
			
		||||
        : torch::nn::Module("TestModule"),
 | 
			
		||||
          fc(register_module("fc", torch::nn::Linear(4, 5))),
 | 
			
		||||
          table(register_module("table", torch::nn::Embedding(10, 2))),
 | 
			
		||||
          table(register_module("table", torch::nn::Embedding(EmbeddingOptions(10, 2)))),
 | 
			
		||||
          inner(register_module("inner", std::make_shared<InnerTestModule>())) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -1352,10 +1388,10 @@ TEST_F(ModulesTest, PrettyPrintNestedModel) {
 | 
			
		||||
      c10::str(TestModule{}),
 | 
			
		||||
      "TestModule(\n"
 | 
			
		||||
      "  (fc): torch::nn::Linear(in=4, out=5, with_bias=true)\n"
 | 
			
		||||
      "  (table): torch::nn::Embedding(count=10, dimension=2)\n"
 | 
			
		||||
      "  (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n"
 | 
			
		||||
      "  (inner): InnerTestModule(\n"
 | 
			
		||||
      "    (fc): torch::nn::Linear(in=3, out=4, with_bias=true)\n"
 | 
			
		||||
      "    (table): torch::nn::Embedding(count=10, dimension=2)\n"
 | 
			
		||||
      "    (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n"
 | 
			
		||||
      "  )\n"
 | 
			
		||||
      ")");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user