mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Add native_dropout
and native_dropout_backward
(#162108)
Fixes #162002 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162108 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
e025c0f459
commit
583bbf7761
@ -7524,6 +7524,39 @@ class TestMPS(TestCaseMPS):
|
||||
uniq = mps_out.unique()
|
||||
self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype))
|
||||
|
||||
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_dropout(self, dtype):
|
||||
shapes = [
|
||||
(100_000,),
|
||||
(100, 1000),
|
||||
(10, 100, 100),
|
||||
(10, 10, 10, 10, 10),
|
||||
]
|
||||
p_list = [0, 0.34, 0.78, 1]
|
||||
|
||||
for shape, p, train in itertools.product(shapes, p_list, [False, True]):
|
||||
input = torch.randn(shape, device='mps', dtype=dtype, requires_grad=True)
|
||||
output, mask = torch.native_dropout(input, p, train=train)
|
||||
|
||||
p_actual_mps = 1 - (mask.sum() / mask.numel())
|
||||
if train:
|
||||
self.assertEqual(p_actual_mps, p, atol=1e-2, rtol=1e-2)
|
||||
self.assertTrue((output[mask.logical_not()] == 0).all())
|
||||
self.assertEqual(output[mask], input[mask] / (1 - p))
|
||||
else:
|
||||
self.assertEqual(output, input)
|
||||
self.assertTrue(mask.all())
|
||||
|
||||
output_grad = torch.randn_like(output)
|
||||
output.backward(output_grad)
|
||||
|
||||
grad_scale = 0 if p == 1 else 1 / (1 - p)
|
||||
if train:
|
||||
self.assertEqual(input.grad, output_grad * mask * grad_scale)
|
||||
else:
|
||||
self.assertEqual(input.grad, output_grad)
|
||||
|
||||
|
||||
def test_mps_generator(self):
|
||||
# explicit manual seeding by creating an MPS Generator
|
||||
g_mps = torch.Generator(device='mps')
|
||||
|
Reference in New Issue
Block a user