mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make InstanceNorm1d raise an error if the input is 2D (#11992)
Summary: Resolves #11991 . Any comment is welcome. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11992 Differential Revision: D14680974 Pulled By: soumith fbshipit-source-id: 8e287a9c32bf43b35edc9d127f16ed6b72c61d91
This commit is contained in:
committed by
Facebook Github Bot
parent
c189eba3e1
commit
cf444f3544
@ -57,7 +57,7 @@ class _InstanceNorm(_BatchNorm):
|
||||
|
||||
@weak_module
|
||||
class InstanceNorm1d(_InstanceNorm):
|
||||
r"""Applies Instance Normalization over a 2D or 3D input (a mini-batch of 1D
|
||||
r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D
|
||||
inputs with optional additional channel dimension) as described in the paper
|
||||
`Instance Normalization: The Missing Ingredient for Fast Stylization`_ .
|
||||
|
||||
@ -126,8 +126,15 @@ class InstanceNorm1d(_InstanceNorm):
|
||||
|
||||
@weak_script_method
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 2 and input.dim() != 3:
|
||||
raise ValueError('expected 2D or 3D input (got {}D input)'
|
||||
if input.dim() == 2:
|
||||
raise ValueError(
|
||||
'InstanceNorm1d returns 0-filled tensor to 2D tensor.'
|
||||
'This is because InstanceNorm1d reshapes inputs to'
|
||||
'(1, N * C, ...) from (N, C,...) and this makes'
|
||||
'variances 0.'
|
||||
)
|
||||
if input.dim() != 3:
|
||||
raise ValueError('expected 3D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user