Compare commits

...

2 Commits

Author SHA1 Message Date
80295cdec1 Adjust derivative calculation for removed linalg.solve optimization 2025-10-09 09:24:32 +00:00
a475f1ac24 Avoid differing results in linalg.(tensor_)solve
Remove an optimization potentially using a transposed matrix as input
for `linalg_lu_factor_ex_out`.

Depending on whether the input memory layout is contiguous or not this
may lead to slightly different results which may cause larger
differences in subsequent steps ultimately leading to test failures in
e.g. `test_vmapvjp_linalg_tensorsolve_cpu_float32` & `test_vmapvjpvjp_linalg_tensorsolve_cpu_float32`.

The intended optimization no longer applies after 59bc76f so this code
can be removed too resolving the accuracy issues observed in those tests.

Fixes #151440
2025-10-09 09:24:32 +00:00
5 changed files with 7 additions and 24 deletions

View File

@ -382,14 +382,6 @@ fourOutputs solve_ex_batch_rule(
A_ = ensure_has_bdim(A_, A_bdim.has_value(), batch_size);
B_ = ensure_has_bdim(B_, B_bdim.has_value(), batch_size);
// NOTE [ solve_ex Batch Rule Contiguity ]
// A determines whether or not linalg_solve takes an optimized path. We need the check on A_ to match the one run on
// A as BatchedTensor since it might have been saved by autograd (specifically by the jvp) and the autograd behvaior
// differs based on whether or not the optimized path was taken
const auto batched_A_was_contiguous = A_bdim.has_value() ? at::select(A, *A_bdim, 0).is_contiguous() : A.is_contiguous();
if (batched_A_was_contiguous && !A.is_complex()) {
A_ = A_.contiguous();
}
auto res = _linalg_solve_ex(A_, B_, left, check_errors);
return std::make_tuple(std::move(std::get<0>(res)), 0, std::move(std::get<1>(res)), 0, std::move(std::get<2>(res)), 0, std::move(std::get<3>(res)), 0);
}

View File

@ -1957,15 +1957,10 @@ TORCH_IMPL_FUNC(_linalg_solve_ex_out)(const Tensor& A,
const Tensor& LU,
const Tensor& pivots,
const Tensor& info) {
// Possible optimization: Compute the LU factorization of A^T if A is contiguous
// Then we solve A^T X = B with adjoint=True
// This saves a copy as A doesn't need to be copied into an F-contig matrix in lu_factor
// This optimization makes functorch's batching rule difficult. See NOTE [ solve_ex Batch Rule Contiguity ]
const bool use_A_T = A.is_contiguous() && !A.is_complex();
at::linalg_lu_factor_ex_out(const_cast<Tensor&>(LU),
const_cast<Tensor&>(pivots),
const_cast<Tensor&>(info),
use_A_T ? A.mT() : A);
A);
if (check_errors) {
at::_linalg_check_errors(info, "torch.linalg.solve_ex", A.dim() == 2);
}
@ -1974,7 +1969,7 @@ TORCH_IMPL_FUNC(_linalg_solve_ex_out)(const Tensor& A,
const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, B);
auto result_ = vector_case ? result.unsqueeze(-1) : result;
auto B_ = vector_case ? B.unsqueeze(-1) : B;
at::linalg_lu_solve_out(result_, LU, pivots, B_, left, /*adjoint*/use_A_T);
at::linalg_lu_solve_out(result_, LU, pivots, B_, left);
}
std::tuple<Tensor&, Tensor&> linalg_solve_ex_out(const Tensor& A,

View File

@ -1586,7 +1586,7 @@
- name: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info)
A, B: linalg_solve_backward(grad, result, A, LU, pivots, left, grad_input_mask[1])
result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left, A_p.is_contiguous() && !A_p.is_complex())"
result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left)"
output_differentiability: [True, False, False, False] # LU is an auxiliary tensor not exposed to the user
- name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)

View File

@ -6033,8 +6033,7 @@ Tensor linalg_solve_jvp(
const Tensor& X,
const Tensor& LU,
const Tensor& pivots,
const bool left,
const bool use_A_T) {
const bool left) {
at::NoTF32Guard disable_tf32;
// For left=True (left=False is analogous)
// dX = A^{-1}(dB - dAX)
@ -6056,8 +6055,7 @@ Tensor linalg_solve_jvp(
auto X_ = vector_to_matrix(X);
auto dB_ = vector_to_matrix(dB);
auto R_ = left ? dA.matmul(X_) : X_.matmul(dA);
auto dX_ =
at::linalg_lu_solve(LU, pivots, dB_ - R_, left, /*adjoint*/ use_A_T);
auto dX_ = at::linalg_lu_solve(LU, pivots, dB_ - R_, left);
return matrix_to_vector(dX_);
}
@ -6095,9 +6093,8 @@ std::tuple<Tensor, Tensor> linalg_solve_backward(
if (at::GradMode::is_enabled()) {
gB_ = at::linalg_solve(A.mH(), vector_to_matrix(gX), left);
} else {
const auto use_A_T = A.is_contiguous() && !A.is_complex();
gB_ = at::linalg_lu_solve(
LU, pivots, vector_to_matrix(gX), left, /*adjoint*/ !use_A_T);
LU, pivots, vector_to_matrix(gX), left, /*adjoint*/ true);
}
Tensor gA_;

View File

@ -896,8 +896,7 @@ Tensor linalg_solve_jvp(
const Tensor& X,
const Tensor& LU,
const Tensor& pivots,
const bool left,
const bool use_A_T);
const bool left);
Tensor lu_unpack_backward(
const Tensor& L_grad,
const Tensor& U_grad,