mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Type hints for distributions/utils (#154712)
Fixes #144196 Part of #144219 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154712 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
0f81c7a28d
commit
ba3f91af97
@ -1,16 +1,15 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from collections.abc import Sequence
|
||||
from functools import update_wrapper
|
||||
from typing import Any, Callable, Generic, overload, Union
|
||||
from typing_extensions import TypeVar
|
||||
from typing import Any, Callable, Final, Generic, Optional, overload, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch import SymInt, Tensor
|
||||
from torch.overrides import is_tensor_like
|
||||
from torch.types import _Number, Number
|
||||
from torch.types import _dtype, _Number, Device, Number
|
||||
|
||||
|
||||
euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
|
||||
euler_constant: Final[float] = 0.57721566490153286060 # Euler Mascheroni Constant
|
||||
|
||||
__all__ = [
|
||||
"broadcast_all",
|
||||
@ -59,7 +58,11 @@ def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]:
|
||||
return torch.broadcast_tensors(*values)
|
||||
|
||||
|
||||
def _standard_normal(shape, dtype, device):
|
||||
def _standard_normal(
|
||||
shape: Sequence[Union[int, SymInt]],
|
||||
dtype: Optional[_dtype],
|
||||
device: Optional[Device],
|
||||
) -> Tensor:
|
||||
if torch._C._get_tracing_state():
|
||||
# [JIT WORKAROUND] lack of support for .normal_()
|
||||
return torch.normal(
|
||||
@ -69,7 +72,7 @@ def _standard_normal(shape, dtype, device):
|
||||
return torch.empty(shape, dtype=dtype, device=device).normal_()
|
||||
|
||||
|
||||
def _sum_rightmost(value, dim):
|
||||
def _sum_rightmost(value: Tensor, dim: int) -> Tensor:
|
||||
r"""
|
||||
Sum out ``dim`` many rightmost dimensions of a given tensor.
|
||||
|
||||
@ -83,7 +86,7 @@ def _sum_rightmost(value, dim):
|
||||
return value.reshape(required_shape).sum(-1)
|
||||
|
||||
|
||||
def logits_to_probs(logits, is_binary=False):
|
||||
def logits_to_probs(logits: Tensor, is_binary: bool = False) -> Tensor:
|
||||
r"""
|
||||
Converts a tensor of logits into probabilities. Note that for the
|
||||
binary case, each value denotes log odds, whereas for the
|
||||
@ -95,7 +98,7 @@ def logits_to_probs(logits, is_binary=False):
|
||||
return F.softmax(logits, dim=-1)
|
||||
|
||||
|
||||
def clamp_probs(probs):
|
||||
def clamp_probs(probs: Tensor) -> Tensor:
|
||||
"""Clamps the probabilities to be in the open interval `(0, 1)`.
|
||||
|
||||
The probabilities would be clamped between `eps` and `1 - eps`,
|
||||
@ -121,7 +124,7 @@ def clamp_probs(probs):
|
||||
return probs.clamp(min=eps, max=1 - eps)
|
||||
|
||||
|
||||
def probs_to_logits(probs, is_binary=False):
|
||||
def probs_to_logits(probs: Tensor, is_binary: bool = False) -> Tensor:
|
||||
r"""
|
||||
Converts a tensor of probabilities into logits. For the binary case,
|
||||
this denotes the probability of occurrence of the event indexed by `1`.
|
||||
|
Reference in New Issue
Block a user