mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nn] Assert parsed iterable arguments are an appropriate length (#162340)
Fixes #162327 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162340 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
fefc406a3d
commit
b5e6e58050
@ -481,7 +481,7 @@ class TestPoolingNN(NNTestCase):
|
|||||||
|
|
||||||
def test_max_unpool3d_input_check(self):
|
def test_max_unpool3d_input_check(self):
|
||||||
x = torch.ones(1, 3, 1, 1, 1)
|
x = torch.ones(1, 3, 1, 1, 1)
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(AssertionError):
|
||||||
F.max_unpool3d(x, torch.zeros(x.shape, dtype=int), [1, 1])
|
F.max_unpool3d(x, torch.zeros(x.shape, dtype=int), [1, 1])
|
||||||
|
|
||||||
def test_quantized_max_pool1d_empty_kernel(self):
|
def test_quantized_max_pool1d_empty_kernel(self):
|
||||||
|
@ -15,7 +15,7 @@ import torch
|
|||||||
from torch import _VF
|
from torch import _VF
|
||||||
import torch.jit
|
import torch.jit
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.modules.utils import _single, _pair
|
from torch.nn.modules.utils import _ntuple, _pair, _single
|
||||||
|
|
||||||
from hypothesis import settings, HealthCheck
|
from hypothesis import settings, HealthCheck
|
||||||
from hypothesis import assume, given, note
|
from hypothesis import assume, given, note
|
||||||
@ -5311,10 +5311,11 @@ class TestQuantizedConv(TestCase):
|
|||||||
input_channels = input_channels_per_group * groups
|
input_channels = input_channels_per_group * groups
|
||||||
output_channels = output_channels_per_group * groups
|
output_channels = output_channels_per_group * groups
|
||||||
# Padded input size should be at least as big as dilated kernel
|
# Padded input size should be at least as big as dilated kernel
|
||||||
kernels = _single(kernels)
|
input_dimension_function = _ntuple(len(input_feature_map_shape))
|
||||||
strides = _single(strides)
|
kernels = input_dimension_function(kernels)
|
||||||
pads = _single(pads)
|
strides = input_dimension_function(strides)
|
||||||
dilations = _single(dilations)
|
pads = input_dimension_function(pads)
|
||||||
|
dilations = input_dimension_function(dilations)
|
||||||
for i in range(len(kernels)):
|
for i in range(len(kernels)):
|
||||||
assume(input_feature_map_shape[i] + 2 * pads[i]
|
assume(input_feature_map_shape[i] + 2 * pads[i]
|
||||||
>= dilations[i] * (kernels[i] - 1) + 1)
|
>= dilations[i] * (kernels[i] - 1) + 1)
|
||||||
@ -7846,10 +7847,11 @@ class TestQuantizedConv(TestCase):
|
|||||||
input_channels = input_channels_per_group * groups
|
input_channels = input_channels_per_group * groups
|
||||||
output_channels = output_channels_per_group * groups
|
output_channels = output_channels_per_group * groups
|
||||||
# Padded input size should be at least as big as dilated kernel
|
# Padded input size should be at least as big as dilated kernel
|
||||||
kernels = _single(kernels)
|
input_dimension_function = _ntuple(len(input_feature_map_shape))
|
||||||
strides = _single(strides)
|
kernels = input_dimension_function(kernels)
|
||||||
pads = _single(pads)
|
strides = input_dimension_function(strides)
|
||||||
dilations = _single(dilations)
|
pads = input_dimension_function(pads)
|
||||||
|
dilations = input_dimension_function(dilations)
|
||||||
for i in range(len(kernels)):
|
for i in range(len(kernels)):
|
||||||
assume(input_feature_map_shape[i] + 2 * pads[i]
|
assume(input_feature_map_shape[i] + 2 * pads[i]
|
||||||
>= dilations[i] * (kernels[i] - 1) + 1)
|
>= dilations[i] * (kernels[i] - 1) + 1)
|
||||||
|
@ -8957,9 +8957,9 @@ class TestPad(TestCaseMPS):
|
|||||||
# pad dims == input dims
|
# pad dims == input dims
|
||||||
helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
|
helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
|
||||||
# input.numel() == 0 but output.numel() > 0
|
# input.numel() == 0 but output.numel() > 0
|
||||||
helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d)
|
helper((0, 3, 3), 1, nn.ConstantPad2d)
|
||||||
# pad dims < input dims - 2
|
# pad dims < input dims - 2
|
||||||
helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d)
|
helper((1, 2, 3, 4, 5), (1, 2, 0, 0), nn.ConstantPad2d)
|
||||||
|
|
||||||
# 3D Padding
|
# 3D Padding
|
||||||
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
|
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
|
||||||
@ -8972,7 +8972,7 @@ class TestPad(TestCaseMPS):
|
|||||||
# input size < pad size
|
# input size < pad size
|
||||||
helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
|
helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
|
||||||
# check the workaround for the right padding bug in Monterey
|
# check the workaround for the right padding bug in Monterey
|
||||||
helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d)
|
helper((1, 2, 2, 2, 2), (0, 1, 0, 1, 0, 1), nn.ConstantPad3d)
|
||||||
|
|
||||||
def test_constant_pad_nd_preserves_memory_format(self):
|
def test_constant_pad_nd_preserves_memory_format(self):
|
||||||
nchw_tensor = torch.rand((1, 2, 5, 3))
|
nchw_tensor = torch.rand((1, 2, 5, 3))
|
||||||
|
@ -7466,14 +7466,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||||||
def test_fractional_max_pool2d_invalid_output_ratio(self):
|
def test_fractional_max_pool2d_invalid_output_ratio(self):
|
||||||
arg_1 = [2, 1]
|
arg_1 = [2, 1]
|
||||||
arg_2 = [0.5, 0.5, 0.6]
|
arg_2 = [0.5, 0.5, 0.6]
|
||||||
arg_class = torch.nn.FractionalMaxPool2d(kernel_size=arg_1, output_ratio=arg_2,)
|
with self.assertRaisesRegex(AssertionError, "Expected an iterable of length 2, but got length 3"):
|
||||||
arg_3_0_tensor = torch.rand([20, 16, 50, 32], dtype=torch.float32)
|
arg_class = torch.nn.FractionalMaxPool2d(kernel_size=arg_1, output_ratio=arg_2,)
|
||||||
arg_3_0 = arg_3_0_tensor.clone()
|
|
||||||
arg_3 = [arg_3_0,]
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError,
|
|
||||||
"fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."):
|
|
||||||
res = arg_class(*arg_3)
|
|
||||||
|
|
||||||
def test_max_pool1d_invalid_output_size(self):
|
def test_max_pool1d_invalid_output_size(self):
|
||||||
arg_1 = 3
|
arg_1 = 3
|
||||||
|
@ -768,7 +768,7 @@ class _ConvTransposeNd(_ConvNd):
|
|||||||
dilation: Optional[list[int]] = None,
|
dilation: Optional[list[int]] = None,
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
if output_size is None:
|
if output_size is None:
|
||||||
ret = _single(self.output_padding) # converting to list if was not already
|
ret = list(self.output_padding) # converting to list if was not already
|
||||||
else:
|
else:
|
||||||
has_batch_dim = input.dim() == num_spatial_dims + 2
|
has_batch_dim = input.dim() == num_spatial_dims + 2
|
||||||
num_non_spatial_dims = 2 if has_batch_dim else 1
|
num_non_spatial_dims = 2 if has_batch_dim else 1
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import collections
|
import collections.abc
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -10,7 +10,18 @@ __all__ = ["consume_prefix_in_state_dict_if_present"]
|
|||||||
def _ntuple(n, name="parse"):
|
def _ntuple(n, name="parse"):
|
||||||
def parse(x):
|
def parse(x):
|
||||||
if isinstance(x, collections.abc.Iterable):
|
if isinstance(x, collections.abc.Iterable):
|
||||||
return tuple(x)
|
ret = tuple(x)
|
||||||
|
|
||||||
|
# If the iterable is length 1, automatically expand to fill. This
|
||||||
|
# matches the behavior of expand_param_if_needed.
|
||||||
|
if len(ret) == 1:
|
||||||
|
return tuple(repeat(ret[0], n))
|
||||||
|
|
||||||
|
# Otherwise assert the correct length.
|
||||||
|
assert len(ret) == n, (
|
||||||
|
f"Expected an iterable of length {n}, but got length {len(ret)}"
|
||||||
|
)
|
||||||
|
return ret
|
||||||
return tuple(repeat(x, n))
|
return tuple(repeat(x, n))
|
||||||
|
|
||||||
parse.__name__ = name
|
parse.__name__ = name
|
||||||
|
Reference in New Issue
Block a user