Apply UFMT to low traffic torch modules (#106249)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
This commit is contained in:
Edward Z. Yang
2023-07-29 10:51:26 -04:00
committed by PyTorch MergeBot
parent a4ebc61f15
commit 3bf922a6ce
163 changed files with 8472 additions and 4412 deletions

View File

@ -1,14 +1,23 @@
from functools import update_wrapper
from numbers import Number
from typing import Any, Dict
import torch
import torch.nn.functional as F
from typing import Dict, Any
from torch.overrides import is_tensor_like
euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
__all__ = ["broadcast_all", "logits_to_probs", "clamp_probs", "probs_to_logits", "lazy_property",
"tril_matrix_to_vec", "vec_to_tril_matrix"]
__all__ = [
"broadcast_all",
"logits_to_probs",
"clamp_probs",
"probs_to_logits",
"lazy_property",
"tril_matrix_to_vec",
"vec_to_tril_matrix",
]
def broadcast_all(*values):
r"""
@ -26,18 +35,20 @@ def broadcast_all(*values):
ValueError: if any of the values is not a `numbers.Number` instance,
a `torch.*Tensor` instance, or an instance implementing __torch_function__
"""
if not all(is_tensor_like(v) or isinstance(v, Number)
for v in values):
raise ValueError('Input arguments must all be instances of numbers.Number, '
'torch.Tensor or objects implementing __torch_function__.')
if not all(is_tensor_like(v) or isinstance(v, Number) for v in values):
raise ValueError(
"Input arguments must all be instances of numbers.Number, "
"torch.Tensor or objects implementing __torch_function__."
)
if not all(is_tensor_like(v) for v in values):
options: Dict[str, Any] = dict(dtype=torch.get_default_dtype())
for value in values:
if isinstance(value, torch.Tensor):
options = dict(dtype=value.dtype, device=value.device)
break
new_values = [v if is_tensor_like(v) else torch.tensor(v, **options)
for v in values]
new_values = [
v if is_tensor_like(v) else torch.tensor(v, **options) for v in values
]
return torch.broadcast_tensors(*new_values)
return torch.broadcast_tensors(*values)
@ -45,8 +56,10 @@ def broadcast_all(*values):
def _standard_normal(shape, dtype, device):
if torch._C._get_tracing_state():
# [JIT WORKAROUND] lack of support for .normal_()
return torch.normal(torch.zeros(shape, dtype=dtype, device=device),
torch.ones(shape, dtype=dtype, device=device))
return torch.normal(
torch.zeros(shape, dtype=dtype, device=device),
torch.ones(shape, dtype=dtype, device=device),
)
return torch.empty(shape, dtype=dtype, device=device).normal_()
@ -101,6 +114,7 @@ class lazy_property:
first call; thereafter replacing the wrapped method into an instance
attribute.
"""
def __init__(self, wrapped):
self.wrapped = wrapped
update_wrapper(self, wrapped)
@ -120,6 +134,7 @@ class _lazy_property_and_property(lazy_property, property):
* property when Sphinx autodoc looks
* lazy_property when Distribution validate_args looks
"""
def __init__(self, wrapped):
property.__init__(self, wrapped)
@ -131,7 +146,7 @@ def tril_matrix_to_vec(mat: torch.Tensor, diag: int = 0) -> torch.Tensor:
"""
n = mat.shape[-1]
if not torch._C._get_tracing_state() and (diag < -n or diag >= n):
raise ValueError(f'diag ({diag}) provided is outside [{-n}, {n-1}].')
raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n-1}].")
arange = torch.arange(n, device=mat.device)
tril_mask = arange < arange.view(-1, 1) + (diag + 1)
vec = mat[..., tril_mask]
@ -144,11 +159,16 @@ def vec_to_tril_matrix(vec: torch.Tensor, diag: int = 0) -> torch.Tensor:
lower triangular matrix containing elements from the vector in row order.
"""
# +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0
n = (-(1 + 2 * diag) + ((1 + 2 * diag)**2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1))**0.5) / 2
n = (
-(1 + 2 * diag)
+ ((1 + 2 * diag) ** 2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1)) ** 0.5
) / 2
eps = torch.finfo(vec.dtype).eps
if not torch._C._get_tracing_state() and (round(n) - n > eps):
raise ValueError(f'The size of last dimension is {vec.shape[-1]} which cannot be expressed as ' +
'the lower triangular part of a square D x D matrix.')
raise ValueError(
f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as "
+ "the lower triangular part of a square D x D matrix."
)
n = round(n.item()) if isinstance(n, torch.Tensor) else round(n)
mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n)))
arange = torch.arange(n, device=vec.device)