Add linalg.solve_triangular (#63568)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63568

This PR adds the first solver with structure to `linalg`. This solver
has an API compatible with that of `linalg.solve` preparing these for a
possible future merge of the APIs. The new API:
- Just returns the solution, rather than the solution and a copy of `A`
- Removes the confusing `transpose` argument and replaces it by a
correct handling of conj and strides within the call
- Adds a `left=True` kwarg. This can be achieved via transposes of the
inputs and the result, but it's exposed for convenience.

This PR also implements a dataflow that minimises the number of copies
needed before calling LAPACK / MAGMA / cuBLAS and takes advantage of the
conjugate and neg bits.

This algorithm is implemented for `solve_triangular` (which, for this, is
the most complex of all the solvers due to the `upper` parameters).
Once more solvers are added, we will factor out this calling algorithm,
so that all of them can take advantage of it.

Given the complexity of this algorithm, we implement some thorough
testing. We also added tests for all the backends, which was not done
before.

We also add forward AD support for `linalg.solve_triangular` and improve the
docs of `linalg.solve_triangular`. We also fix a few issues with those of
`torch.triangular_solve`.

Resolves https://github.com/pytorch/pytorch/issues/54258
Resolves https://github.com/pytorch/pytorch/issues/56327
Resolves https://github.com/pytorch/pytorch/issues/45734

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D32588230

Pulled By: mruberry

fbshipit-source-id: 69e484849deb9ad7bb992cc97905df29c8915910
This commit is contained in:
lezcano
2021-11-22 12:39:30 -08:00
committed by Facebook GitHub Bot
parent a2e35e167b
commit b46c89d950
17 changed files with 763 additions and 100 deletions

View File

@ -37,6 +37,8 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
m.impl("dot.out", torch::CppFunction::makeFallthrough());
m.impl("vdot.out", torch::CppFunction::makeFallthrough());
m.impl("mm", torch::CppFunction::makeFallthrough());
m.impl("linalg_solve_triangular", torch::CppFunction::makeFallthrough());
m.impl("linalg_solve_triangular.out", torch::CppFunction::makeFallthrough());
m.impl("mm.out", torch::CppFunction::makeFallthrough());
m.impl("addmm", torch::CppFunction::makeFallthrough());
m.impl("addmm_", torch::CppFunction::makeFallthrough());

View File

@ -3801,4 +3801,260 @@ Tensor _det_lu_based_helper_backward_helper(
}
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve_triangular ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
namespace {
void checkIsMatrix(const Tensor& t,
const char* const f_name,
const char* const t_name) {
TORCH_CHECK(t.dim() >= 2, f_name, ": Expected ", t_name,
" to be a tensor of at least 2 dimensions.");
}
void checkIsSquareMatrix(const Tensor& t,
const char* const f_name,
const char* const t_name) {
checkIsMatrix(t, f_name, t_name);
TORCH_CHECK(t.size(-1) == t.size(-2),
f_name, ": Expected ", t_name,
" to be a square matrix or batch of square matrices. "
"Got matrices of size (", t.size(-2), ", ", t.size(-1), ").");
}
void checkInputsSolver(const Tensor& A,
const Tensor& B,
const Tensor& out,
const bool left,
const char* const f_name) {
checkIsSquareMatrix(A, f_name, "A");
checkIsMatrix(B, f_name, "B");
TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1),
f_name, ": Incompatible shapes of A and B for the equation ",
left ? "AX = B" : "XA = B",
" (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
}
bool is_row_or_column_contiguous(const Tensor& t) {
// This could be made more general, similar to how it's checked in matmul, which would allow to
// ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
// We choose to be conservative for simplicity
return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
}
TransposeType to_transpose_type(const bool contig, const bool conj) {
if (conj) {
if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
else { return TransposeType::ConjTranspose; }
} else {
if (contig) { return TransposeType::NoTranspose; }
else { return TransposeType::Transpose; }
}
}
} // end of anonymous namespace
/*
Solves the matrix equation AX = B for A triangular.
'left' If true solves AX = B, if false solves XA = B
'upper' controls the portion of input matrix to consider in computations,
'unitriangular' if true then we assume diag(A) to be ones
'out' The tensor with the result. If A == out, A will be modified in place
*/
Tensor& linalg_solve_triangular_out(
const Tensor& A,
const Tensor& B,
bool upper,
bool left,
bool unitriangular,
Tensor& out) {
checkInputsSolver(A, B, out, left, "linalg.solve_triangular");
Tensor A_, B_;
std::tie(B_, A_) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/nullptr);
// We'll write F-contig / F-transpose for FORTRAN contiguous / FORTRAN transpose etc
// We say that a matrix is F-ready if it's F-contig OR F-transpose
// At this point, A, B have been broadcasted but may or may not be F-ready
// The following algorithm minimises copies and allocations. In pseudocode:
// if out is wrong size:
// resize_output(out)
// # Invariant: out is the right size
// Tensor out_f; # Tensor that we will pass to FORTRAN
// if out is F-ready:
// out_f = out;
// else:
// Allocate out_f F-ready
// if B != out_f:
// copy B into out_f
// # Invariant: out_f F-ready and has B copied into it
// if out_f is F-transposed:
// transpose equation
// if out_f is conj:
// conjugate equation
// # Invariant: out_f is not conjugated and F-contig
// Tensor A_f; # Tensor that will be sent to FORTRAN
// if A is F-ready:
// if A is conj and A is not transposed:
// # We need to clone A in this case. See [Cloning A]
// clone A F-contig into A_f
// else:
// A_f = A;
// else:
// clone A F-contig into A_f
// # Invariant: out_f is F-contig and A_f is F-ready
// # We pass FORTRAN the flags indicating if A_f is transposed and or conjugated
//
// # Here we undo the conjugations / transposes on out_f if needed
//
// if out_f not same out:
// copy out_f into out
// return out
//
// Note: The logic for the negative bit is the same as that for the conjugate bit
//
// Note: [Cloning A] If we are careful when allocating B when it needs to be allocated at the
// beginning of the algorithm, it is possible to always elide the copy of A here.
// Via this trick, the algorithm will copy at most one of A or B (never both) whenever A
// and B are F-ready and not A.is_neg() (which happens almost always in practice).
// When called as f(A, B, out=B) in most practical cases it'll perform no copies.
const bool avoid_copy_A = A_.transpose(-2, -1).is_contiguous() && A_.is_conj();
if (avoid_copy_A) {
// See Note: [Cloning A]
at::native::resize_output(out, B_.sizes());
}
else {
// poorman's reimplementation of resize_output with result F-contig
if (resize_output_check(out, B_.sizes())) {
out.resize_(B_.transpose(-2, -1).sizes(), MemoryFormat::Contiguous);
out.transpose_(-2, -1); // make 'out' have Fortran contiguous memory layout
}
}
// Invariant: out has the right size, so we'll be able to copy into it later on
Tensor out_f; // the out that will go into fortran
// We use C10_LIKELY mostly for documentation as it helps following what's the most likely path
if C10_LIKELY (is_row_or_column_contiguous(out)) {
out_f = out;
if C10_LIKELY (!out.is_same(B_)) {
out_f.copy_(B_);
}
} else {
if (avoid_copy_A) {
// See Note: [Cloning A]
out_f = B_.clone(at::MemoryFormat::Contiguous);
}
else {
out_f = cloneBatchedColumnMajor(B_);
}
}
// Invariant: out_f F-ready and has B copied into it
// out_f is F-transposed
bool transpose_A = false;
bool transpose_out_f = false;
if (out_f.stride(-1) == 1) {
left = !left;
transpose_A = true;
transpose_out_f = true;
out_f.transpose_(-2 ,-1);
}
// No need to conjugate anything if out_f is conj as AX = conj(B) <=> conj(A)conj(X) = B
// and X = B after the algortihm. We just anotate that A is conjugated later on
// The solution will be written into out_f, so it'll be conjugated already
Tensor A_f = A_; // The A that will go into fortran
bool A_is_conj = A_f.is_conj() != out_f.is_conj();
bool A_is_neg = A_f.is_neg() != out_f.is_neg();
bool A_is_f_contig = (A_f.stride(-1) == 1) == transpose_A;
if C10_UNLIKELY (!is_row_or_column_contiguous(A_f)) {
// We first anotate with flags on A_f all the conj / transpose / neg coming from out
// and then we clone the resulting tensor to resolve all of them in memory
if (out_f.is_conj()) {
A_f = A_f.conj();
}
A_is_conj = false;
if (out_f.is_neg()) {
A_f = A_f._neg_view();
}
A_is_neg = false;
// This choice is to be consistent with how we flip `upper` later on
// Note that this is the same reasoning we apply for neg and conj below
// If B has neg or out or transpose, then we need to resolve it in memory
A_f = transpose_A ? A_f.clone(at::MemoryFormat::Contiguous)
: cloneBatchedColumnMajor(A_f);
A_is_f_contig = true;
} else if C10_UNLIKELY (A_is_f_contig && A_is_conj) {
if C10_UNLIKELY (A_f.is_neg() || out_f.is_neg()) {
// Cases A_is_neg (remember that B.is_neg() iff out_f.is_same(B))
// -AX = -B => A(-X) = B. Swap neg of A_f. Nothing to do on X as X.is_same(B).
// -AX = B. We resolve the neg in memory
// AX = -B => -A -X = B. We resolve the neg in memory for A,
// Since X.is_same(B), we already have that X.is_neg() == true
// We do the neg with a view, as this will be resolved in the clone below
if (out_f.is_neg()) {
A_f = A_f._neg_view();
}
A_is_neg = false;
}
// We resolve the transpose if necessary and then leave A_f F-transposed,
// as BLAS can handle the case F-transposed and conjugated
A_f = at::clone(transpose_A ? A_f.mT() : A_f, at::MemoryFormat::Contiguous);
A_is_f_contig = false;
if (transpose_A) {
upper = !upper;
}
// As we've already resolved the conj of A in the clone
A_is_conj = out_f.is_conj();
} else if C10_UNLIKELY (A_is_neg) {
// We follow the same logic as above, only that in this case we need to perform the
// negation in memory
if (out_f.is_neg()) {
A_f = -A_f;
} else {
A_f = A_f.resolve_neg();
}
A_is_neg = false;
// As we've already resolved the conj of A in the negationa bove
A_is_conj = out_f.is_conj();
}
// Invariant: out_f is F-contig and A_f is F-ready
// neg has been resolved
// If we pass the matrix physically F-transposed, we need to change the parity of upper
if (A_f.stride(-1) == 1) {
upper = !upper;
}
triangular_solve_stub(
A_f.device().type(), A_f, out_f,
/*left=*/left,
/*upper=*/upper,
/*transpose*/to_transpose_type(A_is_f_contig, A_is_conj),
/*unitriangular=*/unitriangular);
if (transpose_out_f) {
out_f.transpose_(-2, -1);
}
if (!out_f.is_same(out)) {
out.copy_(out_f);
}
return out;
}
Tensor linalg_solve_triangular(
const Tensor& A,
const Tensor& B,
bool upper,
bool left,
bool unitriangular) {
Tensor out = at::empty({0}, A.options());
linalg_solve_triangular_out(A, B, upper, left, unitriangular, out);
return out;
}
}} // namespace at::native

