mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix gumbel cdf (#91698)
Fix `Gumbel.cdf` function. **Description** When transformed parameters is outside of the support of underlying Uniform distribution. This makes behavior of `Gumbel.cdf` consistent with other `TransformedDistribution` that pass value of validate_args to the base distribution. **Issue** running `Gumbel(0.0,1.0,validate_args=False).cdf(20.0)` would cause `ValueError` exception from `_validate_sample` **Testing** Test was added to the `test_distributions.py` to check if `Gumbel(0.0,1.0,validate_args=False).cdf(20.0)` successfully returns `1.0` This is a second attempt to push changes , after https://github.com/pytorch/pytorch/pull/82488 Pull Request resolved: https://github.com/pytorch/pytorch/pull/91698 Approved by: https://github.com/fritzo, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
203890e1e0
commit
b0b5f3c6c6
@ -2589,6 +2589,18 @@ class TestDistributions(DistributionsTestCase):
|
||||
self.assertEqual(Gumbel(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
|
||||
self.assertEqual(Gumbel(1.0, 1.0).sample().size(), ())
|
||||
self.assertEqual(Gumbel(1.0, 1.0).sample((1,)).size(), (1,))
|
||||
self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float32),
|
||||
torch.tensor(1.0, dtype=torch.float32),
|
||||
validate_args=False).cdf(20.0), 1.0, atol=1e-4, rtol=0)
|
||||
self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float64),
|
||||
torch.tensor(1.0, dtype=torch.float64),
|
||||
validate_args=False).cdf(50.0), 1.0, atol=1e-4, rtol=0)
|
||||
self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float32),
|
||||
torch.tensor(1.0, dtype=torch.float32),
|
||||
validate_args=False).cdf(-5.0), 0.0, atol=1e-4, rtol=0)
|
||||
self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float64),
|
||||
torch.tensor(1.0, dtype=torch.float64),
|
||||
validate_args=False).cdf(-10.0), 0.0, atol=1e-8, rtol=0)
|
||||
|
||||
def ref_log_prob(idx, x, log_prob):
|
||||
l = loc.view(-1)[idx].detach()
|
||||
|
@ -31,10 +31,11 @@ class Gumbel(TransformedDistribution):
|
||||
self.loc, self.scale = broadcast_all(loc, scale)
|
||||
finfo = torch.finfo(self.loc.dtype)
|
||||
if isinstance(loc, Number) and isinstance(scale, Number):
|
||||
base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
|
||||
base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args)
|
||||
else:
|
||||
base_dist = Uniform(torch.full_like(self.loc, finfo.tiny),
|
||||
torch.full_like(self.loc, 1 - finfo.eps))
|
||||
torch.full_like(self.loc, 1 - finfo.eps),
|
||||
validate_args=validate_args)
|
||||
transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
|
||||
ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
|
||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||
|
Reference in New Issue
Block a user