mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
InstanceNorm{1,2,3}d (#28790)
Summary: Hi yf225, I have a few doubts related to implementation: 1) What tests do I have to write? 2) What does _load_state_from_dict does? 3) Do I need to override reset() function as I can not see it's utility? 4) InstanceNormOptions could be removed with BatchNormOptions, but I find that `track_running_status` is not defined instead `stateful` is defined. InstanceNorm{1,2,3}d https://github.com/pytorch/pytorch/issues/25883 Pull Request resolved: https://github.com/pytorch/pytorch/pull/28790 Differential Revision: D18588666 Pulled By: yf225 fbshipit-source-id: bb9b81f01f62c3fc8765fa0ba0716768087ee155
This commit is contained in:
committed by
Facebook Github Bot
parent
8e3486de81
commit
ec52d911bd
@ -566,6 +566,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/activation.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/batchnorm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/normalization.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/instancenorm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/conv.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/dropout.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/distance.cpp
|
||||
@ -583,11 +584,11 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/activation.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/batchnorm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/embedding.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/instancenorm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/normalization.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/conv.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/dropout.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/linear.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/normalization.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/padding.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/pooling.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/rnn.cpp
|
||||
|
@ -1579,6 +1579,236 @@ TEST_F(FunctionalTest, BatchNorm3dDefaultOptions) {
|
||||
ASSERT_TRUE(output.allclose(expected));
|
||||
}
|
||||
|
||||
TEST_F(FunctionalTest, InstanceNorm1d) {
|
||||
int num_features = 5;
|
||||
double eps = 1e-05;
|
||||
double momentum = 0.1;
|
||||
|
||||
auto input = torch::arange(40.).view({2, 5, 4});
|
||||
auto mean = torch::arange(5.);
|
||||
auto variance = torch::arange(5.);
|
||||
auto weight = torch::arange((double)num_features);
|
||||
auto bias = torch::arange((double)num_features);
|
||||
auto output = F::instance_norm(
|
||||
input,
|
||||
F::InstanceNormFuncOptions()
|
||||
.running_mean(mean)
|
||||
.running_var(variance)
|
||||
.weight(weight)
|
||||
.bias(bias)
|
||||
.momentum(momentum)
|
||||
.eps(eps));
|
||||
auto expected = torch::tensor({{{ 0.0000, 0.0000, 0.0000, 0.0000},
|
||||
{-0.3416, 0.5528, 1.4472, 2.3416},
|
||||
{-0.6833, 1.1056, 2.8944, 4.6833},
|
||||
{-1.0249, 1.6584, 4.3416, 7.0249},
|
||||
{-1.3665, 2.2112, 5.7888, 9.3665}},
|
||||
{{ 0.0000, 0.0000, 0.0000, 0.0000},
|
||||
{-0.3416, 0.5528, 1.4472, 2.3416},
|
||||
{-0.6833, 1.1056, 2.8944, 4.6833},
|
||||
{-1.0249, 1.6584, 4.3416, 7.0249},
|
||||
{-1.3665, 2.2112, 5.7888, 9.3665}}});
|
||||
ASSERT_TRUE(output.allclose(expected, 2e-04));
|
||||
}
|
||||
|
||||
TEST_F(FunctionalTest, InstanceNorm1dDefaultOptions) {
|
||||
auto input = torch::arange(40.).view({2, 5, 4});
|
||||
auto output = F::instance_norm(input);
|
||||
auto expected = torch::tensor({{{-1.3416, -0.4472, 0.4472, 1.3416},
|
||||
{-1.3416, -0.4472, 0.4472, 1.3416},
|
||||
{-1.3416, -0.4472, 0.4472, 1.3416},
|
||||
{-1.3416, -0.4472, 0.4472, 1.3416},
|
||||
{-1.3416, -0.4472, 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472, 0.4472, 1.3416},
|
||||
{-1.3416, -0.4472, 0.4472, 1.3416},
|
||||
{-1.3416, -0.4472, 0.4472, 1.3416},
|
||||
{-1.3416, -0.4472, 0.4472, 1.3416},
|
||||
{-1.3416, -0.4472, 0.4472, 1.3416}}});
|
||||
ASSERT_TRUE(output.allclose(expected, 2e-04));
|
||||
}
|
||||
|
||||
TEST_F(FunctionalTest, InstanceNorm2d) {
|
||||
int num_features = 5;
|
||||
double eps = 1e-05;
|
||||
double momentum = 0.1;
|
||||
|
||||
auto input = torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
|
||||
auto mean = torch::arange((double)num_features);
|
||||
auto variance = torch::arange((double)num_features);
|
||||
auto weight = torch::arange((double)num_features);
|
||||
auto bias = torch::arange((double)num_features);
|
||||
auto output = F::instance_norm(
|
||||
input,
|
||||
F::InstanceNormFuncOptions()
|
||||
.running_mean(mean)
|
||||
.running_var(variance)
|
||||
.weight(weight)
|
||||
.bias(bias)
|
||||
.momentum(momentum)
|
||||
.eps(eps));
|
||||
auto expected = torch::tensor({{{{ 0.0000, 0.0000},
|
||||
{ 0.0000, 0.0000}},
|
||||
{{-0.3416, 0.5528},
|
||||
{ 1.4472, 2.3416}},
|
||||
{{-0.6833, 1.1056},
|
||||
{ 2.8944, 4.6833}},
|
||||
{{-1.0249, 1.6584},
|
||||
{ 4.3416, 7.0249}},
|
||||
{{-1.3665, 2.2112},
|
||||
{ 5.7888, 9.3665}}},
|
||||
{{{ 0.0000, 0.0000},
|
||||
{ 0.0000, 0.0000}},
|
||||
{{-0.3416, 0.5528},
|
||||
{ 1.4472, 2.3416}},
|
||||
{{-0.6833, 1.1056},
|
||||
{ 2.8944, 4.6833}},
|
||||
{{-1.0249, 1.6584},
|
||||
{ 4.3416, 7.0249}},
|
||||
{{-1.3665, 2.2112},
|
||||
{ 5.7888, 9.3665}}}});
|
||||
ASSERT_TRUE(output.allclose(expected, 2e-04));
|
||||
}
|
||||
|
||||
TEST_F(FunctionalTest, InstanceNorm2dDefaultOptions) {
|
||||
int num_features = 5;
|
||||
double eps = 1e-05;
|
||||
|
||||
auto input = torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
|
||||
auto output = F::instance_norm(input);
|
||||
auto expected = torch::tensor({{{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}}},
|
||||
{{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}}}});
|
||||
ASSERT_TRUE(output.allclose(expected, 2e-04));
|
||||
}
|
||||
|
||||
TEST_F(FunctionalTest, InstanceNorm3d) {
|
||||
int num_features = 5;
|
||||
double eps = 1e-05;
|
||||
double momentum = 0.1;
|
||||
|
||||
auto input = torch::arange(2. * num_features * 2 * 2 * 2).view({2, num_features, 2, 2, 2});
|
||||
auto mean = torch::arange((double)num_features);
|
||||
auto variance = torch::arange((double)num_features);
|
||||
auto weight = torch::arange((double)num_features);
|
||||
auto bias = torch::arange((double)num_features);
|
||||
auto output = F::instance_norm(
|
||||
input,
|
||||
F::InstanceNormFuncOptions()
|
||||
.running_mean(mean)
|
||||
.running_var(variance)
|
||||
.weight(weight)
|
||||
.bias(bias)
|
||||
.momentum(momentum)
|
||||
.eps(eps));
|
||||
auto expected = torch::tensor({{{{{ 0.0000, 0.0000},
|
||||
{ 0.0000, 0.0000}},
|
||||
{{ 0.0000, 0.0000},
|
||||
{ 0.0000, 0.0000}}},
|
||||
{{{-0.5275, -0.0911},
|
||||
{ 0.3453, 0.7818}},
|
||||
{{ 1.2182, 1.6547},
|
||||
{ 2.0911, 2.5275}}},
|
||||
{{{-1.0550, -0.1822},
|
||||
{ 0.6907, 1.5636}},
|
||||
{{ 2.4364, 3.3093},
|
||||
{ 4.1822, 5.0550}}},
|
||||
{{{-1.5826, -0.2733},
|
||||
{ 1.0360, 2.3453}},
|
||||
{{ 3.6547, 4.9640},
|
||||
{ 6.2733, 7.5826}}},
|
||||
{{{-2.1101, -0.3644},
|
||||
{ 1.3814, 3.1271}},
|
||||
{{ 4.8729, 6.6186},
|
||||
{ 8.3644, 10.1101}}}},
|
||||
{{{{ 0.0000, 0.0000},
|
||||
{ 0.0000, 0.0000}},
|
||||
{{ 0.0000, 0.0000},
|
||||
{ 0.0000, 0.0000}}},
|
||||
{{{-0.5275, -0.0911},
|
||||
{ 0.3453, 0.7818}},
|
||||
{{ 1.2182, 1.6547},
|
||||
{ 2.0911, 2.5275}}},
|
||||
{{{-1.0550, -0.1822},
|
||||
{ 0.6907, 1.5636}},
|
||||
{{ 2.4364, 3.3093},
|
||||
{ 4.1822, 5.0550}}},
|
||||
{{{-1.5826, -0.2733},
|
||||
{ 1.0360, 2.3453}},
|
||||
{{ 3.6547, 4.9640},
|
||||
{ 6.2733, 7.5826}}},
|
||||
{{{-2.1101, -0.3644},
|
||||
{ 1.3814, 3.1271}},
|
||||
{{ 4.8729, 6.6186},
|
||||
{ 8.3644, 10.1101}}}}});
|
||||
ASSERT_TRUE(output.allclose(expected, 2e-04));
|
||||
}
|
||||
|
||||
TEST_F(FunctionalTest, InstanceNorm3dDefaultOptions) {
|
||||
int num_features = 5;
|
||||
double eps = 1e-05;
|
||||
|
||||
auto input = torch::arange(2. * num_features * 2 * 2 * 2).view({2, num_features, 2, 2, 2});
|
||||
auto output = F::instance_norm(input);
|
||||
auto expected = torch::tensor({{{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}}},
|
||||
{{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}}}});
|
||||
ASSERT_TRUE(output.allclose(expected, 2e-04));
|
||||
}
|
||||
|
||||
TEST_F(FunctionalTest, Interpolate) {
|
||||
{
|
||||
// 1D interpolation
|
||||
|
@ -1427,7 +1427,7 @@ TEST_F(ModulesTest, BatchNormLegacyWarning) {
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, BatchNorm1dStateful) {
|
||||
BatchNorm1d bn(BatchNorm1dOptions(5));
|
||||
BatchNorm1d bn(5);
|
||||
|
||||
ASSERT_TRUE(bn->options.track_running_stats());
|
||||
|
||||
@ -1464,20 +1464,30 @@ TEST_F(ModulesTest, BatchNorm1dStateless) {
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, BatchNorm1d) {
|
||||
BatchNorm1d bn(BatchNorm1dOptions(5));
|
||||
BatchNorm1d bn(5);
|
||||
bn->eval();
|
||||
|
||||
auto input = torch::randn({2, 5}, torch::requires_grad());
|
||||
auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
|
||||
auto output = bn->forward(input);
|
||||
auto expected = torch::tensor({{{ 0.0000, 1.0000},
|
||||
{ 2.0000, 3.0000},
|
||||
{ 4.0000, 5.0000},
|
||||
{ 6.0000, 7.0000},
|
||||
{ 8.0000, 9.0000}},
|
||||
{{10.0000, 10.9999},
|
||||
{11.9999, 12.9999},
|
||||
{13.9999, 14.9999},
|
||||
{15.9999, 16.9999},
|
||||
{17.9999, 18.9999}}});
|
||||
ASSERT_TRUE(output.allclose(expected));
|
||||
auto s = output.sum();
|
||||
s.backward();
|
||||
|
||||
ASSERT_EQ(input.sizes(), input.grad().sizes());
|
||||
ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5})));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, BatchNorm2dStateful) {
|
||||
BatchNorm2d bn(BatchNorm2dOptions(5));
|
||||
BatchNorm2d bn(5);
|
||||
|
||||
ASSERT_TRUE(bn->options.track_running_stats());
|
||||
|
||||
@ -1514,20 +1524,40 @@ TEST_F(ModulesTest, BatchNorm2dStateless) {
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, BatchNorm2d) {
|
||||
BatchNorm2d bn(BatchNorm2dOptions(5));
|
||||
BatchNorm2d bn(5);
|
||||
bn->eval();
|
||||
|
||||
auto input = torch::randn({2, 5, 4, 4}, torch::requires_grad());
|
||||
auto input = torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
|
||||
auto output = bn->forward(input);
|
||||
auto expected = torch::tensor({{{{ 0.0000, 1.0000},
|
||||
{ 2.0000, 3.0000}},
|
||||
{{ 4.0000, 5.0000},
|
||||
{ 6.0000, 7.0000}},
|
||||
{{ 8.0000, 9.0000},
|
||||
{10.0000, 10.9999}},
|
||||
{{11.9999, 12.9999},
|
||||
{13.9999, 14.9999}},
|
||||
{{15.9999, 16.9999},
|
||||
{17.9999, 18.9999}}},
|
||||
{{{19.9999, 20.9999},
|
||||
{21.9999, 22.9999}},
|
||||
{{23.9999, 24.9999},
|
||||
{25.9999, 26.9999}},
|
||||
{{27.9999, 28.9999},
|
||||
{29.9998, 30.9998}},
|
||||
{{31.9998, 32.9998},
|
||||
{33.9998, 34.9998}},
|
||||
{{35.9998, 36.9998},
|
||||
{37.9998, 38.9998}}}});
|
||||
ASSERT_TRUE(output.allclose(expected));
|
||||
auto s = output.sum();
|
||||
s.backward();
|
||||
|
||||
ASSERT_EQ(input.sizes(), input.grad().sizes());
|
||||
ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5, 4, 4})));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, BatchNorm3dStateful) {
|
||||
BatchNorm3d bn(BatchNorm3dOptions(5));
|
||||
BatchNorm3d bn(5);
|
||||
|
||||
ASSERT_TRUE(bn->options.track_running_stats());
|
||||
|
||||
@ -1564,16 +1594,276 @@ TEST_F(ModulesTest, BatchNorm3dStateless) {
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, BatchNorm3d) {
|
||||
BatchNorm3d bn(BatchNorm3dOptions(5));
|
||||
BatchNorm3d bn(5);
|
||||
bn->eval();
|
||||
|
||||
auto input = torch::randn({2, 5, 4, 4, 4}, torch::requires_grad());
|
||||
auto input = torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
|
||||
auto output = bn->forward(input);
|
||||
auto expected = torch::tensor({{{{{ 0.0000, 1.0000},
|
||||
{ 2.0000, 3.0000}},
|
||||
{{ 4.0000, 5.0000},
|
||||
{ 6.0000, 7.0000}}},
|
||||
{{{ 8.0000, 9.0000},
|
||||
{10.0000, 10.9999}},
|
||||
{{11.9999, 12.9999},
|
||||
{13.9999, 14.9999}}},
|
||||
{{{15.9999, 16.9999},
|
||||
{17.9999, 18.9999}},
|
||||
{{19.9999, 20.9999},
|
||||
{21.9999, 22.9999}}},
|
||||
{{{23.9999, 24.9999},
|
||||
{25.9999, 26.9999}},
|
||||
{{27.9999, 28.9999},
|
||||
{29.9998, 30.9998}}},
|
||||
{{{31.9998, 32.9998},
|
||||
{33.9998, 34.9998}},
|
||||
{{35.9998, 36.9998},
|
||||
{37.9998, 38.9998}}}},
|
||||
{{{{39.9998, 40.9998},
|
||||
{41.9998, 42.9998}},
|
||||
{{43.9998, 44.9998},
|
||||
{45.9998, 46.9998}}},
|
||||
{{{47.9998, 48.9998},
|
||||
{49.9997, 50.9997}},
|
||||
{{51.9997, 52.9997},
|
||||
{53.9997, 54.9997}}},
|
||||
{{{55.9997, 56.9997},
|
||||
{57.9997, 58.9997}},
|
||||
{{59.9997, 60.9997},
|
||||
{61.9997, 62.9997}}},
|
||||
{{{63.9997, 64.9997},
|
||||
{65.9997, 66.9997}},
|
||||
{{67.9997, 68.9997},
|
||||
{69.9996, 70.9996}}},
|
||||
{{{71.9996, 72.9996},
|
||||
{73.9996, 74.9996}},
|
||||
{{75.9996, 76.9996},
|
||||
{77.9996, 78.9996}}}}});
|
||||
ASSERT_TRUE(output.allclose(expected));
|
||||
auto s = output.sum();
|
||||
s.backward();
|
||||
|
||||
ASSERT_EQ(input.sizes(), input.grad().sizes());
|
||||
ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5, 4, 4, 4})));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, InstanceNorm1dStateful) {
|
||||
InstanceNorm1d instance_norm(InstanceNorm1dOptions(5).track_running_stats(true).affine(true));
|
||||
|
||||
ASSERT_TRUE(instance_norm->options.track_running_stats());
|
||||
|
||||
ASSERT_TRUE(instance_norm->running_mean.defined());
|
||||
ASSERT_EQ(instance_norm->running_mean.dim(), 1);
|
||||
ASSERT_EQ(instance_norm->running_mean.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(instance_norm->running_var.defined());
|
||||
ASSERT_EQ(instance_norm->running_var.dim(), 1);
|
||||
ASSERT_EQ(instance_norm->running_var.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
|
||||
ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
|
||||
|
||||
ASSERT_TRUE(instance_norm->options.affine());
|
||||
|
||||
ASSERT_TRUE(instance_norm->weight.defined());
|
||||
ASSERT_EQ(instance_norm->weight.dim(), 1);
|
||||
ASSERT_EQ(instance_norm->weight.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(instance_norm->bias.defined());
|
||||
ASSERT_EQ(instance_norm->bias.dim(), 1);
|
||||
ASSERT_EQ(instance_norm->bias.size(0), 5);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, InstanceNorm1dStateless) {
|
||||
InstanceNorm1d instance_norm(InstanceNorm1dOptions(5).track_running_stats(false).affine(false));
|
||||
|
||||
ASSERT_FALSE(instance_norm->running_mean.defined());
|
||||
ASSERT_FALSE(instance_norm->running_var.defined());
|
||||
ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
|
||||
ASSERT_FALSE(instance_norm->weight.defined());
|
||||
ASSERT_FALSE(instance_norm->bias.defined());
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, InstanceNorm1d) {
|
||||
InstanceNorm1d instance_norm(5);
|
||||
instance_norm->eval();
|
||||
|
||||
auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
|
||||
auto output = instance_norm->forward(input);
|
||||
auto expected = torch::tensor({{{-1.0000, 1.0000},
|
||||
{-1.0000, 1.0000},
|
||||
{-1.0000, 1.0000},
|
||||
{-1.0000, 1.0000},
|
||||
{-1.0000, 1.0000}},
|
||||
{{-1.0000, 1.0000},
|
||||
{-1.0000, 1.0000},
|
||||
{-1.0000, 1.0000},
|
||||
{-1.0000, 1.0000},
|
||||
{-1.0000, 1.0000}}});
|
||||
ASSERT_TRUE(output.allclose(expected, 1e-3));
|
||||
auto s = output.sum();
|
||||
s.backward();
|
||||
|
||||
ASSERT_EQ(input.sizes(), input.grad().sizes());
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, InstanceNorm2dStateful) {
|
||||
InstanceNorm2d instance_norm(InstanceNorm2dOptions(5).track_running_stats(true).affine(true));
|
||||
|
||||
ASSERT_TRUE(instance_norm->options.track_running_stats());
|
||||
|
||||
ASSERT_TRUE(instance_norm->running_mean.defined());
|
||||
ASSERT_EQ(instance_norm->running_mean.dim(), 1);
|
||||
ASSERT_EQ(instance_norm->running_mean.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(instance_norm->running_var.defined());
|
||||
ASSERT_EQ(instance_norm->running_var.dim(), 1);
|
||||
ASSERT_EQ(instance_norm->running_var.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
|
||||
ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
|
||||
|
||||
ASSERT_TRUE(instance_norm->options.affine());
|
||||
|
||||
ASSERT_TRUE(instance_norm->weight.defined());
|
||||
ASSERT_EQ(instance_norm->weight.dim(), 1);
|
||||
ASSERT_EQ(instance_norm->weight.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(instance_norm->bias.defined());
|
||||
ASSERT_EQ(instance_norm->bias.dim(), 1);
|
||||
ASSERT_EQ(instance_norm->bias.size(0), 5);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, InstanceNorm2dStateless) {
|
||||
InstanceNorm2d instance_norm(InstanceNorm2dOptions(5).track_running_stats(false).affine(false));
|
||||
|
||||
ASSERT_FALSE(instance_norm->running_mean.defined());
|
||||
ASSERT_FALSE(instance_norm->running_var.defined());
|
||||
ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
|
||||
ASSERT_FALSE(instance_norm->weight.defined());
|
||||
ASSERT_FALSE(instance_norm->bias.defined());
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, InstanceNorm2d) {
|
||||
InstanceNorm2d instance_norm(5);
|
||||
instance_norm->eval();
|
||||
|
||||
auto input = torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
|
||||
auto output = instance_norm->forward(input);
|
||||
auto expected = torch::tensor({{{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}}},
|
||||
{{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}},
|
||||
{{-1.3416, -0.4472},
|
||||
{ 0.4472, 1.3416}}}});
|
||||
ASSERT_TRUE(output.allclose(expected, 1e-3));
|
||||
auto s = output.sum();
|
||||
s.backward();
|
||||
|
||||
ASSERT_EQ(input.sizes(), input.grad().sizes());
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, InstanceNorm3dStateful) {
|
||||
InstanceNorm3d instance_norm(InstanceNorm3dOptions(5).track_running_stats(true).affine(true));
|
||||
|
||||
ASSERT_TRUE(instance_norm->options.track_running_stats());
|
||||
|
||||
ASSERT_TRUE(instance_norm->running_mean.defined());
|
||||
ASSERT_EQ(instance_norm->running_mean.dim(), 1);
|
||||
ASSERT_EQ(instance_norm->running_mean.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(instance_norm->running_var.defined());
|
||||
ASSERT_EQ(instance_norm->running_var.dim(), 1);
|
||||
ASSERT_EQ(instance_norm->running_var.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
|
||||
ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
|
||||
|
||||
ASSERT_TRUE(instance_norm->options.affine());
|
||||
|
||||
ASSERT_TRUE(instance_norm->weight.defined());
|
||||
ASSERT_EQ(instance_norm->weight.dim(), 1);
|
||||
ASSERT_EQ(instance_norm->weight.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(instance_norm->bias.defined());
|
||||
ASSERT_EQ(instance_norm->bias.dim(), 1);
|
||||
ASSERT_EQ(instance_norm->bias.size(0), 5);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, InstanceNorm3dStateless) {
|
||||
InstanceNorm3d instance_norm(InstanceNorm3dOptions(5).track_running_stats(false).affine(false));
|
||||
|
||||
ASSERT_FALSE(instance_norm->running_mean.defined());
|
||||
ASSERT_FALSE(instance_norm->running_var.defined());
|
||||
ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
|
||||
ASSERT_FALSE(instance_norm->weight.defined());
|
||||
ASSERT_FALSE(instance_norm->bias.defined());
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, InstanceNorm3d) {
|
||||
InstanceNorm3d instance_norm(5);
|
||||
instance_norm->eval();
|
||||
|
||||
auto input = torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
|
||||
auto output = instance_norm->forward(input);
|
||||
auto expected = torch::tensor({{{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}}},
|
||||
{{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}},
|
||||
{{{-1.5275, -1.0911},
|
||||
{-0.6547, -0.2182}},
|
||||
{{ 0.2182, 0.6547},
|
||||
{ 1.0911, 1.5275}}}}});
|
||||
ASSERT_TRUE(output.allclose(expected, 1e-3));
|
||||
auto s = output.sum();
|
||||
s.backward();
|
||||
|
||||
ASSERT_EQ(input.sizes(), input.grad().sizes());
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Linear_CUDA) {
|
||||
@ -3178,6 +3468,30 @@ TEST_F(ModulesTest, PrettyPrintBatchNorm3d) {
|
||||
"torch::nn::BatchNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintInstanceNorm1d) {
|
||||
ASSERT_EQ(
|
||||
c10::str(InstanceNorm1d(
|
||||
InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false)
|
||||
.track_running_stats(true))),
|
||||
"torch::nn::InstanceNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintInstanceNorm2d) {
|
||||
ASSERT_EQ(
|
||||
c10::str(InstanceNorm2d(
|
||||
InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false)
|
||||
.track_running_stats(true))),
|
||||
"torch::nn::InstanceNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintInstanceNorm3d) {
|
||||
ASSERT_EQ(
|
||||
c10::str(InstanceNorm3d(
|
||||
InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false)
|
||||
.track_running_stats(true))),
|
||||
"torch::nn::InstanceNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintLayerNorm) {
|
||||
ASSERT_EQ(
|
||||
c10::str(LayerNorm(LayerNormOptions({2, 2}))),
|
||||
|
@ -75,9 +75,9 @@ torch.nn.BatchNorm2d|Yes|No
|
||||
torch.nn.BatchNorm3d|Yes|No
|
||||
torch.nn.GroupNorm|Yes|No
|
||||
torch.nn.SyncBatchNorm|No|No
|
||||
torch.nn.InstanceNorm1d|No|No
|
||||
torch.nn.InstanceNorm2d|No|No
|
||||
torch.nn.InstanceNorm3d|No|No
|
||||
torch.nn.InstanceNorm1d|Yes|No
|
||||
torch.nn.InstanceNorm2d|Yes|No
|
||||
torch.nn.InstanceNorm3d|Yes|No
|
||||
torch.nn.LayerNorm|Yes|No
|
||||
torch.nn.LocalResponseNorm|Yes|No
|
||||
torch.nn.CrossMapLRN2d|Yes|No
|
||||
|
@ -221,6 +221,7 @@ def add_torch_libs():
|
||||
"torch/csrc/api/src/nn/modules/activation.cpp",
|
||||
"torch/csrc/api/src/nn/modules/batchnorm.cpp",
|
||||
"torch/csrc/api/src/nn/modules/normalization.cpp",
|
||||
"torch/csrc/api/src/nn/modules/instancenorm.cpp",
|
||||
"torch/csrc/api/src/nn/modules/conv.cpp",
|
||||
"torch/csrc/api/src/nn/modules/dropout.cpp",
|
||||
"torch/csrc/api/src/nn/modules/distance.cpp",
|
||||
@ -239,6 +240,7 @@ def add_torch_libs():
|
||||
"torch/csrc/api/src/nn/options/batchnorm.cpp",
|
||||
"torch/csrc/api/src/nn/options/conv.cpp",
|
||||
"torch/csrc/api/src/nn/options/dropout.cpp",
|
||||
"torch/csrc/api/src/nn/options/instancenorm.cpp",
|
||||
"torch/csrc/api/src/nn/options/linear.cpp",
|
||||
"torch/csrc/api/src/nn/options/normalization.cpp",
|
||||
"torch/csrc/api/src/nn/options/embedding.cpp",
|
||||
|
@ -14,3 +14,4 @@
|
||||
#include <torch/nn/functional/pooling.h>
|
||||
#include <torch/nn/functional/upsampling.h>
|
||||
#include <torch/nn/functional/vision.h>
|
||||
#include <torch/nn/functional/instancenorm.h>
|
||||
|
30
torch/csrc/api/include/torch/nn/functional/instancenorm.h
Normal file
30
torch/csrc/api/include/torch/nn/functional/instancenorm.h
Normal file
@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nn/options/instancenorm.h>
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
namespace functional {
|
||||
|
||||
namespace detail {
|
||||
inline Tensor instance_norm(const Tensor& input, const Tensor& running_mean,
|
||||
const Tensor& running_var, const Tensor& weight, const Tensor& bias,
|
||||
bool use_input_stats, double momentum, double eps) {
|
||||
|
||||
return torch::instance_norm(
|
||||
input, weight, bias, running_mean, running_var,
|
||||
use_input_stats, momentum, eps, at::globalContext().userEnabledCuDNN()
|
||||
);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
inline Tensor instance_norm(const Tensor& input, const InstanceNormFuncOptions& options = {}) {
|
||||
return detail::instance_norm(
|
||||
input, options.running_mean(),
|
||||
options.running_var(), options.weight(), options.bias(),
|
||||
options.use_input_stats(), options.momentum(), options.eps());
|
||||
}
|
||||
|
||||
} // namespace functional
|
||||
} // namespace nn
|
||||
} // namespace torch
|
@ -9,6 +9,7 @@
|
||||
|
||||
// Layers
|
||||
#include <torch/nn/modules/batchnorm.h>
|
||||
#include <torch/nn/modules/instancenorm.h>
|
||||
#include <torch/nn/modules/conv.h>
|
||||
#include <torch/nn/modules/dropout.h>
|
||||
#include <torch/nn/modules/distance.h>
|
||||
|
@ -1,12 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nn/cloneable.h>
|
||||
#include <torch/nn/functional/batchnorm.h>
|
||||
#include <torch/nn/options/batchnorm.h>
|
||||
#include <torch/nn/init.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace F = torch::nn::functional;
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
@ -77,28 +81,55 @@ TORCH_MODULE(BatchNorm);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Base class for all (dimension-specialized) batchnorm modules.
|
||||
template <size_t D, typename Derived>
|
||||
class TORCH_API BatchNormImplBase : public torch::nn::Cloneable<Derived> {
|
||||
/// Base class for all (dimension-specialized) batchnorm and instancenorm modules.
|
||||
template <size_t D, typename Derived, typename DerivedOptions>
|
||||
class NormImplBase : public torch::nn::Cloneable<Derived> {
|
||||
protected:
|
||||
virtual void _check_input_dim(const Tensor& input) = 0;
|
||||
|
||||
public:
|
||||
explicit BatchNormImplBase(const BatchNormOptions& options_);
|
||||
NormImplBase(const DerivedOptions& options_) : options(options_) {
|
||||
reset();
|
||||
}
|
||||
|
||||
Tensor forward(const Tensor& input);
|
||||
void reset() override {
|
||||
if (options.affine()) {
|
||||
weight = this->register_parameter("weight", torch::empty({options.num_features()}));
|
||||
bias = this->register_parameter("bias", torch::empty({options.num_features()}));
|
||||
} else {
|
||||
weight = this->register_parameter("weight", Tensor());
|
||||
bias = this->register_parameter("bias", Tensor());
|
||||
}
|
||||
if (options.track_running_stats()) {
|
||||
running_mean = this->register_buffer("running_mean", torch::zeros({options.num_features()}));
|
||||
running_var = this->register_buffer("running_var", torch::ones({options.num_features()}));
|
||||
num_batches_tracked = this->register_buffer("num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
|
||||
} else {
|
||||
running_mean = this->register_buffer("running_mean", Tensor());
|
||||
running_var = this->register_buffer("running_var", Tensor());
|
||||
num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor());
|
||||
}
|
||||
reset_parameters();
|
||||
}
|
||||
|
||||
void reset() override;
|
||||
void reset_running_stats() {
|
||||
if (options.track_running_stats()) {
|
||||
running_mean.zero_();
|
||||
running_var.fill_(1);
|
||||
num_batches_tracked.zero_();
|
||||
}
|
||||
}
|
||||
|
||||
void reset_running_stats();
|
||||
|
||||
void reset_parameters();
|
||||
|
||||
/// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
|
||||
void pretty_print(std::ostream& stream) const override;
|
||||
void reset_parameters() {
|
||||
reset_running_stats();
|
||||
if (options.affine()) {
|
||||
torch::nn::init::ones_(weight);
|
||||
torch::nn::init::zeros_(bias);
|
||||
}
|
||||
}
|
||||
|
||||
/// The options with which this module was constructed.
|
||||
BatchNormOptions options;
|
||||
DerivedOptions options;
|
||||
|
||||
/// The learned weight.
|
||||
/// Only defined if the `affine` option was `true` upon construction.
|
||||
@ -121,6 +152,47 @@ class TORCH_API BatchNormImplBase : public torch::nn::Cloneable<Derived> {
|
||||
Tensor num_batches_tracked;
|
||||
};
|
||||
|
||||
/// Base class for all (dimension-specialized) batchnorm modules.
|
||||
template <size_t D, typename Derived>
|
||||
class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> {
|
||||
public:
|
||||
using NormImplBase<D, Derived, BatchNormOptions>::NormImplBase;
|
||||
|
||||
Tensor forward(const Tensor& input) {
|
||||
this->_check_input_dim(input);
|
||||
double exponential_average_factor;
|
||||
if (this->options.momentum() == c10::nullopt) {
|
||||
exponential_average_factor = 0.0;
|
||||
} else {
|
||||
exponential_average_factor = this->options.momentum().value();
|
||||
}
|
||||
|
||||
if (this->is_training() && this->options.track_running_stats()) {
|
||||
if (this->num_batches_tracked.defined()) {
|
||||
this->num_batches_tracked += 1;
|
||||
if (this->options.momentum() == c10::nullopt) { // use cumulative moving average
|
||||
exponential_average_factor = 1.0 / this->num_batches_tracked.template item<double>();
|
||||
} else { // use exponential moving average
|
||||
exponential_average_factor = this->options.momentum().value();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return F::detail::batch_norm(
|
||||
input,
|
||||
this->running_mean,
|
||||
this->running_var,
|
||||
this->weight,
|
||||
this->bias,
|
||||
this->is_training() || !this->options.track_running_stats(),
|
||||
/*momentum=*/exponential_average_factor,
|
||||
this->options.eps());
|
||||
}
|
||||
|
||||
/// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
|
||||
void pretty_print(std::ostream& stream) const override;
|
||||
};
|
||||
|
||||
/// Applies the BatchNorm1d function.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm1d to learn
|
||||
/// about the exact behavior of this module.
|
||||
|
66
torch/csrc/api/include/torch/nn/modules/instancenorm.h
Normal file
66
torch/csrc/api/include/torch/nn/modules/instancenorm.h
Normal file
@ -0,0 +1,66 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nn/modules/batchnorm.h>
|
||||
#include <torch/nn/options/instancenorm.h>
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
/// Base class for all (dimension-specialized) instance norm modules
|
||||
template <size_t D, typename Derived>
|
||||
class InstanceNormImpl : public torch::nn::NormImplBase<D, Derived, InstanceNormOptions> {
|
||||
public:
|
||||
using torch::nn::NormImplBase<D, Derived, InstanceNormOptions>::NormImplBase;
|
||||
|
||||
Tensor forward(const Tensor& input) {
|
||||
this->_check_input_dim(input);
|
||||
return F::detail::instance_norm(
|
||||
input, this->running_mean, this->running_var, this->weight, this->bias,
|
||||
this->is_training() || !this->options.track_running_stats(), this->options.momentum(), this->options.eps());
|
||||
}
|
||||
|
||||
/// Pretty prints the `InstanceNorm{1,2,3}d` module into the given `stream`.
|
||||
void pretty_print(std::ostream& stream) const override;
|
||||
};
|
||||
|
||||
/// Applies the InstanceNorm1d function.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm1d to learn
|
||||
/// about the exact behavior of this module.
|
||||
class TORCH_API InstanceNorm1dImpl : public InstanceNormImpl<1, InstanceNorm1dImpl> {
|
||||
protected:
|
||||
virtual void _check_input_dim(const Tensor& input) override;
|
||||
|
||||
public:
|
||||
using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl;
|
||||
};
|
||||
|
||||
TORCH_MODULE(InstanceNorm1d);
|
||||
|
||||
/// Applies the InstanceNorm2d function.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm2d to learn
|
||||
/// about the exact behavior of this module.
|
||||
class TORCH_API InstanceNorm2dImpl : public InstanceNormImpl<2, InstanceNorm2dImpl> {
|
||||
protected:
|
||||
virtual void _check_input_dim(const Tensor& input) override;
|
||||
|
||||
public:
|
||||
using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl;
|
||||
};
|
||||
|
||||
TORCH_MODULE(InstanceNorm2d);
|
||||
|
||||
/// Applies the InstanceNorm3d function.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm3d to learn
|
||||
/// about the exact behavior of this module.
|
||||
class TORCH_API InstanceNorm3dImpl : public InstanceNormImpl<3, InstanceNorm3dImpl> {
|
||||
protected:
|
||||
virtual void _check_input_dim(const Tensor& input) override;
|
||||
|
||||
public:
|
||||
using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl;
|
||||
};
|
||||
|
||||
TORCH_MODULE(InstanceNorm3d);
|
||||
|
||||
} // namespace nn
|
||||
} // namespace torch
|
@ -43,7 +43,7 @@ using BatchNorm3dOptions = BatchNormOptions;
|
||||
|
||||
namespace functional {
|
||||
|
||||
/// Options for the `BatchNorm` module.
|
||||
/// Options for the `BatchNorm` functional.
|
||||
struct TORCH_API BatchNormFuncOptions {
|
||||
TORCH_ARG(Tensor, weight) = Tensor();
|
||||
|
||||
|
59
torch/csrc/api/include/torch/nn/options/instancenorm.h
Normal file
59
torch/csrc/api/include/torch/nn/options/instancenorm.h
Normal file
@ -0,0 +1,59 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/arg.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/nn/options/batchnorm.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
/// Options for the `InstanceNorm` module.
|
||||
struct TORCH_API InstanceNormOptions {
|
||||
/* implicit */ InstanceNormOptions(int64_t num_features);
|
||||
|
||||
/// The number of features of the input tensor.
|
||||
TORCH_ARG(int64_t, num_features);
|
||||
|
||||
/// The epsilon value added for numerical stability.
|
||||
TORCH_ARG(double, eps) = 1e-5;
|
||||
|
||||
/// A momentum multiplier for the mean and variance.
|
||||
TORCH_ARG(double, momentum) = 0.1;
|
||||
|
||||
/// Whether to learn a scale and bias that are applied in an affine
|
||||
/// transformation on the input.
|
||||
TORCH_ARG(bool, affine) = false;
|
||||
|
||||
/// Whether to store and update batch statistics (mean and variance) in the
|
||||
/// module.
|
||||
TORCH_ARG(bool, track_running_stats) = false;
|
||||
};
|
||||
|
||||
using InstanceNorm1dOptions = InstanceNormOptions;
|
||||
using InstanceNorm2dOptions = InstanceNormOptions;
|
||||
using InstanceNorm3dOptions = InstanceNormOptions;
|
||||
|
||||
namespace functional {
|
||||
|
||||
/// Options for the `InstanceNorm` functional.
|
||||
struct TORCH_API InstanceNormFuncOptions {
|
||||
TORCH_ARG(Tensor, running_mean) = Tensor();
|
||||
|
||||
TORCH_ARG(Tensor, running_var) = Tensor();
|
||||
|
||||
TORCH_ARG(Tensor, weight) = Tensor();
|
||||
|
||||
TORCH_ARG(Tensor, bias) = Tensor();
|
||||
|
||||
TORCH_ARG(bool, use_input_stats) = true;
|
||||
|
||||
TORCH_ARG(double, momentum) = 0.1;
|
||||
|
||||
TORCH_ARG(double, eps) = 1e-5;
|
||||
};
|
||||
|
||||
} // namespace functional
|
||||
|
||||
} // namespace nn
|
||||
} // namespace torch
|
@ -3,7 +3,6 @@
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/nn/init.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
@ -17,7 +16,7 @@ namespace F = torch::nn::functional;
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) {
|
||||
BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value)
|
||||
TORCH_WARN("torch::nn::BatchNorm module is deprecated and will be removed in 1.5. "
|
||||
"Use BatchNorm{1,2,3}d instead.");
|
||||
reset();
|
||||
@ -78,93 +77,17 @@ Tensor BatchNormImpl::pure_forward(
|
||||
torch::cuda::cudnn_is_available());
|
||||
}
|
||||
|
||||
template <size_t D, typename Derived>
|
||||
BatchNormImplBase<D, Derived>::BatchNormImplBase(const BatchNormOptions& options_)
|
||||
: options(options_) {
|
||||
reset();
|
||||
}
|
||||
// ===========================================================================
|
||||
|
||||
template <size_t D, typename Derived>
|
||||
void BatchNormImplBase<D, Derived>::reset() {
|
||||
if (options.affine()) {
|
||||
weight = this->register_parameter("weight", torch::empty({options.num_features()}));
|
||||
bias = this->register_parameter("bias", torch::empty({options.num_features()}));
|
||||
} else {
|
||||
weight = this->register_parameter("weight", Tensor());
|
||||
bias = this->register_parameter("bias", Tensor());
|
||||
}
|
||||
if (options.track_running_stats()) {
|
||||
running_mean = this->register_buffer("running_mean", torch::zeros({options.num_features()}));
|
||||
running_var = this->register_buffer("running_var", torch::ones({options.num_features()}));
|
||||
num_batches_tracked = this->register_buffer("num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
|
||||
} else {
|
||||
running_mean = this->register_buffer("running_mean", Tensor());
|
||||
running_var = this->register_buffer("running_var", Tensor());
|
||||
num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor());
|
||||
}
|
||||
reset_parameters();
|
||||
}
|
||||
|
||||
template <size_t D, typename Derived>
|
||||
void BatchNormImplBase<D, Derived>::reset_running_stats() {
|
||||
if (options.track_running_stats()) {
|
||||
running_mean.zero_();
|
||||
running_var.fill_(1);
|
||||
num_batches_tracked.zero_();
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t D, typename Derived>
|
||||
void BatchNormImplBase<D, Derived>::reset_parameters() {
|
||||
reset_running_stats();
|
||||
if (options.affine()) {
|
||||
torch::nn::init::ones_(weight);
|
||||
torch::nn::init::zeros_(bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t D, typename Derived>
|
||||
template <size_t D, typename Derived>
|
||||
void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const {
|
||||
stream << std::boolalpha
|
||||
<< "torch::nn::BatchNorm" << D << "d("
|
||||
<< options.num_features() << ", "
|
||||
<< "eps=" << options.eps() << ", "
|
||||
<< "momentum=" << options.momentum().value() << ", "
|
||||
<< "affine=" << options.affine() << ", "
|
||||
<< "track_running_stats=" << options.track_running_stats() << ")";
|
||||
}
|
||||
|
||||
template <size_t D, typename Derived>
|
||||
Tensor BatchNormImplBase<D, Derived>::forward(const Tensor& input) {
|
||||
_check_input_dim(input);
|
||||
|
||||
double exponential_average_factor;
|
||||
if (options.momentum() == c10::nullopt) {
|
||||
exponential_average_factor = 0.0;
|
||||
} else {
|
||||
exponential_average_factor = options.momentum().value();
|
||||
}
|
||||
|
||||
if (this->is_training() && options.track_running_stats()) {
|
||||
if (num_batches_tracked.defined()) {
|
||||
num_batches_tracked += 1;
|
||||
if (options.momentum() == c10::nullopt) { // use cumulative moving average
|
||||
exponential_average_factor = 1.0 / num_batches_tracked.item<double>();
|
||||
} else { // use exponential moving average
|
||||
exponential_average_factor = options.momentum().value();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return F::detail::batch_norm(
|
||||
input,
|
||||
running_mean,
|
||||
running_var,
|
||||
weight,
|
||||
bias,
|
||||
this->is_training() || !options.track_running_stats(),
|
||||
/*momentum=*/exponential_average_factor,
|
||||
options.eps());
|
||||
<< this->options.num_features() << ", "
|
||||
<< "eps=" << this->options.eps() << ", "
|
||||
<< "momentum=" << this->options.momentum().value() << ", "
|
||||
<< "affine=" << this->options.affine() << ", "
|
||||
<< "track_running_stats=" << this->options.track_running_stats() << ")";
|
||||
}
|
||||
|
||||
void BatchNorm1dImpl::_check_input_dim(const Tensor& input) {
|
||||
|
57
torch/csrc/api/src/nn/modules/instancenorm.cpp
Normal file
57
torch/csrc/api/src/nn/modules/instancenorm.cpp
Normal file
@ -0,0 +1,57 @@
|
||||
#include <torch/nn/functional/instancenorm.h>
|
||||
#include <torch/nn/modules/instancenorm.h>
|
||||
|
||||
namespace F = torch::nn::functional;
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
template <size_t D, typename Derived>
|
||||
void InstanceNormImpl<D, Derived>::pretty_print(std::ostream& stream) const {
|
||||
stream << std::boolalpha
|
||||
<< "torch::nn::InstanceNorm" << D << "d("
|
||||
<< this->options.num_features() << ", "
|
||||
<< "eps=" << this->options.eps() << ", "
|
||||
<< "momentum=" << this->options.momentum() << ", "
|
||||
<< "affine=" << this->options.affine() << ", "
|
||||
<< "track_running_stats=" << this->options.track_running_stats() << ")";
|
||||
}
|
||||
|
||||
void InstanceNorm1dImpl::_check_input_dim(const Tensor& input) {
|
||||
if (input.dim() == 2) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"InstanceNorm1d returns 0-filled tensor to 2D tensor.",
|
||||
"This is because InstanceNorm1d reshapes inputs to",
|
||||
"(1, N * C, ...) from (N, C,...) and this makes",
|
||||
"variances 0.");
|
||||
}
|
||||
if (input.dim() != 3) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"expected 3D input (got ", input.dim(), "D input)");
|
||||
}
|
||||
}
|
||||
|
||||
void InstanceNorm2dImpl::_check_input_dim(const Tensor& input) {
|
||||
if (input.dim() != 4) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"expected 4D input (got ", input.dim(), "D input)");
|
||||
}
|
||||
}
|
||||
|
||||
void InstanceNorm3dImpl::_check_input_dim(const Tensor& input) {
|
||||
if (input.dim() != 5) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"expected 5D input (got ", input.dim(), "D input)");
|
||||
}
|
||||
}
|
||||
|
||||
template class InstanceNormImpl<1, InstanceNorm1dImpl>;
|
||||
template class InstanceNormImpl<2, InstanceNorm2dImpl>;
|
||||
template class InstanceNormImpl<3, InstanceNorm3dImpl>;
|
||||
|
||||
} // namespace nn
|
||||
} // namespace torch
|
9
torch/csrc/api/src/nn/options/instancenorm.cpp
Normal file
9
torch/csrc/api/src/nn/options/instancenorm.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
#include <torch/nn/options/instancenorm.h>
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
InstanceNormOptions::InstanceNormOptions(int64_t num_features) : num_features_(num_features) {}
|
||||
|
||||
} // namespace nn
|
||||
} // namespace torch
|
Reference in New Issue
Block a user