Add documentation for Cholesky lapack functions (#1816)

This commit is contained in:
Thomas Viehmann
2017-06-16 02:10:56 +02:00
committed by Adam Paszke
parent 86a96cd759
commit 97f50edf46

View File

@ -3018,17 +3018,152 @@ Example::
# """
# """)
# add_docstr(torch._C.potrf,
# """
# """)
add_docstr(torch._C.potrf,
"""
potrf(a, out=None)
potrf(a, upper, out=None)
# add_docstr(torch._C.potri,
# """
# """)
Computes the Cholesky decomposition of a positive semidefinite
matrix :attr:`a`: returns matrix `u`
If `upper` is True or not provided, `u` is upper triangular
such that :math:`a = u^T u`.
If `upper` is False, `u` is lower triangular
such that :math:`a = u u^T`.
# add_docstr(torch._C.potrs,
# """
# """)
Args:
a (Tensor): the input 2D `Tensor`, a symmetric positive semidefinite matrix
upper (bool, optional): Return upper (default) or lower triangular matrix
out (Tensor, optional): A Tensor for u
Example::
>>> a = torch.randn(3,3)
>>> a = torch.mm(a, a.t()) # make symmetric positive definite
>>> u = torch.potrf(a)
>>> a
2.3563 3.2318 -0.9406
3.2318 4.9557 -2.1618
-0.9406 -2.1618 2.2443
[torch.FloatTensor of size 3x3]
>>> u
1.5350 2.1054 -0.6127
0.0000 0.7233 -1.2053
0.0000 0.0000 0.6451
[torch.FloatTensor of size 3x3]
>>> torch.mm(u.t(),u)
2.3563 3.2318 -0.9406
3.2318 4.9557 -2.1618
-0.9406 -2.1618 2.2443
[torch.FloatTensor of size 3x3]
""")
add_docstr(torch._C.potri,
"""
potri(u, out=None)
potri(u, upper, out=None)
Computes the inverse of a positive semidefinite matrix given its
Cholesky factor :attr:`u`: returns matrix `inv`
If `upper` is True or not provided, `u` is upper triangular
such that :math:`inv = (u^T u)^{-1}`.
If `upper` is False, `u` is lower triangular
such that :math:`inv = (u u^T)^{-1}`.
Args:
u (Tensor): the input 2D `Tensor`, a upper or lower triangular Cholesky factor
upper (bool, optional): Flag if upper (default) or lower triangular matrix
out (Tensor, optional): A Tensor for inv
Example::
>>> a = torch.randn(3,3)
>>> a = torch.mm(a, a.t()) # make symmetric positive definite
>>> u = torch.potrf(a)
>>> a
2.3563 3.2318 -0.9406
3.2318 4.9557 -2.1618
-0.9406 -2.1618 2.2443
[torch.FloatTensor of size 3x3]
>>> torch.potri(u)
12.5724 -10.1765 -4.5333
-10.1765 8.5852 4.0047
-4.5333 4.0047 2.4031
[torch.FloatTensor of size 3x3]
>>> a.inverse()
12.5723 -10.1765 -4.5333
-10.1765 8.5852 4.0047
-4.5333 4.0047 2.4031
[torch.FloatTensor of size 3x3]
""")
add_docstr(torch._C.potrs,
"""
potrs(b, u, out=None)
potrs(b, u, upper, out=None)
Solves a linear system of equations with a positive semidefinite
matrix to be inverted given its given a Cholesky factor
matrix :attr:`u`: returns matrix `c`
If `upper` is True or not provided, `u` is and upper triangular
such that :math:`c = (u^T u)^{-1} b`.
If `upper` is False, `u` is and lower triangular
such that :math:`c = (u u^T)^{-1} b`.
.. note:: `b` is always a 2D `Tensor`, use `b.unsqueeze(1)` to convert a vector.
Args:
b (Tensor): the right hand side 2D `Tensor`
u (Tensor): the input 2D `Tensor`, a upper or lower triangular Cholesky factor
upper (bool, optional): Return upper (default) or lower triangular matrix
out (Tensor, optional): A Tensor for c
Example::
>>> a = torch.randn(3,3)
>>> a = torch.mm(a, a.t()) # make symmetric positive definite
>>> u = torch.potrf(a)
>>> a
2.3563 3.2318 -0.9406
3.2318 4.9557 -2.1618
-0.9406 -2.1618 2.2443
[torch.FloatTensor of size 3x3]
>>> b = torch.randn(3,2)
>>> b
-0.3119 -1.8224
-0.2798 0.1789
-0.3735 1.7451
[torch.FloatTensor of size 3x2]
>>> torch.potrs(b,u)
0.6187 -32.6438
-0.7234 27.0703
-0.6039 13.1717
[torch.FloatTensor of size 3x2]
>>> torch.mm(a.inverse(),b)
0.6187 -32.6436
-0.7234 27.0702
-0.6039 13.1717
[torch.FloatTensor of size 3x2]
""")
add_docstr(torch._C.pow,
"""
@ -3186,10 +3321,58 @@ Example::
""")
# TODO
# add_docstr(torch._C.pstrf,
# """
# """)
add_docstr(torch._C.pstrf,
"""
pstrf(a, out=None)
pstrf(a, upper, out=None)
Computes the pivoted Cholesky decomposition of a positive semidefinite
matrix :attr:`a`: returns matrices `u` and `piv`.
If `upper` is True or not provided, `u` is and upper triangular
such that :math:`a = p^T u^T u p`, with `p` the permutation given by `piv`.
If `upper` is False, `u` is and lower triangular
such that :math:`a = p^T u u^T p`.
Args:
a (Tensor): the input 2D `Tensor`
upper (bool, optional): Return upper (default) or lower triangular matrix
out (tuple, optional): A tuple of u and piv Tensors
Example::
>>> a = torch.randn(3,3)
>>> a = torch.mm(a, a.t()) # make symmetric positive definite
>>> a
5.4417 -2.5280 1.3643
-2.5280 2.9689 -2.1368
1.3643 -2.1368 4.6116
[torch.FloatTensor of size 3x3]
>>> u,piv = torch.pstrf(a)
>>> u
2.3328 0.5848 -1.0837
0.0000 2.0663 -0.7274
0.0000 0.0000 1.1249
[torch.FloatTensor of size 3x3]
>>> piv
0
2
1
[torch.IntTensor of size 3]
>>> p = torch.eye(3).index_select(0,piv.long()).index_select(0,piv.long()).t() # make pivot permutation
>>> torch.mm(torch.mm(p.t(),torch.mm(u.t(),u)),p) # reconstruct
5.4417 1.3643 -2.5280
1.3643 4.6116 -2.1368
-2.5280 -2.1368 2.9689
[torch.FloatTensor of size 3x3]
""")
add_docstr(torch._C.qr,
"""