From a0136f149c5fd0c1d968abff85fd133cb8a7fbc1 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 25 Sep 2025 12:37:08 -0500 Subject: [PATCH] [MPS] Fix nan behavior in `grid_sampler_3d` (#163881) Fixes #163851 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163881 Approved by: https://github.com/malfet --- aten/src/ATen/native/mps/kernels/GridSampler.metal | 3 --- test/test_mps.py | 7 +++++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/GridSampler.metal b/aten/src/ATen/native/mps/kernels/GridSampler.metal index 331793e08d66..84bfbb57f8f0 100644 --- a/aten/src/ATen/native/mps/kernels/GridSampler.metal +++ b/aten/src/ATen/native/mps/kernels/GridSampler.metal @@ -223,9 +223,6 @@ void grid_sampler_single_element( auto input_size = input_sizes[input_dim]; auto coord = static_cast>(coords[coord_dim]); - // Interpret nan as -1 - coord = isnan(coord) ? -1 : coord; - if (!align_corners) { // Map unaligned grid space to aligned grid space auto corner_alignment_factor = static_cast>(input_size) / diff --git a/test/test_mps.py b/test/test_mps.py index 1a8f7af83e3f..e38159172c9d 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -12546,6 +12546,13 @@ class TestConsistency(TestCaseMPS): self.assertEqual(half_out, full_out.to(dtype), atol=atol, rtol=rtol) + def test_grid_sampler_3d_nan(self, device): + input = torch.ones(1, 1, 3, 3, 3) + grid_nan = torch.tensor([[[[[torch.nan, 1., 1.], [1., 1., 1.]]]]]) + out_cpu = torch.grid_sampler_3d(input, grid_nan, 0, 0, True) + out_mps = torch.grid_sampler_3d(input.to(device), grid_nan.to(device), 0, 0, True) + self.assertEqual(out_mps, out_cpu) + def test_fmax_mixed_dtypes(self, device): # Regression tesing for https://github.com/pytorch/pytorch/issues/149951 # fmax and fmin are implemented as binary metal shaders and they were implemented