Files
pytorch/torch/distributions/utils.py
Randolf Scholz 6c38b9be73 [typing] Add type hints to __init__ methods in torch.distributions. (#144197)
Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-06 17:50:35 +00:00

219 lines
7.6 KiB
Python

# mypy: allow-untyped-defs
from functools import update_wrapper
from typing import Any, Callable, Generic, overload, Union
from typing_extensions import TypeVar
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.overrides import is_tensor_like
from torch.types import _Number, Number
euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
__all__ = [
"broadcast_all",
"logits_to_probs",
"clamp_probs",
"probs_to_logits",
"lazy_property",
"tril_matrix_to_vec",
"vec_to_tril_matrix",
]
# FIXME: Use (*values: *Ts) -> tuple[Tensor for T in Ts] if Mapping-Type is ever added.
# See https://github.com/python/typing/issues/1216#issuecomment-2126153831
def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]:
r"""
Given a list of values (possibly containing numbers), returns a list where each
value is broadcasted based on the following rules:
- `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`.
- Number instances (scalars) are upcast to tensors having
the same size and type as the first tensor passed to `values`. If all the
values are scalars, then they are upcasted to scalar Tensors.
Args:
values (list of `Number`, `torch.*Tensor` or objects implementing __torch_function__)
Raises:
ValueError: if any of the values is not a `Number` instance,
a `torch.*Tensor` instance, or an instance implementing __torch_function__
"""
if not all(is_tensor_like(v) or isinstance(v, _Number) for v in values):
raise ValueError(
"Input arguments must all be instances of Number, "
"torch.Tensor or objects implementing __torch_function__."
)
if not all(is_tensor_like(v) for v in values):
options: dict[str, Any] = dict(dtype=torch.get_default_dtype())
for value in values:
if isinstance(value, torch.Tensor):
options = dict(dtype=value.dtype, device=value.device)
break
new_values = [
v if is_tensor_like(v) else torch.tensor(v, **options) for v in values
]
return torch.broadcast_tensors(*new_values)
return torch.broadcast_tensors(*values)
def _standard_normal(shape, dtype, device):
if torch._C._get_tracing_state():
# [JIT WORKAROUND] lack of support for .normal_()
return torch.normal(
torch.zeros(shape, dtype=dtype, device=device),
torch.ones(shape, dtype=dtype, device=device),
)
return torch.empty(shape, dtype=dtype, device=device).normal_()
def _sum_rightmost(value, dim):
r"""
Sum out ``dim`` many rightmost dimensions of a given tensor.
Args:
value (Tensor): A tensor of ``.dim()`` at least ``dim``.
dim (int): The number of rightmost dims to sum out.
"""
if dim == 0:
return value
required_shape = value.shape[:-dim] + (-1,)
return value.reshape(required_shape).sum(-1)
def logits_to_probs(logits, is_binary=False):
r"""
Converts a tensor of logits into probabilities. Note that for the
binary case, each value denotes log odds, whereas for the
multi-dimensional case, the values along the last dimension denote
the log probabilities (possibly unnormalized) of the events.
"""
if is_binary:
return torch.sigmoid(logits)
return F.softmax(logits, dim=-1)
def clamp_probs(probs):
"""Clamps the probabilities to be in the open interval `(0, 1)`.
The probabilities would be clamped between `eps` and `1 - eps`,
and `eps` would be the smallest representable positive number for the input data type.
Args:
probs (Tensor): A tensor of probabilities.
Returns:
Tensor: The clamped probabilities.
Examples:
>>> probs = torch.tensor([0.0, 0.5, 1.0])
>>> clamp_probs(probs)
tensor([1.1921e-07, 5.0000e-01, 1.0000e+00])
>>> probs = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64)
>>> clamp_probs(probs)
tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=torch.float64)
"""
eps = torch.finfo(probs.dtype).eps
return probs.clamp(min=eps, max=1 - eps)
def probs_to_logits(probs, is_binary=False):
r"""
Converts a tensor of probabilities into logits. For the binary case,
this denotes the probability of occurrence of the event indexed by `1`.
For the multi-dimensional case, the values along the last dimension
denote the probabilities of occurrence of each of the events.
"""
ps_clamped = clamp_probs(probs)
if is_binary:
return torch.log(ps_clamped) - torch.log1p(-ps_clamped)
return torch.log(ps_clamped)
T = TypeVar("T", contravariant=True)
R = TypeVar("R", covariant=True)
class lazy_property(Generic[T, R]):
r"""
Used as a decorator for lazy loading of class attributes. This uses a
non-data descriptor that calls the wrapped method to compute the property on
first call; thereafter replacing the wrapped method into an instance
attribute.
"""
def __init__(self, wrapped: Callable[[T], R]) -> None:
self.wrapped: Callable[[T], R] = wrapped
update_wrapper(self, wrapped) # type:ignore[arg-type]
@overload
def __get__(
self, instance: None, obj_type: Any = None
) -> "_lazy_property_and_property[T, R]": ...
@overload
def __get__(self, instance: T, obj_type: Any = None) -> R: ...
def __get__(
self, instance: Union[T, None], obj_type: Any = None
) -> "R | _lazy_property_and_property[T, R]":
if instance is None:
return _lazy_property_and_property(self.wrapped)
with torch.enable_grad():
value = self.wrapped(instance)
setattr(instance, self.wrapped.__name__, value)
return value
class _lazy_property_and_property(lazy_property[T, R], property):
"""We want lazy properties to look like multiple things.
* property when Sphinx autodoc looks
* lazy_property when Distribution validate_args looks
"""
def __init__(self, wrapped: Callable[[T], R]) -> None:
property.__init__(self, wrapped)
def tril_matrix_to_vec(mat: Tensor, diag: int = 0) -> Tensor:
r"""
Convert a `D x D` matrix or a batch of matrices into a (batched) vector
which comprises of lower triangular elements from the matrix in row order.
"""
n = mat.shape[-1]
if not torch._C._get_tracing_state() and (diag < -n or diag >= n):
raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n - 1}].")
arange = torch.arange(n, device=mat.device)
tril_mask = arange < arange.view(-1, 1) + (diag + 1)
vec = mat[..., tril_mask]
return vec
def vec_to_tril_matrix(vec: Tensor, diag: int = 0) -> Tensor:
r"""
Convert a vector or a batch of vectors into a batched `D x D`
lower triangular matrix containing elements from the vector in row order.
"""
# +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0
n = (
-(1 + 2 * diag)
+ ((1 + 2 * diag) ** 2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1)) ** 0.5
) / 2
eps = torch.finfo(vec.dtype).eps
if not torch._C._get_tracing_state() and (round(n) - n > eps):
raise ValueError(
f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as "
+ "the lower triangular part of a square D x D matrix."
)
n = round(n.item()) if isinstance(n, torch.Tensor) else round(n)
mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n)))
arange = torch.arange(n, device=vec.device)
tril_mask = arange < arange.view(-1, 1) + (diag + 1)
mat[..., tril_mask] = vec
return mat