mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aotd] capture rrelu_with_noise noise mutation in compile (#141867)
Rebase-copy of long standing already approved PR https://github.com/pytorch/pytorch/pull/138503 that was blocked on landing by xla build issues. Got a new PR with the same content (ghstack checkout was failing due to changed submodules) Corresponding xla PR: https://github.com/pytorch/xla/pull/8363 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141867 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
61dc5e9c0a
commit
f85e238186
@ -656,50 +656,6 @@ class TestDecomp(TestCase):
|
||||
for dim in (-1, 0, 1):
|
||||
self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))
|
||||
|
||||
def test_rrelu_with_noise(self, device):
|
||||
# rrelu_with_noise behavior depends on a) whether elements in the input
|
||||
# are <= 0, and b) whether we're in training mode. Cover all cases:
|
||||
dtype = torch.float64
|
||||
x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device)
|
||||
lower = 1.0
|
||||
upper = 4.0
|
||||
training = False
|
||||
|
||||
torch.manual_seed(123)
|
||||
noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
|
||||
ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
|
||||
|
||||
torch.manual_seed(123)
|
||||
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
|
||||
res = torch._decomp.decompositions.rrelu_with_noise(
|
||||
x,
|
||||
noise_res,
|
||||
lower,
|
||||
upper,
|
||||
training,
|
||||
)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(noise_ref, noise_res)
|
||||
|
||||
# Now with training=True:
|
||||
training = True
|
||||
|
||||
torch.manual_seed(123)
|
||||
noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
|
||||
ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
|
||||
|
||||
torch.manual_seed(123)
|
||||
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
|
||||
res = torch._decomp.decompositions.rrelu_with_noise(
|
||||
x,
|
||||
noise_res,
|
||||
lower,
|
||||
upper,
|
||||
training,
|
||||
)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(noise_ref, noise_res)
|
||||
|
||||
@suppress_warnings
|
||||
@tf32_off()
|
||||
# only tests RNNs since we have py dispsatcher decomps for them
|
||||
|
Reference in New Issue
Block a user