Fix invalid indices bug for max_unpool2d/3d on MPS (#163036)

Fixes #163035
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163036
Approved by: https://github.com/kulinseth, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
can-gaa-hou
2025-09-19 05:13:17 +00:00
committed by PyTorch MergeBot
parent c91f59b1a0
commit ce5637be29
2 changed files with 39 additions and 0 deletions

View File

@ -534,6 +534,18 @@ static void max_unpool_out_mps_template(const Tensor& input,
output.resize_(output_size, memory_format);
output.fill_(0);
if (indices.defined() && indices.numel() > 0) {
auto output_image_size = c10::multiply_integers(output_size_);
int64_t min_idx = indices.min().item<int64_t>();
int64_t max_idx = indices.max().item<int64_t>();
if (min_idx < 0 || max_idx >= output_image_size) {
int64_t error_idx = (min_idx < 0) ? min_idx : max_idx;
TORCH_CHECK(false, "Found an invalid max index: ", error_idx, " for output tensor of shape ", output_size_);
}
}
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const auto numThreads = input.numel();

View File

@ -1411,6 +1411,33 @@ torch.cuda.synchronize()
indices,
)
def test_max_unpool_invalid_indices(self):
input = torch.randn(1, 1, 2, 2)
negative_indices = torch.tensor([[[[-1, 0], [0, 2]]]], dtype=torch.int64)
large_indices = torch.tensor([[[[10000, 10], [0, 2]]]], dtype=torch.int64)
output_size = (2, 2)
with self.assertRaisesRegex(RuntimeError, "Found an invalid max index"):
F.max_unpool2d(input, negative_indices, output_size)
with self.assertRaisesRegex(RuntimeError, "Found an invalid max index"):
F.max_unpool2d(input, large_indices, output_size)
input = torch.randn(1, 1, 2, 2, 2)
negative_indices = torch.tensor(
[[[[[-1, 10], [0, 2]], [[1, 3], [4, 5]]]]], dtype=torch.int64
)
large_indices = torch.tensor(
[[[[[10000, 10], [0, 2]], [[1, 3], [4, 5]]]]], dtype=torch.int64
)
output_size = (2, 2, 2)
with self.assertRaisesRegex(RuntimeError, "Found an invalid max index"):
F.max_unpool3d(input, negative_indices, output_size)
with self.assertRaisesRegex(RuntimeError, "Found an invalid max index"):
F.max_unpool3d(input, large_indices, output_size)
@onlyCPU
@dtypes(torch.half, torch.bfloat16)
def test_avg_pool2d_reduced_floating(self, device, dtype):