Simplify and optimize linalg.solve

This PR heavily simplifies the code of `linalg.solve`. At the same time,
this implementation saves quite a few copies of the input data in some
cases (e.g. A is contiguous)

We also implement it in such a way that the derivative goes from
computing two LU decompositions and two LU solves to no LU
decompositions and one LU solves. It also avoids a number of unnecessary
copies the derivative was unnecessarily performing (at least the copy of
two matrices).

On top of this, we add a `left` kw-only arg that allows the user to
solve `XA = B` rather concisely.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74046

Approved by: https://github.com/nikitaved, https://github.com/IvanYashchuk, https://github.com/mruberry
This commit is contained in:
lezcano
2022-06-10 19:24:28 +00:00
committed by PyTorch MergeBot
parent 65a37923f9
commit 54949a5abc
20 changed files with 281 additions and 352 deletions

View File

@ -567,7 +567,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank.tol_tensor", Tensor(const Tensor &, const Tensor &, bool), fp32)
KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank.atol_rtol_tensor", Tensor(const Tensor &, const c10::optional<at::Tensor> &, const c10::optional<at::Tensor> &, bool), fp32)
KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank.atol_rtol_float", Tensor(const Tensor &, c10::optional<double>, c10::optional<double>, bool), fp32)
KERNEL_CPU(ADD_NS(linalg_solve), "linalg_solve", Tensor(const Tensor &, const Tensor &), fp32)
KERNEL_CPU(ADD_NS(linalg_solve), "linalg_solve", Tensor(const Tensor &, const Tensor &, bool), fp32)
KERNEL_CPU(ADD_NS(linalg_cholesky), "linalg_cholesky", Tensor(const Tensor &, bool), fp32)
KERNEL_CPU(ADD_NS(linalg_svdvals), "linalg_svdvals", Tensor(const Tensor &, c10::optional<c10::string_view>), fp32)
KERNEL_CPU(ADD_NS(linalg_eigvals), "linalg_eigvals", Tensor(const Tensor &), fp32)

View File

