fix layer_norm decomp precision for cpu (#140557)

xref: https://fb.workplace.com/groups/1075192433118967/posts/1540519826586223/?comment_id=1543752356262970&reply_comment_id=1544425069529032

the issue is that our decomp needs to branch on device (it only upcasts for cpu), but the device shows up as "meta" because it is registered as a meta tensor rule.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140557
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2024-11-13 07:41:39 -08:00
committed by PyTorch MergeBot
parent 240aa77ad0
commit 9ae19ffbed
2 changed files with 30 additions and 0 deletions

View File

@ -1202,6 +1202,31 @@ class DecompOneOffTests(TestCase):
self.assertTrue(torch.allclose(actual_res, eager_res, atol=atol, rtol=rtol))
@onlyCPU
def test_native_layer_norm_cpu_decomp(self, device):
def f(x, w, b):
return torch.ops.aten.native_layer_norm.default(x, [1, 2, 3], w, b, eps=0.5)
x = torch.randn(1, 2, 3, dtype=torch.bfloat16, device="cpu")
w = torch.randn(1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu")
b = torch.randn(1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu")
out_ref = f(x, w, b)
from torch._subclasses.fake_tensor import FakeTensorMode
with enable_python_dispatcher(), FakeTensorMode():
x = torch.randn(1, 2, 3, dtype=torch.bfloat16, device="cpu")
w = torch.randn(
1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu"
)
b = torch.randn(
1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu"
)
out = f(x, w, b)
for o_ref, o in zip(out_ref, out):
self.assertEqual(o_ref.dtype, o.dtype)
instantiate_device_type_tests(DecompOneOffTests, globals())

View File

@ -3306,6 +3306,11 @@ def native_layer_norm(
return (out, mean, rstd)
@torch._subclasses.fake_impls.register_op_impl(aten.native_layer_norm.default)
def native_layer_norm_fake(fake_mode, func, *args, **kwargs):
return native_layer_norm(*args)
# TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode.
# test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu
@register_decomposition(aten.permute)