mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds some basic flake8-pytest-style rules from ruff with their autofixes. I just picked a couple uncontroversial changes about having a consistent pytest style that were already following. We should consider enabling some more in the future, but this is a good start. I also upgraded ruff to the latest version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110362 Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/kit1980
130 lines
5.5 KiB
Python
130 lines
5.5 KiB
Python
# Owner(s): ["module: distributions"]
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
from torch.distributions import biject_to, constraints, transform_to
|
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
|
|
EXAMPLES = [
|
|
(constraints.symmetric, False, [[2., 0], [2., 2]]),
|
|
(constraints.positive_semidefinite, False, [[2., 0], [2., 2]]),
|
|
(constraints.positive_definite, False, [[2., 0], [2., 2]]),
|
|
(constraints.symmetric, True, [[3., -5], [-5., 3]]),
|
|
(constraints.positive_semidefinite, False, [[3., -5], [-5., 3]]),
|
|
(constraints.positive_definite, False, [[3., -5], [-5., 3]]),
|
|
(constraints.symmetric, True, [[1., 2], [2., 4]]),
|
|
(constraints.positive_semidefinite, True, [[1., 2], [2., 4]]),
|
|
(constraints.positive_definite, False, [[1., 2], [2., 4]]),
|
|
(constraints.symmetric, True, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
|
|
(constraints.positive_semidefinite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
|
|
(constraints.positive_definite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
|
|
(constraints.symmetric, True, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
|
|
(constraints.positive_semidefinite, True, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
|
|
(constraints.positive_definite, False, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
|
|
(constraints.symmetric, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
|
|
(constraints.positive_semidefinite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
|
|
(constraints.positive_definite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
|
|
]
|
|
|
|
CONSTRAINTS = [
|
|
(constraints.real,),
|
|
(constraints.real_vector,),
|
|
(constraints.positive,),
|
|
(constraints.greater_than, [-10., -2, 0, 2, 10]),
|
|
(constraints.greater_than, 0),
|
|
(constraints.greater_than, 2),
|
|
(constraints.greater_than, -2),
|
|
(constraints.greater_than_eq, 0),
|
|
(constraints.greater_than_eq, 2),
|
|
(constraints.greater_than_eq, -2),
|
|
(constraints.less_than, [-10., -2, 0, 2, 10]),
|
|
(constraints.less_than, 0),
|
|
(constraints.less_than, 2),
|
|
(constraints.less_than, -2),
|
|
(constraints.unit_interval,),
|
|
(constraints.interval, [-4., -2, 0, 2, 4], [-3., 3, 1, 5, 5]),
|
|
(constraints.interval, -2, -1),
|
|
(constraints.interval, 1, 2),
|
|
(constraints.half_open_interval, [-4., -2, 0, 2, 4], [-3., 3, 1, 5, 5]),
|
|
(constraints.half_open_interval, -2, -1),
|
|
(constraints.half_open_interval, 1, 2),
|
|
(constraints.simplex,),
|
|
(constraints.corr_cholesky,),
|
|
(constraints.lower_cholesky,),
|
|
(constraints.positive_definite,),
|
|
]
|
|
|
|
|
|
def build_constraint(constraint_fn, args, is_cuda=False):
|
|
if not args:
|
|
return constraint_fn
|
|
t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
|
|
return constraint_fn(*(t(x) if isinstance(x, list) else x for x in args))
|
|
|
|
@pytest.mark.parametrize(('constraint_fn', 'result', 'value'), EXAMPLES)
|
|
@pytest.mark.parametrize('is_cuda', [False,
|
|
pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
|
|
reason='CUDA not found.'))])
|
|
def test_constraint(constraint_fn, result, value, is_cuda):
|
|
t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
|
|
assert constraint_fn.check(t(value)).all() == result
|
|
|
|
|
|
@pytest.mark.parametrize(('constraint_fn', 'args'), [(c[0], c[1:]) for c in CONSTRAINTS])
|
|
@pytest.mark.parametrize('is_cuda', [False,
|
|
pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
|
|
reason='CUDA not found.'))])
|
|
def test_biject_to(constraint_fn, args, is_cuda):
|
|
constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
|
|
try:
|
|
t = biject_to(constraint)
|
|
except NotImplementedError:
|
|
pytest.skip('`biject_to` not implemented.')
|
|
assert t.bijective, f"biject_to({constraint}) is not bijective"
|
|
if constraint_fn is constraints.corr_cholesky:
|
|
# (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
|
|
x = torch.randn(6, 6, dtype=torch.double)
|
|
else:
|
|
x = torch.randn(5, 5, dtype=torch.double)
|
|
if is_cuda:
|
|
x = x.cuda()
|
|
y = t(x)
|
|
assert constraint.check(y).all(), '\n'.join([
|
|
f"Failed to biject_to({constraint})",
|
|
f"x = {x}",
|
|
f"biject_to(...)(x) = {y}",
|
|
])
|
|
x2 = t.inv(y)
|
|
assert torch.allclose(x, x2), f"Error in biject_to({constraint}) inverse"
|
|
|
|
j = t.log_abs_det_jacobian(x, y)
|
|
assert j.shape == x.shape[:x.dim() - t.domain.event_dim]
|
|
|
|
|
|
@pytest.mark.parametrize(('constraint_fn', 'args'), [(c[0], c[1:]) for c in CONSTRAINTS])
|
|
@pytest.mark.parametrize('is_cuda', [False,
|
|
pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
|
|
reason='CUDA not found.'))])
|
|
def test_transform_to(constraint_fn, args, is_cuda):
|
|
constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
|
|
t = transform_to(constraint)
|
|
if constraint_fn is constraints.corr_cholesky:
|
|
# (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
|
|
x = torch.randn(6, 6, dtype=torch.double)
|
|
else:
|
|
x = torch.randn(5, 5, dtype=torch.double)
|
|
if is_cuda:
|
|
x = x.cuda()
|
|
y = t(x)
|
|
assert constraint.check(y).all(), f"Failed to transform_to({constraint})"
|
|
x2 = t.inv(y)
|
|
y2 = t(x2)
|
|
assert torch.allclose(y, y2), f"Error in transform_to({constraint}) pseudoinverse"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|