mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fixes #144196 Extends #144106 and #144110 ## Open Problems: - [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769 # Notes - `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped. - `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402 - `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218 - `independent.py`: made `Independent` a generic class of its base distribution. - `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - `relaxed_bernoulli.py`: added class-level type hint for `base_dist`. - `relaxed_categorical.py`: added class-level type hint for `base_dist`. - ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401 - ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400 - `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`. - `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1]. - `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`. - skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`. ## Remark `TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`. ```python import torch from torch.distributions import * b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0])) b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1)) t = StickBreakingTransform() d1 = TransformedDistribution(b1, t) d2 = TransformedDistribution(b2, t) print(d1.base_dist) # Independent with 1 dimension print(d2.base_dist) # MultivariateNormal ``` One could consider changing this to `if reinterpreted_batch_ndims > 1:`. [^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places [^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
179 lines
6.2 KiB
Python
179 lines
6.2 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.distributions import constraints
|
|
from torch.distributions.distribution import Distribution
|
|
from torch.distributions.utils import (
|
|
broadcast_all,
|
|
lazy_property,
|
|
logits_to_probs,
|
|
probs_to_logits,
|
|
)
|
|
|
|
|
|
__all__ = ["Binomial"]
|
|
|
|
|
|
def _clamp_by_zero(x):
|
|
# works like clamp(x, min=0) but has grad at 0 is 0.5
|
|
return (x.clamp(min=0) + x - x.clamp(max=0)) / 2
|
|
|
|
|
|
class Binomial(Distribution):
|
|
r"""
|
|
Creates a Binomial distribution parameterized by :attr:`total_count` and
|
|
either :attr:`probs` or :attr:`logits` (but not both). :attr:`total_count` must be
|
|
broadcastable with :attr:`probs`/:attr:`logits`.
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
|
>>> m = Binomial(100, torch.tensor([0 , .2, .8, 1]))
|
|
>>> x = m.sample()
|
|
tensor([ 0., 22., 71., 100.])
|
|
|
|
>>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8]))
|
|
>>> x = m.sample()
|
|
tensor([[ 4., 5.],
|
|
[ 7., 6.]])
|
|
|
|
Args:
|
|
total_count (int or Tensor): number of Bernoulli trials
|
|
probs (Tensor): Event probabilities
|
|
logits (Tensor): Event log-odds
|
|
"""
|
|
|
|
arg_constraints = {
|
|
"total_count": constraints.nonnegative_integer,
|
|
"probs": constraints.unit_interval,
|
|
"logits": constraints.real,
|
|
}
|
|
has_enumerate_support = True
|
|
|
|
def __init__(
|
|
self,
|
|
total_count: Union[Tensor, int] = 1,
|
|
probs: Optional[Tensor] = None,
|
|
logits: Optional[Tensor] = None,
|
|
validate_args: Optional[bool] = None,
|
|
) -> None:
|
|
if (probs is None) == (logits is None):
|
|
raise ValueError(
|
|
"Either `probs` or `logits` must be specified, but not both."
|
|
)
|
|
if probs is not None:
|
|
(
|
|
self.total_count,
|
|
self.probs,
|
|
) = broadcast_all(total_count, probs)
|
|
self.total_count = self.total_count.type_as(self.probs)
|
|
else:
|
|
assert logits is not None # helps mypy
|
|
(
|
|
self.total_count,
|
|
self.logits,
|
|
) = broadcast_all(total_count, logits)
|
|
self.total_count = self.total_count.type_as(self.logits)
|
|
|
|
self._param = self.probs if probs is not None else self.logits
|
|
batch_shape = self._param.size()
|
|
super().__init__(batch_shape, validate_args=validate_args)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(Binomial, _instance)
|
|
batch_shape = torch.Size(batch_shape)
|
|
new.total_count = self.total_count.expand(batch_shape)
|
|
if "probs" in self.__dict__:
|
|
new.probs = self.probs.expand(batch_shape)
|
|
new._param = new.probs
|
|
if "logits" in self.__dict__:
|
|
new.logits = self.logits.expand(batch_shape)
|
|
new._param = new.logits
|
|
super(Binomial, new).__init__(batch_shape, validate_args=False)
|
|
new._validate_args = self._validate_args
|
|
return new
|
|
|
|
def _new(self, *args, **kwargs):
|
|
return self._param.new(*args, **kwargs)
|
|
|
|
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
|
def support(self):
|
|
return constraints.integer_interval(0, self.total_count)
|
|
|
|
@property
|
|
def mean(self) -> Tensor:
|
|
return self.total_count * self.probs
|
|
|
|
@property
|
|
def mode(self) -> Tensor:
|
|
return ((self.total_count + 1) * self.probs).floor().clamp(max=self.total_count)
|
|
|
|
@property
|
|
def variance(self) -> Tensor:
|
|
return self.total_count * self.probs * (1 - self.probs)
|
|
|
|
@lazy_property
|
|
def logits(self) -> Tensor:
|
|
return probs_to_logits(self.probs, is_binary=True)
|
|
|
|
@lazy_property
|
|
def probs(self) -> Tensor:
|
|
return logits_to_probs(self.logits, is_binary=True)
|
|
|
|
@property
|
|
def param_shape(self) -> torch.Size:
|
|
return self._param.size()
|
|
|
|
def sample(self, sample_shape=torch.Size()):
|
|
shape = self._extended_shape(sample_shape)
|
|
with torch.no_grad():
|
|
return torch.binomial(
|
|
self.total_count.expand(shape), self.probs.expand(shape)
|
|
)
|
|
|
|
def log_prob(self, value):
|
|
if self._validate_args:
|
|
self._validate_sample(value)
|
|
log_factorial_n = torch.lgamma(self.total_count + 1)
|
|
log_factorial_k = torch.lgamma(value + 1)
|
|
log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
|
|
# k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
|
|
# (case logit < 0) = k * logit - n * log1p(e^logit)
|
|
# (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
|
|
# = k * logit - n * logit - n * log1p(e^-logit)
|
|
# (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
|
|
normalize_term = (
|
|
self.total_count * _clamp_by_zero(self.logits)
|
|
+ self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits)))
|
|
- log_factorial_n
|
|
)
|
|
return (
|
|
value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
|
|
)
|
|
|
|
def entropy(self):
|
|
total_count = int(self.total_count.max())
|
|
if not self.total_count.min() == total_count:
|
|
raise NotImplementedError(
|
|
"Inhomogeneous total count not supported by `entropy`."
|
|
)
|
|
|
|
log_prob = self.log_prob(self.enumerate_support(False))
|
|
return -(torch.exp(log_prob) * log_prob).sum(0)
|
|
|
|
def enumerate_support(self, expand=True):
|
|
total_count = int(self.total_count.max())
|
|
if not self.total_count.min() == total_count:
|
|
raise NotImplementedError(
|
|
"Inhomogeneous total count not supported by `enumerate_support`."
|
|
)
|
|
values = torch.arange(
|
|
1 + total_count, dtype=self._param.dtype, device=self._param.device
|
|
)
|
|
values = values.view((-1,) + (1,) * len(self._batch_shape))
|
|
if expand:
|
|
values = values.expand((-1,) + self._batch_shape)
|
|
return values
|