[nn] Assert parsed iterable arguments are an appropriate length (#162340)

Fixes #162327
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162340
Approved by: https://github.com/Skylion007
This commit is contained in:
Benjamin Glass
2025-09-08 20:55:19 +00:00
committed by PyTorch MergeBot
parent fefc406a3d
commit b5e6e58050
6 changed files with 31 additions and 24 deletions

View File

@ -481,7 +481,7 @@ class TestPoolingNN(NNTestCase):
def test_max_unpool3d_input_check(self): def test_max_unpool3d_input_check(self):
x = torch.ones(1, 3, 1, 1, 1) x = torch.ones(1, 3, 1, 1, 1)
with self.assertRaises(RuntimeError): with self.assertRaises(AssertionError):
F.max_unpool3d(x, torch.zeros(x.shape, dtype=int), [1, 1]) F.max_unpool3d(x, torch.zeros(x.shape, dtype=int), [1, 1])
def test_quantized_max_pool1d_empty_kernel(self): def test_quantized_max_pool1d_empty_kernel(self):

View File

@ -15,7 +15,7 @@ import torch
from torch import _VF from torch import _VF
import torch.jit import torch.jit
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.modules.utils import _single, _pair from torch.nn.modules.utils import _ntuple, _pair, _single
from hypothesis import settings, HealthCheck from hypothesis import settings, HealthCheck
from hypothesis import assume, given, note from hypothesis import assume, given, note
@ -5311,10 +5311,11 @@ class TestQuantizedConv(TestCase):
input_channels = input_channels_per_group * groups input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups output_channels = output_channels_per_group * groups
# Padded input size should be at least as big as dilated kernel # Padded input size should be at least as big as dilated kernel
kernels = _single(kernels) input_dimension_function = _ntuple(len(input_feature_map_shape))
strides = _single(strides) kernels = input_dimension_function(kernels)
pads = _single(pads) strides = input_dimension_function(strides)
dilations = _single(dilations) pads = input_dimension_function(pads)
dilations = input_dimension_function(dilations)
for i in range(len(kernels)): for i in range(len(kernels)):
assume(input_feature_map_shape[i] + 2 * pads[i] assume(input_feature_map_shape[i] + 2 * pads[i]
>= dilations[i] * (kernels[i] - 1) + 1) >= dilations[i] * (kernels[i] - 1) + 1)
@ -7846,10 +7847,11 @@ class TestQuantizedConv(TestCase):
input_channels = input_channels_per_group * groups input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups output_channels = output_channels_per_group * groups
# Padded input size should be at least as big as dilated kernel # Padded input size should be at least as big as dilated kernel
kernels = _single(kernels) input_dimension_function = _ntuple(len(input_feature_map_shape))
strides = _single(strides) kernels = input_dimension_function(kernels)
pads = _single(pads) strides = input_dimension_function(strides)
dilations = _single(dilations) pads = input_dimension_function(pads)
dilations = input_dimension_function(dilations)
for i in range(len(kernels)): for i in range(len(kernels)):
assume(input_feature_map_shape[i] + 2 * pads[i] assume(input_feature_map_shape[i] + 2 * pads[i]
>= dilations[i] * (kernels[i] - 1) + 1) >= dilations[i] * (kernels[i] - 1) + 1)

View File

@ -8957,9 +8957,9 @@ class TestPad(TestCaseMPS):
# pad dims == input dims # pad dims == input dims
helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d) helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
# input.numel() == 0 but output.numel() > 0 # input.numel() == 0 but output.numel() > 0
helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d) helper((0, 3, 3), 1, nn.ConstantPad2d)
# pad dims < input dims - 2 # pad dims < input dims - 2
helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d) helper((1, 2, 3, 4, 5), (1, 2, 0, 0), nn.ConstantPad2d)
# 3D Padding # 3D Padding
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d) helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
@ -8972,7 +8972,7 @@ class TestPad(TestCaseMPS):
# input size < pad size # input size < pad size
helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d) helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
# check the workaround for the right padding bug in Monterey # check the workaround for the right padding bug in Monterey
helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d) helper((1, 2, 2, 2, 2), (0, 1, 0, 1, 0, 1), nn.ConstantPad3d)
def test_constant_pad_nd_preserves_memory_format(self): def test_constant_pad_nd_preserves_memory_format(self):
nchw_tensor = torch.rand((1, 2, 5, 3)) nchw_tensor = torch.rand((1, 2, 5, 3))

View File

@ -7466,14 +7466,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
def test_fractional_max_pool2d_invalid_output_ratio(self): def test_fractional_max_pool2d_invalid_output_ratio(self):
arg_1 = [2, 1] arg_1 = [2, 1]
arg_2 = [0.5, 0.5, 0.6] arg_2 = [0.5, 0.5, 0.6]
arg_class = torch.nn.FractionalMaxPool2d(kernel_size=arg_1, output_ratio=arg_2,) with self.assertRaisesRegex(AssertionError, "Expected an iterable of length 2, but got length 3"):
arg_3_0_tensor = torch.rand([20, 16, 50, 32], dtype=torch.float32) arg_class = torch.nn.FractionalMaxPool2d(kernel_size=arg_1, output_ratio=arg_2,)
arg_3_0 = arg_3_0_tensor.clone()
arg_3 = [arg_3_0,]
with self.assertRaisesRegex(ValueError,
"fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."):
res = arg_class(*arg_3)
def test_max_pool1d_invalid_output_size(self): def test_max_pool1d_invalid_output_size(self):
arg_1 = 3 arg_1 = 3

View File

@ -768,7 +768,7 @@ class _ConvTransposeNd(_ConvNd):
dilation: Optional[list[int]] = None, dilation: Optional[list[int]] = None,
) -> list[int]: ) -> list[int]:
if output_size is None: if output_size is None:
ret = _single(self.output_padding) # converting to list if was not already ret = list(self.output_padding) # converting to list if was not already
else: else:
has_batch_dim = input.dim() == num_spatial_dims + 2 has_batch_dim = input.dim() == num_spatial_dims + 2
num_non_spatial_dims = 2 if has_batch_dim else 1 num_non_spatial_dims = 2 if has_batch_dim else 1

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import collections import collections.abc
from itertools import repeat from itertools import repeat
from typing import Any from typing import Any
@ -10,7 +10,18 @@ __all__ = ["consume_prefix_in_state_dict_if_present"]
def _ntuple(n, name="parse"): def _ntuple(n, name="parse"):
def parse(x): def parse(x):
if isinstance(x, collections.abc.Iterable): if isinstance(x, collections.abc.Iterable):
return tuple(x) ret = tuple(x)
# If the iterable is length 1, automatically expand to fill. This
# matches the behavior of expand_param_if_needed.
if len(ret) == 1:
return tuple(repeat(ret[0], n))
# Otherwise assert the correct length.
assert len(ret) == n, (
f"Expected an iterable of length {n}, but got length {len(ret)}"
)
return ret
return tuple(repeat(x, n)) return tuple(repeat(x, n))
parse.__name__ = name parse.__name__ = name