diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py index 105038641bcc..659f9a20b10e 100644 --- a/torch/distributions/bernoulli.py +++ b/torch/distributions/bernoulli.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -10,7 +12,7 @@ from torch.distributions.utils import ( probs_to_logits, ) from torch.nn.functional import binary_cross_entropy_with_logits -from torch.types import _Number +from torch.types import _Number, Number __all__ = ["Bernoulli"] @@ -41,7 +43,12 @@ class Bernoulli(ExponentialFamily): has_enumerate_support = True _mean_carrier_measure = 0 - def __init__(self, probs=None, logits=None, validate_args=None): + def __init__( + self, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -50,6 +57,7 @@ class Bernoulli(ExponentialFamily): is_scalar = isinstance(probs, _Number) (self.probs,) = broadcast_all(probs) else: + assert logits is not None # helps mypy is_scalar = isinstance(logits, _Number) (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py index e030b648a88e..e06a28ca5aa4 100644 --- a/torch/distributions/beta.py +++ b/torch/distributions/beta.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -36,7 +38,12 @@ class Beta(ExponentialFamily): support = constraints.unit_interval has_rsample = True - def __init__(self, concentration1, concentration0, validate_args=None): + def __init__( + self, + concentration1: Union[Tensor, float], + concentration0: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: if isinstance(concentration1, _Number) and isinstance(concentration0, _Number): concentration1_concentration0 = torch.tensor( [float(concentration1), float(concentration0)] diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 6cbfae150844..90461784c06d 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -50,7 +52,13 @@ class Binomial(Distribution): } has_enumerate_support = True - def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): + def __init__( + self, + total_count: Union[Tensor, int] = 1, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -62,6 +70,7 @@ class Binomial(Distribution): ) = broadcast_all(total_count, probs) self.total_count = self.total_count.type_as(self.probs) else: + assert logits is not None # helps mypy ( self.total_count, self.logits, diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index 715429c66552..1c8fed2636ad 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -51,7 +53,12 @@ class Categorical(Distribution): arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} has_enumerate_support = True - def __init__(self, probs=None, logits=None, validate_args=None): + def __init__( + self, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -61,6 +68,7 @@ class Categorical(Distribution): raise ValueError("`probs` parameter must be at least one-dimensional.") self.probs = probs / probs.sum(-1, keepdim=True) else: + assert logits is not None # helps mypy if logits.dim() < 1: raise ValueError("`logits` parameter must be at least one-dimensional.") # Normalize diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py index 582c08ebb858..84c1d34bda79 100644 --- a/torch/distributions/cauchy.py +++ b/torch/distributions/cauchy.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import inf, nan, Tensor @@ -34,7 +35,12 @@ class Cauchy(Distribution): support = constraints.real has_rsample = True - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.scale = broadcast_all(loc, scale) if isinstance(loc, _Number) and isinstance(scale, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/chi2.py b/torch/distributions/chi2.py index f175bc44f69e..fa23115fc035 100644 --- a/torch/distributions/chi2.py +++ b/torch/distributions/chi2.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + from torch import Tensor from torch.distributions import constraints from torch.distributions.gamma import Gamma @@ -25,7 +27,11 @@ class Chi2(Gamma): arg_constraints = {"df": constraints.positive} - def __init__(self, df, validate_args=None): + def __init__( + self, + df: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: super().__init__(0.5 * df, 0.5, validate_args=validate_args) def expand(self, batch_shape, _instance=None): diff --git a/torch/distributions/continuous_bernoulli.py b/torch/distributions/continuous_bernoulli.py index b1e8eddfb0ec..14d0d6a9c177 100644 --- a/torch/distributions/continuous_bernoulli.py +++ b/torch/distributions/continuous_bernoulli.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import Tensor @@ -13,7 +14,7 @@ from torch.distributions.utils import ( probs_to_logits, ) from torch.nn.functional import binary_cross_entropy_with_logits -from torch.types import _Number, _size +from torch.types import _Number, _size, Number __all__ = ["ContinuousBernoulli"] @@ -52,7 +53,11 @@ class ContinuousBernoulli(ExponentialFamily): has_rsample = True def __init__( - self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None + self, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + lims: tuple[float, float] = (0.499, 0.501), + validate_args: Optional[bool] = None, ) -> None: if (probs is None) == (logits is None): raise ValueError( @@ -68,6 +73,7 @@ class ContinuousBernoulli(ExponentialFamily): raise ValueError("The parameter probs has invalid values") self.probs = clamp_probs(self.probs) else: + assert logits is not None # helps mypy is_scalar = isinstance(logits, _Number) (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py index f656a0582e89..414ad6efe47e 100644 --- a/torch/distributions/dirichlet.py +++ b/torch/distributions/dirichlet.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import Tensor from torch.autograd import Function @@ -54,7 +56,11 @@ class Dirichlet(ExponentialFamily): support = constraints.simplex has_rsample = True - def __init__(self, concentration, validate_args=None): + def __init__( + self, + concentration: Tensor, + validate_args: Optional[bool] = None, + ) -> None: if concentration.dim() < 1: raise ValueError( "`concentration` parameter must be at least one-dimensional." diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index 75ea50d24860..b2895cb3b0d7 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -44,7 +44,7 @@ class Distribution: batch_shape: torch.Size = torch.Size(), event_shape: torch.Size = torch.Size(), validate_args: Optional[bool] = None, - ): + ) -> None: self._batch_shape = batch_shape self._event_shape = event_shape if validate_args is not None: diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py index 8ca2636e1f52..d15cb1f7a258 100644 --- a/torch/distributions/exponential.py +++ b/torch/distributions/exponential.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -46,7 +48,11 @@ class Exponential(ExponentialFamily): def variance(self) -> Tensor: return self.rate.pow(-2) - def __init__(self, rate, validate_args=None): + def __init__( + self, + rate: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: (self.rate,) = broadcast_all(rate) batch_shape = torch.Size() if isinstance(rate, _Number) else self.rate.size() super().__init__(batch_shape, validate_args=validate_args) diff --git a/torch/distributions/fishersnedecor.py b/torch/distributions/fishersnedecor.py index 053686c6de07..4755bd0d8bde 100644 --- a/torch/distributions/fishersnedecor.py +++ b/torch/distributions/fishersnedecor.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -31,7 +33,12 @@ class FisherSnedecor(Distribution): support = constraints.positive has_rsample = True - def __init__(self, df1, df2, validate_args=None): + def __init__( + self, + df1: Union[Tensor, float], + df2: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.df1, self.df2 = broadcast_all(df1, df2) self._gamma1 = Gamma(self.df1 * 0.5, self.df1) self._gamma2 = Gamma(self.df2 * 0.5, self.df2) diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py index 5e0fe3fc7823..9df91ebee640 100644 --- a/torch/distributions/gamma.py +++ b/torch/distributions/gamma.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -52,7 +54,12 @@ class Gamma(ExponentialFamily): def variance(self) -> Tensor: return self.concentration / self.rate.pow(2) - def __init__(self, concentration, rate, validate_args=None): + def __init__( + self, + concentration: Union[Tensor, float], + rate: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.concentration, self.rate = broadcast_all(concentration, rate) if isinstance(concentration, _Number) and isinstance(rate, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py index b8b05142db5b..b5ceac39e94e 100644 --- a/torch/distributions/geometric.py +++ b/torch/distributions/geometric.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -10,7 +12,7 @@ from torch.distributions.utils import ( probs_to_logits, ) from torch.nn.functional import binary_cross_entropy_with_logits -from torch.types import _Number +from torch.types import _Number, Number __all__ = ["Geometric"] @@ -45,7 +47,12 @@ class Geometric(Distribution): arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.nonnegative_integer - def __init__(self, probs=None, logits=None, validate_args=None): + def __init__( + self, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -53,11 +60,13 @@ class Geometric(Distribution): if probs is not None: (self.probs,) = broadcast_all(probs) else: + assert logits is not None # helps mypy (self.logits,) = broadcast_all(logits) probs_or_logits = probs if probs is not None else logits if isinstance(probs_or_logits, _Number): batch_shape = torch.Size() else: + assert probs_or_logits is not None # helps mypy batch_shape = probs_or_logits.size() super().__init__(batch_shape, validate_args=validate_args) if self._validate_args and probs is not None: diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index 623cc7edbda6..6d097c9324e2 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import Tensor @@ -33,7 +34,12 @@ class Gumbel(TransformedDistribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.scale = broadcast_all(loc, scale) finfo = torch.finfo(self.loc.dtype) if isinstance(loc, _Number) and isinstance(scale, _Number): diff --git a/torch/distributions/half_cauchy.py b/torch/distributions/half_cauchy.py index da17c40da2ed..572ae080ac3e 100644 --- a/torch/distributions/half_cauchy.py +++ b/torch/distributions/half_cauchy.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import inf, Tensor @@ -33,8 +34,13 @@ class HalfCauchy(TransformedDistribution): arg_constraints = {"scale": constraints.positive} support = constraints.nonnegative has_rsample = True + base_dist: Cauchy - def __init__(self, scale, validate_args=None): + def __init__( + self, + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Cauchy(0, scale, validate_args=False) super().__init__(base_dist, AbsTransform(), validate_args=validate_args) diff --git a/torch/distributions/half_normal.py b/torch/distributions/half_normal.py index 5850f883e908..21e1b9d2c506 100644 --- a/torch/distributions/half_normal.py +++ b/torch/distributions/half_normal.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import inf, Tensor @@ -33,8 +34,13 @@ class HalfNormal(TransformedDistribution): arg_constraints = {"scale": constraints.positive} support = constraints.nonnegative has_rsample = True + base_dist: Normal - def __init__(self, scale, validate_args=None): + def __init__( + self, + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Normal(0, scale, validate_args=False) super().__init__(base_dist, AbsTransform(), validate_args=validate_args) diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index 0442a4c1b483..b66406681bb8 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs +from typing import Generic, Optional, TypeVar import torch -from torch import Tensor +from torch import Size, Tensor from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import _sum_rightmost @@ -11,7 +12,10 @@ from torch.types import _size __all__ = ["Independent"] -class Independent(Distribution): +D = TypeVar("D", bound=Distribution) + + +class Independent(Distribution, Generic[D]): r""" Reinterprets some of the batch dims of a distribution as event dims. @@ -42,17 +46,21 @@ class Independent(Distribution): """ arg_constraints: dict[str, constraints.Constraint] = {} + base_dist: D def __init__( - self, base_distribution, reinterpreted_batch_ndims, validate_args=None - ): + self, + base_distribution: D, + reinterpreted_batch_ndims: int, + validate_args: Optional[bool] = None, + ) -> None: if reinterpreted_batch_ndims > len(base_distribution.batch_shape): raise ValueError( "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), " f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}" ) - shape = base_distribution.batch_shape + base_distribution.event_shape - event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape) + shape: Size = base_distribution.batch_shape + base_distribution.event_shape + event_dim: int = reinterpreted_batch_ndims + len(base_distribution.event_shape) batch_shape = shape[: len(shape) - event_dim] event_shape = shape[len(shape) - event_dim :] self.base_dist = base_distribution diff --git a/torch/distributions/inverse_gamma.py b/torch/distributions/inverse_gamma.py index aaee976b7f17..de432a34434e 100644 --- a/torch/distributions/inverse_gamma.py +++ b/torch/distributions/inverse_gamma.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -38,8 +40,14 @@ class InverseGamma(TransformedDistribution): } support = constraints.positive has_rsample = True + base_dist: Gamma - def __init__(self, concentration, rate, validate_args=None): + def __init__( + self, + concentration: Union[Tensor, float], + rate: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Gamma(concentration, rate, validate_args=validate_args) neg_one = -base_dist.rate.new_ones(()) super().__init__( diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py index d38efb631e86..53c09ab9870d 100644 --- a/torch/distributions/kumaraswamy.py +++ b/torch/distributions/kumaraswamy.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -45,7 +47,12 @@ class Kumaraswamy(TransformedDistribution): support = constraints.unit_interval has_rsample = True - def __init__(self, concentration1, concentration0, validate_args=None): + def __init__( + self, + concentration1: Union[Tensor, float], + concentration0: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.concentration1, self.concentration0 = broadcast_all( concentration1, concentration0 ) diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py index 39ef9b1efdb7..0d50712fb26f 100644 --- a/torch/distributions/laplace.py +++ b/torch/distributions/laplace.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -46,7 +48,12 @@ class Laplace(Distribution): def stddev(self) -> Tensor: return (2**0.5) * self.scale - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.scale = broadcast_all(loc, scale) if isinstance(loc, _Number) and isinstance(scale, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/lkj_cholesky.py b/torch/distributions/lkj_cholesky.py index a18f2ed9f52a..d2c29a9286de 100644 --- a/torch/distributions/lkj_cholesky.py +++ b/torch/distributions/lkj_cholesky.py @@ -9,8 +9,10 @@ Original copyright notice: """ import math +from typing import Optional, Union import torch +from torch import Tensor from torch.distributions import Beta, constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all @@ -61,7 +63,12 @@ class LKJCholesky(Distribution): arg_constraints = {"concentration": constraints.positive} support = constraints.corr_cholesky - def __init__(self, dim, concentration=1.0, validate_args=None): + def __init__( + self, + dim: int, + concentration: Union[Tensor, float] = 1.0, + validate_args: Optional[bool] = None, + ) -> None: if dim < 2: raise ValueError( f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}." diff --git a/torch/distributions/log_normal.py b/torch/distributions/log_normal.py index a048f94286c8..2c6dbc6bf55c 100644 --- a/torch/distributions/log_normal.py +++ b/torch/distributions/log_normal.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + from torch import Tensor from torch.distributions import constraints from torch.distributions.normal import Normal @@ -32,8 +34,14 @@ class LogNormal(TransformedDistribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.positive has_rsample = True + base_dist: Normal - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Normal(loc, scale, validate_args=validate_args) super().__init__(base_dist, ExpTransform(), validate_args=validate_args) diff --git a/torch/distributions/logistic_normal.py b/torch/distributions/logistic_normal.py index a8f7c099d1e8..729e3a67419f 100644 --- a/torch/distributions/logistic_normal.py +++ b/torch/distributions/logistic_normal.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + from torch import Tensor -from torch.distributions import constraints +from torch.distributions import constraints, Independent from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import StickBreakingTransform @@ -36,8 +38,14 @@ class LogisticNormal(TransformedDistribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.simplex has_rsample = True + base_dist: Independent[Normal] - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Normal(loc, scale, validate_args=validate_args) if not base_dist.batch_shape: base_dist = base_dist.expand([1]) diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index c6f739a595a3..968e4634ba62 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional import torch from torch import Tensor @@ -93,7 +94,13 @@ class LowRankMultivariateNormal(Distribution): support = constraints.real_vector has_rsample = True - def __init__(self, loc, cov_factor, cov_diag, validate_args=None): + def __init__( + self, + loc: Tensor, + cov_factor: Tensor, + cov_diag: Tensor, + validate_args: Optional[bool] = None, + ) -> None: if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") event_shape = loc.shape[-1:] diff --git a/torch/distributions/mixture_same_family.py b/torch/distributions/mixture_same_family.py index 1fc2c1052d03..79a7029e1d72 100644 --- a/torch/distributions/mixture_same_family.py +++ b/torch/distributions/mixture_same_family.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +from typing import Optional import torch from torch import Tensor @@ -59,7 +60,7 @@ class MixtureSameFamily(Distribution): self, mixture_distribution: Categorical, component_distribution: Distribution, - validate_args=None, + validate_args: Optional[bool] = None, ) -> None: self._mixture_distribution = mixture_distribution self._component_distribution = component_distribution diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 85a227f5c403..41d8ded53fd6 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import inf, Tensor from torch.distributions import Categorical, constraints @@ -59,7 +61,13 @@ class Multinomial(Distribution): def variance(self) -> Tensor: return self.total_count * self.probs * (1 - self.probs) - def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): + def __init__( + self, + total_count: int = 1, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if not isinstance(total_count, int): raise NotImplementedError("inhomogeneous total_count is not supported") self.total_count = total_count diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index 849ee4170015..c15a84815b06 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional import torch from torch import Tensor @@ -133,12 +134,12 @@ class MultivariateNormal(Distribution): def __init__( self, - loc, - covariance_matrix=None, - precision_matrix=None, - scale_tril=None, - validate_args=None, - ): + loc: Tensor, + covariance_matrix: Optional[Tensor] = None, + precision_matrix: Optional[Tensor] = None, + scale_tril: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") if (covariance_matrix is not None) + (scale_tril is not None) + ( @@ -167,6 +168,7 @@ class MultivariateNormal(Distribution): ) self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) else: + assert precision_matrix is not None # helps mypy if precision_matrix.dim() < 2: raise ValueError( "precision_matrix must be at least two-dimensional, " diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py index e5b0e128efe6..f28222f92f78 100644 --- a/torch/distributions/negative_binomial.py +++ b/torch/distributions/negative_binomial.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch import torch.nn.functional as F from torch import Tensor @@ -38,7 +40,13 @@ class NegativeBinomial(Distribution): } support = constraints.nonnegative_integer - def __init__(self, total_count, probs=None, logits=None, validate_args=None): + def __init__( + self, + total_count: Union[Tensor, float], + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -50,6 +58,7 @@ class NegativeBinomial(Distribution): ) = broadcast_all(total_count, probs) self.total_count = self.total_count.type_as(self.probs) else: + assert logits is not None # helps mypy ( self.total_count, self.logits, diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index 86e30ba450f5..626358d14795 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import Tensor @@ -51,7 +52,12 @@ class Normal(ExponentialFamily): def variance(self) -> Tensor: return self.stddev.pow(2) - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.scale = broadcast_all(loc, scale) if isinstance(loc, _Number) and isinstance(scale, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index 7e0bc03c5aba..8edb6da0b8dd 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import Tensor from torch.distributions import constraints @@ -44,7 +46,12 @@ class OneHotCategorical(Distribution): support = constraints.one_hot has_enumerate_support = True - def __init__(self, probs=None, logits=None, validate_args=None): + def __init__( + self, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: self._categorical = Categorical(probs, logits) batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index 2cc1e298ba25..bbca7e0cba35 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union from torch import Tensor from torch.distributions import constraints @@ -31,7 +31,10 @@ class Pareto(TransformedDistribution): arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive} def __init__( - self, scale: Tensor, alpha: Tensor, validate_args: Optional[bool] = None + self, + scale: Union[Tensor, float], + alpha: Union[Tensor, float], + validate_args: Optional[bool] = None, ) -> None: self.scale, self.alpha = broadcast_all(scale, alpha) base_dist = Exponential(self.alpha, validate_args=validate_args) diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py index c3b4bacc54cb..d3fb4446baf4 100644 --- a/torch/distributions/poisson.py +++ b/torch/distributions/poisson.py @@ -1,10 +1,12 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import broadcast_all -from torch.types import _Number +from torch.types import _Number, Number __all__ = ["Poisson"] @@ -45,7 +47,11 @@ class Poisson(ExponentialFamily): def variance(self) -> Tensor: return self.rate - def __init__(self, rate, validate_args=None): + def __init__( + self, + rate: Union[Tensor, Number], + validate_args: Optional[bool] = None, + ) -> None: (self.rate,) = broadcast_all(rate) if isinstance(rate, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/relaxed_bernoulli.py b/torch/distributions/relaxed_bernoulli.py index 4c1549660313..16ad4219627e 100644 --- a/torch/distributions/relaxed_bernoulli.py +++ b/torch/distributions/relaxed_bernoulli.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -12,7 +14,7 @@ from torch.distributions.utils import ( logits_to_probs, probs_to_logits, ) -from torch.types import _Number, _size +from torch.types import _Number, _size, Number __all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"] @@ -41,7 +43,13 @@ class LogitRelaxedBernoulli(Distribution): arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.real - def __init__(self, temperature, probs=None, logits=None, validate_args=None): + def __init__( + self, + temperature: Tensor, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: self.temperature = temperature if (probs is None) == (logits is None): raise ValueError( @@ -51,6 +59,7 @@ class LogitRelaxedBernoulli(Distribution): is_scalar = isinstance(probs, _Number) (self.probs,) = broadcast_all(probs) else: + assert logits is not None # helps mypy is_scalar = isinstance(logits, _Number) (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits @@ -131,8 +140,15 @@ class RelaxedBernoulli(TransformedDistribution): arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.unit_interval has_rsample = True + base_dist: LogitRelaxedBernoulli - def __init__(self, temperature, probs=None, logits=None, validate_args=None): + def __init__( + self, + temperature: Tensor, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: base_dist = LogitRelaxedBernoulli(temperature, probs, logits) super().__init__(base_dist, SigmoidTransform(), validate_args=validate_args) diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py index 97ae3ed1857b..47314be9e44a 100644 --- a/torch/distributions/relaxed_categorical.py +++ b/torch/distributions/relaxed_categorical.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import Tensor from torch.distributions import constraints @@ -42,7 +44,13 @@ class ExpRelaxedCategorical(Distribution): ) # The true support is actually a submanifold of this. has_rsample = True - def __init__(self, temperature, probs=None, logits=None, validate_args=None): + def __init__( + self, + temperature: Tensor, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: self._categorical = Categorical(probs, logits) self.temperature = temperature batch_shape = self._categorical.batch_shape @@ -121,8 +129,15 @@ class RelaxedOneHotCategorical(TransformedDistribution): arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.simplex has_rsample = True + base_dist: ExpRelaxedCategorical - def __init__(self, temperature, probs=None, logits=None, validate_args=None): + def __init__( + self, + temperature: Tensor, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: base_dist = ExpRelaxedCategorical( temperature, probs, logits, validate_args=validate_args ) diff --git a/torch/distributions/studentT.py b/torch/distributions/studentT.py index e141939b2745..d98554f413c0 100644 --- a/torch/distributions/studentT.py +++ b/torch/distributions/studentT.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import inf, nan, Tensor @@ -60,7 +61,13 @@ class StudentT(Distribution): m[self.df <= 1] = nan return m - def __init__(self, df, loc=0.0, scale=1.0, validate_args=None): + def __init__( + self, + df: Union[Tensor, float], + loc: Union[Tensor, float] = 0.0, + scale: Union[Tensor, float] = 1.0, + validate_args: Optional[bool] = None, + ) -> None: self.df, self.loc, self.scale = broadcast_all(df, loc, scale) self._chi2 = Chi2(self.df) batch_shape = self.df.size() diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 02792ce9d309..d5fbff877413 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +from typing import Optional, Union import torch from torch import Tensor @@ -49,7 +50,12 @@ class TransformedDistribution(Distribution): arg_constraints: dict[str, constraints.Constraint] = {} - def __init__(self, base_distribution, transforms, validate_args=None): + def __init__( + self, + base_distribution: Distribution, + transforms: Union[Transform, list[Transform]], + validate_args: Optional[bool] = None, + ) -> None: if isinstance(transforms, Transform): self.transforms = [ transforms, diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 8958f1a63c87..a033ce14408b 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -3,11 +3,14 @@ import functools import math import operator import weakref -from typing import Optional +from collections.abc import Sequence +from typing import Optional, Union import torch import torch.nn.functional as F +from torch import Tensor from torch.distributions import constraints +from torch.distributions.distribution import Distribution from torch.distributions.utils import ( _sum_rightmost, broadcast_all, @@ -92,7 +95,7 @@ class Transform: domain: constraints.Constraint codomain: constraints.Constraint - def __init__(self, cache_size=0): + def __init__(self, cache_size: int = 0) -> None: self._cache_size = cache_size self._inv: Optional[weakref.ReferenceType[Transform]] = None if cache_size == 0: @@ -218,7 +221,7 @@ class _InverseTransform(Transform): This class is private; please instead use the ``Transform.inv`` property. """ - def __init__(self, transform: Transform): + def __init__(self, transform: Transform) -> None: super().__init__(cache_size=transform._cache_size) self._inv: Transform = transform # type: ignore[assignment] @@ -285,7 +288,7 @@ class ComposeTransform(Transform): the latest single value is cached. Only 0 and 1 are supported. """ - def __init__(self, parts: list[Transform], cache_size=0): + def __init__(self, parts: list[Transform], cache_size: int = 0) -> None: if cache_size: parts = [part.with_cache(cache_size) for part in parts] super().__init__(cache_size=cache_size) @@ -413,7 +416,12 @@ class IndependentTransform(Transform): dimensions to treat as dependent. """ - def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0): + def __init__( + self, + base_transform: Transform, + reinterpreted_batch_ndims: int, + cache_size: int = 0, + ) -> None: super().__init__(cache_size=cache_size) self.base_transform = base_transform.with_cache(cache_size) self.reinterpreted_batch_ndims = reinterpreted_batch_ndims @@ -442,7 +450,7 @@ class IndependentTransform(Transform): return self.base_transform.bijective @property - def sign(self) -> int: # type: ignore[override] + def sign(self) -> int: return self.base_transform.sign def _call(self, x): @@ -486,7 +494,12 @@ class ReshapeTransform(Transform): bijective = True - def __init__(self, in_shape, out_shape, cache_size=0): + def __init__( + self, + in_shape: torch.Size, + out_shape: torch.Size, + cache_size: int = 0, + ) -> None: self.in_shape = torch.Size(in_shape) self.out_shape = torch.Size(out_shape) if self.in_shape.numel() != self.out_shape.numel(): @@ -571,7 +584,7 @@ class PowerTransform(Transform): codomain = constraints.positive bijective = True - def __init__(self, exponent, cache_size=0): + def __init__(self, exponent: Tensor, cache_size: int = 0) -> None: super().__init__(cache_size=cache_size) (self.exponent,) = broadcast_all(exponent) @@ -582,7 +595,7 @@ class PowerTransform(Transform): @lazy_property def sign(self) -> int: # type: ignore[override] - return self.exponent.sign() + return self.exponent.sign() # type: ignore[return-value] def __eq__(self, other): if not isinstance(other, PowerTransform): @@ -734,7 +747,13 @@ class AffineTransform(Transform): bijective = True - def __init__(self, loc, scale, event_dim=0, cache_size=0): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + event_dim: int = 0, + cache_size: int = 0, + ) -> None: super().__init__(cache_size=cache_size) self.loc = loc self.scale = scale @@ -771,20 +790,20 @@ class AffineTransform(Transform): if self.loc != other.loc: return False else: - if not (self.loc == other.loc).all().item(): + if not (self.loc == other.loc).all().item(): # type: ignore[union-attr] return False if isinstance(self.scale, _Number) and isinstance(other.scale, _Number): if self.scale != other.scale: return False else: - if not (self.scale == other.scale).all().item(): + if not (self.scale == other.scale).all().item(): # type: ignore[union-attr] return False return True @property - def sign(self) -> int: + def sign(self) -> Union[Tensor, int]: # type: ignore[override] if isinstance(self.scale, _Number): return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0 return self.scale.sign() @@ -1022,7 +1041,7 @@ class PositiveDefiniteTransform(Transform): """ domain = constraints.independent(constraints.real, 2) - codomain = constraints.positive_definite # type: ignore[assignment] + codomain = constraints.positive_definite def __eq__(self, other): return isinstance(other, PositiveDefiniteTransform) @@ -1053,7 +1072,13 @@ class CatTransform(Transform): transforms: list[Transform] - def __init__(self, tseq, dim=0, lengths=None, cache_size=0): + def __init__( + self, + tseq: Sequence[Transform], + dim: int = 0, + lengths: Optional[Sequence[int]] = None, + cache_size: int = 0, + ) -> None: assert all(isinstance(t, Transform) for t in tseq) if cache_size: tseq = [t.with_cache(cache_size) for t in tseq] @@ -1157,7 +1182,9 @@ class StackTransform(Transform): transforms: list[Transform] - def __init__(self, tseq, dim=0, cache_size=0): + def __init__( + self, tseq: Sequence[Transform], dim: int = 0, cache_size: int = 0 + ) -> None: assert all(isinstance(t, Transform) for t in tseq) if cache_size: tseq = [t.with_cache(cache_size) for t in tseq] @@ -1237,12 +1264,12 @@ class CumulativeDistributionTransform(Transform): codomain = constraints.unit_interval sign = +1 - def __init__(self, distribution, cache_size=0): + def __init__(self, distribution: Distribution, cache_size: int = 0) -> None: super().__init__(cache_size=cache_size) self.distribution = distribution @property - def domain(self) -> constraints.Constraint: # type: ignore[override] + def domain(self) -> Optional[constraints.Constraint]: # type: ignore[override] return self.distribution.support def _call(self, x): diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index 31007c924de0..37decbaadce5 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -50,7 +52,12 @@ class Uniform(Distribution): def variance(self) -> Tensor: return (self.high - self.low).pow(2) / 12 - def __init__(self, low, high, validate_args=None): + def __init__( + self, + low: Union[Tensor, float], + high: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.low, self.high = broadcast_all(low, high) if isinstance(low, _Number) and isinstance(high, _Number): diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index f83d75c904ab..b53c4721ffc7 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F from torch import Tensor from torch.overrides import is_tensor_like -from torch.types import _Number +from torch.types import _Number, Number euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant @@ -23,7 +23,9 @@ __all__ = [ ] -def broadcast_all(*values): +# FIXME: Use (*values: *Ts) -> tuple[Tensor for T in Ts] if Mapping-Type is ever added. +# See https://github.com/python/typing/issues/1216#issuecomment-2126153831 +def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]: r""" Given a list of values (possibly containing numbers), returns a list where each value is broadcasted based on the following rules: diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index 9a144fe10817..4f96a23cf55b 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional import torch import torch.jit @@ -126,7 +127,12 @@ class VonMises(Distribution): support = constraints.real has_rsample = False - def __init__(self, loc, concentration, validate_args=None): + def __init__( + self, + loc: Tensor, + concentration: Tensor, + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.concentration = broadcast_all(loc, concentration) batch_shape = self.loc.shape event_shape = torch.Size() diff --git a/torch/distributions/weibull.py b/torch/distributions/weibull.py index e7b3c5e0cebe..98132472b4ee 100644 --- a/torch/distributions/weibull.py +++ b/torch/distributions/weibull.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -34,7 +36,12 @@ class Weibull(TransformedDistribution): } support = constraints.positive - def __init__(self, scale, concentration, validate_args=None): + def __init__( + self, + scale: Union[Tensor, float], + concentration: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.scale, self.concentration = broadcast_all(scale, concentration) self.concentration_reciprocal = self.concentration.reciprocal() base_dist = Exponential( diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index 225aeeb97430..1b5a51ea88f9 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -80,8 +80,8 @@ class Wishart(ExponentialFamily): covariance_matrix: Optional[Tensor] = None, precision_matrix: Optional[Tensor] = None, scale_tril: Optional[Tensor] = None, - validate_args=None, - ): + validate_args: Optional[bool] = None, + ) -> None: assert (covariance_matrix is not None) + (scale_tril is not None) + ( precision_matrix is not None ) == 1, (