mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
a4ebc61f15
commit
3bf922a6ce
@ -1,14 +1,23 @@
|
||||
from functools import update_wrapper
|
||||
from numbers import Number
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Any
|
||||
from torch.overrides import is_tensor_like
|
||||
|
||||
euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
|
||||
|
||||
__all__ = ["broadcast_all", "logits_to_probs", "clamp_probs", "probs_to_logits", "lazy_property",
|
||||
"tril_matrix_to_vec", "vec_to_tril_matrix"]
|
||||
__all__ = [
|
||||
"broadcast_all",
|
||||
"logits_to_probs",
|
||||
"clamp_probs",
|
||||
"probs_to_logits",
|
||||
"lazy_property",
|
||||
"tril_matrix_to_vec",
|
||||
"vec_to_tril_matrix",
|
||||
]
|
||||
|
||||
|
||||
def broadcast_all(*values):
|
||||
r"""
|
||||
@ -26,18 +35,20 @@ def broadcast_all(*values):
|
||||
ValueError: if any of the values is not a `numbers.Number` instance,
|
||||
a `torch.*Tensor` instance, or an instance implementing __torch_function__
|
||||
"""
|
||||
if not all(is_tensor_like(v) or isinstance(v, Number)
|
||||
for v in values):
|
||||
raise ValueError('Input arguments must all be instances of numbers.Number, '
|
||||
'torch.Tensor or objects implementing __torch_function__.')
|
||||
if not all(is_tensor_like(v) or isinstance(v, Number) for v in values):
|
||||
raise ValueError(
|
||||
"Input arguments must all be instances of numbers.Number, "
|
||||
"torch.Tensor or objects implementing __torch_function__."
|
||||
)
|
||||
if not all(is_tensor_like(v) for v in values):
|
||||
options: Dict[str, Any] = dict(dtype=torch.get_default_dtype())
|
||||
for value in values:
|
||||
if isinstance(value, torch.Tensor):
|
||||
options = dict(dtype=value.dtype, device=value.device)
|
||||
break
|
||||
new_values = [v if is_tensor_like(v) else torch.tensor(v, **options)
|
||||
for v in values]
|
||||
new_values = [
|
||||
v if is_tensor_like(v) else torch.tensor(v, **options) for v in values
|
||||
]
|
||||
return torch.broadcast_tensors(*new_values)
|
||||
return torch.broadcast_tensors(*values)
|
||||
|
||||
@ -45,8 +56,10 @@ def broadcast_all(*values):
|
||||
def _standard_normal(shape, dtype, device):
|
||||
if torch._C._get_tracing_state():
|
||||
# [JIT WORKAROUND] lack of support for .normal_()
|
||||
return torch.normal(torch.zeros(shape, dtype=dtype, device=device),
|
||||
torch.ones(shape, dtype=dtype, device=device))
|
||||
return torch.normal(
|
||||
torch.zeros(shape, dtype=dtype, device=device),
|
||||
torch.ones(shape, dtype=dtype, device=device),
|
||||
)
|
||||
return torch.empty(shape, dtype=dtype, device=device).normal_()
|
||||
|
||||
|
||||
@ -101,6 +114,7 @@ class lazy_property:
|
||||
first call; thereafter replacing the wrapped method into an instance
|
||||
attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapped):
|
||||
self.wrapped = wrapped
|
||||
update_wrapper(self, wrapped)
|
||||
@ -120,6 +134,7 @@ class _lazy_property_and_property(lazy_property, property):
|
||||
* property when Sphinx autodoc looks
|
||||
* lazy_property when Distribution validate_args looks
|
||||
"""
|
||||
|
||||
def __init__(self, wrapped):
|
||||
property.__init__(self, wrapped)
|
||||
|
||||
@ -131,7 +146,7 @@ def tril_matrix_to_vec(mat: torch.Tensor, diag: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
n = mat.shape[-1]
|
||||
if not torch._C._get_tracing_state() and (diag < -n or diag >= n):
|
||||
raise ValueError(f'diag ({diag}) provided is outside [{-n}, {n-1}].')
|
||||
raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n-1}].")
|
||||
arange = torch.arange(n, device=mat.device)
|
||||
tril_mask = arange < arange.view(-1, 1) + (diag + 1)
|
||||
vec = mat[..., tril_mask]
|
||||
@ -144,11 +159,16 @@ def vec_to_tril_matrix(vec: torch.Tensor, diag: int = 0) -> torch.Tensor:
|
||||
lower triangular matrix containing elements from the vector in row order.
|
||||
"""
|
||||
# +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0
|
||||
n = (-(1 + 2 * diag) + ((1 + 2 * diag)**2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1))**0.5) / 2
|
||||
n = (
|
||||
-(1 + 2 * diag)
|
||||
+ ((1 + 2 * diag) ** 2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1)) ** 0.5
|
||||
) / 2
|
||||
eps = torch.finfo(vec.dtype).eps
|
||||
if not torch._C._get_tracing_state() and (round(n) - n > eps):
|
||||
raise ValueError(f'The size of last dimension is {vec.shape[-1]} which cannot be expressed as ' +
|
||||
'the lower triangular part of a square D x D matrix.')
|
||||
raise ValueError(
|
||||
f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as "
|
||||
+ "the lower triangular part of a square D x D matrix."
|
||||
)
|
||||
n = round(n.item()) if isinstance(n, torch.Tensor) else round(n)
|
||||
mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n)))
|
||||
arange = torch.arange(n, device=vec.device)
|
||||
|
Reference in New Issue
Block a user