Files
pytorch/torch/distributions/fishersnedecor.py
Randolf Scholz 6c38b9be73 [typing] Add type hints to __init__ methods in torch.distributions. (#144197)
Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-06 17:50:35 +00:00

108 lines
3.6 KiB
Python

# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import nan, Tensor
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.gamma import Gamma
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size
__all__ = ["FisherSnedecor"]
class FisherSnedecor(Distribution):
r"""
Creates a Fisher-Snedecor distribution parameterized by :attr:`df1` and :attr:`df2`.
Example::
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0]))
>>> m.sample() # Fisher-Snedecor-distributed with df1=1 and df2=2
tensor([ 0.2453])
Args:
df1 (float or Tensor): degrees of freedom parameter 1
df2 (float or Tensor): degrees of freedom parameter 2
"""
arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
support = constraints.positive
has_rsample = True
def __init__(
self,
df1: Union[Tensor, float],
df2: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.df1, self.df2 = broadcast_all(df1, df2)
self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
self._gamma2 = Gamma(self.df2 * 0.5, self.df2)
if isinstance(df1, _Number) and isinstance(df2, _Number):
batch_shape = torch.Size()
else:
batch_shape = self.df1.size()
super().__init__(batch_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(FisherSnedecor, _instance)
batch_shape = torch.Size(batch_shape)
new.df1 = self.df1.expand(batch_shape)
new.df2 = self.df2.expand(batch_shape)
new._gamma1 = self._gamma1.expand(batch_shape)
new._gamma2 = self._gamma2.expand(batch_shape)
super(FisherSnedecor, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
@property
def mean(self) -> Tensor:
df2 = self.df2.clone(memory_format=torch.contiguous_format)
df2[df2 <= 2] = nan
return df2 / (df2 - 2)
@property
def mode(self) -> Tensor:
mode = (self.df1 - 2) / self.df1 * self.df2 / (self.df2 + 2)
mode[self.df1 <= 2] = nan
return mode
@property
def variance(self) -> Tensor:
df2 = self.df2.clone(memory_format=torch.contiguous_format)
df2[df2 <= 4] = nan
return (
2
* df2.pow(2)
* (self.df1 + df2 - 2)
/ (self.df1 * (df2 - 2).pow(2) * (df2 - 4))
)
def rsample(self, sample_shape: _size = torch.Size(())) -> Tensor:
shape = self._extended_shape(sample_shape)
# X1 ~ Gamma(df1 / 2, 1 / df1), X2 ~ Gamma(df2 / 2, 1 / df2)
# Y = df2 * df1 * X1 / (df1 * df2 * X2) = X1 / X2 ~ F(df1, df2)
X1 = self._gamma1.rsample(sample_shape).view(shape)
X2 = self._gamma2.rsample(sample_shape).view(shape)
tiny = torch.finfo(X2.dtype).tiny
X2.clamp_(min=tiny)
Y = X1 / X2
Y.clamp_(min=tiny)
return Y
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
ct1 = self.df1 * 0.5
ct2 = self.df2 * 0.5
ct3 = self.df1 / self.df2
t1 = (ct1 + ct2).lgamma() - ct1.lgamma() - ct2.lgamma()
t2 = ct1 * ct3.log() + (ct1 - 1) * torch.log(value)
t3 = (ct1 + ct2) * torch.log1p(ct3 * value)
return t1 + t2 - t3