mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert behavior of Dropout2d on 3D inputs to 1D channel-wise dropout behavior & warn
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79549 Approved by: https://github.com/ngimel, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
2d73c8e6e0
commit
5953fd9133
@ -14432,10 +14432,16 @@ class TestNNDeviceType(NNTestCase):
|
||||
with self.assertWarnsRegex(UserWarning, "Received a 2-D input to dropout2d"):
|
||||
nn.Dropout2d(p=0.5)(torch.rand(1, 2, device=device))
|
||||
|
||||
# no batch dims
|
||||
input = torch.rand(50, 2, 2, device=device)
|
||||
self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5), input)
|
||||
self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5, inplace=True), input)
|
||||
# TODO: Uncomment these lines once no-batch-dim inputs are supported.
|
||||
# For now, the historical dropout1d behavior is performed for 3D inputs.
|
||||
# See https://github.com/pytorch/pytorch/issues/77081
|
||||
|
||||
# input = torch.rand(50, 2, 2, device=device)
|
||||
# self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5), input)
|
||||
# self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5, inplace=True), input)
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, "assuming that channel-wise 1D dropout behavior is desired"):
|
||||
nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, device=device))
|
||||
|
||||
# check that complete channels are dropped
|
||||
input = torch.ones(10, 4, 2, 2, device=device)
|
||||
|
@ -1330,15 +1330,19 @@ def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: boo
|
||||
"a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs).")
|
||||
warnings.warn(warn_msg)
|
||||
|
||||
is_batched = inp_dim == 4
|
||||
if not is_batched:
|
||||
input = input.unsqueeze_(0) if inplace else input.unsqueeze(0)
|
||||
# TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing
|
||||
# a 3D input will perform dropout1d behavior instead. This was done historically and the
|
||||
# behavior is maintained here for now.
|
||||
# See https://github.com/pytorch/pytorch/issues/77081
|
||||
if inp_dim == 3:
|
||||
warnings.warn("dropout2d: Received a 3D input to dropout2d and assuming that channel-wise "
|
||||
"1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C "
|
||||
"is the channel dim. This behavior will change in a future release to interpret the "
|
||||
"input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D "
|
||||
"channel-wise dropout behavior, please switch to using dropout1d instead.")
|
||||
|
||||
result = _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training)
|
||||
|
||||
if not is_batched:
|
||||
result = result.squeeze_(0) if inplace else result.squeeze(0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
@ -124,9 +124,16 @@ class Dropout2d(_DropoutNd):
|
||||
inplace (bool, optional): If set to ``True``, will do this operation
|
||||
in-place
|
||||
|
||||
.. warning ::
|
||||
Due to historical reasons, this class will perform 1D channel-wise dropout
|
||||
for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT
|
||||
support inputs without a batch dimension of shape :math:`(C, H, W)`. This
|
||||
behavior will change in a future release to interpret 3D inputs as no-batch-dim
|
||||
inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`.
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
|
||||
- Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input).
|
||||
- Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
|
||||
- Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
|
||||
|
||||
Examples::
|
||||
|
||||
|
Reference in New Issue
Block a user