UFMT formatting on test/distributions, test/error_messages, test/forward_backward_compatability (#123527)

Partiall addresses #123062

UFMT formatting on
- test/distributions
- test/error_messages, test/forward_backward_compatability

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123527
Approved by: https://github.com/huydhn
This commit is contained in:
Arun Pa
2024-04-09 16:03:42 +00:00
committed by PyTorch MergeBot
parent c96bd3de06
commit 266e278ccf
8 changed files with 3782 additions and 2094 deletions

View File

@ -1143,11 +1143,6 @@ exclude_patterns = [
'test/distributed/test_nccl.py',
'test/distributed/test_pg_wrapper.py',
'test/distributed/test_store.py',
'test/distributions/test_constraints.py',
'test/distributions/test_distributions.py',
'test/distributions/test_transforms.py',
'test/distributions/test_utils.py',
'test/error_messages/storage.py',
'test/expect/__init__.py',
'test/export/test_db.py',
'test/export/test_export.py',
@ -1158,8 +1153,6 @@ exclude_patterns = [
'test/export/test_upgrade.py',
'test/export/test_verifier.py',
'test/export/test_unflatten.py',
'test/forward_backward_compatibility/check_forward_backward_compatibility.py',
'test/forward_backward_compatibility/dump_all_function_schemas.py',
'test/functorch/attn_ft.py',
'test/functorch/attn_positional.py',
'test/functorch/common_utils.py',

View File

@ -9,46 +9,70 @@ from torch.testing._internal.common_utils import run_tests
EXAMPLES = [
(constraints.symmetric, False, [[2., 0], [2., 2]]),
(constraints.positive_semidefinite, False, [[2., 0], [2., 2]]),
(constraints.positive_definite, False, [[2., 0], [2., 2]]),
(constraints.symmetric, True, [[3., -5], [-5., 3]]),
(constraints.positive_semidefinite, False, [[3., -5], [-5., 3]]),
(constraints.positive_definite, False, [[3., -5], [-5., 3]]),
(constraints.symmetric, True, [[1., 2], [2., 4]]),
(constraints.positive_semidefinite, True, [[1., 2], [2., 4]]),
(constraints.positive_definite, False, [[1., 2], [2., 4]]),
(constraints.symmetric, True, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
(constraints.positive_semidefinite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
(constraints.positive_definite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
(constraints.symmetric, True, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
(constraints.positive_semidefinite, True, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
(constraints.positive_definite, False, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
(constraints.symmetric, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
(constraints.positive_semidefinite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
(constraints.positive_definite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
(constraints.symmetric, False, [[2.0, 0], [2.0, 2]]),
(constraints.positive_semidefinite, False, [[2.0, 0], [2.0, 2]]),
(constraints.positive_definite, False, [[2.0, 0], [2.0, 2]]),
(constraints.symmetric, True, [[3.0, -5], [-5.0, 3]]),
(constraints.positive_semidefinite, False, [[3.0, -5], [-5.0, 3]]),
(constraints.positive_definite, False, [[3.0, -5], [-5.0, 3]]),
(constraints.symmetric, True, [[1.0, 2], [2.0, 4]]),
(constraints.positive_semidefinite, True, [[1.0, 2], [2.0, 4]]),
(constraints.positive_definite, False, [[1.0, 2], [2.0, 4]]),
(constraints.symmetric, True, [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]]),
(
constraints.positive_semidefinite,
False,
[[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]],
),
(
constraints.positive_definite,
False,
[[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]],
),
(constraints.symmetric, True, [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]]),
(
constraints.positive_semidefinite,
True,
[[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]],
),
(
constraints.positive_definite,
False,
[[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]],
),
(constraints.symmetric, True, [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]]),
(
constraints.positive_semidefinite,
True,
[[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]],
),
(
constraints.positive_definite,
True,
[[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]],
),
]
CONSTRAINTS = [
(constraints.real,),
(constraints.real_vector,),
(constraints.positive,),
(constraints.greater_than, [-10., -2, 0, 2, 10]),
(constraints.greater_than, [-10.0, -2, 0, 2, 10]),
(constraints.greater_than, 0),
(constraints.greater_than, 2),
(constraints.greater_than, -2),
(constraints.greater_than_eq, 0),
(constraints.greater_than_eq, 2),
(constraints.greater_than_eq, -2),
(constraints.less_than, [-10., -2, 0, 2, 10]),
(constraints.less_than, [-10.0, -2, 0, 2, 10]),
(constraints.less_than, 0),
(constraints.less_than, 2),
(constraints.less_than, -2),
(constraints.unit_interval,),
(constraints.interval, [-4., -2, 0, 2, 4], [-3., 3, 1, 5, 5]),
(constraints.interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]),
(constraints.interval, -2, -1),
(constraints.interval, 1, 2),
(constraints.half_open_interval, [-4., -2, 0, 2, 4], [-3., 3, 1, 5, 5]),
(constraints.half_open_interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]),
(constraints.half_open_interval, -2, -1),
(constraints.half_open_interval, 1, 2),
(constraints.simplex,),
@ -64,25 +88,40 @@ def build_constraint(constraint_fn, args, is_cuda=False):
t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
return constraint_fn(*(t(x) if isinstance(x, list) else x for x in args))
@pytest.mark.parametrize(('constraint_fn', 'result', 'value'), EXAMPLES)
@pytest.mark.parametrize('is_cuda', [False,
pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
reason='CUDA not found.'))])
@pytest.mark.parametrize(("constraint_fn", "result", "value"), EXAMPLES)
@pytest.mark.parametrize(
"is_cuda",
[
False,
pytest.param(
True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
),
],
)
def test_constraint(constraint_fn, result, value, is_cuda):
t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
assert constraint_fn.check(t(value)).all() == result
@pytest.mark.parametrize(('constraint_fn', 'args'), [(c[0], c[1:]) for c in CONSTRAINTS])
@pytest.mark.parametrize('is_cuda', [False,
pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
reason='CUDA not found.'))])
@pytest.mark.parametrize(
("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS]
)
@pytest.mark.parametrize(
"is_cuda",
[
False,
pytest.param(
True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
),
],
)
def test_biject_to(constraint_fn, args, is_cuda):
constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
try:
t = biject_to(constraint)
except NotImplementedError:
pytest.skip('`biject_to` not implemented.')
pytest.skip("`biject_to` not implemented.")
assert t.bijective, f"biject_to({constraint}) is not bijective"
if constraint_fn is constraints.corr_cholesky:
# (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
@ -92,22 +131,32 @@ def test_biject_to(constraint_fn, args, is_cuda):
if is_cuda:
x = x.cuda()
y = t(x)
assert constraint.check(y).all(), '\n'.join([
f"Failed to biject_to({constraint})",
f"x = {x}",
f"biject_to(...)(x) = {y}",
])
assert constraint.check(y).all(), "\n".join(
[
f"Failed to biject_to({constraint})",
f"x = {x}",
f"biject_to(...)(x) = {y}",
]
)
x2 = t.inv(y)
assert torch.allclose(x, x2), f"Error in biject_to({constraint}) inverse"
j = t.log_abs_det_jacobian(x, y)
assert j.shape == x.shape[:x.dim() - t.domain.event_dim]
assert j.shape == x.shape[: x.dim() - t.domain.event_dim]
@pytest.mark.parametrize(('constraint_fn', 'args'), [(c[0], c[1:]) for c in CONSTRAINTS])
@pytest.mark.parametrize('is_cuda', [False,
pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
reason='CUDA not found.'))])
@pytest.mark.parametrize(
("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS]
)
@pytest.mark.parametrize(
"is_cuda",
[
False,
pytest.param(
True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
),
],
)
def test_transform_to(constraint_fn, args, is_cuda):
constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
t = transform_to(constraint)

File diff suppressed because it is too large Load Diff

View File

@ -8,15 +8,34 @@ import pytest
import torch
from torch.autograd import grad
from torch.autograd.functional import jacobian
from torch.distributions import Dirichlet, Independent, Normal, TransformedDistribution, constraints
from torch.distributions.transforms import (AbsTransform, AffineTransform, ComposeTransform,
CorrCholeskyTransform, CumulativeDistributionTransform,
ExpTransform, IndependentTransform,
LowerCholeskyTransform, PowerTransform,
ReshapeTransform, SigmoidTransform, TanhTransform,
SoftmaxTransform, SoftplusTransform, StickBreakingTransform,
identity_transform, Transform, _InverseTransform,
PositiveDefiniteTransform)
from torch.distributions import (
constraints,
Dirichlet,
Independent,
Normal,
TransformedDistribution,
)
from torch.distributions.transforms import (
_InverseTransform,
AbsTransform,
AffineTransform,
ComposeTransform,
CorrCholeskyTransform,
CumulativeDistributionTransform,
ExpTransform,
identity_transform,
IndependentTransform,
LowerCholeskyTransform,
PositiveDefiniteTransform,
PowerTransform,
ReshapeTransform,
SigmoidTransform,
SoftmaxTransform,
SoftplusTransform,
StickBreakingTransform,
TanhTransform,
Transform,
)
from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix
from torch.testing._internal.common_utils import run_tests
@ -25,57 +44,53 @@ def get_transforms(cache_size):
transforms = [
AbsTransform(cache_size=cache_size),
ExpTransform(cache_size=cache_size),
PowerTransform(exponent=2,
cache_size=cache_size),
PowerTransform(exponent=-2,
cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.).normal_(),
cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.).normal_(),
cache_size=cache_size),
PowerTransform(exponent=2, cache_size=cache_size),
PowerTransform(exponent=-2, cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size),
SigmoidTransform(cache_size=cache_size),
TanhTransform(cache_size=cache_size),
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(torch.randn(5),
torch.randn(5),
cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size),
AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size),
SoftmaxTransform(cache_size=cache_size),
SoftplusTransform(cache_size=cache_size),
StickBreakingTransform(cache_size=cache_size),
LowerCholeskyTransform(cache_size=cache_size),
CorrCholeskyTransform(cache_size=cache_size),
PositiveDefiniteTransform(cache_size=cache_size),
ComposeTransform([
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
]),
ComposeTransform([
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
ExpTransform(cache_size=cache_size),
]),
ComposeTransform([
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
]),
ComposeTransform(
[
AffineTransform(
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
),
]
),
ComposeTransform(
[
AffineTransform(
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
),
ExpTransform(cache_size=cache_size),
]
),
ComposeTransform(
[
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
),
]
),
ReshapeTransform((4, 5), (2, 5, 2)),
IndependentTransform(
AffineTransform(torch.randn(5),
torch.randn(5),
cache_size=cache_size),
1),
AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), 1
),
CumulativeDistributionTransform(Normal(0, 1)),
]
transforms += [t.inv for t in transforms]
@ -88,9 +103,17 @@ def reshape_transform(transform, shape):
if isinstance(transform.loc, Number):
return transform
try:
return AffineTransform(transform.loc.expand(shape), transform.scale.expand(shape), cache_size=transform._cache_size)
return AffineTransform(
transform.loc.expand(shape),
transform.scale.expand(shape),
cache_size=transform._cache_size,
)
except RuntimeError:
return AffineTransform(transform.loc.reshape(shape), transform.scale.reshape(shape), cache_size=transform._cache_size)
return AffineTransform(
transform.loc.reshape(shape),
transform.scale.reshape(shape),
cache_size=transform._cache_size,
)
if isinstance(transform, ComposeTransform):
reshaped_parts = []
for p in transform.parts:
@ -106,8 +129,12 @@ def reshape_transform(transform, shape):
# Generate pytest ids
def transform_id(x):
assert isinstance(x, Transform)
name = f'Inv({type(x._inv).__name__})' if isinstance(x, _InverseTransform) else f'{type(x).__name__}'
return f'{name}(cache_size={x._cache_size})'
name = (
f"Inv({type(x._inv).__name__})"
if isinstance(x, _InverseTransform)
else f"{type(x).__name__}"
)
return f"{name}(cache_size={x._cache_size})"
def generate_data(transform):
@ -119,12 +146,17 @@ def generate_data(transform):
if isinstance(transform.inv, ReshapeTransform):
return torch.randn(transform.inv.out_shape)
domain = transform.domain
while (isinstance(domain, constraints.independent) and
domain is not constraints.real_vector):
while (
isinstance(domain, constraints.independent)
and domain is not constraints.real_vector
):
domain = domain.base_constraint
codomain = transform.codomain
x = torch.empty(4, 5)
positive_definite_constraints = [constraints.lower_cholesky, constraints.positive_definite]
positive_definite_constraints = [
constraints.lower_cholesky,
constraints.positive_definite,
]
if domain in positive_definite_constraints:
x = torch.randn(6, 6)
x = x.tril(-1) + x.diag().exp().diag_embed()
@ -159,21 +191,23 @@ def generate_data(transform):
x /= x.norm(dim=-1, keepdim=True)
x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs())
return x
raise ValueError(f'Unsupported domain: {domain}')
raise ValueError(f"Unsupported domain: {domain}")
TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1)
TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0)
ALL_TRANSFORMS = TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform]
ALL_TRANSFORMS = (
TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform]
)
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
def test_inv_inv(transform, ids=transform_id):
assert transform.inv.inv is transform
@pytest.mark.parametrize('x', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
@pytest.mark.parametrize('y', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
@pytest.mark.parametrize("x", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
@pytest.mark.parametrize("y", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_equality(x, y):
if x is y:
assert x == y
@ -182,7 +216,7 @@ def test_equality(x, y):
assert identity_transform == identity_transform.inv
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
def test_with_cache(transform):
if transform._cache_size == 0:
transform = transform.with_cache(1)
@ -191,20 +225,20 @@ def test_with_cache(transform):
try:
y = transform(x)
except NotImplementedError:
pytest.skip('Not implemented.')
pytest.skip("Not implemented.")
y2 = transform(x)
assert y2 is y
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
@pytest.mark.parametrize('test_cached', [True, False])
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
@pytest.mark.parametrize("test_cached", [True, False])
def test_forward_inverse(transform, test_cached):
x = generate_data(transform).requires_grad_()
assert transform.domain.check(x).all() # verify that the input data are valid
try:
y = transform(x)
except NotImplementedError:
pytest.skip('Not implemented.')
pytest.skip("Not implemented.")
assert y.shape == transform.forward_shape(x.shape)
if test_cached:
x2 = transform.inv(y) # should be implemented at least by caching
@ -212,26 +246,30 @@ def test_forward_inverse(transform, test_cached):
try:
x2 = transform.inv(y.clone()) # bypass cache
except NotImplementedError:
pytest.skip('Not implemented.')
pytest.skip("Not implemented.")
assert x2.shape == transform.inverse_shape(y.shape)
y2 = transform(x2)
if transform.bijective:
# verify function inverse
assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), '\n'.join([
f'{transform} t.inv(t(-)) error',
f'x = {x}',
f'y = t(x) = {y}',
f'x2 = t.inv(y) = {x2}',
])
assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), "\n".join(
[
f"{transform} t.inv(t(-)) error",
f"x = {x}",
f"y = t(x) = {y}",
f"x2 = t.inv(y) = {x2}",
]
)
else:
# verify weaker function pseudo-inverse
assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), '\n'.join([
f'{transform} t(t.inv(t(-))) error',
f'x = {x}',
f'y = t(x) = {y}',
f'x2 = t.inv(y) = {x2}',
f'y2 = t(x2) = {y2}',
])
assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), "\n".join(
[
f"{transform} t(t.inv(t(-))) error",
f"x = {x}",
f"y = t(x) = {y}",
f"x2 = t.inv(y) = {x2}",
f"y2 = t(x2) = {y2}",
]
)
def test_compose_transform_shapes():
@ -255,33 +293,36 @@ base_dist1 = Dirichlet(torch.ones(4, 4))
base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))
@pytest.mark.parametrize(('batch_shape', 'event_shape', 'dist'), [
((4, 4), (), base_dist0),
((4,), (4,), base_dist1),
((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
((3, 4, 4), (), base_dist2),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
])
@pytest.mark.parametrize(
("batch_shape", "event_shape", "dist"),
[
((4, 4), (), base_dist0),
((4,), (4,), base_dist1),
((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
((3, 4, 4), (), base_dist2),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
],
)
def test_transformed_distribution_shapes(batch_shape, event_shape, dist):
assert dist.batch_shape == batch_shape
assert dist.event_shape == event_shape
@ -289,10 +330,10 @@ def test_transformed_distribution_shapes(batch_shape, event_shape, dist):
try:
dist.log_prob(x) # this should not crash
except NotImplementedError:
pytest.skip('Not implemented.')
pytest.skip("Not implemented.")
@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_fwd(transform):
x = generate_data(transform).requires_grad_()
@ -302,14 +343,14 @@ def test_jit_fwd(transform):
try:
traced_f = torch.jit.trace(f, (x,))
except NotImplementedError:
pytest.skip('Not implemented.')
pytest.skip("Not implemented.")
# check on different inputs
x = generate_data(transform).requires_grad_()
assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_inv(transform):
y = generate_data(transform.inv).requires_grad_()
@ -319,14 +360,14 @@ def test_jit_inv(transform):
try:
traced_f = torch.jit.trace(f, (y,))
except NotImplementedError:
pytest.skip('Not implemented.')
pytest.skip("Not implemented.")
# check on different inputs
y = generate_data(transform.inv).requires_grad_()
assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True)
@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_jacobian(transform):
x = generate_data(transform).requires_grad_()
@ -337,23 +378,23 @@ def test_jit_jacobian(transform):
try:
traced_f = torch.jit.trace(f, (x,))
except NotImplementedError:
pytest.skip('Not implemented.')
pytest.skip("Not implemented.")
# check on different inputs
x = generate_data(transform).requires_grad_()
assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
def test_jacobian(transform):
x = generate_data(transform)
try:
y = transform(x)
actual = transform.log_abs_det_jacobian(x, y)
except NotImplementedError:
pytest.skip('Not implemented.')
pytest.skip("Not implemented.")
# Test shape
target_shape = x.shape[:x.dim() - transform.domain.event_dim]
target_shape = x.shape[: x.dim() - transform.domain.event_dim]
assert actual.shape == target_shape
# Expand if required
@ -366,7 +407,9 @@ def test_jacobian(transform):
transform = reshape_transform(transform, x_.shape)
# 1. Transforms with unit jacobian
if isinstance(transform, ReshapeTransform) or isinstance(transform.inv, ReshapeTransform):
if isinstance(transform, ReshapeTransform) or isinstance(
transform.inv, ReshapeTransform
):
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
# 2. Transforms with 0 off-diagonal elements
@ -380,8 +423,10 @@ def test_jacobian(transform):
if isinstance(transform, CorrCholeskyTransform):
jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_)
elif isinstance(transform.inv, CorrCholeskyTransform):
jac = jacobian(lambda x: transform(vec_to_tril_matrix(x, diag=-1)),
tril_matrix_to_vec(x_, diag=-1))
jac = jacobian(
lambda x: transform(vec_to_tril_matrix(x, diag=-1)),
tril_matrix_to_vec(x_, diag=-1),
)
elif isinstance(transform, StickBreakingTransform):
jac = jacobian(lambda x: transform(x)[..., :-1], x_)
else:
@ -393,20 +438,28 @@ def test_jacobian(transform):
# can be computed.
gather_idx_shape = list(jac.shape)
gather_idx_shape[-2] = 1
gather_idxs = torch.arange(n).reshape((n,) + (1,) * (len(jac.shape) - 1)).expand(gather_idx_shape)
gather_idxs = (
torch.arange(n)
.reshape((n,) + (1,) * (len(jac.shape) - 1))
.expand(gather_idx_shape)
)
jac = jac.gather(-2, gather_idxs).squeeze(-2)
out_ndims = jac.shape[-2]
jac = jac[..., :out_ndims] # Remove extra zero-valued dims (for inverse stick-breaking).
jac = jac[
..., :out_ndims
] # Remove extra zero-valued dims (for inverse stick-breaking).
expected = torch.slogdet(jac).logabsdet
assert torch.allclose(actual, expected, atol=1e-5)
@pytest.mark.parametrize("event_dims",
[(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)],
ids=str)
@pytest.mark.parametrize(
"event_dims", [(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)], ids=str
)
def test_compose_affine(event_dims):
transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims]
transforms = [
AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims
]
transform = ComposeTransform(transforms)
assert transform.codomain.event_dim == max(event_dims)
assert transform.domain.event_dim == max(event_dims)
@ -426,10 +479,12 @@ def test_compose_affine(event_dims):
@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
def test_compose_reshape(batch_shape):
transforms = [ReshapeTransform((), ()),
ReshapeTransform((2,), (1, 2)),
ReshapeTransform((3, 1, 2), (6,)),
ReshapeTransform((6,), (2, 3))]
transforms = [
ReshapeTransform((), ()),
ReshapeTransform((2,), (1, 2)),
ReshapeTransform((3, 1, 2), (6,)),
ReshapeTransform((6,), (2, 3)),
]
transform = ComposeTransform(transforms)
assert transform.codomain.event_dim == 2
assert transform.domain.event_dim == 2
@ -447,16 +502,19 @@ def test_compose_reshape(batch_shape):
@pytest.mark.parametrize("base_batch_dim", [0, 1, 2])
@pytest.mark.parametrize("base_event_dim", [0, 1, 2])
@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3])
def test_transformed_distribution(base_batch_dim, base_event_dim, transform_dim,
num_transforms, sample_shape):
def test_transformed_distribution(
base_batch_dim, base_event_dim, transform_dim, num_transforms, sample_shape
):
shape = torch.Size([2, 3, 4, 5])
base_dist = Normal(0, 1)
base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:])
base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim :])
if base_event_dim:
base_dist = Independent(base_dist, base_event_dim)
transforms = [AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1),
ReshapeTransform((4, 5), (20,)),
ReshapeTransform((3, 20), (6, 10))]
transforms = [
AffineTransform(torch.zeros(shape[4 - transform_dim :]), 1),
ReshapeTransform((4, 5), (20,)),
ReshapeTransform((3, 20), (6, 10)),
]
transforms = transforms[:num_transforms]
transform = ComposeTransform(transforms)
@ -498,17 +556,17 @@ def test_save_load_transform():
assert torch.allclose(log_prob, other.log_prob(x))
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
def test_transform_sign(transform: Transform):
try:
sign = transform.sign
except NotImplementedError:
pytest.skip('Not implemented.')
pytest.skip("Not implemented.")
x = generate_data(transform).requires_grad_()
y = transform(x).sum()
derivatives, = grad(y, [x])
assert torch.less(torch.as_tensor(0.), derivatives * sign).all()
(derivatives,) = grad(y, [x])
assert torch.less(torch.as_tensor(0.0), derivatives * sign).all()
if __name__ == "__main__":

