mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Add nearest_3d forward and backward (#156090)
Introduce generalizable `UpsampleParams` structure in `UpSample.h`, which could be shared between CPU and MPS Delete `upsample_nearest3d` MPS fallback and replace it with proper shader Pull Request resolved: https://github.com/pytorch/pytorch/pull/156090 Approved by: https://github.com/kulinseth, https://github.com/dcci ghstack dependencies: #156256
This commit is contained in:
committed by
PyTorch MergeBot
parent
a82c171bb2
commit
c28e74e457
@ -9799,7 +9799,6 @@ class TestNNDeviceType(NNTestCase):
|
||||
expected_out = expected_out.to(device=device)
|
||||
self.assertEqual(out_t, expected_out)
|
||||
|
||||
@expectedFailureMPS # NotImplementedError: aten::_upsample_nearest_exact3d.out https://github.com/pytorch/pytorch/issues/77764
|
||||
@parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d])
|
||||
@parametrize_test("isize, osize", [(20, 11), (10, 15)])
|
||||
def test_upsamplingNearestExact3d_correctness(self, device, memory_format, isize, osize):
|
||||
|
Reference in New Issue
Block a user