mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Attempts to fix #92656 BC-breaking! This changes the default of zero_grad in optim and in nn to default set grads to None instead of zero tensors. We are changing the default because there are proven perf wins and existing code has typically not regressed due to this change. (will probably have to flesh out this note more). Pull Request resolved: https://github.com/pytorch/pytorch/pull/92731 Approved by: https://github.com/ngimel
1058 lines
34 KiB
C++
1058 lines
34 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <c10/util/irange.h>
|
|
#include <torch/torch.h>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
using namespace torch::nn;
|
|
using namespace torch::test;
|
|
|
|
struct AGIUnit : torch::nn::Module {};
|
|
|
|
namespace test {
|
|
struct AGIUnit : torch::nn::Module {};
|
|
struct AGIUnit2 : torch::nn::Module {
|
|
AGIUnit2() : torch::nn::Module("Foo") {}
|
|
};
|
|
} // namespace test
|
|
|
|
struct ModuleTest : torch::test::SeedingFixture {};
|
|
|
|
TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) {
|
|
Linear module(3, 4);
|
|
ASSERT_TRUE(module->is_training());
|
|
|
|
module->eval();
|
|
ASSERT_FALSE(module->is_training());
|
|
|
|
module->train();
|
|
ASSERT_TRUE(module->is_training());
|
|
}
|
|
|
|
TEST_F(ModuleTest, ZeroGrad) {
|
|
Linear module(3, 4);
|
|
auto weight = torch::ones({8, 3}, torch::requires_grad());
|
|
auto loss = module(weight).sum();
|
|
loss.backward();
|
|
for (auto& parameter : module->parameters()) {
|
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
|
auto grad = parameter.grad();
|
|
ASSERT_TRUE(grad.defined());
|
|
ASSERT_NE(grad.sum().item<float>(), 0);
|
|
}
|
|
module->zero_grad();
|
|
for (auto& parameter : module->parameters()) {
|
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
|
auto grad = parameter.grad();
|
|
ASSERT_FALSE(grad.defined());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, ZeroGradWithUndefined) {
|
|
struct TestModule : torch::nn::Module {
|
|
TestModule() {
|
|
x = register_parameter("x", torch::ones(5, torch::requires_grad()));
|
|
y = register_parameter("y", torch::ones(5, torch::requires_grad()));
|
|
}
|
|
torch::Tensor x, y;
|
|
};
|
|
|
|
TestModule module;
|
|
auto z = module.x * 2;
|
|
z.sum().backward();
|
|
|
|
ASSERT_TRUE(module.x.grad().defined());
|
|
ASSERT_FALSE(module.y.grad().defined());
|
|
|
|
module.zero_grad(false); // set_to_none = false
|
|
|
|
ASSERT_TRUE(module.x.grad().defined());
|
|
ASSERT_FALSE(module.y.grad().defined());
|
|
|
|
ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
|
|
|
|
module.zero_grad();
|
|
|
|
ASSERT_FALSE(module.x.grad().defined());
|
|
ASSERT_FALSE(module.y.grad().defined());
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
ASSERT_THROWS_WITH(
|
|
TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)),
|
|
"Submodule name must not contain a dot (got 'name.with.dot')");
|
|
ASSERT_THROWS_WITH(
|
|
TestModel{}.register_module("", torch::nn::Linear(3, 4)),
|
|
"Submodule name must not be empty");
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
TestModel model;
|
|
model.register_module("linear", torch::nn::Linear(3, 4));
|
|
ASSERT_THROWS_WITH(
|
|
model.register_module("linear", torch::nn::Linear(3, 4)),
|
|
"Submodule 'linear' already defined");
|
|
}
|
|
|
|
TEST_F(ModuleTest, ReplaceModuleThrowsForUnknownModuleName) {
|
|
torch::nn::Module model;
|
|
ASSERT_THROWS_WITH(
|
|
model.replace_module("linear", torch::nn::Linear(3, 4)),
|
|
"Submodule 'linear' is not defined");
|
|
}
|
|
|
|
TEST_F(ModuleTest, ReplaceModule) {
|
|
struct TestModel : public torch::nn::Module {
|
|
torch::nn::Linear l1{nullptr};
|
|
TestModel() {
|
|
l1 = register_module("l1", torch::nn::Linear(3, 4));
|
|
}
|
|
};
|
|
auto model = std::make_shared<TestModel>();
|
|
model->l1 = model->replace_module("l1", torch::nn::Linear(5, 6));
|
|
ASSERT_EQ(model->named_parameters()["l1.weight"].size(0), 6);
|
|
ASSERT_EQ(model->l1.get(), model->named_modules()["l1"]->as<Linear>());
|
|
}
|
|
|
|
TEST_F(ModuleTest, UnregisterModule) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
TestModel model;
|
|
ASSERT_THROWS_WITH(
|
|
model.unregister_module("linear"),
|
|
"No Module with name `linear` is registered");
|
|
model.register_module("linear", torch::nn::Linear(3, 4));
|
|
model.unregister_module("linear");
|
|
ASSERT_TRUE(model.children().empty());
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
ASSERT_THROWS_WITH(
|
|
TestModel{}.register_parameter("name.with.dot", torch::ones(5)),
|
|
"Parameter name must not contain a dot (got 'name.with.dot')");
|
|
ASSERT_THROWS_WITH(
|
|
TestModel{}.register_parameter("", torch::ones(5)),
|
|
"Parameter name must not be empty");
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
TestModel model;
|
|
model.register_parameter("p", torch::ones(5));
|
|
ASSERT_THROWS_WITH(
|
|
model.register_parameter("p", torch::ones(5)),
|
|
"Parameter 'p' already defined");
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterParameterUndefinedTensor) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
{
|
|
TestModel model;
|
|
model.register_parameter(
|
|
"undefined_tensor", torch::Tensor(), /*requires_grad=*/false);
|
|
ASSERT_EQ(model.parameters().size(), 0);
|
|
}
|
|
{
|
|
WarningCapture warnings;
|
|
|
|
TestModel model;
|
|
model.register_parameter("undefined_tensor", torch::Tensor());
|
|
ASSERT_EQ(model.parameters().size(), 0);
|
|
|
|
ASSERT_EQ(
|
|
count_substr_occurrences(
|
|
warnings.str(),
|
|
"Ignoring the `requires_grad=true` function parameter"),
|
|
1);
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
ASSERT_THROWS_WITH(
|
|
TestModel{}.register_buffer("name.with.dot", torch::ones(5)),
|
|
"Buffer name must not contain a dot (got 'name.with.dot')");
|
|
ASSERT_THROWS_WITH(
|
|
TestModel{}.register_buffer("", torch::ones(5)),
|
|
"Buffer name must not be empty");
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
TestModel model;
|
|
model.register_buffer("p", torch::ones(5));
|
|
ASSERT_THROWS_WITH(
|
|
model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined");
|
|
}
|
|
|
|
TEST_F(ModuleTest, CanGetName) {
|
|
// CHECK instead of REQUIRE because demangling may fail.
|
|
AGIUnit agi;
|
|
// Call it twice just to make sure there are no bugs in the lazy
|
|
// initialization semantics.
|
|
EXPECT_EQ(agi.name(), "AGIUnit");
|
|
EXPECT_EQ(agi.name(), "AGIUnit");
|
|
EXPECT_EQ(test::AGIUnit().name(), "test::AGIUnit");
|
|
EXPECT_EQ(test::AGIUnit2().name(), "Foo");
|
|
}
|
|
|
|
TEST_F(ModuleTest, AsCastsModulesCorrectly) {
|
|
Linear module(3, 4);
|
|
ASSERT_EQ(module->as<Linear>(), module.get());
|
|
ASSERT_EQ(module->as<LinearImpl>(), module.get());
|
|
ASSERT_EQ(module->as<Module>(), module.get());
|
|
ASSERT_EQ(module->as<AGIUnit>(), nullptr);
|
|
|
|
std::shared_ptr<Module> raw = module.ptr();
|
|
ASSERT_EQ(raw->as<Linear>(), module.get());
|
|
ASSERT_EQ(raw->as<LinearImpl>(), module.get());
|
|
ASSERT_EQ(raw->as<Module>(), module.get());
|
|
ASSERT_EQ(raw->as<AGIUnit>(), nullptr);
|
|
|
|
Module& raw_ref = *raw.get();
|
|
ASSERT_EQ(raw_ref.as<Linear>(), module.get());
|
|
ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get());
|
|
ASSERT_EQ(raw_ref.as<Module>(), module.get());
|
|
ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr);
|
|
if (auto* linear = raw_ref.as<Linear>()) {
|
|
ASSERT_EQ(linear->weight.ndimension(), 2);
|
|
}
|
|
|
|
AGIUnit unit;
|
|
ASSERT_EQ(unit.as<Linear>(), nullptr);
|
|
ASSERT_EQ(unit.as<LinearImpl>(), nullptr);
|
|
ASSERT_EQ(unit.as<AGIUnit>(), &unit);
|
|
}
|
|
|
|
void test_DeviceOrDtypeConversionSkipsUndefinedTensor(
|
|
torch::Device to_device,
|
|
torch::Dtype to_dtype) {
|
|
{
|
|
// Case 1: Undefined tensors as parameters
|
|
Linear module(LinearOptions(10, 20).bias(false));
|
|
ASSERT_TRUE(module->weight.defined());
|
|
ASSERT_FALSE(module->bias.defined());
|
|
|
|
module->to(to_device);
|
|
ASSERT_TRUE(module->weight.defined());
|
|
ASSERT_EQ(module->weight.device().type(), to_device.type());
|
|
ASSERT_FALSE(module->bias.defined());
|
|
|
|
module->to(to_dtype);
|
|
ASSERT_TRUE(module->weight.defined());
|
|
ASSERT_EQ(module->weight.dtype(), to_dtype);
|
|
ASSERT_FALSE(module->bias.defined());
|
|
}
|
|
{
|
|
// Case 2: Undefined tensors as buffers
|
|
BatchNorm1d module(
|
|
BatchNorm1dOptions(5).track_running_stats(false).affine(true));
|
|
ASSERT_TRUE(module->weight.defined());
|
|
ASSERT_FALSE(module->running_mean.defined());
|
|
|
|
module->to(to_device);
|
|
ASSERT_TRUE(module->weight.defined());
|
|
ASSERT_EQ(module->weight.device().type(), to_device.type());
|
|
ASSERT_FALSE(module->running_mean.defined());
|
|
|
|
module->to(to_dtype);
|
|
ASSERT_TRUE(module->weight.defined());
|
|
ASSERT_EQ(module->weight.dtype(), to_dtype);
|
|
ASSERT_FALSE(module->running_mean.defined());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor) {
|
|
test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::kCPU, torch::kDouble);
|
|
}
|
|
|
|
TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA) {
|
|
test_DeviceOrDtypeConversionSkipsUndefinedTensor(
|
|
torch::kCUDA, torch::kDouble);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) {
|
|
{
|
|
Linear module(LinearOptions(10, 20).bias(false));
|
|
|
|
auto params = module->parameters();
|
|
ASSERT_EQ(params.size(), 1);
|
|
auto named_params = module->named_parameters();
|
|
ASSERT_EQ(named_params.size(), 1);
|
|
|
|
ASSERT_TRUE(pointer_equal(params[0], named_params["weight"]));
|
|
ASSERT_TRUE(pointer_equal(named_params["weight"], module->weight));
|
|
}
|
|
{
|
|
BatchNorm1d module(
|
|
BatchNorm1dOptions(5).track_running_stats(false).affine(false));
|
|
|
|
auto buffers = module->buffers();
|
|
ASSERT_EQ(buffers.size(), 0);
|
|
auto named_buffers = module->named_buffers();
|
|
ASSERT_EQ(named_buffers.size(), 0);
|
|
}
|
|
{
|
|
BatchNorm1d module(
|
|
BatchNorm1dOptions(5).track_running_stats(true).affine(false));
|
|
|
|
auto buffers = module->buffers();
|
|
ASSERT_EQ(buffers.size(), 3);
|
|
auto named_buffers = module->named_buffers();
|
|
ASSERT_EQ(named_buffers.size(), 3);
|
|
|
|
ASSERT_TRUE(pointer_equal(buffers[0], named_buffers["running_mean"]));
|
|
ASSERT_TRUE(
|
|
pointer_equal(named_buffers["running_mean"], module->running_mean));
|
|
ASSERT_TRUE(pointer_equal(buffers[1], named_buffers["running_var"]));
|
|
ASSERT_TRUE(
|
|
pointer_equal(named_buffers["running_var"], module->running_var));
|
|
ASSERT_TRUE(
|
|
pointer_equal(buffers[2], named_buffers["num_batches_tracked"]));
|
|
ASSERT_TRUE(pointer_equal(
|
|
named_buffers["num_batches_tracked"], module->num_batches_tracked));
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, Conversion_MultiCUDA) {
|
|
Linear module(128, 64);
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.device(), torch::Device(torch::kCPU));
|
|
ASSERT_EQ(parameter.dtype(), torch::kFloat32);
|
|
}
|
|
{
|
|
module->to({torch::kCUDA, 0});
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
|
|
ASSERT_EQ(parameter.device().index(), 0);
|
|
}
|
|
module->to({torch::kCUDA, 1});
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
|
|
ASSERT_EQ(parameter.device().index(), 1);
|
|
}
|
|
}
|
|
{
|
|
module->to(torch::Device(torch::kCPU));
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU);
|
|
}
|
|
}
|
|
{
|
|
module->to(torch::kFloat64);
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.dtype(), torch::kFloat64);
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, Conversion_NoGrad_MultiCUDA) {
|
|
Linear module(128, 64);
|
|
for (auto& parameter : module->parameters()) {
|
|
parameter.requires_grad_(false);
|
|
}
|
|
{
|
|
module->to(torch::kInt32);
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.dtype(), torch::kInt32);
|
|
}
|
|
}
|
|
{
|
|
module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
|
|
ASSERT_EQ(parameter.device().index(), 1);
|
|
}
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.dtype(), torch::kUInt8);
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
|
|
struct UnCloneable : Module {};
|
|
UnCloneable module;
|
|
ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented");
|
|
}
|
|
|
|
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
|
|
struct Cloneable : Module {
|
|
std::shared_ptr<Module> clone(
|
|
const torch::optional<torch::Device>& device =
|
|
torch::nullopt) const override {
|
|
return nullptr;
|
|
}
|
|
};
|
|
Cloneable module;
|
|
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
|
ASSERT_NO_THROW({ module.clone(); });
|
|
}
|
|
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
struct TestDistinctParametersModule
|
|
: public Cloneable<TestDistinctParametersModule> {
|
|
TestDistinctParametersModule() {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
reset();
|
|
}
|
|
void reset() override {
|
|
l1 = register_module("l1", Linear(10, 3));
|
|
l2 = register_module("l2", Linear(3, 5));
|
|
l3 = register_module("l3", Linear(5, 100));
|
|
buffer = register_buffer("buf", torch::ones({2, 2}));
|
|
}
|
|
|
|
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
|
|
torch::Tensor buffer;
|
|
};
|
|
|
|
void testDistinctParameters(
|
|
std::shared_ptr<Module> m1,
|
|
std::shared_ptr<Module> m2) {
|
|
auto params1 = m1->named_parameters();
|
|
auto params2 = m2->named_parameters();
|
|
ASSERT_EQ(params1.size(), 6);
|
|
ASSERT_EQ(params2.size(), 6);
|
|
for (auto& param : params1) {
|
|
ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
|
|
ASSERT_TRUE(param->allclose(params2[param.key()]));
|
|
param->add_(2);
|
|
}
|
|
for (auto& param : params1) {
|
|
ASSERT_FALSE(param->allclose(params2[param.key()]));
|
|
}
|
|
|
|
auto buffers1 = m1->named_buffers();
|
|
auto buffers2 = m2->named_buffers();
|
|
ASSERT_EQ(buffers1.size(), 1);
|
|
ASSERT_EQ(buffers2.size(), 1);
|
|
for (auto& buffer : buffers1) {
|
|
ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()]));
|
|
ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()]));
|
|
buffer->add_(2);
|
|
}
|
|
for (auto& buffer : buffers1) {
|
|
ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()]));
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, CloneCreatesDistinctParameters) {
|
|
auto module = std::make_shared<TestDistinctParametersModule>();
|
|
torch::NoGradGuard no_grad;
|
|
auto module2 = module->clone();
|
|
testDistinctParameters(module, module2);
|
|
}
|
|
|
|
TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_CUDA) {
|
|
auto module = std::make_shared<TestDistinctParametersModule>();
|
|
torch::NoGradGuard no_grad;
|
|
torch::Device device(torch::kCUDA, 0);
|
|
module->to(device);
|
|
auto module2 = module->clone(device);
|
|
testDistinctParameters(module, module2);
|
|
}
|
|
|
|
TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_MultiCUDA) {
|
|
auto module = std::make_shared<TestDistinctParametersModule>();
|
|
torch::NoGradGuard no_grad;
|
|
torch::Device d0(torch::kCUDA, 0);
|
|
torch::Device d1(torch::kCUDA, 1);
|
|
module->to(d0);
|
|
auto module2 = module->clone(d1);
|
|
|
|
for (auto& param : module->parameters()) {
|
|
ASSERT_EQ(param.device(), d0);
|
|
}
|
|
|
|
for (auto& param : module2->parameters()) {
|
|
ASSERT_EQ(param.device(), d1);
|
|
}
|
|
|
|
// need to move the module back to d0 as allclose expects two tensors on
|
|
// the same device.
|
|
module2->to(d0);
|
|
testDistinctParameters(module, module2);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ClonePreservesExternalReferences) {
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
struct TestModule : public Cloneable<TestModule> {
|
|
TestModule() {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
reset();
|
|
}
|
|
void reset() override {
|
|
weight = register_parameter("weight", torch::ones({4, 4}));
|
|
}
|
|
torch::Tensor weight;
|
|
};
|
|
auto module = std::make_shared<TestModule>();
|
|
{
|
|
torch::NoGradGuard no_grad;
|
|
module->weight += 1;
|
|
}
|
|
ASSERT_TRUE(
|
|
pointer_equal(module->weight, module->named_parameters()["weight"]));
|
|
ASSERT_TRUE(module->weight.allclose(module->named_parameters()["weight"]));
|
|
|
|
auto module2 = std::dynamic_pointer_cast<TestModule>(
|
|
std::shared_ptr<Module>(module->clone()));
|
|
ASSERT_FALSE(pointer_equal(module2->weight, module->weight));
|
|
ASSERT_TRUE(
|
|
pointer_equal(module2->weight, module2->named_parameters()["weight"]));
|
|
ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()["weight"]));
|
|
ASSERT_TRUE(module2->weight.allclose(module->weight));
|
|
ASSERT_FALSE(
|
|
pointer_equal(module2->weight, module->named_parameters()["weight"]));
|
|
}
|
|
|
|
TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) {
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
struct TestModule : public Cloneable<TestModule> {
|
|
TestModule() {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
reset();
|
|
}
|
|
void reset() override {
|
|
weight = register_parameter("weight", torch::ones({4, 4}));
|
|
}
|
|
|
|
torch::Tensor weight;
|
|
int value = 0;
|
|
};
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
struct NestedModule : public Cloneable<NestedModule> {
|
|
NestedModule() {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
reset();
|
|
}
|
|
void reset() override {
|
|
module = register_module("module", std::make_shared<TestModule>());
|
|
}
|
|
std::shared_ptr<TestModule> module;
|
|
};
|
|
|
|
auto a = std::make_shared<NestedModule>();
|
|
{
|
|
torch::NoGradGuard no_grad;
|
|
a->module->weight += 1;
|
|
a->module->value = 123;
|
|
}
|
|
|
|
auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
|
|
|
|
ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight));
|
|
ASSERT_TRUE(pointer_equal(
|
|
b->module->weight, b->module->named_parameters()["weight"]));
|
|
ASSERT_TRUE(
|
|
b->module->named_parameters()["weight"].allclose(a->module->weight));
|
|
ASSERT_TRUE(b->module->weight.allclose(a->module->weight));
|
|
ASSERT_EQ(b->module->value, a->module->value);
|
|
}
|
|
|
|
TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) {
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
struct TestModule : public Cloneable<TestModule> {
|
|
TestModule() {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
reset();
|
|
}
|
|
void reset() override {
|
|
l1 = register_module("l1", Linear(10, 3));
|
|
l2 = register_module("l2", Linear(3, 5));
|
|
l3 = register_module("l3", Linear(5, 100));
|
|
buffer = register_buffer("buf", torch::ones({2, 2}));
|
|
}
|
|
|
|
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
|
|
torch::Tensor buffer;
|
|
};
|
|
|
|
TestModule m;
|
|
torch::Device device(torch::kCUDA, 0);
|
|
|
|
m.to(device);
|
|
|
|
auto clone = m.clone();
|
|
for (const auto& parameter : clone->parameters()) {
|
|
ASSERT_EQ(parameter.device().type(), device.type());
|
|
ASSERT_EQ(parameter.device().index(), device.index());
|
|
}
|
|
for (const auto& buffer : clone->buffers()) {
|
|
ASSERT_EQ(buffer.device().type(), device.type());
|
|
ASSERT_EQ(buffer.device().index(), device.index());
|
|
}
|
|
}
|
|
|
|
TEST_F(
|
|
ModuleTest,
|
|
CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA) {
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
struct TestModule : public Cloneable<TestModule> {
|
|
TestModule() {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
reset();
|
|
}
|
|
void reset() override {
|
|
l1 = register_module("l1", Linear(10, 3));
|
|
l2 = register_module("l2", Linear(3, 5));
|
|
l3 = register_module("l3", Linear(5, 100));
|
|
buffer = register_buffer("buf", torch::ones({2, 2}));
|
|
}
|
|
|
|
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
|
|
torch::Tensor buffer;
|
|
};
|
|
|
|
TestModule m;
|
|
torch::Device device(torch::kCUDA, 1);
|
|
// everything is on CPU here
|
|
auto clone = m.clone(device);
|
|
for (const auto& parameter : clone->parameters()) {
|
|
ASSERT_EQ(parameter.device().type(), device.type());
|
|
ASSERT_EQ(parameter.device().index(), device.index());
|
|
}
|
|
for (const auto& buffer : clone->buffers()) {
|
|
ASSERT_EQ(buffer.device().type(), device.type());
|
|
ASSERT_EQ(buffer.device().index(), device.index());
|
|
}
|
|
}
|
|
|
|
struct ParameterTestModule : Module {
|
|
ParameterTestModule() {
|
|
a = register_parameter("a", torch::zeros({2, 2}));
|
|
b = register_parameter("b", torch::ones({2, 2}));
|
|
c = register_parameter("c", torch::ones({2, 2}) * 2);
|
|
}
|
|
|
|
torch::Tensor a, b, c;
|
|
};
|
|
|
|
TEST_F(ModuleTest, HasCorrectNumberOfParameters) {
|
|
ParameterTestModule module;
|
|
ASSERT_EQ(module.parameters().size(), 3);
|
|
ASSERT_EQ(module.named_parameters().size(), 3);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) {
|
|
ParameterTestModule module;
|
|
auto parameters = module.named_parameters();
|
|
ASSERT_TRUE(parameters.contains("a"));
|
|
ASSERT_TRUE(parameters.contains("b"));
|
|
ASSERT_TRUE(parameters.contains("c"));
|
|
}
|
|
|
|
struct BufferTestModule : Module {
|
|
BufferTestModule() {
|
|
a = register_buffer("a", torch::zeros({2, 2}));
|
|
b = register_buffer("b", torch::ones({2, 2}));
|
|
c = register_buffer("c", torch::ones({2, 2}) * 2);
|
|
}
|
|
|
|
torch::Tensor a, b, c;
|
|
};
|
|
|
|
TEST_F(ModuleTest, HasCorrectNumberOfBuffers) {
|
|
BufferTestModule module;
|
|
ASSERT_EQ(module.buffers().size(), 3);
|
|
ASSERT_EQ(module.named_buffers().size(), 3);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) {
|
|
BufferTestModule module;
|
|
auto buffers = module.named_buffers();
|
|
ASSERT_TRUE(buffers.contains("a"));
|
|
ASSERT_TRUE(buffers.contains("b"));
|
|
ASSERT_TRUE(buffers.contains("c"));
|
|
}
|
|
|
|
struct AImpl : torch::nn::Module {
|
|
AImpl() : x_(123) {}
|
|
AImpl(int x) : x_(x) {}
|
|
int x_;
|
|
};
|
|
TORCH_MODULE(A);
|
|
|
|
TEST_F(
|
|
ModuleTest,
|
|
DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) {
|
|
A a;
|
|
ASSERT_TRUE(a);
|
|
ASSERT_FALSE(a.is_empty());
|
|
ASSERT_EQ(a->x_, 123);
|
|
}
|
|
|
|
TEST_F(
|
|
ModuleTest,
|
|
ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) {
|
|
A a(5);
|
|
ASSERT_TRUE(a);
|
|
ASSERT_FALSE(a.is_empty());
|
|
ASSERT_EQ(a->x_, 5);
|
|
}
|
|
|
|
TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) {
|
|
A a = nullptr;
|
|
ASSERT_FALSE(a);
|
|
ASSERT_TRUE(a.is_empty());
|
|
ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder");
|
|
}
|
|
|
|
struct TestModule : public torch::nn::Module {
|
|
TestModule(int64_t size) {
|
|
p1 = register_parameter("p1", torch::randn({size}));
|
|
p2 = register_parameter("p2", torch::randn({size}));
|
|
b1 = register_buffer("b1", torch::randn({size}));
|
|
b2 = register_buffer("b2", torch::randn({size}));
|
|
}
|
|
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
return input;
|
|
}
|
|
|
|
torch::Tensor p1, p2, b1, b2;
|
|
};
|
|
|
|
TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) {
|
|
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
|
|
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
|
|
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
|
|
model.ptr(), model[0], model[1], model[2]};
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
// Assert pointer equality.
|
|
ASSERT_EQ(modules[i].get(), expected[i].get());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) {
|
|
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
|
|
std::vector<std::shared_ptr<torch::nn::Module>> modules =
|
|
model->modules(/*include_self=*/false);
|
|
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
|
|
model[0], model[1], model[2]};
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
// Assert pointer equality.
|
|
ASSERT_EQ(modules[i].get(), expected[i].get());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) {
|
|
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
|
|
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
|
|
model->named_modules();
|
|
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
|
|
model.ptr(), model[0], model[1], model[2]};
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
// Assert pointer equality.
|
|
ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string());
|
|
ASSERT_EQ(modules[i].value().get(), expected[i].get());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) {
|
|
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
|
|
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
|
|
model->named_modules(
|
|
/*name_prefix=*/std::string(), /*include_self=*/false);
|
|
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
|
|
model[0], model[1], model[2]};
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
// Assert pointer equality.
|
|
ASSERT_EQ(modules[i].key(), std::to_string(i));
|
|
ASSERT_EQ(modules[i].value().get(), expected[i].get());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) {
|
|
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
|
|
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
|
|
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
|
|
model[0], model[1], model[2]};
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
// Assert pointer equality.
|
|
ASSERT_EQ(modules[i].get(), expected[i].get());
|
|
}
|
|
|
|
// For this flat model, this should be true.
|
|
ASSERT_EQ(modules, model->modules(/*include_self=*/false));
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) {
|
|
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
|
|
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
|
|
model->named_children();
|
|
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
|
|
model[0], model[1], model[2]};
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
// Assert pointer equality.
|
|
ASSERT_EQ(modules[i].key(), std::to_string(i));
|
|
ASSERT_EQ(modules[i].value().get(), expected[i].get());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) {
|
|
TestModule module(1);
|
|
std::vector<torch::Tensor> parameters = module.parameters();
|
|
ASSERT_EQ(parameters.size(), 2);
|
|
ASSERT_EQ(parameters[0].data_ptr<float>(), module.p1.data_ptr<float>());
|
|
ASSERT_EQ(parameters[1].data_ptr<float>(), module.p2.data_ptr<float>());
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) {
|
|
TestModule module(1);
|
|
torch::OrderedDict<std::string, torch::Tensor> parameters =
|
|
module.named_parameters();
|
|
ASSERT_EQ(parameters.size(), 2);
|
|
ASSERT_EQ(parameters[0].key(), "p1");
|
|
ASSERT_EQ(parameters[0]->data_ptr<float>(), module.p1.data_ptr<float>());
|
|
ASSERT_EQ(parameters[1].key(), "p2");
|
|
ASSERT_EQ(parameters[1]->data_ptr<float>(), module.p2.data_ptr<float>());
|
|
}
|
|
|
|
TEST_F(ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) {
|
|
TestModule module(1);
|
|
std::vector<torch::Tensor> buffers = module.buffers();
|
|
ASSERT_EQ(buffers.size(), 2);
|
|
ASSERT_EQ(buffers[0].data_ptr<float>(), module.b1.data_ptr<float>());
|
|
ASSERT_EQ(buffers[1].data_ptr<float>(), module.b2.data_ptr<float>());
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) {
|
|
TestModule module(1);
|
|
torch::OrderedDict<std::string, torch::Tensor> buffers =
|
|
module.named_buffers();
|
|
ASSERT_EQ(buffers.size(), 2);
|
|
ASSERT_EQ(buffers[0].key(), "b1");
|
|
ASSERT_EQ(buffers[0]->data_ptr<float>(), module.b1.data_ptr<float>());
|
|
ASSERT_EQ(buffers[1].key(), "b2");
|
|
ASSERT_EQ(buffers[1]->data_ptr<float>(), module.b2.data_ptr<float>());
|
|
}
|
|
|
|
struct TestContainer : torch::nn::Module {
|
|
TestContainer(int64_t number, std::vector<TestContainer> modules = {})
|
|
: tensor(torch::tensor(number)) {
|
|
for (const auto i : c10::irange(modules.size())) {
|
|
register_module(
|
|
std::to_string(i),
|
|
std::make_shared<TestContainer>(std::move(modules[i])));
|
|
}
|
|
}
|
|
torch::Tensor tensor;
|
|
};
|
|
|
|
int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) {
|
|
return std::dynamic_pointer_cast<TestContainer>(module)
|
|
->tensor.item<int64_t>();
|
|
}
|
|
|
|
std::shared_ptr<TestContainer> make_deeply_nested_test_container() {
|
|
return std::make_shared<TestContainer>(TestContainer(
|
|
0,
|
|
{TestContainer(1, {TestContainer(2), TestContainer(3)}),
|
|
TestContainer(4),
|
|
TestContainer(
|
|
5,
|
|
{TestContainer(6),
|
|
TestContainer(7, {TestContainer(8), TestContainer(9)})})}));
|
|
}
|
|
|
|
std::vector<std::pair<std::string, int64_t>>
|
|
make_key_value_pairs_for_deeply_nested_container() {
|
|
return {
|
|
{"test_prefix", 0},
|
|
{"test_prefix.0", 1},
|
|
{"test_prefix.0.0", 2},
|
|
{"test_prefix.0.1", 3},
|
|
{"test_prefix.1", 4},
|
|
{"test_prefix.2", 5},
|
|
{"test_prefix.2.0", 6},
|
|
{"test_prefix.2.1", 7},
|
|
{"test_prefix.2.1.0", 8},
|
|
{"test_prefix.2.1.1", 9}};
|
|
}
|
|
|
|
TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) {
|
|
auto model = make_deeply_nested_test_container();
|
|
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
|
|
|
|
ASSERT_EQ(modules.size(), 10);
|
|
for (const auto i : c10::irange(modules.size())) {
|
|
ASSERT_EQ(get_test_container_item(modules[i]), i);
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) {
|
|
auto model = make_deeply_nested_test_container();
|
|
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
|
|
model->named_modules(/*name_prefix=*/"test_prefix");
|
|
auto expected = make_key_value_pairs_for_deeply_nested_container();
|
|
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
ASSERT_EQ(modules[i].key(), expected[i].first);
|
|
ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second);
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) {
|
|
auto model = make_deeply_nested_test_container();
|
|
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
|
|
|
|
ASSERT_EQ(modules.size(), 3);
|
|
ASSERT_EQ(get_test_container_item(modules[0]), 1);
|
|
ASSERT_EQ(get_test_container_item(modules[1]), 4);
|
|
ASSERT_EQ(get_test_container_item(modules[2]), 5);
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) {
|
|
auto model = make_deeply_nested_test_container();
|
|
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
|
|
model->named_children();
|
|
|
|
ASSERT_EQ(modules.size(), 3);
|
|
|
|
ASSERT_EQ(get_test_container_item(modules[0].value()), 1);
|
|
ASSERT_EQ(modules[0].key(), "0");
|
|
|
|
ASSERT_EQ(get_test_container_item(modules[1].value()), 4);
|
|
ASSERT_EQ(modules[1].key(), "1");
|
|
|
|
ASSERT_EQ(get_test_container_item(modules[2].value()), 5);
|
|
ASSERT_EQ(modules[2].key(), "2");
|
|
}
|
|
|
|
TEST_F(ModuleTest, ModuleApplyIteratesCorreclty) {
|
|
auto model = make_deeply_nested_test_container();
|
|
int64_t index = 0;
|
|
model->apply([&index](torch::nn::Module& module) {
|
|
ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
|
|
});
|
|
ASSERT_EQ(index, 10);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ConstModuleApplyIteratesCorreclty) {
|
|
std::shared_ptr<const TestContainer> model =
|
|
make_deeply_nested_test_container();
|
|
int64_t index = 0;
|
|
model->apply([&index](const torch::nn::Module& module) {
|
|
ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
|
|
});
|
|
ASSERT_EQ(index, 10);
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedModuleApplyIteratesCorreclty) {
|
|
auto model = make_deeply_nested_test_container();
|
|
auto expected = make_key_value_pairs_for_deeply_nested_container();
|
|
int64_t index = 0;
|
|
model->apply(
|
|
[&index, expected](const std::string& name, torch::nn::Module& module) {
|
|
ASSERT_EQ(name, expected[index].first);
|
|
ASSERT_EQ(
|
|
module.as<TestContainer>()->tensor.item<int64_t>(),
|
|
expected[index++].second);
|
|
},
|
|
/*name_prefix=*/"test_prefix");
|
|
ASSERT_EQ(index, 10);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ConstNamedModuleApplyIteratesCorreclty) {
|
|
std::shared_ptr<const TestContainer> model =
|
|
make_deeply_nested_test_container();
|
|
auto expected = make_key_value_pairs_for_deeply_nested_container();
|
|
int64_t index = 0;
|
|
model->apply(
|
|
[&index, &expected](
|
|
const std::string& name, const torch::nn::Module& module) {
|
|
ASSERT_EQ(name, expected[index].first);
|
|
ASSERT_EQ(
|
|
module.as<const TestContainer>()->tensor.item<int64_t>(),
|
|
expected[index++].second);
|
|
},
|
|
/*name_prefix=*/"test_prefix");
|
|
ASSERT_EQ(index, 10);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ModulePointerApplyIteratesCorreclty) {
|
|
auto model = make_deeply_nested_test_container();
|
|
int64_t index = 0;
|
|
model->apply([&index](const std::shared_ptr<torch::nn::Module>& module) {
|
|
ASSERT_EQ(get_test_container_item(module), index++);
|
|
});
|
|
ASSERT_EQ(index, 10);
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedModulePointerApplyIteratesCorreclty) {
|
|
auto model = make_deeply_nested_test_container();
|
|
auto expected = make_key_value_pairs_for_deeply_nested_container();
|
|
int64_t index = 0;
|
|
model->apply(
|
|
[&index, &expected](
|
|
const std::string& name,
|
|
const std::shared_ptr<torch::nn::Module>& module) {
|
|
ASSERT_EQ(name, expected[index].first);
|
|
ASSERT_EQ(get_test_container_item(module), expected[index++].second);
|
|
},
|
|
/*name_prefix=*/"test_prefix");
|
|
ASSERT_EQ(index, 10);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
|
|
{
|
|
TestModule module(1);
|
|
ASSERT_THROWS_WITH(
|
|
module.modules(),
|
|
"It looks like you attempted to retrieve "
|
|
"your top-level module as a shared_ptr")
|
|
}
|
|
{
|
|
TestModule module(1);
|
|
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
|
ASSERT_NO_THROW(module.modules(/*include_self=*/false));
|
|
}
|
|
{
|
|
auto module = std::make_shared<TestModule>(1);
|
|
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
|
ASSERT_NO_THROW(module->modules());
|
|
}
|
|
}
|
|
|
|
struct EmptyModule : torch::nn::Module {};
|
|
|
|
TEST_F(ModuleTest, PrettyPrint) {
|
|
struct TestModule : torch::nn::Module {
|
|
TestModule(int x, float y) : x_(x), y_(y) {}
|
|
|
|
void pretty_print(std::ostream& stream) const override {
|
|
stream << "TestModule(x=" << x_ << ", y=" << y_ << ")";
|
|
}
|
|
|
|
int x_;
|
|
float y_;
|
|
};
|
|
|
|
ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
|
|
ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
|
|
}
|
|
|
|
struct ModuleWithNonTensorForwardImpl : torch::nn::Module {
|
|
int64_t forward(torch::Tensor x) {
|
|
return x.numel();
|
|
}
|
|
};
|
|
TORCH_MODULE(ModuleWithNonTensorForward);
|
|
|
|
TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) {
|
|
ModuleWithNonTensorForward m;
|
|
ASSERT_EQ(m(torch::ones(123)), 123);
|
|
}
|