Fix gumbel cdf (#91698)

Fix `Gumbel.cdf` function.

**Description**
When transformed parameters is outside of the support of underlying Uniform distribution. This makes behavior of `Gumbel.cdf` consistent with other `TransformedDistribution` that pass value of validate_args to the base distribution.

**Issue**
running `Gumbel(0.0,1.0,validate_args=False).cdf(20.0)` would cause `ValueError` exception from `_validate_sample`

**Testing**
Test was added to the `test_distributions.py` to check if `Gumbel(0.0,1.0,validate_args=False).cdf(20.0)` successfully returns `1.0`

This is a second attempt to push changes , after https://github.com/pytorch/pytorch/pull/82488

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91698
Approved by: https://github.com/fritzo, https://github.com/zou3519
This commit is contained in:
Vladimir S. FONOV
2023-03-07 23:04:47 +00:00
committed by PyTorch MergeBot
parent 203890e1e0
commit b0b5f3c6c6
2 changed files with 15 additions and 2 deletions

View File

@ -2589,6 +2589,18 @@ class TestDistributions(DistributionsTestCase):
self.assertEqual(Gumbel(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Gumbel(1.0, 1.0).sample().size(), ())
self.assertEqual(Gumbel(1.0, 1.0).sample((1,)).size(), (1,))
self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float32),
torch.tensor(1.0, dtype=torch.float32),
validate_args=False).cdf(20.0), 1.0, atol=1e-4, rtol=0)
self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float64),
torch.tensor(1.0, dtype=torch.float64),
validate_args=False).cdf(50.0), 1.0, atol=1e-4, rtol=0)
self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float32),
torch.tensor(1.0, dtype=torch.float32),
validate_args=False).cdf(-5.0), 0.0, atol=1e-4, rtol=0)
self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float64),
torch.tensor(1.0, dtype=torch.float64),
validate_args=False).cdf(-10.0), 0.0, atol=1e-8, rtol=0)
def ref_log_prob(idx, x, log_prob):
l = loc.view(-1)[idx].detach()

View File

@ -31,10 +31,11 @@ class Gumbel(TransformedDistribution):
self.loc, self.scale = broadcast_all(loc, scale)
finfo = torch.finfo(self.loc.dtype)
if isinstance(loc, Number) and isinstance(scale, Number):
base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args)
else:
base_dist = Uniform(torch.full_like(self.loc, finfo.tiny),
torch.full_like(self.loc, 1 - finfo.eps))
torch.full_like(self.loc, 1 - finfo.eps),
validate_args=validate_args)
transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
super().__init__(base_dist, transforms, validate_args=validate_args)