mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix invalid indices bug for max_unpool2d/3d on MPS (#163036)
Fixes #163035 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163036 Approved by: https://github.com/kulinseth, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
c91f59b1a0
commit
ce5637be29
@ -534,6 +534,18 @@ static void max_unpool_out_mps_template(const Tensor& input,
|
||||
output.resize_(output_size, memory_format);
|
||||
output.fill_(0);
|
||||
|
||||
if (indices.defined() && indices.numel() > 0) {
|
||||
auto output_image_size = c10::multiply_integers(output_size_);
|
||||
|
||||
int64_t min_idx = indices.min().item<int64_t>();
|
||||
int64_t max_idx = indices.max().item<int64_t>();
|
||||
|
||||
if (min_idx < 0 || max_idx >= output_image_size) {
|
||||
int64_t error_idx = (min_idx < 0) ? min_idx : max_idx;
|
||||
TORCH_CHECK(false, "Found an invalid max index: ", error_idx, " for output tensor of shape ", output_size_);
|
||||
}
|
||||
}
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const auto numThreads = input.numel();
|
||||
|
@ -1411,6 +1411,33 @@ torch.cuda.synchronize()
|
||||
indices,
|
||||
)
|
||||
|
||||
def test_max_unpool_invalid_indices(self):
|
||||
input = torch.randn(1, 1, 2, 2)
|
||||
negative_indices = torch.tensor([[[[-1, 0], [0, 2]]]], dtype=torch.int64)
|
||||
large_indices = torch.tensor([[[[10000, 10], [0, 2]]]], dtype=torch.int64)
|
||||
output_size = (2, 2)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Found an invalid max index"):
|
||||
F.max_unpool2d(input, negative_indices, output_size)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Found an invalid max index"):
|
||||
F.max_unpool2d(input, large_indices, output_size)
|
||||
|
||||
input = torch.randn(1, 1, 2, 2, 2)
|
||||
negative_indices = torch.tensor(
|
||||
[[[[[-1, 10], [0, 2]], [[1, 3], [4, 5]]]]], dtype=torch.int64
|
||||
)
|
||||
large_indices = torch.tensor(
|
||||
[[[[[10000, 10], [0, 2]], [[1, 3], [4, 5]]]]], dtype=torch.int64
|
||||
)
|
||||
output_size = (2, 2, 2)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Found an invalid max index"):
|
||||
F.max_unpool3d(input, negative_indices, output_size)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Found an invalid max index"):
|
||||
F.max_unpool3d(input, large_indices, output_size)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.half, torch.bfloat16)
|
||||
def test_avg_pool2d_reduced_floating(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user