mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
better error handling for rrelu when lower or upper range is infinite (#160965)
… - issue#153281 Fixes #153281 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160965 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
7d59e37434
commit
ace89350fc
@ -12803,6 +12803,43 @@ if __name__ == '__main__':
|
||||
expected_bf16 = torch.tensor([0., 0., 1.], device=device, dtype=dtype)
|
||||
self.assertEqual(a_bf16.grad, expected_bf16)
|
||||
|
||||
@onlyCPU
|
||||
def test_rrelu_bounds_validation(self, device):
|
||||
"""Test RReLU bounds validation for finite and infinite values."""
|
||||
x = torch.randn(5, 5, device=device)
|
||||
|
||||
# Test with finite bounds
|
||||
result = F.rrelu(x, lower=0.1, upper=0.3)
|
||||
self.assertEqual(result.shape, x.shape)
|
||||
|
||||
# Test with infinite lower bound
|
||||
with self.assertRaisesRegex(RuntimeError, "rrelu: lower bound must be finite, got inf"):
|
||||
F.rrelu(x, lower=float('inf'), upper=0.3)
|
||||
|
||||
# Test with infinite upper bound
|
||||
with self.assertRaisesRegex(RuntimeError, "rrelu: upper bound must be finite, got inf"):
|
||||
F.rrelu(x, lower=0.1, upper=float('inf'))
|
||||
|
||||
# Test with NaN lower bound
|
||||
with self.assertRaisesRegex(RuntimeError, "rrelu: lower bound must be finite, got nan"):
|
||||
F.rrelu(x, lower=float('nan'), upper=0.3)
|
||||
|
||||
# Test with NaN upper bound
|
||||
with self.assertRaisesRegex(RuntimeError, "rrelu: upper bound must be finite, got nan"):
|
||||
F.rrelu(x, lower=0.1, upper=float('nan'))
|
||||
|
||||
# Test with negative infinity lower bound
|
||||
with self.assertRaisesRegex(RuntimeError, "rrelu: lower bound must be finite, got -inf"):
|
||||
F.rrelu(x, lower=float('-inf'), upper=0.3)
|
||||
|
||||
# Test with negative infinity upper bound
|
||||
with self.assertRaisesRegex(RuntimeError, "rrelu: upper bound must be finite, got -inf"):
|
||||
F.rrelu(x, lower=0.1, upper=float('-inf'))
|
||||
|
||||
# Test with lower bound greater than upper bound
|
||||
with self.assertRaisesRegex(RuntimeError, "Lower bound should be less than or equal to the upper bound"):
|
||||
F.rrelu(x, lower=0.5, upper=0.3)
|
||||
|
||||
@onlyCPU
|
||||
def test_softshrink(self, device):
|
||||
x = torch.tensor([[1.21, 0.56, 0.5001, 0.4999, 1.2357, -0.4999, -0.5001, -1.154,
|
||||
|
Reference in New Issue
Block a user