mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add failing bitwise equivalence UT for aot_eager on rms_norm (#164280)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/164280 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
cfd46d13e6
commit
39c340ec9e
@ -28,6 +28,7 @@ from common_utils import (
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils._pytree as pytree
|
||||
from functorch import grad, jacrev, make_fx, vjp, vmap
|
||||
from functorch.compile import (
|
||||
@ -7199,6 +7200,27 @@ metadata incorrectly.
|
||||
torch.compile(fn, backend="inductor", fullgraph=True)(x)
|
||||
torch.compile(fn_, backend="inductor", fullgraph=True)(x)
|
||||
|
||||
def test_layer_norm(self):
|
||||
def fn(x):
|
||||
return F.layer_norm(x, normalized_shape=(8,))
|
||||
|
||||
x = torch.randn(2, 4, 8)
|
||||
eager = fn(x)
|
||||
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
|
||||
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
def test_rms_norm(self):
|
||||
# Only CUDA rms norm fails to be decomposed
|
||||
def fn(x):
|
||||
return F.rms_norm(x, normalized_shape=(8,))
|
||||
|
||||
x = torch.randn(2, 4, 8, device="cuda")
|
||||
eager = fn(x)
|
||||
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
|
||||
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
|
||||
|
||||
def test_subclass_parameters(self):
|
||||
class _M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
Reference in New Issue
Block a user