mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Let's have some fun. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78828 Approved by: https://github.com/ezyang
458 lines
14 KiB
C++
458 lines
14 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/torch.h>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
#include <algorithm>
|
|
#include <string>
|
|
|
|
using namespace torch::nn;
|
|
|
|
struct AnyModuleTest : torch::test::SeedingFixture {};
|
|
|
|
TEST_F(AnyModuleTest, SimpleReturnType) {
|
|
struct M : torch::nn::Module {
|
|
int forward() {
|
|
return 123;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
ASSERT_EQ(any.forward<int>(), 123);
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, SimpleReturnTypeAndSingleArgument) {
|
|
struct M : torch::nn::Module {
|
|
int forward(int x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
ASSERT_EQ(any.forward<int>(5), 5);
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, StringLiteralReturnTypeAndArgument) {
|
|
struct M : torch::nn::Module {
|
|
const char* forward(const char* x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
ASSERT_EQ(any.forward<const char*>("hello"), std::string("hello"));
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, StringReturnTypeWithConstArgument) {
|
|
struct M : torch::nn::Module {
|
|
std::string forward(int x, const double f) {
|
|
return std::to_string(static_cast<int>(x + f));
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
int x = 4;
|
|
ASSERT_EQ(any.forward<std::string>(x, 3.14), std::string("7"));
|
|
}
|
|
|
|
TEST_F(
|
|
AnyModuleTest,
|
|
TensorReturnTypeAndStringArgumentsWithFunkyQualifications) {
|
|
struct M : torch::nn::Module {
|
|
torch::Tensor forward(
|
|
std::string a,
|
|
const std::string& b,
|
|
std::string&& c) {
|
|
const auto s = a + b + c;
|
|
return torch::ones({static_cast<int64_t>(s.size())});
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
ASSERT_TRUE(
|
|
any.forward(std::string("a"), std::string("ab"), std::string("abc"))
|
|
.sum()
|
|
.item<int32_t>() == 6);
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, WrongArgumentType) {
|
|
struct M : torch::nn::Module {
|
|
int forward(float x) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any(M{});
|
|
ASSERT_THROWS_WITH(
|
|
any.forward(5.0),
|
|
"Expected argument #0 to be of type float, "
|
|
"but received value of type double");
|
|
}
|
|
|
|
struct M_test_wrong_number_of_arguments : torch::nn::Module {
|
|
int forward(int a, int b) {
|
|
return a + b;
|
|
}
|
|
};
|
|
|
|
TEST_F(AnyModuleTest, WrongNumberOfArguments) {
|
|
AnyModule any(M_test_wrong_number_of_arguments{});
|
|
#if defined(_MSC_VER)
|
|
std::string module_name = "struct M_test_wrong_number_of_arguments";
|
|
#else
|
|
std::string module_name = "M_test_wrong_number_of_arguments";
|
|
#endif
|
|
ASSERT_THROWS_WITH(
|
|
any.forward(),
|
|
module_name +
|
|
"'s forward() method expects 2 argument(s), but received 0. "
|
|
"If " +
|
|
module_name +
|
|
"'s forward() method has default arguments, "
|
|
"please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
|
|
ASSERT_THROWS_WITH(
|
|
any.forward(5),
|
|
module_name +
|
|
"'s forward() method expects 2 argument(s), but received 1. "
|
|
"If " +
|
|
module_name +
|
|
"'s forward() method has default arguments, "
|
|
"please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
|
|
ASSERT_THROWS_WITH(
|
|
any.forward(1, 2, 3),
|
|
module_name +
|
|
"'s forward() method expects 2 argument(s), but received 3.");
|
|
}
|
|
|
|
struct M_default_arg_with_macro : torch::nn::Module {
|
|
double forward(int a, int b = 2, double c = 3.0) {
|
|
return a + b + c;
|
|
}
|
|
|
|
protected:
|
|
FORWARD_HAS_DEFAULT_ARGS(
|
|
{1, torch::nn::AnyValue(2)},
|
|
{2, torch::nn::AnyValue(3.0)})
|
|
};
|
|
|
|
struct M_default_arg_without_macro : torch::nn::Module {
|
|
double forward(int a, int b = 2, double c = 3.0) {
|
|
return a + b + c;
|
|
}
|
|
};
|
|
|
|
TEST_F(
|
|
AnyModuleTest,
|
|
PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) {
|
|
{
|
|
AnyModule any(M_default_arg_with_macro{});
|
|
|
|
ASSERT_EQ(any.forward<double>(1), 6.0);
|
|
ASSERT_EQ(any.forward<double>(1, 3), 7.0);
|
|
ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);
|
|
|
|
ASSERT_THROWS_WITH(
|
|
any.forward(),
|
|
"M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 0.");
|
|
ASSERT_THROWS_WITH(
|
|
any.forward(1, 2, 3.0, 4),
|
|
"M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 4.");
|
|
}
|
|
{
|
|
AnyModule any(M_default_arg_without_macro{});
|
|
|
|
ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);
|
|
|
|
#if defined(_MSC_VER)
|
|
std::string module_name = "struct M_default_arg_without_macro";
|
|
#else
|
|
std::string module_name = "M_default_arg_without_macro";
|
|
#endif
|
|
|
|
ASSERT_THROWS_WITH(
|
|
any.forward(),
|
|
module_name +
|
|
"'s forward() method expects 3 argument(s), but received 0. "
|
|
"If " +
|
|
module_name +
|
|
"'s forward() method has default arguments, "
|
|
"please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
|
|
ASSERT_THROWS_WITH(
|
|
any.forward<double>(1),
|
|
module_name +
|
|
"'s forward() method expects 3 argument(s), but received 1. "
|
|
"If " +
|
|
module_name +
|
|
"'s forward() method has default arguments, "
|
|
"please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
|
|
ASSERT_THROWS_WITH(
|
|
any.forward<double>(1, 3),
|
|
module_name +
|
|
"'s forward() method expects 3 argument(s), but received 2. "
|
|
"If " +
|
|
module_name +
|
|
"'s forward() method has default arguments, "
|
|
"please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
|
|
ASSERT_THROWS_WITH(
|
|
any.forward(1, 2, 3.0, 4),
|
|
module_name +
|
|
"'s forward() method expects 3 argument(s), but received 4.");
|
|
}
|
|
}
|
|
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
|
|
int value;
|
|
int forward(float x) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
|
return x;
|
|
}
|
|
};
|
|
|
|
TEST_F(AnyModuleTest, GetWithCorrectTypeSucceeds) {
|
|
AnyModule any(M{5});
|
|
ASSERT_EQ(any.get<M>().value, 5);
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, GetWithIncorrectTypeThrows) {
|
|
struct N : torch::nn::Module {
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
return input;
|
|
}
|
|
};
|
|
AnyModule any(M{5});
|
|
ASSERT_THROWS_WITH(any.get<N>(), "Attempted to cast module");
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, PtrWithBaseClassSucceeds) {
|
|
AnyModule any(M{5});
|
|
auto ptr = any.ptr();
|
|
ASSERT_NE(ptr, nullptr);
|
|
ASSERT_EQ(ptr->name(), "M");
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, PtrWithGoodDowncastSuccceeds) {
|
|
AnyModule any(M{5});
|
|
auto ptr = any.ptr<M>();
|
|
ASSERT_NE(ptr, nullptr);
|
|
ASSERT_EQ(ptr->value, 5);
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, PtrWithBadDowncastThrows) {
|
|
struct N : torch::nn::Module {
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
return input;
|
|
}
|
|
};
|
|
AnyModule any(M{5});
|
|
ASSERT_THROWS_WITH(any.ptr<N>(), "Attempted to cast module");
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, DefaultStateIsEmpty) {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
int forward(float x) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any;
|
|
ASSERT_TRUE(any.is_empty());
|
|
any = std::make_shared<M>(5);
|
|
ASSERT_FALSE(any.is_empty());
|
|
ASSERT_EQ(any.get<M>().value, 5);
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, AllMethodsThrowForEmptyAnyModule) {
|
|
struct M : torch::nn::Module {
|
|
int forward(int x) {
|
|
return x;
|
|
}
|
|
};
|
|
AnyModule any;
|
|
ASSERT_TRUE(any.is_empty());
|
|
ASSERT_THROWS_WITH(any.get<M>(), "Cannot call get() on an empty AnyModule");
|
|
ASSERT_THROWS_WITH(any.ptr<M>(), "Cannot call ptr() on an empty AnyModule");
|
|
ASSERT_THROWS_WITH(any.ptr(), "Cannot call ptr() on an empty AnyModule");
|
|
ASSERT_THROWS_WITH(
|
|
any.type_info(), "Cannot call type_info() on an empty AnyModule");
|
|
ASSERT_THROWS_WITH(
|
|
any.forward<int>(5), "Cannot call forward() on an empty AnyModule");
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, CanMoveAssignDifferentModules) {
|
|
struct M : torch::nn::Module {
|
|
std::string forward(int x) {
|
|
return std::to_string(x);
|
|
}
|
|
};
|
|
struct N : torch::nn::Module {
|
|
int forward(float x) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
|
return 3 + x;
|
|
}
|
|
};
|
|
AnyModule any;
|
|
ASSERT_TRUE(any.is_empty());
|
|
any = std::make_shared<M>();
|
|
ASSERT_FALSE(any.is_empty());
|
|
ASSERT_EQ(any.forward<std::string>(5), "5");
|
|
any = std::make_shared<N>();
|
|
ASSERT_FALSE(any.is_empty());
|
|
ASSERT_EQ(any.forward<int>(5.0f), 8);
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, ConstructsFromModuleHolder) {
|
|
struct MImpl : torch::nn::Module {
|
|
explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
|
|
int value;
|
|
int forward(float x) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
|
return x;
|
|
}
|
|
};
|
|
|
|
struct M : torch::nn::ModuleHolder<MImpl> {
|
|
using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
|
|
using torch::nn::ModuleHolder<MImpl>::get;
|
|
};
|
|
|
|
AnyModule any(M{5});
|
|
ASSERT_EQ(any.get<MImpl>().value, 5);
|
|
ASSERT_EQ(any.get<M>()->value, 5);
|
|
|
|
AnyModule module(Linear(3, 4));
|
|
std::shared_ptr<Module> ptr = module.ptr();
|
|
Linear linear(module.get<Linear>());
|
|
}
|
|
|
|
TEST_F(AnyModuleTest, ConvertsVariableToTensorCorrectly) {
|
|
struct M : torch::nn::Module {
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
return input;
|
|
}
|
|
};
|
|
|
|
// When you have an autograd::Variable, it should be converted to a
|
|
// torch::Tensor before being passed to the function (to avoid a type
|
|
// mismatch).
|
|
AnyModule any(M{});
|
|
ASSERT_TRUE(
|
|
any.forward(torch::autograd::Variable(torch::ones(5)))
|
|
.sum()
|
|
.item<float>() == 5);
|
|
// at::Tensors that are not variables work too.
|
|
ASSERT_EQ(any.forward(at::ones(5)).sum().item<float>(), 5);
|
|
}
|
|
|
|
namespace torch {
|
|
namespace nn {
|
|
struct TestAnyValue {
|
|
template <typename T>
|
|
// NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
|
|
explicit TestAnyValue(T&& value) : value_(std::forward<T>(value)) {}
|
|
AnyValue operator()() {
|
|
return std::move(value_);
|
|
}
|
|
AnyValue value_;
|
|
};
|
|
template <typename T>
|
|
AnyValue make_value(T&& value) {
|
|
return TestAnyValue(std::forward<T>(value))();
|
|
}
|
|
} // namespace nn
|
|
} // namespace torch
|
|
|
|
struct AnyValueTest : torch::test::SeedingFixture {};
|
|
|
|
TEST_F(AnyValueTest, CorrectlyAccessesIntWhenCorrectType) {
|
|
auto value = make_value<int>(5);
|
|
ASSERT_NE(value.try_get<int>(), nullptr);
|
|
// const and non-const types have the same typeid(),
|
|
// but casting Holder<int> to Holder<const int> is undefined
|
|
// behavior according to UBSAN:
|
|
// https://github.com/pytorch/pytorch/issues/26964
|
|
// ASSERT_NE(value.try_get<const int>(), nullptr);
|
|
ASSERT_EQ(value.get<int>(), 5);
|
|
}
|
|
// This test does not work at all, because it looks like make_value
|
|
// decays const int into int.
|
|
// TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) {
|
|
// auto value = make_value<const int>(5);
|
|
// ASSERT_NE(value.try_get<const int>(), nullptr);
|
|
// // ASSERT_NE(value.try_get<int>(), nullptr);
|
|
// ASSERT_EQ(value.get<const int>(), 5);
|
|
//}
|
|
TEST_F(AnyValueTest, CorrectlyAccessesStringLiteralWhenCorrectType) {
|
|
auto value = make_value("hello");
|
|
ASSERT_NE(value.try_get<const char*>(), nullptr);
|
|
ASSERT_EQ(value.get<const char*>(), std::string("hello"));
|
|
}
|
|
TEST_F(AnyValueTest, CorrectlyAccessesStringWhenCorrectType) {
|
|
auto value = make_value(std::string("hello"));
|
|
ASSERT_NE(value.try_get<std::string>(), nullptr);
|
|
ASSERT_EQ(value.get<std::string>(), "hello");
|
|
}
|
|
TEST_F(AnyValueTest, CorrectlyAccessesPointersWhenCorrectType) {
|
|
std::string s("hello");
|
|
std::string* p = &s;
|
|
auto value = make_value(p);
|
|
ASSERT_NE(value.try_get<std::string*>(), nullptr);
|
|
ASSERT_EQ(*value.get<std::string*>(), "hello");
|
|
}
|
|
TEST_F(AnyValueTest, CorrectlyAccessesReferencesWhenCorrectType) {
|
|
std::string s("hello");
|
|
const std::string& t = s;
|
|
auto value = make_value(t);
|
|
ASSERT_NE(value.try_get<std::string>(), nullptr);
|
|
ASSERT_EQ(value.get<std::string>(), "hello");
|
|
}
|
|
|
|
TEST_F(AnyValueTest, TryGetReturnsNullptrForTheWrongType) {
|
|
auto value = make_value(5);
|
|
ASSERT_NE(value.try_get<int>(), nullptr);
|
|
ASSERT_EQ(value.try_get<float>(), nullptr);
|
|
ASSERT_EQ(value.try_get<long>(), nullptr);
|
|
ASSERT_EQ(value.try_get<std::string>(), nullptr);
|
|
}
|
|
|
|
TEST_F(AnyValueTest, GetThrowsForTheWrongType) {
|
|
auto value = make_value(5);
|
|
ASSERT_NE(value.try_get<int>(), nullptr);
|
|
ASSERT_THROWS_WITH(
|
|
value.get<float>(),
|
|
"Attempted to cast AnyValue to float, "
|
|
"but its actual type is int");
|
|
ASSERT_THROWS_WITH(
|
|
value.get<long>(),
|
|
"Attempted to cast AnyValue to long, "
|
|
"but its actual type is int");
|
|
}
|
|
|
|
TEST_F(AnyValueTest, MoveConstructionIsAllowed) {
|
|
auto value = make_value(5);
|
|
auto copy = make_value(std::move(value));
|
|
ASSERT_NE(copy.try_get<int>(), nullptr);
|
|
ASSERT_EQ(copy.get<int>(), 5);
|
|
}
|
|
|
|
TEST_F(AnyValueTest, MoveAssignmentIsAllowed) {
|
|
auto value = make_value(5);
|
|
auto copy = make_value(10);
|
|
copy = std::move(value);
|
|
ASSERT_NE(copy.try_get<int>(), nullptr);
|
|
ASSERT_EQ(copy.get<int>(), 5);
|
|
}
|
|
|
|
TEST_F(AnyValueTest, TypeInfoIsCorrectForInt) {
|
|
auto value = make_value(5);
|
|
ASSERT_EQ(value.type_info().hash_code(), typeid(int).hash_code());
|
|
}
|
|
|
|
TEST_F(AnyValueTest, TypeInfoIsCorrectForStringLiteral) {
|
|
auto value = make_value("hello");
|
|
ASSERT_EQ(value.type_info().hash_code(), typeid(const char*).hash_code());
|
|
}
|
|
|
|
TEST_F(AnyValueTest, TypeInfoIsCorrectForString) {
|
|
auto value = make_value(std::string("hello"));
|
|
ASSERT_EQ(value.type_info().hash_code(), typeid(std::string).hash_code());
|
|
}
|