[torch.distributions] Implement positive-semidefinite constraint (#71375)

Summary:
While implementing https://github.com/pytorch/pytorch/issues/70275, I thought that it will be useful if there is a `torch.distributions.constraints` to check the positive-semidefiniteness of matrix random variables.
This PR implements it with `torch.linalg.eigvalsh`, different from `torch.distributions.constraints.positive_definite` implemented with `torch.linalg.cholesky_ex`.
Currently, `torch.linalg.cholesky_ex` returns only the order of the leading minor that is not positive-definite in symmetric matrices and we can't check positive semi-definiteness by the mechanism.
cc neerajprad

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71375

Reviewed By: H-Huang

Differential Revision: D33663990

Pulled By: neerajprad

fbshipit-source-id: 02cefbb595a1da5e54a239d4f17b33c619416518
(cherry picked from commit 43eaea5bd861714f234e9efc1a7fb571631298f4)
This commit is contained in:
Kim Juhyeong
2022-01-20 09:14:19 -08:00
committed by PyTorch MergeBot
parent 640bfa7e6f
commit 89c844db9b
2 changed files with 25 additions and 2 deletions

View File

@ -9,12 +9,22 @@ from torch.testing._internal.common_cuda import TEST_CUDA
EXAMPLES = [
(constraints.symmetric, False, [[2., 0], [2., 2]]),
(constraints.positive_semidefinite, False, [[2., 0], [2., 2]]),
(constraints.positive_definite, False, [[2., 0], [2., 2]]),
(constraints.symmetric, True, [[3., -5], [-5., 3]]),
(constraints.positive_semidefinite, False, [[3., -5], [-5., 3]]),
(constraints.positive_definite, False, [[3., -5], [-5., 3]]),
(constraints.symmetric, True, [[1., 2], [2., 4]]),
(constraints.positive_semidefinite, True, [[1., 2], [2., 4]]),
(constraints.positive_definite, False, [[1., 2], [2., 4]]),
(constraints.symmetric, True, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
(constraints.positive_semidefinite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
(constraints.positive_definite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
(constraints.symmetric, True, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
(constraints.positive_semidefinite, True, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
(constraints.positive_definite, False, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
(constraints.symmetric, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
(constraints.positive_semidefinite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
(constraints.positive_definite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
]

View File

@ -16,9 +16,10 @@ The following constraints are implemented:
- ``constraints.multinomial``
- ``constraints.nonnegative_integer``
- ``constraints.one_hot``
- ``constraints.positive_definite``
- ``constraints.positive_integer``
- ``constraints.positive``
- ``constraints.positive_semidefinite``
- ``constraints.positive_definite``
- ``constraints.real_vector``
- ``constraints.real``
- ``constraints.simplex``
@ -51,6 +52,7 @@ __all__ = [
'multinomial',
'nonnegative_integer',
'positive',
'positive_semidefinite',
'positive_definite',
'positive_integer',
'real',
@ -488,11 +490,21 @@ class _Symmetric(_Square):
return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1)
class _PositiveSemidefinite(_Symmetric):
"""
Constrain to positive-semidefinite matrices.
"""
def check(self, value):
sym_check = super().check(value)
if not sym_check.all():
return sym_check
return torch.linalg.eigvalsh(value).ge(0).all(-1)
class _PositiveDefinite(_Symmetric):
"""
Constrain to positive-definite matrices.
"""
def check(self, value):
sym_check = super().check(value)
if not sym_check.all():
@ -591,6 +603,7 @@ lower_cholesky = _LowerCholesky()
corr_cholesky = _CorrCholesky()
square = _Square()
symmetric = _Symmetric()
positive_semidefinite = _PositiveSemidefinite()
positive_definite = _PositiveDefinite()
cat = _Cat
stack = _Stack