Add batched linear solver to torch.gesv() (#6100)

* Add batched linear solver to torch.gesv()

Fixes #3164
Picks up from #4502

I moved `gesv` to ATen.
Adds bindings for MAGMA's `gesv_batched` function for CUDA.
For CPU, runs `THLapack(gesv)` in a for loop.

The new function supports arbitrary batch dimensions (and broadcasting
of those dimensions). For example, the 4-d tensor `A x B x M x M` should
be treated as having batch-size `(A x B)`.

The overhead of creating the magma_queue_t is: ~350000 microseconds
the first time it's called and ~6 microseconds every time after that.

* Tests and docs

* Address comments

* Address comments

* Rebase

* Address comments

* Fix rebase

* Addressed comments

* Address comments

* Address comments

* Addressed comments
This commit is contained in:
Richard Zou
2018-05-08 17:06:27 -04:00
committed by Soumith Chintala
parent f598ef9102
commit 71626491c4
13 changed files with 510 additions and 14 deletions

View File

@ -1713,7 +1713,7 @@ Example::
add_docstr(torch.gesv,
r"""
gesv(B, A, out=None) -> (Tensor, Tensor)
torch.gesv(B, A) -> (Tensor, Tensor)
This function returns the solution to the system of linear
equations represented by :math:`AX = B` and the LU factorization of
@ -1721,21 +1721,28 @@ A, in order as a tuple `X, LU`.
`LU` contains `L` and `U` factors for LU factorization of `A`.
:attr:`A` has to be a square and non-singular matrix (2-D tensor).
`torch.gesv(B, A)` can take in 2D inputs `B, A` or inputs that are
batches of 2D matrices. If the inputs are batches, then returns
batched outputs `X, LU`.
If `A` is an :math:`(m \times m)` matrix and `B` is :math:`(m \times k)`,
the result `LU` is :math:`(m \times m)` and `X` is :math:`(m \times k)`.
.. note::
The `out` keyword only supports 2D matrix inputs, that is,
`B, A` must be 2D matrices.
.. note::
Irrespective of the original strides, the returned matrices
`X` and `LU` will be transposed, i.e. with strides `(1, m)`
instead of `(m, 1)`.
`X` and `LU` will be transposed, i.e. with strides like
`B.contiguous().transpose(-1, -2).strides()` and
`A.contiguous().transpose(-1, -2).strides()` respectively.
Args:
B (Tensor): input matrix of :math:`(m \times k)` dimensions
A (Tensor): input square matrix of :math:`(m \times m)` dimensions
out (Tensor, optional): optional output matrix
B (Tensor): input matrix of size :math:`(*, m, k)` , where `*`
is zero or more batch dimensions.
A (Tensor): input square matrix of size :math:`(*, m, m)`, where
`*` is zero or more batch dimensions.
out ((Tensor, Tensor), optional): optional output tuple.
Example::
@ -1751,6 +1758,15 @@ Example::
>>> torch.dist(B, torch.mm(A, X))
tensor(1.00000e-06 *
7.0977)
>>> # Batched solver example
>>> A = torch.randn(2, 3, 1, 4, 4)
>>> B = torch.randn(2, 3, 1, 4, 6)
>>> X, LU = torch.gesv(B, A)
>>> torch.dist(B, A.matmul(X))
tensor(1.00000e-06 *
3.6386)
""")
add_docstr(torch.get_default_dtype,