mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix overflow in slow_conv3d when kernel size is too large. (#162718)
Also, adding check for padding to avoid segmentation fault caused by overflow. Fixes #141846 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162718 Approved by: https://github.com/jgong5, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
bfd21cd3e6
commit
d4e4f70768
@ -9,6 +9,7 @@
|
||||
#include <ATen/native/TransposeType.h>
|
||||
#include <ATen/native/Unfold3d.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/safe_numerics.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -174,6 +175,23 @@ static inline void slow_conv3d_shape_check(
|
||||
const int64_t input_height = input.size(dim_height);
|
||||
const int64_t input_width = input.size(dim_width);
|
||||
|
||||
constexpr int64_t MAX_SAFE_PAD = (1LL << 61);
|
||||
|
||||
TORCH_CHECK_VALUE(
|
||||
pad_height <= MAX_SAFE_PAD,
|
||||
"Padding height too large: pad_height=",
|
||||
pad_height);
|
||||
|
||||
TORCH_CHECK_VALUE(
|
||||
pad_width <= MAX_SAFE_PAD,
|
||||
"Padding width too large: pad_width=",
|
||||
pad_width);
|
||||
|
||||
TORCH_CHECK_VALUE(
|
||||
pad_depth <= MAX_SAFE_PAD,
|
||||
"Padding depth too large: pad_depth=",
|
||||
pad_depth);
|
||||
|
||||
const int64_t exact_input_depth = input_depth + 2 * pad_depth;
|
||||
const int64_t exact_input_height = input_height + 2 * pad_height;
|
||||
const int64_t exact_input_width = input_width + 2 * pad_width;
|
||||
@ -221,6 +239,14 @@ static inline void slow_conv3d_shape_check(
|
||||
output_width,
|
||||
"). Output size is too small");
|
||||
|
||||
uint64_t kernel_product;
|
||||
TORCH_CHECK(
|
||||
!c10::mul_overflows(kernel_height, kernel_width, &kernel_product),
|
||||
"Kernel height x width product is too large: kernel_height=",
|
||||
kernel_height,
|
||||
", kernel_width=",
|
||||
kernel_width);
|
||||
|
||||
if (weight.defined()) {
|
||||
int64_t n_input_plane = weight.size(1);
|
||||
if (weight.dim() == 2) {
|
||||
|
@ -229,6 +229,52 @@ class TestConvolutionNN(NNTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
|
||||
torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2)
|
||||
|
||||
def test_conv3d_overflow_values(self):
|
||||
input = torch.full(
|
||||
(
|
||||
0,
|
||||
7,
|
||||
9,
|
||||
1,
|
||||
5,
|
||||
),
|
||||
0,
|
||||
dtype=torch.float32,
|
||||
requires_grad=False,
|
||||
)
|
||||
weight = torch.full(
|
||||
(
|
||||
9,
|
||||
1,
|
||||
),
|
||||
4.14214e16,
|
||||
dtype=torch.float32,
|
||||
requires_grad=False,
|
||||
)
|
||||
stride = [5, 5, 5]
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Padding height too large"):
|
||||
torch.ops.aten.slow_conv3d(
|
||||
input,
|
||||
weight,
|
||||
kernel_size=[5, 5, 5],
|
||||
bias=None,
|
||||
stride=stride,
|
||||
padding=[2**62, 2**62, 2**62],
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Kernel height x width product is too large:"
|
||||
):
|
||||
torch.ops.aten.slow_conv3d(
|
||||
input,
|
||||
weight,
|
||||
kernel_size=[2**32, 2**32, 2**32],
|
||||
bias=None,
|
||||
stride=stride,
|
||||
padding=[2**31, 2**31, 2**31],
|
||||
)
|
||||
|
||||
def test_Conv1d_module_same_padding(self):
|
||||
# Compare module against functional: without strides/dilation, asymmetric padding
|
||||
x = torch.rand(1, 1, 20)
|
||||
|
Reference in New Issue
Block a user