[MPS] Add nearest_3d forward and backward (#156090)

Introduce generalizable `UpsampleParams` structure in `UpSample.h`, which could be shared between CPU and MPS
Delete `upsample_nearest3d` MPS fallback and replace it with proper shader
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156090
Approved by: https://github.com/kulinseth, https://github.com/dcci
ghstack dependencies: #156256
This commit is contained in:
Nikita Shulga
2025-06-17 17:51:00 -07:00
committed by PyTorch MergeBot
parent a82c171bb2
commit c28e74e457
7 changed files with 377 additions and 6 deletions

View File

@ -9799,7 +9799,6 @@ class TestNNDeviceType(NNTestCase):
expected_out = expected_out.to(device=device)
self.assertEqual(out_t, expected_out)
@expectedFailureMPS # NotImplementedError: aten::_upsample_nearest_exact3d.out https://github.com/pytorch/pytorch/issues/77764
@parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d])
@parametrize_test("isize, osize", [(20, 11), (10, 15)])
def test_upsamplingNearestExact3d_correctness(self, device, memory_format, isize, osize):