add grad_output shape check for fractional_max_pool2d_backward (#141666)

Fix https://github.com/pytorch/pytorch/issues/141102.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141666
Approved by: https://github.com/mingfeima, https://github.com/malfet
This commit is contained in:
Sun, Jiayi
2024-12-19 09:43:23 +00:00
committed by PyTorch MergeBot
parent 2def1f6f74
commit d2b83aa122
2 changed files with 20 additions and 4 deletions

View File

@ -109,10 +109,13 @@ TORCH_META_FUNC(fractional_max_pool2d_backward)(
/* get contiguous gradOutput */
auto gradOutput = gradOutput_.contiguous();
TORCH_CHECK(outputW == gradOutput.size(widthDim),
"fractional_max_pool2d_backward(): gradOutput width unexpected");
TORCH_CHECK(outputH == gradOutput.size(heightDim),
"fractional_max_pool2d_backward(): gradOutput height unexpected");
auto expectedOutputShape = IntArrayRef(input.sizes().data(), ndims - 2).vec();
expectedOutputShape.push_back(outputH);
expectedOutputShape.push_back(outputW);
TORCH_CHECK(gradOutput.sizes().equals(expectedOutputShape),
"fractional_max_pool2d_backward(): gradOutput sizes unexpected");
TORCH_CHECK(indices.sizes().equals(expectedOutputShape),
"fractional_max_pool2d_backward(): indices sizes unexpected");
/* resize */
if (ndims == 3) {

View File

@ -1783,6 +1783,19 @@ torch.cuda.synchronize()
x, (2, 2), output_size=output_size, _random_samples=samples
)
@onlyNativeDeviceTypes
def test_fractional_max_pool2d_backward_fails(self, device):
grad_output = torch.randn(1, 1, 2, 3, 3, device=device)
input = torch.randn(1, 2, 7, 7, device=device)
kernel_size = (2, 2)
output_size = (3, 3)
indices = torch.ones(1, 2, 3, 3, dtype=torch.long, device=device)
with self.assertRaisesRegex(RuntimeError, "gradOutput sizes unexpected"):
torch.ops.aten.fractional_max_pool2d_backward(
grad_output, input, kernel_size, output_size, indices
)
@expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta
@onlyNativeDeviceTypes
def test_fractional_max_pool3d(self, device):