@ -411,6 +411,45 @@ TORCH_META_FUNC(triangular_solve)(const Tensor& self, const Tensor& A, bool uppe
}
}
TORCH_META_FUNC(_linalg_solve)(const Tensor& A,
const Tensor& B,
bool left) {
// dtype
at::native::checkFloatingOrComplex(A, "linalg.solve");
TORCH_CHECK(A.scalar_type() == B.scalar_type(),
"linalg.solve: Expected A and B to have the same dtype, but found A of type ",
A.scalar_type(), " and B of type ", B.scalar_type(), " instead");
// NumPy compat: Two types of 'B' tensors are supported:
// - 1D tensor or batch of 1D tensors (vector case)
// - 2D tensor or batch of 2D tensors (matrix case)
const bool vector_case = at::native::linalg_solve_is_vector_rhs(A, B);
auto B_ = vector_case ? B.unsqueeze(-1) : B;
// matrix shapes
at::native::checkInputsSolver(A, B_, /*left=*/left, "linalg.solve");
// Check that B can be broadcasted to the shape of A
auto B_broad_shape = std::get<0>(at::native::_linalg_broadcast_batch_dims(B_, A));
// We disallow the broadcasting of B as a vector when left=False as, in that case, A.shape = (*, 1, 1)
TORCH_CHECK(left || !vector_case, "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. In this case linalg.solve is equivalent to B / A.squeeze(-1)");
auto result_shape = vector_case ? IntArrayRef(B_broad_shape.data(), B_broad_shape.size() - 1)
: B_broad_shape;
auto result_strides = at::native::batched_matrix_contiguous_strides(result_shape, /*column_major=*/left);
set_output_strided(0, result_shape, result_strides, B.options(), {});
auto shape = A.sizes();
auto ndim = shape.size();
// LU
auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
set_output_strided(1, shape, LU_strides, A.options(), {});
// Pivots
set_output_contiguous(2, shape.slice(0, ndim - 1), A.options().dtype(kInt));
}
TORCH_META_FUNC(linalg_lu_factor_ex)(const Tensor& A, bool pivot, bool check_errors) {
TORCH_CHECK(A.dim() >= 2, "torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: ", A.sizes(), " instead");
@ -1368,145 +1407,7 @@ bool _requires_fw_or_bw_grad(const Tensor& input) {
|| input._fw_grad(/*level */ 0).defined());
}
// Solves a system of linear equations matmul(input, x) = other in-place
// LAPACK/MAGMA error codes are saved in 'infos' tensor, they are not checked here
static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor& input, const Tensor& other) {
checkSameDevice("linalg.solve", result, input);
checkSameDevice("linalg.solve", other, input, "other");
checkLinalgCompatibleDtype("linalg.solve", result, input);
TORCH_CHECK(input.scalar_type() == other.scalar_type(),
"input dtype ", input.scalar_type(), " does not match other dtype ", other.scalar_type());
squareCheckInputs(input, "linalg.solve");
TORCH_CHECK(other.dim() >= 1,
"other should have at least 1 dimension, but has ", other.dim(), " dimensions instead");
// Two types of 'other' tensors are supported:
// - 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
// - 2-dimensional (2D) tensor or batch of 2D tensors (matrix case)
// original torch.solve supported only the matrix case, while NumPy works for both cases
// for the batched input we need to be able to distinguish them
bool vector_case = linalg_solve_is_vector_rhs(input, other);
bool is_batched_column_major = false;
if (vector_case) {
is_batched_column_major = result.is_contiguous();
} else if (!vector_case && result.dim() >= 2) {
is_batched_column_major = result.mT().is_contiguous();
}
// if 'other' is a batch of 2D tensors, then 'input' can be non-batched and will be broadcasted
auto expected_shape = IntArrayRef(input.sizes().data(), input.dim() - 1); // input.shape[:-1]
if (!vector_case && other.dim() > 2) {
expected_shape = other.sizes();
}
bool result_equal_expected_shape = result.sizes().equals(expected_shape);
bool result_input_same_type = (result.scalar_type() == input.scalar_type());
// if result is not empty and not in batched column major format
bool copy_needed = (result.numel() != 0 && !is_batched_column_major);
copy_needed |= !result_input_same_type; // or result does not have the same dtype as input
copy_needed |= (result.numel() != 0 && !result_equal_expected_shape); // or result does not have the expected shape
// we have to allocate a temporary tensor
if (copy_needed) {
Tensor result_tmp = at::empty({0}, input.options());
result_tmp = linalg_solve_out_info(result_tmp, infos, input, other);
at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);
return result;
}
// else use result's storage directly
// we need to unsqueeze 'other' because 2-dimensional tensors are expected in the implementation
Tensor other_ = vector_case ? other.unsqueeze(-1) : other;
// _linalg_broadcast_batch_dims also includes linearSolveCheckInputs
// it checks for squareness of 'input' and 'shape' compatibility of 'other' and 'input'
Tensor other_broadcasted;
std::tie(other_broadcasted, std::ignore) = _linalg_broadcast_batch_dims(other_, input, "linalg.solve");
auto squeezed_other_broadcasted = at::squeeze(other_broadcasted, -1);
auto squeezed_result_shape = squeezed_other_broadcasted.sizes();
// if result has no elements we can modify it
if (result.numel() == 0) {
if (vector_case) {
result.resize_(squeezed_result_shape);
} else {
at::native::resize_as_(result, other_broadcasted.mT(), MemoryFormat::Contiguous);
result.transpose_(-2, -1);
}
}
auto expected_result_shape = vector_case ? squeezed_result_shape : other_broadcasted.sizes();
TORCH_INTERNAL_ASSERT(result.sizes().equals(expected_result_shape));
TORCH_INTERNAL_ASSERT(result.scalar_type() == input.scalar_type());
TORCH_INTERNAL_ASSERT(result.device() == input.device());
// result tensor must be in batched column major order (Fortran contiguous) for 2D inputs
// or C contiguous for 1D input
if (vector_case) {
TORCH_INTERNAL_ASSERT(result.is_contiguous());
} else {
TORCH_INTERNAL_ASSERT(result.mT().is_contiguous());
}
// for 1-dimensional 'other', we need to unsqueeze the result before passing to "apply_solve"
if (vector_case) {
result = result.unsqueeze_(-1);
}
// lu_factor_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
result.copy_(other_broadcasted);
TORCH_INTERNAL_ASSERT(infos.scalar_type() == kInt);
TORCH_INTERNAL_ASSERT(infos.device() == input.device());
infos.resize_({std::max<int64_t>(1, batchCount(input))});
// if input is empty infos might not get filled; make sure infos doesn't contain garbage then
if (input.numel() == 0) {
infos.fill_(0);
}
// compute the LU factorization of 'input_working_copy'
auto input_working_copy = cloneBatchedColumnMajor(input);
auto pivots_shape = IntArrayRef(input.sizes().data(), input.dim() - 2).vec(); // input.shape[:-2]
pivots_shape.push_back(std::min(input.size(-2), input.size(-1)));
Tensor pivots = at::empty(pivots_shape, input.options().dtype(kInt));
lu_factor_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true);
// solve the linear system using the LU factorization
lu_solve_stub(input.device().type(), input_working_copy, pivots, result, TransposeType::NoTranspose);
// for 1-dimensional 'other', we need to squeeze the result after "apply_solve"
if (vector_case) {
result = result.squeeze_(-1);
}
return result;
}
// Solves a system of linear equations matmul(input, x) = other in-place
Tensor& linalg_solve_out(const Tensor& input, const Tensor& other, Tensor& result) {
auto infos = at::empty({0}, input.options().dtype(kInt));
result = linalg_solve_out_info(result, infos, input, other);
// Now check LAPACK/MAGMA error codes
// _linalg_check_errors calls 'infos = infos.to(kCPU)'
at::_linalg_check_errors(infos, "linalg.solve", input.dim() == 2);
return result;
}
// Solves a system of linear equations matmul(input, x) = other
Tensor linalg_solve(const Tensor& input, const Tensor& other) {
Tensor result = at::empty({0}, input.options());
result = at::linalg_solve_out(result, input, other);
return result;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/*
Computes the inverse of n-by-n matrix 'self'
This is an in-place routine, it overwrites the content of 'self'.
@ -2058,6 +1959,53 @@ Tensor cholesky_inverse(const Tensor &input, bool upper) {
return result;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Auxiliary function that returns the LU decomposition to use it in the backward
TORCH_IMPL_FUNC(_linalg_solve_out)(const Tensor& A,
const Tensor& B,
bool left,
const Tensor& result,
const Tensor& LU,
const Tensor& pivots) {
// Possible optimization: Compute the LU factorization of A^T if A is contiguous
// Then we solve A^T X = B with adjoint=True
// This saves a copy as A doesn't need to be copied into an F-contig matrix in lu_factor
const bool use_A_T = A.is_contiguous() && !A.is_complex();
auto info = at::empty({0}, A.options().dtype(kInt));
at::linalg_lu_factor_ex_out(const_cast<Tensor&>(LU),
const_cast<Tensor&>(pivots),
const_cast<Tensor&>(info),
use_A_T ? A.mT() : A,
/*pivot=*/true,
/*check_errors=*/false);
at::_linalg_check_errors(info, "torch.linalg.solve", A.dim() == 2);
// [numpy-compat] Handle vectors on the rhs
const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, B);
auto result_ = vector_case ? result.unsqueeze(-1) : result;
auto B_ = vector_case ? B.unsqueeze(-1) : B;
at::linalg_lu_solve_out(result_, LU, pivots, B_, left, /*adjoint*/use_A_T);
}
Tensor& linalg_solve_out(const Tensor& A,
const Tensor& B,
bool left,
Tensor& result) {
auto LU = at::empty({0}, A.options());
auto pivots = at::empty({0}, A.options().dtype(kInt));
at::_linalg_solve_out(result, LU, pivots, A, B, left);
return result;
}
// We implement linalg_solve as a composite function of _linalg_solve
Tensor linalg_solve(const Tensor& A,
const Tensor& B,
bool left) {
return std::get<0>(at::_linalg_solve(A, B, left));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_factor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
DEFINE_DISPATCH(lu_factor_stub);

View File

@ -329,7 +329,7 @@ static inline void singleCheckErrors(int64_t info, const c10::string_view name,
} else if (name.find("solve") != name.npos) {
// solve, linalg_solve, cholesky_solve, etc.
TORCH_CHECK_LINALG(false, name, batch_string,
": The diagonal element ", info, " is zero, the solve could not be completed because the input matrix is singular.");
": The solver failed because the input matrix is singular.");
} else if (name.find("cholesky") != name.npos) {
TORCH_CHECK_LINALG(false, name, batch_string,
": The factorization could not be completed because the input is not positive-definite (the leading minor of order ", info, " is not positive-definite).");

View File

@ -11859,16 +11859,19 @@
python_module: linalg
variants: function
- func: linalg_solve(Tensor input, Tensor other) -> Tensor
python_module: linalg
variants: function
dispatch:
CPU, CUDA: linalg_solve
- func: _linalg_solve(Tensor A, Tensor B, *, bool left=True) -> (Tensor result, Tensor LU, Tensor pivots)
structured_delegate: _linalg_solve.result
- func: linalg_solve.out(Tensor input, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
- func: _linalg_solve.result(Tensor A, Tensor B, *, bool left=True, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots)
structured: True
dispatch:
CPU, CUDA: linalg_solve_out
CPU, CUDA: _linalg_solve_out
- func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor
python_module: linalg
- func: linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
- func: linalg_tensorinv(Tensor self, int ind=2) -> Tensor
python_module: linalg

View File

@ -2915,13 +2915,6 @@
"Generator"
],
"torch.return_types": [
"_det_lu_based_helper",
"_fake_quantize_per_tensor_affine_cachemask_tensor_qparams",
"_fused_moving_avg_obs_fq_helper",
"_linalg_svd",
"_linalg_svd_out",
"_lu_with_info",
"_unpack_dual",
"attr",
"pytree_register_structseq"
],

View File

@ -52,6 +52,8 @@ ALLOW_LIST = [
("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)),
("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)),
("aten::randperm", datetime.date(9999, 1, 1)),
("aten::linalg_solve", datetime.date(2022, 8, 31)),
("aten::linalg_solve.out", datetime.date(2022, 8, 31)),
("aten::gelu", datetime.date(2022, 3, 1)),
("aten::gelu_backward", datetime.date(2022, 3, 1)),
("aten::cudnn_convolution_backward", datetime.date(2022, 1, 31)),
@ -112,6 +114,8 @@ ALLOW_LIST = [
("q::_FloatToBfloat16Quantized", datetime.date(2021, 12, 21)),
("q::_Bfloat16QuantizedToFloat", datetime.date(2021, 12, 21)),
("aten::_inverse_helper", datetime.date(2021, 12, 31)),
("aten::linalg_solve", datetime.date(2022, 8, 31)),
("aten::linalg_solve.out", datetime.date(2022, 8, 31)),
("aten::softplus_backward", datetime.date(2022, 1, 31)),
("aten::softplus_backward.grad_input", datetime.date(2022, 1, 31)),
("aten::quantile", datetime.date(2022, 9, 30)),

View File

@ -3083,96 +3083,12 @@ class TestLinalg(TestCase):
expected = np.linalg.solve(A.cpu().numpy(), b.expand_as(x).cpu().numpy())
self.assertEqual(x, expected)
# Check out= variant
out = torch.empty_like(x)
ans = torch.linalg.solve(A, b, out=out)
self.assertEqual(ans, out)
self.assertEqual(x, out)
# Check out= variant with complex128 out tensor
out = torch.empty_like(x).to(torch.complex128)
ans = torch.linalg.solve(A, b, out=out)
self.assertEqual(ans, out)
self.assertEqual(x.to(torch.complex128), out)
# Check empty out
out = torch.empty(0, dtype=dtype, device=device)
ans = torch.linalg.solve(A, b, out=out)
self.assertEqual(ans, out)
self.assertEqual(x, out)
batches = [(), (0, ), (3, ), (2, 3)]
ns = [0, 5, 32]
nrhs = [(), (1, ), (5, )]
for n, batch, rhs in itertools.product(ns, batches, nrhs):
run_test(n, batch, rhs)
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
def test_solve_errors_and_warnings(self, device, dtype):
# solve expects batches of square matrices as input
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device)
b = torch.randn(2, 3, 4, 1, dtype=dtype, device=device)
torch.linalg.solve(a, b)
# solve expects compatible shapes for A x = b
with self.assertRaisesRegex(RuntimeError, "Incompatible matrix sizes"):
a = torch.randn(2, 3, 3, 3, dtype=dtype, device=device)
b = torch.randn(2, 3, 2, 1, dtype=dtype, device=device)
torch.linalg.solve(a, b)
# if input is not solvable, RuntimeError is raised mentioning the first non-solvable batch
def run_test_singular_input(batch_dim, n):
a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
a[n, -1, -1] = 0
b = torch.randn(batch_dim, 3, 1, dtype=dtype, device=device)
with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
torch.linalg.solve(a, b)
for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
run_test_singular_input(*params)
# if out tensor with wrong shape is passed a warning is given
# matrix 'b' case
with warnings.catch_warnings(record=True) as w:
A = torch.eye(2, dtype=dtype, device=device).reshape((1, 2, 2)).repeat(2, 1, 1)
b = torch.randn(2, 2, 2, dtype=dtype, device=device)
out = torch.zeros(1, dtype=dtype, device=device)
# Trigger warning
torch.linalg.solve(A, b, out=out)
# Check warning occurs
self.assertEqual(len(w), 1)
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
# if out tensor with wrong shape is passed a warning is given
# vector 'b' case
with warnings.catch_warnings(record=True) as w:
A = torch.eye(2, dtype=dtype, device=device)
b = torch.randn(2, dtype=dtype, device=device)
out = torch.zeros(1, dtype=dtype, device=device)
# Trigger warning
torch.linalg.solve(A, b, out=out)
# Check warning occurs
self.assertEqual(len(w), 1)
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
# dtypes should be safely castable
a = torch.eye(2, dtype=dtype, device=device)
b = torch.randn(2, 1, dtype=dtype, device=device)
out = torch.empty(0, dtype=torch.int, device=device)
with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
torch.linalg.solve(a, b, out=out)
# device should match
if torch.cuda.is_available():
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
out = torch.empty(0, dtype=dtype, device=wrong_device)
clone_a = torch.empty_like(a)
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
torch.linalg.solve(a, b, out=out)
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
@ -5103,13 +5019,14 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
def test_linalg_lu_factor_and_lu_and_lu_unpack(self, device, dtype):
def test_linalg_lu_family(self, device, dtype):
# Tests torch.lu
# torch.linalg.lu_factor
# torch.linalg.lu_factor_ex
# torch.lu_unpack
# torch.linalg.lu_solve
from torch.testing._internal.common_utils import random_matrix
# torch.linalg.solve
make_arg_full = partial(make_fullrank_matrices_with_distinct_singular_values, device=device, dtype=dtype)
make_arg = partial(make_tensor, device=device, dtype=dtype)
def run_test(A, pivot, singular, fn):
@ -5141,20 +5058,37 @@ class TestLinalg(TestCase):
self.assertEqual(L, PLU.L)
self.assertEqual(U, PLU.U)
rhs = 3
if not singular and A.size(-2) == A.size(-1):
for left in (True, False):
shape_B = list(A.shape)
dim = -1 if left else -2
shape_B[dim] = rhs
nrhs = ((), (1,), (3,))
for left, rhs in product((True, False), nrhs):
# Vector case when left = False is not allowed
if not left and rhs == ():
continue
if left:
shape_B = A.shape[:-1] + rhs
else:
shape_B = A.shape[:-2] + rhs + A.shape[-1:]
B = make_arg(shape_B)
for adjoint in (True, False):
X = torch.linalg.lu_solve(LU, pivots, B, left=left, adjoint=adjoint)
A_adj = A.mH if adjoint else A
if left:
self.assertEqual(B, A_adj @ X)
else:
self.assertEqual(B, X @ A_adj)
# Test linalg.lu_solve. It does not support vectors as rhs
# See https://github.com/pytorch/pytorch/pull/74045#issuecomment-1112304913
if rhs != ():
for adjoint in (True, False):
X = torch.linalg.lu_solve(LU, pivots, B, left=left, adjoint=adjoint)
A_adj = A.mH if adjoint else A
if left:
self.assertEqual(B, A_adj @ X)
else:
self.assertEqual(B, X @ A_adj)
# Test linalg.solve
X = torch.linalg.solve(A, B, left=left)
X_ = X.unsqueeze(-1) if rhs == () else X
B_ = B.unsqueeze(-1) if rhs == () else B
if left:
self.assertEqual(B_, A @ X_)
else:
self.assertEqual(B_, X_ @ A)
sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0))
@ -5163,8 +5097,8 @@ class TestLinalg(TestCase):
pivots = (True, False) if self.device_type == "cuda" else (True,)
fns = (partial(torch.lu, get_infos=True), torch.linalg.lu_factor, torch.linalg.lu_factor_ex)
for ms, batch, pivot, singular, fn in itertools.product(sizes, batches, pivots, (True, False), fns):
m, n = ms
A = random_matrix(m, n, *batch, singular=singular, dtype=dtype, device=device)
shape = batch + ms
A = make_arg(shape) if singular else make_arg_full(*shape)
# Just do one of them on singular matrices
if A.numel() == 0 and not singular:
continue
@ -6159,38 +6093,6 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
for b1, b2, ref, out_tensor in generate_tensor():
self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor)
# TODO: update to compare against NumPy
@onlyCUDA
def test_solve_methods_arg_device(self, device):
for b_device, A_device in itertools.product(['cpu', device], repeat=2):
if b_device == A_device:
continue
b = torch.randn(3, 1, device=b_device)
A = torch.randn(3, 3, device=A_device)
# cholesky_solve goes through generic backend dispatch and hit kernel specific device check first
# triangular_solve goes through specific backend dispatch (CPU/CUDA) and hit auto-generated device check first
generic_backend_dispatch_err_str = "Expected b and A to be on the same device"
specific_backend_dispatch_err_str = "Expected all tensors to be on the same device"
with self.assertRaisesRegex(RuntimeError, generic_backend_dispatch_err_str):
torch.cholesky_solve(b, A)
with self.assertRaisesRegex(RuntimeError, specific_backend_dispatch_err_str):
torch.triangular_solve(b, A)
# b and A have to be modified to match accepted inputs sizes for lu_solve
b = b.unsqueeze(0)
A = A.unsqueeze(0)
with self.assertRaisesRegex(RuntimeError, specific_backend_dispatch_err_str):
torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=A_device).int())
# This checks if a suitable error message is thrown
# when LU output and pivots are not on the same device
with self.assertRaisesRegex(RuntimeError, specific_backend_dispatch_err_str):
torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=b_device).int())
@precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3})
@skipCUDAIfNoMagma
@skipCPUIfNoLapack

