mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:17:11 +08:00
This PR makes libtorch behave the same as PyTorch when loading optimizer state from archive. With PyTorch, options of parameter groups are loaded from the archive, which is missing currently in libtorch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125215 Approved by: https://github.com/janeyx99
1096 lines
37 KiB
C++
1096 lines
37 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <c10/util/flat_hash_map.h>
|
|
#include <c10/util/irange.h>
|
|
#include <c10/util/tempfile.h>
|
|
|
|
#include <torch/torch.h>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
#include <cstdio>
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
using namespace torch::test;
|
|
using namespace torch::nn;
|
|
using namespace torch::optim;
|
|
|
|
namespace {
|
|
Sequential xor_model() {
|
|
return Sequential(
|
|
Linear(2, 8),
|
|
Functional(at::sigmoid),
|
|
Linear(8, 1),
|
|
Functional(at::sigmoid));
|
|
}
|
|
|
|
torch::Tensor save_and_load(torch::Tensor input) {
|
|
std::stringstream stream;
|
|
torch::save(input, stream);
|
|
torch::Tensor tensor;
|
|
torch::load(tensor, stream);
|
|
return tensor;
|
|
}
|
|
} // namespace
|
|
|
|
template <typename DerivedOptions>
|
|
void is_optimizer_param_group_equal(
|
|
const OptimizerParamGroup& lhs,
|
|
const OptimizerParamGroup& rhs) {
|
|
const auto& lhs_params = lhs.params();
|
|
const auto& rhs_params = rhs.params();
|
|
|
|
ASSERT_TRUE(lhs_params.size() == rhs_params.size());
|
|
for (const auto j : c10::irange(lhs_params.size())) {
|
|
ASSERT_TRUE(torch::equal(lhs_params[j], rhs_params[j]));
|
|
}
|
|
ASSERT_TRUE(
|
|
static_cast<const DerivedOptions&>(lhs.options()) ==
|
|
static_cast<const DerivedOptions&>(rhs.options()));
|
|
}
|
|
|
|
template <typename DerivedOptimizerParamState>
|
|
void is_optimizer_state_equal(
|
|
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
|
|
lhs_state,
|
|
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
|
|
rhs_state) {
|
|
ASSERT_TRUE(lhs_state.size() == rhs_state.size());
|
|
for (const auto& value : lhs_state) {
|
|
auto found = rhs_state.find(value.first);
|
|
ASSERT_TRUE(found != rhs_state.end());
|
|
const DerivedOptimizerParamState& lhs_curr_state =
|
|
static_cast<const DerivedOptimizerParamState&>(*(value.second.get()));
|
|
const DerivedOptimizerParamState& rhs_curr_state =
|
|
static_cast<const DerivedOptimizerParamState&>(*(found->second.get()));
|
|
ASSERT_TRUE(lhs_curr_state == rhs_curr_state);
|
|
}
|
|
}
|
|
|
|
template <
|
|
typename OptimizerClass,
|
|
typename DerivedOptimizerOptions,
|
|
typename DerivedOptimizerParamState>
|
|
void test_serialize_optimizer(
|
|
DerivedOptimizerOptions options,
|
|
bool only_has_global_state = false) {
|
|
torch::manual_seed(0);
|
|
auto model1 = Linear(5, 2);
|
|
auto model2 = Linear(5, 2);
|
|
auto model3 = Linear(5, 2);
|
|
|
|
// Models 1, 2, 3 will have the same parameters.
|
|
auto model_tempfile = c10::make_tempfile();
|
|
torch::save(model1, model_tempfile.name);
|
|
torch::load(model2, model_tempfile.name);
|
|
torch::load(model3, model_tempfile.name);
|
|
|
|
auto param1 = model1->named_parameters();
|
|
auto param2 = model2->named_parameters();
|
|
auto param3 = model3->named_parameters();
|
|
for (const auto& p : param1) {
|
|
ASSERT_TRUE(p->allclose(param2[p.key()]));
|
|
ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
|
|
}
|
|
// Make some optimizers
|
|
auto optim1 = OptimizerClass(
|
|
{torch::optim::OptimizerParamGroup(model1->parameters())}, options);
|
|
auto optim2 = OptimizerClass(model2->parameters(), options);
|
|
auto optim2_2 = OptimizerClass(model2->parameters(), options);
|
|
auto optim3 = OptimizerClass(model3->parameters(), options);
|
|
auto optim3_2 = OptimizerClass(model3->parameters(), options);
|
|
for (auto& param_group : optim3_2.param_groups()) {
|
|
const double lr = param_group.options().get_lr();
|
|
// change the learning rate, which will be overwritten by the loading
|
|
// otherwise, test cannot check if options are saved and loaded correctly
|
|
param_group.options().set_lr(lr + 0.01);
|
|
}
|
|
|
|
auto x = torch::ones({10, 5});
|
|
|
|
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
|
optimizer.zero_grad();
|
|
auto y = model->forward(x).sum();
|
|
y.backward();
|
|
auto closure = []() { return torch::tensor({10}); };
|
|
optimizer.step(closure);
|
|
};
|
|
|
|
// Do 2 steps of model1
|
|
step(optim1, model1);
|
|
step(optim1, model1);
|
|
|
|
// Do 2 steps of model 2 without saving the optimizer
|
|
step(optim2, model2);
|
|
step(optim2_2, model2);
|
|
|
|
// Do 1 step of model 3
|
|
step(optim3, model3);
|
|
|
|
// save the optimizer
|
|
auto optim_tempfile = c10::make_tempfile();
|
|
torch::save(optim3, optim_tempfile.name);
|
|
torch::load(optim3_2, optim_tempfile.name);
|
|
|
|
auto& optim3_2_param_groups = optim3_2.param_groups();
|
|
auto& optim3_param_groups = optim3.param_groups();
|
|
auto& optim3_2_state = optim3_2.state();
|
|
auto& optim3_state = optim3.state();
|
|
|
|
// optim3_2 and optim1 should have param_groups and state of size 1 and
|
|
// state_size respectively
|
|
ASSERT_TRUE(optim3_2_param_groups.size() == 1);
|
|
// state_size = 2 for all optimizers except LBFGS as LBFGS only maintains one
|
|
// global state
|
|
unsigned state_size = only_has_global_state ? 1 : 2;
|
|
ASSERT_TRUE(optim3_2_state.size() == state_size);
|
|
|
|
// optim3_2 and optim1 should have param_groups and state of same size
|
|
ASSERT_TRUE(optim3_2_param_groups.size() == optim3_param_groups.size());
|
|
ASSERT_TRUE(optim3_2_state.size() == optim3_state.size());
|
|
|
|
// checking correctness of serialization logic for optimizer.param_groups_ and
|
|
// optimizer.state_
|
|
for (const auto i : c10::irange(optim3_2_param_groups.size())) {
|
|
is_optimizer_param_group_equal<DerivedOptimizerOptions>(
|
|
optim3_2_param_groups[i], optim3_param_groups[i]);
|
|
is_optimizer_state_equal<DerivedOptimizerParamState>(
|
|
optim3_2_state, optim3_state);
|
|
}
|
|
|
|
// Do step2 for model 3
|
|
step(optim3_2, model3);
|
|
|
|
param1 = model1->named_parameters();
|
|
param2 = model2->named_parameters();
|
|
param3 = model3->named_parameters();
|
|
for (const auto& p : param1) {
|
|
const auto& name = p.key();
|
|
// Model 1 and 3 should be the same
|
|
ASSERT_TRUE(
|
|
param1[name].norm().item<float>() == param3[name].norm().item<float>());
|
|
ASSERT_TRUE(
|
|
param1[name].norm().item<float>() != param2[name].norm().item<float>());
|
|
}
|
|
}
|
|
|
|
/// Utility function to save a value of `int64_t` type.
|
|
void write_int_value(
|
|
torch::serialize::OutputArchive& archive,
|
|
const std::string& key,
|
|
const int64_t& value) {
|
|
archive.write(key, c10::IValue(value));
|
|
}
|
|
// Utility function to save a vector of buffers.
|
|
template <typename BufferContainer>
|
|
void write_tensors_to_archive(
|
|
torch::serialize::OutputArchive& archive,
|
|
const std::string& key,
|
|
const BufferContainer& buffers) {
|
|
archive.write(
|
|
key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
|
|
for (const auto index : c10::irange(buffers.size())) {
|
|
archive.write(
|
|
key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
|
|
}
|
|
}
|
|
|
|
// Utility function to save a vector of step buffers.
|
|
void write_step_buffers(
|
|
torch::serialize::OutputArchive& archive,
|
|
const std::string& key,
|
|
const std::vector<int64_t>& steps) {
|
|
std::vector<torch::Tensor> tensors;
|
|
tensors.reserve(steps.size());
|
|
for (const auto& step : steps) {
|
|
tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
|
|
}
|
|
write_tensors_to_archive(archive, key, tensors);
|
|
}
|
|
|
|
#define OLD_SERIALIZATION_LOGIC_WARNING_CHECK(funcname, optimizer, filename) \
|
|
{ \
|
|
WarningCapture warnings; \
|
|
funcname(optimizer, filename); \
|
|
ASSERT_EQ( \
|
|
count_substr_occurrences(warnings.str(), "old serialization"), 1); \
|
|
}
|
|
|
|
TEST(SerializeTest, KeysFunc) {
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::serialize::OutputArchive output_archive;
|
|
for (const auto i : c10::irange(3)) {
|
|
output_archive.write(
|
|
"element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
|
|
}
|
|
output_archive.save_to(tempfile.name);
|
|
torch::serialize::InputArchive input_archive;
|
|
input_archive.load_from(tempfile.name);
|
|
std::vector<std::string> keys = input_archive.keys();
|
|
ASSERT_EQ(keys.size(), 3);
|
|
for (const auto i : c10::irange(keys.size())) {
|
|
ASSERT_EQ(keys[i], "element/" + std::to_string(i));
|
|
}
|
|
}
|
|
|
|
TEST(SerializeTest, TryReadFunc) {
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::serialize::OutputArchive output_archive;
|
|
for (const auto i : c10::irange(3)) {
|
|
output_archive.write(
|
|
"element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
|
|
}
|
|
output_archive.save_to(tempfile.name);
|
|
torch::serialize::InputArchive input_archive;
|
|
input_archive.load_from(tempfile.name);
|
|
c10::IValue ivalue;
|
|
ASSERT_FALSE(input_archive.try_read("1", ivalue));
|
|
ASSERT_TRUE(input_archive.try_read("element/1", ivalue));
|
|
ASSERT_EQ(ivalue.toInt(), 1);
|
|
}
|
|
|
|
TEST(SerializeTest, Basic) {
|
|
torch::manual_seed(0);
|
|
|
|
auto x = torch::randn({5, 5});
|
|
auto y = save_and_load(x);
|
|
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
}
|
|
|
|
TEST(SerializeTest, MathBits) {
|
|
torch::manual_seed(0);
|
|
|
|
auto options = torch::TensorOptions{}.dtype(torch::kComplexFloat);
|
|
auto x = torch::randn({5, 5}, options);
|
|
{
|
|
auto expected = torch::conj(x);
|
|
auto actual = save_and_load(expected);
|
|
|
|
ASSERT_TRUE(actual.defined());
|
|
ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
|
|
ASSERT_TRUE(actual.allclose(expected));
|
|
}
|
|
|
|
{
|
|
auto expected = torch::_neg_view(x);
|
|
auto actual = save_and_load(expected);
|
|
|
|
ASSERT_TRUE(actual.defined());
|
|
ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
|
|
ASSERT_TRUE(actual.allclose(expected));
|
|
}
|
|
|
|
{
|
|
auto expected = torch::conj(torch::_neg_view(x));
|
|
auto actual = save_and_load(expected);
|
|
|
|
ASSERT_TRUE(actual.defined());
|
|
ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
|
|
ASSERT_TRUE(actual.allclose(expected));
|
|
}
|
|
|
|
{
|
|
// We don't support serializing `ZeroTensor` as it is not public facing yet.
|
|
// If in future, `ZeroTensor` serialization is supported, this test should
|
|
// start failing!
|
|
auto t = torch::_efficientzerotensor({5, 5});
|
|
ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,");
|
|
}
|
|
}
|
|
|
|
TEST(SerializeTest, BasicToFile) {
|
|
torch::manual_seed(0);
|
|
|
|
auto x = torch::randn({5, 5});
|
|
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::save(x, tempfile.name);
|
|
|
|
torch::Tensor y;
|
|
torch::load(y, tempfile.name);
|
|
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
}
|
|
|
|
TEST(SerializeTest, BasicViaFunc) {
|
|
torch::manual_seed(0);
|
|
|
|
auto x = torch::randn({5, 5});
|
|
|
|
std::string serialized;
|
|
torch::save(x, [&](const void* buf, size_t n) {
|
|
serialized.append(reinterpret_cast<const char*>(buf), n);
|
|
return n;
|
|
});
|
|
torch::Tensor y;
|
|
torch::load(y, serialized.data(), serialized.size());
|
|
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
|
|
torch::Tensor z;
|
|
torch::load(
|
|
z,
|
|
[&](uint64_t pos, void* buf, size_t n) -> size_t {
|
|
if (pos >= serialized.size())
|
|
return 0;
|
|
size_t nbytes =
|
|
std::min(static_cast<size_t>(pos) + n, serialized.size()) - pos;
|
|
memcpy(buf, serialized.data() + pos, nbytes);
|
|
return nbytes;
|
|
},
|
|
[&]() -> size_t { return serialized.size(); });
|
|
ASSERT_TRUE(z.defined());
|
|
ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(z));
|
|
}
|
|
|
|
TEST(SerializeTest, Resized) {
|
|
torch::manual_seed(0);
|
|
|
|
auto x = torch::randn({11, 5});
|
|
x.resize_({5, 5});
|
|
auto y = save_and_load(x);
|
|
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
}
|
|
|
|
TEST(SerializeTest, Sliced) {
|
|
torch::manual_seed(0);
|
|
|
|
auto x = torch::randn({11, 5});
|
|
x = x.slice(0, 1, 5);
|
|
auto y = save_and_load(x);
|
|
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
}
|
|
|
|
TEST(SerializeTest, NonContiguous) {
|
|
torch::manual_seed(0);
|
|
|
|
auto x = torch::randn({11, 5});
|
|
x = x.slice(1, 1, 4);
|
|
auto y = save_and_load(x);
|
|
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
}
|
|
|
|
TEST(SerializeTest, ErrorOnMissingKey) {
|
|
struct B : torch::nn::Module {
|
|
B(const std::string& name_c) {
|
|
register_buffer(name_c, torch::ones(5, torch::kFloat));
|
|
}
|
|
};
|
|
struct A : torch::nn::Module {
|
|
A(const std::string& name_b, const std::string& name_c) {
|
|
register_module(name_b, std::make_shared<B>(name_c));
|
|
}
|
|
};
|
|
struct M : torch::nn::Module {
|
|
M(const std::string& name_a,
|
|
const std::string& name_b,
|
|
const std::string& name_c) {
|
|
register_module(name_a, std::make_shared<A>(name_b, name_c));
|
|
}
|
|
};
|
|
|
|
// create a hierarchy of models with names differing below the top level
|
|
auto model1 = std::make_shared<M>("a", "b", "c");
|
|
auto model2 = std::make_shared<M>("a", "b", "x");
|
|
auto model3 = std::make_shared<M>("a", "x", "c");
|
|
|
|
std::stringstream stream;
|
|
torch::save(model1, stream);
|
|
// We want the errors to contain hierarchy information, too.
|
|
ASSERT_THROWS_WITH(
|
|
torch::load(model2, stream), "No such serialized tensor 'a.b.x'");
|
|
stream.seekg(0, stream.beg);
|
|
ASSERT_THROWS_WITH(
|
|
torch::load(model3, stream), "No such serialized submodule: 'a.x'");
|
|
}
|
|
|
|
TEST(SerializeTest, XOR) {
|
|
// We better be able to save and load an XOR model!
|
|
auto getLoss = [](Sequential model, uint32_t batch_size) {
|
|
auto inputs = torch::empty({batch_size, 2});
|
|
auto labels = torch::empty({batch_size});
|
|
for (const auto i : c10::irange(batch_size)) {
|
|
inputs[i] = torch::randint(2, {2}, torch::kInt64);
|
|
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
|
|
}
|
|
auto x = model->forward<torch::Tensor>(inputs);
|
|
return torch::binary_cross_entropy(x, labels);
|
|
};
|
|
|
|
auto model = xor_model();
|
|
auto model2 = xor_model();
|
|
auto model3 = xor_model();
|
|
auto optimizer = torch::optim::SGD(
|
|
model->parameters(),
|
|
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
|
|
1e-6));
|
|
|
|
float running_loss = 1;
|
|
int epoch = 0;
|
|
while (running_loss > 0.1) {
|
|
torch::Tensor loss = getLoss(model, 4);
|
|
optimizer.zero_grad();
|
|
loss.backward();
|
|
optimizer.step();
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
|
running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
|
|
ASSERT_LT(epoch, 3000);
|
|
epoch++;
|
|
}
|
|
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::save(model, tempfile.name);
|
|
torch::load(model2, tempfile.name);
|
|
|
|
auto loss = getLoss(model2, 100);
|
|
ASSERT_LT(loss.item<float>(), 0.1);
|
|
}
|
|
|
|
TEST(SerializeTest, Optim) {
|
|
auto model1 = Linear(5, 2);
|
|
auto model2 = Linear(5, 2);
|
|
auto model3 = Linear(5, 2);
|
|
|
|
// Models 1, 2, 3 will have the same parameters.
|
|
auto model_tempfile = c10::make_tempfile();
|
|
torch::save(model1, model_tempfile.name);
|
|
torch::load(model2, model_tempfile.name);
|
|
torch::load(model3, model_tempfile.name);
|
|
|
|
auto param1 = model1->named_parameters();
|
|
auto param2 = model2->named_parameters();
|
|
auto param3 = model3->named_parameters();
|
|
for (const auto& p : param1) {
|
|
ASSERT_TRUE(p->allclose(param2[p.key()]));
|
|
ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
|
|
}
|
|
|
|
// Make some optimizers with momentum (and thus state)
|
|
auto optim1 = torch::optim::SGD(
|
|
model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
auto optim2 = torch::optim::SGD(
|
|
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
auto optim2_2 = torch::optim::SGD(
|
|
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
auto optim3 = torch::optim::SGD(
|
|
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
auto optim3_2 = torch::optim::SGD(
|
|
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
|
|
auto x = torch::ones({10, 5});
|
|
|
|
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
|
optimizer.zero_grad();
|
|
auto y = model->forward(x).sum();
|
|
y.backward();
|
|
optimizer.step();
|
|
};
|
|
|
|
// Do 2 steps of model1
|
|
step(optim1, model1);
|
|
step(optim1, model1);
|
|
|
|
// Do 2 steps of model 2 without saving the optimizer
|
|
step(optim2, model2);
|
|
step(optim2_2, model2);
|
|
|
|
// Do 2 steps of model 3 while saving the optimizer
|
|
step(optim3, model3);
|
|
|
|
auto optim_tempfile = c10::make_tempfile();
|
|
torch::save(optim3, optim_tempfile.name);
|
|
torch::load(optim3_2, optim_tempfile.name);
|
|
step(optim3_2, model3);
|
|
|
|
param1 = model1->named_parameters();
|
|
param2 = model2->named_parameters();
|
|
param3 = model3->named_parameters();
|
|
for (const auto& p : param1) {
|
|
const auto& name = p.key();
|
|
// Model 1 and 3 should be the same
|
|
ASSERT_TRUE(
|
|
param1[name].norm().item<float>() == param3[name].norm().item<float>());
|
|
ASSERT_TRUE(
|
|
param1[name].norm().item<float>() != param2[name].norm().item<float>());
|
|
}
|
|
}
|
|
|
|
TEST(SerializeTest, Optim_Adagrad) {
|
|
test_serialize_optimizer<Adagrad, AdagradOptions, AdagradParamState>(
|
|
AdagradOptions(1e-1));
|
|
|
|
// bc compatibility check
|
|
auto model1 = Linear(5, 2);
|
|
auto optim1 = torch::optim::Adagrad(
|
|
model1->parameters(), torch::optim::AdagradOptions(1e-1));
|
|
|
|
auto x = torch::ones({10, 5});
|
|
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
|
optimizer.zero_grad();
|
|
auto y = model->forward(x).sum();
|
|
y.backward();
|
|
optimizer.step();
|
|
};
|
|
step(optim1, model1);
|
|
auto optim1_2 =
|
|
Adagrad(model1->parameters(), torch::optim::AdagradOptions(1e-1));
|
|
|
|
// fill up with optim1 sum_buffers
|
|
std::vector<torch::Tensor> sum_buffers;
|
|
// fill up with optim1 state_buffers
|
|
std::vector<int64_t> step_buffers;
|
|
const auto& params_ = optim1.param_groups()[0].params();
|
|
const auto& optim1_state = optim1.state();
|
|
for (const auto& param : params_) {
|
|
auto key_ = param.unsafeGetTensorImpl();
|
|
const AdagradParamState& curr_state_ =
|
|
static_cast<const AdagradParamState&>(*(optim1_state.at(key_).get()));
|
|
sum_buffers.emplace_back(curr_state_.sum());
|
|
step_buffers.emplace_back(curr_state_.step());
|
|
}
|
|
// write sum_buffers and step_buffers to the file
|
|
auto optim_tempfile_old_format = c10::make_tempfile();
|
|
torch::serialize::OutputArchive output_archive;
|
|
write_tensors_to_archive(output_archive, "sum_buffers", sum_buffers);
|
|
write_step_buffers(output_archive, "step_buffers", step_buffers);
|
|
output_archive.save_to(optim_tempfile_old_format.name);
|
|
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
|
|
torch::load, optim1_2, optim_tempfile_old_format.name);
|
|
is_optimizer_state_equal<AdagradParamState>(optim1.state(), optim1_2.state());
|
|
}
|
|
|
|
TEST(SerializeTest, Optim_SGD) {
|
|
test_serialize_optimizer<SGD, SGDOptions, SGDParamState>(
|
|
SGDOptions(1e-1).momentum(0.9));
|
|
|
|
// bc compatibility check
|
|
auto model1 = Linear(5, 2);
|
|
auto model1_params = model1->parameters();
|
|
// added a tensor for lazy init check - when all params do not have a momentum
|
|
// buffer entry
|
|
model1_params.emplace_back(torch::randn({2, 3}));
|
|
auto optim1 = torch::optim::SGD(
|
|
model1_params, torch::optim::SGDOptions(0.01).momentum(0.9));
|
|
|
|
auto x = torch::ones({10, 5});
|
|
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
|
optimizer.zero_grad();
|
|
auto y = model->forward(x).sum();
|
|
y.backward();
|
|
optimizer.step();
|
|
};
|
|
step(optim1, model1);
|
|
|
|
std::vector<at::Tensor> momentum_buffers;
|
|
int64_t iteration_{0};
|
|
const auto& params_ = optim1.param_groups()[0].params();
|
|
const auto& optim1_state = optim1.state();
|
|
for (const auto i : c10::irange(params_.size())) {
|
|
if (i != (params_.size() - 1)) {
|
|
auto key_ = params_[i].unsafeGetTensorImpl();
|
|
const SGDParamState& curr_state_ =
|
|
static_cast<const SGDParamState&>(*(optim1_state.at(key_).get()));
|
|
momentum_buffers.emplace_back(curr_state_.momentum_buffer());
|
|
}
|
|
}
|
|
ASSERT_TRUE(momentum_buffers.size() == (params_.size() - 1));
|
|
// write momentum_buffers to the file
|
|
auto optim_tempfile_old_format = c10::make_tempfile();
|
|
torch::serialize::OutputArchive output_archive;
|
|
write_tensors_to_archive(
|
|
output_archive, "momentum_buffers", momentum_buffers);
|
|
write_int_value(output_archive, "iteration_", iteration_);
|
|
output_archive.save_to(optim_tempfile_old_format.name);
|
|
auto optim1_2 =
|
|
SGD(model1_params, torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
|
|
torch::load, optim1_2, optim_tempfile_old_format.name);
|
|
is_optimizer_state_equal<SGDParamState>(optim1.state(), optim1_2.state());
|
|
}
|
|
|
|
TEST(SerializeTest, Optim_Adam) {
|
|
test_serialize_optimizer<Adam, AdamOptions, AdamParamState>(
|
|
AdamOptions().lr(0.99999).amsgrad(true).weight_decay(0.5));
|
|
|
|
// bc compatibility check
|
|
auto model1 = Linear(5, 2);
|
|
auto model1_params = model1->parameters();
|
|
// added a tensor for lazy init check - when all params do not have entry in
|
|
// buffers
|
|
model1_params.emplace_back(torch::randn({2, 3}));
|
|
auto optim1 = torch::optim::Adam(
|
|
model1_params, torch::optim::AdamOptions().weight_decay(0.5));
|
|
|
|
auto x = torch::ones({10, 5});
|
|
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
|
optimizer.zero_grad();
|
|
auto y = model->forward(x).sum();
|
|
y.backward();
|
|
optimizer.step();
|
|
};
|
|
step(optim1, model1);
|
|
|
|
std::vector<int64_t> step_buffers;
|
|
std::vector<at::Tensor> exp_average_buffers;
|
|
std::vector<at::Tensor> exp_average_sq_buffers;
|
|
std::vector<at::Tensor> max_exp_average_sq_buffers;
|
|
const auto& params_ = optim1.param_groups()[0].params();
|
|
const auto& optim1_state = optim1.state();
|
|
for (const auto i : c10::irange(params_.size())) {
|
|
if (i != (params_.size() - 1)) {
|
|
auto key_ = params_[i].unsafeGetTensorImpl();
|
|
const AdamParamState& curr_state_ =
|
|
static_cast<const AdamParamState&>(*(optim1_state.at(key_).get()));
|
|
step_buffers.emplace_back(curr_state_.step());
|
|
exp_average_buffers.emplace_back(curr_state_.exp_avg());
|
|
exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
|
|
if (curr_state_.max_exp_avg_sq().defined()) {
|
|
max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
|
|
}
|
|
}
|
|
}
|
|
// write buffers to the file
|
|
auto optim_tempfile_old_format = c10::make_tempfile();
|
|
torch::serialize::OutputArchive output_archive;
|
|
write_step_buffers(output_archive, "step_buffers", step_buffers);
|
|
write_tensors_to_archive(
|
|
output_archive, "exp_average_buffers", exp_average_buffers);
|
|
write_tensors_to_archive(
|
|
output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
|
|
write_tensors_to_archive(
|
|
output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
|
|
output_archive.save_to(optim_tempfile_old_format.name);
|
|
auto optim1_2 = Adam(model1_params, torch::optim::AdamOptions());
|
|
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
|
|
torch::load, optim1_2, optim_tempfile_old_format.name);
|
|
is_optimizer_state_equal<AdamParamState>(optim1.state(), optim1_2.state());
|
|
}
|
|
|
|
TEST(SerializeTest, Optim_AdamW) {
|
|
test_serialize_optimizer<AdamW, AdamWOptions, AdamWParamState>(
|
|
AdamWOptions().lr(0.99999).amsgrad(true).betas(
|
|
std::make_tuple(0.999, 0.1)));
|
|
|
|
// bc compatibility check
|
|
auto model1 = Linear(5, 2);
|
|
auto model1_params = model1->parameters();
|
|
// added a tensor for lazy init check - when all params do not have entry in
|
|
// buffers
|
|
model1_params.emplace_back(torch::randn({2, 3}));
|
|
auto optim1 = torch::optim::AdamW(
|
|
model1_params, torch::optim::AdamWOptions().weight_decay(0.5));
|
|
|
|
auto x = torch::ones({10, 5});
|
|
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
|
optimizer.zero_grad();
|
|
auto y = model->forward(x).sum();
|
|
y.backward();
|
|
optimizer.step();
|
|
};
|
|
step(optim1, model1);
|
|
|
|
std::vector<int64_t> step_buffers;
|
|
std::vector<at::Tensor> exp_average_buffers;
|
|
std::vector<at::Tensor> exp_average_sq_buffers;
|
|
std::vector<at::Tensor> max_exp_average_sq_buffers;
|
|
const auto& params_ = optim1.param_groups()[0].params();
|
|
const auto& optim1_state = optim1.state();
|
|
for (const auto i : c10::irange(params_.size())) {
|
|
if (i != (params_.size() - 1)) {
|
|
auto key_ = params_[i].unsafeGetTensorImpl();
|
|
const AdamWParamState& curr_state_ =
|
|
static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
|
|
step_buffers.emplace_back(curr_state_.step());
|
|
exp_average_buffers.emplace_back(curr_state_.exp_avg());
|
|
exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
|
|
if (curr_state_.max_exp_avg_sq().defined()) {
|
|
max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
|
|
}
|
|
}
|
|
}
|
|
// write buffers to the file
|
|
auto optim_tempfile_old_format = c10::make_tempfile();
|
|
torch::serialize::OutputArchive output_archive;
|
|
write_step_buffers(output_archive, "step_buffers", step_buffers);
|
|
write_tensors_to_archive(
|
|
output_archive, "exp_average_buffers", exp_average_buffers);
|
|
write_tensors_to_archive(
|
|
output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
|
|
write_tensors_to_archive(
|
|
output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
|
|
output_archive.save_to(optim_tempfile_old_format.name);
|
|
auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions());
|
|
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
|
|
torch::load, optim1_2, optim_tempfile_old_format.name);
|
|
is_optimizer_state_equal<AdamWParamState>(optim1.state(), optim1_2.state());
|
|
}
|
|
|
|
TEST(SerializeTest, Optim_RMSprop) {
|
|
auto options = RMSpropOptions(0.1).momentum(0.9).centered(true);
|
|
test_serialize_optimizer<RMSprop, RMSpropOptions, RMSpropParamState>(options);
|
|
|
|
// bc compatibility check
|
|
auto model1 = Linear(5, 2);
|
|
auto model1_params = model1->parameters();
|
|
|
|
// added a tensor for lazy init check - when all params do not have a momentum
|
|
// buffer entry
|
|
model1_params.emplace_back(torch::randn({2, 3}));
|
|
auto optim1 = torch::optim::RMSprop(model1_params, options);
|
|
|
|
auto x = torch::ones({10, 5});
|
|
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
|
optimizer.zero_grad();
|
|
auto y = model->forward(x).sum();
|
|
y.backward();
|
|
optimizer.step();
|
|
};
|
|
step(optim1, model1);
|
|
|
|
std::vector<at::Tensor> square_average_buffers;
|
|
std::vector<at::Tensor> momentum_buffers;
|
|
std::vector<at::Tensor> grad_average_buffers;
|
|
const auto& params_ = optim1.param_groups()[0].params();
|
|
const auto& optim1_state = optim1.state();
|
|
for (const auto i : c10::irange(params_.size())) {
|
|
if (i != (params_.size() - 1)) {
|
|
auto key_ = params_[i].unsafeGetTensorImpl();
|
|
const RMSpropParamState& curr_state_ =
|
|
static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
|
|
square_average_buffers.emplace_back(curr_state_.square_avg());
|
|
if (curr_state_.momentum_buffer().defined()) {
|
|
momentum_buffers.emplace_back(curr_state_.momentum_buffer());
|
|
}
|
|
if (curr_state_.grad_avg().defined()) {
|
|
grad_average_buffers.emplace_back(curr_state_.grad_avg());
|
|
}
|
|
}
|
|
}
|
|
// write buffers to the file
|
|
auto optim_tempfile_old_format = c10::make_tempfile();
|
|
torch::serialize::OutputArchive output_archive;
|
|
write_tensors_to_archive(
|
|
output_archive, "square_average_buffers", square_average_buffers);
|
|
write_tensors_to_archive(
|
|
output_archive, "momentum_buffers", momentum_buffers);
|
|
write_tensors_to_archive(
|
|
output_archive, "grad_average_buffers", grad_average_buffers);
|
|
output_archive.save_to(optim_tempfile_old_format.name);
|
|
auto optim1_2 = RMSprop(model1_params, options);
|
|
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
|
|
torch::load, optim1_2, optim_tempfile_old_format.name);
|
|
const auto& params1_2_ = optim1_2.param_groups()[0].params();
|
|
auto& optim1_2_state = optim1_2.state();
|
|
// old RMSprop didn't track step value
|
|
for (const auto i : c10::irange(params1_2_.size())) {
|
|
if (i != (params1_2_.size() - 1)) {
|
|
auto key_ = params_[i].unsafeGetTensorImpl();
|
|
auto key1_2_ = params1_2_[i].unsafeGetTensorImpl();
|
|
const RMSpropParamState& curr_state_ =
|
|
static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
|
|
RMSpropParamState& curr_state1_2_ =
|
|
static_cast<RMSpropParamState&>(*(optim1_2_state.at(key_).get()));
|
|
curr_state1_2_.step(curr_state_.step());
|
|
}
|
|
}
|
|
is_optimizer_state_equal<RMSpropParamState>(optim1.state(), optim1_2.state());
|
|
}
|
|
|
|
TEST(SerializeTest, Optim_LBFGS) {
|
|
test_serialize_optimizer<LBFGS, LBFGSOptions, LBFGSParamState>(
|
|
LBFGSOptions(), true);
|
|
// bc compatibility check
|
|
auto model1 = Linear(5, 2);
|
|
auto model1_params = model1->parameters();
|
|
// added a tensor for lazy init check - when all params do not have entry in
|
|
// buffers
|
|
model1_params.emplace_back(torch::randn({2, 3}));
|
|
auto optim1 =
|
|
torch::optim::LBFGS(model1_params, torch::optim::LBFGSOptions());
|
|
|
|
auto x = torch::ones({10, 5});
|
|
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
|
optimizer.zero_grad();
|
|
auto y = model->forward(x).sum();
|
|
y.backward();
|
|
auto closure = []() { return torch::tensor({10}); };
|
|
optimizer.step(closure);
|
|
};
|
|
|
|
step(optim1, model1);
|
|
|
|
at::Tensor d, t, H_diag, prev_flat_grad, prev_loss;
|
|
std::deque<at::Tensor> old_dirs, old_stps;
|
|
|
|
const auto& params_ = optim1.param_groups()[0].params();
|
|
auto key_ = params_[0].unsafeGetTensorImpl();
|
|
const auto& optim1_state =
|
|
static_cast<const LBFGSParamState&>(*(optim1.state().at(key_).get()));
|
|
d = optim1_state.d();
|
|
t = at::tensor(optim1_state.t());
|
|
H_diag = optim1_state.H_diag();
|
|
prev_flat_grad = optim1_state.prev_flat_grad();
|
|
prev_loss = at::tensor(optim1_state.prev_loss());
|
|
old_dirs = optim1_state.old_dirs();
|
|
|
|
// write buffers to the file
|
|
auto optim_tempfile_old_format = c10::make_tempfile();
|
|
torch::serialize::OutputArchive output_archive;
|
|
output_archive.write("d", d, /*is_buffer=*/true);
|
|
output_archive.write("t", t, /*is_buffer=*/true);
|
|
output_archive.write("H_diag", H_diag, /*is_buffer=*/true);
|
|
output_archive.write("prev_flat_grad", prev_flat_grad, /*is_buffer=*/true);
|
|
output_archive.write("prev_loss", prev_loss, /*is_buffer=*/true);
|
|
write_tensors_to_archive(output_archive, "old_dirs", old_dirs);
|
|
write_tensors_to_archive(output_archive, "old_stps", old_stps);
|
|
output_archive.save_to(optim_tempfile_old_format.name);
|
|
|
|
auto optim1_2 = LBFGS(model1_params, torch::optim::LBFGSOptions());
|
|
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
|
|
torch::load, optim1_2, optim_tempfile_old_format.name);
|
|
|
|
const auto& params1_2_ = optim1_2.param_groups()[0].params();
|
|
auto param_key = params1_2_[0].unsafeGetTensorImpl();
|
|
auto& optim1_2_state =
|
|
static_cast<LBFGSParamState&>(*(optim1_2.state().at(param_key).get()));
|
|
|
|
// old LBFGS didn't track func_evals, n_iter, ro, al values
|
|
optim1_2_state.func_evals(optim1_state.func_evals());
|
|
optim1_2_state.n_iter(optim1_state.n_iter());
|
|
optim1_2_state.ro(optim1_state.ro());
|
|
optim1_2_state.al(optim1_state.al());
|
|
|
|
is_optimizer_state_equal<LBFGSParamState>(optim1.state(), optim1_2.state());
|
|
}
|
|
|
|
TEST(SerializeTest, XOR_CUDA) {
|
|
torch::manual_seed(0);
|
|
// We better be able to save and load a XOR model!
|
|
auto getLoss = [](Sequential model,
|
|
uint32_t batch_size,
|
|
bool is_cuda = false) {
|
|
auto inputs = torch::empty({batch_size, 2});
|
|
auto labels = torch::empty({batch_size});
|
|
if (is_cuda) {
|
|
inputs = inputs.cuda();
|
|
labels = labels.cuda();
|
|
}
|
|
for (const auto i : c10::irange(batch_size)) {
|
|
inputs[i] = torch::randint(2, {2}, torch::kInt64);
|
|
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
|
|
}
|
|
auto x = model->forward<torch::Tensor>(inputs);
|
|
return torch::binary_cross_entropy(x, labels);
|
|
};
|
|
|
|
auto model = xor_model();
|
|
auto model2 = xor_model();
|
|
auto model3 = xor_model();
|
|
auto optimizer = torch::optim::SGD(
|
|
model->parameters(),
|
|
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
|
|
1e-6));
|
|
|
|
float running_loss = 1;
|
|
int epoch = 0;
|
|
while (running_loss > 0.1) {
|
|
torch::Tensor loss = getLoss(model, 4);
|
|
optimizer.zero_grad();
|
|
loss.backward();
|
|
optimizer.step();
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
|
running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
|
|
ASSERT_LT(epoch, 3000);
|
|
epoch++;
|
|
}
|
|
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::save(model, tempfile.name);
|
|
torch::load(model2, tempfile.name);
|
|
|
|
auto loss = getLoss(model2, 100);
|
|
ASSERT_LT(loss.item<float>(), 0.1);
|
|
|
|
model2->to(torch::kCUDA);
|
|
loss = getLoss(model2, 100, true);
|
|
ASSERT_LT(loss.item<float>(), 0.1);
|
|
|
|
auto tempfile2 = c10::make_tempfile();
|
|
torch::save(model2, tempfile2.name);
|
|
torch::load(model3, tempfile2.name);
|
|
|
|
loss = getLoss(model3, 100, true);
|
|
ASSERT_LT(loss.item<float>(), 0.1);
|
|
}
|
|
|
|
TEST(
|
|
SerializeTest,
|
|
CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) {
|
|
struct C : torch::nn::Module {
|
|
C() {
|
|
register_buffer("foo", torch::ones(5, torch::kInt32));
|
|
}
|
|
};
|
|
struct B : torch::nn::Module {};
|
|
struct A : torch::nn::Module {
|
|
A() {
|
|
register_module("b", std::make_shared<B>());
|
|
register_module("c", std::make_shared<C>());
|
|
}
|
|
};
|
|
struct M : torch::nn::Module {
|
|
M() {
|
|
register_module("a", std::make_shared<A>());
|
|
}
|
|
};
|
|
|
|
auto out = std::make_shared<M>();
|
|
std::stringstream ss;
|
|
torch::save(out, ss);
|
|
auto in = std::make_shared<M>();
|
|
torch::load(in, ss);
|
|
|
|
const int output = in->named_buffers()["a.c.foo"].sum().item<int>();
|
|
ASSERT_EQ(output, 5);
|
|
}
|
|
|
|
TEST(SerializeTest, VectorOfTensors) {
|
|
torch::manual_seed(0);
|
|
|
|
std::vector<torch::Tensor> x_vec = {
|
|
torch::randn({1, 2}), torch::randn({3, 4})};
|
|
|
|
std::stringstream stream;
|
|
torch::save(x_vec, stream);
|
|
|
|
std::vector<torch::Tensor> y_vec;
|
|
torch::load(y_vec, stream);
|
|
|
|
for (const auto i : c10::irange(x_vec.size())) {
|
|
auto& x = x_vec[i];
|
|
auto& y = y_vec[i];
|
|
ASSERT_TRUE(y.defined());
|
|
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
|
|
ASSERT_TRUE(x.allclose(y));
|
|
}
|
|
}
|
|
|
|
TEST(SerializeTest, IValue) {
|
|
c10::IValue ivalue(1);
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::serialize::OutputArchive output_archive;
|
|
output_archive.write("value", ivalue);
|
|
output_archive.save_to(tempfile.name);
|
|
|
|
torch::serialize::InputArchive input_archive;
|
|
input_archive.load_from(tempfile.name);
|
|
c10::IValue ivalue_out;
|
|
input_archive.read("value", ivalue_out);
|
|
ASSERT_EQ(ivalue_out.toInt(), 1);
|
|
|
|
ASSERT_THROWS_WITH(
|
|
input_archive.read("bad_key", ivalue_out),
|
|
"does not have a field with name");
|
|
}
|
|
|
|
// NOTE: if a `Module` contains unserializable submodules (e.g.
|
|
// `nn::Functional`), we expect those submodules to be skipped when the `Module`
|
|
// is being serialized.
|
|
TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) {
|
|
struct A : torch::nn::Module {
|
|
A() {
|
|
register_module("relu", torch::nn::Functional(torch::relu));
|
|
}
|
|
};
|
|
|
|
auto out = std::make_shared<A>();
|
|
std::stringstream ss;
|
|
torch::save(out, ss);
|
|
|
|
torch::serialize::InputArchive archive;
|
|
archive.load_from(ss);
|
|
torch::serialize::InputArchive relu_archive;
|
|
|
|
// Submodule with name "relu" should not exist in the `InputArchive`,
|
|
// because the "relu" submodule is an `nn::Functional` and is not
|
|
// serializable.
|
|
ASSERT_FALSE(archive.try_read("relu", relu_archive));
|
|
}
|
|
|
|
// NOTE: If a `Module` contains unserializable submodules (e.g.
|
|
// `nn::Functional`), we don't check the existence of those submodules in the
|
|
// `InputArchive` when deserializing.
|
|
TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) {
|
|
struct B : torch::nn::Module {
|
|
B() {
|
|
register_module("relu1", torch::nn::Functional(torch::relu));
|
|
register_buffer("foo", torch::zeros(5, torch::kInt32));
|
|
}
|
|
};
|
|
struct A : torch::nn::Module {
|
|
A() {
|
|
register_module("b", std::make_shared<B>());
|
|
register_module("relu2", torch::nn::Functional(torch::relu));
|
|
}
|
|
};
|
|
|
|
auto out = std::make_shared<A>();
|
|
// Manually change the values of "b.foo", so that we can check whether the
|
|
// buffer contains these values after deserialization.
|
|
out->named_buffers()["b.foo"].fill_(1);
|
|
auto tempfile = c10::make_tempfile();
|
|
torch::save(out, tempfile.name);
|
|
|
|
torch::serialize::InputArchive archive;
|
|
archive.load_from(tempfile.name);
|
|
torch::serialize::InputArchive archive_b;
|
|
torch::serialize::InputArchive archive_relu;
|
|
torch::Tensor tensor_foo;
|
|
|
|
ASSERT_TRUE(archive.try_read("b", archive_b));
|
|
ASSERT_TRUE(archive_b.try_read("foo", tensor_foo, /*is_buffer=*/true));
|
|
|
|
// Submodule with name "relu1" should not exist in `archive_b`, because the
|
|
// "relu1" submodule is an `nn::Functional` and is not serializable.
|
|
ASSERT_FALSE(archive_b.try_read("relu1", archive_relu));
|
|
|
|
// Submodule with name "relu2" should not exist in `archive`, because the
|
|
// "relu2" submodule is an `nn::Functional` and is not serializable.
|
|
ASSERT_FALSE(archive.try_read("relu2", archive_relu));
|
|
|
|
auto in = std::make_shared<A>();
|
|
// `torch::load(...)` works without error, even though `A` contains the
|
|
// `nn::Functional` submodules while the serialized file doesn't, because the
|
|
// `nn::Functional` submodules are not serializable and thus ignored when
|
|
// deserializing.
|
|
torch::load(in, tempfile.name);
|
|
|
|
// Check that the "b.foo" buffer is correctly deserialized from the file.
|
|
const int output = in->named_buffers()["b.foo"].sum().item<int>();
|
|
// `output` should equal to the sum of the values we manually assigned to
|
|
// "b.foo" before serialization.
|
|
ASSERT_EQ(output, 5);
|
|
}
|