mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Fix internal assertion in torch.linalg.solve for singular matrices (#165254)
Fixes #163962 by special casing MPS in the negative status code branch in `_linalg_check_errors`. Checks if info is [`MPSMatrixDecompositionStatus.singular`](https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus/singular) (which has a raw value of -2). I didn't find an official Apple source with this raw value (besides printing the enum value), so I'm not sure if we can (or should) depend on it? Is there a way to directly get the Objective-C enum value in C++? Pull Request resolved: https://github.com/pytorch/pytorch/pull/165254 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
3af2f0c12a
commit
935ccdbe75
@ -196,6 +196,28 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output)
|
|||||||
other.size(0) > max_stride_size || other.size(1) > max_stride_size);
|
other.size(0) > max_stride_size || other.size(1) > max_stride_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void map_mps_decomposition_error_code_to_blas(const Tensor& status) {
|
||||||
|
const auto& status_flat = status.view(-1);
|
||||||
|
|
||||||
|
for (const auto i : c10::irange(status_flat.size(0))) {
|
||||||
|
int code = status_flat[i].item<int>();
|
||||||
|
switch (code) {
|
||||||
|
case MPSMatrixDecompositionStatusSuccess:
|
||||||
|
status_flat[i] = 0;
|
||||||
|
break;
|
||||||
|
case MPSMatrixDecompositionStatusNonPositiveDefinite:
|
||||||
|
case MPSMatrixDecompositionStatusSingular:
|
||||||
|
status_flat[i] = 2;
|
||||||
|
break;
|
||||||
|
case MPSMatrixDecompositionStatusFailure:
|
||||||
|
status_flat[i] = -1;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "Unknown MPSMatrixDecompositionStatus enum value: ", code);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
|
static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
|
||||||
@ -487,6 +509,9 @@ static void linalg_solve_out_mps_impl(const Tensor& A,
|
|||||||
"mpsmatrixdecompositionstatus for details.");
|
"mpsmatrixdecompositionstatus for details.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
map_mps_decomposition_error_code_to_blas(info);
|
||||||
|
|
||||||
if (!left) {
|
if (!left) {
|
||||||
// If this was a right solve, transpose the result back
|
// If this was a right solve, transpose the result back
|
||||||
result.copy_(result_t.transpose(-2, -1).contiguous());
|
result.copy_(result_t.transpose(-2, -1).contiguous());
|
||||||
|
@ -1978,6 +1978,16 @@ class TestMPS(TestCaseMPS):
|
|||||||
run_linalg_solve_test(32, 10, 10)
|
run_linalg_solve_test(32, 10, 10)
|
||||||
run_linalg_solve_test(32, 2, 2, 2, 2, 10, 10)
|
run_linalg_solve_test(32, 2, 2, 2, 2, 10, 10)
|
||||||
|
|
||||||
|
def test_linalg_solve_singular(self):
|
||||||
|
# Regression test for https://github.com/pytorch/pytorch/issues/163962
|
||||||
|
|
||||||
|
# Explicit singular matrix
|
||||||
|
A = torch.tensor([[1.0, 2.0], [2.0, 4.0]], device="mps")
|
||||||
|
b = torch.rand_like(A)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "input matrix is singular"):
|
||||||
|
torch.linalg.solve(A, b)
|
||||||
|
|
||||||
def test_linalg_solve_with_broadcasting(self):
|
def test_linalg_solve_with_broadcasting(self):
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import torch
|
import torch
|
||||||
|
Reference in New Issue
Block a user