View File

@ -6,12 +6,16 @@ import torch
from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix
from torch.testing._internal.common_utils import run_tests
@pytest.mark.parametrize('shape', [
(2, 2),
(3, 3),
(2, 4, 4),
(2, 2, 4, 4),
])
@pytest.mark.parametrize(
"shape",
[
(2, 2),
(3, 3),
(2, 4, 4),
(2, 2, 4, 4),
],
)
def test_tril_matrix_to_vec(shape):
mat = torch.randn(shape)
n = mat.shape[-1]

View File

@ -6,66 +6,71 @@ def check_error(desc, fn, *required_substrings):
fn()
except Exception as e:
error_message = e.args[0]
print('=' * 80)
print("=" * 80)
print(desc)
print('-' * 80)
print("-" * 80)
print(error_message)
print('')
print("")
for sub in required_substrings:
assert sub in error_message
return
raise AssertionError(f"given function ({desc}) didn't raise an error")
check_error("Wrong argument types", lambda: torch.FloatStorage(object()), "object")
check_error(
'Wrong argument types',
lambda: torch.FloatStorage(object()),
'object')
"Unknown keyword argument", lambda: torch.FloatStorage(content=1234.0), "keyword"
)
check_error('Unknown keyword argument',
lambda: torch.FloatStorage(content=1234.),
'keyword')
check_error(
"Invalid types inside a sequence",
lambda: torch.FloatStorage(["a", "b"]),
"list",
"str",
)
check_error('Invalid types inside a sequence',
lambda: torch.FloatStorage(['a', 'b']),
'list', 'str')
check_error("Invalid size type", lambda: torch.FloatStorage(1.5), "float")
check_error('Invalid size type',
lambda: torch.FloatStorage(1.5),
'float')
check_error(
"Invalid offset", lambda: torch.FloatStorage(torch.FloatStorage(2), 4), "2", "4"
)
check_error('Invalid offset',
lambda: torch.FloatStorage(torch.FloatStorage(2), 4),
'2', '4')
check_error(
"Negative offset", lambda: torch.FloatStorage(torch.FloatStorage(2), -1), "2", "-1"
)
check_error('Negative offset',
lambda: torch.FloatStorage(torch.FloatStorage(2), -1),
'2', '-1')
check_error(
"Invalid size",
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
"2",
"1",
"5",
)
check_error('Invalid size',
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
'2', '1', '5')
check_error(
"Negative size",
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
"2",
"1",
"-5",
)
check_error('Negative size',
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
'2', '1', '-5')
check_error('Invalid index type',
lambda: torch.FloatStorage(10)['first item'],
'str')
check_error("Invalid index type", lambda: torch.FloatStorage(10)["first item"], "str")
def assign():
torch.FloatStorage(10)[1:-1] = '1'
check_error('Invalid value type',
assign,
'str')
torch.FloatStorage(10)[1:-1] = "1"
check_error('resize_ with invalid type',
lambda: torch.FloatStorage(10).resize_(1.5),
'float')
check_error('fill_ with invalid type',
lambda: torch.IntStorage(10).fill_('asdf'),
'str')
check_error("Invalid value type", assign, "str")
check_error(
"resize_ with invalid type", lambda: torch.FloatStorage(10).resize_(1.5), "float"
)
check_error(
"fill_ with invalid type", lambda: torch.IntStorage(10).fill_("asdf"), "str"
)
# TODO: frombuffer

