mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[decomp] Fix native_batch_norm_backward dtype of dweight and dbias (#89740)
Discovered while debugging an accuracy issue for Inductor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89740 Approved by: https://github.com/soumith, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d7ec30220
commit
c1950620c5
@ -21,6 +21,7 @@ from torch.testing._internal.common_device_type import (
|
||||
onlyNativeDeviceTypes,
|
||||
ops,
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
@ -577,6 +578,47 @@ class DecompContiguousTests(TestCase):
|
||||
|
||||
instantiate_device_type_tests(DecompContiguousTests, globals())
|
||||
|
||||
class DecompAmpTests(TestCase):
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@skipIfCrossRef
|
||||
@onlyCUDA
|
||||
def test_amp_batch_norm_backward(self):
|
||||
device = "cuda"
|
||||
grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
|
||||
x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
|
||||
weight = torch.randn((2,), dtype=torch.float32, device=device)
|
||||
rmean = torch.randn((2,), dtype=torch.float32, device=device)
|
||||
rvar = torch.randn((2,), dtype=torch.float32, device=device)
|
||||
mean = torch.randn((0,), dtype=torch.float32, device=device)
|
||||
|
||||
ref = torch.ops.aten.native_batch_norm_backward(
|
||||
grad_out,
|
||||
x,
|
||||
weight,
|
||||
rmean,
|
||||
rvar,
|
||||
mean,
|
||||
mean,
|
||||
False,
|
||||
1e-05,
|
||||
[True, True, True])
|
||||
res = torch._decomp.decompositions.native_batch_norm_backward(
|
||||
grad_out,
|
||||
x,
|
||||
weight,
|
||||
rmean,
|
||||
rvar,
|
||||
mean,
|
||||
mean,
|
||||
False,
|
||||
1e-05,
|
||||
[True, True, True])
|
||||
for (a, b) in zip(ref, res):
|
||||
self.assertEqual(a.stride(), b.stride())
|
||||
self.assertEqual(a.dtype, b.dtype)
|
||||
|
||||
|
||||
instantiate_device_type_tests(DecompAmpTests, globals())
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
Reference in New Issue
Block a user