mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add linalg.solve_triangular (#63568)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63568 This PR adds the first solver with structure to `linalg`. This solver has an API compatible with that of `linalg.solve` preparing these for a possible future merge of the APIs. The new API: - Just returns the solution, rather than the solution and a copy of `A` - Removes the confusing `transpose` argument and replaces it by a correct handling of conj and strides within the call - Adds a `left=True` kwarg. This can be achieved via transposes of the inputs and the result, but it's exposed for convenience. This PR also implements a dataflow that minimises the number of copies needed before calling LAPACK / MAGMA / cuBLAS and takes advantage of the conjugate and neg bits. This algorithm is implemented for `solve_triangular` (which, for this, is the most complex of all the solvers due to the `upper` parameters). Once more solvers are added, we will factor out this calling algorithm, so that all of them can take advantage of it. Given the complexity of this algorithm, we implement some thorough testing. We also added tests for all the backends, which was not done before. We also add forward AD support for `linalg.solve_triangular` and improve the docs of `linalg.solve_triangular`. We also fix a few issues with those of `torch.triangular_solve`. Resolves https://github.com/pytorch/pytorch/issues/54258 Resolves https://github.com/pytorch/pytorch/issues/56327 Resolves https://github.com/pytorch/pytorch/issues/45734 cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Test Plan: Imported from OSS Reviewed By: jbschlosser Differential Revision: D32588230 Pulled By: mruberry fbshipit-source-id: 69e484849deb9ad7bb992cc97905df29c8915910
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a2e35e167b
commit
b46c89d950
@ -37,6 +37,8 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
|
||||
m.impl("dot.out", torch::CppFunction::makeFallthrough());
|
||||
m.impl("vdot.out", torch::CppFunction::makeFallthrough());
|
||||
m.impl("mm", torch::CppFunction::makeFallthrough());
|
||||
m.impl("linalg_solve_triangular", torch::CppFunction::makeFallthrough());
|
||||
m.impl("linalg_solve_triangular.out", torch::CppFunction::makeFallthrough());
|
||||
m.impl("mm.out", torch::CppFunction::makeFallthrough());
|
||||
m.impl("addmm", torch::CppFunction::makeFallthrough());
|
||||
m.impl("addmm_", torch::CppFunction::makeFallthrough());
|
||||
|
@ -3801,4 +3801,260 @@ Tensor _det_lu_based_helper_backward_helper(
|
||||
}
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve_triangular ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
namespace {
|
||||
void checkIsMatrix(const Tensor& t,
|
||||
const char* const f_name,
|
||||
const char* const t_name) {
|
||||
TORCH_CHECK(t.dim() >= 2, f_name, ": Expected ", t_name,
|
||||
" to be a tensor of at least 2 dimensions.");
|
||||
}
|
||||
|
||||
void checkIsSquareMatrix(const Tensor& t,
|
||||
const char* const f_name,
|
||||
const char* const t_name) {
|
||||
checkIsMatrix(t, f_name, t_name);
|
||||
TORCH_CHECK(t.size(-1) == t.size(-2),
|
||||
f_name, ": Expected ", t_name,
|
||||
" to be a square matrix or batch of square matrices. "
|
||||
"Got matrices of size (", t.size(-2), ", ", t.size(-1), ").");
|
||||
}
|
||||
|
||||
void checkInputsSolver(const Tensor& A,
|
||||
const Tensor& B,
|
||||
const Tensor& out,
|
||||
const bool left,
|
||||
const char* const f_name) {
|
||||
checkIsSquareMatrix(A, f_name, "A");
|
||||
checkIsMatrix(B, f_name, "B");
|
||||
TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1),
|
||||
f_name, ": Incompatible shapes of A and B for the equation ",
|
||||
left ? "AX = B" : "XA = B",
|
||||
" (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
|
||||
}
|
||||
|
||||
bool is_row_or_column_contiguous(const Tensor& t) {
|
||||
// This could be made more general, similar to how it's checked in matmul, which would allow to
|
||||
// ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
|
||||
// We choose to be conservative for simplicity
|
||||
return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
|
||||
}
|
||||
|
||||
TransposeType to_transpose_type(const bool contig, const bool conj) {
|
||||
if (conj) {
|
||||
if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
|
||||
else { return TransposeType::ConjTranspose; }
|
||||
} else {
|
||||
if (contig) { return TransposeType::NoTranspose; }
|
||||
else { return TransposeType::Transpose; }
|
||||
}
|
||||
}
|
||||
} // end of anonymous namespace
|
||||
|
||||
/*
|
||||
Solves the matrix equation AX = B for A triangular.
|
||||
'left' If true solves AX = B, if false solves XA = B
|
||||
'upper' controls the portion of input matrix to consider in computations,
|
||||
'unitriangular' if true then we assume diag(A) to be ones
|
||||
'out' The tensor with the result. If A == out, A will be modified in place
|
||||
*/
|
||||
Tensor& linalg_solve_triangular_out(
|
||||
const Tensor& A,
|
||||
const Tensor& B,
|
||||
bool upper,
|
||||
bool left,
|
||||
bool unitriangular,
|
||||
Tensor& out) {
|
||||
checkInputsSolver(A, B, out, left, "linalg.solve_triangular");
|
||||
Tensor A_, B_;
|
||||
std::tie(B_, A_) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/nullptr);
|
||||
|
||||
// We'll write F-contig / F-transpose for FORTRAN contiguous / FORTRAN transpose etc
|
||||
// We say that a matrix is F-ready if it's F-contig OR F-transpose
|
||||
// At this point, A, B have been broadcasted but may or may not be F-ready
|
||||
|
||||
// The following algorithm minimises copies and allocations. In pseudocode:
|
||||
// if out is wrong size:
|
||||
// resize_output(out)
|
||||
// # Invariant: out is the right size
|
||||
// Tensor out_f; # Tensor that we will pass to FORTRAN
|
||||
// if out is F-ready:
|
||||
// out_f = out;
|
||||
// else:
|
||||
// Allocate out_f F-ready
|
||||
// if B != out_f:
|
||||
// copy B into out_f
|
||||
// # Invariant: out_f F-ready and has B copied into it
|
||||
// if out_f is F-transposed:
|
||||
// transpose equation
|
||||
// if out_f is conj:
|
||||
// conjugate equation
|
||||
// # Invariant: out_f is not conjugated and F-contig
|
||||
// Tensor A_f; # Tensor that will be sent to FORTRAN
|
||||
// if A is F-ready:
|
||||
// if A is conj and A is not transposed:
|
||||
// # We need to clone A in this case. See [Cloning A]
|
||||
// clone A F-contig into A_f
|
||||
// else:
|
||||
// A_f = A;
|
||||
// else:
|
||||
// clone A F-contig into A_f
|
||||
// # Invariant: out_f is F-contig and A_f is F-ready
|
||||
// # We pass FORTRAN the flags indicating if A_f is transposed and or conjugated
|
||||
//
|
||||
// # Here we undo the conjugations / transposes on out_f if needed
|
||||
//
|
||||
// if out_f not same out:
|
||||
// copy out_f into out
|
||||
// return out
|
||||
//
|
||||
// Note: The logic for the negative bit is the same as that for the conjugate bit
|
||||
//
|
||||
// Note: [Cloning A] If we are careful when allocating B when it needs to be allocated at the
|
||||
// beginning of the algorithm, it is possible to always elide the copy of A here.
|
||||
// Via this trick, the algorithm will copy at most one of A or B (never both) whenever A
|
||||
// and B are F-ready and not A.is_neg() (which happens almost always in practice).
|
||||
// When called as f(A, B, out=B) in most practical cases it'll perform no copies.
|
||||
|
||||
const bool avoid_copy_A = A_.transpose(-2, -1).is_contiguous() && A_.is_conj();
|
||||
if (avoid_copy_A) {
|
||||
// See Note: [Cloning A]
|
||||
at::native::resize_output(out, B_.sizes());
|
||||
}
|
||||
else {
|
||||
// poorman's reimplementation of resize_output with result F-contig
|
||||
if (resize_output_check(out, B_.sizes())) {
|
||||
out.resize_(B_.transpose(-2, -1).sizes(), MemoryFormat::Contiguous);
|
||||
out.transpose_(-2, -1); // make 'out' have Fortran contiguous memory layout
|
||||
}
|
||||
}
|
||||
// Invariant: out has the right size, so we'll be able to copy into it later on
|
||||
|
||||
Tensor out_f; // the out that will go into fortran
|
||||
// We use C10_LIKELY mostly for documentation as it helps following what's the most likely path
|
||||
if C10_LIKELY (is_row_or_column_contiguous(out)) {
|
||||
out_f = out;
|
||||
if C10_LIKELY (!out.is_same(B_)) {
|
||||
out_f.copy_(B_);
|
||||
}
|
||||
} else {
|
||||
if (avoid_copy_A) {
|
||||
// See Note: [Cloning A]
|
||||
out_f = B_.clone(at::MemoryFormat::Contiguous);
|
||||
}
|
||||
else {
|
||||
out_f = cloneBatchedColumnMajor(B_);
|
||||
}
|
||||
}
|
||||
// Invariant: out_f F-ready and has B copied into it
|
||||
|
||||
// out_f is F-transposed
|
||||
bool transpose_A = false;
|
||||
bool transpose_out_f = false;
|
||||
if (out_f.stride(-1) == 1) {
|
||||
left = !left;
|
||||
transpose_A = true;
|
||||
transpose_out_f = true;
|
||||
out_f.transpose_(-2 ,-1);
|
||||
}
|
||||
|
||||
// No need to conjugate anything if out_f is conj as AX = conj(B) <=> conj(A)conj(X) = B
|
||||
// and X = B after the algortihm. We just anotate that A is conjugated later on
|
||||
// The solution will be written into out_f, so it'll be conjugated already
|
||||
|
||||
Tensor A_f = A_; // The A that will go into fortran
|
||||
|
||||
bool A_is_conj = A_f.is_conj() != out_f.is_conj();
|
||||
bool A_is_neg = A_f.is_neg() != out_f.is_neg();
|
||||
bool A_is_f_contig = (A_f.stride(-1) == 1) == transpose_A;
|
||||
if C10_UNLIKELY (!is_row_or_column_contiguous(A_f)) {
|
||||
// We first anotate with flags on A_f all the conj / transpose / neg coming from out
|
||||
// and then we clone the resulting tensor to resolve all of them in memory
|
||||
if (out_f.is_conj()) {
|
||||
A_f = A_f.conj();
|
||||
}
|
||||
A_is_conj = false;
|
||||
|
||||
if (out_f.is_neg()) {
|
||||
A_f = A_f._neg_view();
|
||||
}
|
||||
A_is_neg = false;
|
||||
|
||||
// This choice is to be consistent with how we flip `upper` later on
|
||||
// Note that this is the same reasoning we apply for neg and conj below
|
||||
// If B has neg or out or transpose, then we need to resolve it in memory
|
||||
A_f = transpose_A ? A_f.clone(at::MemoryFormat::Contiguous)
|
||||
: cloneBatchedColumnMajor(A_f);
|
||||
A_is_f_contig = true;
|
||||
} else if C10_UNLIKELY (A_is_f_contig && A_is_conj) {
|
||||
if C10_UNLIKELY (A_f.is_neg() || out_f.is_neg()) {
|
||||
// Cases A_is_neg (remember that B.is_neg() iff out_f.is_same(B))
|
||||
// -AX = -B => A(-X) = B. Swap neg of A_f. Nothing to do on X as X.is_same(B).
|
||||
// -AX = B. We resolve the neg in memory
|
||||
// AX = -B => -A -X = B. We resolve the neg in memory for A,
|
||||
// Since X.is_same(B), we already have that X.is_neg() == true
|
||||
|
||||
// We do the neg with a view, as this will be resolved in the clone below
|
||||
if (out_f.is_neg()) {
|
||||
A_f = A_f._neg_view();
|
||||
}
|
||||
A_is_neg = false;
|
||||
}
|
||||
// We resolve the transpose if necessary and then leave A_f F-transposed,
|
||||
// as BLAS can handle the case F-transposed and conjugated
|
||||
A_f = at::clone(transpose_A ? A_f.mT() : A_f, at::MemoryFormat::Contiguous);
|
||||
A_is_f_contig = false;
|
||||
if (transpose_A) {
|
||||
upper = !upper;
|
||||
}
|
||||
// As we've already resolved the conj of A in the clone
|
||||
A_is_conj = out_f.is_conj();
|
||||
} else if C10_UNLIKELY (A_is_neg) {
|
||||
// We follow the same logic as above, only that in this case we need to perform the
|
||||
// negation in memory
|
||||
if (out_f.is_neg()) {
|
||||
A_f = -A_f;
|
||||
} else {
|
||||
A_f = A_f.resolve_neg();
|
||||
}
|
||||
A_is_neg = false;
|
||||
// As we've already resolved the conj of A in the negationa bove
|
||||
A_is_conj = out_f.is_conj();
|
||||
}
|
||||
// Invariant: out_f is F-contig and A_f is F-ready
|
||||
// neg has been resolved
|
||||
|
||||
// If we pass the matrix physically F-transposed, we need to change the parity of upper
|
||||
if (A_f.stride(-1) == 1) {
|
||||
upper = !upper;
|
||||
}
|
||||
|
||||
triangular_solve_stub(
|
||||
A_f.device().type(), A_f, out_f,
|
||||
/*left=*/left,
|
||||
/*upper=*/upper,
|
||||
/*transpose*/to_transpose_type(A_is_f_contig, A_is_conj),
|
||||
/*unitriangular=*/unitriangular);
|
||||
|
||||
if (transpose_out_f) {
|
||||
out_f.transpose_(-2, -1);
|
||||
}
|
||||
|
||||
if (!out_f.is_same(out)) {
|
||||
out.copy_(out_f);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor linalg_solve_triangular(
|
||||
const Tensor& A,
|
||||
const Tensor& B,
|
||||
bool upper,
|
||||
bool left,
|
||||
bool unitriangular) {
|
||||
Tensor out = at::empty({0}, A.options());
|
||||
linalg_solve_triangular_out(A, B, upper, left, unitriangular, out);
|
||||
return out;
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
|
@ -323,13 +323,16 @@ static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_bro
|
||||
}
|
||||
|
||||
static inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
|
||||
linearSolveCheckInputs(arg1, arg2, name);
|
||||
// If there's no name we assume we don't want to check the errors
|
||||
if (name != nullptr) {
|
||||
linearSolveCheckInputs(arg1, arg2, name);
|
||||
}
|
||||
|
||||
std::vector<int64_t> arg1_expand_size, arg2_expand_size;
|
||||
std::tie(arg1_expand_size, arg2_expand_size) = at::native::_linalg_broadcast_batch_dims(arg1, arg2);
|
||||
|
||||
Tensor arg1_broadcasted = arg1.expand(arg1_expand_size);
|
||||
Tensor arg2_broadcasted = arg2.expand(arg2_expand_size);
|
||||
auto arg1_broadcasted = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size);
|
||||
auto arg2_broadcasted = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size);
|
||||
return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
|
||||
}
|
||||
|
||||
|
@ -29,6 +29,10 @@ TORCH_LIBRARY_IMPL(aten, Negative, m) {
|
||||
m.impl("resolve_neg", torch::CppFunction::makeFallthrough());
|
||||
m.impl("resolve_conj", torch::CppFunction::makeFallthrough());
|
||||
|
||||
// linear algebra functions
|
||||
m.impl("linalg_solve_triangular", torch::CppFunction::makeFallthrough());
|
||||
m.impl("linalg_solve_triangular.out", torch::CppFunction::makeFallthrough());
|
||||
|
||||
TORCH_VIEW_FNS(m)
|
||||
TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
|
||||
}
|
||||
|
@ -6842,6 +6842,17 @@
|
||||
structured_delegate: triangular_solve.X
|
||||
variants: method, function
|
||||
|
||||
- func: linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: linalg
|
||||
dispatch:
|
||||
CPU, CUDA: linalg_solve_triangular_out
|
||||
|
||||
- func: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor
|
||||
python_module: linalg
|
||||
variants: method, function
|
||||
dispatch:
|
||||
CPU, CUDA: linalg_solve_triangular
|
||||
|
||||
- func: symeig.e(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: symeig_out
|
||||
|
@ -48,6 +48,7 @@ Solvers
|
||||
:nosignatures:
|
||||
|
||||
solve
|
||||
solve_triangular
|
||||
lstsq
|
||||
|
||||
Inverses
|
||||
|
@ -12,7 +12,7 @@ from math import inf, nan, isnan
|
||||
import random
|
||||
from random import randrange
|
||||
from itertools import product
|
||||
from functools import reduce
|
||||
from functools import reduce, partial
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
|
||||
@ -421,7 +421,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_cholesky(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||
|
||||
@ -467,7 +467,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_cholesky_errors_and_warnings(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||
|
||||
@ -571,7 +571,7 @@ class TestLinalg(TestCase):
|
||||
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_old_cholesky_batched(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||
|
||||
@ -587,7 +587,7 @@ class TestLinalg(TestCase):
|
||||
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@tf32_on_and_off(0.01)
|
||||
def test_old_cholesky(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||
@ -611,7 +611,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_old_cholesky_empty(self, device, dtype):
|
||||
def run_test(upper):
|
||||
A = torch.empty(0, 0, dtype=dtype, device=device)
|
||||
@ -627,7 +627,7 @@ class TestLinalg(TestCase):
|
||||
# it was using the lower triangular part instead of the upper one
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_old_cholesky_batched_upper(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||
|
||||
@ -642,7 +642,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_cholesky_ex(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||
|
||||
@ -673,7 +673,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_cholesky_ex_non_pd(self, device, dtype):
|
||||
# if the input matrix is not positive definite, info with positive integer is returned
|
||||
A = torch.eye(3, 3, dtype=dtype, device=device)
|
||||
@ -699,7 +699,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_cholesky_ex_out_info_error(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||
|
||||
@ -902,7 +902,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
|
||||
def test_eigh(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_matrix
|
||||
@ -939,7 +939,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
|
||||
def test_eigh_lower_uplo(self, device, dtype):
|
||||
def run_test(shape, batch, uplo):
|
||||
@ -957,7 +957,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_eigh_errors_and_warnings(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_matrix
|
||||
|
||||
@ -1012,7 +1012,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
|
||||
def test_eigh_non_contiguous(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_matrix
|
||||
@ -1061,7 +1061,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
|
||||
def test_eigvalsh(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_matrix
|
||||
@ -1086,7 +1086,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_eigvalsh_errors_and_warnings(self, device, dtype):
|
||||
# eigvalsh requires a square matrix
|
||||
t = torch.randn(2, 3, device=device, dtype=dtype)
|
||||
@ -1125,7 +1125,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
|
||||
def test_eigvalsh_non_contiguous(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_matrix
|
||||
@ -1155,7 +1155,7 @@ class TestLinalg(TestCase):
|
||||
run_test_permuted(shape, batch, uplo)
|
||||
run_test_skipped_elements(shape, batch, uplo)
|
||||
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_kron(self, device, dtype):
|
||||
|
||||
def run_test_case(a_shape, b_shape):
|
||||
@ -1176,7 +1176,7 @@ class TestLinalg(TestCase):
|
||||
for a_shape, b_shape in itertools.product(shapes, reversed(shapes)):
|
||||
run_test_case(a_shape, b_shape)
|
||||
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_kron_non_contiguous(self, device, dtype):
|
||||
|
||||
def run_test_transposed(a_shape, b_shape):
|
||||
@ -1232,7 +1232,7 @@ class TestLinalg(TestCase):
|
||||
self.assertTrue(c.is_contiguous(memory_format=torch.contiguous_format))
|
||||
|
||||
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_kron_empty(self, device, dtype):
|
||||
|
||||
def run_test_case(empty_shape):
|
||||
@ -1250,7 +1250,7 @@ class TestLinalg(TestCase):
|
||||
for empty_shape in empty_shapes:
|
||||
run_test_case(empty_shape)
|
||||
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_kron_errors_and_warnings(self, device, dtype):
|
||||
# if non-empty out tensor with wrong shape is passed a warning is given
|
||||
a = torch.eye(3, dtype=dtype, device=device)
|
||||
@ -1598,7 +1598,7 @@ class TestLinalg(TestCase):
|
||||
@skipMeta # https://github.com/pytorch/pytorch/issues/53739
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3})
|
||||
def test_cond(self, device, dtype):
|
||||
def run_test_case(input, p):
|
||||
@ -1658,7 +1658,7 @@ class TestLinalg(TestCase):
|
||||
@skipMeta # https://github.com/pytorch/pytorch/issues/53739
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3})
|
||||
def test_cond_errors_and_warnings(self, device, dtype):
|
||||
norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None]
|
||||
@ -2868,7 +2868,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_svd_errors_and_warnings(self, device, dtype):
|
||||
for svd in [torch.svd, torch.linalg.svd]:
|
||||
# if non-empty out tensor with wrong shape is passed a warning is given
|
||||
@ -3006,7 +3006,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_svdvals(self, device, dtype):
|
||||
|
||||
def run_test(shape):
|
||||
@ -3026,7 +3026,7 @@ class TestLinalg(TestCase):
|
||||
@skipCUDAIfNoCusolver # MAGMA backend doesn't work in this case
|
||||
@skipCUDAIfRocm
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_svd_memory_allocation(self, device, dtype):
|
||||
# test for https://github.com/pytorch/pytorch/issues/61949
|
||||
# the problem was that tensors of incorrect size were allocated and then narrowed
|
||||
@ -3053,7 +3053,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_cholesky_solve(self, device, dtype):
|
||||
@ -3064,7 +3064,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_cholesky_solve_batched(self, device, dtype):
|
||||
@ -3084,7 +3084,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_cholesky_solve_batched_non_contiguous(self, device, dtype):
|
||||
from numpy.linalg import solve
|
||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||
@ -3103,7 +3103,7 @@ class TestLinalg(TestCase):
|
||||
@slowTest
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_cholesky_solve_batched_many_batches(self, device, dtype):
|
||||
@ -3116,7 +3116,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_cholesky_solve_batched_broadcasting(self, device, dtype):
|
||||
@ -3172,7 +3172,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_cholesky_solve_out_errors_and_warnings(self, device, dtype):
|
||||
# dtypes should be safely castable
|
||||
a = torch.eye(2, dtype=dtype, device=device)
|
||||
@ -3199,7 +3199,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_inverse(self, device, dtype):
|
||||
@ -3270,7 +3270,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_inv_ex_info_device(self, device, dtype):
|
||||
A = torch.eye(3, 3, dtype=dtype, device=device)
|
||||
info = torch.linalg.inv_ex(A).info
|
||||
@ -3278,7 +3278,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@skipCUDAIfRocm
|
||||
def test_inv_ex_singular(self, device, dtype):
|
||||
# if the input matrix is not invertible, info with positive integer is returned
|
||||
@ -3306,7 +3306,7 @@ class TestLinalg(TestCase):
|
||||
@slowTest
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
|
||||
torch.float64: 1e-5, torch.complex128: 1e-5})
|
||||
def test_inverse_many_batches(self, device, dtype):
|
||||
@ -3327,7 +3327,7 @@ class TestLinalg(TestCase):
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@onlyNativeDeviceTypes # TODO: XLA doesn't raise exception
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_inverse_errors(self, device, dtype):
|
||||
# inverse expects batches of square matrices as input
|
||||
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
|
||||
@ -3348,7 +3348,7 @@ class TestLinalg(TestCase):
|
||||
@onlyNativeDeviceTypes # TODO: XLA doesn't raise exception
|
||||
@skipCUDAIfRocm
|
||||
@skipCUDAVersionIn([(11, 3)]) # https://github.com/pytorch/pytorch/issues/57482
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_inverse_errors_large(self, device, dtype):
|
||||
# Test batched inverse of singular matrices reports errors without crashing (gh-51930)
|
||||
x = torch.empty((8, 10, 616, 616), dtype=dtype, device=device)
|
||||
@ -3360,7 +3360,7 @@ class TestLinalg(TestCase):
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7})
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_pinv(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||
|
||||
@ -3420,7 +3420,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_pinv_errors_and_warnings(self, device, dtype):
|
||||
# pinv requires at least 2D tensor
|
||||
a = torch.randn(1, device=device, dtype=dtype)
|
||||
@ -3472,7 +3472,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_inv_errors_and_warnings(self, device, dtype):
|
||||
# inv expects batches of square matrices as input
|
||||
a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device)
|
||||
@ -3539,7 +3539,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3})
|
||||
def test_solve(self, device, dtype):
|
||||
def run_test(n, batch, rhs):
|
||||
@ -3586,7 +3586,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3})
|
||||
def test_solve_batched_non_contiguous(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
|
||||
@ -3600,7 +3600,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_solve_errors_and_warnings(self, device, dtype):
|
||||
# solve expects batches of square matrices as input
|
||||
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
|
||||
@ -3666,7 +3666,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_old_solve(self, device, dtype):
|
||||
for (k, n) in zip([2, 3, 5], [3, 5, 7]):
|
||||
b, A = self.solve_test_helper((n,), (n, k), device, dtype)
|
||||
@ -3675,7 +3675,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_old_solve_batched(self, device, dtype):
|
||||
def solve_batch_helper(A_dims, b_dims):
|
||||
b, A = self.solve_test_helper(A_dims, b_dims, device, dtype)
|
||||
@ -3693,7 +3693,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_old_solve_batched_non_contiguous(self, device, dtype):
|
||||
from numpy.linalg import solve
|
||||
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
|
||||
@ -3706,7 +3706,7 @@ class TestLinalg(TestCase):
|
||||
@slowTest
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_old_solve_batched_many_batches(self, device, dtype):
|
||||
for A_dims, b_dims in zip([(5, 256, 256), (3, )], [(5, 1), (512, 512, 3, 1)]):
|
||||
b, A = self.solve_test_helper(A_dims, b_dims, device, dtype)
|
||||
@ -3716,7 +3716,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_old_solve_batched_broadcasting(self, device, dtype):
|
||||
from numpy.linalg import solve
|
||||
|
||||
@ -3736,7 +3736,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_old_solve_errors_and_warnings(self, device, dtype):
|
||||
# dtypes should be safely castable
|
||||
a = torch.eye(2, dtype=dtype, device=device)
|
||||
@ -3877,7 +3877,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float: 1e-3, torch.cfloat: 1e-3})
|
||||
def test_tensorinv(self, device, dtype):
|
||||
|
||||
@ -3907,7 +3907,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float: 1e-3, torch.cfloat: 1e-3})
|
||||
def test_tensorinv_non_contiguous(self, device, dtype):
|
||||
|
||||
@ -3955,7 +3955,7 @@ class TestLinalg(TestCase):
|
||||
@skipMeta # See https://github.com/pytorch/pytorch/issues/53739
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_tensorinv_empty(self, device, dtype):
|
||||
for ind in range(1, 4):
|
||||
# Check for empty inputs. NumPy does not work for these cases.
|
||||
@ -3966,7 +3966,7 @@ class TestLinalg(TestCase):
|
||||
@skipMeta # See https://github.com/pytorch/pytorch/issues/53739
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_tensorinv_errors_and_warnings(self, device, dtype):
|
||||
|
||||
def check_shape(a_shape, ind):
|
||||
@ -4018,7 +4018,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_tensorinv_singular_input(self, device, dtype):
|
||||
|
||||
def check_singular_input(a_shape, ind):
|
||||
@ -4111,7 +4111,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_matrix_rank(self, device, dtype):
|
||||
matrix_rank = torch.linalg.matrix_rank
|
||||
|
||||
@ -4154,7 +4154,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_matrix_rank_atol(self, device, dtype):
|
||||
|
||||
def run_test_atol(shape0, shape1, batch):
|
||||
@ -4208,7 +4208,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_matrix_rank_empty(self, device, dtype):
|
||||
matrix_rank = torch.linalg.matrix_rank
|
||||
|
||||
@ -4245,7 +4245,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_matrix_rank_out_errors_and_warnings(self, device, dtype):
|
||||
# dtypes should be safely castable
|
||||
a = torch.eye(2, dtype=dtype, device=device)
|
||||
@ -4271,7 +4271,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_matrix_rank_basic(self, device, dtype):
|
||||
matrix_rank = torch.linalg.matrix_rank
|
||||
|
||||
@ -4285,7 +4285,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_old_matrix_rank(self, device, dtype):
|
||||
a = torch.eye(10, dtype=dtype, device=device)
|
||||
self.assertEqual(torch.matrix_rank(a).item(), 10)
|
||||
@ -4399,7 +4399,7 @@ class TestLinalg(TestCase):
|
||||
@precisionOverride({torch.float32: 5e-6, torch.complex64: 5e-6})
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_qr(self, device, dtype):
|
||||
def run_test(tensor_dims, some):
|
||||
A = torch.randn(*tensor_dims, dtype=dtype, device=device)
|
||||
@ -4852,6 +4852,138 @@ class TestLinalg(TestCase):
|
||||
check(x, [-1], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
|
||||
check(x, [52], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
|
||||
|
||||
def _gen_shape_inputs_linalg_triangular_solve(self, shape, dtype, device, well_conditioned=False):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device)
|
||||
make_randn = partial(torch.randn, dtype=dtype, device=device)
|
||||
b, n, k = shape
|
||||
for left, uni, expand_a, tr_a, conj_a, expand_b, tr_b, conj_b in product((True, False), repeat=8):
|
||||
# expand means that we generate a batch of matrices with a stride of zero in the batch dimension
|
||||
if (conj_a or conj_b) and not dtype.is_complex:
|
||||
continue
|
||||
# We just expand on the batch size
|
||||
if (expand_a or expand_b) and b == 1:
|
||||
continue
|
||||
|
||||
size_a = (b, n, n) if left else (b, k, k)
|
||||
size_b = (b, n, k) if not tr_b else (b, k, n)
|
||||
|
||||
# If expand_a or expand_b, we'll expand them to the correct size later
|
||||
if b == 1 or expand_a:
|
||||
size_a = size_a[1:]
|
||||
if b == 1 or expand_b:
|
||||
size_b = size_b[1:]
|
||||
|
||||
if well_conditioned:
|
||||
PLU = torch.lu_unpack(*torch.lu(make_randn(*size_a)))
|
||||
if uni:
|
||||
# A = L from PLU
|
||||
A = PLU[1].transpose(-2, -1).contiguous()
|
||||
else:
|
||||
# A = U from PLU
|
||||
A = PLU[2].contiguous()
|
||||
else:
|
||||
A = make_arg(size_a)
|
||||
A.triu_()
|
||||
|
||||
diag = A.diagonal(0, -2, -1)
|
||||
if uni:
|
||||
diag.fill_(1.)
|
||||
else:
|
||||
diag[diag.abs() < 1e-6] = 1.
|
||||
|
||||
B = make_arg(size_b)
|
||||
|
||||
if tr_a:
|
||||
A.transpose_(-2, -1)
|
||||
if tr_b:
|
||||
B.transpose_(-2, -1)
|
||||
if conj_a:
|
||||
A = A.conj()
|
||||
if conj_b:
|
||||
B = B.conj()
|
||||
if expand_a:
|
||||
A = A.expand(b, *size_a)
|
||||
if expand_b:
|
||||
B = B.expand(b, n, k)
|
||||
yield A, B, left, not tr_a, uni
|
||||
|
||||
def _test_linalg_solve_triangular(self, A, B, upper, left, uni):
|
||||
X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni)
|
||||
if left:
|
||||
self.assertEqual(A @ X, B)
|
||||
else:
|
||||
self.assertEqual(X @ A, B)
|
||||
out = B
|
||||
# B may be expanded
|
||||
if not B.is_contiguous() and not B.transpose(-2, -1).is_contiguous():
|
||||
out = B.clone()
|
||||
torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni, out=out)
|
||||
self.assertEqual(X, out)
|
||||
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-1, torch.complex64: 1e-1,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_linalg_solve_triangular(self, device, dtype):
|
||||
# This exercises the API + BLAS CPU + batched cuBLAS
|
||||
ks = (3, 1, 0)
|
||||
ns = (5, 0)
|
||||
bs = (1, 2, 0)
|
||||
|
||||
gen_inputs = self._gen_shape_inputs_linalg_triangular_solve
|
||||
for b, n, k in product(bs, ns, ks):
|
||||
for A, B, left, upper, uni in gen_inputs((b, n, k), dtype, device):
|
||||
self._test_linalg_solve_triangular(A, B, upper, left, uni)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNoMagma # Magma needed for the PLU decomposition
|
||||
@skipCUDAIfRocm # There is a memory access bug in rocBLAS in the (non-batched) solve_triangular
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_linalg_solve_triangular_large(self, device, dtype):
|
||||
# Exercises magma and cublas
|
||||
magma = (9, 513, 1)
|
||||
iterative_cublas = (2, 64, 1)
|
||||
|
||||
gen_inputs = self._gen_shape_inputs_linalg_triangular_solve
|
||||
for shape in (magma, iterative_cublas):
|
||||
for A, B, left, upper, uni in gen_inputs(shape, dtype, device, well_conditioned=True):
|
||||
self._test_linalg_solve_triangular(A, B, upper, left, uni)
|
||||
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_linalg_solve_triangular_broadcasting(self, device, dtype):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device)
|
||||
|
||||
sizes = (((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)),
|
||||
((2, 1, 3, 4, 4), (4, 6)),
|
||||
((4, 4), (2, 1, 3, 4, 2)),
|
||||
((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)))
|
||||
for size_A, size_B in sizes:
|
||||
for left, upper, uni in itertools.product([True, False], repeat=3):
|
||||
A = make_arg(size_A)
|
||||
if upper:
|
||||
A.triu_()
|
||||
else:
|
||||
A.tril_()
|
||||
diag = A.diagonal(0, -2, -1)
|
||||
if uni:
|
||||
diag.fill_(1.)
|
||||
else:
|
||||
diag[diag.abs() < 1e-6] = 1.
|
||||
B = make_arg(size_B)
|
||||
if not left:
|
||||
B.transpose_(-2, -1)
|
||||
|
||||
X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni)
|
||||
if left:
|
||||
B_other = A @ X
|
||||
else:
|
||||
B_other = X @ A
|
||||
|
||||
self.assertEqual(*torch.broadcast_tensors(B, B_other))
|
||||
|
||||
def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular,
|
||||
device, dtype):
|
||||
triangle_function = torch.triu if upper else torch.tril
|
||||
@ -4866,7 +4998,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_triangular_solve(self, device, dtype):
|
||||
@ -4884,7 +5016,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_triangular_solve_batched(self, device, dtype):
|
||||
@ -4935,7 +5067,7 @@ class TestLinalg(TestCase):
|
||||
@slowTest
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_triangular_solve_batched_many_batches(self, device, dtype):
|
||||
@ -4966,7 +5098,7 @@ class TestLinalg(TestCase):
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_triangular_solve_batched_broadcasting(self, device, dtype):
|
||||
from scipy.linalg import solve_triangular as tri_solve
|
||||
|
||||
@ -5001,7 +5133,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_triangular_solve_out_errors_and_warnings(self, device, dtype):
|
||||
# dtypes should be safely castable
|
||||
a = torch.eye(2, dtype=dtype, device=device)
|
||||
@ -5302,7 +5434,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoCusolver
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_ormqr(self, device, dtype):
|
||||
|
||||
def run_test(batch, m, n, fortran_contiguous):
|
||||
@ -5346,7 +5478,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoCusolver
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_ormqr_errors_and_warnings(self, device, dtype):
|
||||
test_cases = [
|
||||
# input1 size, input2 size, input3 size, error regex
|
||||
@ -5536,7 +5668,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoCusolver
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_householder_product(self, device, dtype):
|
||||
def generate_reflectors_and_tau(A):
|
||||
"""
|
||||
@ -6698,7 +6830,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
@precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3})
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_pinverse(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value as fullrank
|
||||
|
||||
@ -7070,7 +7202,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_slogdet(self, device, dtype):
|
||||
@ -7122,7 +7254,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_slogdet_errors_and_warnings(self, device, dtype):
|
||||
# slogdet requires the input to be a square matrix or batch of square matrices
|
||||
a = torch.randn(2, 3, device=device, dtype=dtype)
|
||||
@ -7394,7 +7526,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_cholesky_inverse(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||
|
||||
@ -7442,7 +7574,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_cholesky_inverse_errors_and_warnings(self, device, dtype):
|
||||
# cholesky_inverse requires the input to be at least 2 dimensional tensor
|
||||
a = torch.randn(2, device=device, dtype=dtype)
|
||||
@ -7626,7 +7758,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_lu_solve_batched_non_contiguous(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
|
||||
|
||||
@ -7651,7 +7783,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_lu_solve(self, device, dtype):
|
||||
@ -7667,7 +7799,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
|
||||
torch.float64: 1e-8, torch.complex128: 1e-8})
|
||||
def test_lu_solve_batched(self, device, dtype):
|
||||
@ -7699,7 +7831,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
@slowTest
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_lu_solve_batched_many_batches(self, device, dtype):
|
||||
def run_test(A_dims, b_dims):
|
||||
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
|
||||
@ -7712,7 +7844,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_lu_solve_batched_broadcasting(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
|
||||
|
||||
@ -7734,7 +7866,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
# this tests https://github.com/pytorch/pytorch/issues/36921
|
||||
def test_lu_solve_large_matrices(self, device, dtype):
|
||||
def run_test(A_dims, b_dims):
|
||||
@ -7747,7 +7879,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_lu_solve_out_errors_and_warnings(self, device, dtype):
|
||||
# dtypes should be safely castable
|
||||
a = torch.eye(2, dtype=dtype, device=device)
|
||||
@ -7776,7 +7908,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
@precisionOverride({torch.float32: 1e-5, torch.complex64: 1e-5})
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_symeig(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_matrix
|
||||
|
||||
@ -7827,7 +7959,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_symeig_out_errors_and_warnings(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_hermitian_matrix
|
||||
|
||||
@ -7965,7 +8097,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_geqrf(self, device, dtype):
|
||||
|
||||
def run_test(shape):
|
||||
|
@ -571,6 +571,16 @@ def generate_tensor_like_override_tests(cls):
|
||||
def instance_gen():
|
||||
return TensorLike()
|
||||
|
||||
# FIXME The following code does not support kwonly args without defaults.
|
||||
# The fix is easy, as one just needs to save these args when generating the variable
|
||||
# annotated_args. The problem is that, if one does so, one finds a number
|
||||
# of functions that have problematic signatures in native_functions.yaml.
|
||||
# Fixing these would be BC breaking, so hence this terrible hack
|
||||
# https://github.com/pytorch/pytorch/issues/67008
|
||||
kwargs = {}
|
||||
if hasattr(func, "__name__") and "linalg_solve_triangular" in func.__name__:
|
||||
kwargs = {"upper": True}
|
||||
|
||||
func_args = []
|
||||
is_method = is_tensor_method_or_property(func)
|
||||
if func in annotated_args:
|
||||
@ -633,7 +643,7 @@ def generate_tensor_like_override_tests(cls):
|
||||
func_args += [instance_gen(), instance_gen()]
|
||||
|
||||
def test(self):
|
||||
ret = func(*func_args)
|
||||
ret = func(*func_args, **kwargs)
|
||||
# ret is None for certain protocols, e.g., `__weakref__` and `__setitem__`
|
||||
# This is currently the best check but doesn't work for, for example,
|
||||
# Tensor.__add__ because it redirects to Tensor.add.
|
||||
|
@ -1461,6 +1461,10 @@
|
||||
solution: triangular_solve_jvp(solution, A_p, A_t, self_t, upper, transpose, unitriangular)
|
||||
cloned_coefficient: A_t
|
||||
|
||||
- name: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor
|
||||
self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask)
|
||||
result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular)
|
||||
|
||||
- name: tril(Tensor self, int diagonal=0) -> Tensor
|
||||
self: grad.tril(diagonal)
|
||||
result: auto_linear
|
||||
|
@ -107,7 +107,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
||||
'index', 'masked_fill', 'linalg_cross', 'lu_unpack', 'renorm', '_conj_physical',
|
||||
'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'cumulative_trapezoid',
|
||||
'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve', '_lu_with_info',
|
||||
'linalg_pinv', 'linalg_lstsq', 'col2im', 'col2im_backward', 'im2col', 'im2col_backward',
|
||||
'linalg_solve_triangular', 'linalg_pinv', 'linalg_lstsq', 'col2im', 'col2im_backward', 'im2col', 'im2col_backward',
|
||||
}
|
||||
|
||||
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
|
||||
|
@ -9924,11 +9924,11 @@ add_docstr(torch.triangular_solve,
|
||||
r"""
|
||||
triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None) -> (Tensor, Tensor)
|
||||
|
||||
Solves a system of equations with a triangular coefficient matrix :math:`A`
|
||||
Solves a system of equations with a square upper or lower triangular invertible matrix :math:`A`
|
||||
and multiple right-hand sides :math:`b`.
|
||||
|
||||
In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
|
||||
with the default keyword arguments.
|
||||
In symbols, it solves :math:`AX = b` and assumes :math:`A` is square upper-triangular
|
||||
(or lower-triangular if :attr:`upper`\ `= False`) and does not have zeros on the diagonal.
|
||||
|
||||
`torch.triangular_solve(b, A)` can take in 2D inputs `b, A` or inputs that are
|
||||
batches of 2D matrices. If the inputs are batches, then returns
|
||||
@ -9945,10 +9945,9 @@ Args:
|
||||
:math:`*` is zero of more batch dimensions
|
||||
A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)`
|
||||
where :math:`*` is zero or more batch dimensions
|
||||
upper (bool, optional): whether to solve the upper-triangular system
|
||||
of equations (default) or the lower-triangular system of equations. Default: ``True``.
|
||||
transpose (bool, optional): whether :math:`A` should be transposed before
|
||||
being sent into the solver. Default: ``False``.
|
||||
upper (bool, optional): whether :math:`A` is upper or lower triangular. Default: ``True``.
|
||||
transpose (bool, optional): solves `op(A)X = b` where `op(A) = A^T` if this flag is ``True``,
|
||||
and `op(A) = A` if it is ``False``. Default: ``False``.
|
||||
unitriangular (bool, optional): whether :math:`A` is unit triangular.
|
||||
If True, the diagonal elements of :math:`A` are assumed to be
|
||||
1 and not referenced from :math:`A`. Default: ``False``.
|
||||
|
@ -164,6 +164,14 @@ inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& othe
|
||||
return torch::linalg_solve_out(result, input, other);
|
||||
}
|
||||
|
||||
inline Tensor solve_triangular(const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) {
|
||||
return torch::linalg_solve_triangular(input, other, upper, left, unitriangular);
|
||||
}
|
||||
|
||||
inline Tensor& solve_triangular_out(Tensor& result, const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) {
|
||||
return torch::linalg_solve_triangular_out(result, input, other, upper, left, unitriangular);
|
||||
}
|
||||
|
||||
inline std::tuple<Tensor, Tensor, Tensor> svd(const Tensor& input, bool full_matrices) {
|
||||
return torch::linalg_svd(input, full_matrices);
|
||||
}
|
||||
@ -437,6 +445,18 @@ inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& othe
|
||||
return detail::solve_out(result, input, other);
|
||||
}
|
||||
|
||||
/// Computes a solution of a linear system AX = B for input = A and other = B whenever A is square
|
||||
/// upper or lower triangular and does not have zeros in the diagonal
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.solve_triangular
|
||||
inline Tensor solve_triangular(const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) {
|
||||
return detail::solve_triangular(input, other, upper, left, unitriangular);
|
||||
}
|
||||
|
||||
inline Tensor& solve_triangular_out(Tensor& result, const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) {
|
||||
return detail::solve_triangular_out(result, input, other, upper, left, unitriangular);
|
||||
}
|
||||
|
||||
/// Computes the singular values and singular vectors
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.svd
|
||||
|
@ -2660,6 +2660,7 @@ Tensor eigh_backward(const std::vector<torch::autograd::Variable> &grads, const
|
||||
// This function is used for both torch.symeig and torch.linalg.eigh.
|
||||
// eigh (and torch.symeig) operates only on symmetric (resp. Hermitian) inputs.
|
||||
|
||||
// [Note: eigh backward]
|
||||
// General considerations of the differential and adjoint
|
||||
// Let U(n) = {U \in C^{n x n} | U^H U = I} by the unitary group and
|
||||
// Her(n) = {A \in C^{n x n} | A^H = A} be the Hermitian matrices
|
||||
@ -3248,6 +3249,78 @@ Tensor triangular_solve_jvp(
|
||||
);
|
||||
}
|
||||
|
||||
Tensor linalg_solve_triangular_forward_AD(
|
||||
const Tensor& A_t,
|
||||
const Tensor& B_t,
|
||||
const Tensor& A,
|
||||
const Tensor& X,
|
||||
const bool upper,
|
||||
const bool left,
|
||||
const bool unitriangular) {
|
||||
// The forward AD formula (for left = true) is A^{-1}(B_t - A_tX)
|
||||
// For the derivation see:
|
||||
// [Note: Forward / Backward AD solve_triangular]
|
||||
const Tensor proj_A_t = upper ? A_t.triu(static_cast<int>(unitriangular))
|
||||
: A_t.tril(- static_cast<int>(unitriangular));
|
||||
const Tensor X_t = B_t - (left ? at::matmul(proj_A_t, X) : at::matmul(X, proj_A_t));
|
||||
return at::linalg_solve_triangular(A, X_t, upper, left, unitriangular);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
|
||||
const Tensor& grad,
|
||||
const Tensor& A,
|
||||
const Tensor& X,
|
||||
const bool upper,
|
||||
const bool left,
|
||||
const bool unitriangular,
|
||||
std::array<bool, 2> output_mask) {
|
||||
const bool A_requires_grad = output_mask[0];
|
||||
const bool B_requires_grad = output_mask[1];
|
||||
// [Note: Forward / Backward AD solve_triangular]
|
||||
// Assume left=true for simplicity.
|
||||
// Remark: A solver computes A^{-1}B
|
||||
//
|
||||
// Forward AD:
|
||||
// If f(A) = A^{-1}, differentiating the equation A^{-1}A = I_n gives
|
||||
// (df)_A(E) = -A^{-1}EA^{-1}
|
||||
// As such, if g(A,B) = A^{-1}B,
|
||||
// (dg)_(A,B)(E_A, E_B) = -A^{-1}E_AA^{-1}B + A^{-1}E_B
|
||||
// = A^{-1}(E_B - E_AX)
|
||||
|
||||
// Backward AD:
|
||||
// Denoting the gradients by G_A, G_B, we solve above to give
|
||||
// G_B = A^{-H}G_X
|
||||
// G_A = -A^{-H}G_XX^H = -G_B X^H
|
||||
//
|
||||
// Note that you don't need to store B for forward nor backward
|
||||
//
|
||||
// These formulas work for a general solver of linear equations.
|
||||
// Let's prove now that when A is triangular, G_A is the projection onto the triangular matrices
|
||||
// of the formula above, i.e. simply taking triu (resp. tril) in the formula above.
|
||||
// This is because, since the triangular matrices form a vector space, the tangent space at any
|
||||
// point is itself the space of triangular matrices. The result follows from a reasoning as that
|
||||
// at the end of [Note: eigh backward]
|
||||
// Something similar happens for `unitriangular`, only that int his case the tangent space is
|
||||
// the set of lower-triangular matrices with zeros on the diagonal.
|
||||
|
||||
if (!grad.defined() || (!A_requires_grad && !B_requires_grad)) {
|
||||
return std::make_tuple(Tensor{}, Tensor{});
|
||||
}
|
||||
// We always need to comput G_B
|
||||
const Tensor A_H = A.mH();
|
||||
const Tensor G_B = at::linalg_solve_triangular(A_H, grad, !upper, left, unitriangular);
|
||||
|
||||
if (A_requires_grad) {
|
||||
const Tensor X_H = X.mH();
|
||||
Tensor G_A = left ? -at::matmul(G_B, X_H) : -at::matmul(X_H, G_B);
|
||||
G_A = upper ? G_A.triu(static_cast<int>(unitriangular))
|
||||
: G_A.tril(- static_cast<int>(unitriangular));
|
||||
return std::make_tuple(G_A, B_requires_grad ? G_B : Tensor{});
|
||||
} else {
|
||||
return std::make_tuple(Tensor{}, G_B);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> cholesky_solve_backward(
|
||||
const Tensor& grad_x, const Tensor& self,
|
||||
const Tensor& input2, const Tensor& result, const bool upper) {
|
||||
|
@ -194,6 +194,22 @@ Tensor triangular_solve_jvp(
|
||||
const bool transpose,
|
||||
const bool unitriangular
|
||||
);
|
||||
Tensor linalg_solve_triangular_forward_AD(
|
||||
const Tensor& A_t,
|
||||
const Tensor& B_t,
|
||||
const Tensor& A,
|
||||
const Tensor& X,
|
||||
const bool upper,
|
||||
const bool left,
|
||||
const bool unitriangular);
|
||||
std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
|
||||
const Tensor& grad,
|
||||
const Tensor& A,
|
||||
const Tensor& X,
|
||||
const bool upper,
|
||||
const bool left,
|
||||
const bool unitriangular,
|
||||
std::array<bool, 2> output_mask);
|
||||
std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3,
|
||||
IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3,
|
||||
IntArrayRef sumdim, std::array<bool, 3> grad_mask);
|
||||
|
@ -1859,7 +1859,7 @@ Computes the solution of a square system of linear equations with a unique solut
|
||||
|
||||
Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`,
|
||||
this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** associated to
|
||||
:math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{m \times k}`, which is defined as
|
||||
:math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{n \times k}`, which is defined as
|
||||
|
||||
.. math:: AX = B
|
||||
|
||||
@ -1883,10 +1883,19 @@ Letting `*` be zero or more batch dimensions,
|
||||
This function computes `X = \ `:attr:`A`\ `.inverse() @ \ `:attr:`B` in a faster and
|
||||
more numerically stable way than performing the computations separately.
|
||||
|
||||
.. note::
|
||||
It is possible to compute the solution of the system :math:`XA = B` by passing the inputs
|
||||
:attr:`A` and :attr:`B` transposed and transposing the output returned by this function.
|
||||
|
||||
""" + fr"""
|
||||
.. note:: {common_notes["sync_note"]}
|
||||
""" + r"""
|
||||
|
||||
.. seealso::
|
||||
|
||||
:func:`torch.linalg.solve_triangular` computes the solution of a triangular system of linear
|
||||
equations with a unique solution.
|
||||
|
||||
Args:
|
||||
A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions.
|
||||
B (Tensor): right-hand side tensor of shape `(*, n)` or `(*, n, k)` or `(n,)` or `(n, k)`
|
||||
@ -1933,6 +1942,84 @@ Examples::
|
||||
https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem
|
||||
""")
|
||||
|
||||
solve_triangular = _add_docstr(_linalg.linalg_solve_triangular, r"""
|
||||
linalg.solve_triangular(A, B, *, upper, left=True, unitriangular=False, out=None) -> Tensor
|
||||
|
||||
Computes the solution of a triangular system of linear equations with a unique solution.
|
||||
|
||||
Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`,
|
||||
this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system**
|
||||
associated to the triangular matrix :math:`A \in \mathbb{K}^{n \times n}` without zeros on the diagonal
|
||||
(that is, it is `invertible`_) and the rectangular matrix , :math:`B \in \mathbb{K}^{n \times k}`,
|
||||
which is defined as
|
||||
|
||||
.. math:: AX = B
|
||||
|
||||
The argument :attr:`upper` signals whether :math:`A` is upper or lower triangular.
|
||||
|
||||
If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that
|
||||
solves the system
|
||||
|
||||
.. math::
|
||||
|
||||
XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.}
|
||||
|
||||
If :attr:`upper`\ `= True` (resp. `False`) just the upper (resp. lower) triangular half of :attr:`A`
|
||||
will be accessed. The elements below the main diagonal will be considered to be zero and will not be accessed.
|
||||
|
||||
If :attr:`unitriangular`\ `= True`, the diagonal of :attr:`A` is assumed to be ones and will not be accessed.
|
||||
|
||||
The result may contain `NaN` s if the diagonal of :attr:`A` contains zeros or elements that
|
||||
are very close to zero and :attr:`unitriangular`\ `= False` (default) or if the input matrix
|
||||
has very small eigenvalues.
|
||||
|
||||
Supports inputs of float, double, cfloat and cdouble dtypes.
|
||||
Also supports batches of matrices, and if the inputs are batches of matrices then
|
||||
the output has the same batch dimensions.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:func:`torch.linalg.solve` computes the solution of a general square system of linear
|
||||
equations with a unique solution.
|
||||
|
||||
Args:
|
||||
A (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= True`)
|
||||
where `*` is zero or more batch dimensions.
|
||||
B (Tensor): right-hand side tensor of shape `(*, n, k)`.
|
||||
|
||||
Keyword args:
|
||||
upper (bool): whether :attr:`A` is an upper or lower triangular matrix.
|
||||
left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`.
|
||||
unitriangular (bool, optional): if `True`, the diagonal elements of :attr:`A` are assumed to be
|
||||
all equal to `1`. Default: `False`.
|
||||
out (Tensor, optional): output tensor. `B` may be passed as `out` and the result is computed in-place on `B`.
|
||||
Ignored if `None`. Default: `None`.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> A = torch.randn(3, 3).triu_()
|
||||
>>> b = torch.randn(3, 4)
|
||||
>>> X = torch.linalg.solve_triangular(A, B, upper=True)
|
||||
>>> torch.allclose(A @ X, B)
|
||||
True
|
||||
|
||||
>>> A = torch.randn(2, 3, 3).tril_()
|
||||
>>> B = torch.randn(2, 3, 4)
|
||||
>>> X = torch.linalg.solve_triangular(A, B, upper=False)
|
||||
>>> torch.allclose(A @ X, B)
|
||||
True
|
||||
|
||||
>>> A = torch.randn(2, 4, 4).tril_()
|
||||
>>> B = torch.randn(2, 3, 4)
|
||||
>>> X = torch.linalg.solve_triangular(A, B, upper=False, left=False)
|
||||
>>> torch.allclose(X @ A, B)
|
||||
True
|
||||
|
||||
.. _invertible:
|
||||
https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem
|
||||
""")
|
||||
|
||||
|
||||
tensorinv = _add_docstr(_linalg.linalg_tensorinv, r"""
|
||||
linalg.tensorinv(A, ind=2, *, out=None) -> Tensor
|
||||
|
||||
|
@ -984,6 +984,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.trapz: lambda y, x=None, dim=-1: -1,
|
||||
torch.trapezoid: lambda y, x=None, dim=-1: -1,
|
||||
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
|
||||
torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
|
||||
torch.tril: lambda input, diagonal=0, out=None: -1,
|
||||
torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
|
||||
|
||||
|
@ -4862,6 +4862,44 @@ def sample_inputs_linalg_solve(op_info, device, dtype, requires_grad=False, vect
|
||||
out.append(SampleInput(a, args=(b,)))
|
||||
return out
|
||||
|
||||
def sample_inputs_linalg_solve_triangular(op_info, device, dtype, requires_grad=False, **kwargs):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device)
|
||||
bs = (1, 2, 0)
|
||||
ns = (3, 0)
|
||||
ks = (1, 3, 0)
|
||||
|
||||
def gen_inputs():
|
||||
for b, n, k, (left, upper, uni) in product(bs, ns, ks, product((True, False), repeat=3)):
|
||||
with torch.no_grad():
|
||||
if b == 1:
|
||||
A = make_arg((n, n)) if left else make_arg((k, k))
|
||||
B = make_arg((n, k))
|
||||
else:
|
||||
A = make_arg((b, n, n)) if left else make_arg((b, k, k))
|
||||
B = make_arg((b, n, k))
|
||||
if uni:
|
||||
# Not really necessary, but writing it for consistency
|
||||
A.diagonal(0, -2, -1).fill_(1.)
|
||||
else:
|
||||
d = A.diagonal(0, -2, -1)
|
||||
d[d.abs() < 1e-6] = 1.
|
||||
if upper:
|
||||
A.triu_()
|
||||
else:
|
||||
A.tril_()
|
||||
kwargs = {"upper": upper, "left": left, "unitriangular": uni}
|
||||
if requires_grad:
|
||||
for grad_A, grad_B in product((True, False), repeat=2):
|
||||
# Either A or B needs to have a gradient
|
||||
if not grad_A and not grad_B:
|
||||
continue
|
||||
A.requires_grad_(grad_A)
|
||||
B.requires_grad_(grad_B)
|
||||
yield SampleInput(A, args=(B,), kwargs=kwargs)
|
||||
else:
|
||||
yield SampleInput(A, args=(B,), kwargs=kwargs)
|
||||
|
||||
return list(gen_inputs())
|
||||
|
||||
def sample_inputs_legacy_solve(op_info, device, dtype, requires_grad=False, **kwargs):
|
||||
"""
|
||||
@ -7598,8 +7636,7 @@ def gradcheck_wrapper_triangular_input(op, *args, upper=False, idx=0, **kwargs):
|
||||
`idx` is used to specific which `args[idx]` is to be triangularized.
|
||||
"""
|
||||
triangular_arg = args[idx].triu() if upper else args[idx].tril()
|
||||
modified_args = args[:idx] + (triangular_arg,) + args[idx + 1:]
|
||||
return op(*modified_args, upper)
|
||||
return op(*args[:idx], triangular_arg, *args[idx + 1:], upper, **kwargs)
|
||||
|
||||
|
||||
def gradcheck_wrapper_masked_operation(op, input, *args, **kwargs):
|
||||
@ -11174,6 +11211,13 @@ op_db: List[OpInfo] = [
|
||||
check_batched_gradgrad=False,
|
||||
supports_forward_ad=True,
|
||||
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]),
|
||||
OpInfo('linalg.solve_triangular',
|
||||
aten_name='linalg_solve_triangular',
|
||||
op=torch.linalg.solve_triangular,
|
||||
dtypes=floating_and_complex_types(),
|
||||
sample_inputs_func=sample_inputs_linalg_solve_triangular,
|
||||
# linalg.solve_triangular cannot be batched over because of a call to out.copy_(result);
|
||||
supports_forward_ad=True),
|
||||
OpInfo('linalg.matrix_rank',
|
||||
aten_name='linalg_matrix_rank',
|
||||
dtypes=floating_and_complex_types(),
|
||||
|
Reference in New Issue
Block a user