View File

@ -466,9 +466,7 @@ meta_function_expected_failures = {
torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product
torch.linalg.lstsq: {f32, f64}, # aten::linalg_lstsq.out
torch.linalg.slogdet: {f32, f64}, # aten::linalg_slogdet
torch.linalg.solve: {f32, f64}, # aten::linalg_solve, aten::linalg_solve.out
torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular
torch.linalg.tensorsolve: {f32, f64}, # aten::linalg_solve
torch.logdet: {f32, f64}, # aten::_local_scalar_dense, aten::nonzero
}
@ -719,13 +717,9 @@ meta_dispatch_expected_failures = {
aten.linalg_householder_product.out: {f32, f64}, # aten::linalg_householder_product.out
aten.linalg_lstsq.default: {f32, f64}, # aten::linalg_lstsq.out
aten.linalg_slogdet.default: {f32, f64}, # aten::linalg_slogdet
aten.linalg_solve.default: {f32, f64}, # aten::linalg_solve
aten.linalg_solve.out: {f32, f64}, # aten::linalg_solve.out
aten.linalg_solve_triangular.default: {f32, f64}, # aten::linalg_solve_triangular
aten.linalg_solve_triangular.out: {f32, f64}, # aten::linalg_solve_triangular.out
aten.logdet.default: {f32, f64}, # aten::_local_scalar_dense, aten::nonzero
aten.lu_solve.default: {f32, f64}, # aten::lu_solve
aten.lu_solve.out: {f32, f64}, # aten::lu_solve.out
aten.ormqr.default: {f32, f64}, # aten::ormqr
aten.ormqr.out: {f32, f64}, # aten::ormqr.out
aten.symeig.default: {f32, f64}, # aten::_symeig_helper

View File

@ -21,7 +21,7 @@ all_operators_with_namedtuple_return = {
'frexp', 'lu_unpack', 'histogram', 'histogramdd',
'_fake_quantize_per_tensor_affine_cachemask_tensor_qparams',
'_fused_moving_avg_obs_fq_helper', 'linalg_lu_factor', 'linalg_lu_factor_ex', 'linalg_lu',
'_det_lu_based_helper', '_lu_with_info', 'linalg_ldl_factor_ex', 'linalg_ldl_factor',
'_det_lu_based_helper', '_lu_with_info', 'linalg_ldl_factor_ex', 'linalg_ldl_factor', '_linalg_solve'
}
@ -85,6 +85,7 @@ class TestNamedTupleAPI(TestCase):
op(operators=['linalg_slogdet'], input=(), names=('sign', 'logabsdet'), hasout=True),
op(operators=['linalg_cholesky_ex'], input=(), names=('L', 'info'), hasout=True),
op(operators=['linalg_inv_ex'], input=(), names=('inverse', 'info'), hasout=True),
op(operators=['_linalg_solve'], input=(a,), names=('result', 'LU', 'pivots'), hasout=True),
op(operators=['linalg_lu_factor'], input=(), names=('LU', 'pivots'), hasout=True),
op(operators=['linalg_lu_factor_ex'], input=(), names=('LU', 'pivots', 'info'), hasout=True),
op(operators=['linalg_ldl_factor'], input=(), names=('LD', 'pivots'), hasout=True),

View File

@ -1437,10 +1437,10 @@
self: slogdet_backward(grad, self, sign, logabsdet)
output_differentiability: [false, true]
- name: linalg_solve(Tensor input, Tensor other) -> Tensor
input: solve_backward_A(grad, other, input, result)
other: solve_backward_self(grad, other, input)
result: solve_jvp(result, input_p, input_t, other_t)
- name: _linalg_solve(Tensor A, Tensor B, *, bool left=True) -> (Tensor result, Tensor LU, Tensor pivots)
A, B: linalg_solve_backward(grad, result, A, LU, pivots, left, grad_input_mask[1])
result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left, A_p.is_contiguous() && !A_p.is_complex())"
output_differentiability: [True, False, False] # LU is an auxiliary tensor not exposed to the user
- name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true)

