mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the `project-excludes` field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: 0 errors (4,263 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748 Approved by: https://github.com/oulgen
89 lines
2.5 KiB
Python
89 lines
2.5 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.distributions import constraints
|
|
from torch.distributions.exp_family import ExponentialFamily
|
|
from torch.distributions.utils import broadcast_all
|
|
from torch.types import _Number, Number
|
|
|
|
|
|
__all__ = ["Poisson"]
|
|
|
|
|
|
class Poisson(ExponentialFamily):
|
|
r"""
|
|
Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
|
|
|
|
Samples are nonnegative integers, with a pmf given by
|
|
|
|
.. math::
|
|
\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +SKIP("poisson_cpu not implemented for 'Long'")
|
|
>>> m = Poisson(torch.tensor([4]))
|
|
>>> m.sample()
|
|
tensor([ 3.])
|
|
|
|
Args:
|
|
rate (Number, Tensor): the rate parameter
|
|
"""
|
|
|
|
# pyrefly: ignore # bad-override
|
|
arg_constraints = {"rate": constraints.nonnegative}
|
|
support = constraints.nonnegative_integer
|
|
|
|
@property
|
|
def mean(self) -> Tensor:
|
|
return self.rate
|
|
|
|
@property
|
|
def mode(self) -> Tensor:
|
|
return self.rate.floor()
|
|
|
|
@property
|
|
def variance(self) -> Tensor:
|
|
return self.rate
|
|
|
|
def __init__(
|
|
self,
|
|
rate: Union[Tensor, Number],
|
|
validate_args: Optional[bool] = None,
|
|
) -> None:
|
|
(self.rate,) = broadcast_all(rate)
|
|
if isinstance(rate, _Number):
|
|
batch_shape = torch.Size()
|
|
else:
|
|
batch_shape = self.rate.size()
|
|
super().__init__(batch_shape, validate_args=validate_args)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(Poisson, _instance)
|
|
batch_shape = torch.Size(batch_shape)
|
|
new.rate = self.rate.expand(batch_shape)
|
|
super(Poisson, new).__init__(batch_shape, validate_args=False)
|
|
new._validate_args = self._validate_args
|
|
return new
|
|
|
|
def sample(self, sample_shape=torch.Size()):
|
|
shape = self._extended_shape(sample_shape)
|
|
with torch.no_grad():
|
|
return torch.poisson(self.rate.expand(shape))
|
|
|
|
def log_prob(self, value):
|
|
if self._validate_args:
|
|
self._validate_sample(value)
|
|
rate, value = broadcast_all(self.rate, value)
|
|
return value.xlogy(rate) - rate - (value + 1).lgamma()
|
|
|
|
@property
|
|
def _natural_params(self) -> tuple[Tensor]:
|
|
return (torch.log(self.rate),)
|
|
|
|
# pyrefly: ignore # bad-override
|
|
def _log_normalizer(self, x):
|
|
return torch.exp(x)
|