[nn] Assert parsed iterable arguments are an appropriate length (#162340)

Fixes #162327
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162340
Approved by: https://github.com/Skylion007
This commit is contained in:
Benjamin Glass
2025-09-08 20:55:19 +00:00
committed by PyTorch MergeBot
parent fefc406a3d
commit b5e6e58050
6 changed files with 31 additions and 24 deletions

View File

@ -15,7 +15,7 @@ import torch
from torch import _VF
import torch.jit
import torch.nn.functional as F
from torch.nn.modules.utils import _single, _pair
from torch.nn.modules.utils import _ntuple, _pair, _single
from hypothesis import settings, HealthCheck
from hypothesis import assume, given, note
@ -5311,10 +5311,11 @@ class TestQuantizedConv(TestCase):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
# Padded input size should be at least as big as dilated kernel
kernels = _single(kernels)
strides = _single(strides)
pads = _single(pads)
dilations = _single(dilations)
input_dimension_function = _ntuple(len(input_feature_map_shape))
kernels = input_dimension_function(kernels)
strides = input_dimension_function(strides)
pads = input_dimension_function(pads)
dilations = input_dimension_function(dilations)
for i in range(len(kernels)):
assume(input_feature_map_shape[i] + 2 * pads[i]
>= dilations[i] * (kernels[i] - 1) + 1)
@ -7846,10 +7847,11 @@ class TestQuantizedConv(TestCase):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
# Padded input size should be at least as big as dilated kernel
kernels = _single(kernels)
strides = _single(strides)
pads = _single(pads)
dilations = _single(dilations)
input_dimension_function = _ntuple(len(input_feature_map_shape))
kernels = input_dimension_function(kernels)
strides = input_dimension_function(strides)
pads = input_dimension_function(pads)
dilations = input_dimension_function(dilations)
for i in range(len(kernels)):
assume(input_feature_map_shape[i] + 2 * pads[i]
>= dilations[i] * (kernels[i] - 1) + 1)