Type hints for distributions/utils (#154712)

Fixes #144196
Part of #144219

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154712
Approved by: https://github.com/Skylion007
This commit is contained in:
Randolf Scholz
2025-05-30 15:50:27 +00:00
committed by PyTorch MergeBot
parent 0f81c7a28d
commit ba3f91af97

View File

@ -1,16 +1,15 @@
# mypy: allow-untyped-defs
from collections.abc import Sequence
from functools import update_wrapper
from typing import Any, Callable, Generic, overload, Union
from typing_extensions import TypeVar
from typing import Any, Callable, Final, Generic, Optional, overload, TypeVar, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import SymInt, Tensor
from torch.overrides import is_tensor_like
from torch.types import _Number, Number
from torch.types import _dtype, _Number, Device, Number
euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
euler_constant: Final[float] = 0.57721566490153286060 # Euler Mascheroni Constant
__all__ = [
"broadcast_all",
@ -59,7 +58,11 @@ def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]:
return torch.broadcast_tensors(*values)
def _standard_normal(shape, dtype, device):
def _standard_normal(
shape: Sequence[Union[int, SymInt]],
dtype: Optional[_dtype],
device: Optional[Device],
) -> Tensor:
if torch._C._get_tracing_state():
# [JIT WORKAROUND] lack of support for .normal_()
return torch.normal(
@ -69,7 +72,7 @@ def _standard_normal(shape, dtype, device):
return torch.empty(shape, dtype=dtype, device=device).normal_()
def _sum_rightmost(value, dim):
def _sum_rightmost(value: Tensor, dim: int) -> Tensor:
r"""
Sum out ``dim`` many rightmost dimensions of a given tensor.
@ -83,7 +86,7 @@ def _sum_rightmost(value, dim):
return value.reshape(required_shape).sum(-1)
def logits_to_probs(logits, is_binary=False):
def logits_to_probs(logits: Tensor, is_binary: bool = False) -> Tensor:
r"""
Converts a tensor of logits into probabilities. Note that for the
binary case, each value denotes log odds, whereas for the
@ -95,7 +98,7 @@ def logits_to_probs(logits, is_binary=False):
return F.softmax(logits, dim=-1)
def clamp_probs(probs):
def clamp_probs(probs: Tensor) -> Tensor:
"""Clamps the probabilities to be in the open interval `(0, 1)`.
The probabilities would be clamped between `eps` and `1 - eps`,
@ -121,7 +124,7 @@ def clamp_probs(probs):
return probs.clamp(min=eps, max=1 - eps)
def probs_to_logits(probs, is_binary=False):
def probs_to_logits(probs: Tensor, is_binary: bool = False) -> Tensor:
r"""
Converts a tensor of probabilities into logits. For the binary case,
this denotes the probability of occurrence of the event indexed by `1`.