mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Test] Adding a testcase for constant_pad_nd (#161259)
Fixes #161066 This PR adds a simple testcase for constant_pad_nd on MPS as mentioned in https://github.com/pytorch/pytorch/pull/161149#issuecomment-3211701274 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161259 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
47d267364c
commit
cee72119b2
@ -8904,6 +8904,12 @@ class TestPad(TestCaseMPS):
|
||||
nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
|
||||
self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
|
||||
|
||||
def test_constant_pad_nd_with_empty_pad(self):
|
||||
# Empty constant pad is no-op
|
||||
# See https://github.com/pytorch/pytorch/issues/161066
|
||||
input_mps = torch.randn((2, 3, 4), device="mps")
|
||||
output_mps = torch.constant_pad_nd(input_mps, [])
|
||||
self.assertEqual(output_mps, input_mps)
|
||||
|
||||
class TestLinalgMPS(TestCaseMPS):
|
||||
def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False):
|
||||
|
Reference in New Issue
Block a user