[decomp] Fix native_batch_norm_backward dtype of dweight and dbias (#89740)

Discovered while debugging an accuracy issue for Inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89740
Approved by: https://github.com/soumith, https://github.com/ngimel
This commit is contained in:
Animesh Jain
2022-11-29 03:15:16 +00:00
committed by PyTorch MergeBot
parent 4d7ec30220
commit c1950620c5
2 changed files with 48 additions and 2 deletions

View File

@ -21,6 +21,7 @@ from torch.testing._internal.common_device_type import (
onlyNativeDeviceTypes,
ops,
instantiate_device_type_tests,
onlyCUDA,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch._dispatch.python import enable_python_dispatcher
@ -577,6 +578,47 @@ class DecompContiguousTests(TestCase):
instantiate_device_type_tests(DecompContiguousTests, globals())
class DecompAmpTests(TestCase):
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@skipIfCrossRef
@onlyCUDA
def test_amp_batch_norm_backward(self):
device = "cuda"
grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
weight = torch.randn((2,), dtype=torch.float32, device=device)
rmean = torch.randn((2,), dtype=torch.float32, device=device)
rvar = torch.randn((2,), dtype=torch.float32, device=device)
mean = torch.randn((0,), dtype=torch.float32, device=device)
ref = torch.ops.aten.native_batch_norm_backward(
grad_out,
x,
weight,
rmean,
rvar,
mean,
mean,
False,
1e-05,
[True, True, True])
res = torch._decomp.decompositions.native_batch_norm_backward(
grad_out,
x,
weight,
rmean,
rvar,
mean,
mean,
False,
1e-05,
[True, True, True])
for (a, b) in zip(ref, res):
self.assertEqual(a.stride(), b.stride())
self.assertEqual(a.dtype, b.dtype)
instantiate_device_type_tests(DecompAmpTests, globals())
if __name__ == "__main__":
run_tests()