View File

@ -108,14 +108,12 @@ ALLOW_LIST = [
("aten::mps_max_pool2d_backward.out", datetime.date(9999, 1, 1)),
# TODO: FIXME: prims shouldn't be checked
("prims::.*", datetime.date(9999, 1, 1)),
("aten::_flash_attention_forward", datetime.date(2023, 12, 30)),
("aten::_flash_attention_backward", datetime.date(2023, 12, 30)),
("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
# BetterTransformer 1.0 internal operators
("aten::_transformer_decoder_only_layer_fwd", datetime.date(9999, 1, 1)),
("aten::_native_decoder_only_multi_head_attention",
datetime.date(9999, 1, 1)),
("aten::_native_decoder_only_multi_head_attention", datetime.date(9999, 1, 1)),
("c10d::_allgather_base_", datetime.date(2023, 12, 30)),
("c10d::_reduce_scatter_base_", datetime.date(2023, 12, 30)),
("c10d::broadcast_", datetime.date(2023, 12, 30)),
@ -149,9 +147,12 @@ ALLOW_LIST_COMPILED = [
re.compile(item[0]),
item[1],
re.compile(item[2]) if len(item) > 2 else None,
) for item in ALLOW_LIST if item[1] >= datetime.date.today()
)
for item in ALLOW_LIST
if item[1] >= datetime.date.today()
]
def allow_listed(schema):
for item in ALLOW_LIST_COMPILED:
if item[0].search(str(schema)):
@ -171,6 +172,7 @@ dont_parse_list = [
("__backends__.nnc", datetime.date(2099, 9, 17)),
]
def has_valid_upgraders(schema, version_map):
# we want to parse through the map to find if
# the schema has valid upgraders. Since the
@ -199,6 +201,7 @@ def has_valid_upgraders(schema, version_map):
return False
def dont_parse(schema_line):
for item in dont_parse_list:
if item[1] < datetime.date.today():
@ -208,6 +211,7 @@ def dont_parse(schema_line):
return True
return False
def load_schemas_to_dict():
new_schemas = torch._C._jit_get_all_schemas()
new_schemas += torch._C._jit_get_custom_class_schemas()
@ -216,6 +220,7 @@ def load_schemas_to_dict():
new_schema_dict[s.name].append(s)
return new_schema_dict
def process_version_map(version_map):
# version map maps full schema name to
# list of upgraders. Since we only have
@ -225,12 +230,13 @@ def process_version_map(version_map):
# Dict[schema_name, Dict[overload, List[schema]]]
output = defaultdict(dict)
for (key, entries) in version_map.items():
for key, entries in version_map.items():
operator_name = key.split(".")[0]
schema_entries = [parse_schema(entry.old_schema) for entry in entries]
output[operator_name][key] = schema_entries
return output
def check_bc(existing_schemas):
new_schema_dict = load_schemas_to_dict()
version_map = process_version_map(torch._C._get_operator_version_map())
@ -272,6 +278,7 @@ def check_bc(existing_schemas):
)
return is_bc
def check_fc(existing_schemas):
new_schema_dict = load_schemas_to_dict()
is_fc = True
@ -285,7 +292,9 @@ def check_fc(existing_schemas):
found = False
possible_failure_reasons = []
for matching_new_schema in matching_new_schemas:
is_compatible, reason = matching_new_schema.check_forward_compatible_with(existing_schema)
is_compatible, reason = matching_new_schema.check_forward_compatible_with(
existing_schema
)
if is_compatible:
found = True
break

View File

@ -1,24 +1,25 @@
import argparse
import torch
def dump(filename):
schemas = torch._C._jit_get_all_schemas()
schemas += torch._C._jit_get_custom_class_schemas()
with open(filename, 'w') as f:
with open(filename, "w") as f:
for s in schemas:
f.write(str(s))
f.write('\n')
f.write("\n")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument(
'-f',
'--filename',
help='filename to dump the schemas',
"-f",
"--filename",
help="filename to dump the schemas",
type=str,
default='schemas.txt')
default="schemas.txt",
)
args = parser.parse_args()
dump(args.filename)