mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[torch.distributions] Implement positive-semidefinite constraint (#71375)
Summary: While implementing https://github.com/pytorch/pytorch/issues/70275, I thought that it will be useful if there is a `torch.distributions.constraints` to check the positive-semidefiniteness of matrix random variables. This PR implements it with `torch.linalg.eigvalsh`, different from `torch.distributions.constraints.positive_definite` implemented with `torch.linalg.cholesky_ex`. Currently, `torch.linalg.cholesky_ex` returns only the order of the leading minor that is not positive-definite in symmetric matrices and we can't check positive semi-definiteness by the mechanism. cc neerajprad Pull Request resolved: https://github.com/pytorch/pytorch/pull/71375 Reviewed By: H-Huang Differential Revision: D33663990 Pulled By: neerajprad fbshipit-source-id: 02cefbb595a1da5e54a239d4f17b33c619416518 (cherry picked from commit 43eaea5bd861714f234e9efc1a7fb571631298f4)
This commit is contained in:
committed by
PyTorch MergeBot
parent
640bfa7e6f
commit
89c844db9b
@ -9,12 +9,22 @@ from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
|
||||
EXAMPLES = [
|
||||
(constraints.symmetric, False, [[2., 0], [2., 2]]),
|
||||
(constraints.positive_semidefinite, False, [[2., 0], [2., 2]]),
|
||||
(constraints.positive_definite, False, [[2., 0], [2., 2]]),
|
||||
(constraints.symmetric, True, [[3., -5], [-5., 3]]),
|
||||
(constraints.positive_semidefinite, False, [[3., -5], [-5., 3]]),
|
||||
(constraints.positive_definite, False, [[3., -5], [-5., 3]]),
|
||||
(constraints.symmetric, True, [[1., 2], [2., 4]]),
|
||||
(constraints.positive_semidefinite, True, [[1., 2], [2., 4]]),
|
||||
(constraints.positive_definite, False, [[1., 2], [2., 4]]),
|
||||
(constraints.symmetric, True, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
|
||||
(constraints.positive_semidefinite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
|
||||
(constraints.positive_definite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
|
||||
(constraints.symmetric, True, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
|
||||
(constraints.positive_semidefinite, True, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
|
||||
(constraints.positive_definite, False, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
|
||||
(constraints.symmetric, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
|
||||
(constraints.positive_semidefinite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
|
||||
(constraints.positive_definite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
|
||||
]
|
||||
|
||||
|
@ -16,9 +16,10 @@ The following constraints are implemented:
|
||||
- ``constraints.multinomial``
|
||||
- ``constraints.nonnegative_integer``
|
||||
- ``constraints.one_hot``
|
||||
- ``constraints.positive_definite``
|
||||
- ``constraints.positive_integer``
|
||||
- ``constraints.positive``
|
||||
- ``constraints.positive_semidefinite``
|
||||
- ``constraints.positive_definite``
|
||||
- ``constraints.real_vector``
|
||||
- ``constraints.real``
|
||||
- ``constraints.simplex``
|
||||
@ -51,6 +52,7 @@ __all__ = [
|
||||
'multinomial',
|
||||
'nonnegative_integer',
|
||||
'positive',
|
||||
'positive_semidefinite',
|
||||
'positive_definite',
|
||||
'positive_integer',
|
||||
'real',
|
||||
@ -488,11 +490,21 @@ class _Symmetric(_Square):
|
||||
return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1)
|
||||
|
||||
|
||||
class _PositiveSemidefinite(_Symmetric):
|
||||
"""
|
||||
Constrain to positive-semidefinite matrices.
|
||||
"""
|
||||
def check(self, value):
|
||||
sym_check = super().check(value)
|
||||
if not sym_check.all():
|
||||
return sym_check
|
||||
return torch.linalg.eigvalsh(value).ge(0).all(-1)
|
||||
|
||||
|
||||
class _PositiveDefinite(_Symmetric):
|
||||
"""
|
||||
Constrain to positive-definite matrices.
|
||||
"""
|
||||
|
||||
def check(self, value):
|
||||
sym_check = super().check(value)
|
||||
if not sym_check.all():
|
||||
@ -591,6 +603,7 @@ lower_cholesky = _LowerCholesky()
|
||||
corr_cholesky = _CorrCholesky()
|
||||
square = _Square()
|
||||
symmetric = _Symmetric()
|
||||
positive_semidefinite = _PositiveSemidefinite()
|
||||
positive_definite = _PositiveDefinite()
|
||||
cat = _Cat
|
||||
stack = _Stack
|
||||
|
Reference in New Issue
Block a user