mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Adding check for square matrix for input tensor in matrix_exp backwar… (#163357)
…d op. Fixes #146796 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163357 Approved by: https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
2a5ce2feb4
commit
531f3bf5e1
@ -2801,6 +2801,7 @@ Tensor matrix_exp(const Tensor& a) {
|
||||
// TODO This should be deprecated in favor of linalg_matrix_exp_differential
|
||||
// in FunctionsManual.cpp
|
||||
Tensor matrix_exp_backward(const Tensor& self, const Tensor& grad) {
|
||||
squareCheckInputs(self, "matrix_exp_backward");
|
||||
NoTF32Guard disable_tf32;
|
||||
return backward_analytic_function_of_a_matrix(
|
||||
self, grad,
|
||||
|
@ -8418,6 +8418,21 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
x = torch.randn(3, 3, 1, 1)
|
||||
self.assertEqual(expm(x), x.exp())
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
|
||||
def test_matrix_exp_backward_input_validation(self, device, dtype):
|
||||
|
||||
scalar_tensor = torch.tensor(1.0, dtype=dtype, device=device)
|
||||
grad_1d = torch.randn(1, dtype=dtype, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
|
||||
torch.ops.aten.matrix_exp_backward(scalar_tensor, grad_1d)
|
||||
|
||||
non_square = torch.randn(2, 3, dtype=dtype, device=device)
|
||||
grad_non_square = torch.randn(2, 3, dtype=dtype, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
|
||||
torch.ops.aten.matrix_exp_backward(non_square, grad_non_square)
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
|
||||
|
Reference in New Issue
Block a user