mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[nn] Assert parsed iterable arguments are an appropriate length (#162340)
Fixes #162327 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162340 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
fefc406a3d
commit
b5e6e58050
@ -15,7 +15,7 @@ import torch
|
||||
from torch import _VF
|
||||
import torch.jit
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.utils import _single, _pair
|
||||
from torch.nn.modules.utils import _ntuple, _pair, _single
|
||||
|
||||
from hypothesis import settings, HealthCheck
|
||||
from hypothesis import assume, given, note
|
||||
@ -5311,10 +5311,11 @@ class TestQuantizedConv(TestCase):
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
# Padded input size should be at least as big as dilated kernel
|
||||
kernels = _single(kernels)
|
||||
strides = _single(strides)
|
||||
pads = _single(pads)
|
||||
dilations = _single(dilations)
|
||||
input_dimension_function = _ntuple(len(input_feature_map_shape))
|
||||
kernels = input_dimension_function(kernels)
|
||||
strides = input_dimension_function(strides)
|
||||
pads = input_dimension_function(pads)
|
||||
dilations = input_dimension_function(dilations)
|
||||
for i in range(len(kernels)):
|
||||
assume(input_feature_map_shape[i] + 2 * pads[i]
|
||||
>= dilations[i] * (kernels[i] - 1) + 1)
|
||||
@ -7846,10 +7847,11 @@ class TestQuantizedConv(TestCase):
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
# Padded input size should be at least as big as dilated kernel
|
||||
kernels = _single(kernels)
|
||||
strides = _single(strides)
|
||||
pads = _single(pads)
|
||||
dilations = _single(dilations)
|
||||
input_dimension_function = _ntuple(len(input_feature_map_shape))
|
||||
kernels = input_dimension_function(kernels)
|
||||
strides = input_dimension_function(strides)
|
||||
pads = input_dimension_function(pads)
|
||||
dilations = input_dimension_function(dilations)
|
||||
for i in range(len(kernels)):
|
||||
assume(input_feature_map_shape[i] + 2 * pads[i]
|
||||
>= dilations[i] * (kernels[i] - 1) + 1)
|
||||
|
Reference in New Issue
Block a user