Add failing bitwise equivalence UT for aot_eager on rms_norm (#164280)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164280
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2025-10-01 07:48:21 -07:00
committed by PyTorch MergeBot
parent cfd46d13e6
commit 39c340ec9e

View File

@ -28,6 +28,7 @@ from common_utils import (
import torch
import torch._dynamo as torchdynamo
import torch.nn as nn
import torch.nn.functional as F
import torch.utils._pytree as pytree
from functorch import grad, jacrev, make_fx, vjp, vmap
from functorch.compile import (
@ -7199,6 +7200,27 @@ metadata incorrectly.
torch.compile(fn, backend="inductor", fullgraph=True)(x)
torch.compile(fn_, backend="inductor", fullgraph=True)(x)
def test_layer_norm(self):
def fn(x):
return F.layer_norm(x, normalized_shape=(8,))
x = torch.randn(2, 4, 8)
eager = fn(x)
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
@unittest.expectedFailure
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_rms_norm(self):
# Only CUDA rms norm fails to be decomposed
def fn(x):
return F.rms_norm(x, normalized_shape=(8,))
x = torch.randn(2, 4, 8, device="cuda")
eager = fn(x)
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
def test_subclass_parameters(self):
class _M(torch.nn.Module):
def __init__(self):