raise error when groups is not positive in Conv modules

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77919

Approved by: https://github.com/jbschlosser
This commit is contained in:
yuguo68
2022-05-19 17:14:40 -07:00
committed by PyTorch MergeBot
parent 018982318c
commit c186250d95
2 changed files with 10 additions and 0 deletions

View File

@ -980,6 +980,14 @@ class TestNN(NNTestCase):
with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
module(input)
def test_conv_invalid_groups(self):
with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'):
torch.nn.Conv1d(1, 1, kernel_size=3, dilation=2, stride=2, groups=0)
with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'):
torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-1)
with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'):
torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2)
def test_Conv1d_module_same_padding(self):
# Compare module against functional: without strides/dilation, asymmetric padding
x = torch.rand(1, 1, 20)

View File

@ -80,6 +80,8 @@ class _ConvNd(Module):
dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_ConvNd, self).__init__()
if groups <= 0:
raise ValueError('groups must be a positive integer')
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0: