mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Fix batch norm incorrect gradient (#156867)
Fixes #156555 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156867 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
acaf6ba3c6
commit
653c52fe52
@ -1748,6 +1748,26 @@ class TestMPS(TestCaseMPS):
|
||||
|
||||
self.assertEqual(res_cpu, res_mps)
|
||||
|
||||
def test_batch_norm_backward_weight_bias_gradients(self):
|
||||
# See issue: https://github.com/pytorch/pytorch/issues/156555
|
||||
N, C, L = 4, 3, 5
|
||||
x = torch.randn(N, C, L)
|
||||
y = torch.randn(N, C, L)
|
||||
bn_cpu = nn.BatchNorm1d(C, affine=True).cpu().train()
|
||||
bn_mps = nn.BatchNorm1d(C, affine=True).to('mps').train()
|
||||
bn_mps.load_state_dict(bn_cpu.state_dict())
|
||||
|
||||
out_cpu = bn_cpu(x)
|
||||
out_mps = bn_mps(x.to('mps'))
|
||||
|
||||
loss_cpu = ((out_cpu - y) ** 2).mean()
|
||||
loss_mps = ((out_mps - y.to('mps')) ** 2).mean()
|
||||
loss_cpu.backward()
|
||||
loss_mps.backward()
|
||||
|
||||
self.assertEqual(bn_cpu.weight.grad, bn_mps.weight.grad, atol=1e-5, rtol=1e-5)
|
||||
self.assertEqual(bn_cpu.bias.grad, bn_mps.bias.grad, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_layer_norm_backward(self):
|
||||
inputs = torch.rand(4, 4, device="mps", requires_grad=True)
|
||||
x = torch.nn.LayerNorm(4).to("mps")
|
||||
|
Reference in New Issue
Block a user