make InstanceNorm1d raise an error if the input is 2D (#11992)

Summary:
Resolves #11991 .

Any comment is welcome.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11992

Differential Revision: D14680974

Pulled By: soumith

fbshipit-source-id: 8e287a9c32bf43b35edc9d127f16ed6b72c61d91
This commit is contained in:
crcrpar
2019-03-29 06:41:49 -07:00
committed by Facebook Github Bot
parent c189eba3e1
commit cf444f3544

View File

@ -57,7 +57,7 @@ class _InstanceNorm(_BatchNorm):
@weak_module
class InstanceNorm1d(_InstanceNorm):
r"""Applies Instance Normalization over a 2D or 3D input (a mini-batch of 1D
r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D
inputs with optional additional channel dimension) as described in the paper
`Instance Normalization: The Missing Ingredient for Fast Stylization`_ .
@ -126,8 +126,15 @@ class InstanceNorm1d(_InstanceNorm):
@weak_script_method
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
if input.dim() == 2:
raise ValueError(
'InstanceNorm1d returns 0-filled tensor to 2D tensor.'
'This is because InstanceNorm1d reshapes inputs to'
'(1, N * C, ...) from (N, C,...) and this makes'
'variances 0.'
)
if input.dim() != 3:
raise ValueError('expected 3D input (got {}D input)'
.format(input.dim()))