mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #144196 Extends #144106 and #144110 ## Open Problems: - [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769 # Notes - `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped. - `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402 - `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218 - `independent.py`: made `Independent` a generic class of its base distribution. - `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - `relaxed_bernoulli.py`: added class-level type hint for `base_dist`. - `relaxed_categorical.py`: added class-level type hint for `base_dist`. - ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401 - ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400 - `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`. - `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1]. - `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`. - skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`. ## Remark `TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`. ```python import torch from torch.distributions import * b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0])) b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1)) t = StickBreakingTransform() d1 = TransformedDistribution(b1, t) d2 = TransformedDistribution(b2, t) print(d1.base_dist) # Independent with 1 dimension print(d2.base_dist) # MultivariateNormal ``` One could consider changing this to `if reinterpreted_batch_ndims > 1:`. [^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places [^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
148 lines
4.9 KiB
Python
148 lines
4.9 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch.distributions import constraints
|
|
from torch.distributions.distribution import Distribution
|
|
from torch.distributions.gamma import Gamma
|
|
from torch.distributions.utils import (
|
|
broadcast_all,
|
|
lazy_property,
|
|
logits_to_probs,
|
|
probs_to_logits,
|
|
)
|
|
|
|
|
|
__all__ = ["NegativeBinomial"]
|
|
|
|
|
|
class NegativeBinomial(Distribution):
|
|
r"""
|
|
Creates a Negative Binomial distribution, i.e. distribution
|
|
of the number of successful independent and identical Bernoulli trials
|
|
before :attr:`total_count` failures are achieved. The probability
|
|
of success of each Bernoulli trial is :attr:`probs`.
|
|
|
|
Args:
|
|
total_count (float or Tensor): non-negative number of negative Bernoulli
|
|
trials to stop, although the distribution is still valid for real
|
|
valued count
|
|
probs (Tensor): Event probabilities of success in the half open interval [0, 1)
|
|
logits (Tensor): Event log-odds for probabilities of success
|
|
"""
|
|
|
|
arg_constraints = {
|
|
"total_count": constraints.greater_than_eq(0),
|
|
"probs": constraints.half_open_interval(0.0, 1.0),
|
|
"logits": constraints.real,
|
|
}
|
|
support = constraints.nonnegative_integer
|
|
|
|
def __init__(
|
|
self,
|
|
total_count: Union[Tensor, float],
|
|
probs: Optional[Tensor] = None,
|
|
logits: Optional[Tensor] = None,
|
|
validate_args: Optional[bool] = None,
|
|
) -> None:
|
|
if (probs is None) == (logits is None):
|
|
raise ValueError(
|
|
"Either `probs` or `logits` must be specified, but not both."
|
|
)
|
|
if probs is not None:
|
|
(
|
|
self.total_count,
|
|
self.probs,
|
|
) = broadcast_all(total_count, probs)
|
|
self.total_count = self.total_count.type_as(self.probs)
|
|
else:
|
|
assert logits is not None # helps mypy
|
|
(
|
|
self.total_count,
|
|
self.logits,
|
|
) = broadcast_all(total_count, logits)
|
|
self.total_count = self.total_count.type_as(self.logits)
|
|
|
|
self._param = self.probs if probs is not None else self.logits
|
|
batch_shape = self._param.size()
|
|
super().__init__(batch_shape, validate_args=validate_args)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(NegativeBinomial, _instance)
|
|
batch_shape = torch.Size(batch_shape)
|
|
new.total_count = self.total_count.expand(batch_shape)
|
|
if "probs" in self.__dict__:
|
|
new.probs = self.probs.expand(batch_shape)
|
|
new._param = new.probs
|
|
if "logits" in self.__dict__:
|
|
new.logits = self.logits.expand(batch_shape)
|
|
new._param = new.logits
|
|
super(NegativeBinomial, new).__init__(batch_shape, validate_args=False)
|
|
new._validate_args = self._validate_args
|
|
return new
|
|
|
|
def _new(self, *args, **kwargs):
|
|
return self._param.new(*args, **kwargs)
|
|
|
|
@property
|
|
def mean(self) -> Tensor:
|
|
return self.total_count * torch.exp(self.logits)
|
|
|
|
@property
|
|
def mode(self) -> Tensor:
|
|
return ((self.total_count - 1) * self.logits.exp()).floor().clamp(min=0.0)
|
|
|
|
@property
|
|
def variance(self) -> Tensor:
|
|
return self.mean / torch.sigmoid(-self.logits)
|
|
|
|
@lazy_property
|
|
def logits(self) -> Tensor:
|
|
return probs_to_logits(self.probs, is_binary=True)
|
|
|
|
@lazy_property
|
|
def probs(self) -> Tensor:
|
|
return logits_to_probs(self.logits, is_binary=True)
|
|
|
|
@property
|
|
def param_shape(self) -> torch.Size:
|
|
return self._param.size()
|
|
|
|
@lazy_property
|
|
def _gamma(self) -> Gamma:
|
|
# Note we avoid validating because self.total_count can be zero.
|
|
return Gamma(
|
|
concentration=self.total_count,
|
|
rate=torch.exp(-self.logits),
|
|
validate_args=False,
|
|
)
|
|
|
|
def sample(self, sample_shape=torch.Size()):
|
|
with torch.no_grad():
|
|
rate = self._gamma.sample(sample_shape=sample_shape)
|
|
return torch.poisson(rate)
|
|
|
|
def log_prob(self, value):
|
|
if self._validate_args:
|
|
self._validate_sample(value)
|
|
|
|
log_unnormalized_prob = self.total_count * F.logsigmoid(
|
|
-self.logits
|
|
) + value * F.logsigmoid(self.logits)
|
|
|
|
log_normalization = (
|
|
-torch.lgamma(self.total_count + value)
|
|
+ torch.lgamma(1.0 + value)
|
|
+ torch.lgamma(self.total_count)
|
|
)
|
|
# The case self.total_count == 0 and value == 0 has probability 1 but
|
|
# lgamma(0) is infinite. Handle this case separately using a function
|
|
# that does not modify tensors in place to allow Jit compilation.
|
|
log_normalization = log_normalization.masked_fill(
|
|
self.total_count + value == 0.0, 0.0
|
|
)
|
|
|
|
return log_unnormalized_prob - log_normalization
|