mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BUG] MaxUnpool2d/3d should check output dim before accessing its elements (#163507)
Fixes #163409 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163507 Approved by: https://github.com/malfet, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
da05aa7a9d
commit
0256f91558
@ -23,8 +23,6 @@ Tensor& max_unpooling2d_forward_out_cpu(
|
||||
// Nondeterministic with duplicate indices
|
||||
at::globalContext().alertNotDeterministic("max_unpooling2d_forward_out");
|
||||
|
||||
auto oheight = output_size[0];
|
||||
auto owidth = output_size[1];
|
||||
TORCH_CHECK(
|
||||
indices_.scalar_type() == at::ScalarType::Long,
|
||||
"elements in indices should be type int64 but got: ", indices_.scalar_type());
|
||||
@ -45,6 +43,9 @@ Tensor& max_unpooling2d_forward_out_cpu(
|
||||
self_.sizes(), " with dimension ", i , " being empty.");
|
||||
}
|
||||
|
||||
auto oheight = output_size[0];
|
||||
auto owidth = output_size[1];
|
||||
|
||||
auto memory_format = self_.suggest_memory_format();
|
||||
auto self = self_.contiguous(memory_format);
|
||||
auto indices = indices_.contiguous(memory_format);
|
||||
|
@ -125,8 +125,6 @@ Tensor& max_unpooling2d_forward_out_cuda(const Tensor& self_,
|
||||
TORCH_CHECK(
|
||||
indices_.scalar_type() == at::ScalarType::Long,
|
||||
"elements in indices should be type int64 but got: ", indices_.scalar_type());
|
||||
auto oheight = output_size[0];
|
||||
auto owidth = output_size[1];
|
||||
|
||||
TensorArg output_arg{output, "output", 1}, self_arg{self_, "self_", 2},
|
||||
indices_arg{indices_, "indices_", 3};
|
||||
@ -149,6 +147,9 @@ Tensor& max_unpooling2d_forward_out_cuda(const Tensor& self_,
|
||||
output_size.size() == 2,
|
||||
"There should be exactly two elements (height, width) in output_size, but got ", output_size.size(), " elements.");
|
||||
|
||||
auto oheight = output_size[0];
|
||||
auto owidth = output_size[1];
|
||||
|
||||
int64_t dimw = 2;
|
||||
int64_t dimh = 1;
|
||||
int64_t numBatch = 1;
|
||||
@ -217,9 +218,6 @@ static void max_unpooling3d_shape_check(
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
const char *fn_name) {
|
||||
int64_t oT = output_size[0];
|
||||
int64_t oH = output_size[1];
|
||||
int64_t oW = output_size[2];
|
||||
TORCH_CHECK(
|
||||
indices.scalar_type() == at::ScalarType::Long,
|
||||
"elements in indices should be type int64 but got: ", indices.scalar_type());
|
||||
@ -250,6 +248,10 @@ static void max_unpooling3d_shape_check(
|
||||
"strides should be greater than zero, but got stride: ",
|
||||
stride);
|
||||
|
||||
int64_t oT = output_size[0];
|
||||
int64_t oH = output_size[1];
|
||||
int64_t oW = output_size[2];
|
||||
|
||||
int dimw = 3;
|
||||
int dimh = 2;
|
||||
int dimt = 1;
|
||||
@ -402,8 +404,6 @@ at::Tensor& max_unpooling2d_backward_out_cuda(const Tensor& grad_output_,
|
||||
const Tensor& indices_,
|
||||
IntArrayRef output_size,
|
||||
Tensor& grad_input) {
|
||||
int64_t oheight = output_size[0];
|
||||
int64_t owidth = output_size[1];
|
||||
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
|
||||
TORCH_CHECK(
|
||||
indices_.scalar_type() == at::ScalarType::Long,
|
||||
@ -426,6 +426,9 @@ at::Tensor& max_unpooling2d_backward_out_cuda(const Tensor& grad_output_,
|
||||
|
||||
TORCH_CHECK(output_size.size() == 2, "output_size must have two elements, got size: ", output_size.size());
|
||||
|
||||
int64_t oheight = output_size[0];
|
||||
int64_t owidth = output_size[1];
|
||||
|
||||
int64_t nInputCols, nInputRows, nInputPlane;
|
||||
|
||||
int dimw = 2;
|
||||
@ -505,13 +508,14 @@ at::Tensor& max_unpooling3d_backward_out_cuda(const Tensor& grad_output_,
|
||||
IntArrayRef padding,
|
||||
Tensor& grad_input) {
|
||||
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
|
||||
int64_t oT = output_size[0];
|
||||
int64_t oH = output_size[1];
|
||||
int64_t oW = output_size[2];
|
||||
|
||||
max_unpooling3d_shape_check(
|
||||
self_, grad_output_, indices_, output_size, stride, padding, "max_unpooling3d_backward_out_cuda()");
|
||||
|
||||
int64_t oT = output_size[0];
|
||||
int64_t oH = output_size[1];
|
||||
int64_t oW = output_size[2];
|
||||
|
||||
int batchSize = 0;
|
||||
int inputSlices = 0;
|
||||
int inputTime = 0;
|
||||
|
@ -519,6 +519,13 @@ static void max_unpool_out_mps_template(const Tensor& input,
|
||||
Tensor& output,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
TORCH_CHECK(output_size_.size() == static_cast<size_t>(pooling_dims),
|
||||
op_name,
|
||||
"There should be exactly ",
|
||||
pooling_dims,
|
||||
" elements but got ",
|
||||
output_size_.size());
|
||||
|
||||
auto dims = input.dim();
|
||||
auto leading_dims = input.dim() - pooling_dims;
|
||||
|
||||
|
@ -857,6 +857,20 @@ torch.cuda.synchronize()
|
||||
else:
|
||||
unpool(output, indices)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/163409
|
||||
@onlyNativeDeviceTypes
|
||||
def test_MaxUnpool_invalid_output_size(self, device):
|
||||
input2d = torch.randn(1, 1, 1)
|
||||
input3d = torch.randn(1, 1, 1, 1, 1)
|
||||
unpool2d = torch.nn.MaxUnpool2d(())
|
||||
unpool3d = torch.nn.MaxUnpool3d(())
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "There should be exactly"):
|
||||
unpool2d(input2d, torch.zeros_like(input2d, dtype=torch.int64))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "There should be exactly"):
|
||||
unpool3d(input3d, torch.zeros_like(input3d, dtype=torch.int64))
|
||||
|
||||
@expectedFailureMPS
|
||||
@onlyNativeDeviceTypes
|
||||
def test_AdaptiveMaxPool_zero_batch_dim(self, device):
|
||||
|
Reference in New Issue
Block a user