mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix layer_norm decomp precision for cpu (#140557)
xref: https://fb.workplace.com/groups/1075192433118967/posts/1540519826586223/?comment_id=1543752356262970&reply_comment_id=1544425069529032 the issue is that our decomp needs to branch on device (it only upcasts for cpu), but the device shows up as "meta" because it is registered as a meta tensor rule. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140557 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
240aa77ad0
commit
9ae19ffbed
@ -1202,6 +1202,31 @@ class DecompOneOffTests(TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(actual_res, eager_res, atol=atol, rtol=rtol))
|
||||
|
||||
@onlyCPU
|
||||
def test_native_layer_norm_cpu_decomp(self, device):
|
||||
def f(x, w, b):
|
||||
return torch.ops.aten.native_layer_norm.default(x, [1, 2, 3], w, b, eps=0.5)
|
||||
|
||||
x = torch.randn(1, 2, 3, dtype=torch.bfloat16, device="cpu")
|
||||
w = torch.randn(1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu")
|
||||
b = torch.randn(1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu")
|
||||
out_ref = f(x, w, b)
|
||||
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
||||
with enable_python_dispatcher(), FakeTensorMode():
|
||||
x = torch.randn(1, 2, 3, dtype=torch.bfloat16, device="cpu")
|
||||
w = torch.randn(
|
||||
1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu"
|
||||
)
|
||||
b = torch.randn(
|
||||
1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu"
|
||||
)
|
||||
out = f(x, w, b)
|
||||
|
||||
for o_ref, o in zip(out_ref, out):
|
||||
self.assertEqual(o_ref.dtype, o.dtype)
|
||||
|
||||
|
||||
instantiate_device_type_tests(DecompOneOffTests, globals())
|
||||
|
||||
|
@ -3306,6 +3306,11 @@ def native_layer_norm(
|
||||
return (out, mean, rstd)
|
||||
|
||||
|
||||
@torch._subclasses.fake_impls.register_op_impl(aten.native_layer_norm.default)
|
||||
def native_layer_norm_fake(fake_mode, func, *args, **kwargs):
|
||||
return native_layer_norm(*args)
|
||||
|
||||
|
||||
# TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode.
|
||||
# test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu
|
||||
@register_decomposition(aten.permute)
|
||||
|
Reference in New Issue
Block a user