[MPS] Fix nan behavior in grid_sampler_3d (#163881)

Fixes #163851
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163881
Approved by: https://github.com/malfet
This commit is contained in:
Kurt Mohler
2025-09-25 12:37:08 -05:00
committed by PyTorch MergeBot
parent 62b0ebd8f9
commit a0136f149c
2 changed files with 7 additions and 3 deletions

View File

@ -223,9 +223,6 @@ void grid_sampler_single_element(
auto input_size = input_sizes[input_dim];
auto coord = static_cast<opmath_t<T>>(coords[coord_dim]);
// Interpret nan as -1
coord = isnan(coord) ? -1 : coord;
if (!align_corners) {
// Map unaligned grid space to aligned grid space
auto corner_alignment_factor = static_cast<opmath_t<T>>(input_size) /

View File

@ -12546,6 +12546,13 @@ class TestConsistency(TestCaseMPS):
self.assertEqual(half_out, full_out.to(dtype), atol=atol, rtol=rtol)
def test_grid_sampler_3d_nan(self, device):
input = torch.ones(1, 1, 3, 3, 3)
grid_nan = torch.tensor([[[[[torch.nan, 1., 1.], [1., 1., 1.]]]]])
out_cpu = torch.grid_sampler_3d(input, grid_nan, 0, 0, True)
out_mps = torch.grid_sampler_3d(input.to(device), grid_nan.to(device), 0, 0, True)
self.assertEqual(out_mps, out_cpu)
def test_fmax_mixed_dtypes(self, device):
# Regression tesing for https://github.com/pytorch/pytorch/issues/149951
# fmax and fmin are implemented as binary metal shaders and they were implemented