[MPS] Add native_dropout and native_dropout_backward (#162108)

Fixes #162002
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162108
Approved by: https://github.com/malfet
This commit is contained in:
Kurt Mohler
2025-09-08 15:22:37 -05:00
committed by PyTorch MergeBot
parent e025c0f459
commit 583bbf7761
7 changed files with 96 additions and 1 deletions

View File

@ -7524,6 +7524,39 @@ class TestMPS(TestCaseMPS):
uniq = mps_out.unique()
self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype))
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_dropout(self, dtype):
shapes = [
(100_000,),
(100, 1000),
(10, 100, 100),
(10, 10, 10, 10, 10),
]
p_list = [0, 0.34, 0.78, 1]
for shape, p, train in itertools.product(shapes, p_list, [False, True]):
input = torch.randn(shape, device='mps', dtype=dtype, requires_grad=True)
output, mask = torch.native_dropout(input, p, train=train)
p_actual_mps = 1 - (mask.sum() / mask.numel())
if train:
self.assertEqual(p_actual_mps, p, atol=1e-2, rtol=1e-2)
self.assertTrue((output[mask.logical_not()] == 0).all())
self.assertEqual(output[mask], input[mask] / (1 - p))
else:
self.assertEqual(output, input)
self.assertTrue(mask.all())
output_grad = torch.randn_like(output)
output.backward(output_grad)
grad_scale = 0 if p == 1 else 1 / (1 - p)
if train:
self.assertEqual(input.grad, output_grad * mask * grad_scale)
else:
self.assertEqual(input.grad, output_grad)
def test_mps_generator(self):
# explicit manual seeding by creating an MPS Generator
g_mps = torch.Generator(device='mps')