mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is a new version of #15648 based on the latest master branch. Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR. In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.) Fixes https://github.com/pytorch/pytorch/issues/71105 @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797 Approved by: https://github.com/ezyang
114 lines
4.3 KiB
Python
114 lines
4.3 KiB
Python
from numbers import Number
|
|
|
|
import torch
|
|
from torch.distributions import constraints
|
|
from torch.distributions.distribution import Distribution
|
|
from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property
|
|
from torch.nn.functional import binary_cross_entropy_with_logits
|
|
|
|
__all__ = ['Geometric']
|
|
|
|
class Geometric(Distribution):
|
|
r"""
|
|
Creates a Geometric distribution parameterized by :attr:`probs`,
|
|
where :attr:`probs` is the probability of success of Bernoulli trials.
|
|
It represents the probability that in :math:`k + 1` Bernoulli trials, the
|
|
first :math:`k` trials failed, before seeing a success.
|
|
|
|
Samples are non-negative integers [0, :math:`\inf`).
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +IGNORE_WANT("non-deterinistic")
|
|
>>> m = Geometric(torch.tensor([0.3]))
|
|
>>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0
|
|
tensor([ 2.])
|
|
|
|
Args:
|
|
probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1]
|
|
logits (Number, Tensor): the log-odds of sampling `1`.
|
|
"""
|
|
arg_constraints = {'probs': constraints.unit_interval,
|
|
'logits': constraints.real}
|
|
support = constraints.nonnegative_integer
|
|
|
|
def __init__(self, probs=None, logits=None, validate_args=None):
|
|
if (probs is None) == (logits is None):
|
|
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
|
|
if probs is not None:
|
|
self.probs, = broadcast_all(probs)
|
|
else:
|
|
self.logits, = broadcast_all(logits)
|
|
probs_or_logits = probs if probs is not None else logits
|
|
if isinstance(probs_or_logits, Number):
|
|
batch_shape = torch.Size()
|
|
else:
|
|
batch_shape = probs_or_logits.size()
|
|
super(Geometric, self).__init__(batch_shape, validate_args=validate_args)
|
|
if self._validate_args and probs is not None:
|
|
# Add an extra check beyond unit_interval
|
|
value = self.probs
|
|
valid = value > 0
|
|
if not valid.all():
|
|
invalid_value = value.data[~valid]
|
|
raise ValueError(
|
|
"Expected parameter probs "
|
|
f"({type(value).__name__} of shape {tuple(value.shape)}) "
|
|
f"of distribution {repr(self)} "
|
|
f"to be positive but found invalid values:\n{invalid_value}"
|
|
)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(Geometric, _instance)
|
|
batch_shape = torch.Size(batch_shape)
|
|
if 'probs' in self.__dict__:
|
|
new.probs = self.probs.expand(batch_shape)
|
|
if 'logits' in self.__dict__:
|
|
new.logits = self.logits.expand(batch_shape)
|
|
super(Geometric, new).__init__(batch_shape, validate_args=False)
|
|
new._validate_args = self._validate_args
|
|
return new
|
|
|
|
@property
|
|
def mean(self):
|
|
return 1. / self.probs - 1.
|
|
|
|
@property
|
|
def mode(self):
|
|
return torch.zeros_like(self.probs)
|
|
|
|
@property
|
|
def variance(self):
|
|
return (1. / self.probs - 1.) / self.probs
|
|
|
|
@lazy_property
|
|
def logits(self):
|
|
return probs_to_logits(self.probs, is_binary=True)
|
|
|
|
@lazy_property
|
|
def probs(self):
|
|
return logits_to_probs(self.logits, is_binary=True)
|
|
|
|
def sample(self, sample_shape=torch.Size()):
|
|
shape = self._extended_shape(sample_shape)
|
|
tiny = torch.finfo(self.probs.dtype).tiny
|
|
with torch.no_grad():
|
|
if torch._C._get_tracing_state():
|
|
# [JIT WORKAROUND] lack of support for .uniform_()
|
|
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
|
|
u = u.clamp(min=tiny)
|
|
else:
|
|
u = self.probs.new(shape).uniform_(tiny, 1)
|
|
return (u.log() / (-self.probs).log1p()).floor()
|
|
|
|
def log_prob(self, value):
|
|
if self._validate_args:
|
|
self._validate_sample(value)
|
|
value, probs = broadcast_all(value, self.probs)
|
|
probs = probs.clone(memory_format=torch.contiguous_format)
|
|
probs[(probs == 1) & (value == 0)] = 0
|
|
return value * (-probs).log1p() + self.probs.log()
|
|
|
|
def entropy(self):
|
|
return binary_cross_entropy_with_logits(self.logits, self.probs, reduction='none') / self.probs
|