Adding check for square matrix for input tensor in matrix_exp backwar… (#163357)

…d op.

Fixes #146796

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163357
Approved by: https://github.com/lezcano
This commit is contained in:
mansiag05
2025-10-01 03:12:30 +00:00
committed by PyTorch MergeBot
parent 2a5ce2feb4
commit 531f3bf5e1
2 changed files with 16 additions and 0 deletions

View File

@ -2801,6 +2801,7 @@ Tensor matrix_exp(const Tensor& a) {
// TODO This should be deprecated in favor of linalg_matrix_exp_differential
// in FunctionsManual.cpp
Tensor matrix_exp_backward(const Tensor& self, const Tensor& grad) {
squareCheckInputs(self, "matrix_exp_backward");
NoTF32Guard disable_tf32;
return backward_analytic_function_of_a_matrix(
self, grad,

View File

@ -8418,6 +8418,21 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
x = torch.randn(3, 3, 1, 1)
self.assertEqual(expm(x), x.exp())
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
def test_matrix_exp_backward_input_validation(self, device, dtype):
scalar_tensor = torch.tensor(1.0, dtype=dtype, device=device)
grad_1d = torch.randn(1, dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
torch.ops.aten.matrix_exp_backward(scalar_tensor, grad_1d)
non_square = torch.randn(2, 3, dtype=dtype, device=device)
grad_non_square = torch.randn(2, 3, dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
torch.ops.aten.matrix_exp_backward(non_square, grad_non_square)
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)