mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add grad_output shape check for fractional_max_pool2d_backward (#141666)
Fix https://github.com/pytorch/pytorch/issues/141102. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141666 Approved by: https://github.com/mingfeima, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
2def1f6f74
commit
d2b83aa122
@ -109,10 +109,13 @@ TORCH_META_FUNC(fractional_max_pool2d_backward)(
|
||||
/* get contiguous gradOutput */
|
||||
auto gradOutput = gradOutput_.contiguous();
|
||||
|
||||
TORCH_CHECK(outputW == gradOutput.size(widthDim),
|
||||
"fractional_max_pool2d_backward(): gradOutput width unexpected");
|
||||
TORCH_CHECK(outputH == gradOutput.size(heightDim),
|
||||
"fractional_max_pool2d_backward(): gradOutput height unexpected");
|
||||
auto expectedOutputShape = IntArrayRef(input.sizes().data(), ndims - 2).vec();
|
||||
expectedOutputShape.push_back(outputH);
|
||||
expectedOutputShape.push_back(outputW);
|
||||
TORCH_CHECK(gradOutput.sizes().equals(expectedOutputShape),
|
||||
"fractional_max_pool2d_backward(): gradOutput sizes unexpected");
|
||||
TORCH_CHECK(indices.sizes().equals(expectedOutputShape),
|
||||
"fractional_max_pool2d_backward(): indices sizes unexpected");
|
||||
|
||||
/* resize */
|
||||
if (ndims == 3) {
|
||||
|
@ -1783,6 +1783,19 @@ torch.cuda.synchronize()
|
||||
x, (2, 2), output_size=output_size, _random_samples=samples
|
||||
)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_fractional_max_pool2d_backward_fails(self, device):
|
||||
grad_output = torch.randn(1, 1, 2, 3, 3, device=device)
|
||||
input = torch.randn(1, 2, 7, 7, device=device)
|
||||
kernel_size = (2, 2)
|
||||
output_size = (3, 3)
|
||||
indices = torch.ones(1, 2, 3, 3, dtype=torch.long, device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "gradOutput sizes unexpected"):
|
||||
torch.ops.aten.fractional_max_pool2d_backward(
|
||||
grad_output, input, kernel_size, output_size, indices
|
||||
)
|
||||
|
||||
@expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta
|
||||
@onlyNativeDeviceTypes
|
||||
def test_fractional_max_pool3d(self, device):
|
||||
|
Reference in New Issue
Block a user