[typing] Add type hints to __init__ methods in torch.distributions. (#144197)

Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
Randolf Scholz
2025-04-06 17:50:35 +00:00
committed by PyTorch MergeBot
parent 49f6cce736
commit 6c38b9be73
42 changed files with 378 additions and 80 deletions

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import nan, Tensor
from torch.distributions import constraints
@ -10,7 +12,7 @@ from torch.distributions.utils import (
probs_to_logits,
)
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.types import _Number
from torch.types import _Number, Number
__all__ = ["Bernoulli"]
@ -41,7 +43,12 @@ class Bernoulli(ExponentialFamily):
has_enumerate_support = True
_mean_carrier_measure = 0
def __init__(self, probs=None, logits=None, validate_args=None):
def __init__(
self,
probs: Optional[Union[Tensor, Number]] = None,
logits: Optional[Union[Tensor, Number]] = None,
validate_args: Optional[bool] = None,
) -> None:
if (probs is None) == (logits is None):
raise ValueError(
"Either `probs` or `logits` must be specified, but not both."
@ -50,6 +57,7 @@ class Bernoulli(ExponentialFamily):
is_scalar = isinstance(probs, _Number)
(self.probs,) = broadcast_all(probs)
else:
assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number)
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import constraints
@ -36,7 +38,12 @@ class Beta(ExponentialFamily):
support = constraints.unit_interval
has_rsample = True
def __init__(self, concentration1, concentration0, validate_args=None):
def __init__(
self,
concentration1: Union[Tensor, float],
concentration0: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
if isinstance(concentration1, _Number) and isinstance(concentration0, _Number):
concentration1_concentration0 = torch.tensor(
[float(concentration1), float(concentration0)]

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import constraints
@ -50,7 +52,13 @@ class Binomial(Distribution):
}
has_enumerate_support = True
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
def __init__(
self,
total_count: Union[Tensor, int] = 1,
probs: Optional[Tensor] = None,
logits: Optional[Tensor] = None,
validate_args: Optional[bool] = None,
) -> None:
if (probs is None) == (logits is None):
raise ValueError(
"Either `probs` or `logits` must be specified, but not both."
@ -62,6 +70,7 @@ class Binomial(Distribution):
) = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs)
else:
assert logits is not None # helps mypy
(
self.total_count,
self.logits,

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
from torch import nan, Tensor
from torch.distributions import constraints
@ -51,7 +53,12 @@ class Categorical(Distribution):
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
has_enumerate_support = True
def __init__(self, probs=None, logits=None, validate_args=None):
def __init__(
self,
probs: Optional[Tensor] = None,
logits: Optional[Tensor] = None,
validate_args: Optional[bool] = None,
) -> None:
if (probs is None) == (logits is None):
raise ValueError(
"Either `probs` or `logits` must be specified, but not both."
@ -61,6 +68,7 @@ class Categorical(Distribution):
raise ValueError("`probs` parameter must be at least one-dimensional.")
self.probs = probs / probs.sum(-1, keepdim=True)
else:
assert logits is not None # helps mypy
if logits.dim() < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
# Normalize

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import Optional, Union
import torch
from torch import inf, nan, Tensor
@ -34,7 +35,12 @@ class Cauchy(Distribution):
support = constraints.real
has_rsample = True
def __init__(self, loc, scale, validate_args=None):
def __init__(
self,
loc: Union[Tensor, float],
scale: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, _Number) and isinstance(scale, _Number):
batch_shape = torch.Size()

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
from torch import Tensor
from torch.distributions import constraints
from torch.distributions.gamma import Gamma
@ -25,7 +27,11 @@ class Chi2(Gamma):
arg_constraints = {"df": constraints.positive}
def __init__(self, df, validate_args=None):
def __init__(
self,
df: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
super().__init__(0.5 * df, 0.5, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import Optional, Union
import torch
from torch import Tensor
@ -13,7 +14,7 @@ from torch.distributions.utils import (
probs_to_logits,
)
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.types import _Number, _size
from torch.types import _Number, _size, Number
__all__ = ["ContinuousBernoulli"]
@ -52,7 +53,11 @@ class ContinuousBernoulli(ExponentialFamily):
has_rsample = True
def __init__(
self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None
self,
probs: Optional[Union[Tensor, Number]] = None,
logits: Optional[Union[Tensor, Number]] = None,
lims: tuple[float, float] = (0.499, 0.501),
validate_args: Optional[bool] = None,
) -> None:
if (probs is None) == (logits is None):
raise ValueError(
@ -68,6 +73,7 @@ class ContinuousBernoulli(ExponentialFamily):
raise ValueError("The parameter probs has invalid values")
self.probs = clamp_probs(self.probs)
else:
assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number)
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
from torch import Tensor
from torch.autograd import Function
@ -54,7 +56,11 @@ class Dirichlet(ExponentialFamily):
support = constraints.simplex
has_rsample = True
def __init__(self, concentration, validate_args=None):
def __init__(
self,
concentration: Tensor,
validate_args: Optional[bool] = None,
) -> None:
if concentration.dim() < 1:
raise ValueError(
"`concentration` parameter must be at least one-dimensional."

View File

@ -44,7 +44,7 @@ class Distribution:
batch_shape: torch.Size = torch.Size(),
event_shape: torch.Size = torch.Size(),
validate_args: Optional[bool] = None,
):
) -> None:
self._batch_shape = batch_shape
self._event_shape = event_shape
if validate_args is not None:

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import constraints
@ -46,7 +48,11 @@ class Exponential(ExponentialFamily):
def variance(self) -> Tensor:
return self.rate.pow(-2)
def __init__(self, rate, validate_args=None):
def __init__(
self,
rate: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
(self.rate,) = broadcast_all(rate)
batch_shape = torch.Size() if isinstance(rate, _Number) else self.rate.size()
super().__init__(batch_shape, validate_args=validate_args)

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import nan, Tensor
from torch.distributions import constraints
@ -31,7 +33,12 @@ class FisherSnedecor(Distribution):
support = constraints.positive
has_rsample = True
def __init__(self, df1, df2, validate_args=None):
def __init__(
self,
df1: Union[Tensor, float],
df2: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.df1, self.df2 = broadcast_all(df1, df2)
self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
self._gamma2 = Gamma(self.df2 * 0.5, self.df2)

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import constraints
@ -52,7 +54,12 @@ class Gamma(ExponentialFamily):
def variance(self) -> Tensor:
return self.concentration / self.rate.pow(2)
def __init__(self, concentration, rate, validate_args=None):
def __init__(
self,
concentration: Union[Tensor, float],
rate: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.concentration, self.rate = broadcast_all(concentration, rate)
if isinstance(concentration, _Number) and isinstance(rate, _Number):
batch_shape = torch.Size()

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import constraints
@ -10,7 +12,7 @@ from torch.distributions.utils import (
probs_to_logits,
)
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.types import _Number
from torch.types import _Number, Number
__all__ = ["Geometric"]
@ -45,7 +47,12 @@ class Geometric(Distribution):
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.nonnegative_integer
def __init__(self, probs=None, logits=None, validate_args=None):
def __init__(
self,
probs: Optional[Union[Tensor, Number]] = None,
logits: Optional[Union[Tensor, Number]] = None,
validate_args: Optional[bool] = None,
) -> None:
if (probs is None) == (logits is None):
raise ValueError(
"Either `probs` or `logits` must be specified, but not both."
@ -53,11 +60,13 @@ class Geometric(Distribution):
if probs is not None:
(self.probs,) = broadcast_all(probs)
else:
assert logits is not None # helps mypy
(self.logits,) = broadcast_all(logits)
probs_or_logits = probs if probs is not None else logits
if isinstance(probs_or_logits, _Number):
batch_shape = torch.Size()
else:
assert probs_or_logits is not None # helps mypy
batch_shape = probs_or_logits.size()
super().__init__(batch_shape, validate_args=validate_args)
if self._validate_args and probs is not None:

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import Optional, Union
import torch
from torch import Tensor
@ -33,7 +34,12 @@ class Gumbel(TransformedDistribution):
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
def __init__(self, loc, scale, validate_args=None):
def __init__(
self,
loc: Union[Tensor, float],
scale: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.loc, self.scale = broadcast_all(loc, scale)
finfo = torch.finfo(self.loc.dtype)
if isinstance(loc, _Number) and isinstance(scale, _Number):

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import Optional, Union
import torch
from torch import inf, Tensor
@ -33,8 +34,13 @@ class HalfCauchy(TransformedDistribution):
arg_constraints = {"scale": constraints.positive}
support = constraints.nonnegative
has_rsample = True
base_dist: Cauchy
def __init__(self, scale, validate_args=None):
def __init__(
self,
scale: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
base_dist = Cauchy(0, scale, validate_args=False)
super().__init__(base_dist, AbsTransform(), validate_args=validate_args)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import Optional, Union
import torch
from torch import inf, Tensor
@ -33,8 +34,13 @@ class HalfNormal(TransformedDistribution):
arg_constraints = {"scale": constraints.positive}
support = constraints.nonnegative
has_rsample = True
base_dist: Normal
def __init__(self, scale, validate_args=None):
def __init__(
self,
scale: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
base_dist = Normal(0, scale, validate_args=False)
super().__init__(base_dist, AbsTransform(), validate_args=validate_args)

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
from typing import Generic, Optional, TypeVar
import torch
from torch import Tensor
from torch import Size, Tensor
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import _sum_rightmost
@ -11,7 +12,10 @@ from torch.types import _size
__all__ = ["Independent"]
class Independent(Distribution):
D = TypeVar("D", bound=Distribution)
class Independent(Distribution, Generic[D]):
r"""
Reinterprets some of the batch dims of a distribution as event dims.
@ -42,17 +46,21 @@ class Independent(Distribution):
"""
arg_constraints: dict[str, constraints.Constraint] = {}
base_dist: D
def __init__(
self, base_distribution, reinterpreted_batch_ndims, validate_args=None
):
self,
base_distribution: D,
reinterpreted_batch_ndims: int,
validate_args: Optional[bool] = None,
) -> None:
if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
raise ValueError(
"Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}"
)
shape = base_distribution.batch_shape + base_distribution.event_shape
event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)
shape: Size = base_distribution.batch_shape + base_distribution.event_shape
event_dim: int = reinterpreted_batch_ndims + len(base_distribution.event_shape)
batch_shape = shape[: len(shape) - event_dim]
event_shape = shape[len(shape) - event_dim :]
self.base_dist = base_distribution

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import constraints
@ -38,8 +40,14 @@ class InverseGamma(TransformedDistribution):
}
support = constraints.positive
has_rsample = True
base_dist: Gamma
def __init__(self, concentration, rate, validate_args=None):
def __init__(
self,
concentration: Union[Tensor, float],
rate: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
base_dist = Gamma(concentration, rate, validate_args=validate_args)
neg_one = -base_dist.rate.new_ones(())
super().__init__(

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import nan, Tensor
from torch.distributions import constraints
@ -45,7 +47,12 @@ class Kumaraswamy(TransformedDistribution):
support = constraints.unit_interval
has_rsample = True
def __init__(self, concentration1, concentration0, validate_args=None):
def __init__(
self,
concentration1: Union[Tensor, float],
concentration0: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.concentration1, self.concentration0 = broadcast_all(
concentration1, concentration0
)

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import constraints
@ -46,7 +48,12 @@ class Laplace(Distribution):
def stddev(self) -> Tensor:
return (2**0.5) * self.scale
def __init__(self, loc, scale, validate_args=None):
def __init__(
self,
loc: Union[Tensor, float],
scale: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, _Number) and isinstance(scale, _Number):
batch_shape = torch.Size()

View File

@ -9,8 +9,10 @@ Original copyright notice:
"""
import math
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import Beta, constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
@ -61,7 +63,12 @@ class LKJCholesky(Distribution):
arg_constraints = {"concentration": constraints.positive}
support = constraints.corr_cholesky
def __init__(self, dim, concentration=1.0, validate_args=None):
def __init__(
self,
dim: int,
concentration: Union[Tensor, float] = 1.0,
validate_args: Optional[bool] = None,
) -> None:
if dim < 2:
raise ValueError(
f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}."

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
from torch import Tensor
from torch.distributions import constraints
from torch.distributions.normal import Normal
@ -32,8 +34,14 @@ class LogNormal(TransformedDistribution):
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.positive
has_rsample = True
base_dist: Normal
def __init__(self, loc, scale, validate_args=None):
def __init__(
self,
loc: Union[Tensor, float],
scale: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
base_dist = Normal(loc, scale, validate_args=validate_args)
super().__init__(base_dist, ExpTransform(), validate_args=validate_args)

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
from torch import Tensor
from torch.distributions import constraints
from torch.distributions import constraints, Independent
from torch.distributions.normal import Normal
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import StickBreakingTransform
@ -36,8 +38,14 @@ class LogisticNormal(TransformedDistribution):
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.simplex
has_rsample = True
base_dist: Independent[Normal]
def __init__(self, loc, scale, validate_args=None):
def __init__(
self,
loc: Union[Tensor, float],
scale: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
base_dist = Normal(loc, scale, validate_args=validate_args)
if not base_dist.batch_shape:
base_dist = base_dist.expand([1])

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import Optional
import torch
from torch import Tensor
@ -93,7 +94,13 @@ class LowRankMultivariateNormal(Distribution):
support = constraints.real_vector
has_rsample = True
def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
def __init__(
self,
loc: Tensor,
cov_factor: Tensor,
cov_diag: Tensor,
validate_args: Optional[bool] = None,
) -> None:
if loc.dim() < 1:
raise ValueError("loc must be at least one-dimensional.")
event_shape = loc.shape[-1:]

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
from torch import Tensor
@ -59,7 +60,7 @@ class MixtureSameFamily(Distribution):
self,
mixture_distribution: Categorical,
component_distribution: Distribution,
validate_args=None,
validate_args: Optional[bool] = None,
) -> None:
self._mixture_distribution = mixture_distribution
self._component_distribution = component_distribution

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
from torch import inf, Tensor
from torch.distributions import Categorical, constraints
@ -59,7 +61,13 @@ class Multinomial(Distribution):
def variance(self) -> Tensor:
return self.total_count * self.probs * (1 - self.probs)
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
def __init__(
self,
total_count: int = 1,
probs: Optional[Tensor] = None,
logits: Optional[Tensor] = None,
validate_args: Optional[bool] = None,
) -> None:
if not isinstance(total_count, int):
raise NotImplementedError("inhomogeneous total_count is not supported")
self.total_count = total_count

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import Optional
import torch
from torch import Tensor
@ -133,12 +134,12 @@ class MultivariateNormal(Distribution):
def __init__(
self,
loc,
covariance_matrix=None,
precision_matrix=None,
scale_tril=None,
validate_args=None,
):
loc: Tensor,
covariance_matrix: Optional[Tensor] = None,
precision_matrix: Optional[Tensor] = None,
scale_tril: Optional[Tensor] = None,
validate_args: Optional[bool] = None,
) -> None:
if loc.dim() < 1:
raise ValueError("loc must be at least one-dimensional.")
if (covariance_matrix is not None) + (scale_tril is not None) + (
@ -167,6 +168,7 @@ class MultivariateNormal(Distribution):
)
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
else:
assert precision_matrix is not None # helps mypy
if precision_matrix.dim() < 2:
raise ValueError(
"precision_matrix must be at least two-dimensional, "

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor
@ -38,7 +40,13 @@ class NegativeBinomial(Distribution):
}
support = constraints.nonnegative_integer
def __init__(self, total_count, probs=None, logits=None, validate_args=None):
def __init__(
self,
total_count: Union[Tensor, float],
probs: Optional[Tensor] = None,
logits: Optional[Tensor] = None,
validate_args: Optional[bool] = None,
) -> None:
if (probs is None) == (logits is None):
raise ValueError(
"Either `probs` or `logits` must be specified, but not both."
@ -50,6 +58,7 @@ class NegativeBinomial(Distribution):
) = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs)
else:
assert logits is not None # helps mypy
(
self.total_count,
self.logits,

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import Optional, Union
import torch
from torch import Tensor
@ -51,7 +52,12 @@ class Normal(ExponentialFamily):
def variance(self) -> Tensor:
return self.stddev.pow(2)
def __init__(self, loc, scale, validate_args=None):
def __init__(
self,
loc: Union[Tensor, float],
scale: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, _Number) and isinstance(scale, _Number):
batch_shape = torch.Size()

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
from torch import Tensor
from torch.distributions import constraints
@ -44,7 +46,12 @@ class OneHotCategorical(Distribution):
support = constraints.one_hot
has_enumerate_support = True
def __init__(self, probs=None, logits=None, validate_args=None):
def __init__(
self,
probs: Optional[Tensor] = None,
logits: Optional[Tensor] = None,
validate_args: Optional[bool] = None,
) -> None:
self._categorical = Categorical(probs, logits)
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union
from torch import Tensor
from torch.distributions import constraints
@ -31,7 +31,10 @@ class Pareto(TransformedDistribution):
arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive}
def __init__(
self, scale: Tensor, alpha: Tensor, validate_args: Optional[bool] = None
self,
scale: Union[Tensor, float],
alpha: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.scale, self.alpha = broadcast_all(scale, alpha)
base_dist = Exponential(self.alpha, validate_args=validate_args)

View File

@ -1,10 +1,12 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all
from torch.types import _Number
from torch.types import _Number, Number
__all__ = ["Poisson"]
@ -45,7 +47,11 @@ class Poisson(ExponentialFamily):
def variance(self) -> Tensor:
return self.rate
def __init__(self, rate, validate_args=None):
def __init__(
self,
rate: Union[Tensor, Number],
validate_args: Optional[bool] = None,
) -> None:
(self.rate,) = broadcast_all(rate)
if isinstance(rate, _Number):
batch_shape = torch.Size()

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import constraints
@ -12,7 +14,7 @@ from torch.distributions.utils import (
logits_to_probs,
probs_to_logits,
)
from torch.types import _Number, _size
from torch.types import _Number, _size, Number
__all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"]
@ -41,7 +43,13 @@ class LogitRelaxedBernoulli(Distribution):
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.real
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
def __init__(
self,
temperature: Tensor,
probs: Optional[Union[Tensor, Number]] = None,
logits: Optional[Union[Tensor, Number]] = None,
validate_args: Optional[bool] = None,
) -> None:
self.temperature = temperature
if (probs is None) == (logits is None):
raise ValueError(
@ -51,6 +59,7 @@ class LogitRelaxedBernoulli(Distribution):
is_scalar = isinstance(probs, _Number)
(self.probs,) = broadcast_all(probs)
else:
assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number)
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
@ -131,8 +140,15 @@ class RelaxedBernoulli(TransformedDistribution):
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.unit_interval
has_rsample = True
base_dist: LogitRelaxedBernoulli
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
def __init__(
self,
temperature: Tensor,
probs: Optional[Union[Tensor, Number]] = None,
logits: Optional[Union[Tensor, Number]] = None,
validate_args: Optional[bool] = None,
) -> None:
base_dist = LogitRelaxedBernoulli(temperature, probs, logits)
super().__init__(base_dist, SigmoidTransform(), validate_args=validate_args)

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
from torch import Tensor
from torch.distributions import constraints
@ -42,7 +44,13 @@ class ExpRelaxedCategorical(Distribution):
) # The true support is actually a submanifold of this.
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
def __init__(
self,
temperature: Tensor,
probs: Optional[Tensor] = None,
logits: Optional[Tensor] = None,
validate_args: Optional[bool] = None,
) -> None:
self._categorical = Categorical(probs, logits)
self.temperature = temperature
batch_shape = self._categorical.batch_shape
@ -121,8 +129,15 @@ class RelaxedOneHotCategorical(TransformedDistribution):
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.simplex
has_rsample = True
base_dist: ExpRelaxedCategorical
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
def __init__(
self,
temperature: Tensor,
probs: Optional[Tensor] = None,
logits: Optional[Tensor] = None,
validate_args: Optional[bool] = None,
) -> None:
base_dist = ExpRelaxedCategorical(
temperature, probs, logits, validate_args=validate_args
)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import Optional, Union
import torch
from torch import inf, nan, Tensor
@ -60,7 +61,13 @@ class StudentT(Distribution):
m[self.df <= 1] = nan
return m
def __init__(self, df, loc=0.0, scale=1.0, validate_args=None):
def __init__(
self,
df: Union[Tensor, float],
loc: Union[Tensor, float] = 0.0,
scale: Union[Tensor, float] = 1.0,
validate_args: Optional[bool] = None,
) -> None:
self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
self._chi2 = Chi2(self.df)
batch_shape = self.df.size()

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
@ -49,7 +50,12 @@ class TransformedDistribution(Distribution):
arg_constraints: dict[str, constraints.Constraint] = {}
def __init__(self, base_distribution, transforms, validate_args=None):
def __init__(
self,
base_distribution: Distribution,
transforms: Union[Transform, list[Transform]],
validate_args: Optional[bool] = None,
) -> None:
if isinstance(transforms, Transform):
self.transforms = [
transforms,

View File

@ -3,11 +3,14 @@ import functools
import math
import operator
import weakref
from typing import Optional
from collections.abc import Sequence
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import (
_sum_rightmost,
broadcast_all,
@ -92,7 +95,7 @@ class Transform:
domain: constraints.Constraint
codomain: constraints.Constraint
def __init__(self, cache_size=0):
def __init__(self, cache_size: int = 0) -> None:
self._cache_size = cache_size
self._inv: Optional[weakref.ReferenceType[Transform]] = None
if cache_size == 0:
@ -218,7 +221,7 @@ class _InverseTransform(Transform):
This class is private; please instead use the ``Transform.inv`` property.
"""
def __init__(self, transform: Transform):
def __init__(self, transform: Transform) -> None:
super().__init__(cache_size=transform._cache_size)
self._inv: Transform = transform # type: ignore[assignment]
@ -285,7 +288,7 @@ class ComposeTransform(Transform):
the latest single value is cached. Only 0 and 1 are supported.
"""
def __init__(self, parts: list[Transform], cache_size=0):
def __init__(self, parts: list[Transform], cache_size: int = 0) -> None:
if cache_size:
parts = [part.with_cache(cache_size) for part in parts]
super().__init__(cache_size=cache_size)
@ -413,7 +416,12 @@ class IndependentTransform(Transform):
dimensions to treat as dependent.
"""
def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0):
def __init__(
self,
base_transform: Transform,
reinterpreted_batch_ndims: int,
cache_size: int = 0,
) -> None:
super().__init__(cache_size=cache_size)
self.base_transform = base_transform.with_cache(cache_size)
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
@ -442,7 +450,7 @@ class IndependentTransform(Transform):
return self.base_transform.bijective
@property
def sign(self) -> int: # type: ignore[override]
def sign(self) -> int:
return self.base_transform.sign
def _call(self, x):
@ -486,7 +494,12 @@ class ReshapeTransform(Transform):
bijective = True
def __init__(self, in_shape, out_shape, cache_size=0):
def __init__(
self,
in_shape: torch.Size,
out_shape: torch.Size,
cache_size: int = 0,
) -> None:
self.in_shape = torch.Size(in_shape)
self.out_shape = torch.Size(out_shape)
if self.in_shape.numel() != self.out_shape.numel():
@ -571,7 +584,7 @@ class PowerTransform(Transform):
codomain = constraints.positive
bijective = True
def __init__(self, exponent, cache_size=0):
def __init__(self, exponent: Tensor, cache_size: int = 0) -> None:
super().__init__(cache_size=cache_size)
(self.exponent,) = broadcast_all(exponent)
@ -582,7 +595,7 @@ class PowerTransform(Transform):
@lazy_property
def sign(self) -> int: # type: ignore[override]
return self.exponent.sign()
return self.exponent.sign() # type: ignore[return-value]
def __eq__(self, other):
if not isinstance(other, PowerTransform):
@ -734,7 +747,13 @@ class AffineTransform(Transform):
bijective = True
def __init__(self, loc, scale, event_dim=0, cache_size=0):
def __init__(
self,
loc: Union[Tensor, float],
scale: Union[Tensor, float],
event_dim: int = 0,
cache_size: int = 0,
) -> None:
super().__init__(cache_size=cache_size)
self.loc = loc
self.scale = scale
@ -771,20 +790,20 @@ class AffineTransform(Transform):
if self.loc != other.loc:
return False
else:
if not (self.loc == other.loc).all().item():
if not (self.loc == other.loc).all().item(): # type: ignore[union-attr]
return False
if isinstance(self.scale, _Number) and isinstance(other.scale, _Number):
if self.scale != other.scale:
return False
else:
if not (self.scale == other.scale).all().item():
if not (self.scale == other.scale).all().item(): # type: ignore[union-attr]
return False
return True
@property
def sign(self) -> int:
def sign(self) -> Union[Tensor, int]: # type: ignore[override]
if isinstance(self.scale, _Number):
return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
return self.scale.sign()
@ -1022,7 +1041,7 @@ class PositiveDefiniteTransform(Transform):
"""
domain = constraints.independent(constraints.real, 2)
codomain = constraints.positive_definite # type: ignore[assignment]
codomain = constraints.positive_definite
def __eq__(self, other):
return isinstance(other, PositiveDefiniteTransform)
@ -1053,7 +1072,13 @@ class CatTransform(Transform):
transforms: list[Transform]
def __init__(self, tseq, dim=0, lengths=None, cache_size=0):
def __init__(
self,
tseq: Sequence[Transform],
dim: int = 0,
lengths: Optional[Sequence[int]] = None,
cache_size: int = 0,
) -> None:
assert all(isinstance(t, Transform) for t in tseq)
if cache_size:
tseq = [t.with_cache(cache_size) for t in tseq]
@ -1157,7 +1182,9 @@ class StackTransform(Transform):
transforms: list[Transform]
def __init__(self, tseq, dim=0, cache_size=0):
def __init__(
self, tseq: Sequence[Transform], dim: int = 0, cache_size: int = 0
) -> None:
assert all(isinstance(t, Transform) for t in tseq)
if cache_size:
tseq = [t.with_cache(cache_size) for t in tseq]
@ -1237,12 +1264,12 @@ class CumulativeDistributionTransform(Transform):
codomain = constraints.unit_interval
sign = +1
def __init__(self, distribution, cache_size=0):
def __init__(self, distribution: Distribution, cache_size: int = 0) -> None:
super().__init__(cache_size=cache_size)
self.distribution = distribution
@property
def domain(self) -> constraints.Constraint: # type: ignore[override]
def domain(self) -> Optional[constraints.Constraint]: # type: ignore[override]
return self.distribution.support
def _call(self, x):

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import nan, Tensor
from torch.distributions import constraints
@ -50,7 +52,12 @@ class Uniform(Distribution):
def variance(self) -> Tensor:
return (self.high - self.low).pow(2) / 12
def __init__(self, low, high, validate_args=None):
def __init__(
self,
low: Union[Tensor, float],
high: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.low, self.high = broadcast_all(low, high)
if isinstance(low, _Number) and isinstance(high, _Number):

View File

@ -7,7 +7,7 @@ import torch
import torch.nn.functional as F
from torch import Tensor
from torch.overrides import is_tensor_like
from torch.types import _Number
from torch.types import _Number, Number
euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
@ -23,7 +23,9 @@ __all__ = [
]
def broadcast_all(*values):
# FIXME: Use (*values: *Ts) -> tuple[Tensor for T in Ts] if Mapping-Type is ever added.
# See https://github.com/python/typing/issues/1216#issuecomment-2126153831
def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]:
r"""
Given a list of values (possibly containing numbers), returns a list where each
value is broadcasted based on the following rules:

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import Optional
import torch
import torch.jit
@ -126,7 +127,12 @@ class VonMises(Distribution):
support = constraints.real
has_rsample = False
def __init__(self, loc, concentration, validate_args=None):
def __init__(
self,
loc: Tensor,
concentration: Tensor,
validate_args: Optional[bool] = None,
) -> None:
self.loc, self.concentration = broadcast_all(loc, concentration)
batch_shape = self.loc.shape
event_shape = torch.Size()

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from torch.distributions import constraints
@ -34,7 +36,12 @@ class Weibull(TransformedDistribution):
}
support = constraints.positive
def __init__(self, scale, concentration, validate_args=None):
def __init__(
self,
scale: Union[Tensor, float],
concentration: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.scale, self.concentration = broadcast_all(scale, concentration)
self.concentration_reciprocal = self.concentration.reciprocal()
base_dist = Exponential(

View File

@ -80,8 +80,8 @@ class Wishart(ExponentialFamily):
covariance_matrix: Optional[Tensor] = None,
precision_matrix: Optional[Tensor] = None,
scale_tril: Optional[Tensor] = None,
validate_args=None,
):
validate_args: Optional[bool] = None,
) -> None:
assert (covariance_matrix is not None) + (scale_tril is not None) + (
precision_matrix is not None
) == 1, (