mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fixes #144196 Extends #144106 and #144110 ## Open Problems: - [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769 # Notes - `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped. - `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402 - `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218 - `independent.py`: made `Independent` a generic class of its base distribution. - `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - `relaxed_bernoulli.py`: added class-level type hint for `base_dist`. - `relaxed_categorical.py`: added class-level type hint for `base_dist`. - ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401 - ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400 - `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`. - `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1]. - `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`. - skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`. ## Remark `TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`. ```python import torch from torch.distributions import * b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0])) b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1)) t = StickBreakingTransform() d1 = TransformedDistribution(b1, t) d2 = TransformedDistribution(b2, t) print(d1.base_dist) # Independent with 1 dimension print(d2.base_dist) # MultivariateNormal ``` One could consider changing this to `if reinterpreted_batch_ndims > 1:`. [^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places [^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
141 lines
4.9 KiB
Python
141 lines
4.9 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.distributions import constraints
|
|
from torch.distributions.distribution import Distribution
|
|
from torch.distributions.utils import (
|
|
broadcast_all,
|
|
lazy_property,
|
|
logits_to_probs,
|
|
probs_to_logits,
|
|
)
|
|
from torch.nn.functional import binary_cross_entropy_with_logits
|
|
from torch.types import _Number, Number
|
|
|
|
|
|
__all__ = ["Geometric"]
|
|
|
|
|
|
class Geometric(Distribution):
|
|
r"""
|
|
Creates a Geometric distribution parameterized by :attr:`probs`,
|
|
where :attr:`probs` is the probability of success of Bernoulli trials.
|
|
|
|
.. math::
|
|
|
|
P(X=k) = (1-p)^{k} p, k = 0, 1, ...
|
|
|
|
.. note::
|
|
:func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success
|
|
hence draws samples in :math:`\{0, 1, \ldots\}`, whereas
|
|
:func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`.
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
|
>>> m = Geometric(torch.tensor([0.3]))
|
|
>>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0
|
|
tensor([ 2.])
|
|
|
|
Args:
|
|
probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1]
|
|
logits (Number, Tensor): the log-odds of sampling `1`.
|
|
"""
|
|
|
|
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
|
support = constraints.nonnegative_integer
|
|
|
|
def __init__(
|
|
self,
|
|
probs: Optional[Union[Tensor, Number]] = None,
|
|
logits: Optional[Union[Tensor, Number]] = None,
|
|
validate_args: Optional[bool] = None,
|
|
) -> None:
|
|
if (probs is None) == (logits is None):
|
|
raise ValueError(
|
|
"Either `probs` or `logits` must be specified, but not both."
|
|
)
|
|
if probs is not None:
|
|
(self.probs,) = broadcast_all(probs)
|
|
else:
|
|
assert logits is not None # helps mypy
|
|
(self.logits,) = broadcast_all(logits)
|
|
probs_or_logits = probs if probs is not None else logits
|
|
if isinstance(probs_or_logits, _Number):
|
|
batch_shape = torch.Size()
|
|
else:
|
|
assert probs_or_logits is not None # helps mypy
|
|
batch_shape = probs_or_logits.size()
|
|
super().__init__(batch_shape, validate_args=validate_args)
|
|
if self._validate_args and probs is not None:
|
|
# Add an extra check beyond unit_interval
|
|
value = self.probs
|
|
valid = value > 0
|
|
if not valid.all():
|
|
invalid_value = value.data[~valid]
|
|
raise ValueError(
|
|
"Expected parameter probs "
|
|
f"({type(value).__name__} of shape {tuple(value.shape)}) "
|
|
f"of distribution {repr(self)} "
|
|
f"to be positive but found invalid values:\n{invalid_value}"
|
|
)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(Geometric, _instance)
|
|
batch_shape = torch.Size(batch_shape)
|
|
if "probs" in self.__dict__:
|
|
new.probs = self.probs.expand(batch_shape)
|
|
if "logits" in self.__dict__:
|
|
new.logits = self.logits.expand(batch_shape)
|
|
super(Geometric, new).__init__(batch_shape, validate_args=False)
|
|
new._validate_args = self._validate_args
|
|
return new
|
|
|
|
@property
|
|
def mean(self) -> Tensor:
|
|
return 1.0 / self.probs - 1.0
|
|
|
|
@property
|
|
def mode(self) -> Tensor:
|
|
return torch.zeros_like(self.probs)
|
|
|
|
@property
|
|
def variance(self) -> Tensor:
|
|
return (1.0 / self.probs - 1.0) / self.probs
|
|
|
|
@lazy_property
|
|
def logits(self) -> Tensor:
|
|
return probs_to_logits(self.probs, is_binary=True)
|
|
|
|
@lazy_property
|
|
def probs(self) -> Tensor:
|
|
return logits_to_probs(self.logits, is_binary=True)
|
|
|
|
def sample(self, sample_shape=torch.Size()):
|
|
shape = self._extended_shape(sample_shape)
|
|
tiny = torch.finfo(self.probs.dtype).tiny
|
|
with torch.no_grad():
|
|
if torch._C._get_tracing_state():
|
|
# [JIT WORKAROUND] lack of support for .uniform_()
|
|
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
|
|
u = u.clamp(min=tiny)
|
|
else:
|
|
u = self.probs.new(shape).uniform_(tiny, 1)
|
|
return (u.log() / (-self.probs).log1p()).floor()
|
|
|
|
def log_prob(self, value):
|
|
if self._validate_args:
|
|
self._validate_sample(value)
|
|
value, probs = broadcast_all(value, self.probs)
|
|
probs = probs.clone(memory_format=torch.contiguous_format)
|
|
probs[(probs == 1) & (value == 0)] = 0
|
|
return value * (-probs).log1p() + self.probs.log()
|
|
|
|
def entropy(self):
|
|
return (
|
|
binary_cross_entropy_with_logits(self.logits, self.probs, reduction="none")
|
|
/ self.probs
|
|
)
|