Add linalg.solve_triangular (#63568)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63568

This PR adds the first solver with structure to `linalg`. This solver
has an API compatible with that of `linalg.solve` preparing these for a
possible future merge of the APIs. The new API:
- Just returns the solution, rather than the solution and a copy of `A`
- Removes the confusing `transpose` argument and replaces it by a
correct handling of conj and strides within the call
- Adds a `left=True` kwarg. This can be achieved via transposes of the
inputs and the result, but it's exposed for convenience.

This PR also implements a dataflow that minimises the number of copies
needed before calling LAPACK / MAGMA / cuBLAS and takes advantage of the
conjugate and neg bits.

This algorithm is implemented for `solve_triangular` (which, for this, is
the most complex of all the solvers due to the `upper` parameters).
Once more solvers are added, we will factor out this calling algorithm,
so that all of them can take advantage of it.

Given the complexity of this algorithm, we implement some thorough
testing. We also added tests for all the backends, which was not done
before.

We also add forward AD support for `linalg.solve_triangular` and improve the
docs of `linalg.solve_triangular`. We also fix a few issues with those of
`torch.triangular_solve`.

Resolves https://github.com/pytorch/pytorch/issues/54258
Resolves https://github.com/pytorch/pytorch/issues/56327
Resolves https://github.com/pytorch/pytorch/issues/45734

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Test Plan: Imported from OSS

Reviewed By: zou3519, JacobSzwejbka

Differential Revision: D32283178

Pulled By: mruberry

fbshipit-source-id: deb672e6e52f58b76536ab4158073927a35e43a8
This commit is contained in:
lezcano
2021-11-18 09:44:07 -08:00
committed by Facebook GitHub Bot
parent f120335643
commit 0706607abc
17 changed files with 763 additions and 100 deletions

View File

@ -9924,11 +9924,11 @@ add_docstr(torch.triangular_solve,
r"""
triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None) -> (Tensor, Tensor)
Solves a system of equations with a triangular coefficient matrix :math:`A`
Solves a system of equations with a square upper or lower triangular invertible matrix :math:`A`
and multiple right-hand sides :math:`b`.
In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
with the default keyword arguments.
In symbols, it solves :math:`AX = b` and assumes :math:`A` is square upper-triangular
(or lower-triangular if :attr:`upper`\ `= False`) and does not have zeros on the diagonal.
`torch.triangular_solve(b, A)` can take in 2D inputs `b, A` or inputs that are
batches of 2D matrices. If the inputs are batches, then returns
@ -9945,10 +9945,9 @@ Args:
:math:`*` is zero of more batch dimensions
A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)`
where :math:`*` is zero or more batch dimensions
upper (bool, optional): whether to solve the upper-triangular system
of equations (default) or the lower-triangular system of equations. Default: ``True``.
transpose (bool, optional): whether :math:`A` should be transposed before
being sent into the solver. Default: ``False``.
upper (bool, optional): whether :math:`A` is upper or lower triangular. Default: ``True``.
transpose (bool, optional): solves `op(A)X = b` where `op(A) = A^T` if this flag is ``True``,
and `op(A) = A` if it is ``False``. Default: ``False``.
unitriangular (bool, optional): whether :math:`A` is unit triangular.
If True, the diagonal elements of :math:`A` are assumed to be
1 and not referenced from :math:`A`. Default: ``False``.