mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	* Better forward methods in C++ API capitalize error message in test_torch.test_flatten Support for operator() * Add operator() to Functional * Get rid of SigmoidLinear * Add BoundFunction to FunctionalImpl * Remove macro from conv because it makes errors more nasty
		
			
				
	
	
		
			197 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			197 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#include <catch.hpp>
 | 
						|
 | 
						|
#include <torch/nn/modules.h>
 | 
						|
#include <torch/nn/modules/linear.h>
 | 
						|
#include <torch/nn/modules/sequential.h>
 | 
						|
#include <torch/tensor.h>
 | 
						|
 | 
						|
#include <memory>
 | 
						|
#include <vector>
 | 
						|
 | 
						|
using namespace torch::nn;
 | 
						|
 | 
						|
using Catch::StartsWith;
 | 
						|
 | 
						|
TEST_CASE("sequential") {
 | 
						|
  SECTION("construction from shared pointer") {
 | 
						|
    struct M : torch::nn::Module {
 | 
						|
      explicit M(int value_) : value(value_) {}
 | 
						|
      int value;
 | 
						|
      int forward() {
 | 
						|
        return value;
 | 
						|
      }
 | 
						|
    };
 | 
						|
    Sequential sequential(
 | 
						|
        std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
 | 
						|
    REQUIRE(sequential.size() == 3);
 | 
						|
  }
 | 
						|
  SECTION("construction from concrete type") {
 | 
						|
    struct M : torch::nn::Module {
 | 
						|
      explicit M(int value_) : value(value_) {}
 | 
						|
      int value;
 | 
						|
      int forward() {
 | 
						|
        return value;
 | 
						|
      }
 | 
						|
    };
 | 
						|
 | 
						|
    Sequential sequential(M(1), M(2), M(3));
 | 
						|
    REQUIRE(sequential.size() == 3);
 | 
						|
  }
 | 
						|
  SECTION("construction from module holders") {
 | 
						|
    struct MImpl : torch::nn::Module {
 | 
						|
      explicit MImpl(int value_) : value(value_) {}
 | 
						|
      int forward() {
 | 
						|
        return value;
 | 
						|
      }
 | 
						|
      int value;
 | 
						|
    };
 | 
						|
 | 
						|
    struct M : torch::nn::ModuleHolder<MImpl> {
 | 
						|
      using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
 | 
						|
      using torch::nn::ModuleHolder<MImpl>::get;
 | 
						|
    };
 | 
						|
 | 
						|
    Sequential sequential(M(1), M(2), M(3));
 | 
						|
    REQUIRE(sequential.size() == 3);
 | 
						|
  }
 | 
						|
  SECTION("push_back") {
 | 
						|
    struct M : torch::nn::Module {
 | 
						|
      explicit M(int value_) : value(value_) {}
 | 
						|
      int forward() {
 | 
						|
        return value;
 | 
						|
      }
 | 
						|
      int value;
 | 
						|
    };
 | 
						|
    Sequential sequential;
 | 
						|
    REQUIRE(sequential.size() == 0);
 | 
						|
    REQUIRE(sequential.is_empty());
 | 
						|
    sequential.push_back(Linear(3, 4));
 | 
						|
    REQUIRE(sequential.size() == 1);
 | 
						|
    sequential.push_back(std::make_shared<M>(1));
 | 
						|
    REQUIRE(sequential.size() == 2);
 | 
						|
    sequential.push_back(M(2));
 | 
						|
    REQUIRE(sequential.size() == 3);
 | 
						|
  }
 | 
						|
  SECTION("access") {
 | 
						|
    struct M : torch::nn::Module {
 | 
						|
      explicit M(int value_) : value(value_) {}
 | 
						|
      int forward() {
 | 
						|
        return value;
 | 
						|
      }
 | 
						|
      int value;
 | 
						|
    };
 | 
						|
    std::vector<std::shared_ptr<M>> modules = {
 | 
						|
        std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
 | 
						|
 | 
						|
    Sequential sequential;
 | 
						|
    for (auto& module : modules) {
 | 
						|
      sequential.push_back(module);
 | 
						|
    }
 | 
						|
    REQUIRE(sequential.size() == 3);
 | 
						|
 | 
						|
    SECTION("at()") {
 | 
						|
      SECTION("returns the correct module for a given index") {
 | 
						|
        for (size_t i = 0; i < modules.size(); ++i) {
 | 
						|
          REQUIRE(&sequential.at<M>(i) == modules[i].get());
 | 
						|
        }
 | 
						|
      }
 | 
						|
      SECTION("throws for a bad index") {
 | 
						|
        REQUIRE_THROWS_WITH(
 | 
						|
            sequential.at<M>(modules.size() + 1),
 | 
						|
            StartsWith("Index out of range"));
 | 
						|
        REQUIRE_THROWS_WITH(
 | 
						|
            sequential.at<M>(modules.size() + 1000000),
 | 
						|
            StartsWith("Index out of range"));
 | 
						|
      }
 | 
						|
    }
 | 
						|
 | 
						|
    SECTION("ptr()") {
 | 
						|
      SECTION("returns the correct module for a given index") {
 | 
						|
        for (size_t i = 0; i < modules.size(); ++i) {
 | 
						|
          REQUIRE(sequential.ptr(i).get() == modules[i].get());
 | 
						|
          REQUIRE(sequential[i].get() == modules[i].get());
 | 
						|
          REQUIRE(sequential.ptr<M>(i).get() == modules[i].get());
 | 
						|
        }
 | 
						|
      }
 | 
						|
      SECTION("throws for a bad index") {
 | 
						|
        REQUIRE_THROWS_WITH(
 | 
						|
            sequential.ptr(modules.size() + 1),
 | 
						|
            StartsWith("Index out of range"));
 | 
						|
        REQUIRE_THROWS_WITH(
 | 
						|
            sequential.ptr(modules.size() + 1000000),
 | 
						|
            StartsWith("Index out of range"));
 | 
						|
      }
 | 
						|
    }
 | 
						|
  }
 | 
						|
  SECTION("forward") {
 | 
						|
    SECTION("calling forward() on an empty sequential is disallowed") {
 | 
						|
      Sequential empty;
 | 
						|
      REQUIRE_THROWS_WITH(
 | 
						|
          empty.forward<int>(),
 | 
						|
          StartsWith("Cannot call forward() on an empty Sequential"));
 | 
						|
    }
 | 
						|
 | 
						|
    SECTION("calling forward() on a non-empty sequential chains correctly") {
 | 
						|
      struct MockModule : torch::nn::Module {
 | 
						|
        explicit MockModule(int value) : expected(value) {}
 | 
						|
        int expected;
 | 
						|
        int forward(int value) {
 | 
						|
          REQUIRE(value == expected);
 | 
						|
          return value + 1;
 | 
						|
        }
 | 
						|
      };
 | 
						|
 | 
						|
      Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3});
 | 
						|
 | 
						|
      REQUIRE(sequential.forward<int>(1) == 4);
 | 
						|
    }
 | 
						|
 | 
						|
    SECTION("calling forward() with the wrong return type throws") {
 | 
						|
      struct M : public torch::nn::Module {
 | 
						|
        int forward() {
 | 
						|
          return 5;
 | 
						|
        }
 | 
						|
      };
 | 
						|
 | 
						|
      Sequential sequential(M{});
 | 
						|
      REQUIRE(sequential.forward<int>() == 5);
 | 
						|
      REQUIRE_THROWS_WITH(
 | 
						|
          sequential.forward<float>(),
 | 
						|
          StartsWith("The type of the return value "
 | 
						|
                     "is int, but you asked for type float"));
 | 
						|
    }
 | 
						|
 | 
						|
    SECTION("The return type of forward() defaults to Tensor") {
 | 
						|
      struct M : public torch::nn::Module {
 | 
						|
        torch::Tensor forward(torch::Tensor v) {
 | 
						|
          return v;
 | 
						|
        }
 | 
						|
      };
 | 
						|
 | 
						|
      Sequential sequential(M{});
 | 
						|
      auto variable = torch::ones({3, 3}, at::requires_grad());
 | 
						|
      REQUIRE(sequential.forward(variable).equal(variable));
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  SECTION("returns the last value") {
 | 
						|
    Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
 | 
						|
 | 
						|
    auto x = torch::randn({1000, 10}, at::requires_grad());
 | 
						|
    auto y = sequential.forward(x);
 | 
						|
    REQUIRE(y.ndimension() == 2);
 | 
						|
    REQUIRE(y.size(0) == 1000);
 | 
						|
    REQUIRE(y.size(1) == 100);
 | 
						|
  }
 | 
						|
 | 
						|
  SECTION("can hold other important modules") {
 | 
						|
    Sequential sequential(
 | 
						|
        Linear(10, 3),
 | 
						|
        Conv2d(1, 2, 3),
 | 
						|
        Dropout(0.5),
 | 
						|
        BatchNorm(5),
 | 
						|
        Embedding(4, 10),
 | 
						|
        LSTM(4, 5));
 | 
						|
  }
 | 
						|
}
 |