mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Simplify and optimize linalg.solve
This PR heavily simplifies the code of `linalg.solve`. At the same time, this implementation saves quite a few copies of the input data in some cases (e.g. A is contiguous) We also implement it in such a way that the derivative goes from computing two LU decompositions and two LU solves to no LU decompositions and one LU solves. It also avoids a number of unnecessary copies the derivative was unnecessarily performing (at least the copy of two matrices). On top of this, we add a `left` kw-only arg that allows the user to solve `XA = B` rather concisely. Pull Request resolved: https://github.com/pytorch/pytorch/pull/74046 Approved by: https://github.com/nikitaved, https://github.com/IvanYashchuk, https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
65a37923f9
commit
54949a5abc
@ -567,7 +567,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
|
||||
KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank.tol_tensor", Tensor(const Tensor &, const Tensor &, bool), fp32)
|
||||
KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank.atol_rtol_tensor", Tensor(const Tensor &, const c10::optional<at::Tensor> &, const c10::optional<at::Tensor> &, bool), fp32)
|
||||
KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank.atol_rtol_float", Tensor(const Tensor &, c10::optional<double>, c10::optional<double>, bool), fp32)
|
||||
KERNEL_CPU(ADD_NS(linalg_solve), "linalg_solve", Tensor(const Tensor &, const Tensor &), fp32)
|
||||
KERNEL_CPU(ADD_NS(linalg_solve), "linalg_solve", Tensor(const Tensor &, const Tensor &, bool), fp32)
|
||||
KERNEL_CPU(ADD_NS(linalg_cholesky), "linalg_cholesky", Tensor(const Tensor &, bool), fp32)
|
||||
KERNEL_CPU(ADD_NS(linalg_svdvals), "linalg_svdvals", Tensor(const Tensor &, c10::optional<c10::string_view>), fp32)
|
||||
KERNEL_CPU(ADD_NS(linalg_eigvals), "linalg_eigvals", Tensor(const Tensor &), fp32)
|
||||
|
||||
@ -411,6 +411,45 @@ TORCH_META_FUNC(triangular_solve)(const Tensor& self, const Tensor& A, bool uppe
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(_linalg_solve)(const Tensor& A,
|
||||
const Tensor& B,
|
||||
bool left) {
|
||||
// dtype
|
||||
at::native::checkFloatingOrComplex(A, "linalg.solve");
|
||||
TORCH_CHECK(A.scalar_type() == B.scalar_type(),
|
||||
"linalg.solve: Expected A and B to have the same dtype, but found A of type ",
|
||||
A.scalar_type(), " and B of type ", B.scalar_type(), " instead");
|
||||
|
||||
// NumPy compat: Two types of 'B' tensors are supported:
|
||||
// - 1D tensor or batch of 1D tensors (vector case)
|
||||
// - 2D tensor or batch of 2D tensors (matrix case)
|
||||
const bool vector_case = at::native::linalg_solve_is_vector_rhs(A, B);
|
||||
auto B_ = vector_case ? B.unsqueeze(-1) : B;
|
||||
|
||||
// matrix shapes
|
||||
at::native::checkInputsSolver(A, B_, /*left=*/left, "linalg.solve");
|
||||
|
||||
// Check that B can be broadcasted to the shape of A
|
||||
auto B_broad_shape = std::get<0>(at::native::_linalg_broadcast_batch_dims(B_, A));
|
||||
// We disallow the broadcasting of B as a vector when left=False as, in that case, A.shape = (*, 1, 1)
|
||||
TORCH_CHECK(left || !vector_case, "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. In this case linalg.solve is equivalent to B / A.squeeze(-1)");
|
||||
auto result_shape = vector_case ? IntArrayRef(B_broad_shape.data(), B_broad_shape.size() - 1)
|
||||
: B_broad_shape;
|
||||
auto result_strides = at::native::batched_matrix_contiguous_strides(result_shape, /*column_major=*/left);
|
||||
|
||||
set_output_strided(0, result_shape, result_strides, B.options(), {});
|
||||
|
||||
auto shape = A.sizes();
|
||||
auto ndim = shape.size();
|
||||
|
||||
// LU
|
||||
auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
|
||||
set_output_strided(1, shape, LU_strides, A.options(), {});
|
||||
|
||||
// Pivots
|
||||
set_output_contiguous(2, shape.slice(0, ndim - 1), A.options().dtype(kInt));
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(linalg_lu_factor_ex)(const Tensor& A, bool pivot, bool check_errors) {
|
||||
TORCH_CHECK(A.dim() >= 2, "torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: ", A.sizes(), " instead");
|
||||
|
||||
@ -1368,145 +1407,7 @@ bool _requires_fw_or_bw_grad(const Tensor& input) {
|
||||
|| input._fw_grad(/*level */ 0).defined());
|
||||
}
|
||||
|
||||
// Solves a system of linear equations matmul(input, x) = other in-place
|
||||
// LAPACK/MAGMA error codes are saved in 'infos' tensor, they are not checked here
|
||||
static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor& input, const Tensor& other) {
|
||||
checkSameDevice("linalg.solve", result, input);
|
||||
checkSameDevice("linalg.solve", other, input, "other");
|
||||
checkLinalgCompatibleDtype("linalg.solve", result, input);
|
||||
|
||||
TORCH_CHECK(input.scalar_type() == other.scalar_type(),
|
||||
"input dtype ", input.scalar_type(), " does not match other dtype ", other.scalar_type());
|
||||
|
||||
squareCheckInputs(input, "linalg.solve");
|
||||
TORCH_CHECK(other.dim() >= 1,
|
||||
"other should have at least 1 dimension, but has ", other.dim(), " dimensions instead");
|
||||
|
||||
// Two types of 'other' tensors are supported:
|
||||
// - 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
|
||||
// - 2-dimensional (2D) tensor or batch of 2D tensors (matrix case)
|
||||
// original torch.solve supported only the matrix case, while NumPy works for both cases
|
||||
// for the batched input we need to be able to distinguish them
|
||||
bool vector_case = linalg_solve_is_vector_rhs(input, other);
|
||||
|
||||
bool is_batched_column_major = false;
|
||||
if (vector_case) {
|
||||
is_batched_column_major = result.is_contiguous();
|
||||
} else if (!vector_case && result.dim() >= 2) {
|
||||
is_batched_column_major = result.mT().is_contiguous();
|
||||
}
|
||||
|
||||
// if 'other' is a batch of 2D tensors, then 'input' can be non-batched and will be broadcasted
|
||||
auto expected_shape = IntArrayRef(input.sizes().data(), input.dim() - 1); // input.shape[:-1]
|
||||
if (!vector_case && other.dim() > 2) {
|
||||
expected_shape = other.sizes();
|
||||
}
|
||||
|
||||
bool result_equal_expected_shape = result.sizes().equals(expected_shape);
|
||||
bool result_input_same_type = (result.scalar_type() == input.scalar_type());
|
||||
|
||||
// if result is not empty and not in batched column major format
|
||||
bool copy_needed = (result.numel() != 0 && !is_batched_column_major);
|
||||
copy_needed |= !result_input_same_type; // or result does not have the same dtype as input
|
||||
copy_needed |= (result.numel() != 0 && !result_equal_expected_shape); // or result does not have the expected shape
|
||||
// we have to allocate a temporary tensor
|
||||
if (copy_needed) {
|
||||
Tensor result_tmp = at::empty({0}, input.options());
|
||||
result_tmp = linalg_solve_out_info(result_tmp, infos, input, other);
|
||||
at::native::resize_output(result, result_tmp.sizes());
|
||||
result.copy_(result_tmp);
|
||||
return result;
|
||||
}
|
||||
// else use result's storage directly
|
||||
|
||||
// we need to unsqueeze 'other' because 2-dimensional tensors are expected in the implementation
|
||||
Tensor other_ = vector_case ? other.unsqueeze(-1) : other;
|
||||
|
||||
// _linalg_broadcast_batch_dims also includes linearSolveCheckInputs
|
||||
// it checks for squareness of 'input' and 'shape' compatibility of 'other' and 'input'
|
||||
Tensor other_broadcasted;
|
||||
std::tie(other_broadcasted, std::ignore) = _linalg_broadcast_batch_dims(other_, input, "linalg.solve");
|
||||
|
||||
auto squeezed_other_broadcasted = at::squeeze(other_broadcasted, -1);
|
||||
auto squeezed_result_shape = squeezed_other_broadcasted.sizes();
|
||||
|
||||
// if result has no elements we can modify it
|
||||
if (result.numel() == 0) {
|
||||
if (vector_case) {
|
||||
result.resize_(squeezed_result_shape);
|
||||
} else {
|
||||
at::native::resize_as_(result, other_broadcasted.mT(), MemoryFormat::Contiguous);
|
||||
result.transpose_(-2, -1);
|
||||
}
|
||||
}
|
||||
|
||||
auto expected_result_shape = vector_case ? squeezed_result_shape : other_broadcasted.sizes();
|
||||
TORCH_INTERNAL_ASSERT(result.sizes().equals(expected_result_shape));
|
||||
TORCH_INTERNAL_ASSERT(result.scalar_type() == input.scalar_type());
|
||||
TORCH_INTERNAL_ASSERT(result.device() == input.device());
|
||||
|
||||
// result tensor must be in batched column major order (Fortran contiguous) for 2D inputs
|
||||
// or C contiguous for 1D input
|
||||
if (vector_case) {
|
||||
TORCH_INTERNAL_ASSERT(result.is_contiguous());
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(result.mT().is_contiguous());
|
||||
}
|
||||
|
||||
// for 1-dimensional 'other', we need to unsqueeze the result before passing to "apply_solve"
|
||||
if (vector_case) {
|
||||
result = result.unsqueeze_(-1);
|
||||
}
|
||||
|
||||
// lu_factor_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
|
||||
result.copy_(other_broadcasted);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(infos.scalar_type() == kInt);
|
||||
TORCH_INTERNAL_ASSERT(infos.device() == input.device());
|
||||
infos.resize_({std::max<int64_t>(1, batchCount(input))});
|
||||
// if input is empty infos might not get filled; make sure infos doesn't contain garbage then
|
||||
if (input.numel() == 0) {
|
||||
infos.fill_(0);
|
||||
}
|
||||
|
||||
// compute the LU factorization of 'input_working_copy'
|
||||
auto input_working_copy = cloneBatchedColumnMajor(input);
|
||||
auto pivots_shape = IntArrayRef(input.sizes().data(), input.dim() - 2).vec(); // input.shape[:-2]
|
||||
pivots_shape.push_back(std::min(input.size(-2), input.size(-1)));
|
||||
Tensor pivots = at::empty(pivots_shape, input.options().dtype(kInt));
|
||||
lu_factor_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true);
|
||||
|
||||
// solve the linear system using the LU factorization
|
||||
lu_solve_stub(input.device().type(), input_working_copy, pivots, result, TransposeType::NoTranspose);
|
||||
|
||||
// for 1-dimensional 'other', we need to squeeze the result after "apply_solve"
|
||||
if (vector_case) {
|
||||
result = result.squeeze_(-1);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Solves a system of linear equations matmul(input, x) = other in-place
|
||||
Tensor& linalg_solve_out(const Tensor& input, const Tensor& other, Tensor& result) {
|
||||
auto infos = at::empty({0}, input.options().dtype(kInt));
|
||||
result = linalg_solve_out_info(result, infos, input, other);
|
||||
|
||||
// Now check LAPACK/MAGMA error codes
|
||||
// _linalg_check_errors calls 'infos = infos.to(kCPU)'
|
||||
at::_linalg_check_errors(infos, "linalg.solve", input.dim() == 2);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Solves a system of linear equations matmul(input, x) = other
|
||||
Tensor linalg_solve(const Tensor& input, const Tensor& other) {
|
||||
Tensor result = at::empty({0}, input.options());
|
||||
result = at::linalg_solve_out(result, input, other);
|
||||
return result;
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/*
|
||||
Computes the inverse of n-by-n matrix 'self'
|
||||
This is an in-place routine, it overwrites the content of 'self'.
|
||||
@ -2058,6 +1959,53 @@ Tensor cholesky_inverse(const Tensor &input, bool upper) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
// Auxiliary function that returns the LU decomposition to use it in the backward
|
||||
TORCH_IMPL_FUNC(_linalg_solve_out)(const Tensor& A,
|
||||
const Tensor& B,
|
||||
bool left,
|
||||
const Tensor& result,
|
||||
const Tensor& LU,
|
||||
const Tensor& pivots) {
|
||||
// Possible optimization: Compute the LU factorization of A^T if A is contiguous
|
||||
// Then we solve A^T X = B with adjoint=True
|
||||
// This saves a copy as A doesn't need to be copied into an F-contig matrix in lu_factor
|
||||
const bool use_A_T = A.is_contiguous() && !A.is_complex();
|
||||
auto info = at::empty({0}, A.options().dtype(kInt));
|
||||
at::linalg_lu_factor_ex_out(const_cast<Tensor&>(LU),
|
||||
const_cast<Tensor&>(pivots),
|
||||
const_cast<Tensor&>(info),
|
||||
use_A_T ? A.mT() : A,
|
||||
/*pivot=*/true,
|
||||
/*check_errors=*/false);
|
||||
at::_linalg_check_errors(info, "torch.linalg.solve", A.dim() == 2);
|
||||
|
||||
// [numpy-compat] Handle vectors on the rhs
|
||||
const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, B);
|
||||
auto result_ = vector_case ? result.unsqueeze(-1) : result;
|
||||
auto B_ = vector_case ? B.unsqueeze(-1) : B;
|
||||
at::linalg_lu_solve_out(result_, LU, pivots, B_, left, /*adjoint*/use_A_T);
|
||||
}
|
||||
|
||||
Tensor& linalg_solve_out(const Tensor& A,
|
||||
const Tensor& B,
|
||||
bool left,
|
||||
Tensor& result) {
|
||||
|
||||
auto LU = at::empty({0}, A.options());
|
||||
auto pivots = at::empty({0}, A.options().dtype(kInt));
|
||||
at::_linalg_solve_out(result, LU, pivots, A, B, left);
|
||||
return result;
|
||||
}
|
||||
|
||||
// We implement linalg_solve as a composite function of _linalg_solve
|
||||
Tensor linalg_solve(const Tensor& A,
|
||||
const Tensor& B,
|
||||
bool left) {
|
||||
return std::get<0>(at::_linalg_solve(A, B, left));
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_factor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
DEFINE_DISPATCH(lu_factor_stub);
|
||||
|
||||
@ -329,7 +329,7 @@ static inline void singleCheckErrors(int64_t info, const c10::string_view name,
|
||||
} else if (name.find("solve") != name.npos) {
|
||||
// solve, linalg_solve, cholesky_solve, etc.
|
||||
TORCH_CHECK_LINALG(false, name, batch_string,
|
||||
": The diagonal element ", info, " is zero, the solve could not be completed because the input matrix is singular.");
|
||||
": The solver failed because the input matrix is singular.");
|
||||
} else if (name.find("cholesky") != name.npos) {
|
||||
TORCH_CHECK_LINALG(false, name, batch_string,
|
||||
": The factorization could not be completed because the input is not positive-definite (the leading minor of order ", info, " is not positive-definite).");
|
||||
|
||||
@ -11859,16 +11859,19 @@
|
||||
python_module: linalg
|
||||
variants: function
|
||||
|
||||
- func: linalg_solve(Tensor input, Tensor other) -> Tensor
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: linalg_solve
|
||||
- func: _linalg_solve(Tensor A, Tensor B, *, bool left=True) -> (Tensor result, Tensor LU, Tensor pivots)
|
||||
structured_delegate: _linalg_solve.result
|
||||
|
||||
- func: linalg_solve.out(Tensor input, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: linalg
|
||||
- func: _linalg_solve.result(Tensor A, Tensor B, *, bool left=True, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots)
|
||||
structured: True
|
||||
dispatch:
|
||||
CPU, CUDA: linalg_solve_out
|
||||
CPU, CUDA: _linalg_solve_out
|
||||
|
||||
- func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor
|
||||
python_module: linalg
|
||||
|
||||
- func: linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: linalg
|
||||
|
||||
- func: linalg_tensorinv(Tensor self, int ind=2) -> Tensor
|
||||
python_module: linalg
|
||||
|
||||
@ -2915,13 +2915,6 @@
|
||||
"Generator"
|
||||
],
|
||||
"torch.return_types": [
|
||||
"_det_lu_based_helper",
|
||||
"_fake_quantize_per_tensor_affine_cachemask_tensor_qparams",
|
||||
"_fused_moving_avg_obs_fq_helper",
|
||||
"_linalg_svd",
|
||||
"_linalg_svd_out",
|
||||
"_lu_with_info",
|
||||
"_unpack_dual",
|
||||
"attr",
|
||||
"pytree_register_structseq"
|
||||
],
|
||||
|
||||
@ -52,6 +52,8 @@ ALLOW_LIST = [
|
||||
("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)),
|
||||
("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)),
|
||||
("aten::randperm", datetime.date(9999, 1, 1)),
|
||||
("aten::linalg_solve", datetime.date(2022, 8, 31)),
|
||||
("aten::linalg_solve.out", datetime.date(2022, 8, 31)),
|
||||
("aten::gelu", datetime.date(2022, 3, 1)),
|
||||
("aten::gelu_backward", datetime.date(2022, 3, 1)),
|
||||
("aten::cudnn_convolution_backward", datetime.date(2022, 1, 31)),
|
||||
@ -112,6 +114,8 @@ ALLOW_LIST = [
|
||||
("q::_FloatToBfloat16Quantized", datetime.date(2021, 12, 21)),
|
||||
("q::_Bfloat16QuantizedToFloat", datetime.date(2021, 12, 21)),
|
||||
("aten::_inverse_helper", datetime.date(2021, 12, 31)),
|
||||
("aten::linalg_solve", datetime.date(2022, 8, 31)),
|
||||
("aten::linalg_solve.out", datetime.date(2022, 8, 31)),
|
||||
("aten::softplus_backward", datetime.date(2022, 1, 31)),
|
||||
("aten::softplus_backward.grad_input", datetime.date(2022, 1, 31)),
|
||||
("aten::quantile", datetime.date(2022, 9, 30)),
|
||||
|
||||
@ -3083,96 +3083,12 @@ class TestLinalg(TestCase):
|
||||
expected = np.linalg.solve(A.cpu().numpy(), b.expand_as(x).cpu().numpy())
|
||||
self.assertEqual(x, expected)
|
||||
|
||||
# Check out= variant
|
||||
out = torch.empty_like(x)
|
||||
ans = torch.linalg.solve(A, b, out=out)
|
||||
self.assertEqual(ans, out)
|
||||
self.assertEqual(x, out)
|
||||
|
||||
# Check out= variant with complex128 out tensor
|
||||
out = torch.empty_like(x).to(torch.complex128)
|
||||
ans = torch.linalg.solve(A, b, out=out)
|
||||
self.assertEqual(ans, out)
|
||||
self.assertEqual(x.to(torch.complex128), out)
|
||||
|
||||
# Check empty out
|
||||
out = torch.empty(0, dtype=dtype, device=device)
|
||||
ans = torch.linalg.solve(A, b, out=out)
|
||||
self.assertEqual(ans, out)
|
||||
self.assertEqual(x, out)
|
||||
|
||||
batches = [(), (0, ), (3, ), (2, 3)]
|
||||
ns = [0, 5, 32]
|
||||
nrhs = [(), (1, ), (5, )]
|
||||
for n, batch, rhs in itertools.product(ns, batches, nrhs):
|
||||
run_test(n, batch, rhs)
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_solve_errors_and_warnings(self, device, dtype):
|
||||
# solve expects batches of square matrices as input
|
||||
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
|
||||
a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device)
|
||||
b = torch.randn(2, 3, 4, 1, dtype=dtype, device=device)
|
||||
torch.linalg.solve(a, b)
|
||||
|
||||
# solve expects compatible shapes for A x = b
|
||||
with self.assertRaisesRegex(RuntimeError, "Incompatible matrix sizes"):
|
||||
a = torch.randn(2, 3, 3, 3, dtype=dtype, device=device)
|
||||
b = torch.randn(2, 3, 2, 1, dtype=dtype, device=device)
|
||||
torch.linalg.solve(a, b)
|
||||
|
||||
# if input is not solvable, RuntimeError is raised mentioning the first non-solvable batch
|
||||
def run_test_singular_input(batch_dim, n):
|
||||
a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
|
||||
a[n, -1, -1] = 0
|
||||
b = torch.randn(batch_dim, 3, 1, dtype=dtype, device=device)
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
|
||||
torch.linalg.solve(a, b)
|
||||
|
||||
for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
|
||||
run_test_singular_input(*params)
|
||||
|
||||
# if out tensor with wrong shape is passed a warning is given
|
||||
# matrix 'b' case
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
A = torch.eye(2, dtype=dtype, device=device).reshape((1, 2, 2)).repeat(2, 1, 1)
|
||||
b = torch.randn(2, 2, 2, dtype=dtype, device=device)
|
||||
out = torch.zeros(1, dtype=dtype, device=device)
|
||||
# Trigger warning
|
||||
torch.linalg.solve(A, b, out=out)
|
||||
# Check warning occurs
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
|
||||
|
||||
# if out tensor with wrong shape is passed a warning is given
|
||||
# vector 'b' case
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
A = torch.eye(2, dtype=dtype, device=device)
|
||||
b = torch.randn(2, dtype=dtype, device=device)
|
||||
out = torch.zeros(1, dtype=dtype, device=device)
|
||||
# Trigger warning
|
||||
torch.linalg.solve(A, b, out=out)
|
||||
# Check warning occurs
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
|
||||
|
||||
# dtypes should be safely castable
|
||||
a = torch.eye(2, dtype=dtype, device=device)
|
||||
b = torch.randn(2, 1, dtype=dtype, device=device)
|
||||
out = torch.empty(0, dtype=torch.int, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
|
||||
torch.linalg.solve(a, b, out=out)
|
||||
|
||||
# device should match
|
||||
if torch.cuda.is_available():
|
||||
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
|
||||
out = torch.empty(0, dtype=dtype, device=wrong_device)
|
||||
clone_a = torch.empty_like(a)
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
|
||||
torch.linalg.solve(a, b, out=out)
|
||||
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@ -5103,13 +5019,14 @@ class TestLinalg(TestCase):
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_linalg_lu_factor_and_lu_and_lu_unpack(self, device, dtype):
|
||||
def test_linalg_lu_family(self, device, dtype):
|
||||
# Tests torch.lu
|
||||
# torch.linalg.lu_factor
|
||||
# torch.linalg.lu_factor_ex
|
||||
# torch.lu_unpack
|
||||
# torch.linalg.lu_solve
|
||||
from torch.testing._internal.common_utils import random_matrix
|
||||
# torch.linalg.solve
|
||||
make_arg_full = partial(make_fullrank_matrices_with_distinct_singular_values, device=device, dtype=dtype)
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
|
||||
def run_test(A, pivot, singular, fn):
|
||||
@ -5141,20 +5058,37 @@ class TestLinalg(TestCase):
|
||||
self.assertEqual(L, PLU.L)
|
||||
self.assertEqual(U, PLU.U)
|
||||
|
||||
rhs = 3
|
||||
if not singular and A.size(-2) == A.size(-1):
|
||||
for left in (True, False):
|
||||
shape_B = list(A.shape)
|
||||
dim = -1 if left else -2
|
||||
shape_B[dim] = rhs
|
||||
nrhs = ((), (1,), (3,))
|
||||
for left, rhs in product((True, False), nrhs):
|
||||
# Vector case when left = False is not allowed
|
||||
if not left and rhs == ():
|
||||
continue
|
||||
if left:
|
||||
shape_B = A.shape[:-1] + rhs
|
||||
else:
|
||||
shape_B = A.shape[:-2] + rhs + A.shape[-1:]
|
||||
B = make_arg(shape_B)
|
||||
for adjoint in (True, False):
|
||||
X = torch.linalg.lu_solve(LU, pivots, B, left=left, adjoint=adjoint)
|
||||
A_adj = A.mH if adjoint else A
|
||||
if left:
|
||||
self.assertEqual(B, A_adj @ X)
|
||||
else:
|
||||
self.assertEqual(B, X @ A_adj)
|
||||
|
||||
# Test linalg.lu_solve. It does not support vectors as rhs
|
||||
# See https://github.com/pytorch/pytorch/pull/74045#issuecomment-1112304913
|
||||
if rhs != ():
|
||||
for adjoint in (True, False):
|
||||
X = torch.linalg.lu_solve(LU, pivots, B, left=left, adjoint=adjoint)
|
||||
A_adj = A.mH if adjoint else A
|
||||
if left:
|
||||
self.assertEqual(B, A_adj @ X)
|
||||
else:
|
||||
self.assertEqual(B, X @ A_adj)
|
||||
|
||||
# Test linalg.solve
|
||||
X = torch.linalg.solve(A, B, left=left)
|
||||
X_ = X.unsqueeze(-1) if rhs == () else X
|
||||
B_ = B.unsqueeze(-1) if rhs == () else B
|
||||
if left:
|
||||
self.assertEqual(B_, A @ X_)
|
||||
else:
|
||||
self.assertEqual(B_, X_ @ A)
|
||||
|
||||
|
||||
sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0))
|
||||
@ -5163,8 +5097,8 @@ class TestLinalg(TestCase):
|
||||
pivots = (True, False) if self.device_type == "cuda" else (True,)
|
||||
fns = (partial(torch.lu, get_infos=True), torch.linalg.lu_factor, torch.linalg.lu_factor_ex)
|
||||
for ms, batch, pivot, singular, fn in itertools.product(sizes, batches, pivots, (True, False), fns):
|
||||
m, n = ms
|
||||
A = random_matrix(m, n, *batch, singular=singular, dtype=dtype, device=device)
|
||||
shape = batch + ms
|
||||
A = make_arg(shape) if singular else make_arg_full(*shape)
|
||||
# Just do one of them on singular matrices
|
||||
if A.numel() == 0 and not singular:
|
||||
continue
|
||||
@ -6159,38 +6093,6 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
for b1, b2, ref, out_tensor in generate_tensor():
|
||||
self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor)
|
||||
|
||||
# TODO: update to compare against NumPy
|
||||
@onlyCUDA
|
||||
def test_solve_methods_arg_device(self, device):
|
||||
for b_device, A_device in itertools.product(['cpu', device], repeat=2):
|
||||
if b_device == A_device:
|
||||
continue
|
||||
|
||||
b = torch.randn(3, 1, device=b_device)
|
||||
A = torch.randn(3, 3, device=A_device)
|
||||
|
||||
# cholesky_solve goes through generic backend dispatch and hit kernel specific device check first
|
||||
# triangular_solve goes through specific backend dispatch (CPU/CUDA) and hit auto-generated device check first
|
||||
generic_backend_dispatch_err_str = "Expected b and A to be on the same device"
|
||||
specific_backend_dispatch_err_str = "Expected all tensors to be on the same device"
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, generic_backend_dispatch_err_str):
|
||||
torch.cholesky_solve(b, A)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, specific_backend_dispatch_err_str):
|
||||
torch.triangular_solve(b, A)
|
||||
|
||||
# b and A have to be modified to match accepted inputs sizes for lu_solve
|
||||
b = b.unsqueeze(0)
|
||||
A = A.unsqueeze(0)
|
||||
with self.assertRaisesRegex(RuntimeError, specific_backend_dispatch_err_str):
|
||||
torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=A_device).int())
|
||||
|
||||
# This checks if a suitable error message is thrown
|
||||
# when LU output and pivots are not on the same device
|
||||
with self.assertRaisesRegex(RuntimeError, specific_backend_dispatch_err_str):
|
||||
torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=b_device).int())
|
||||
|
||||
@precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3})
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
|
||||
@ -466,9 +466,7 @@ meta_function_expected_failures = {
|
||||
torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product
|
||||
torch.linalg.lstsq: {f32, f64}, # aten::linalg_lstsq.out
|
||||
torch.linalg.slogdet: {f32, f64}, # aten::linalg_slogdet
|
||||
torch.linalg.solve: {f32, f64}, # aten::linalg_solve, aten::linalg_solve.out
|
||||
torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular
|
||||
torch.linalg.tensorsolve: {f32, f64}, # aten::linalg_solve
|
||||
torch.logdet: {f32, f64}, # aten::_local_scalar_dense, aten::nonzero
|
||||
}
|
||||
|
||||
@ -719,13 +717,9 @@ meta_dispatch_expected_failures = {
|
||||
aten.linalg_householder_product.out: {f32, f64}, # aten::linalg_householder_product.out
|
||||
aten.linalg_lstsq.default: {f32, f64}, # aten::linalg_lstsq.out
|
||||
aten.linalg_slogdet.default: {f32, f64}, # aten::linalg_slogdet
|
||||
aten.linalg_solve.default: {f32, f64}, # aten::linalg_solve
|
||||
aten.linalg_solve.out: {f32, f64}, # aten::linalg_solve.out
|
||||
aten.linalg_solve_triangular.default: {f32, f64}, # aten::linalg_solve_triangular
|
||||
aten.linalg_solve_triangular.out: {f32, f64}, # aten::linalg_solve_triangular.out
|
||||
aten.logdet.default: {f32, f64}, # aten::_local_scalar_dense, aten::nonzero
|
||||
aten.lu_solve.default: {f32, f64}, # aten::lu_solve
|
||||
aten.lu_solve.out: {f32, f64}, # aten::lu_solve.out
|
||||
aten.ormqr.default: {f32, f64}, # aten::ormqr
|
||||
aten.ormqr.out: {f32, f64}, # aten::ormqr.out
|
||||
aten.symeig.default: {f32, f64}, # aten::_symeig_helper
|
||||
|
||||
@ -21,7 +21,7 @@ all_operators_with_namedtuple_return = {
|
||||
'frexp', 'lu_unpack', 'histogram', 'histogramdd',
|
||||
'_fake_quantize_per_tensor_affine_cachemask_tensor_qparams',
|
||||
'_fused_moving_avg_obs_fq_helper', 'linalg_lu_factor', 'linalg_lu_factor_ex', 'linalg_lu',
|
||||
'_det_lu_based_helper', '_lu_with_info', 'linalg_ldl_factor_ex', 'linalg_ldl_factor',
|
||||
'_det_lu_based_helper', '_lu_with_info', 'linalg_ldl_factor_ex', 'linalg_ldl_factor', '_linalg_solve'
|
||||
}
|
||||
|
||||
|
||||
@ -85,6 +85,7 @@ class TestNamedTupleAPI(TestCase):
|
||||
op(operators=['linalg_slogdet'], input=(), names=('sign', 'logabsdet'), hasout=True),
|
||||
op(operators=['linalg_cholesky_ex'], input=(), names=('L', 'info'), hasout=True),
|
||||
op(operators=['linalg_inv_ex'], input=(), names=('inverse', 'info'), hasout=True),
|
||||
op(operators=['_linalg_solve'], input=(a,), names=('result', 'LU', 'pivots'), hasout=True),
|
||||
op(operators=['linalg_lu_factor'], input=(), names=('LU', 'pivots'), hasout=True),
|
||||
op(operators=['linalg_lu_factor_ex'], input=(), names=('LU', 'pivots', 'info'), hasout=True),
|
||||
op(operators=['linalg_ldl_factor'], input=(), names=('LD', 'pivots'), hasout=True),
|
||||
|
||||
@ -1437,10 +1437,10 @@
|
||||
self: slogdet_backward(grad, self, sign, logabsdet)
|
||||
output_differentiability: [false, true]
|
||||
|
||||
- name: linalg_solve(Tensor input, Tensor other) -> Tensor
|
||||
input: solve_backward_A(grad, other, input, result)
|
||||
other: solve_backward_self(grad, other, input)
|
||||
result: solve_jvp(result, input_p, input_t, other_t)
|
||||
- name: _linalg_solve(Tensor A, Tensor B, *, bool left=True) -> (Tensor result, Tensor LU, Tensor pivots)
|
||||
A, B: linalg_solve_backward(grad, result, A, LU, pivots, left, grad_input_mask[1])
|
||||
result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left, A_p.is_contiguous() && !A_p.is_complex())"
|
||||
output_differentiability: [True, False, False] # LU is an auxiliary tensor not exposed to the user
|
||||
|
||||
- name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
|
||||
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true)
|
||||
|
||||
@ -337,6 +337,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
||||
"pixel_shuffle",
|
||||
"pixel_unshuffle",
|
||||
"linalg_lu_solve",
|
||||
"_linalg_solve",
|
||||
}
|
||||
|
||||
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
|
||||
|
||||
@ -196,12 +196,12 @@ inline std::tuple<Tensor&, Tensor&> qr_out(Tensor& Q, Tensor& R, const Tensor& i
|
||||
return torch::linalg_qr_out(Q, R, input, mode);
|
||||
}
|
||||
|
||||
inline Tensor solve(const Tensor& input, const Tensor& other) {
|
||||
return torch::linalg_solve(input, other);
|
||||
inline Tensor solve(const Tensor& input, const Tensor& other, bool left) {
|
||||
return torch::linalg_solve(input, other, left);
|
||||
}
|
||||
|
||||
inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& other) {
|
||||
return torch::linalg_solve_out(result, input, other);
|
||||
inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& other, bool left) {
|
||||
return torch::linalg_solve_out(result, input, other, left);
|
||||
}
|
||||
|
||||
inline Tensor solve_triangular(const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) {
|
||||
@ -566,12 +566,12 @@ inline Tensor& ldl_solve_out(
|
||||
/// Computes a tensor `x` such that `matmul(input, x) = other`.
|
||||
///
|
||||
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.solve
|
||||
inline Tensor solve(const Tensor& input, const Tensor& other) {
|
||||
return detail::solve(input, other);
|
||||
inline Tensor solve(const Tensor& input, const Tensor& other, bool left) {
|
||||
return detail::solve(input, other, left);
|
||||
}
|
||||
|
||||
inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& other) {
|
||||
return detail::solve_out(result, input, other);
|
||||
inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& other, bool left) {
|
||||
return detail::solve_out(result, input, other, left);
|
||||
}
|
||||
|
||||
/// Computes a solution of a linear system AX = B for input = A and other = B whenever A is square
|
||||
|
||||
@ -563,38 +563,6 @@ static Tensor generic_solve_jvp(
|
||||
return solve(A, dB, dA_contrib);
|
||||
}
|
||||
|
||||
Tensor solve_jvp(
|
||||
const Tensor& X,
|
||||
const Tensor& A,
|
||||
const Tensor& dA,
|
||||
const Tensor& dB
|
||||
) {
|
||||
return generic_solve_jvp(
|
||||
[](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
|
||||
return at::linalg_solve(A, dB - dA_contrib);
|
||||
},
|
||||
X, A, dA, dB
|
||||
);
|
||||
}
|
||||
|
||||
Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) {
|
||||
return at::linalg_solve(A.mH(), grad);
|
||||
}
|
||||
|
||||
Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) {
|
||||
at::NoTF32Guard disable_tf32;
|
||||
Tensor grad_self = solve_backward_self(grad, self, A);
|
||||
if (self.ndimension() == 2 && A.ndimension() == 2) {
|
||||
return -at::mm(grad_self, solution.mH());
|
||||
}
|
||||
// if self was unsqueezed from (..., M) to (..., M, 1)
|
||||
bool vector_case = at::native::linalg_solve_is_vector_rhs(A, self);
|
||||
if (vector_case) {
|
||||
return -at::matmul(grad_self.unsqueeze(-1), solution.unsqueeze(-1).mH());
|
||||
}
|
||||
return -at::matmul(grad_self, solution.mH());
|
||||
}
|
||||
|
||||
Tensor cumsum_backward(const Tensor & grad, int64_t dim) {
|
||||
// Trivial case
|
||||
if (grad.numel() <= 1 || grad.size(dim) == 1) {
|
||||
@ -4911,6 +4879,89 @@ Tensor linalg_lu_solve_jvp(
|
||||
}
|
||||
}
|
||||
|
||||
Tensor linalg_solve_jvp(
|
||||
const Tensor& dA,
|
||||
const Tensor& dB,
|
||||
const Tensor& X,
|
||||
const Tensor& LU,
|
||||
const Tensor& pivots,
|
||||
const bool left,
|
||||
const bool use_A_T) {
|
||||
at::NoTF32Guard disable_tf32;
|
||||
// For left=True (left=False is analogous)
|
||||
// dX = A^{-1}(dB - dAX)
|
||||
|
||||
// [NumPy compat] Case where the rhs is a vector.
|
||||
// We denote with an underscore vectors that have been converted to matrices by `unsqueeze(-1)`
|
||||
const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, X);
|
||||
const auto vector_to_matrix = [vector_case](const Tensor& X) { return vector_case ? X.unsqueeze(-1) : X; };
|
||||
const auto matrix_to_vector = [vector_case](const Tensor& X) { return vector_case ? X.squeeze(-1) : X; };
|
||||
|
||||
// This case is disallowed in the primal operation as A.shape = (*, 1, 1)
|
||||
TORCH_INTERNAL_ASSERT(left || !vector_case);
|
||||
|
||||
auto X_ = vector_to_matrix(X);
|
||||
auto dB_ = vector_to_matrix(dB);
|
||||
auto R_ = left ? dA.matmul(X_) : X_.matmul(dA);
|
||||
auto dX_ = at::linalg_lu_solve(LU, pivots, dB_ - R_, left, /*adjoint*/use_A_T);
|
||||
return matrix_to_vector(dX_);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> linalg_solve_backward(
|
||||
const Tensor& gX,
|
||||
const Tensor& X,
|
||||
const Tensor& A,
|
||||
const Tensor& LU,
|
||||
const Tensor& pivots,
|
||||
const bool left,
|
||||
const bool B_requires_grad) {
|
||||
// for X = A^{-1}B
|
||||
// gB = A^{-H}gX
|
||||
// gA = -gB X^H
|
||||
at::NoTF32Guard disable_tf32;
|
||||
const auto A_requires_grad = A.requires_grad();
|
||||
if (!gX.defined() || (!A_requires_grad && !B_requires_grad)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// [NumPy compat] Case where the rhs is a vector.
|
||||
// We denote with an underscore vectors that have been converted to matrices by `unsqueeze(-1)`
|
||||
const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, X);
|
||||
const auto vector_to_matrix = [vector_case](const Tensor& X) { return vector_case ? X.unsqueeze(-1) : X; };
|
||||
const auto matrix_to_vector = [vector_case](const Tensor& X) { return vector_case ? X.squeeze(-1) : X; };
|
||||
|
||||
// If the user is going to compute higher order gradients, then we need to recompute the LU and the pivots
|
||||
Tensor gB_;
|
||||
if (at::GradMode::is_enabled()) {
|
||||
gB_ = at::linalg_solve(A.mH(), vector_to_matrix(gX), left);
|
||||
} else {
|
||||
const auto use_A_T = A.is_contiguous() && !A.is_complex();
|
||||
gB_ = at::linalg_lu_solve(LU, pivots, vector_to_matrix(gX), left, /*adjoint*/!use_A_T);
|
||||
}
|
||||
|
||||
Tensor gA_;
|
||||
if (A_requires_grad) {
|
||||
auto X_ = vector_to_matrix(X);
|
||||
gA_ = left ? -gB_.matmul(X_.mH()) : -X_.mH().matmul(gB_);
|
||||
}
|
||||
return std::make_tuple(A_requires_grad ? matrix_to_vector(gA_) : Tensor{},
|
||||
B_requires_grad ? matrix_to_vector(gB_) : Tensor{});
|
||||
}
|
||||
|
||||
Tensor solve_jvp(
|
||||
const Tensor& X,
|
||||
const Tensor& A,
|
||||
const Tensor& dA,
|
||||
const Tensor& dB
|
||||
) {
|
||||
return generic_solve_jvp(
|
||||
[](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
|
||||
return at::linalg_solve(A, dB - dA_contrib);
|
||||
},
|
||||
X, A, dA, dB
|
||||
);
|
||||
}
|
||||
|
||||
Tensor lu_unpack_backward(
|
||||
const Tensor& L_grad,
|
||||
const Tensor& U_grad,
|
||||
|
||||
@ -380,6 +380,22 @@ Tensor linalg_lu_solve_jvp(
|
||||
const Tensor& dB,
|
||||
const bool left,
|
||||
const bool adjoint);
|
||||
std::tuple<Tensor, Tensor> linalg_solve_backward(
|
||||
const Tensor& gX,
|
||||
const Tensor& X,
|
||||
const Tensor& A,
|
||||
const Tensor& LU,
|
||||
const Tensor& pivots,
|
||||
const bool left,
|
||||
const bool B_requires_grad);
|
||||
Tensor linalg_solve_jvp(
|
||||
const Tensor& dA,
|
||||
const Tensor& dB,
|
||||
const Tensor& X,
|
||||
const Tensor& LU,
|
||||
const Tensor& pivots,
|
||||
const bool left,
|
||||
const bool use_A_T);
|
||||
Tensor lu_unpack_backward(
|
||||
const Tensor& L_grad,
|
||||
const Tensor& U_grad,
|
||||
|
||||
@ -4902,17 +4902,18 @@ REGISTER_OPERATOR_FUNCTOR(
|
||||
aten_linalg_solve,
|
||||
[](Node* n) -> SROperator {
|
||||
if (n->matches(torch::schema(
|
||||
"aten::linalg_solve(Tensor input, Tensor other) -> Tensor"))) {
|
||||
"aten::linalg_solve(Tensor input, Tensor other, bool left) -> Tensor"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& input = p_node->Input(0).toTensor();
|
||||
const auto& other = p_node->Input(1).toTensor();
|
||||
auto left = p_node->Input(2).toBool();
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = at::native::linalg_solve(input, other);
|
||||
p_node->Output(0) = at::native::linalg_solve(input, other, left);
|
||||
return;
|
||||
}
|
||||
auto& out = p_node->Output(0).toTensor();
|
||||
fastResizeToZero(out);
|
||||
at::native::linalg_solve_out(input, other, out);
|
||||
at::native::linalg_solve_out(input, other, left, out);
|
||||
};
|
||||
}
|
||||
LogAndDumpSchema(n);
|
||||
|
||||
@ -2882,7 +2882,7 @@ void nnc_aten_linalg_solve(
|
||||
const at::Tensor& input = tensors[1];
|
||||
const at::Tensor& other = tensors[2];
|
||||
try {
|
||||
at::linalg_solve_out(r, input, other);
|
||||
at::linalg_solve_out(r, input, other, true);
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
|
||||
@ -2055,7 +2055,7 @@ Example::
|
||||
|
||||
|
||||
solve = _add_docstr(_linalg.linalg_solve, r"""
|
||||
linalg.solve(A, B, *, out=None) -> Tensor
|
||||
linalg.solve(A, B, *, left=True, out=None) -> Tensor
|
||||
|
||||
Computes the solution of a square system of linear equations with a unique solution.
|
||||
|
||||
@ -2065,6 +2065,12 @@ this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the
|
||||
|
||||
.. math:: AX = B
|
||||
|
||||
If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that solves the system
|
||||
|
||||
.. math::
|
||||
|
||||
XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.}
|
||||
|
||||
This system of linear equations has one solution if and only if :math:`A` is `invertible`_.
|
||||
This function assumes that :math:`A` is invertible.
|
||||
|
||||
@ -2104,6 +2110,7 @@ Args:
|
||||
according to the rules described above
|
||||
|
||||
Keyword args:
|
||||
left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`.
|
||||
out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
|
||||
|
||||
Raises:
|
||||
|
||||
@ -964,7 +964,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.smm: lambda input, mat2: -1,
|
||||
torch.spmm: lambda input, mat2: -1,
|
||||
torch.softmax: lambda input, dim, dtype=None: -1,
|
||||
torch.linalg.solve: lambda input, other, out=None: -1,
|
||||
torch.linalg.solve: lambda A, B, left=True, out=None: -1,
|
||||
torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1,
|
||||
torch.split: lambda tensor, split_size_or_sections, dim=0: -1,
|
||||
torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
|
||||
|
||||
@ -18,14 +18,17 @@ def pytree_register_structseq(cls):
|
||||
for name in dir(return_types):
|
||||
if name.startswith('__'):
|
||||
continue
|
||||
globals()[name] = getattr(return_types, name)
|
||||
__all__.append(name)
|
||||
|
||||
attr = getattr(return_types, name)
|
||||
globals()[name] = attr
|
||||
|
||||
if not name.startswith('_'):
|
||||
__all__.append(name)
|
||||
|
||||
# Today everything in torch.return_types is a structseq, aka a "namedtuple"-like
|
||||
# thing defined by the Python C-API. We're going to need to modify this when that
|
||||
# is no longer the case.
|
||||
# NB: I don't know how to check that something is a "structseq" so we do a fuzzy
|
||||
# check for tuple
|
||||
attr = globals()[name]
|
||||
if inspect.isclass(attr) and issubclass(attr, tuple):
|
||||
pytree_register_structseq(attr)
|
||||
|
||||
@ -6898,9 +6898,14 @@ def sample_inputs_legacy_solve(op_info, device, dtype, requires_grad=False, **kw
|
||||
op_info, device, dtype, requires_grad=requires_grad, vector_rhs_allowed=False
|
||||
)
|
||||
|
||||
def out_fn(output):
|
||||
return output[0]
|
||||
|
||||
# Reverses tensor order
|
||||
for sample in out:
|
||||
sample.input, sample.args = sample.args[0], (sample.input,)
|
||||
if op_info.name == "solve":
|
||||
sample.output_process_fn_grad = out_fn
|
||||
yield sample
|
||||
|
||||
|
||||
@ -15989,13 +15994,10 @@ op_db: List[OpInfo] = [
|
||||
op=torch.linalg.solve,
|
||||
dtypes=floating_and_complex_types(),
|
||||
sample_inputs_func=sample_inputs_linalg_solve,
|
||||
check_batched_gradgrad=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
|
||||
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
|
||||
skips=(
|
||||
# AssertionError: Scalars are not equal!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
|
||||
device_type='mps', dtypes=[torch.float32]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
|
||||
@ -18351,7 +18353,10 @@ op_db: List[OpInfo] = [
|
||||
sample_inputs_func=sample_inputs_tensorsolve,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagma],
|
||||
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack,
|
||||
DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}),
|
||||
'TestCommon', 'test_noncontiguous_samples',
|
||||
device_type='cuda')],
|
||||
),
|
||||
OpInfo(
|
||||
"nn.functional.mse_loss",
|
||||
|
||||
Reference in New Issue
Block a user