mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fixes #144196 Extends #144106 and #144110 ## Open Problems: - [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769 # Notes - `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped. - `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402 - `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218 - `independent.py`: made `Independent` a generic class of its base distribution. - `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - `relaxed_bernoulli.py`: added class-level type hint for `base_dist`. - `relaxed_categorical.py`: added class-level type hint for `base_dist`. - ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401 - ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400 - `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`. - `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1]. - `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`. - skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`. ## Remark `TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`. ```python import torch from torch.distributions import * b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0])) b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1)) t = StickBreakingTransform() d1 = TransformedDistribution(b1, t) d2 = TransformedDistribution(b2, t) print(d1.base_dist) # Independent with 1 dimension print(d2.base_dist) # MultivariateNormal ``` One could consider changing this to `if reinterpreted_batch_ndims > 1:`. [^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places [^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
147 lines
5.5 KiB
Python
147 lines
5.5 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import inf, Tensor
|
|
from torch.distributions import Categorical, constraints
|
|
from torch.distributions.binomial import Binomial
|
|
from torch.distributions.distribution import Distribution
|
|
from torch.distributions.utils import broadcast_all
|
|
|
|
|
|
__all__ = ["Multinomial"]
|
|
|
|
|
|
class Multinomial(Distribution):
|
|
r"""
|
|
Creates a Multinomial distribution parameterized by :attr:`total_count` and
|
|
either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
|
|
:attr:`probs` indexes over categories. All other dimensions index over batches.
|
|
|
|
Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
|
|
called (see example below)
|
|
|
|
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
|
|
and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
|
|
will return this normalized value.
|
|
The `logits` argument will be interpreted as unnormalized log probabilities
|
|
and can therefore be any real number. It will likewise be normalized so that
|
|
the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
|
|
will return this normalized value.
|
|
|
|
- :meth:`sample` requires a single shared `total_count` for all
|
|
parameters and samples.
|
|
- :meth:`log_prob` allows different `total_count` for each parameter and
|
|
sample.
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +SKIP("FIXME: found invalid values")
|
|
>>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
|
|
>>> x = m.sample() # equal probability of 0, 1, 2, 3
|
|
tensor([ 21., 24., 30., 25.])
|
|
|
|
>>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
|
|
tensor([-4.1338])
|
|
|
|
Args:
|
|
total_count (int): number of trials
|
|
probs (Tensor): event probabilities
|
|
logits (Tensor): event log probabilities (unnormalized)
|
|
"""
|
|
|
|
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
|
total_count: int
|
|
|
|
@property
|
|
def mean(self) -> Tensor:
|
|
return self.probs * self.total_count
|
|
|
|
@property
|
|
def variance(self) -> Tensor:
|
|
return self.total_count * self.probs * (1 - self.probs)
|
|
|
|
def __init__(
|
|
self,
|
|
total_count: int = 1,
|
|
probs: Optional[Tensor] = None,
|
|
logits: Optional[Tensor] = None,
|
|
validate_args: Optional[bool] = None,
|
|
) -> None:
|
|
if not isinstance(total_count, int):
|
|
raise NotImplementedError("inhomogeneous total_count is not supported")
|
|
self.total_count = total_count
|
|
self._categorical = Categorical(probs=probs, logits=logits)
|
|
self._binomial = Binomial(total_count=total_count, probs=self.probs)
|
|
batch_shape = self._categorical.batch_shape
|
|
event_shape = self._categorical.param_shape[-1:]
|
|
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(Multinomial, _instance)
|
|
batch_shape = torch.Size(batch_shape)
|
|
new.total_count = self.total_count
|
|
new._categorical = self._categorical.expand(batch_shape)
|
|
super(Multinomial, new).__init__(
|
|
batch_shape, self.event_shape, validate_args=False
|
|
)
|
|
new._validate_args = self._validate_args
|
|
return new
|
|
|
|
def _new(self, *args, **kwargs):
|
|
return self._categorical._new(*args, **kwargs)
|
|
|
|
@constraints.dependent_property(is_discrete=True, event_dim=1)
|
|
def support(self):
|
|
return constraints.multinomial(self.total_count)
|
|
|
|
@property
|
|
def logits(self) -> Tensor:
|
|
return self._categorical.logits
|
|
|
|
@property
|
|
def probs(self) -> Tensor:
|
|
return self._categorical.probs
|
|
|
|
@property
|
|
def param_shape(self) -> torch.Size:
|
|
return self._categorical.param_shape
|
|
|
|
def sample(self, sample_shape=torch.Size()):
|
|
sample_shape = torch.Size(sample_shape)
|
|
samples = self._categorical.sample(
|
|
torch.Size((self.total_count,)) + sample_shape
|
|
)
|
|
# samples.shape is (total_count, sample_shape, batch_shape), need to change it to
|
|
# (sample_shape, batch_shape, total_count)
|
|
shifted_idx = list(range(samples.dim()))
|
|
shifted_idx.append(shifted_idx.pop(0))
|
|
samples = samples.permute(*shifted_idx)
|
|
counts = samples.new(self._extended_shape(sample_shape)).zero_()
|
|
counts.scatter_add_(-1, samples, torch.ones_like(samples))
|
|
return counts.type_as(self.probs)
|
|
|
|
def entropy(self):
|
|
n = torch.tensor(self.total_count)
|
|
|
|
cat_entropy = self._categorical.entropy()
|
|
term1 = n * cat_entropy - torch.lgamma(n + 1)
|
|
|
|
support = self._binomial.enumerate_support(expand=False)[1:]
|
|
binomial_probs = torch.exp(self._binomial.log_prob(support))
|
|
weights = torch.lgamma(support + 1)
|
|
term2 = (binomial_probs * weights).sum([0, -1])
|
|
|
|
return term1 + term2
|
|
|
|
def log_prob(self, value):
|
|
if self._validate_args:
|
|
self._validate_sample(value)
|
|
logits, value = broadcast_all(self.logits, value)
|
|
logits = logits.clone(memory_format=torch.contiguous_format)
|
|
log_factorial_n = torch.lgamma(value.sum(-1) + 1)
|
|
log_factorial_xs = torch.lgamma(value + 1).sum(-1)
|
|
logits[(value == 0) & (logits == -inf)] = 0
|
|
log_powers = (logits * value).sum(-1)
|
|
return log_factorial_n - log_factorial_xs + log_powers
|