Files
pytorch/torch/distributions/beta.py
Randolf Scholz 6c38b9be73 [typing] Add type hints to __init__ methods in torch.distributions. (#144197)
Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-06 17:50:35 +00:00

118 lines
3.9 KiB
Python

# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import constraints
from torch.distributions.dirichlet import Dirichlet
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size
__all__ = ["Beta"]
class Beta(ExponentialFamily):
r"""
Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.
Example::
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
>>> m.sample() # Beta distributed with concentration concentration1 and concentration0
tensor([ 0.1046])
Args:
concentration1 (float or Tensor): 1st concentration parameter of the distribution
(often referred to as alpha)
concentration0 (float or Tensor): 2nd concentration parameter of the distribution
(often referred to as beta)
"""
arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
}
support = constraints.unit_interval
has_rsample = True
def __init__(
self,
concentration1: Union[Tensor, float],
concentration0: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
if isinstance(concentration1, _Number) and isinstance(concentration0, _Number):
concentration1_concentration0 = torch.tensor(
[float(concentration1), float(concentration0)]
)
else:
concentration1, concentration0 = broadcast_all(
concentration1, concentration0
)
concentration1_concentration0 = torch.stack(
[concentration1, concentration0], -1
)
self._dirichlet = Dirichlet(
concentration1_concentration0, validate_args=validate_args
)
super().__init__(self._dirichlet._batch_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Beta, _instance)
batch_shape = torch.Size(batch_shape)
new._dirichlet = self._dirichlet.expand(batch_shape)
super(Beta, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
@property
def mean(self) -> Tensor:
return self.concentration1 / (self.concentration1 + self.concentration0)
@property
def mode(self) -> Tensor:
return self._dirichlet.mode[..., 0]
@property
def variance(self) -> Tensor:
total = self.concentration1 + self.concentration0
return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))
def rsample(self, sample_shape: _size = ()) -> Tensor:
return self._dirichlet.rsample(sample_shape).select(-1, 0)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
heads_tails = torch.stack([value, 1.0 - value], -1)
return self._dirichlet.log_prob(heads_tails)
def entropy(self):
return self._dirichlet.entropy()
@property
def concentration1(self) -> Tensor:
result = self._dirichlet.concentration[..., 0]
if isinstance(result, _Number):
return torch.tensor([result])
else:
return result
@property
def concentration0(self) -> Tensor:
result = self._dirichlet.concentration[..., 1]
if isinstance(result, _Number):
return torch.tensor([result])
else:
return result
@property
def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.concentration1, self.concentration0)
def _log_normalizer(self, x, y):
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)