Files
pytorch/torch/distributions/independent.py
Randolf Scholz 6c38b9be73 [typing] Add type hints to __init__ methods in torch.distributions. (#144197)
Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-06 17:50:35 +00:00

138 lines
4.9 KiB
Python

# mypy: allow-untyped-defs
from typing import Generic, Optional, TypeVar
import torch
from torch import Size, Tensor
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import _sum_rightmost
from torch.types import _size
__all__ = ["Independent"]
D = TypeVar("D", bound=Distribution)
class Independent(Distribution, Generic[D]):
r"""
Reinterprets some of the batch dims of a distribution as event dims.
This is mainly useful for changing the shape of the result of
:meth:`log_prob`. For example to create a diagonal Normal distribution with
the same shape as a Multivariate Normal distribution (so they are
interchangeable), you can::
>>> from torch.distributions.multivariate_normal import MultivariateNormal
>>> from torch.distributions.normal import Normal
>>> loc = torch.zeros(3)
>>> scale = torch.ones(3)
>>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
>>> [mvn.batch_shape, mvn.event_shape]
[torch.Size([]), torch.Size([3])]
>>> normal = Normal(loc, scale)
>>> [normal.batch_shape, normal.event_shape]
[torch.Size([3]), torch.Size([])]
>>> diagn = Independent(normal, 1)
>>> [diagn.batch_shape, diagn.event_shape]
[torch.Size([]), torch.Size([3])]
Args:
base_distribution (torch.distributions.distribution.Distribution): a
base distribution
reinterpreted_batch_ndims (int): the number of batch dims to
reinterpret as event dims
"""
arg_constraints: dict[str, constraints.Constraint] = {}
base_dist: D
def __init__(
self,
base_distribution: D,
reinterpreted_batch_ndims: int,
validate_args: Optional[bool] = None,
) -> None:
if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
raise ValueError(
"Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}"
)
shape: Size = base_distribution.batch_shape + base_distribution.event_shape
event_dim: int = reinterpreted_batch_ndims + len(base_distribution.event_shape)
batch_shape = shape[: len(shape) - event_dim]
event_shape = shape[len(shape) - event_dim :]
self.base_dist = base_distribution
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Independent, _instance)
batch_shape = torch.Size(batch_shape)
new.base_dist = self.base_dist.expand(
batch_shape + self.event_shape[: self.reinterpreted_batch_ndims]
)
new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
super(Independent, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
@property
def has_rsample(self) -> bool: # type: ignore[override]
return self.base_dist.has_rsample
@property
def has_enumerate_support(self) -> bool: # type: ignore[override]
if self.reinterpreted_batch_ndims > 0:
return False
return self.base_dist.has_enumerate_support
@constraints.dependent_property
def support(self):
result = self.base_dist.support
if self.reinterpreted_batch_ndims:
result = constraints.independent(result, self.reinterpreted_batch_ndims)
return result
@property
def mean(self) -> Tensor:
return self.base_dist.mean
@property
def mode(self) -> Tensor:
return self.base_dist.mode
@property
def variance(self) -> Tensor:
return self.base_dist.variance
def sample(self, sample_shape=torch.Size()) -> Tensor:
return self.base_dist.sample(sample_shape)
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
return self.base_dist.rsample(sample_shape)
def log_prob(self, value):
log_prob = self.base_dist.log_prob(value)
return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
def entropy(self):
entropy = self.base_dist.entropy()
return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
def enumerate_support(self, expand=True):
if self.reinterpreted_batch_ndims > 0:
raise NotImplementedError(
"Enumeration over cartesian product is not implemented"
)
return self.base_dist.enumerate_support(expand=expand)
def __repr__(self):
return (
self.__class__.__name__
+ f"({self.base_dist}, {self.reinterpreted_batch_ndims})"
)