InstanceNorm{1,2,3}d (#28790)

Summary:
Hi yf225,

I have a few doubts related to implementation:
1) What tests do I have to write?
2) What does _load_state_from_dict does?
3) Do I need to override reset() function as I can not see it's utility?
4) InstanceNormOptions could be removed with BatchNormOptions, but I find that
`track_running_status` is not defined instead `stateful` is defined.

InstanceNorm{1,2,3}d https://github.com/pytorch/pytorch/issues/25883
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28790

Differential Revision: D18588666

Pulled By: yf225

fbshipit-source-id: bb9b81f01f62c3fc8765fa0ba0716768087ee155
This commit is contained in:
Divyansh Singhvi
2019-11-19 16:53:37 -08:00
committed by Facebook Github Bot
parent 8e3486de81
commit ec52d911bd
15 changed files with 880 additions and 115 deletions

View File

@ -566,6 +566,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/activation.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/batchnorm.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/normalization.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/instancenorm.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/conv.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/dropout.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/distance.cpp
@ -583,11 +584,11 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/api/src/nn/options/activation.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/batchnorm.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/embedding.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/instancenorm.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/normalization.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/conv.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/dropout.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/linear.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/normalization.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/padding.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/pooling.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/rnn.cpp

View File

@ -1579,6 +1579,236 @@ TEST_F(FunctionalTest, BatchNorm3dDefaultOptions) {
ASSERT_TRUE(output.allclose(expected));
}
TEST_F(FunctionalTest, InstanceNorm1d) {
int num_features = 5;
double eps = 1e-05;
double momentum = 0.1;
auto input = torch::arange(40.).view({2, 5, 4});
auto mean = torch::arange(5.);
auto variance = torch::arange(5.);
auto weight = torch::arange((double)num_features);
auto bias = torch::arange((double)num_features);
auto output = F::instance_norm(
input,
F::InstanceNormFuncOptions()
.running_mean(mean)
.running_var(variance)
.weight(weight)
.bias(bias)
.momentum(momentum)
.eps(eps));
auto expected = torch::tensor({{{ 0.0000, 0.0000, 0.0000, 0.0000},
{-0.3416, 0.5528, 1.4472, 2.3416},
{-0.6833, 1.1056, 2.8944, 4.6833},
{-1.0249, 1.6584, 4.3416, 7.0249},
{-1.3665, 2.2112, 5.7888, 9.3665}},
{{ 0.0000, 0.0000, 0.0000, 0.0000},
{-0.3416, 0.5528, 1.4472, 2.3416},
{-0.6833, 1.1056, 2.8944, 4.6833},
{-1.0249, 1.6584, 4.3416, 7.0249},
{-1.3665, 2.2112, 5.7888, 9.3665}}});
ASSERT_TRUE(output.allclose(expected, 2e-04));
}
TEST_F(FunctionalTest, InstanceNorm1dDefaultOptions) {
auto input = torch::arange(40.).view({2, 5, 4});
auto output = F::instance_norm(input);
auto expected = torch::tensor({{{-1.3416, -0.4472, 0.4472, 1.3416},
{-1.3416, -0.4472, 0.4472, 1.3416},
{-1.3416, -0.4472, 0.4472, 1.3416},
{-1.3416, -0.4472, 0.4472, 1.3416},
{-1.3416, -0.4472, 0.4472, 1.3416}},
{{-1.3416, -0.4472, 0.4472, 1.3416},
{-1.3416, -0.4472, 0.4472, 1.3416},
{-1.3416, -0.4472, 0.4472, 1.3416},
{-1.3416, -0.4472, 0.4472, 1.3416},
{-1.3416, -0.4472, 0.4472, 1.3416}}});
ASSERT_TRUE(output.allclose(expected, 2e-04));
}
TEST_F(FunctionalTest, InstanceNorm2d) {
int num_features = 5;
double eps = 1e-05;
double momentum = 0.1;
auto input = torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
auto mean = torch::arange((double)num_features);
auto variance = torch::arange((double)num_features);
auto weight = torch::arange((double)num_features);
auto bias = torch::arange((double)num_features);
auto output = F::instance_norm(
input,
F::InstanceNormFuncOptions()
.running_mean(mean)
.running_var(variance)
.weight(weight)
.bias(bias)
.momentum(momentum)
.eps(eps));
auto expected = torch::tensor({{{{ 0.0000, 0.0000},
{ 0.0000, 0.0000}},
{{-0.3416, 0.5528},
{ 1.4472, 2.3416}},
{{-0.6833, 1.1056},
{ 2.8944, 4.6833}},
{{-1.0249, 1.6584},
{ 4.3416, 7.0249}},
{{-1.3665, 2.2112},
{ 5.7888, 9.3665}}},
{{{ 0.0000, 0.0000},
{ 0.0000, 0.0000}},
{{-0.3416, 0.5528},
{ 1.4472, 2.3416}},
{{-0.6833, 1.1056},
{ 2.8944, 4.6833}},
{{-1.0249, 1.6584},
{ 4.3416, 7.0249}},
{{-1.3665, 2.2112},
{ 5.7888, 9.3665}}}});
ASSERT_TRUE(output.allclose(expected, 2e-04));
}
TEST_F(FunctionalTest, InstanceNorm2dDefaultOptions) {
int num_features = 5;
double eps = 1e-05;
auto input = torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
auto output = F::instance_norm(input);
auto expected = torch::tensor({{{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}}},
{{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}}}});
ASSERT_TRUE(output.allclose(expected, 2e-04));
}
TEST_F(FunctionalTest, InstanceNorm3d) {
int num_features = 5;
double eps = 1e-05;
double momentum = 0.1;
auto input = torch::arange(2. * num_features * 2 * 2 * 2).view({2, num_features, 2, 2, 2});
auto mean = torch::arange((double)num_features);
auto variance = torch::arange((double)num_features);
auto weight = torch::arange((double)num_features);
auto bias = torch::arange((double)num_features);
auto output = F::instance_norm(
input,
F::InstanceNormFuncOptions()
.running_mean(mean)
.running_var(variance)
.weight(weight)
.bias(bias)
.momentum(momentum)
.eps(eps));
auto expected = torch::tensor({{{{{ 0.0000, 0.0000},
{ 0.0000, 0.0000}},
{{ 0.0000, 0.0000},
{ 0.0000, 0.0000}}},
{{{-0.5275, -0.0911},
{ 0.3453, 0.7818}},
{{ 1.2182, 1.6547},
{ 2.0911, 2.5275}}},
{{{-1.0550, -0.1822},
{ 0.6907, 1.5636}},
{{ 2.4364, 3.3093},
{ 4.1822, 5.0550}}},
{{{-1.5826, -0.2733},
{ 1.0360, 2.3453}},
{{ 3.6547, 4.9640},
{ 6.2733, 7.5826}}},
{{{-2.1101, -0.3644},
{ 1.3814, 3.1271}},
{{ 4.8729, 6.6186},
{ 8.3644, 10.1101}}}},
{{{{ 0.0000, 0.0000},
{ 0.0000, 0.0000}},
{{ 0.0000, 0.0000},
{ 0.0000, 0.0000}}},
{{{-0.5275, -0.0911},
{ 0.3453, 0.7818}},
{{ 1.2182, 1.6547},
{ 2.0911, 2.5275}}},
{{{-1.0550, -0.1822},
{ 0.6907, 1.5636}},
{{ 2.4364, 3.3093},
{ 4.1822, 5.0550}}},
{{{-1.5826, -0.2733},
{ 1.0360, 2.3453}},
{{ 3.6547, 4.9640},
{ 6.2733, 7.5826}}},
{{{-2.1101, -0.3644},
{ 1.3814, 3.1271}},
{{ 4.8729, 6.6186},
{ 8.3644, 10.1101}}}}});
ASSERT_TRUE(output.allclose(expected, 2e-04));
}
TEST_F(FunctionalTest, InstanceNorm3dDefaultOptions) {
int num_features = 5;
double eps = 1e-05;
auto input = torch::arange(2. * num_features * 2 * 2 * 2).view({2, num_features, 2, 2, 2});
auto output = F::instance_norm(input);
auto expected = torch::tensor({{{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}}},
{{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}}}});
ASSERT_TRUE(output.allclose(expected, 2e-04));
}
TEST_F(FunctionalTest, Interpolate) {
{
// 1D interpolation

View File

@ -1427,7 +1427,7 @@ TEST_F(ModulesTest, BatchNormLegacyWarning) {
}
TEST_F(ModulesTest, BatchNorm1dStateful) {
BatchNorm1d bn(BatchNorm1dOptions(5));
BatchNorm1d bn(5);
ASSERT_TRUE(bn->options.track_running_stats());
@ -1464,20 +1464,30 @@ TEST_F(ModulesTest, BatchNorm1dStateless) {
}
TEST_F(ModulesTest, BatchNorm1d) {
BatchNorm1d bn(BatchNorm1dOptions(5));
BatchNorm1d bn(5);
bn->eval();
auto input = torch::randn({2, 5}, torch::requires_grad());
auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
auto output = bn->forward(input);
auto expected = torch::tensor({{{ 0.0000, 1.0000},
{ 2.0000, 3.0000},
{ 4.0000, 5.0000},
{ 6.0000, 7.0000},
{ 8.0000, 9.0000}},
{{10.0000, 10.9999},
{11.9999, 12.9999},
{13.9999, 14.9999},
{15.9999, 16.9999},
{17.9999, 18.9999}}});
ASSERT_TRUE(output.allclose(expected));
auto s = output.sum();
s.backward();
ASSERT_EQ(input.sizes(), input.grad().sizes());
ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5})));
}
TEST_F(ModulesTest, BatchNorm2dStateful) {
BatchNorm2d bn(BatchNorm2dOptions(5));
BatchNorm2d bn(5);
ASSERT_TRUE(bn->options.track_running_stats());
@ -1514,20 +1524,40 @@ TEST_F(ModulesTest, BatchNorm2dStateless) {
}
TEST_F(ModulesTest, BatchNorm2d) {
BatchNorm2d bn(BatchNorm2dOptions(5));
BatchNorm2d bn(5);
bn->eval();
auto input = torch::randn({2, 5, 4, 4}, torch::requires_grad());
auto input = torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
auto output = bn->forward(input);
auto expected = torch::tensor({{{{ 0.0000, 1.0000},
{ 2.0000, 3.0000}},
{{ 4.0000, 5.0000},
{ 6.0000, 7.0000}},
{{ 8.0000, 9.0000},
{10.0000, 10.9999}},
{{11.9999, 12.9999},
{13.9999, 14.9999}},
{{15.9999, 16.9999},
{17.9999, 18.9999}}},
{{{19.9999, 20.9999},
{21.9999, 22.9999}},
{{23.9999, 24.9999},
{25.9999, 26.9999}},
{{27.9999, 28.9999},
{29.9998, 30.9998}},
{{31.9998, 32.9998},
{33.9998, 34.9998}},
{{35.9998, 36.9998},
{37.9998, 38.9998}}}});
ASSERT_TRUE(output.allclose(expected));
auto s = output.sum();
s.backward();
ASSERT_EQ(input.sizes(), input.grad().sizes());
ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5, 4, 4})));
}
TEST_F(ModulesTest, BatchNorm3dStateful) {
BatchNorm3d bn(BatchNorm3dOptions(5));
BatchNorm3d bn(5);
ASSERT_TRUE(bn->options.track_running_stats());
@ -1564,16 +1594,276 @@ TEST_F(ModulesTest, BatchNorm3dStateless) {
}
TEST_F(ModulesTest, BatchNorm3d) {
BatchNorm3d bn(BatchNorm3dOptions(5));
BatchNorm3d bn(5);
bn->eval();
auto input = torch::randn({2, 5, 4, 4, 4}, torch::requires_grad());
auto input = torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
auto output = bn->forward(input);
auto expected = torch::tensor({{{{{ 0.0000, 1.0000},
{ 2.0000, 3.0000}},
{{ 4.0000, 5.0000},
{ 6.0000, 7.0000}}},
{{{ 8.0000, 9.0000},
{10.0000, 10.9999}},
{{11.9999, 12.9999},
{13.9999, 14.9999}}},
{{{15.9999, 16.9999},
{17.9999, 18.9999}},
{{19.9999, 20.9999},
{21.9999, 22.9999}}},
{{{23.9999, 24.9999},
{25.9999, 26.9999}},
{{27.9999, 28.9999},
{29.9998, 30.9998}}},
{{{31.9998, 32.9998},
{33.9998, 34.9998}},
{{35.9998, 36.9998},
{37.9998, 38.9998}}}},
{{{{39.9998, 40.9998},
{41.9998, 42.9998}},
{{43.9998, 44.9998},
{45.9998, 46.9998}}},
{{{47.9998, 48.9998},
{49.9997, 50.9997}},
{{51.9997, 52.9997},
{53.9997, 54.9997}}},
{{{55.9997, 56.9997},
{57.9997, 58.9997}},
{{59.9997, 60.9997},
{61.9997, 62.9997}}},
{{{63.9997, 64.9997},
{65.9997, 66.9997}},
{{67.9997, 68.9997},
{69.9996, 70.9996}}},
{{{71.9996, 72.9996},
{73.9996, 74.9996}},
{{75.9996, 76.9996},
{77.9996, 78.9996}}}}});
ASSERT_TRUE(output.allclose(expected));
auto s = output.sum();
s.backward();
ASSERT_EQ(input.sizes(), input.grad().sizes());
ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5, 4, 4, 4})));
}
TEST_F(ModulesTest, InstanceNorm1dStateful) {
InstanceNorm1d instance_norm(InstanceNorm1dOptions(5).track_running_stats(true).affine(true));
ASSERT_TRUE(instance_norm->options.track_running_stats());
ASSERT_TRUE(instance_norm->running_mean.defined());
ASSERT_EQ(instance_norm->running_mean.dim(), 1);
ASSERT_EQ(instance_norm->running_mean.size(0), 5);
ASSERT_TRUE(instance_norm->running_var.defined());
ASSERT_EQ(instance_norm->running_var.dim(), 1);
ASSERT_EQ(instance_norm->running_var.size(0), 5);
ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
ASSERT_TRUE(instance_norm->options.affine());
ASSERT_TRUE(instance_norm->weight.defined());
ASSERT_EQ(instance_norm->weight.dim(), 1);
ASSERT_EQ(instance_norm->weight.size(0), 5);
ASSERT_TRUE(instance_norm->bias.defined());
ASSERT_EQ(instance_norm->bias.dim(), 1);
ASSERT_EQ(instance_norm->bias.size(0), 5);
}
TEST_F(ModulesTest, InstanceNorm1dStateless) {
InstanceNorm1d instance_norm(InstanceNorm1dOptions(5).track_running_stats(false).affine(false));
ASSERT_FALSE(instance_norm->running_mean.defined());
ASSERT_FALSE(instance_norm->running_var.defined());
ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
ASSERT_FALSE(instance_norm->weight.defined());
ASSERT_FALSE(instance_norm->bias.defined());
}
TEST_F(ModulesTest, InstanceNorm1d) {
InstanceNorm1d instance_norm(5);
instance_norm->eval();
auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
auto output = instance_norm->forward(input);
auto expected = torch::tensor({{{-1.0000, 1.0000},
{-1.0000, 1.0000},
{-1.0000, 1.0000},
{-1.0000, 1.0000},
{-1.0000, 1.0000}},
{{-1.0000, 1.0000},
{-1.0000, 1.0000},
{-1.0000, 1.0000},
{-1.0000, 1.0000},
{-1.0000, 1.0000}}});
ASSERT_TRUE(output.allclose(expected, 1e-3));
auto s = output.sum();
s.backward();
ASSERT_EQ(input.sizes(), input.grad().sizes());
}
TEST_F(ModulesTest, InstanceNorm2dStateful) {
InstanceNorm2d instance_norm(InstanceNorm2dOptions(5).track_running_stats(true).affine(true));
ASSERT_TRUE(instance_norm->options.track_running_stats());
ASSERT_TRUE(instance_norm->running_mean.defined());
ASSERT_EQ(instance_norm->running_mean.dim(), 1);
ASSERT_EQ(instance_norm->running_mean.size(0), 5);
ASSERT_TRUE(instance_norm->running_var.defined());
ASSERT_EQ(instance_norm->running_var.dim(), 1);
ASSERT_EQ(instance_norm->running_var.size(0), 5);
ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
ASSERT_TRUE(instance_norm->options.affine());
ASSERT_TRUE(instance_norm->weight.defined());
ASSERT_EQ(instance_norm->weight.dim(), 1);
ASSERT_EQ(instance_norm->weight.size(0), 5);
ASSERT_TRUE(instance_norm->bias.defined());
ASSERT_EQ(instance_norm->bias.dim(), 1);
ASSERT_EQ(instance_norm->bias.size(0), 5);
}
TEST_F(ModulesTest, InstanceNorm2dStateless) {
InstanceNorm2d instance_norm(InstanceNorm2dOptions(5).track_running_stats(false).affine(false));
ASSERT_FALSE(instance_norm->running_mean.defined());
ASSERT_FALSE(instance_norm->running_var.defined());
ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
ASSERT_FALSE(instance_norm->weight.defined());
ASSERT_FALSE(instance_norm->bias.defined());
}
TEST_F(ModulesTest, InstanceNorm2d) {
InstanceNorm2d instance_norm(5);
instance_norm->eval();
auto input = torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
auto output = instance_norm->forward(input);
auto expected = torch::tensor({{{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}}},
{{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}},
{{-1.3416, -0.4472},
{ 0.4472, 1.3416}}}});
ASSERT_TRUE(output.allclose(expected, 1e-3));
auto s = output.sum();
s.backward();
ASSERT_EQ(input.sizes(), input.grad().sizes());
}
TEST_F(ModulesTest, InstanceNorm3dStateful) {
InstanceNorm3d instance_norm(InstanceNorm3dOptions(5).track_running_stats(true).affine(true));
ASSERT_TRUE(instance_norm->options.track_running_stats());
ASSERT_TRUE(instance_norm->running_mean.defined());
ASSERT_EQ(instance_norm->running_mean.dim(), 1);
ASSERT_EQ(instance_norm->running_mean.size(0), 5);
ASSERT_TRUE(instance_norm->running_var.defined());
ASSERT_EQ(instance_norm->running_var.dim(), 1);
ASSERT_EQ(instance_norm->running_var.size(0), 5);
ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
ASSERT_TRUE(instance_norm->options.affine());
ASSERT_TRUE(instance_norm->weight.defined());
ASSERT_EQ(instance_norm->weight.dim(), 1);
ASSERT_EQ(instance_norm->weight.size(0), 5);
ASSERT_TRUE(instance_norm->bias.defined());
ASSERT_EQ(instance_norm->bias.dim(), 1);
ASSERT_EQ(instance_norm->bias.size(0), 5);
}
TEST_F(ModulesTest, InstanceNorm3dStateless) {
InstanceNorm3d instance_norm(InstanceNorm3dOptions(5).track_running_stats(false).affine(false));
ASSERT_FALSE(instance_norm->running_mean.defined());
ASSERT_FALSE(instance_norm->running_var.defined());
ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
ASSERT_FALSE(instance_norm->weight.defined());
ASSERT_FALSE(instance_norm->bias.defined());
}
TEST_F(ModulesTest, InstanceNorm3d) {
InstanceNorm3d instance_norm(5);
instance_norm->eval();
auto input = torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
auto output = instance_norm->forward(input);
auto expected = torch::tensor({{{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}}},
{{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}},
{{{-1.5275, -1.0911},
{-0.6547, -0.2182}},
{{ 0.2182, 0.6547},
{ 1.0911, 1.5275}}}}});
ASSERT_TRUE(output.allclose(expected, 1e-3));
auto s = output.sum();
s.backward();
ASSERT_EQ(input.sizes(), input.grad().sizes());
}
TEST_F(ModulesTest, Linear_CUDA) {
@ -3178,6 +3468,30 @@ TEST_F(ModulesTest, PrettyPrintBatchNorm3d) {
"torch::nn::BatchNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
}
TEST_F(ModulesTest, PrettyPrintInstanceNorm1d) {
ASSERT_EQ(
c10::str(InstanceNorm1d(
InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false)
.track_running_stats(true))),
"torch::nn::InstanceNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
}
TEST_F(ModulesTest, PrettyPrintInstanceNorm2d) {
ASSERT_EQ(
c10::str(InstanceNorm2d(
InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false)
.track_running_stats(true))),
"torch::nn::InstanceNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
}
TEST_F(ModulesTest, PrettyPrintInstanceNorm3d) {
ASSERT_EQ(
c10::str(InstanceNorm3d(
InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false)
.track_running_stats(true))),
"torch::nn::InstanceNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
}
TEST_F(ModulesTest, PrettyPrintLayerNorm) {
ASSERT_EQ(
c10::str(LayerNorm(LayerNormOptions({2, 2}))),

View File

@ -75,9 +75,9 @@ torch.nn.BatchNorm2d|Yes|No
torch.nn.BatchNorm3d|Yes|No
torch.nn.GroupNorm|Yes|No
torch.nn.SyncBatchNorm|No|No
torch.nn.InstanceNorm1d|No|No
torch.nn.InstanceNorm2d|No|No
torch.nn.InstanceNorm3d|No|No
torch.nn.InstanceNorm1d|Yes|No
torch.nn.InstanceNorm2d|Yes|No
torch.nn.InstanceNorm3d|Yes|No
torch.nn.LayerNorm|Yes|No
torch.nn.LocalResponseNorm|Yes|No
torch.nn.CrossMapLRN2d|Yes|No

View File

@ -221,6 +221,7 @@ def add_torch_libs():
"torch/csrc/api/src/nn/modules/activation.cpp",
"torch/csrc/api/src/nn/modules/batchnorm.cpp",
"torch/csrc/api/src/nn/modules/normalization.cpp",
"torch/csrc/api/src/nn/modules/instancenorm.cpp",
"torch/csrc/api/src/nn/modules/conv.cpp",
"torch/csrc/api/src/nn/modules/dropout.cpp",
"torch/csrc/api/src/nn/modules/distance.cpp",
@ -239,6 +240,7 @@ def add_torch_libs():
"torch/csrc/api/src/nn/options/batchnorm.cpp",
"torch/csrc/api/src/nn/options/conv.cpp",
"torch/csrc/api/src/nn/options/dropout.cpp",
"torch/csrc/api/src/nn/options/instancenorm.cpp",
"torch/csrc/api/src/nn/options/linear.cpp",
"torch/csrc/api/src/nn/options/normalization.cpp",
"torch/csrc/api/src/nn/options/embedding.cpp",

View File

@ -14,3 +14,4 @@
#include <torch/nn/functional/pooling.h>
#include <torch/nn/functional/upsampling.h>
#include <torch/nn/functional/vision.h>
#include <torch/nn/functional/instancenorm.h>

View File

@ -0,0 +1,30 @@
#pragma once
#include <torch/nn/options/instancenorm.h>
namespace torch {
namespace nn {
namespace functional {
namespace detail {
inline Tensor instance_norm(const Tensor& input, const Tensor& running_mean,
const Tensor& running_var, const Tensor& weight, const Tensor& bias,
bool use_input_stats, double momentum, double eps) {
return torch::instance_norm(
input, weight, bias, running_mean, running_var,
use_input_stats, momentum, eps, at::globalContext().userEnabledCuDNN()
);
}
} // namespace detail
inline Tensor instance_norm(const Tensor& input, const InstanceNormFuncOptions& options = {}) {
return detail::instance_norm(
input, options.running_mean(),
options.running_var(), options.weight(), options.bias(),
options.use_input_stats(), options.momentum(), options.eps());
}
} // namespace functional
} // namespace nn
} // namespace torch

View File

@ -9,6 +9,7 @@
// Layers
#include <torch/nn/modules/batchnorm.h>
#include <torch/nn/modules/instancenorm.h>
#include <torch/nn/modules/conv.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/modules/distance.h>

View File

@ -1,12 +1,16 @@
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/functional/batchnorm.h>
#include <torch/nn/options/batchnorm.h>
#include <torch/nn/init.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
#include <cstdint>
namespace F = torch::nn::functional;
namespace torch {
namespace nn {
@ -77,28 +81,55 @@ TORCH_MODULE(BatchNorm);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Base class for all (dimension-specialized) batchnorm modules.
template <size_t D, typename Derived>
class TORCH_API BatchNormImplBase : public torch::nn::Cloneable<Derived> {
/// Base class for all (dimension-specialized) batchnorm and instancenorm modules.
template <size_t D, typename Derived, typename DerivedOptions>
class NormImplBase : public torch::nn::Cloneable<Derived> {
protected:
virtual void _check_input_dim(const Tensor& input) = 0;
public:
explicit BatchNormImplBase(const BatchNormOptions& options_);
NormImplBase(const DerivedOptions& options_) : options(options_) {
reset();
}
Tensor forward(const Tensor& input);
void reset() override {
if (options.affine()) {
weight = this->register_parameter("weight", torch::empty({options.num_features()}));
bias = this->register_parameter("bias", torch::empty({options.num_features()}));
} else {
weight = this->register_parameter("weight", Tensor());
bias = this->register_parameter("bias", Tensor());
}
if (options.track_running_stats()) {
running_mean = this->register_buffer("running_mean", torch::zeros({options.num_features()}));
running_var = this->register_buffer("running_var", torch::ones({options.num_features()}));
num_batches_tracked = this->register_buffer("num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
} else {
running_mean = this->register_buffer("running_mean", Tensor());
running_var = this->register_buffer("running_var", Tensor());
num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor());
}
reset_parameters();
}
void reset() override;
void reset_running_stats() {
if (options.track_running_stats()) {
running_mean.zero_();
running_var.fill_(1);
num_batches_tracked.zero_();
}
}
void reset_running_stats();
void reset_parameters();
/// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
void reset_parameters() {
reset_running_stats();
if (options.affine()) {
torch::nn::init::ones_(weight);
torch::nn::init::zeros_(bias);
}
}
/// The options with which this module was constructed.
BatchNormOptions options;
DerivedOptions options;
/// The learned weight.
/// Only defined if the `affine` option was `true` upon construction.
@ -121,6 +152,47 @@ class TORCH_API BatchNormImplBase : public torch::nn::Cloneable<Derived> {
Tensor num_batches_tracked;
};
/// Base class for all (dimension-specialized) batchnorm modules.
template <size_t D, typename Derived>
class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> {
public:
using NormImplBase<D, Derived, BatchNormOptions>::NormImplBase;
Tensor forward(const Tensor& input) {
this->_check_input_dim(input);
double exponential_average_factor;
if (this->options.momentum() == c10::nullopt) {
exponential_average_factor = 0.0;
} else {
exponential_average_factor = this->options.momentum().value();
}
if (this->is_training() && this->options.track_running_stats()) {
if (this->num_batches_tracked.defined()) {
this->num_batches_tracked += 1;
if (this->options.momentum() == c10::nullopt) { // use cumulative moving average
exponential_average_factor = 1.0 / this->num_batches_tracked.template item<double>();
} else { // use exponential moving average
exponential_average_factor = this->options.momentum().value();
}
}
}
return F::detail::batch_norm(
input,
this->running_mean,
this->running_var,
this->weight,
this->bias,
this->is_training() || !this->options.track_running_stats(),
/*momentum=*/exponential_average_factor,
this->options.eps());
}
/// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
};
/// Applies the BatchNorm1d function.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm1d to learn
/// about the exact behavior of this module.

View File

@ -0,0 +1,66 @@
#pragma once
#include <torch/nn/modules/batchnorm.h>
#include <torch/nn/options/instancenorm.h>
namespace torch {
namespace nn {
/// Base class for all (dimension-specialized) instance norm modules
template <size_t D, typename Derived>
class InstanceNormImpl : public torch::nn::NormImplBase<D, Derived, InstanceNormOptions> {
public:
using torch::nn::NormImplBase<D, Derived, InstanceNormOptions>::NormImplBase;
Tensor forward(const Tensor& input) {
this->_check_input_dim(input);
return F::detail::instance_norm(
input, this->running_mean, this->running_var, this->weight, this->bias,
this->is_training() || !this->options.track_running_stats(), this->options.momentum(), this->options.eps());
}
/// Pretty prints the `InstanceNorm{1,2,3}d` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
};
/// Applies the InstanceNorm1d function.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm1d to learn
/// about the exact behavior of this module.
class TORCH_API InstanceNorm1dImpl : public InstanceNormImpl<1, InstanceNorm1dImpl> {
protected:
virtual void _check_input_dim(const Tensor& input) override;
public:
using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl;
};
TORCH_MODULE(InstanceNorm1d);
/// Applies the InstanceNorm2d function.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm2d to learn
/// about the exact behavior of this module.
class TORCH_API InstanceNorm2dImpl : public InstanceNormImpl<2, InstanceNorm2dImpl> {
protected:
virtual void _check_input_dim(const Tensor& input) override;
public:
using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl;
};
TORCH_MODULE(InstanceNorm2d);
/// Applies the InstanceNorm3d function.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm3d to learn
/// about the exact behavior of this module.
class TORCH_API InstanceNorm3dImpl : public InstanceNormImpl<3, InstanceNorm3dImpl> {
protected:
virtual void _check_input_dim(const Tensor& input) override;
public:
using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl;
};
TORCH_MODULE(InstanceNorm3d);
} // namespace nn
} // namespace torch

View File

@ -43,7 +43,7 @@ using BatchNorm3dOptions = BatchNormOptions;
namespace functional {
/// Options for the `BatchNorm` module.
/// Options for the `BatchNorm` functional.
struct TORCH_API BatchNormFuncOptions {
TORCH_ARG(Tensor, weight) = Tensor();

View File

@ -0,0 +1,59 @@
#pragma once
#include <torch/arg.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/nn/options/batchnorm.h>
#include <torch/types.h>
namespace torch {
namespace nn {
/// Options for the `InstanceNorm` module.
struct TORCH_API InstanceNormOptions {
/* implicit */ InstanceNormOptions(int64_t num_features);
/// The number of features of the input tensor.
TORCH_ARG(int64_t, num_features);
/// The epsilon value added for numerical stability.
TORCH_ARG(double, eps) = 1e-5;
/// A momentum multiplier for the mean and variance.
TORCH_ARG(double, momentum) = 0.1;
/// Whether to learn a scale and bias that are applied in an affine
/// transformation on the input.
TORCH_ARG(bool, affine) = false;
/// Whether to store and update batch statistics (mean and variance) in the
/// module.
TORCH_ARG(bool, track_running_stats) = false;
};
using InstanceNorm1dOptions = InstanceNormOptions;
using InstanceNorm2dOptions = InstanceNormOptions;
using InstanceNorm3dOptions = InstanceNormOptions;
namespace functional {
/// Options for the `InstanceNorm` functional.
struct TORCH_API InstanceNormFuncOptions {
TORCH_ARG(Tensor, running_mean) = Tensor();
TORCH_ARG(Tensor, running_var) = Tensor();
TORCH_ARG(Tensor, weight) = Tensor();
TORCH_ARG(Tensor, bias) = Tensor();
TORCH_ARG(bool, use_input_stats) = true;
TORCH_ARG(double, momentum) = 0.1;
TORCH_ARG(double, eps) = 1e-5;
};
} // namespace functional
} // namespace nn
} // namespace torch

View File

@ -3,7 +3,6 @@
#include <torch/cuda.h>
#include <torch/types.h>
#include <torch/nn/init.h>
#include <c10/util/Exception.h>
@ -17,7 +16,7 @@ namespace F = torch::nn::functional;
namespace torch {
namespace nn {
BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) {
BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value)
TORCH_WARN("torch::nn::BatchNorm module is deprecated and will be removed in 1.5. "
"Use BatchNorm{1,2,3}d instead.");
reset();
@ -78,93 +77,17 @@ Tensor BatchNormImpl::pure_forward(
torch::cuda::cudnn_is_available());
}
template <size_t D, typename Derived>
BatchNormImplBase<D, Derived>::BatchNormImplBase(const BatchNormOptions& options_)
: options(options_) {
reset();
}
// ===========================================================================
template <size_t D, typename Derived>
void BatchNormImplBase<D, Derived>::reset() {
if (options.affine()) {
weight = this->register_parameter("weight", torch::empty({options.num_features()}));
bias = this->register_parameter("bias", torch::empty({options.num_features()}));
} else {
weight = this->register_parameter("weight", Tensor());
bias = this->register_parameter("bias", Tensor());
}
if (options.track_running_stats()) {
running_mean = this->register_buffer("running_mean", torch::zeros({options.num_features()}));
running_var = this->register_buffer("running_var", torch::ones({options.num_features()}));
num_batches_tracked = this->register_buffer("num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
} else {
running_mean = this->register_buffer("running_mean", Tensor());
running_var = this->register_buffer("running_var", Tensor());
num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor());
}
reset_parameters();
}
template <size_t D, typename Derived>
void BatchNormImplBase<D, Derived>::reset_running_stats() {
if (options.track_running_stats()) {
running_mean.zero_();
running_var.fill_(1);
num_batches_tracked.zero_();
}
}
template <size_t D, typename Derived>
void BatchNormImplBase<D, Derived>::reset_parameters() {
reset_running_stats();
if (options.affine()) {
torch::nn::init::ones_(weight);
torch::nn::init::zeros_(bias);
}
}
template <size_t D, typename Derived>
template <size_t D, typename Derived>
void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const {
stream << std::boolalpha
<< "torch::nn::BatchNorm" << D << "d("
<< options.num_features() << ", "
<< "eps=" << options.eps() << ", "
<< "momentum=" << options.momentum().value() << ", "
<< "affine=" << options.affine() << ", "
<< "track_running_stats=" << options.track_running_stats() << ")";
}
template <size_t D, typename Derived>
Tensor BatchNormImplBase<D, Derived>::forward(const Tensor& input) {
_check_input_dim(input);
double exponential_average_factor;
if (options.momentum() == c10::nullopt) {
exponential_average_factor = 0.0;
} else {
exponential_average_factor = options.momentum().value();
}
if (this->is_training() && options.track_running_stats()) {
if (num_batches_tracked.defined()) {
num_batches_tracked += 1;
if (options.momentum() == c10::nullopt) { // use cumulative moving average
exponential_average_factor = 1.0 / num_batches_tracked.item<double>();
} else { // use exponential moving average
exponential_average_factor = options.momentum().value();
}
}
}
return F::detail::batch_norm(
input,
running_mean,
running_var,
weight,
bias,
this->is_training() || !options.track_running_stats(),
/*momentum=*/exponential_average_factor,
options.eps());
<< this->options.num_features() << ", "
<< "eps=" << this->options.eps() << ", "
<< "momentum=" << this->options.momentum().value() << ", "
<< "affine=" << this->options.affine() << ", "
<< "track_running_stats=" << this->options.track_running_stats() << ")";
}
void BatchNorm1dImpl::_check_input_dim(const Tensor& input) {

View File

@ -0,0 +1,57 @@
#include <torch/nn/functional/instancenorm.h>
#include <torch/nn/modules/instancenorm.h>
namespace F = torch::nn::functional;
namespace torch {
namespace nn {
template <size_t D, typename Derived>
void InstanceNormImpl<D, Derived>::pretty_print(std::ostream& stream) const {
stream << std::boolalpha
<< "torch::nn::InstanceNorm" << D << "d("
<< this->options.num_features() << ", "
<< "eps=" << this->options.eps() << ", "
<< "momentum=" << this->options.momentum() << ", "
<< "affine=" << this->options.affine() << ", "
<< "track_running_stats=" << this->options.track_running_stats() << ")";
}
void InstanceNorm1dImpl::_check_input_dim(const Tensor& input) {
if (input.dim() == 2) {
TORCH_CHECK(
false,
"InstanceNorm1d returns 0-filled tensor to 2D tensor.",
"This is because InstanceNorm1d reshapes inputs to",
"(1, N * C, ...) from (N, C,...) and this makes",
"variances 0.");
}
if (input.dim() != 3) {
TORCH_CHECK(
false,
"expected 3D input (got ", input.dim(), "D input)");
}
}
void InstanceNorm2dImpl::_check_input_dim(const Tensor& input) {
if (input.dim() != 4) {
TORCH_CHECK(
false,
"expected 4D input (got ", input.dim(), "D input)");
}
}
void InstanceNorm3dImpl::_check_input_dim(const Tensor& input) {
if (input.dim() != 5) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
TORCH_CHECK(
false,
"expected 5D input (got ", input.dim(), "D input)");
}
}
template class InstanceNormImpl<1, InstanceNorm1dImpl>;
template class InstanceNormImpl<2, InstanceNorm2dImpl>;
template class InstanceNormImpl<3, InstanceNorm3dImpl>;
} // namespace nn
} // namespace torch

View File

@ -0,0 +1,9 @@
#include <torch/nn/options/instancenorm.h>
namespace torch {
namespace nn {
InstanceNormOptions::InstanceNormOptions(int64_t num_features) : num_features_(num_features) {}
} // namespace nn
} // namespace torch