mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
raise error when groups is not positive in Conv modules
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77919 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
018982318c
commit
c186250d95
@ -980,6 +980,14 @@ class TestNN(NNTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
|
||||
module(input)
|
||||
|
||||
def test_conv_invalid_groups(self):
|
||||
with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'):
|
||||
torch.nn.Conv1d(1, 1, kernel_size=3, dilation=2, stride=2, groups=0)
|
||||
with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'):
|
||||
torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-1)
|
||||
with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'):
|
||||
torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2)
|
||||
|
||||
def test_Conv1d_module_same_padding(self):
|
||||
# Compare module against functional: without strides/dilation, asymmetric padding
|
||||
x = torch.rand(1, 1, 20)
|
||||
|
||||
@ -80,6 +80,8 @@ class _ConvNd(Module):
|
||||
dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(_ConvNd, self).__init__()
|
||||
if groups <= 0:
|
||||
raise ValueError('groups must be a positive integer')
|
||||
if in_channels % groups != 0:
|
||||
raise ValueError('in_channels must be divisible by groups')
|
||||
if out_channels % groups != 0:
|
||||
|
||||
Reference in New Issue
Block a user