mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138987 Approved by: https://github.com/Skylion007
1058 lines
34 KiB
C++
1058 lines
34 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <c10/util/irange.h>
|
|
#include <torch/torch.h>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
using namespace torch::nn;
|
|
using namespace torch::test;
|
|
|
|
struct AGIUnit : torch::nn::Module {};
|
|
|
|
namespace test {
|
|
struct AGIUnit : torch::nn::Module {};
|
|
struct AGIUnit2 : torch::nn::Module {
|
|
AGIUnit2() : torch::nn::Module("Foo") {}
|
|
};
|
|
} // namespace test
|
|
|
|
struct ModuleTest : torch::test::SeedingFixture {};
|
|
|
|
TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) {
|
|
Linear module(3, 4);
|
|
ASSERT_TRUE(module->is_training());
|
|
|
|
module->eval();
|
|
ASSERT_FALSE(module->is_training());
|
|
|
|
module->train();
|
|
ASSERT_TRUE(module->is_training());
|
|
}
|
|
|
|
TEST_F(ModuleTest, ZeroGrad) {
|
|
Linear module(3, 4);
|
|
auto weight = torch::ones({8, 3}, torch::requires_grad());
|
|
auto loss = module(weight).sum();
|
|
loss.backward();
|
|
for (auto& parameter : module->parameters()) {
|
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
|
auto grad = parameter.grad();
|
|
ASSERT_TRUE(grad.defined());
|
|
ASSERT_NE(grad.sum().item<float>(), 0);
|
|
}
|
|
module->zero_grad();
|
|
for (auto& parameter : module->parameters()) {
|
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
|
auto grad = parameter.grad();
|
|
ASSERT_FALSE(grad.defined());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, ZeroGradWithUndefined) {
|
|
struct TestModule : torch::nn::Module {
|
|
TestModule() {
|
|
x = register_parameter("x", torch::ones(5, torch::requires_grad()));
|
|
y = register_parameter("y", torch::ones(5, torch::requires_grad()));
|
|
}
|
|
torch::Tensor x, y;
|
|
};
|
|
|
|
TestModule module;
|
|
auto z = module.x * 2;
|
|
z.sum().backward();
|
|
|
|
ASSERT_TRUE(module.x.grad().defined());
|
|
ASSERT_FALSE(module.y.grad().defined());
|
|
|
|
module.zero_grad(false); // set_to_none = false
|
|
|
|
ASSERT_TRUE(module.x.grad().defined());
|
|
ASSERT_FALSE(module.y.grad().defined());
|
|
|
|
ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
|
|
|
|
module.zero_grad();
|
|
|
|
ASSERT_FALSE(module.x.grad().defined());
|
|
ASSERT_FALSE(module.y.grad().defined());
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
ASSERT_THROWS_WITH(
|
|
TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)),
|
|
"Submodule name must not contain a dot (got 'name.with.dot')");
|
|
ASSERT_THROWS_WITH(
|
|
TestModel{}.register_module("", torch::nn::Linear(3, 4)),
|
|
"Submodule name must not be empty");
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
TestModel model;
|
|
model.register_module("linear", torch::nn::Linear(3, 4));
|
|
ASSERT_THROWS_WITH(
|
|
model.register_module("linear", torch::nn::Linear(3, 4)),
|
|
"Submodule 'linear' already defined");
|
|
}
|
|
|
|
TEST_F(ModuleTest, ReplaceModuleThrowsForUnknownModuleName) {
|
|
torch::nn::Module model;
|
|
ASSERT_THROWS_WITH(
|
|
model.replace_module("linear", torch::nn::Linear(3, 4)),
|
|
"Submodule 'linear' is not defined");
|
|
}
|
|
|
|
TEST_F(ModuleTest, ReplaceModule) {
|
|
struct TestModel : public torch::nn::Module {
|
|
torch::nn::Linear l1{nullptr};
|
|
TestModel() {
|
|
l1 = register_module("l1", torch::nn::Linear(3, 4));
|
|
}
|
|
};
|
|
auto model = std::make_shared<TestModel>();
|
|
model->l1 = model->replace_module("l1", torch::nn::Linear(5, 6));
|
|
ASSERT_EQ(model->named_parameters()["l1.weight"].size(0), 6);
|
|
ASSERT_EQ(model->l1.get(), model->named_modules()["l1"]->as<Linear>());
|
|
}
|
|
|
|
TEST_F(ModuleTest, UnregisterModule) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
TestModel model;
|
|
ASSERT_THROWS_WITH(
|
|
model.unregister_module("linear"),
|
|
"No Module with name `linear` is registered");
|
|
model.register_module("linear", torch::nn::Linear(3, 4));
|
|
model.unregister_module("linear");
|
|
ASSERT_TRUE(model.children().empty());
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
ASSERT_THROWS_WITH(
|
|
TestModel{}.register_parameter("name.with.dot", torch::ones(5)),
|
|
"Parameter name must not contain a dot (got 'name.with.dot')");
|
|
ASSERT_THROWS_WITH(
|
|
TestModel{}.register_parameter("", torch::ones(5)),
|
|
"Parameter name must not be empty");
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
TestModel model;
|
|
model.register_parameter("p", torch::ones(5));
|
|
ASSERT_THROWS_WITH(
|
|
model.register_parameter("p", torch::ones(5)),
|
|
"Parameter 'p' already defined");
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterParameterUndefinedTensor) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
{
|
|
TestModel model;
|
|
model.register_parameter(
|
|
"undefined_tensor", torch::Tensor(), /*requires_grad=*/false);
|
|
ASSERT_EQ(model.parameters().size(), 0);
|
|
}
|
|
{
|
|
WarningCapture warnings;
|
|
|
|
TestModel model;
|
|
model.register_parameter("undefined_tensor", torch::Tensor());
|
|
ASSERT_EQ(model.parameters().size(), 0);
|
|
|
|
ASSERT_EQ(
|
|
count_substr_occurrences(
|
|
warnings.str(),
|
|
"Ignoring the `requires_grad=true` function parameter"),
|
|
1);
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
ASSERT_THROWS_WITH(
|
|
TestModel{}.register_buffer("name.with.dot", torch::ones(5)),
|
|
"Buffer name must not contain a dot (got 'name.with.dot')");
|
|
ASSERT_THROWS_WITH(
|
|
TestModel{}.register_buffer("", torch::ones(5)),
|
|
"Buffer name must not be empty");
|
|
}
|
|
|
|
TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) {
|
|
struct TestModel : public torch::nn::Module {};
|
|
TestModel model;
|
|
model.register_buffer("p", torch::ones(5));
|
|
ASSERT_THROWS_WITH(
|
|
model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined");
|
|
}
|
|
|
|
TEST_F(ModuleTest, CanGetName) {
|
|
// CHECK instead of REQUIRE because demangling may fail.
|
|
AGIUnit agi;
|
|
// Call it twice just to make sure there are no bugs in the lazy
|
|
// initialization semantics.
|
|
EXPECT_EQ(agi.name(), "AGIUnit");
|
|
EXPECT_EQ(agi.name(), "AGIUnit");
|
|
EXPECT_EQ(test::AGIUnit().name(), "test::AGIUnit");
|
|
EXPECT_EQ(test::AGIUnit2().name(), "Foo");
|
|
}
|
|
|
|
TEST_F(ModuleTest, AsCastsModulesCorrectly) {
|
|
Linear module(3, 4);
|
|
ASSERT_EQ(module->as<Linear>(), module.get());
|
|
ASSERT_EQ(module->as<LinearImpl>(), module.get());
|
|
ASSERT_EQ(module->as<Module>(), module.get());
|
|
ASSERT_EQ(module->as<AGIUnit>(), nullptr);
|
|
|
|
std::shared_ptr<Module> raw = module.ptr();
|
|
ASSERT_EQ(raw->as<Linear>(), module.get());
|
|
ASSERT_EQ(raw->as<LinearImpl>(), module.get());
|
|
ASSERT_EQ(raw->as<Module>(), module.get());
|
|
ASSERT_EQ(raw->as<AGIUnit>(), nullptr);
|
|
|
|
Module& raw_ref = *raw.get();
|
|
ASSERT_EQ(raw_ref.as<Linear>(), module.get());
|
|
ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get());
|
|
ASSERT_EQ(raw_ref.as<Module>(), module.get());
|
|
ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr);
|
|
if (auto* linear = raw_ref.as<Linear>()) {
|
|
ASSERT_EQ(linear->weight.ndimension(), 2);
|
|
}
|
|
|
|
AGIUnit unit;
|
|
ASSERT_EQ(unit.as<Linear>(), nullptr);
|
|
ASSERT_EQ(unit.as<LinearImpl>(), nullptr);
|
|
ASSERT_EQ(unit.as<AGIUnit>(), &unit);
|
|
}
|
|
|
|
void test_DeviceOrDtypeConversionSkipsUndefinedTensor(
|
|
torch::Device to_device,
|
|
torch::Dtype to_dtype) {
|
|
{
|
|
// Case 1: Undefined tensors as parameters
|
|
Linear module(LinearOptions(10, 20).bias(false));
|
|
ASSERT_TRUE(module->weight.defined());
|
|
ASSERT_FALSE(module->bias.defined());
|
|
|
|
module->to(to_device);
|
|
ASSERT_TRUE(module->weight.defined());
|
|
ASSERT_EQ(module->weight.device().type(), to_device.type());
|
|
ASSERT_FALSE(module->bias.defined());
|
|
|
|
module->to(to_dtype);
|
|
ASSERT_TRUE(module->weight.defined());
|
|
ASSERT_EQ(module->weight.dtype(), to_dtype);
|
|
ASSERT_FALSE(module->bias.defined());
|
|
}
|
|
{
|
|
// Case 2: Undefined tensors as buffers
|
|
BatchNorm1d module(
|
|
BatchNorm1dOptions(5).track_running_stats(false).affine(true));
|
|
ASSERT_TRUE(module->weight.defined());
|
|
ASSERT_FALSE(module->running_mean.defined());
|
|
|
|
module->to(to_device);
|
|
ASSERT_TRUE(module->weight.defined());
|
|
ASSERT_EQ(module->weight.device().type(), to_device.type());
|
|
ASSERT_FALSE(module->running_mean.defined());
|
|
|
|
module->to(to_dtype);
|
|
ASSERT_TRUE(module->weight.defined());
|
|
ASSERT_EQ(module->weight.dtype(), to_dtype);
|
|
ASSERT_FALSE(module->running_mean.defined());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor) {
|
|
test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::kCPU, torch::kDouble);
|
|
}
|
|
|
|
TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA) {
|
|
test_DeviceOrDtypeConversionSkipsUndefinedTensor(
|
|
torch::kCUDA, torch::kDouble);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) {
|
|
{
|
|
Linear module(LinearOptions(10, 20).bias(false));
|
|
|
|
auto params = module->parameters();
|
|
ASSERT_EQ(params.size(), 1);
|
|
auto named_params = module->named_parameters();
|
|
ASSERT_EQ(named_params.size(), 1);
|
|
|
|
ASSERT_TRUE(pointer_equal(params[0], named_params["weight"]));
|
|
ASSERT_TRUE(pointer_equal(named_params["weight"], module->weight));
|
|
}
|
|
{
|
|
BatchNorm1d module(
|
|
BatchNorm1dOptions(5).track_running_stats(false).affine(false));
|
|
|
|
auto buffers = module->buffers();
|
|
ASSERT_EQ(buffers.size(), 0);
|
|
auto named_buffers = module->named_buffers();
|
|
ASSERT_EQ(named_buffers.size(), 0);
|
|
}
|
|
{
|
|
BatchNorm1d module(
|
|
BatchNorm1dOptions(5).track_running_stats(true).affine(false));
|
|
|
|
auto buffers = module->buffers();
|
|
ASSERT_EQ(buffers.size(), 3);
|
|
auto named_buffers = module->named_buffers();
|
|
ASSERT_EQ(named_buffers.size(), 3);
|
|
|
|
ASSERT_TRUE(pointer_equal(buffers[0], named_buffers["running_mean"]));
|
|
ASSERT_TRUE(
|
|
pointer_equal(named_buffers["running_mean"], module->running_mean));
|
|
ASSERT_TRUE(pointer_equal(buffers[1], named_buffers["running_var"]));
|
|
ASSERT_TRUE(
|
|
pointer_equal(named_buffers["running_var"], module->running_var));
|
|
ASSERT_TRUE(
|
|
pointer_equal(buffers[2], named_buffers["num_batches_tracked"]));
|
|
ASSERT_TRUE(pointer_equal(
|
|
named_buffers["num_batches_tracked"], module->num_batches_tracked));
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, Conversion_MultiCUDA) {
|
|
Linear module(128, 64);
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.device(), torch::Device(torch::kCPU));
|
|
ASSERT_EQ(parameter.dtype(), torch::kFloat32);
|
|
}
|
|
{
|
|
module->to({torch::kCUDA, 0});
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
|
|
ASSERT_EQ(parameter.device().index(), 0);
|
|
}
|
|
module->to({torch::kCUDA, 1});
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
|
|
ASSERT_EQ(parameter.device().index(), 1);
|
|
}
|
|
}
|
|
{
|
|
module->to(torch::Device(torch::kCPU));
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU);
|
|
}
|
|
}
|
|
{
|
|
module->to(torch::kFloat64);
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.dtype(), torch::kFloat64);
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, Conversion_NoGrad_MultiCUDA) {
|
|
Linear module(128, 64);
|
|
for (auto& parameter : module->parameters()) {
|
|
parameter.requires_grad_(false);
|
|
}
|
|
{
|
|
module->to(torch::kInt32);
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.dtype(), torch::kInt32);
|
|
}
|
|
}
|
|
{
|
|
module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
|
|
ASSERT_EQ(parameter.device().index(), 1);
|
|
}
|
|
for (auto& parameter : module->parameters()) {
|
|
ASSERT_EQ(parameter.dtype(), torch::kUInt8);
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
|
|
struct UnCloneable : Module {};
|
|
UnCloneable module;
|
|
ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented");
|
|
}
|
|
|
|
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
|
|
struct Cloneable : Module {
|
|
std::shared_ptr<Module> clone(
|
|
const std::optional<torch::Device>& device =
|
|
std::nullopt) const override {
|
|
return nullptr;
|
|
}
|
|
};
|
|
Cloneable module;
|
|
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
|
ASSERT_NO_THROW({ module.clone(); });
|
|
}
|
|
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
struct TestDistinctParametersModule
|
|
: public Cloneable<TestDistinctParametersModule> {
|
|
TestDistinctParametersModule() {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
reset();
|
|
}
|
|
void reset() override {
|
|
l1 = register_module("l1", Linear(10, 3));
|
|
l2 = register_module("l2", Linear(3, 5));
|
|
l3 = register_module("l3", Linear(5, 100));
|
|
buffer = register_buffer("buf", torch::ones({2, 2}));
|
|
}
|
|
|
|
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
|
|
torch::Tensor buffer;
|
|
};
|
|
|
|
void testDistinctParameters(
|
|
std::shared_ptr<Module> m1,
|
|
std::shared_ptr<Module> m2) {
|
|
auto params1 = m1->named_parameters();
|
|
auto params2 = m2->named_parameters();
|
|
ASSERT_EQ(params1.size(), 6);
|
|
ASSERT_EQ(params2.size(), 6);
|
|
for (auto& param : params1) {
|
|
ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
|
|
ASSERT_TRUE(param->allclose(params2[param.key()]));
|
|
param->add_(2);
|
|
}
|
|
for (auto& param : params1) {
|
|
ASSERT_FALSE(param->allclose(params2[param.key()]));
|
|
}
|
|
|
|
auto buffers1 = m1->named_buffers();
|
|
auto buffers2 = m2->named_buffers();
|
|
ASSERT_EQ(buffers1.size(), 1);
|
|
ASSERT_EQ(buffers2.size(), 1);
|
|
for (auto& buffer : buffers1) {
|
|
ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()]));
|
|
ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()]));
|
|
buffer->add_(2);
|
|
}
|
|
for (auto& buffer : buffers1) {
|
|
ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()]));
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, CloneCreatesDistinctParameters) {
|
|
auto module = std::make_shared<TestDistinctParametersModule>();
|
|
torch::NoGradGuard no_grad;
|
|
auto module2 = module->clone();
|
|
testDistinctParameters(module, module2);
|
|
}
|
|
|
|
TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_CUDA) {
|
|
auto module = std::make_shared<TestDistinctParametersModule>();
|
|
torch::NoGradGuard no_grad;
|
|
torch::Device device(torch::kCUDA, 0);
|
|
module->to(device);
|
|
auto module2 = module->clone(device);
|
|
testDistinctParameters(module, module2);
|
|
}
|
|
|
|
TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_MultiCUDA) {
|
|
auto module = std::make_shared<TestDistinctParametersModule>();
|
|
torch::NoGradGuard no_grad;
|
|
torch::Device d0(torch::kCUDA, 0);
|
|
torch::Device d1(torch::kCUDA, 1);
|
|
module->to(d0);
|
|
auto module2 = module->clone(d1);
|
|
|
|
for (auto& param : module->parameters()) {
|
|
ASSERT_EQ(param.device(), d0);
|
|
}
|
|
|
|
for (auto& param : module2->parameters()) {
|
|
ASSERT_EQ(param.device(), d1);
|
|
}
|
|
|
|
// need to move the module back to d0 as allclose expects two tensors on
|
|
// the same device.
|
|
module2->to(d0);
|
|
testDistinctParameters(module, module2);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ClonePreservesExternalReferences) {
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
struct TestModule : public Cloneable<TestModule> {
|
|
TestModule() {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
reset();
|
|
}
|
|
void reset() override {
|
|
weight = register_parameter("weight", torch::ones({4, 4}));
|
|
}
|
|
torch::Tensor weight;
|
|
};
|
|
auto module = std::make_shared<TestModule>();
|
|
{
|
|
torch::NoGradGuard no_grad;
|
|
module->weight += 1;
|
|
}
|
|
ASSERT_TRUE(
|
|
pointer_equal(module->weight, module->named_parameters()["weight"]));
|
|
ASSERT_TRUE(module->weight.allclose(module->named_parameters()["weight"]));
|
|
|
|
auto module2 = std::dynamic_pointer_cast<TestModule>(
|
|
std::shared_ptr<Module>(module->clone()));
|
|
ASSERT_FALSE(pointer_equal(module2->weight, module->weight));
|
|
ASSERT_TRUE(
|
|
pointer_equal(module2->weight, module2->named_parameters()["weight"]));
|
|
ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()["weight"]));
|
|
ASSERT_TRUE(module2->weight.allclose(module->weight));
|
|
ASSERT_FALSE(
|
|
pointer_equal(module2->weight, module->named_parameters()["weight"]));
|
|
}
|
|
|
|
TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) {
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
struct TestModule : public Cloneable<TestModule> {
|
|
TestModule() {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
reset();
|
|
}
|
|
void reset() override {
|
|
weight = register_parameter("weight", torch::ones({4, 4}));
|
|
}
|
|
|
|
torch::Tensor weight;
|
|
int value = 0;
|
|
};
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
struct NestedModule : public Cloneable<NestedModule> {
|
|
NestedModule() {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
reset();
|
|
}
|
|
void reset() override {
|
|
module = register_module("module", std::make_shared<TestModule>());
|
|
}
|
|
std::shared_ptr<TestModule> module;
|
|
};
|
|
|
|
auto a = std::make_shared<NestedModule>();
|
|
{
|
|
torch::NoGradGuard no_grad;
|
|
a->module->weight += 1;
|
|
a->module->value = 123;
|
|
}
|
|
|
|
auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
|
|
|
|
ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight));
|
|
ASSERT_TRUE(pointer_equal(
|
|
b->module->weight, b->module->named_parameters()["weight"]));
|
|
ASSERT_TRUE(
|
|
b->module->named_parameters()["weight"].allclose(a->module->weight));
|
|
ASSERT_TRUE(b->module->weight.allclose(a->module->weight));
|
|
ASSERT_EQ(b->module->value, a->module->value);
|
|
}
|
|
|
|
TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) {
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
struct TestModule : public Cloneable<TestModule> {
|
|
TestModule() {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
reset();
|
|
}
|
|
void reset() override {
|
|
l1 = register_module("l1", Linear(10, 3));
|
|
l2 = register_module("l2", Linear(3, 5));
|
|
l3 = register_module("l3", Linear(5, 100));
|
|
buffer = register_buffer("buf", torch::ones({2, 2}));
|
|
}
|
|
|
|
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
|
|
torch::Tensor buffer;
|
|
};
|
|
|
|
TestModule m;
|
|
torch::Device device(torch::kCUDA, 0);
|
|
|
|
m.to(device);
|
|
|
|
auto clone = m.clone();
|
|
for (const auto& parameter : clone->parameters()) {
|
|
ASSERT_EQ(parameter.device().type(), device.type());
|
|
ASSERT_EQ(parameter.device().index(), device.index());
|
|
}
|
|
for (const auto& buffer : clone->buffers()) {
|
|
ASSERT_EQ(buffer.device().type(), device.type());
|
|
ASSERT_EQ(buffer.device().index(), device.index());
|
|
}
|
|
}
|
|
|
|
TEST_F(
|
|
ModuleTest,
|
|
CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA) {
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
struct TestModule : public Cloneable<TestModule> {
|
|
TestModule() {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
reset();
|
|
}
|
|
void reset() override {
|
|
l1 = register_module("l1", Linear(10, 3));
|
|
l2 = register_module("l2", Linear(3, 5));
|
|
l3 = register_module("l3", Linear(5, 100));
|
|
buffer = register_buffer("buf", torch::ones({2, 2}));
|
|
}
|
|
|
|
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
|
|
torch::Tensor buffer;
|
|
};
|
|
|
|
TestModule m;
|
|
torch::Device device(torch::kCUDA, 1);
|
|
// everything is on CPU here
|
|
auto clone = m.clone(device);
|
|
for (const auto& parameter : clone->parameters()) {
|
|
ASSERT_EQ(parameter.device().type(), device.type());
|
|
ASSERT_EQ(parameter.device().index(), device.index());
|
|
}
|
|
for (const auto& buffer : clone->buffers()) {
|
|
ASSERT_EQ(buffer.device().type(), device.type());
|
|
ASSERT_EQ(buffer.device().index(), device.index());
|
|
}
|
|
}
|
|
|
|
struct ParameterTestModule : Module {
|
|
ParameterTestModule() {
|
|
a = register_parameter("a", torch::zeros({2, 2}));
|
|
b = register_parameter("b", torch::ones({2, 2}));
|
|
c = register_parameter("c", torch::ones({2, 2}) * 2);
|
|
}
|
|
|
|
torch::Tensor a, b, c;
|
|
};
|
|
|
|
TEST_F(ModuleTest, HasCorrectNumberOfParameters) {
|
|
ParameterTestModule module;
|
|
ASSERT_EQ(module.parameters().size(), 3);
|
|
ASSERT_EQ(module.named_parameters().size(), 3);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) {
|
|
ParameterTestModule module;
|
|
auto parameters = module.named_parameters();
|
|
ASSERT_TRUE(parameters.contains("a"));
|
|
ASSERT_TRUE(parameters.contains("b"));
|
|
ASSERT_TRUE(parameters.contains("c"));
|
|
}
|
|
|
|
struct BufferTestModule : Module {
|
|
BufferTestModule() {
|
|
a = register_buffer("a", torch::zeros({2, 2}));
|
|
b = register_buffer("b", torch::ones({2, 2}));
|
|
c = register_buffer("c", torch::ones({2, 2}) * 2);
|
|
}
|
|
|
|
torch::Tensor a, b, c;
|
|
};
|
|
|
|
TEST_F(ModuleTest, HasCorrectNumberOfBuffers) {
|
|
BufferTestModule module;
|
|
ASSERT_EQ(module.buffers().size(), 3);
|
|
ASSERT_EQ(module.named_buffers().size(), 3);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) {
|
|
BufferTestModule module;
|
|
auto buffers = module.named_buffers();
|
|
ASSERT_TRUE(buffers.contains("a"));
|
|
ASSERT_TRUE(buffers.contains("b"));
|
|
ASSERT_TRUE(buffers.contains("c"));
|
|
}
|
|
|
|
struct AImpl : torch::nn::Module {
|
|
AImpl() : x_(123) {}
|
|
AImpl(int x) : x_(x) {}
|
|
int x_;
|
|
};
|
|
TORCH_MODULE(A);
|
|
|
|
TEST_F(
|
|
ModuleTest,
|
|
DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) {
|
|
A a;
|
|
ASSERT_TRUE(a);
|
|
ASSERT_FALSE(a.is_empty());
|
|
ASSERT_EQ(a->x_, 123);
|
|
}
|
|
|
|
TEST_F(
|
|
ModuleTest,
|
|
ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) {
|
|
A a(5);
|
|
ASSERT_TRUE(a);
|
|
ASSERT_FALSE(a.is_empty());
|
|
ASSERT_EQ(a->x_, 5);
|
|
}
|
|
|
|
TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) {
|
|
A a = nullptr;
|
|
ASSERT_FALSE(a);
|
|
ASSERT_TRUE(a.is_empty());
|
|
ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder");
|
|
}
|
|
|
|
struct TestModule : public torch::nn::Module {
|
|
TestModule(int64_t size) {
|
|
p1 = register_parameter("p1", torch::randn({size}));
|
|
p2 = register_parameter("p2", torch::randn({size}));
|
|
b1 = register_buffer("b1", torch::randn({size}));
|
|
b2 = register_buffer("b2", torch::randn({size}));
|
|
}
|
|
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
return input;
|
|
}
|
|
|
|
torch::Tensor p1, p2, b1, b2;
|
|
};
|
|
|
|
TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) {
|
|
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
|
|
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
|
|
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
|
|
model.ptr(), model[0], model[1], model[2]};
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
// Assert pointer equality.
|
|
ASSERT_EQ(modules[i].get(), expected[i].get());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) {
|
|
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
|
|
std::vector<std::shared_ptr<torch::nn::Module>> modules =
|
|
model->modules(/*include_self=*/false);
|
|
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
|
|
model[0], model[1], model[2]};
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
// Assert pointer equality.
|
|
ASSERT_EQ(modules[i].get(), expected[i].get());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) {
|
|
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
|
|
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
|
|
model->named_modules();
|
|
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
|
|
model.ptr(), model[0], model[1], model[2]};
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
// Assert pointer equality.
|
|
ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string());
|
|
ASSERT_EQ(modules[i].value().get(), expected[i].get());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) {
|
|
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
|
|
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
|
|
model->named_modules(
|
|
/*name_prefix=*/std::string(), /*include_self=*/false);
|
|
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
|
|
model[0], model[1], model[2]};
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
// Assert pointer equality.
|
|
ASSERT_EQ(modules[i].key(), std::to_string(i));
|
|
ASSERT_EQ(modules[i].value().get(), expected[i].get());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) {
|
|
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
|
|
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
|
|
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
|
|
model[0], model[1], model[2]};
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
// Assert pointer equality.
|
|
ASSERT_EQ(modules[i].get(), expected[i].get());
|
|
}
|
|
|
|
// For this flat model, this should be true.
|
|
ASSERT_EQ(modules, model->modules(/*include_self=*/false));
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) {
|
|
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
|
|
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
|
|
model->named_children();
|
|
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
|
|
model[0], model[1], model[2]};
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
// Assert pointer equality.
|
|
ASSERT_EQ(modules[i].key(), std::to_string(i));
|
|
ASSERT_EQ(modules[i].value().get(), expected[i].get());
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) {
|
|
TestModule module(1);
|
|
std::vector<torch::Tensor> parameters = module.parameters();
|
|
ASSERT_EQ(parameters.size(), 2);
|
|
ASSERT_EQ(parameters[0].data_ptr<float>(), module.p1.data_ptr<float>());
|
|
ASSERT_EQ(parameters[1].data_ptr<float>(), module.p2.data_ptr<float>());
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) {
|
|
TestModule module(1);
|
|
torch::OrderedDict<std::string, torch::Tensor> parameters =
|
|
module.named_parameters();
|
|
ASSERT_EQ(parameters.size(), 2);
|
|
ASSERT_EQ(parameters[0].key(), "p1");
|
|
ASSERT_EQ(parameters[0]->data_ptr<float>(), module.p1.data_ptr<float>());
|
|
ASSERT_EQ(parameters[1].key(), "p2");
|
|
ASSERT_EQ(parameters[1]->data_ptr<float>(), module.p2.data_ptr<float>());
|
|
}
|
|
|
|
TEST_F(ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) {
|
|
TestModule module(1);
|
|
std::vector<torch::Tensor> buffers = module.buffers();
|
|
ASSERT_EQ(buffers.size(), 2);
|
|
ASSERT_EQ(buffers[0].data_ptr<float>(), module.b1.data_ptr<float>());
|
|
ASSERT_EQ(buffers[1].data_ptr<float>(), module.b2.data_ptr<float>());
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) {
|
|
TestModule module(1);
|
|
torch::OrderedDict<std::string, torch::Tensor> buffers =
|
|
module.named_buffers();
|
|
ASSERT_EQ(buffers.size(), 2);
|
|
ASSERT_EQ(buffers[0].key(), "b1");
|
|
ASSERT_EQ(buffers[0]->data_ptr<float>(), module.b1.data_ptr<float>());
|
|
ASSERT_EQ(buffers[1].key(), "b2");
|
|
ASSERT_EQ(buffers[1]->data_ptr<float>(), module.b2.data_ptr<float>());
|
|
}
|
|
|
|
struct TestContainer : torch::nn::Module {
|
|
TestContainer(int64_t number, std::vector<TestContainer> modules = {})
|
|
: tensor(torch::tensor(number)) {
|
|
for (const auto i : c10::irange(modules.size())) {
|
|
register_module(
|
|
std::to_string(i),
|
|
std::make_shared<TestContainer>(std::move(modules[i])));
|
|
}
|
|
}
|
|
torch::Tensor tensor;
|
|
};
|
|
|
|
int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) {
|
|
return std::dynamic_pointer_cast<TestContainer>(module)
|
|
->tensor.item<int64_t>();
|
|
}
|
|
|
|
std::shared_ptr<TestContainer> make_deeply_nested_test_container() {
|
|
return std::make_shared<TestContainer>(TestContainer(
|
|
0,
|
|
{TestContainer(1, {TestContainer(2), TestContainer(3)}),
|
|
TestContainer(4),
|
|
TestContainer(
|
|
5,
|
|
{TestContainer(6),
|
|
TestContainer(7, {TestContainer(8), TestContainer(9)})})}));
|
|
}
|
|
|
|
std::vector<std::pair<std::string, int64_t>>
|
|
make_key_value_pairs_for_deeply_nested_container() {
|
|
return {
|
|
{"test_prefix", 0},
|
|
{"test_prefix.0", 1},
|
|
{"test_prefix.0.0", 2},
|
|
{"test_prefix.0.1", 3},
|
|
{"test_prefix.1", 4},
|
|
{"test_prefix.2", 5},
|
|
{"test_prefix.2.0", 6},
|
|
{"test_prefix.2.1", 7},
|
|
{"test_prefix.2.1.0", 8},
|
|
{"test_prefix.2.1.1", 9}};
|
|
}
|
|
|
|
TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) {
|
|
auto model = make_deeply_nested_test_container();
|
|
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
|
|
|
|
ASSERT_EQ(modules.size(), 10);
|
|
for (const auto i : c10::irange(modules.size())) {
|
|
ASSERT_EQ(get_test_container_item(modules[i]), i);
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) {
|
|
auto model = make_deeply_nested_test_container();
|
|
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
|
|
model->named_modules(/*name_prefix=*/"test_prefix");
|
|
auto expected = make_key_value_pairs_for_deeply_nested_container();
|
|
|
|
ASSERT_EQ(modules.size(), expected.size());
|
|
|
|
for (const auto i : c10::irange(expected.size())) {
|
|
ASSERT_EQ(modules[i].key(), expected[i].first);
|
|
ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second);
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) {
|
|
auto model = make_deeply_nested_test_container();
|
|
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
|
|
|
|
ASSERT_EQ(modules.size(), 3);
|
|
ASSERT_EQ(get_test_container_item(modules[0]), 1);
|
|
ASSERT_EQ(get_test_container_item(modules[1]), 4);
|
|
ASSERT_EQ(get_test_container_item(modules[2]), 5);
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) {
|
|
auto model = make_deeply_nested_test_container();
|
|
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
|
|
model->named_children();
|
|
|
|
ASSERT_EQ(modules.size(), 3);
|
|
|
|
ASSERT_EQ(get_test_container_item(modules[0].value()), 1);
|
|
ASSERT_EQ(modules[0].key(), "0");
|
|
|
|
ASSERT_EQ(get_test_container_item(modules[1].value()), 4);
|
|
ASSERT_EQ(modules[1].key(), "1");
|
|
|
|
ASSERT_EQ(get_test_container_item(modules[2].value()), 5);
|
|
ASSERT_EQ(modules[2].key(), "2");
|
|
}
|
|
|
|
TEST_F(ModuleTest, ModuleApplyIteratesCorreclty) {
|
|
auto model = make_deeply_nested_test_container();
|
|
int64_t index = 0;
|
|
model->apply([&index](torch::nn::Module& module) {
|
|
ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
|
|
});
|
|
ASSERT_EQ(index, 10);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ConstModuleApplyIteratesCorreclty) {
|
|
std::shared_ptr<const TestContainer> model =
|
|
make_deeply_nested_test_container();
|
|
int64_t index = 0;
|
|
model->apply([&index](const torch::nn::Module& module) {
|
|
ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
|
|
});
|
|
ASSERT_EQ(index, 10);
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedModuleApplyIteratesCorreclty) {
|
|
auto model = make_deeply_nested_test_container();
|
|
auto expected = make_key_value_pairs_for_deeply_nested_container();
|
|
int64_t index = 0;
|
|
model->apply(
|
|
[&index, expected](const std::string& name, torch::nn::Module& module) {
|
|
ASSERT_EQ(name, expected[index].first);
|
|
ASSERT_EQ(
|
|
module.as<TestContainer>()->tensor.item<int64_t>(),
|
|
expected[index++].second);
|
|
},
|
|
/*name_prefix=*/"test_prefix");
|
|
ASSERT_EQ(index, 10);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ConstNamedModuleApplyIteratesCorreclty) {
|
|
std::shared_ptr<const TestContainer> model =
|
|
make_deeply_nested_test_container();
|
|
auto expected = make_key_value_pairs_for_deeply_nested_container();
|
|
int64_t index = 0;
|
|
model->apply(
|
|
[&index, &expected](
|
|
const std::string& name, const torch::nn::Module& module) {
|
|
ASSERT_EQ(name, expected[index].first);
|
|
ASSERT_EQ(
|
|
module.as<const TestContainer>()->tensor.item<int64_t>(),
|
|
expected[index++].second);
|
|
},
|
|
/*name_prefix=*/"test_prefix");
|
|
ASSERT_EQ(index, 10);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ModulePointerApplyIteratesCorreclty) {
|
|
auto model = make_deeply_nested_test_container();
|
|
int64_t index = 0;
|
|
model->apply([&index](const std::shared_ptr<torch::nn::Module>& module) {
|
|
ASSERT_EQ(get_test_container_item(module), index++);
|
|
});
|
|
ASSERT_EQ(index, 10);
|
|
}
|
|
|
|
TEST_F(ModuleTest, NamedModulePointerApplyIteratesCorreclty) {
|
|
auto model = make_deeply_nested_test_container();
|
|
auto expected = make_key_value_pairs_for_deeply_nested_container();
|
|
int64_t index = 0;
|
|
model->apply(
|
|
[&index, &expected](
|
|
const std::string& name,
|
|
const std::shared_ptr<torch::nn::Module>& module) {
|
|
ASSERT_EQ(name, expected[index].first);
|
|
ASSERT_EQ(get_test_container_item(module), expected[index++].second);
|
|
},
|
|
/*name_prefix=*/"test_prefix");
|
|
ASSERT_EQ(index, 10);
|
|
}
|
|
|
|
TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
|
|
{
|
|
TestModule module(1);
|
|
ASSERT_THROWS_WITH(
|
|
module.modules(),
|
|
"It looks like you attempted to retrieve "
|
|
"your top-level module as a shared_ptr")
|
|
}
|
|
{
|
|
TestModule module(1);
|
|
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
|
ASSERT_NO_THROW(module.modules(/*include_self=*/false));
|
|
}
|
|
{
|
|
auto module = std::make_shared<TestModule>(1);
|
|
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
|
ASSERT_NO_THROW(module->modules());
|
|
}
|
|
}
|
|
|
|
struct EmptyModule : torch::nn::Module {};
|
|
|
|
TEST_F(ModuleTest, PrettyPrint) {
|
|
struct TestModule : torch::nn::Module {
|
|
TestModule(int x, float y) : x_(x), y_(y) {}
|
|
|
|
void pretty_print(std::ostream& stream) const override {
|
|
stream << "TestModule(x=" << x_ << ", y=" << y_ << ")";
|
|
}
|
|
|
|
int x_;
|
|
float y_;
|
|
};
|
|
|
|
ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
|
|
ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
|
|
}
|
|
|
|
struct ModuleWithNonTensorForwardImpl : torch::nn::Module {
|
|
int64_t forward(torch::Tensor x) {
|
|
return x.numel();
|
|
}
|
|
};
|
|
TORCH_MODULE(ModuleWithNonTensorForward);
|
|
|
|
TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) {
|
|
ModuleWithNonTensorForward m;
|
|
ASSERT_EQ(m(torch::ones(123)), 123);
|
|
}
|