Files
pytorch/test/distributions/test_constraints.py
Arun Pa 266e278ccf UFMT formatting on test/distributions, test/error_messages, test/forward_backward_compatability (#123527)
Partiall addresses #123062

UFMT formatting on
- test/distributions
- test/error_messages, test/forward_backward_compatability

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123527
Approved by: https://github.com/huydhn
2024-04-09 16:03:46 +00:00

179 lines
5.6 KiB
Python

# Owner(s): ["module: distributions"]
import pytest
import torch
from torch.distributions import biject_to, constraints, transform_to
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests
EXAMPLES = [
(constraints.symmetric, False, [[2.0, 0], [2.0, 2]]),
(constraints.positive_semidefinite, False, [[2.0, 0], [2.0, 2]]),
(constraints.positive_definite, False, [[2.0, 0], [2.0, 2]]),
(constraints.symmetric, True, [[3.0, -5], [-5.0, 3]]),
(constraints.positive_semidefinite, False, [[3.0, -5], [-5.0, 3]]),
(constraints.positive_definite, False, [[3.0, -5], [-5.0, 3]]),
(constraints.symmetric, True, [[1.0, 2], [2.0, 4]]),
(constraints.positive_semidefinite, True, [[1.0, 2], [2.0, 4]]),
(constraints.positive_definite, False, [[1.0, 2], [2.0, 4]]),
(constraints.symmetric, True, [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]]),
(
constraints.positive_semidefinite,
False,
[[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]],
),
(
constraints.positive_definite,
False,
[[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]],
),
(constraints.symmetric, True, [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]]),
(
constraints.positive_semidefinite,
True,
[[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]],
),
(
constraints.positive_definite,
False,
[[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]],
),
(constraints.symmetric, True, [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]]),
(
constraints.positive_semidefinite,
True,
[[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]],
),
(
constraints.positive_definite,
True,
[[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]],
),
]
CONSTRAINTS = [
(constraints.real,),
(constraints.real_vector,),
(constraints.positive,),
(constraints.greater_than, [-10.0, -2, 0, 2, 10]),
(constraints.greater_than, 0),
(constraints.greater_than, 2),
(constraints.greater_than, -2),
(constraints.greater_than_eq, 0),
(constraints.greater_than_eq, 2),
(constraints.greater_than_eq, -2),
(constraints.less_than, [-10.0, -2, 0, 2, 10]),
(constraints.less_than, 0),
(constraints.less_than, 2),
(constraints.less_than, -2),
(constraints.unit_interval,),
(constraints.interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]),
(constraints.interval, -2, -1),
(constraints.interval, 1, 2),
(constraints.half_open_interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]),
(constraints.half_open_interval, -2, -1),
(constraints.half_open_interval, 1, 2),
(constraints.simplex,),
(constraints.corr_cholesky,),
(constraints.lower_cholesky,),
(constraints.positive_definite,),
]
def build_constraint(constraint_fn, args, is_cuda=False):
if not args:
return constraint_fn
t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
return constraint_fn(*(t(x) if isinstance(x, list) else x for x in args))
@pytest.mark.parametrize(("constraint_fn", "result", "value"), EXAMPLES)
@pytest.mark.parametrize(
"is_cuda",
[
False,
pytest.param(
True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
),
],
)
def test_constraint(constraint_fn, result, value, is_cuda):
t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
assert constraint_fn.check(t(value)).all() == result
@pytest.mark.parametrize(
("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS]
)
@pytest.mark.parametrize(
"is_cuda",
[
False,
pytest.param(
True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
),
],
)
def test_biject_to(constraint_fn, args, is_cuda):
constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
try:
t = biject_to(constraint)
except NotImplementedError:
pytest.skip("`biject_to` not implemented.")
assert t.bijective, f"biject_to({constraint}) is not bijective"
if constraint_fn is constraints.corr_cholesky:
# (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
x = torch.randn(6, 6, dtype=torch.double)
else:
x = torch.randn(5, 5, dtype=torch.double)
if is_cuda:
x = x.cuda()
y = t(x)
assert constraint.check(y).all(), "\n".join(
[
f"Failed to biject_to({constraint})",
f"x = {x}",
f"biject_to(...)(x) = {y}",
]
)
x2 = t.inv(y)
assert torch.allclose(x, x2), f"Error in biject_to({constraint}) inverse"
j = t.log_abs_det_jacobian(x, y)
assert j.shape == x.shape[: x.dim() - t.domain.event_dim]
@pytest.mark.parametrize(
("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS]
)
@pytest.mark.parametrize(
"is_cuda",
[
False,
pytest.param(
True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
),
],
)
def test_transform_to(constraint_fn, args, is_cuda):
constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
t = transform_to(constraint)
if constraint_fn is constraints.corr_cholesky:
# (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
x = torch.randn(6, 6, dtype=torch.double)
else:
x = torch.randn(5, 5, dtype=torch.double)
if is_cuda:
x = x.cuda()
y = t(x)
assert constraint.check(y).all(), f"Failed to transform_to({constraint})"
x2 = t.inv(y)
y2 = t(x2)
assert torch.allclose(y, y2), f"Error in transform_to({constraint}) pseudoinverse"
if __name__ == "__main__":
run_tests()