mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-28 02:04:53 +08:00
Summary: Raise and assert used to have a hard-coded error message "Exception". User provided error message was ignored. This PR adds support to represent user's error message in TorchScript. This breaks backward compatibility because now we actually need to script the user's error message, which can potentially contain unscriptable expressions. Such programs can break when scripting, but saved models can still continue to work. Increased an op count in test_mobile_optimizer.py because now we need aten::format to form the actual exception message. This is built upon an WIP PR: https://github.com/pytorch/pytorch/pull/34112 by driazati Pull Request resolved: https://github.com/pytorch/pytorch/pull/41907 Reviewed By: ngimel Differential Revision: D22778301 Pulled By: gmagogsfm fbshipit-source-id: 2b94f0db4ae9fe70c4cd03f4048e519ea96323ad
108 lines
2.5 KiB
Python
108 lines
2.5 KiB
Python
"""Various linear algebra utility methods for internal use.
|
|
|
|
"""
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
|
|
def is_sparse(A):
|
|
"""Check if tensor A is a sparse tensor"""
|
|
if isinstance(A, torch.Tensor):
|
|
return A.layout == torch.sparse_coo
|
|
|
|
error_str = "expected Tensor"
|
|
if not torch.jit.is_scripting():
|
|
error_str += " but got {}".format(type(A))
|
|
raise TypeError(error_str)
|
|
|
|
def get_floating_dtype(A):
|
|
"""Return the floating point dtype of tensor A.
|
|
|
|
Integer types map to float32.
|
|
"""
|
|
dtype = A.dtype
|
|
if dtype in (torch.float16, torch.float32, torch.float64):
|
|
return dtype
|
|
return torch.float32
|
|
|
|
|
|
def matmul(A, B):
|
|
# type: (Optional[Tensor], Tensor) -> Tensor
|
|
"""Multiply two matrices.
|
|
|
|
If A is None, return B. A can be sparse or dense. B is always
|
|
dense.
|
|
"""
|
|
if A is None:
|
|
return B
|
|
if is_sparse(A):
|
|
return torch.sparse.mm(A, B)
|
|
return torch.matmul(A, B)
|
|
|
|
|
|
def conjugate(A):
|
|
"""Return conjugate of tensor A.
|
|
|
|
.. note:: If A's dtype is not complex, A is returned.
|
|
"""
|
|
if A.is_complex():
|
|
return A.conj()
|
|
return A
|
|
|
|
|
|
def transpose(A):
|
|
"""Return transpose of a matrix or batches of matrices.
|
|
"""
|
|
ndim = len(A.shape)
|
|
return A.transpose(ndim - 1, ndim - 2)
|
|
|
|
|
|
def transjugate(A):
|
|
"""Return transpose conjugate of a matrix or batches of matrices.
|
|
"""
|
|
return conjugate(transpose(A))
|
|
|
|
|
|
def bform(X, A, Y):
|
|
# type: (Tensor, Optional[Tensor], Tensor) -> Tensor
|
|
"""Return bilinear form of matrices: :math:`X^T A Y`.
|
|
"""
|
|
return matmul(transpose(X), matmul(A, Y))
|
|
|
|
|
|
def qform(A, S):
|
|
# type: (Optional[Tensor], Tensor) -> Tensor
|
|
"""Return quadratic form :math:`S^T A S`.
|
|
"""
|
|
return bform(S, A, S)
|
|
|
|
|
|
def basis(A):
|
|
"""Return orthogonal basis of A columns.
|
|
"""
|
|
if A.is_cuda:
|
|
# torch.orgqr is not available in CUDA
|
|
Q, _ = torch.qr(A, some=True)
|
|
else:
|
|
Q = torch.orgqr(*torch.geqrf(A))
|
|
return Q
|
|
|
|
|
|
def symeig(A, largest=False, eigenvectors=True):
|
|
# type: (Tensor, Optional[bool], Optional[bool]) -> Tuple[Tensor, Tensor]
|
|
"""Return eigenpairs of A with specified ordering.
|
|
"""
|
|
if largest is None:
|
|
largest = False
|
|
if eigenvectors is None:
|
|
eigenvectors = True
|
|
E, Z = torch.symeig(A, eigenvectors, True)
|
|
# assuming that E is ordered
|
|
if largest:
|
|
E = torch.flip(E, dims=(-1,))
|
|
Z = torch.flip(Z, dims=(-1,))
|
|
return E, Z
|