View File

@ -323,13 +323,16 @@ static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_bro
}
static inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
linearSolveCheckInputs(arg1, arg2, name);
// If there's no name we assume we don't want to check the errors
if (name != nullptr) {
linearSolveCheckInputs(arg1, arg2, name);
}
std::vector<int64_t> arg1_expand_size, arg2_expand_size;
std::tie(arg1_expand_size, arg2_expand_size) = at::native::_linalg_broadcast_batch_dims(arg1, arg2);
Tensor arg1_broadcasted = arg1.expand(arg1_expand_size);
Tensor arg2_broadcasted = arg2.expand(arg2_expand_size);
auto arg1_broadcasted = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size);
auto arg2_broadcasted = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size);
return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
}

View File

@ -29,6 +29,10 @@ TORCH_LIBRARY_IMPL(aten, Negative, m) {
m.impl("resolve_neg", torch::CppFunction::makeFallthrough());
m.impl("resolve_conj", torch::CppFunction::makeFallthrough());
// linear algebra functions
m.impl("linalg_solve_triangular", torch::CppFunction::makeFallthrough());
m.impl("linalg_solve_triangular.out", torch::CppFunction::makeFallthrough());
TORCH_VIEW_FNS(m)
TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
}

View File

@ -6842,6 +6842,17 @@
structured_delegate: triangular_solve.X
variants: method, function
- func: linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
dispatch:
CPU, CUDA: linalg_solve_triangular_out
- func: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor
python_module: linalg
variants: method, function
dispatch:
CPU, CUDA: linalg_solve_triangular
- func: symeig.e(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
dispatch:
CompositeExplicitAutograd: symeig_out

View File

@ -48,6 +48,7 @@ Solvers
:nosignatures:
solve
solve_triangular
lstsq
Inverses

View File

@ -12,7 +12,7 @@ from math import inf, nan, isnan
import random
from random import randrange
from itertools import product
from functools import reduce
from functools import reduce, partial
from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
@ -421,7 +421,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_cholesky(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
@ -467,7 +467,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_cholesky_errors_and_warnings(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
@ -571,7 +571,7 @@ class TestLinalg(TestCase):
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_old_cholesky_batched(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
@ -587,7 +587,7 @@ class TestLinalg(TestCase):
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@tf32_on_and_off(0.01)
def test_old_cholesky(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
@ -611,7 +611,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_old_cholesky_empty(self, device, dtype):
def run_test(upper):
A = torch.empty(0, 0, dtype=dtype, device=device)
@ -627,7 +627,7 @@ class TestLinalg(TestCase):
# it was using the lower triangular part instead of the upper one
@onlyCUDA
@skipCUDAIfNoMagma
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_old_cholesky_batched_upper(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
@ -642,7 +642,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_cholesky_ex(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
@ -673,7 +673,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_cholesky_ex_non_pd(self, device, dtype):
# if the input matrix is not positive definite, info with positive integer is returned
A = torch.eye(3, 3, dtype=dtype, device=device)
@ -699,7 +699,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_cholesky_ex_out_info_error(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
@ -902,7 +902,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
def test_eigh(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_matrix
@ -939,7 +939,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
def test_eigh_lower_uplo(self, device, dtype):
def run_test(shape, batch, uplo):
@ -957,7 +957,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_eigh_errors_and_warnings(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_matrix
@ -1012,7 +1012,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
def test_eigh_non_contiguous(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_matrix
@ -1061,7 +1061,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
def test_eigvalsh(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_matrix
@ -1086,7 +1086,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_eigvalsh_errors_and_warnings(self, device, dtype):
# eigvalsh requires a square matrix
t = torch.randn(2, 3, device=device, dtype=dtype)
@ -1125,7 +1125,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
def test_eigvalsh_non_contiguous(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_matrix
@ -1155,7 +1155,7 @@ class TestLinalg(TestCase):
run_test_permuted(shape, batch, uplo)
run_test_skipped_elements(shape, batch, uplo)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_kron(self, device, dtype):
def run_test_case(a_shape, b_shape):
@ -1176,7 +1176,7 @@ class TestLinalg(TestCase):
for a_shape, b_shape in itertools.product(shapes, reversed(shapes)):
run_test_case(a_shape, b_shape)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_kron_non_contiguous(self, device, dtype):
def run_test_transposed(a_shape, b_shape):
@ -1232,7 +1232,7 @@ class TestLinalg(TestCase):
self.assertTrue(c.is_contiguous(memory_format=torch.contiguous_format))
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_kron_empty(self, device, dtype):
def run_test_case(empty_shape):
@ -1250,7 +1250,7 @@ class TestLinalg(TestCase):
for empty_shape in empty_shapes:
run_test_case(empty_shape)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_kron_errors_and_warnings(self, device, dtype):
# if non-empty out tensor with wrong shape is passed a warning is given
a = torch.eye(3, dtype=dtype, device=device)
@ -1598,7 +1598,7 @@ class TestLinalg(TestCase):
@skipMeta # https://github.com/pytorch/pytorch/issues/53739
@skipCPUIfNoLapack
@skipCUDAIfNoMagma
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3})
def test_cond(self, device, dtype):
def run_test_case(input, p):
@ -1658,7 +1658,7 @@ class TestLinalg(TestCase):
@skipMeta # https://github.com/pytorch/pytorch/issues/53739
@skipCPUIfNoLapack
@skipCUDAIfNoMagma
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3})
def test_cond_errors_and_warnings(self, device, dtype):
norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None]
@ -2868,7 +2868,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_svd_errors_and_warnings(self, device, dtype):
for svd in [torch.svd, torch.linalg.svd]:
# if non-empty out tensor with wrong shape is passed a warning is given
@ -3006,7 +3006,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_svdvals(self, device, dtype):
def run_test(shape):
@ -3026,7 +3026,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoCusolver # MAGMA backend doesn't work in this case
@skipCUDAIfRocm
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_svd_memory_allocation(self, device, dtype):
# test for https://github.com/pytorch/pytorch/issues/61949
# the problem was that tensors of incorrect size were allocated and then narrowed
@ -3053,7 +3053,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_cholesky_solve(self, device, dtype):
@ -3064,7 +3064,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_cholesky_solve_batched(self, device, dtype):
@ -3084,7 +3084,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_cholesky_solve_batched_non_contiguous(self, device, dtype):
from numpy.linalg import solve
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
@ -3103,7 +3103,7 @@ class TestLinalg(TestCase):
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_cholesky_solve_batched_many_batches(self, device, dtype):
@ -3116,7 +3116,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_cholesky_solve_batched_broadcasting(self, device, dtype):
@ -3172,7 +3172,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_cholesky_solve_out_errors_and_warnings(self, device, dtype):
# dtypes should be safely castable
a = torch.eye(2, dtype=dtype, device=device)
@ -3199,7 +3199,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_inverse(self, device, dtype):
@ -3270,7 +3270,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_inv_ex_info_device(self, device, dtype):
A = torch.eye(3, 3, dtype=dtype, device=device)
info = torch.linalg.inv_ex(A).info
@ -3278,7 +3278,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@skipCUDAIfRocm
def test_inv_ex_singular(self, device, dtype):
# if the input matrix is not invertible, info with positive integer is returned
@ -3306,7 +3306,7 @@ class TestLinalg(TestCase):
@slowTest
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
torch.float64: 1e-5, torch.complex128: 1e-5})
def test_inverse_many_batches(self, device, dtype):
@ -3327,7 +3327,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@onlyNativeDeviceTypes # TODO: XLA doesn't raise exception
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_inverse_errors(self, device, dtype):
# inverse expects batches of square matrices as input
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
@ -3348,7 +3348,7 @@ class TestLinalg(TestCase):
@onlyNativeDeviceTypes # TODO: XLA doesn't raise exception
@skipCUDAIfRocm
@skipCUDAVersionIn([(11, 3)]) # https://github.com/pytorch/pytorch/issues/57482
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_inverse_errors_large(self, device, dtype):
# Test batched inverse of singular matrices reports errors without crashing (gh-51930)
x = torch.empty((8, 10, 616, 616), dtype=dtype, device=device)
@ -3360,7 +3360,7 @@ class TestLinalg(TestCase):
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7})
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_pinv(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
@ -3420,7 +3420,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_pinv_errors_and_warnings(self, device, dtype):
# pinv requires at least 2D tensor
a = torch.randn(1, device=device, dtype=dtype)
@ -3472,7 +3472,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_inv_errors_and_warnings(self, device, dtype):
# inv expects batches of square matrices as input
a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device)
@ -3539,7 +3539,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3})
def test_solve(self, device, dtype):
def run_test(n, batch, rhs):
@ -3586,7 +3586,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3})
def test_solve_batched_non_contiguous(self, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
@ -3600,7 +3600,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_solve_errors_and_warnings(self, device, dtype):
# solve expects batches of square matrices as input
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
@ -3666,7 +3666,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_old_solve(self, device, dtype):
for (k, n) in zip([2, 3, 5], [3, 5, 7]):
b, A = self.solve_test_helper((n,), (n, k), device, dtype)
@ -3675,7 +3675,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_old_solve_batched(self, device, dtype):
def solve_batch_helper(A_dims, b_dims):
b, A = self.solve_test_helper(A_dims, b_dims, device, dtype)
@ -3693,7 +3693,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_old_solve_batched_non_contiguous(self, device, dtype):
from numpy.linalg import solve
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
@ -3706,7 +3706,7 @@ class TestLinalg(TestCase):
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_old_solve_batched_many_batches(self, device, dtype):
for A_dims, b_dims in zip([(5, 256, 256), (3, )], [(5, 1), (512, 512, 3, 1)]):
b, A = self.solve_test_helper(A_dims, b_dims, device, dtype)
@ -3716,7 +3716,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_old_solve_batched_broadcasting(self, device, dtype):
from numpy.linalg import solve
@ -3736,7 +3736,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_old_solve_errors_and_warnings(self, device, dtype):
# dtypes should be safely castable
a = torch.eye(2, dtype=dtype, device=device)
@ -3877,7 +3877,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float: 1e-3, torch.cfloat: 1e-3})
def test_tensorinv(self, device, dtype):
@ -3907,7 +3907,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float: 1e-3, torch.cfloat: 1e-3})
def test_tensorinv_non_contiguous(self, device, dtype):
@ -3955,7 +3955,7 @@ class TestLinalg(TestCase):
@skipMeta # See https://github.com/pytorch/pytorch/issues/53739
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_tensorinv_empty(self, device, dtype):
for ind in range(1, 4):
# Check for empty inputs. NumPy does not work for these cases.
@ -3966,7 +3966,7 @@ class TestLinalg(TestCase):
@skipMeta # See https://github.com/pytorch/pytorch/issues/53739
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_tensorinv_errors_and_warnings(self, device, dtype):
def check_shape(a_shape, ind):
@ -4018,7 +4018,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_tensorinv_singular_input(self, device, dtype):
def check_singular_input(a_shape, ind):
@ -4111,7 +4111,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_matrix_rank(self, device, dtype):
matrix_rank = torch.linalg.matrix_rank
@ -4154,7 +4154,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_matrix_rank_atol(self, device, dtype):
def run_test_atol(shape0, shape1, batch):
@ -4208,7 +4208,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_matrix_rank_empty(self, device, dtype):
matrix_rank = torch.linalg.matrix_rank
@ -4245,7 +4245,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_matrix_rank_out_errors_and_warnings(self, device, dtype):
# dtypes should be safely castable
a = torch.eye(2, dtype=dtype, device=device)
@ -4271,7 +4271,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_matrix_rank_basic(self, device, dtype):
matrix_rank = torch.linalg.matrix_rank
@ -4285,7 +4285,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_old_matrix_rank(self, device, dtype):
a = torch.eye(10, dtype=dtype, device=device)
self.assertEqual(torch.matrix_rank(a).item(), 10)
@ -4399,7 +4399,7 @@ class TestLinalg(TestCase):
@precisionOverride({torch.float32: 5e-6, torch.complex64: 5e-6})
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_qr(self, device, dtype):
def run_test(tensor_dims, some):
A = torch.randn(*tensor_dims, dtype=dtype, device=device)
@ -4852,6 +4852,138 @@ class TestLinalg(TestCase):
check(x, [-1], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
check(x, [52], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
def _gen_shape_inputs_linalg_triangular_solve(self, shape, dtype, device, well_conditioned=False):
make_arg = partial(make_tensor, dtype=dtype, device=device)
make_randn = partial(torch.randn, dtype=dtype, device=device)
b, n, k = shape
for left, uni, expand_a, tr_a, conj_a, expand_b, tr_b, conj_b in product((True, False), repeat=8):
# expand means that we generate a batch of matrices with a stride of zero in the batch dimension
if (conj_a or conj_b) and not dtype.is_complex:
continue
# We just expand on the batch size
if (expand_a or expand_b) and b == 1:
continue
size_a = (b, n, n) if left else (b, k, k)
size_b = (b, n, k) if not tr_b else (b, k, n)
# If expand_a or expand_b, we'll expand them to the correct size later
if b == 1 or expand_a:
size_a = size_a[1:]
if b == 1 or expand_b:
size_b = size_b[1:]
if well_conditioned:
PLU = torch.lu_unpack(*torch.lu(make_randn(*size_a)))
if uni:
# A = L from PLU
A = PLU[1].transpose(-2, -1).contiguous()
else:
# A = U from PLU
A = PLU[2].contiguous()
else:
A = make_arg(size_a)
A.triu_()
diag = A.diagonal(0, -2, -1)
if uni:
diag.fill_(1.)
else:
diag[diag.abs() < 1e-6] = 1.
B = make_arg(size_b)
if tr_a:
A.transpose_(-2, -1)
if tr_b:
B.transpose_(-2, -1)
if conj_a:
A = A.conj()
if conj_b:
B = B.conj()
if expand_a:
A = A.expand(b, *size_a)
if expand_b:
B = B.expand(b, n, k)
yield A, B, left, not tr_a, uni
def _test_linalg_solve_triangular(self, A, B, upper, left, uni):
X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni)
if left:
self.assertEqual(A @ X, B)
else:
self.assertEqual(X @ A, B)
out = B
# B may be expanded
if not B.is_contiguous() and not B.transpose(-2, -1).is_contiguous():
out = B.clone()
torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni, out=out)
self.assertEqual(X, out)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-1, torch.complex64: 1e-1,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_linalg_solve_triangular(self, device, dtype):
# This exercises the API + BLAS CPU + batched cuBLAS
ks = (3, 1, 0)
ns = (5, 0)
bs = (1, 2, 0)
gen_inputs = self._gen_shape_inputs_linalg_triangular_solve
for b, n, k in product(bs, ns, ks):
for A, B, left, upper, uni in gen_inputs((b, n, k), dtype, device):
self._test_linalg_solve_triangular(A, B, upper, left, uni)
@onlyCUDA
@skipCUDAIfNoMagma # Magma needed for the PLU decomposition
@skipCUDAIfRocm # There is a memory access bug in rocBLAS in the (non-batched) solve_triangular
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_linalg_solve_triangular_large(self, device, dtype):
# Exercises magma and cublas
magma = (9, 513, 1)
iterative_cublas = (2, 64, 1)
gen_inputs = self._gen_shape_inputs_linalg_triangular_solve
for shape in (magma, iterative_cublas):
for A, B, left, upper, uni in gen_inputs(shape, dtype, device, well_conditioned=True):
self._test_linalg_solve_triangular(A, B, upper, left, uni)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_linalg_solve_triangular_broadcasting(self, device, dtype):
make_arg = partial(make_tensor, dtype=dtype, device=device)
sizes = (((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)),
((2, 1, 3, 4, 4), (4, 6)),
((4, 4), (2, 1, 3, 4, 2)),
((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)))
for size_A, size_B in sizes:
for left, upper, uni in itertools.product([True, False], repeat=3):
A = make_arg(size_A)
if upper:
A.triu_()
else:
A.tril_()
diag = A.diagonal(0, -2, -1)
if uni:
diag.fill_(1.)
else:
diag[diag.abs() < 1e-6] = 1.
B = make_arg(size_B)
if not left:
B.transpose_(-2, -1)
X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni)
if left:
B_other = A @ X
else:
B_other = X @ A
self.assertEqual(*torch.broadcast_tensors(B, B_other))
def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular,
device, dtype):
triangle_function = torch.triu if upper else torch.tril
@ -4866,7 +4998,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_triangular_solve(self, device, dtype):
@ -4884,7 +5016,7 @@ class TestLinalg(TestCase):
@skipCPUIfNoLapack
@skipCUDAIfNoMagma
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_triangular_solve_batched(self, device, dtype):
@ -4935,7 +5067,7 @@ class TestLinalg(TestCase):
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_triangular_solve_batched_many_batches(self, device, dtype):
@ -4966,7 +5098,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_triangular_solve_batched_broadcasting(self, device, dtype):
from scipy.linalg import solve_triangular as tri_solve
@ -5001,7 +5133,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_triangular_solve_out_errors_and_warnings(self, device, dtype):
# dtypes should be safely castable
a = torch.eye(2, dtype=dtype, device=device)
@ -5302,7 +5434,7 @@ class TestLinalg(TestCase):
@skipCPUIfNoLapack
@skipCUDAIfNoCusolver
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_ormqr(self, device, dtype):
def run_test(batch, m, n, fortran_contiguous):
@ -5346,7 +5478,7 @@ class TestLinalg(TestCase):
@skipCPUIfNoLapack
@skipCUDAIfNoCusolver
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_ormqr_errors_and_warnings(self, device, dtype):
test_cases = [
# input1 size, input2 size, input3 size, error regex
@ -5536,7 +5668,7 @@ class TestLinalg(TestCase):
@skipCPUIfNoLapack
@skipCUDAIfNoCusolver
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_householder_product(self, device, dtype):
def generate_reflectors_and_tau(A):
"""
@ -6698,7 +6830,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3})
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_pinverse(self, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value as fullrank
@ -7070,7 +7202,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_slogdet(self, device, dtype):
@ -7122,7 +7254,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_slogdet_errors_and_warnings(self, device, dtype):
# slogdet requires the input to be a square matrix or batch of square matrices
a = torch.randn(2, 3, device=device, dtype=dtype)
@ -7394,7 +7526,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_cholesky_inverse(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
@ -7442,7 +7574,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_cholesky_inverse_errors_and_warnings(self, device, dtype):
# cholesky_inverse requires the input to be at least 2 dimensional tensor
a = torch.randn(2, device=device, dtype=dtype)
@ -7626,7 +7758,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_lu_solve_batched_non_contiguous(self, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
@ -7651,7 +7783,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCPUIfNoLapack
@skipCUDAIfNoMagma
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_lu_solve(self, device, dtype):
@ -7667,7 +7799,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_lu_solve_batched(self, device, dtype):
@ -7699,7 +7831,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_lu_solve_batched_many_batches(self, device, dtype):
def run_test(A_dims, b_dims):
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
@ -7712,7 +7844,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_lu_solve_batched_broadcasting(self, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
@ -7734,7 +7866,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@onlyCUDA
@skipCUDAIfNoMagma
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
# this tests https://github.com/pytorch/pytorch/issues/36921
def test_lu_solve_large_matrices(self, device, dtype):
def run_test(A_dims, b_dims):
@ -7747,7 +7879,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_lu_solve_out_errors_and_warnings(self, device, dtype):
# dtypes should be safely castable
a = torch.eye(2, dtype=dtype, device=device)
@ -7776,7 +7908,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@precisionOverride({torch.float32: 1e-5, torch.complex64: 1e-5})
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_symeig(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_matrix
@ -7827,7 +7959,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_symeig_out_errors_and_warnings(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_matrix
@ -7965,7 +8097,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types())
def test_geqrf(self, device, dtype):
def run_test(shape):

View File

@ -571,6 +571,16 @@ def generate_tensor_like_override_tests(cls):
def instance_gen():
return TensorLike()
# FIXME The following code does not support kwonly args without defaults.
# The fix is easy, as one just needs to save these args when generating the variable
# annotated_args. The problem is that, if one does so, one finds a number
# of functions that have problematic signatures in native_functions.yaml.
# Fixing these would be BC breaking, so hence this terrible hack
# https://github.com/pytorch/pytorch/issues/67008
kwargs = {}
if hasattr(func, "__name__") and "linalg_solve_triangular" in func.__name__:
kwargs = {"upper": True}
func_args = []
is_method = is_tensor_method_or_property(func)
if func in annotated_args:
@ -633,7 +643,7 @@ def generate_tensor_like_override_tests(cls):
func_args += [instance_gen(), instance_gen()]
def test(self):
ret = func(*func_args)
ret = func(*func_args, **kwargs)
# ret is None for certain protocols, e.g., `__weakref__` and `__setitem__`
# This is currently the best check but doesn't work for, for example,
# Tensor.__add__ because it redirects to Tensor.add.

View File

@ -1461,6 +1461,10 @@
solution: triangular_solve_jvp(solution, A_p, A_t, self_t, upper, transpose, unitriangular)
cloned_coefficient: A_t
- name: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor
self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask)
result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular)
- name: tril(Tensor self, int diagonal=0) -> Tensor
self: grad.tril(diagonal)
result: auto_linear

View File

@ -107,7 +107,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
'index', 'masked_fill', 'linalg_cross', 'lu_unpack', 'renorm', '_conj_physical',
'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'cumulative_trapezoid',
'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve', '_lu_with_info',
'linalg_pinv', 'linalg_lstsq', 'col2im', 'col2im_backward', 'im2col', 'im2col_backward',
'linalg_solve_triangular', 'linalg_pinv', 'linalg_lstsq', 'col2im', 'col2im_backward', 'im2col', 'im2col_backward',
}
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {

View File

@ -9924,11 +9924,11 @@ add_docstr(torch.triangular_solve,
r"""
triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None) -> (Tensor, Tensor)
Solves a system of equations with a triangular coefficient matrix :math:`A`
Solves a system of equations with a square upper or lower triangular invertible matrix :math:`A`
and multiple right-hand sides :math:`b`.
In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
with the default keyword arguments.
In symbols, it solves :math:`AX = b` and assumes :math:`A` is square upper-triangular
(or lower-triangular if :attr:`upper`\ `= False`) and does not have zeros on the diagonal.
`torch.triangular_solve(b, A)` can take in 2D inputs `b, A` or inputs that are
batches of 2D matrices. If the inputs are batches, then returns
@ -9945,10 +9945,9 @@ Args:
:math:`*` is zero of more batch dimensions
A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)`
where :math:`*` is zero or more batch dimensions
upper (bool, optional): whether to solve the upper-triangular system
of equations (default) or the lower-triangular system of equations. Default: ``True``.
transpose (bool, optional): whether :math:`A` should be transposed before
being sent into the solver. Default: ``False``.
upper (bool, optional): whether :math:`A` is upper or lower triangular. Default: ``True``.
transpose (bool, optional): solves `op(A)X = b` where `op(A) = A^T` if this flag is ``True``,
and `op(A) = A` if it is ``False``. Default: ``False``.
unitriangular (bool, optional): whether :math:`A` is unit triangular.
If True, the diagonal elements of :math:`A` are assumed to be
1 and not referenced from :math:`A`. Default: ``False``.

View File

@ -164,6 +164,14 @@ inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& othe
return torch::linalg_solve_out(result, input, other);
}
inline Tensor solve_triangular(const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) {
return torch::linalg_solve_triangular(input, other, upper, left, unitriangular);
}
inline Tensor& solve_triangular_out(Tensor& result, const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) {
return torch::linalg_solve_triangular_out(result, input, other, upper, left, unitriangular);
}
inline std::tuple<Tensor, Tensor, Tensor> svd(const Tensor& input, bool full_matrices) {
return torch::linalg_svd(input, full_matrices);
}
@ -437,6 +445,18 @@ inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& othe
return detail::solve_out(result, input, other);
}
/// Computes a solution of a linear system AX = B for input = A and other = B whenever A is square
/// upper or lower triangular and does not have zeros in the diagonal
///
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.solve_triangular
inline Tensor solve_triangular(const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) {
return detail::solve_triangular(input, other, upper, left, unitriangular);
}
inline Tensor& solve_triangular_out(Tensor& result, const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) {
return detail::solve_triangular_out(result, input, other, upper, left, unitriangular);
}
/// Computes the singular values and singular vectors
///
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.svd

View File

@ -2660,6 +2660,7 @@ Tensor eigh_backward(const std::vector<torch::autograd::Variable> &grads, const
// This function is used for both torch.symeig and torch.linalg.eigh.
// eigh (and torch.symeig) operates only on symmetric (resp. Hermitian) inputs.
// [Note: eigh backward]
// General considerations of the differential and adjoint
// Let U(n) = {U \in C^{n x n} | U^H U = I} by the unitary group and
// Her(n) = {A \in C^{n x n} | A^H = A} be the Hermitian matrices
@ -3248,6 +3249,78 @@ Tensor triangular_solve_jvp(
);
}
Tensor linalg_solve_triangular_forward_AD(
const Tensor& A_t,
const Tensor& B_t,
const Tensor& A,
const Tensor& X,
const bool upper,
const bool left,
const bool unitriangular) {
// The forward AD formula (for left = true) is A^{-1}(B_t - A_tX)
// For the derivation see:
// [Note: Forward / Backward AD solve_triangular]
const Tensor proj_A_t = upper ? A_t.triu(static_cast<int>(unitriangular))
: A_t.tril(- static_cast<int>(unitriangular));
const Tensor X_t = B_t - (left ? at::matmul(proj_A_t, X) : at::matmul(X, proj_A_t));
return at::linalg_solve_triangular(A, X_t, upper, left, unitriangular);
}
std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
const Tensor& grad,
const Tensor& A,
const Tensor& X,
const bool upper,
const bool left,
const bool unitriangular,
std::array<bool, 2> output_mask) {
const bool A_requires_grad = output_mask[0];
const bool B_requires_grad = output_mask[1];
// [Note: Forward / Backward AD solve_triangular]
// Assume left=true for simplicity.
// Remark: A solver computes A^{-1}B
//
// Forward AD:
// If f(A) = A^{-1}, differentiating the equation A^{-1}A = I_n gives
// (df)_A(E) = -A^{-1}EA^{-1}
// As such, if g(A,B) = A^{-1}B,
// (dg)_(A,B)(E_A, E_B) = -A^{-1}E_AA^{-1}B + A^{-1}E_B
// = A^{-1}(E_B - E_AX)
// Backward AD:
// Denoting the gradients by G_A, G_B, we solve above to give
// G_B = A^{-H}G_X
// G_A = -A^{-H}G_XX^H = -G_B X^H
//
// Note that you don't need to store B for forward nor backward
//
// These formulas work for a general solver of linear equations.
// Let's prove now that when A is triangular, G_A is the projection onto the triangular matrices
// of the formula above, i.e. simply taking triu (resp. tril) in the formula above.
// This is because, since the triangular matrices form a vector space, the tangent space at any
// point is itself the space of triangular matrices. The result follows from a reasoning as that
// at the end of [Note: eigh backward]
// Something similar happens for `unitriangular`, only that int his case the tangent space is
// the set of lower-triangular matrices with zeros on the diagonal.
if (!grad.defined() || (!A_requires_grad && !B_requires_grad)) {
return std::make_tuple(Tensor{}, Tensor{});
}
// We always need to comput G_B
const Tensor A_H = A.mH();
const Tensor G_B = at::linalg_solve_triangular(A_H, grad, !upper, left, unitriangular);
if (A_requires_grad) {
const Tensor X_H = X.mH();
Tensor G_A = left ? -at::matmul(G_B, X_H) : -at::matmul(X_H, G_B);
G_A = upper ? G_A.triu(static_cast<int>(unitriangular))
: G_A.tril(- static_cast<int>(unitriangular));
return std::make_tuple(G_A, B_requires_grad ? G_B : Tensor{});
} else {
return std::make_tuple(Tensor{}, G_B);
}
}
std::tuple<Tensor, Tensor> cholesky_solve_backward(
const Tensor& grad_x, const Tensor& self,
const Tensor& input2, const Tensor& result, const bool upper) {

View File

@ -194,6 +194,22 @@ Tensor triangular_solve_jvp(
const bool transpose,
const bool unitriangular
);
Tensor linalg_solve_triangular_forward_AD(
const Tensor& A_t,
const Tensor& B_t,
const Tensor& A,
const Tensor& X,
const bool upper,
const bool left,
const bool unitriangular);
std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
const Tensor& grad,
const Tensor& A,
const Tensor& X,
const bool upper,
const bool left,
const bool unitriangular,
std::array<bool, 2> output_mask);
std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3,
IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3,
IntArrayRef sumdim, std::array<bool, 3> grad_mask);

View File

@ -1859,7 +1859,7 @@ Computes the solution of a square system of linear equations with a unique solut
Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`,
this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** associated to
:math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{m \times k}`, which is defined as
:math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{n \times k}`, which is defined as
.. math:: AX = B
@ -1883,10 +1883,19 @@ Letting `*` be zero or more batch dimensions,
This function computes `X = \ `:attr:`A`\ `.inverse() @ \ `:attr:`B` in a faster and
more numerically stable way than performing the computations separately.
.. note::
It is possible to compute the solution of the system :math:`XA = B` by passing the inputs
:attr:`A` and :attr:`B` transposed and transposing the output returned by this function.
""" + fr"""
.. note:: {common_notes["sync_note"]}
""" + r"""
.. seealso::
:func:`torch.linalg.solve_triangular` computes the solution of a triangular system of linear
equations with a unique solution.
Args:
A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions.
B (Tensor): right-hand side tensor of shape `(*, n)` or `(*, n, k)` or `(n,)` or `(n, k)`
@ -1933,6 +1942,84 @@ Examples::
https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem
""")
solve_triangular = _add_docstr(_linalg.linalg_solve_triangular, r"""
linalg.solve_triangular(A, B, *, upper, left=True, unitriangular=False, out=None) -> Tensor
Computes the solution of a triangular system of linear equations with a unique solution.
Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`,
this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system**
associated to the triangular matrix :math:`A \in \mathbb{K}^{n \times n}` without zeros on the diagonal
(that is, it is `invertible`_) and the rectangular matrix , :math:`B \in \mathbb{K}^{n \times k}`,
which is defined as
.. math:: AX = B
The argument :attr:`upper` signals whether :math:`A` is upper or lower triangular.
If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that
solves the system
.. math::
XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.}
If :attr:`upper`\ `= True` (resp. `False`) just the upper (resp. lower) triangular half of :attr:`A`
will be accessed. The elements below the main diagonal will be considered to be zero and will not be accessed.
If :attr:`unitriangular`\ `= True`, the diagonal of :attr:`A` is assumed to be ones and will not be accessed.
The result may contain `NaN` s if the diagonal of :attr:`A` contains zeros or elements that
are very close to zero and :attr:`unitriangular`\ `= False` (default) or if the input matrix
has very small eigenvalues.
Supports inputs of float, double, cfloat and cdouble dtypes.
Also supports batches of matrices, and if the inputs are batches of matrices then
the output has the same batch dimensions.
.. seealso::
:func:`torch.linalg.solve` computes the solution of a general square system of linear
equations with a unique solution.
Args:
A (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= True`)
where `*` is zero or more batch dimensions.
B (Tensor): right-hand side tensor of shape `(*, n, k)`.
Keyword args:
upper (bool): whether :attr:`A` is an upper or lower triangular matrix.
left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`.
unitriangular (bool, optional): if `True`, the diagonal elements of :attr:`A` are assumed to be
all equal to `1`. Default: `False`.
out (Tensor, optional): output tensor. `B` may be passed as `out` and the result is computed in-place on `B`.
Ignored if `None`. Default: `None`.
Examples::
>>> A = torch.randn(3, 3).triu_()
>>> b = torch.randn(3, 4)
>>> X = torch.linalg.solve_triangular(A, B, upper=True)
>>> torch.allclose(A @ X, B)
True
>>> A = torch.randn(2, 3, 3).tril_()
>>> B = torch.randn(2, 3, 4)
>>> X = torch.linalg.solve_triangular(A, B, upper=False)
>>> torch.allclose(A @ X, B)
True
>>> A = torch.randn(2, 4, 4).tril_()
>>> B = torch.randn(2, 3, 4)
>>> X = torch.linalg.solve_triangular(A, B, upper=False, left=False)
>>> torch.allclose(X @ A, B)
True
.. _invertible:
https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem
""")
tensorinv = _add_docstr(_linalg.linalg_tensorinv, r"""
linalg.tensorinv(A, ind=2, *, out=None) -> Tensor

View File

@ -984,6 +984,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.trapz: lambda y, x=None, dim=-1: -1,
torch.trapezoid: lambda y, x=None, dim=-1: -1,
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
torch.tril: lambda input, diagonal=0, out=None: -1,
torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,

View File

@ -4862,6 +4862,44 @@ def sample_inputs_linalg_solve(op_info, device, dtype, requires_grad=False, vect
out.append(SampleInput(a, args=(b,)))
return out
def sample_inputs_linalg_solve_triangular(op_info, device, dtype, requires_grad=False, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device)
bs = (1, 2, 0)
ns = (3, 0)
ks = (1, 3, 0)
def gen_inputs():
for b, n, k, (left, upper, uni) in product(bs, ns, ks, product((True, False), repeat=3)):
with torch.no_grad():
if b == 1:
A = make_arg((n, n)) if left else make_arg((k, k))
B = make_arg((n, k))
else:
A = make_arg((b, n, n)) if left else make_arg((b, k, k))
B = make_arg((b, n, k))
if uni:
# Not really necessary, but writing it for consistency
A.diagonal(0, -2, -1).fill_(1.)
else:
d = A.diagonal(0, -2, -1)
d[d.abs() < 1e-6] = 1.
if upper:
A.triu_()
else:
A.tril_()
kwargs = {"upper": upper, "left": left, "unitriangular": uni}
if requires_grad:
for grad_A, grad_B in product((True, False), repeat=2):
# Either A or B needs to have a gradient
if not grad_A and not grad_B:
continue
A.requires_grad_(grad_A)
B.requires_grad_(grad_B)
yield SampleInput(A, args=(B,), kwargs=kwargs)
else:
yield SampleInput(A, args=(B,), kwargs=kwargs)
return list(gen_inputs())
def sample_inputs_legacy_solve(op_info, device, dtype, requires_grad=False, **kwargs):
"""
@ -7598,8 +7636,7 @@ def gradcheck_wrapper_triangular_input(op, *args, upper=False, idx=0, **kwargs):
`idx` is used to specific which `args[idx]` is to be triangularized.
"""
triangular_arg = args[idx].triu() if upper else args[idx].tril()
modified_args = args[:idx] + (triangular_arg,) + args[idx + 1:]
return op(*modified_args, upper)
return op(*args[:idx], triangular_arg, *args[idx + 1:], upper, **kwargs)
def gradcheck_wrapper_masked_operation(op, input, *args, **kwargs):
@ -11174,6 +11211,13 @@ op_db: List[OpInfo] = [
check_batched_gradgrad=False,
supports_forward_ad=True,
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]),
OpInfo('linalg.solve_triangular',
aten_name='linalg_solve_triangular',
op=torch.linalg.solve_triangular,
dtypes=floating_and_complex_types(),
sample_inputs_func=sample_inputs_linalg_solve_triangular,
# linalg.solve_triangular cannot be batched over because of a call to out.copy_(result);
supports_forward_ad=True),
OpInfo('linalg.matrix_rank',
aten_name='linalg_matrix_rank',
dtypes=floating_and_complex_types(),