[aotd] capture rrelu_with_noise noise mutation in compile (#141867)

Rebase-copy of long standing already approved PR https://github.com/pytorch/pytorch/pull/138503 that was blocked on landing by xla build issues.

Got a new  PR with the same content (ghstack checkout was failing due to changed submodules)

Corresponding xla PR:
https://github.com/pytorch/xla/pull/8363

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141867
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev
2024-12-03 02:54:06 -08:00
committed by PyTorch MergeBot
parent 61dc5e9c0a
commit f85e238186
14 changed files with 149 additions and 102 deletions

View File

@ -656,50 +656,6 @@ class TestDecomp(TestCase):
for dim in (-1, 0, 1):
self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))
def test_rrelu_with_noise(self, device):
# rrelu_with_noise behavior depends on a) whether elements in the input
# are <= 0, and b) whether we're in training mode. Cover all cases:
dtype = torch.float64
x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device)
lower = 1.0
upper = 4.0
training = False
torch.manual_seed(123)
noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
torch.manual_seed(123)
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
res = torch._decomp.decompositions.rrelu_with_noise(
x,
noise_res,
lower,
upper,
training,
)
self.assertEqual(ref, res)
self.assertEqual(noise_ref, noise_res)
# Now with training=True:
training = True
torch.manual_seed(123)
noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
torch.manual_seed(123)
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
res = torch._decomp.decompositions.rrelu_with_noise(
x,
noise_res,
lower,
upper,
training,
)
self.assertEqual(ref, res)
self.assertEqual(noise_ref, noise_res)
@suppress_warnings
@tf32_off()
# only tests RNNs since we have py dispsatcher decomps for them