[MPS] Fix batch norm incorrect gradient (#156867)

Fixes #156555

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156867
Approved by: https://github.com/malfet
This commit is contained in:
Isalia20
2025-06-25 23:05:49 +00:00
committed by PyTorch MergeBot
parent acaf6ba3c6
commit 653c52fe52
2 changed files with 24 additions and 1 deletions

View File

@ -1748,6 +1748,26 @@ class TestMPS(TestCaseMPS):
self.assertEqual(res_cpu, res_mps)
def test_batch_norm_backward_weight_bias_gradients(self):
# See issue: https://github.com/pytorch/pytorch/issues/156555
N, C, L = 4, 3, 5
x = torch.randn(N, C, L)
y = torch.randn(N, C, L)
bn_cpu = nn.BatchNorm1d(C, affine=True).cpu().train()
bn_mps = nn.BatchNorm1d(C, affine=True).to('mps').train()
bn_mps.load_state_dict(bn_cpu.state_dict())
out_cpu = bn_cpu(x)
out_mps = bn_mps(x.to('mps'))
loss_cpu = ((out_cpu - y) ** 2).mean()
loss_mps = ((out_mps - y.to('mps')) ** 2).mean()
loss_cpu.backward()
loss_mps.backward()
self.assertEqual(bn_cpu.weight.grad, bn_mps.weight.grad, atol=1e-5, rtol=1e-5)
self.assertEqual(bn_cpu.bias.grad, bn_mps.bias.grad, atol=1e-5, rtol=1e-5)
def test_layer_norm_backward(self):
inputs = torch.rand(4, 4, device="mps", requires_grad=True)
x = torch.nn.LayerNorm(4).to("mps")