mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 22:25:10 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28828 This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. Test Plan: Imported from OSS Differential Revision: D18197611 Pulled By: zdevito fbshipit-source-id: 7ee4dcbb258605d1c988314b05d938423f1ccee5
497 lines
14 KiB
C++
497 lines
14 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <c10/util/tempfile.h>
|
|
|
|
#include <torch/torch.h>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
#include <cstdio>
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
using namespace torch::nn;
|
|
using namespace torch::serialize;
|
|
|
|
namespace {
|
|
Sequential xor_model() {
|
|
return Sequential(
|
|
Linear(2, 8),
|
|
Functional(at::sigmoid),
|
|
Linear(8, 1),
|
|
Functional(at::sigmoid));
|
|
}
|
|
|
|
torch::Tensor save_and_load(torch::Tensor input) {
|
|
std::stringstream stream;
|
|
torch::save(input, stream);
|
|
torch::Tensor tensor;
|
|
torch::load(tensor, stream);
|
|
return tensor;
|
|
}
|
|
} // namespace
|
|
|
|
TEST(SerializeTest, Basic) {
|
|
torch::manual_seed(0);
|
|
|
|
auto x = torch::randn({5, 5});
|
|
auto y = save_and_load(x);
|
|
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
}
|
|
|
|
TEST(SerializeTest, BasicToFile) {
|
|
torch::manual_seed(0);
|
|
|
|
auto x = torch::randn({5, 5});
|
|
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::save(x, tempfile.name);
|
|
|
|
torch::Tensor y;
|
|
torch::load(y, tempfile.name);
|
|
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
}
|
|
|
|
TEST(SerializeTest, BasicViaFunc) {
|
|
torch::manual_seed(0);
|
|
|
|
auto x = torch::randn({5, 5});
|
|
|
|
std::string serialized;
|
|
torch::save(x, [&](const void* buf, size_t n) {
|
|
serialized.append(reinterpret_cast<const char *>(buf), n);
|
|
return n;
|
|
});
|
|
torch::Tensor y;
|
|
torch::load(y, serialized.data(), serialized.size());
|
|
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
|
|
torch::Tensor z;
|
|
torch::load(z, [&](uint64_t pos, void* buf, size_t n) -> size_t {
|
|
if (pos >= serialized.size()) return 0;
|
|
size_t nbytes = std::min(static_cast<size_t>(pos) + n,
|
|
serialized.size()) - pos;
|
|
memcpy(buf, serialized.data() + pos, nbytes);
|
|
return nbytes;
|
|
},
|
|
[&]() -> size_t { return serialized.size(); });
|
|
ASSERT_TRUE(z.defined());
|
|
ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(z));
|
|
}
|
|
|
|
TEST(SerializeTest, Resized) {
|
|
torch::manual_seed(0);
|
|
|
|
auto x = torch::randn({11, 5});
|
|
x.resize_({5, 5});
|
|
auto y = save_and_load(x);
|
|
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
}
|
|
|
|
TEST(SerializeTest, Sliced) {
|
|
torch::manual_seed(0);
|
|
|
|
auto x = torch::randn({11, 5});
|
|
x = x.slice(0, 1, 5);
|
|
auto y = save_and_load(x);
|
|
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
}
|
|
|
|
TEST(SerializeTest, NonContiguous) {
|
|
torch::manual_seed(0);
|
|
|
|
auto x = torch::randn({11, 5});
|
|
x = x.slice(1, 1, 4);
|
|
auto y = save_and_load(x);
|
|
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
}
|
|
|
|
TEST(SerializeTest, ErrorOnMissingKey) {
|
|
struct B : torch::nn::Module {
|
|
B(const std::string& name_c) {
|
|
register_buffer(name_c, torch::ones(5, torch::kFloat));
|
|
}
|
|
};
|
|
struct A : torch::nn::Module {
|
|
A(const std::string& name_b, const std::string& name_c) {
|
|
register_module(name_b, std::make_shared<B>(name_c));
|
|
}
|
|
};
|
|
struct M : torch::nn::Module {
|
|
M(const std::string& name_a,
|
|
const std::string& name_b,
|
|
const std::string& name_c) {
|
|
register_module(name_a, std::make_shared<A>(name_b, name_c));
|
|
}
|
|
};
|
|
|
|
// create a hierarchy of models with names differing below the top level
|
|
auto model1 = std::make_shared<M>("a", "b", "c");
|
|
auto model2 = std::make_shared<M>("a", "b", "x");
|
|
auto model3 = std::make_shared<M>("a", "x", "c");
|
|
|
|
std::stringstream stream;
|
|
torch::save(model1, stream);
|
|
// We want the errors to contain hierarchy information, too.
|
|
ASSERT_THROWS_WITH(
|
|
torch::load(model2, stream), "No such serialized tensor 'a.b.x'");
|
|
ASSERT_THROWS_WITH(
|
|
torch::load(model3, stream), "No such serialized submodule: 'a.x'");
|
|
}
|
|
|
|
TEST(SerializeTest, XOR) {
|
|
// We better be able to save and load an XOR model!
|
|
auto getLoss = [](Sequential model, uint32_t batch_size) {
|
|
auto inputs = torch::empty({batch_size, 2});
|
|
auto labels = torch::empty({batch_size});
|
|
for (size_t i = 0; i < batch_size; i++) {
|
|
inputs[i] = torch::randint(2, {2}, torch::kInt64);
|
|
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
|
|
}
|
|
auto x = model->forward<torch::Tensor>(inputs);
|
|
return torch::binary_cross_entropy(x, labels);
|
|
};
|
|
|
|
auto model = xor_model();
|
|
auto model2 = xor_model();
|
|
auto model3 = xor_model();
|
|
auto optimizer = torch::optim::SGD(
|
|
model->parameters(),
|
|
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
|
|
1e-6));
|
|
|
|
float running_loss = 1;
|
|
int epoch = 0;
|
|
while (running_loss > 0.1) {
|
|
torch::Tensor loss = getLoss(model, 4);
|
|
optimizer.zero_grad();
|
|
loss.backward();
|
|
optimizer.step();
|
|
|
|
running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
|
|
ASSERT_LT(epoch, 3000);
|
|
epoch++;
|
|
}
|
|
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::save(model, tempfile.name);
|
|
torch::load(model2, tempfile.name);
|
|
|
|
auto loss = getLoss(model2, 100);
|
|
ASSERT_LT(loss.item<float>(), 0.1);
|
|
}
|
|
|
|
TEST(SerializeTest, Optim) {
|
|
auto model1 = Linear(5, 2);
|
|
auto model2 = Linear(5, 2);
|
|
auto model3 = Linear(5, 2);
|
|
|
|
// Models 1, 2, 3 will have the same parameters.
|
|
auto model_tempfile = c10::make_tempfile();
|
|
torch::save(model1, model_tempfile.name);
|
|
torch::load(model2, model_tempfile.name);
|
|
torch::load(model3, model_tempfile.name);
|
|
|
|
auto param1 = model1->named_parameters();
|
|
auto param2 = model2->named_parameters();
|
|
auto param3 = model3->named_parameters();
|
|
for (const auto& p : param1) {
|
|
ASSERT_TRUE(p->allclose(param2[p.key()]));
|
|
ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
|
|
}
|
|
|
|
// Make some optimizers with momentum (and thus state)
|
|
auto optim1 = torch::optim::SGD(
|
|
model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
auto optim2 = torch::optim::SGD(
|
|
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
auto optim2_2 = torch::optim::SGD(
|
|
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
auto optim3 = torch::optim::SGD(
|
|
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
auto optim3_2 = torch::optim::SGD(
|
|
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
|
|
auto x = torch::ones({10, 5});
|
|
|
|
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
|
optimizer.zero_grad();
|
|
auto y = model->forward(x).sum();
|
|
y.backward();
|
|
optimizer.step();
|
|
};
|
|
|
|
// Do 2 steps of model1
|
|
step(optim1, model1);
|
|
step(optim1, model1);
|
|
|
|
// Do 2 steps of model 2 without saving the optimizer
|
|
step(optim2, model2);
|
|
step(optim2_2, model2);
|
|
|
|
// Do 2 steps of model 3 while saving the optimizer
|
|
step(optim3, model3);
|
|
|
|
auto optim_tempfile = c10::make_tempfile();
|
|
torch::save(optim3, optim_tempfile.name);
|
|
torch::load(optim3_2, optim_tempfile.name);
|
|
step(optim3_2, model3);
|
|
|
|
param1 = model1->named_parameters();
|
|
param2 = model2->named_parameters();
|
|
param3 = model3->named_parameters();
|
|
for (const auto& p : param1) {
|
|
const auto& name = p.key();
|
|
// Model 1 and 3 should be the same
|
|
ASSERT_TRUE(
|
|
param1[name].norm().item<float>() == param3[name].norm().item<float>());
|
|
ASSERT_TRUE(
|
|
param1[name].norm().item<float>() != param2[name].norm().item<float>());
|
|
}
|
|
}
|
|
|
|
TEST(SerializeTest, SerializationShouldPreserveIteration_SGD) {
|
|
std::vector<torch::Tensor> parameters = {
|
|
torch::randn({2, 2}), torch::randn({3, 3})};
|
|
|
|
torch::optim::SGD optimizer(parameters, 1.0);
|
|
|
|
optimizer.step();
|
|
optimizer.step();
|
|
|
|
ASSERT_EQ(optimizer.iteration(), 2);
|
|
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::save(optimizer, tempfile.name);
|
|
|
|
torch::optim::SGD optimizer_out(parameters, 1.0);
|
|
ASSERT_EQ(optimizer_out.iteration(), 0);
|
|
|
|
torch::load(optimizer_out, tempfile.name);
|
|
ASSERT_EQ(optimizer_out.iteration(), 2);
|
|
}
|
|
|
|
TEST(SerializeTest, XOR_CUDA) {
|
|
torch::manual_seed(0);
|
|
// We better be able to save and load a XOR model!
|
|
auto getLoss = [](Sequential model,
|
|
uint32_t batch_size,
|
|
bool is_cuda = false) {
|
|
auto inputs = torch::empty({batch_size, 2});
|
|
auto labels = torch::empty({batch_size});
|
|
if (is_cuda) {
|
|
inputs = inputs.cuda();
|
|
labels = labels.cuda();
|
|
}
|
|
for (size_t i = 0; i < batch_size; i++) {
|
|
inputs[i] = torch::randint(2, {2}, torch::kInt64);
|
|
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
|
|
}
|
|
auto x = model->forward<torch::Tensor>(inputs);
|
|
return torch::binary_cross_entropy(x, labels);
|
|
};
|
|
|
|
auto model = xor_model();
|
|
auto model2 = xor_model();
|
|
auto model3 = xor_model();
|
|
auto optimizer = torch::optim::SGD(
|
|
model->parameters(),
|
|
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
|
|
1e-6));
|
|
|
|
float running_loss = 1;
|
|
int epoch = 0;
|
|
while (running_loss > 0.1) {
|
|
torch::Tensor loss = getLoss(model, 4);
|
|
optimizer.zero_grad();
|
|
loss.backward();
|
|
optimizer.step();
|
|
|
|
running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
|
|
ASSERT_LT(epoch, 3000);
|
|
epoch++;
|
|
}
|
|
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::save(model, tempfile.name);
|
|
torch::load(model2, tempfile.name);
|
|
|
|
auto loss = getLoss(model2, 100);
|
|
ASSERT_LT(loss.item<float>(), 0.1);
|
|
|
|
model2->to(torch::kCUDA);
|
|
loss = getLoss(model2, 100, true);
|
|
ASSERT_LT(loss.item<float>(), 0.1);
|
|
|
|
auto tempfile2 = c10::make_tempfile();
|
|
torch::save(model2, tempfile2.name);
|
|
torch::load(model3, tempfile2.name);
|
|
|
|
loss = getLoss(model3, 100, true);
|
|
ASSERT_LT(loss.item<float>(), 0.1);
|
|
}
|
|
|
|
TEST(
|
|
SerializeTest,
|
|
CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) {
|
|
struct C : torch::nn::Module {
|
|
C() {
|
|
register_buffer("foo", torch::ones(5, torch::kInt32));
|
|
}
|
|
};
|
|
struct B : torch::nn::Module {};
|
|
struct A : torch::nn::Module {
|
|
A() {
|
|
register_module("b", std::make_shared<B>());
|
|
register_module("c", std::make_shared<C>());
|
|
}
|
|
};
|
|
struct M : torch::nn::Module {
|
|
M() {
|
|
register_module("a", std::make_shared<A>());
|
|
}
|
|
};
|
|
|
|
auto out = std::make_shared<M>();
|
|
std::stringstream ss;
|
|
torch::save(out, ss);
|
|
auto in = std::make_shared<M>();
|
|
torch::load(in, ss);
|
|
|
|
const int output = in->named_buffers()["a.c.foo"].sum().item<int>();
|
|
ASSERT_EQ(output, 5);
|
|
}
|
|
|
|
TEST(SerializeTest, VectorOfTensors) {
|
|
torch::manual_seed(0);
|
|
|
|
std::vector<torch::Tensor> x_vec = { torch::randn({1, 2}), torch::randn({3, 4}) };
|
|
|
|
std::stringstream stream;
|
|
torch::save(x_vec, stream);
|
|
|
|
std::vector<torch::Tensor> y_vec;
|
|
torch::load(y_vec, stream);
|
|
|
|
for (int64_t i = 0; i < x_vec.size(); i++) {
|
|
auto& x = x_vec[i];
|
|
auto& y = y_vec[i];
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
}
|
|
}
|
|
|
|
TEST(SerializeTest, IValue) {
|
|
c10::IValue ivalue(1);
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::serialize::OutputArchive output_archive;
|
|
output_archive.write("value", ivalue);
|
|
output_archive.save_to(tempfile.name);
|
|
|
|
torch::serialize::InputArchive input_archive;
|
|
input_archive.load_from(tempfile.name);
|
|
c10::IValue ivalue_out;
|
|
input_archive.read("value", ivalue_out);
|
|
ASSERT_EQ(ivalue_out.toInt(), 1);
|
|
|
|
ASSERT_THROWS_WITH(input_archive.read("bad_key", ivalue_out), "does not have a field with the name");
|
|
}
|
|
|
|
// NOTE: if a `Module` contains unserializable submodules (e.g. `nn::Functional`),
|
|
// we expect those submodules to be skipped when the `Module` is being serialized.
|
|
TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) {
|
|
struct A : torch::nn::Module {
|
|
A() {
|
|
register_module("relu", torch::nn::Functional(torch::relu));
|
|
}
|
|
};
|
|
|
|
auto out = std::make_shared<A>();
|
|
std::stringstream ss;
|
|
torch::save(out, ss);
|
|
|
|
torch::serialize::InputArchive archive;
|
|
archive.load_from(ss);
|
|
torch::serialize::InputArchive relu_archive;
|
|
|
|
// Submodule with name "relu" should not exist in the `InputArchive`,
|
|
// because the "relu" submodule is an `nn::Functional` and is not serializable.
|
|
ASSERT_FALSE(archive.try_read("relu", relu_archive));
|
|
}
|
|
|
|
// NOTE: If a `Module` contains unserializable submodules (e.g. `nn::Functional`),
|
|
// we don't check the existence of those submodules in the `InputArchive` when
|
|
// deserializing.
|
|
TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) {
|
|
struct B : torch::nn::Module {
|
|
B() {
|
|
register_module("relu1", torch::nn::Functional(torch::relu));
|
|
register_buffer("foo", torch::zeros(5, torch::kInt32));
|
|
}
|
|
};
|
|
struct A : torch::nn::Module {
|
|
A() {
|
|
register_module("b", std::make_shared<B>());
|
|
register_module("relu2", torch::nn::Functional(torch::relu));
|
|
}
|
|
};
|
|
|
|
auto out = std::make_shared<A>();
|
|
// Manually change the values of "b.foo", so that we can check whether the buffer
|
|
// contains these values after deserialization.
|
|
out->named_buffers()["b.foo"].fill_(1);
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::save(out, tempfile.name);
|
|
|
|
torch::serialize::InputArchive archive;
|
|
archive.load_from(tempfile.name);
|
|
torch::serialize::InputArchive archive_b;
|
|
torch::serialize::InputArchive archive_relu;
|
|
torch::Tensor tensor_foo;
|
|
|
|
ASSERT_TRUE(archive.try_read("b", archive_b));
|
|
ASSERT_TRUE(archive_b.try_read("foo", tensor_foo, /*is_buffer=*/true));
|
|
|
|
// Submodule with name "relu1" should not exist in `archive_b`, because the "relu1"
|
|
// submodule is an `nn::Functional` and is not serializable.
|
|
ASSERT_FALSE(archive_b.try_read("relu1", archive_relu));
|
|
|
|
// Submodule with name "relu2" should not exist in `archive`, because the "relu2"
|
|
// submodule is an `nn::Functional` and is not serializable.
|
|
ASSERT_FALSE(archive.try_read("relu2", archive_relu));
|
|
|
|
auto in = std::make_shared<A>();
|
|
// `torch::load(...)` works without error, even though `A` contains the `nn::Functional`
|
|
// submodules while the serialized file doesn't, because the `nn::Functional` submodules
|
|
// are not serializable and thus ignored when deserializing.
|
|
torch::load(in, tempfile.name);
|
|
|
|
// Check that the "b.foo" buffer is correctly deserialized from the file.
|
|
const int output = in->named_buffers()["b.foo"].sum().item<int>();
|
|
// `output` should equal to the sum of the values we manually assigned to "b.foo" before
|
|
// serialization.
|
|
ASSERT_EQ(output, 5);
|
|
}
|