add grad_output shape check for adaptive_avg_pool2d_backward (#145241)

Fix https://github.com/pytorch/pytorch/issues/145070.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145241
Approved by: https://github.com/malfet, https://github.com/eqy
This commit is contained in:
Sun, Jiayi
2025-03-19 19:15:56 -07:00
committed by PyTorch MergeBot
parent 00a2c68f67
commit 496bbf38be
5 changed files with 26 additions and 12 deletions

View File

@ -63,20 +63,16 @@ namespace {
const Tensor& grad_output,
const Tensor& input)
{
int64_t ndim = grad_output.ndimension();
for (const auto i : c10::irange(1, ndim)) {
TORCH_CHECK(grad_output.size(i) > 0,
"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, "
"but grad_output has sizes ", grad_output.sizes(), " with dimension ", i, " being "
"empty");
}
adaptive_pool_empty_output_check(grad_output, "adaptive_avg_pool2d_backward");
int64_t ndim = grad_output.dim();
TORCH_CHECK(input.dim() == ndim,
__func__, ": Expected dimensions ", input.dim(), " for `grad_output` but got dimensions ", ndim);
TORCH_CHECK((ndim == 3 || ndim == 4),
"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got ", input.sizes());
__func__, ": Expected 3D or 4D tensor, but got ", input.sizes());
TORCH_CHECK(input.dtype() == grad_output.dtype(),
"expected dtype ", input.dtype(), " for `grad_output` but got dtype ", grad_output.dtype());
__func__, ": Expected dtype ", input.dtype(), " for `grad_output` but got dtype ", grad_output.dtype());
TORCH_CHECK(input.dtype() == grad_input.dtype(),
"expected dtype ", input.dtype(), " for `grad_input` but got dtype ", grad_input.dtype());
__func__, ": Expected dtype ", input.dtype(), " for `grad_input` but got dtype ", grad_input.dtype());
grad_input.resize_(input.sizes(), input.suggest_memory_format());
grad_input.zero_();

View File

@ -235,6 +235,8 @@ Tensor& adaptive_avg_pool3d_backward_out_cpu_template(
auto gradOutput = gradOutput_.contiguous();
adaptive_pool_empty_output_check(gradOutput_, "adaptive_avg_pool3d_backward");
TORCH_CHECK(input.dim() == gradOutput_.dim(),
__func__, ": Expected dimensions ", input.dim(), " for `gradOutput_` but got dimensions ", gradOutput_.dim());
/* sizes */
int64_t sizeD = input.size(-4);

View File

@ -608,6 +608,8 @@ namespace {
input_arg{ input, "input", 3 };
adaptive_pool_empty_output_check(gradOutput_, "adaptive_avg_pool2d_backward");
TORCH_CHECK(input.dim() == gradOutput_.dim(),
__func__, ": Expected dimensions ", input.dim(), " for `gradOutput_` but got dimensions ", gradOutput_.dim());
checkAllSameGPU(__func__, {grad_input_arg, grad_output_arg, input_arg});

View File

@ -428,6 +428,8 @@ void adaptive_avg_pool3d_backward_out_cuda_template(
TensorArg input_arg{input, "input", 3};
adaptive_pool_empty_output_check(gradOutput_, "adaptive_avg_pool3d_backward");
TORCH_CHECK(input.dim() == gradOutput_.dim(),
__func__, ": Expected dimensions ", input.dim(), " for `gradOutput_` but got dimensions ", gradOutput_.dim());
checkAllSameGPU(
"adaptive_avg_pool3d_out_cuda",

View File

@ -557,7 +557,19 @@ class TestPoolingNNDeviceType(NNTestCase):
fn(input2, output_size).sum().backward()
@onlyNativeDeviceTypes
def test_adaptive_pooling_backward_fails(self, device):
def test_adaptive_avg_pooling_backward_fails(self, device):
grad_output = torch.randn(1, 2, 7, device=device)
input = torch.randn(1, 2, 3, 3, device=device)
with self.assertRaisesRegex(RuntimeError, "Expected dimensions"):
torch.ops.aten._adaptive_avg_pool2d_backward(grad_output, input)
grad_output = torch.randn(1, 2, 7, 7, device=device)
input = torch.randn(1, 2, 3, 3, 3, device=device)
with self.assertRaisesRegex(RuntimeError, "Expected dimensions"):
torch.ops.aten._adaptive_avg_pool3d_backward(grad_output, input)
@onlyNativeDeviceTypes
def test_adaptive_max_pooling_backward_fails(self, device):
grad_output = torch.randn(1, 2, 7, 7, device=device)
input = torch.randn(1, 2, 7, 7, device=device)
indices = torch.ones(1, 2, 3, 3, dtype=torch.long, device=device)