mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[MPS] Fix nan behavior in grid_sampler_3d
(#163881)
Fixes #163851 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163881 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
62b0ebd8f9
commit
a0136f149c
@ -223,9 +223,6 @@ void grid_sampler_single_element(
|
||||
auto input_size = input_sizes[input_dim];
|
||||
auto coord = static_cast<opmath_t<T>>(coords[coord_dim]);
|
||||
|
||||
// Interpret nan as -1
|
||||
coord = isnan(coord) ? -1 : coord;
|
||||
|
||||
if (!align_corners) {
|
||||
// Map unaligned grid space to aligned grid space
|
||||
auto corner_alignment_factor = static_cast<opmath_t<T>>(input_size) /
|
||||
|
@ -12546,6 +12546,13 @@ class TestConsistency(TestCaseMPS):
|
||||
|
||||
self.assertEqual(half_out, full_out.to(dtype), atol=atol, rtol=rtol)
|
||||
|
||||
def test_grid_sampler_3d_nan(self, device):
|
||||
input = torch.ones(1, 1, 3, 3, 3)
|
||||
grid_nan = torch.tensor([[[[[torch.nan, 1., 1.], [1., 1., 1.]]]]])
|
||||
out_cpu = torch.grid_sampler_3d(input, grid_nan, 0, 0, True)
|
||||
out_mps = torch.grid_sampler_3d(input.to(device), grid_nan.to(device), 0, 0, True)
|
||||
self.assertEqual(out_mps, out_cpu)
|
||||
|
||||
def test_fmax_mixed_dtypes(self, device):
|
||||
# Regression tesing for https://github.com/pytorch/pytorch/issues/149951
|
||||
# fmax and fmin are implemented as binary metal shaders and they were implemented
|
||||
|
Reference in New Issue
Block a user