From b5e6e58050bd2a15f4173cfffa00c7e32e382b49 Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Mon, 8 Sep 2025 20:55:19 +0000 Subject: [PATCH] [nn] Assert parsed iterable arguments are an appropriate length (#162340) Fixes #162327 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162340 Approved by: https://github.com/Skylion007 --- test/nn/test_pooling.py | 2 +- test/quantization/core/test_quantized_op.py | 20 +++++++++++--------- test/test_mps.py | 6 +++--- test/test_nn.py | 10 ++-------- torch/nn/modules/conv.py | 2 +- torch/nn/modules/utils.py | 15 +++++++++++++-- 6 files changed, 31 insertions(+), 24 deletions(-) diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py index a8f77df22d31..2e85f2da2268 100644 --- a/test/nn/test_pooling.py +++ b/test/nn/test_pooling.py @@ -481,7 +481,7 @@ class TestPoolingNN(NNTestCase): def test_max_unpool3d_input_check(self): x = torch.ones(1, 3, 1, 1, 1) - with self.assertRaises(RuntimeError): + with self.assertRaises(AssertionError): F.max_unpool3d(x, torch.zeros(x.shape, dtype=int), [1, 1]) def test_quantized_max_pool1d_empty_kernel(self): diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index b6df2089e87e..6b362bef365e 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -15,7 +15,7 @@ import torch from torch import _VF import torch.jit import torch.nn.functional as F -from torch.nn.modules.utils import _single, _pair +from torch.nn.modules.utils import _ntuple, _pair, _single from hypothesis import settings, HealthCheck from hypothesis import assume, given, note @@ -5311,10 +5311,11 @@ class TestQuantizedConv(TestCase): input_channels = input_channels_per_group * groups output_channels = output_channels_per_group * groups # Padded input size should be at least as big as dilated kernel - kernels = _single(kernels) - strides = _single(strides) - pads = _single(pads) - dilations = _single(dilations) + input_dimension_function = _ntuple(len(input_feature_map_shape)) + kernels = input_dimension_function(kernels) + strides = input_dimension_function(strides) + pads = input_dimension_function(pads) + dilations = input_dimension_function(dilations) for i in range(len(kernels)): assume(input_feature_map_shape[i] + 2 * pads[i] >= dilations[i] * (kernels[i] - 1) + 1) @@ -7846,10 +7847,11 @@ class TestQuantizedConv(TestCase): input_channels = input_channels_per_group * groups output_channels = output_channels_per_group * groups # Padded input size should be at least as big as dilated kernel - kernels = _single(kernels) - strides = _single(strides) - pads = _single(pads) - dilations = _single(dilations) + input_dimension_function = _ntuple(len(input_feature_map_shape)) + kernels = input_dimension_function(kernels) + strides = input_dimension_function(strides) + pads = input_dimension_function(pads) + dilations = input_dimension_function(dilations) for i in range(len(kernels)): assume(input_feature_map_shape[i] + 2 * pads[i] >= dilations[i] * (kernels[i] - 1) + 1) diff --git a/test/test_mps.py b/test/test_mps.py index 756b2cd20567..c172c8c119b2 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -8957,9 +8957,9 @@ class TestPad(TestCaseMPS): # pad dims == input dims helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d) # input.numel() == 0 but output.numel() > 0 - helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d) + helper((0, 3, 3), 1, nn.ConstantPad2d) # pad dims < input dims - 2 - helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d) + helper((1, 2, 3, 4, 5), (1, 2, 0, 0), nn.ConstantPad2d) # 3D Padding helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d) @@ -8972,7 +8972,7 @@ class TestPad(TestCaseMPS): # input size < pad size helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d) # check the workaround for the right padding bug in Monterey - helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d) + helper((1, 2, 2, 2, 2), (0, 1, 0, 1, 0, 1), nn.ConstantPad3d) def test_constant_pad_nd_preserves_memory_format(self): nchw_tensor = torch.rand((1, 2, 5, 3)) diff --git a/test/test_nn.py b/test/test_nn.py index c17f7cb668b6..13ee5c2e2a42 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7466,14 +7466,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""") def test_fractional_max_pool2d_invalid_output_ratio(self): arg_1 = [2, 1] arg_2 = [0.5, 0.5, 0.6] - arg_class = torch.nn.FractionalMaxPool2d(kernel_size=arg_1, output_ratio=arg_2,) - arg_3_0_tensor = torch.rand([20, 16, 50, 32], dtype=torch.float32) - arg_3_0 = arg_3_0_tensor.clone() - arg_3 = [arg_3_0,] - - with self.assertRaisesRegex(ValueError, - "fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."): - res = arg_class(*arg_3) + with self.assertRaisesRegex(AssertionError, "Expected an iterable of length 2, but got length 3"): + arg_class = torch.nn.FractionalMaxPool2d(kernel_size=arg_1, output_ratio=arg_2,) def test_max_pool1d_invalid_output_size(self): arg_1 = 3 diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 2f15c3d488f7..ffb6f21e6714 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -768,7 +768,7 @@ class _ConvTransposeNd(_ConvNd): dilation: Optional[list[int]] = None, ) -> list[int]: if output_size is None: - ret = _single(self.output_padding) # converting to list if was not already + ret = list(self.output_padding) # converting to list if was not already else: has_batch_dim = input.dim() == num_spatial_dims + 2 num_non_spatial_dims = 2 if has_batch_dim else 1 diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index 220b8f206b19..492556dab01e 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -import collections +import collections.abc from itertools import repeat from typing import Any @@ -10,7 +10,18 @@ __all__ = ["consume_prefix_in_state_dict_if_present"] def _ntuple(n, name="parse"): def parse(x): if isinstance(x, collections.abc.Iterable): - return tuple(x) + ret = tuple(x) + + # If the iterable is length 1, automatically expand to fill. This + # matches the behavior of expand_param_if_needed. + if len(ret) == 1: + return tuple(repeat(ret[0], n)) + + # Otherwise assert the correct length. + assert len(ret) == n, ( + f"Expected an iterable of length {n}, but got length {len(ret)}" + ) + return ret return tuple(repeat(x, n)) parse.__name__ = name