Files
pytorch/torch/distributions/laplace.py
Randolf Scholz 6c38b9be73 [typing] Add type hints to __init__ methods in torch.distributions. (#144197)
Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-06 17:50:35 +00:00

105 lines
3.4 KiB
Python

# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size
__all__ = ["Laplace"]
class Laplace(Distribution):
r"""
Creates a Laplace distribution parameterized by :attr:`loc` and :attr:`scale`.
Example::
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample() # Laplace distributed with loc=0, scale=1
tensor([ 0.1046])
Args:
loc (float or Tensor): mean of the distribution
scale (float or Tensor): scale of the distribution
"""
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
has_rsample = True
@property
def mean(self) -> Tensor:
return self.loc
@property
def mode(self) -> Tensor:
return self.loc
@property
def variance(self) -> Tensor:
return 2 * self.scale.pow(2)
@property
def stddev(self) -> Tensor:
return (2**0.5) * self.scale
def __init__(
self,
loc: Union[Tensor, float],
scale: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, _Number) and isinstance(scale, _Number):
batch_shape = torch.Size()
else:
batch_shape = self.loc.size()
super().__init__(batch_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Laplace, _instance)
batch_shape = torch.Size(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
super(Laplace, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
shape = self._extended_shape(sample_shape)
finfo = torch.finfo(self.loc.dtype)
if torch._C._get_tracing_state():
# [JIT WORKAROUND] lack of support for .uniform_()
u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) * 2 - 1
return self.loc - self.scale * u.sign() * torch.log1p(
-u.abs().clamp(min=finfo.tiny)
)
u = self.loc.new(shape).uniform_(finfo.eps - 1, 1)
# TODO: If we ever implement tensor.nextafter, below is what we want ideally.
# u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5)
return self.loc - self.scale * u.sign() * torch.log1p(-u.abs())
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale
def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1(
-(value - self.loc).abs() / self.scale
)
def icdf(self, value):
term = value - 0.5
return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs())
def entropy(self):
return 1 + torch.log(2 * self.scale)