[MPS] Fix internal assertion in torch.linalg.solve for singular matrices (#165254)

Fixes #163962 by special casing MPS in the negative status code branch in `_linalg_check_errors`.

Checks if info is [`MPSMatrixDecompositionStatus.singular`](https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus/singular) (which has a raw value of -2). I didn't find an official Apple source with this raw value (besides printing the enum value), so I'm not sure if we can (or should) depend on it? Is there a way to directly get the Objective-C enum value in C++?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165254
Approved by: https://github.com/malfet
This commit is contained in:
inventshah
2025-10-17 15:35:49 +00:00
committed by PyTorch MergeBot
parent 3af2f0c12a
commit 935ccdbe75
2 changed files with 35 additions and 0 deletions

View File

@ -196,6 +196,28 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output)
other.size(0) > max_stride_size || other.size(1) > max_stride_size);
}
void map_mps_decomposition_error_code_to_blas(const Tensor& status) {
const auto& status_flat = status.view(-1);
for (const auto i : c10::irange(status_flat.size(0))) {
int code = status_flat[i].item<int>();
switch (code) {
case MPSMatrixDecompositionStatusSuccess:
status_flat[i] = 0;
break;
case MPSMatrixDecompositionStatusNonPositiveDefinite:
case MPSMatrixDecompositionStatusSingular:
status_flat[i] = 2;
break;
case MPSMatrixDecompositionStatusFailure:
status_flat[i] = -1;
break;
default:
TORCH_INTERNAL_ASSERT(false, "Unknown MPSMatrixDecompositionStatus enum value: ", code);
}
}
}
} // anonymous namespace
static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
@ -487,6 +509,9 @@ static void linalg_solve_out_mps_impl(const Tensor& A,
"mpsmatrixdecompositionstatus for details.");
}
}
map_mps_decomposition_error_code_to_blas(info);
if (!left) {
// If this was a right solve, transpose the result back
result.copy_(result_t.transpose(-2, -1).contiguous());

View File

@ -1978,6 +1978,16 @@ class TestMPS(TestCaseMPS):
run_linalg_solve_test(32, 10, 10)
run_linalg_solve_test(32, 2, 2, 2, 2, 10, 10)
def test_linalg_solve_singular(self):
# Regression test for https://github.com/pytorch/pytorch/issues/163962
# Explicit singular matrix
A = torch.tensor([[1.0, 2.0], [2.0, 4.0]], device="mps")
b = torch.rand_like(A)
with self.assertRaisesRegex(RuntimeError, "input matrix is singular"):
torch.linalg.solve(A, b)
def test_linalg_solve_with_broadcasting(self):
from functools import partial
import torch