View File

@ -337,6 +337,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
"pixel_shuffle",
"pixel_unshuffle",
"linalg_lu_solve",
"_linalg_solve",
}
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {

View File

@ -196,12 +196,12 @@ inline std::tuple<Tensor&, Tensor&> qr_out(Tensor& Q, Tensor& R, const Tensor& i
return torch::linalg_qr_out(Q, R, input, mode);
}
inline Tensor solve(const Tensor& input, const Tensor& other) {
return torch::linalg_solve(input, other);
inline Tensor solve(const Tensor& input, const Tensor& other, bool left) {
return torch::linalg_solve(input, other, left);
}
inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& other) {
return torch::linalg_solve_out(result, input, other);
inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& other, bool left) {
return torch::linalg_solve_out(result, input, other, left);
}
inline Tensor solve_triangular(const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) {
@ -566,12 +566,12 @@ inline Tensor& ldl_solve_out(
/// Computes a tensor `x` such that `matmul(input, x) = other`.
///
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.solve
inline Tensor solve(const Tensor& input, const Tensor& other) {
return detail::solve(input, other);
inline Tensor solve(const Tensor& input, const Tensor& other, bool left) {
return detail::solve(input, other, left);
}
inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& other) {
return detail::solve_out(result, input, other);
inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& other, bool left) {
return detail::solve_out(result, input, other, left);
}
/// Computes a solution of a linear system AX = B for input = A and other = B whenever A is square

View File

@ -563,38 +563,6 @@ static Tensor generic_solve_jvp(
return solve(A, dB, dA_contrib);
}
Tensor solve_jvp(
const Tensor& X,
const Tensor& A,
const Tensor& dA,
const Tensor& dB
) {
return generic_solve_jvp(
[](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
return at::linalg_solve(A, dB - dA_contrib);
},
X, A, dA, dB
);
}
Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) {
return at::linalg_solve(A.mH(), grad);
}
Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) {
at::NoTF32Guard disable_tf32;
Tensor grad_self = solve_backward_self(grad, self, A);
if (self.ndimension() == 2 && A.ndimension() == 2) {
return -at::mm(grad_self, solution.mH());
}
// if self was unsqueezed from (..., M) to (..., M, 1)
bool vector_case = at::native::linalg_solve_is_vector_rhs(A, self);
if (vector_case) {
return -at::matmul(grad_self.unsqueeze(-1), solution.unsqueeze(-1).mH());
}
return -at::matmul(grad_self, solution.mH());
}
Tensor cumsum_backward(const Tensor & grad, int64_t dim) {
// Trivial case
if (grad.numel() <= 1 || grad.size(dim) == 1) {
@ -4911,6 +4879,89 @@ Tensor linalg_lu_solve_jvp(
}
}
Tensor linalg_solve_jvp(
const Tensor& dA,
const Tensor& dB,
const Tensor& X,
const Tensor& LU,
const Tensor& pivots,
const bool left,
const bool use_A_T) {
at::NoTF32Guard disable_tf32;
// For left=True (left=False is analogous)
// dX = A^{-1}(dB - dAX)
// [NumPy compat] Case where the rhs is a vector.
// We denote with an underscore vectors that have been converted to matrices by `unsqueeze(-1)`
const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, X);
const auto vector_to_matrix = [vector_case](const Tensor& X) { return vector_case ? X.unsqueeze(-1) : X; };
const auto matrix_to_vector = [vector_case](const Tensor& X) { return vector_case ? X.squeeze(-1) : X; };
// This case is disallowed in the primal operation as A.shape = (*, 1, 1)
TORCH_INTERNAL_ASSERT(left || !vector_case);
auto X_ = vector_to_matrix(X);
auto dB_ = vector_to_matrix(dB);
auto R_ = left ? dA.matmul(X_) : X_.matmul(dA);
auto dX_ = at::linalg_lu_solve(LU, pivots, dB_ - R_, left, /*adjoint*/use_A_T);
return matrix_to_vector(dX_);
}
std::tuple<Tensor, Tensor> linalg_solve_backward(
const Tensor& gX,
const Tensor& X,
const Tensor& A,
const Tensor& LU,
const Tensor& pivots,
const bool left,
const bool B_requires_grad) {
// for X = A^{-1}B
// gB = A^{-H}gX
// gA = -gB X^H
at::NoTF32Guard disable_tf32;
const auto A_requires_grad = A.requires_grad();
if (!gX.defined() || (!A_requires_grad && !B_requires_grad)) {
return {};
}
// [NumPy compat] Case where the rhs is a vector.
// We denote with an underscore vectors that have been converted to matrices by `unsqueeze(-1)`
const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, X);
const auto vector_to_matrix = [vector_case](const Tensor& X) { return vector_case ? X.unsqueeze(-1) : X; };
const auto matrix_to_vector = [vector_case](const Tensor& X) { return vector_case ? X.squeeze(-1) : X; };
// If the user is going to compute higher order gradients, then we need to recompute the LU and the pivots
Tensor gB_;
if (at::GradMode::is_enabled()) {
gB_ = at::linalg_solve(A.mH(), vector_to_matrix(gX), left);
} else {
const auto use_A_T = A.is_contiguous() && !A.is_complex();
gB_ = at::linalg_lu_solve(LU, pivots, vector_to_matrix(gX), left, /*adjoint*/!use_A_T);
}
Tensor gA_;
if (A_requires_grad) {
auto X_ = vector_to_matrix(X);
gA_ = left ? -gB_.matmul(X_.mH()) : -X_.mH().matmul(gB_);
}
return std::make_tuple(A_requires_grad ? matrix_to_vector(gA_) : Tensor{},
B_requires_grad ? matrix_to_vector(gB_) : Tensor{});
}
Tensor solve_jvp(
const Tensor& X,
const Tensor& A,
const Tensor& dA,
const Tensor& dB
) {
return generic_solve_jvp(
[](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
return at::linalg_solve(A, dB - dA_contrib);
},
X, A, dA, dB
);
}
Tensor lu_unpack_backward(
const Tensor& L_grad,
const Tensor& U_grad,

View File

@ -380,6 +380,22 @@ Tensor linalg_lu_solve_jvp(
const Tensor& dB,
const bool left,
const bool adjoint);
std::tuple<Tensor, Tensor> linalg_solve_backward(
const Tensor& gX,
const Tensor& X,
const Tensor& A,
const Tensor& LU,
const Tensor& pivots,
const bool left,
const bool B_requires_grad);
Tensor linalg_solve_jvp(
const Tensor& dA,
const Tensor& dB,
const Tensor& X,
const Tensor& LU,
const Tensor& pivots,
const bool left,
const bool use_A_T);
Tensor lu_unpack_backward(
const Tensor& L_grad,
const Tensor& U_grad,

View File

@ -4902,17 +4902,18 @@ REGISTER_OPERATOR_FUNCTOR(
aten_linalg_solve,
[](Node* n) -> SROperator {
if (n->matches(torch::schema(
"aten::linalg_solve(Tensor input, Tensor other) -> Tensor"))) {
"aten::linalg_solve(Tensor input, Tensor other, bool left) -> Tensor"))) {
return [](ProcessedNode* p_node) {
const auto& input = p_node->Input(0).toTensor();
const auto& other = p_node->Input(1).toTensor();
auto left = p_node->Input(2).toBool();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = at::native::linalg_solve(input, other);
p_node->Output(0) = at::native::linalg_solve(input, other, left);
return;
}
auto& out = p_node->Output(0).toTensor();
fastResizeToZero(out);
at::native::linalg_solve_out(input, other, out);
at::native::linalg_solve_out(input, other, left, out);
};
}
LogAndDumpSchema(n);

View File

@ -2882,7 +2882,7 @@ void nnc_aten_linalg_solve(
const at::Tensor& input = tensors[1];
const at::Tensor& other = tensors[2];
try {
at::linalg_solve_out(r, input, other);
at::linalg_solve_out(r, input, other, true);
} catch (...) {
}
}

View File

@ -2055,7 +2055,7 @@ Example::
solve = _add_docstr(_linalg.linalg_solve, r"""
linalg.solve(A, B, *, out=None) -> Tensor
linalg.solve(A, B, *, left=True, out=None) -> Tensor
Computes the solution of a square system of linear equations with a unique solution.
@ -2065,6 +2065,12 @@ this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the
.. math:: AX = B
If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that solves the system
.. math::
XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.}
This system of linear equations has one solution if and only if :math:`A` is `invertible`_.
This function assumes that :math:`A` is invertible.
@ -2104,6 +2110,7 @@ Args:
according to the rules described above
Keyword args:
left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`.
out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
Raises:

View File

@ -964,7 +964,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.smm: lambda input, mat2: -1,
torch.spmm: lambda input, mat2: -1,
torch.softmax: lambda input, dim, dtype=None: -1,
torch.linalg.solve: lambda input, other, out=None: -1,
torch.linalg.solve: lambda A, B, left=True, out=None: -1,
torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1,
torch.split: lambda tensor, split_size_or_sections, dim=0: -1,
torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,

View File

@ -18,14 +18,17 @@ def pytree_register_structseq(cls):
for name in dir(return_types):
if name.startswith('__'):
continue
globals()[name] = getattr(return_types, name)
__all__.append(name)
attr = getattr(return_types, name)
globals()[name] = attr
if not name.startswith('_'):
__all__.append(name)
# Today everything in torch.return_types is a structseq, aka a "namedtuple"-like
# thing defined by the Python C-API. We're going to need to modify this when that
# is no longer the case.
# NB: I don't know how to check that something is a "structseq" so we do a fuzzy
# check for tuple
attr = globals()[name]
if inspect.isclass(attr) and issubclass(attr, tuple):
pytree_register_structseq(attr)

View File

@ -6898,9 +6898,14 @@ def sample_inputs_legacy_solve(op_info, device, dtype, requires_grad=False, **kw
op_info, device, dtype, requires_grad=requires_grad, vector_rhs_allowed=False
)
def out_fn(output):
return output[0]
# Reverses tensor order
for sample in out:
sample.input, sample.args = sample.args[0], (sample.input,)
if op_info.name == "solve":
sample.output_process_fn_grad = out_fn
yield sample
@ -15989,13 +15994,10 @@ op_db: List[OpInfo] = [
op=torch.linalg.solve,
dtypes=floating_and_complex_types(),
sample_inputs_func=sample_inputs_linalg_solve,
check_batched_gradgrad=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
skips=(
# AssertionError: Scalars are not equal!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
device_type='mps', dtypes=[torch.float32]),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
@ -18351,7 +18353,10 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_tensorsolve,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagma],
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack,
DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}),
'TestCommon', 'test_noncontiguous_samples',
device_type='cuda')],
),
OpInfo(
"nn.functional.mse_loss",