mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Add batched linear solver to torch.gesv() (#6100)
* Add batched linear solver to torch.gesv() Fixes #3164 Picks up from #4502 I moved `gesv` to ATen. Adds bindings for MAGMA's `gesv_batched` function for CUDA. For CPU, runs `THLapack(gesv)` in a for loop. The new function supports arbitrary batch dimensions (and broadcasting of those dimensions). For example, the 4-d tensor `A x B x M x M` should be treated as having batch-size `(A x B)`. The overhead of creating the magma_queue_t is: ~350000 microseconds the first time it's called and ~6 microseconds every time after that. * Tests and docs * Address comments * Address comments * Rebase * Address comments * Fix rebase * Addressed comments * Address comments * Address comments * Addressed comments
This commit is contained in:
committed by
Soumith Chintala
parent
f598ef9102
commit
71626491c4
@ -1713,7 +1713,7 @@ Example::
|
||||
|
||||
add_docstr(torch.gesv,
|
||||
r"""
|
||||
gesv(B, A, out=None) -> (Tensor, Tensor)
|
||||
torch.gesv(B, A) -> (Tensor, Tensor)
|
||||
|
||||
This function returns the solution to the system of linear
|
||||
equations represented by :math:`AX = B` and the LU factorization of
|
||||
@ -1721,21 +1721,28 @@ A, in order as a tuple `X, LU`.
|
||||
|
||||
`LU` contains `L` and `U` factors for LU factorization of `A`.
|
||||
|
||||
:attr:`A` has to be a square and non-singular matrix (2-D tensor).
|
||||
`torch.gesv(B, A)` can take in 2D inputs `B, A` or inputs that are
|
||||
batches of 2D matrices. If the inputs are batches, then returns
|
||||
batched outputs `X, LU`.
|
||||
|
||||
If `A` is an :math:`(m \times m)` matrix and `B` is :math:`(m \times k)`,
|
||||
the result `LU` is :math:`(m \times m)` and `X` is :math:`(m \times k)`.
|
||||
.. note::
|
||||
|
||||
The `out` keyword only supports 2D matrix inputs, that is,
|
||||
`B, A` must be 2D matrices.
|
||||
|
||||
.. note::
|
||||
|
||||
Irrespective of the original strides, the returned matrices
|
||||
`X` and `LU` will be transposed, i.e. with strides `(1, m)`
|
||||
instead of `(m, 1)`.
|
||||
`X` and `LU` will be transposed, i.e. with strides like
|
||||
`B.contiguous().transpose(-1, -2).strides()` and
|
||||
`A.contiguous().transpose(-1, -2).strides()` respectively.
|
||||
|
||||
Args:
|
||||
B (Tensor): input matrix of :math:`(m \times k)` dimensions
|
||||
A (Tensor): input square matrix of :math:`(m \times m)` dimensions
|
||||
out (Tensor, optional): optional output matrix
|
||||
B (Tensor): input matrix of size :math:`(*, m, k)` , where `*`
|
||||
is zero or more batch dimensions.
|
||||
A (Tensor): input square matrix of size :math:`(*, m, m)`, where
|
||||
`*` is zero or more batch dimensions.
|
||||
out ((Tensor, Tensor), optional): optional output tuple.
|
||||
|
||||
Example::
|
||||
|
||||
@ -1751,6 +1758,15 @@ Example::
|
||||
>>> torch.dist(B, torch.mm(A, X))
|
||||
tensor(1.00000e-06 *
|
||||
7.0977)
|
||||
|
||||
>>> # Batched solver example
|
||||
>>> A = torch.randn(2, 3, 1, 4, 4)
|
||||
>>> B = torch.randn(2, 3, 1, 4, 6)
|
||||
>>> X, LU = torch.gesv(B, A)
|
||||
>>> torch.dist(B, A.matmul(X))
|
||||
tensor(1.00000e-06 *
|
||||
3.6386)
|
||||
|
||||
""")
|
||||
|
||||
add_docstr(torch.get_default_dtype,
|
||||
|
Reference in New Issue
Block a user