mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
UFMT formatting on test/distributions, test/error_messages, test/forward_backward_compatability (#123527)
Partiall addresses #123062 UFMT formatting on - test/distributions - test/error_messages, test/forward_backward_compatability Pull Request resolved: https://github.com/pytorch/pytorch/pull/123527 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
c96bd3de06
commit
266e278ccf
@ -1143,11 +1143,6 @@ exclude_patterns = [
|
||||
'test/distributed/test_nccl.py',
|
||||
'test/distributed/test_pg_wrapper.py',
|
||||
'test/distributed/test_store.py',
|
||||
'test/distributions/test_constraints.py',
|
||||
'test/distributions/test_distributions.py',
|
||||
'test/distributions/test_transforms.py',
|
||||
'test/distributions/test_utils.py',
|
||||
'test/error_messages/storage.py',
|
||||
'test/expect/__init__.py',
|
||||
'test/export/test_db.py',
|
||||
'test/export/test_export.py',
|
||||
@ -1158,8 +1153,6 @@ exclude_patterns = [
|
||||
'test/export/test_upgrade.py',
|
||||
'test/export/test_verifier.py',
|
||||
'test/export/test_unflatten.py',
|
||||
'test/forward_backward_compatibility/check_forward_backward_compatibility.py',
|
||||
'test/forward_backward_compatibility/dump_all_function_schemas.py',
|
||||
'test/functorch/attn_ft.py',
|
||||
'test/functorch/attn_positional.py',
|
||||
'test/functorch/common_utils.py',
|
||||
|
@ -9,46 +9,70 @@ from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
EXAMPLES = [
|
||||
(constraints.symmetric, False, [[2., 0], [2., 2]]),
|
||||
(constraints.positive_semidefinite, False, [[2., 0], [2., 2]]),
|
||||
(constraints.positive_definite, False, [[2., 0], [2., 2]]),
|
||||
(constraints.symmetric, True, [[3., -5], [-5., 3]]),
|
||||
(constraints.positive_semidefinite, False, [[3., -5], [-5., 3]]),
|
||||
(constraints.positive_definite, False, [[3., -5], [-5., 3]]),
|
||||
(constraints.symmetric, True, [[1., 2], [2., 4]]),
|
||||
(constraints.positive_semidefinite, True, [[1., 2], [2., 4]]),
|
||||
(constraints.positive_definite, False, [[1., 2], [2., 4]]),
|
||||
(constraints.symmetric, True, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
|
||||
(constraints.positive_semidefinite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
|
||||
(constraints.positive_definite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
|
||||
(constraints.symmetric, True, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
|
||||
(constraints.positive_semidefinite, True, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
|
||||
(constraints.positive_definite, False, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
|
||||
(constraints.symmetric, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
|
||||
(constraints.positive_semidefinite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
|
||||
(constraints.positive_definite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
|
||||
(constraints.symmetric, False, [[2.0, 0], [2.0, 2]]),
|
||||
(constraints.positive_semidefinite, False, [[2.0, 0], [2.0, 2]]),
|
||||
(constraints.positive_definite, False, [[2.0, 0], [2.0, 2]]),
|
||||
(constraints.symmetric, True, [[3.0, -5], [-5.0, 3]]),
|
||||
(constraints.positive_semidefinite, False, [[3.0, -5], [-5.0, 3]]),
|
||||
(constraints.positive_definite, False, [[3.0, -5], [-5.0, 3]]),
|
||||
(constraints.symmetric, True, [[1.0, 2], [2.0, 4]]),
|
||||
(constraints.positive_semidefinite, True, [[1.0, 2], [2.0, 4]]),
|
||||
(constraints.positive_definite, False, [[1.0, 2], [2.0, 4]]),
|
||||
(constraints.symmetric, True, [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]]),
|
||||
(
|
||||
constraints.positive_semidefinite,
|
||||
False,
|
||||
[[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]],
|
||||
),
|
||||
(
|
||||
constraints.positive_definite,
|
||||
False,
|
||||
[[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]],
|
||||
),
|
||||
(constraints.symmetric, True, [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]]),
|
||||
(
|
||||
constraints.positive_semidefinite,
|
||||
True,
|
||||
[[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]],
|
||||
),
|
||||
(
|
||||
constraints.positive_definite,
|
||||
False,
|
||||
[[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]],
|
||||
),
|
||||
(constraints.symmetric, True, [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]]),
|
||||
(
|
||||
constraints.positive_semidefinite,
|
||||
True,
|
||||
[[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]],
|
||||
),
|
||||
(
|
||||
constraints.positive_definite,
|
||||
True,
|
||||
[[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]],
|
||||
),
|
||||
]
|
||||
|
||||
CONSTRAINTS = [
|
||||
(constraints.real,),
|
||||
(constraints.real_vector,),
|
||||
(constraints.positive,),
|
||||
(constraints.greater_than, [-10., -2, 0, 2, 10]),
|
||||
(constraints.greater_than, [-10.0, -2, 0, 2, 10]),
|
||||
(constraints.greater_than, 0),
|
||||
(constraints.greater_than, 2),
|
||||
(constraints.greater_than, -2),
|
||||
(constraints.greater_than_eq, 0),
|
||||
(constraints.greater_than_eq, 2),
|
||||
(constraints.greater_than_eq, -2),
|
||||
(constraints.less_than, [-10., -2, 0, 2, 10]),
|
||||
(constraints.less_than, [-10.0, -2, 0, 2, 10]),
|
||||
(constraints.less_than, 0),
|
||||
(constraints.less_than, 2),
|
||||
(constraints.less_than, -2),
|
||||
(constraints.unit_interval,),
|
||||
(constraints.interval, [-4., -2, 0, 2, 4], [-3., 3, 1, 5, 5]),
|
||||
(constraints.interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]),
|
||||
(constraints.interval, -2, -1),
|
||||
(constraints.interval, 1, 2),
|
||||
(constraints.half_open_interval, [-4., -2, 0, 2, 4], [-3., 3, 1, 5, 5]),
|
||||
(constraints.half_open_interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]),
|
||||
(constraints.half_open_interval, -2, -1),
|
||||
(constraints.half_open_interval, 1, 2),
|
||||
(constraints.simplex,),
|
||||
@ -64,25 +88,40 @@ def build_constraint(constraint_fn, args, is_cuda=False):
|
||||
t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
|
||||
return constraint_fn(*(t(x) if isinstance(x, list) else x for x in args))
|
||||
|
||||
@pytest.mark.parametrize(('constraint_fn', 'result', 'value'), EXAMPLES)
|
||||
@pytest.mark.parametrize('is_cuda', [False,
|
||||
pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
|
||||
reason='CUDA not found.'))])
|
||||
|
||||
@pytest.mark.parametrize(("constraint_fn", "result", "value"), EXAMPLES)
|
||||
@pytest.mark.parametrize(
|
||||
"is_cuda",
|
||||
[
|
||||
False,
|
||||
pytest.param(
|
||||
True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_constraint(constraint_fn, result, value, is_cuda):
|
||||
t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
|
||||
assert constraint_fn.check(t(value)).all() == result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('constraint_fn', 'args'), [(c[0], c[1:]) for c in CONSTRAINTS])
|
||||
@pytest.mark.parametrize('is_cuda', [False,
|
||||
pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
|
||||
reason='CUDA not found.'))])
|
||||
@pytest.mark.parametrize(
|
||||
("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"is_cuda",
|
||||
[
|
||||
False,
|
||||
pytest.param(
|
||||
True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_biject_to(constraint_fn, args, is_cuda):
|
||||
constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
|
||||
try:
|
||||
t = biject_to(constraint)
|
||||
except NotImplementedError:
|
||||
pytest.skip('`biject_to` not implemented.')
|
||||
pytest.skip("`biject_to` not implemented.")
|
||||
assert t.bijective, f"biject_to({constraint}) is not bijective"
|
||||
if constraint_fn is constraints.corr_cholesky:
|
||||
# (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
|
||||
@ -92,22 +131,32 @@ def test_biject_to(constraint_fn, args, is_cuda):
|
||||
if is_cuda:
|
||||
x = x.cuda()
|
||||
y = t(x)
|
||||
assert constraint.check(y).all(), '\n'.join([
|
||||
f"Failed to biject_to({constraint})",
|
||||
f"x = {x}",
|
||||
f"biject_to(...)(x) = {y}",
|
||||
])
|
||||
assert constraint.check(y).all(), "\n".join(
|
||||
[
|
||||
f"Failed to biject_to({constraint})",
|
||||
f"x = {x}",
|
||||
f"biject_to(...)(x) = {y}",
|
||||
]
|
||||
)
|
||||
x2 = t.inv(y)
|
||||
assert torch.allclose(x, x2), f"Error in biject_to({constraint}) inverse"
|
||||
|
||||
j = t.log_abs_det_jacobian(x, y)
|
||||
assert j.shape == x.shape[:x.dim() - t.domain.event_dim]
|
||||
assert j.shape == x.shape[: x.dim() - t.domain.event_dim]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('constraint_fn', 'args'), [(c[0], c[1:]) for c in CONSTRAINTS])
|
||||
@pytest.mark.parametrize('is_cuda', [False,
|
||||
pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
|
||||
reason='CUDA not found.'))])
|
||||
@pytest.mark.parametrize(
|
||||
("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"is_cuda",
|
||||
[
|
||||
False,
|
||||
pytest.param(
|
||||
True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_transform_to(constraint_fn, args, is_cuda):
|
||||
constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
|
||||
t = transform_to(constraint)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -8,15 +8,34 @@ import pytest
|
||||
import torch
|
||||
from torch.autograd import grad
|
||||
from torch.autograd.functional import jacobian
|
||||
from torch.distributions import Dirichlet, Independent, Normal, TransformedDistribution, constraints
|
||||
from torch.distributions.transforms import (AbsTransform, AffineTransform, ComposeTransform,
|
||||
CorrCholeskyTransform, CumulativeDistributionTransform,
|
||||
ExpTransform, IndependentTransform,
|
||||
LowerCholeskyTransform, PowerTransform,
|
||||
ReshapeTransform, SigmoidTransform, TanhTransform,
|
||||
SoftmaxTransform, SoftplusTransform, StickBreakingTransform,
|
||||
identity_transform, Transform, _InverseTransform,
|
||||
PositiveDefiniteTransform)
|
||||
from torch.distributions import (
|
||||
constraints,
|
||||
Dirichlet,
|
||||
Independent,
|
||||
Normal,
|
||||
TransformedDistribution,
|
||||
)
|
||||
from torch.distributions.transforms import (
|
||||
_InverseTransform,
|
||||
AbsTransform,
|
||||
AffineTransform,
|
||||
ComposeTransform,
|
||||
CorrCholeskyTransform,
|
||||
CumulativeDistributionTransform,
|
||||
ExpTransform,
|
||||
identity_transform,
|
||||
IndependentTransform,
|
||||
LowerCholeskyTransform,
|
||||
PositiveDefiniteTransform,
|
||||
PowerTransform,
|
||||
ReshapeTransform,
|
||||
SigmoidTransform,
|
||||
SoftmaxTransform,
|
||||
SoftplusTransform,
|
||||
StickBreakingTransform,
|
||||
TanhTransform,
|
||||
Transform,
|
||||
)
|
||||
from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
@ -25,57 +44,53 @@ def get_transforms(cache_size):
|
||||
transforms = [
|
||||
AbsTransform(cache_size=cache_size),
|
||||
ExpTransform(cache_size=cache_size),
|
||||
PowerTransform(exponent=2,
|
||||
cache_size=cache_size),
|
||||
PowerTransform(exponent=-2,
|
||||
cache_size=cache_size),
|
||||
PowerTransform(exponent=torch.tensor(5.).normal_(),
|
||||
cache_size=cache_size),
|
||||
PowerTransform(exponent=torch.tensor(5.).normal_(),
|
||||
cache_size=cache_size),
|
||||
PowerTransform(exponent=2, cache_size=cache_size),
|
||||
PowerTransform(exponent=-2, cache_size=cache_size),
|
||||
PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size),
|
||||
PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size),
|
||||
SigmoidTransform(cache_size=cache_size),
|
||||
TanhTransform(cache_size=cache_size),
|
||||
AffineTransform(0, 1, cache_size=cache_size),
|
||||
AffineTransform(1, -2, cache_size=cache_size),
|
||||
AffineTransform(torch.randn(5),
|
||||
torch.randn(5),
|
||||
cache_size=cache_size),
|
||||
AffineTransform(torch.randn(4, 5),
|
||||
torch.randn(4, 5),
|
||||
cache_size=cache_size),
|
||||
AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size),
|
||||
AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size),
|
||||
SoftmaxTransform(cache_size=cache_size),
|
||||
SoftplusTransform(cache_size=cache_size),
|
||||
StickBreakingTransform(cache_size=cache_size),
|
||||
LowerCholeskyTransform(cache_size=cache_size),
|
||||
CorrCholeskyTransform(cache_size=cache_size),
|
||||
PositiveDefiniteTransform(cache_size=cache_size),
|
||||
ComposeTransform([
|
||||
AffineTransform(torch.randn(4, 5),
|
||||
torch.randn(4, 5),
|
||||
cache_size=cache_size),
|
||||
]),
|
||||
ComposeTransform([
|
||||
AffineTransform(torch.randn(4, 5),
|
||||
torch.randn(4, 5),
|
||||
cache_size=cache_size),
|
||||
ExpTransform(cache_size=cache_size),
|
||||
]),
|
||||
ComposeTransform([
|
||||
AffineTransform(0, 1, cache_size=cache_size),
|
||||
AffineTransform(torch.randn(4, 5),
|
||||
torch.randn(4, 5),
|
||||
cache_size=cache_size),
|
||||
AffineTransform(1, -2, cache_size=cache_size),
|
||||
AffineTransform(torch.randn(4, 5),
|
||||
torch.randn(4, 5),
|
||||
cache_size=cache_size),
|
||||
]),
|
||||
ComposeTransform(
|
||||
[
|
||||
AffineTransform(
|
||||
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
|
||||
),
|
||||
]
|
||||
),
|
||||
ComposeTransform(
|
||||
[
|
||||
AffineTransform(
|
||||
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
|
||||
),
|
||||
ExpTransform(cache_size=cache_size),
|
||||
]
|
||||
),
|
||||
ComposeTransform(
|
||||
[
|
||||
AffineTransform(0, 1, cache_size=cache_size),
|
||||
AffineTransform(
|
||||
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
|
||||
),
|
||||
AffineTransform(1, -2, cache_size=cache_size),
|
||||
AffineTransform(
|
||||
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
|
||||
),
|
||||
]
|
||||
),
|
||||
ReshapeTransform((4, 5), (2, 5, 2)),
|
||||
IndependentTransform(
|
||||
AffineTransform(torch.randn(5),
|
||||
torch.randn(5),
|
||||
cache_size=cache_size),
|
||||
1),
|
||||
AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), 1
|
||||
),
|
||||
CumulativeDistributionTransform(Normal(0, 1)),
|
||||
]
|
||||
transforms += [t.inv for t in transforms]
|
||||
@ -88,9 +103,17 @@ def reshape_transform(transform, shape):
|
||||
if isinstance(transform.loc, Number):
|
||||
return transform
|
||||
try:
|
||||
return AffineTransform(transform.loc.expand(shape), transform.scale.expand(shape), cache_size=transform._cache_size)
|
||||
return AffineTransform(
|
||||
transform.loc.expand(shape),
|
||||
transform.scale.expand(shape),
|
||||
cache_size=transform._cache_size,
|
||||
)
|
||||
except RuntimeError:
|
||||
return AffineTransform(transform.loc.reshape(shape), transform.scale.reshape(shape), cache_size=transform._cache_size)
|
||||
return AffineTransform(
|
||||
transform.loc.reshape(shape),
|
||||
transform.scale.reshape(shape),
|
||||
cache_size=transform._cache_size,
|
||||
)
|
||||
if isinstance(transform, ComposeTransform):
|
||||
reshaped_parts = []
|
||||
for p in transform.parts:
|
||||
@ -106,8 +129,12 @@ def reshape_transform(transform, shape):
|
||||
# Generate pytest ids
|
||||
def transform_id(x):
|
||||
assert isinstance(x, Transform)
|
||||
name = f'Inv({type(x._inv).__name__})' if isinstance(x, _InverseTransform) else f'{type(x).__name__}'
|
||||
return f'{name}(cache_size={x._cache_size})'
|
||||
name = (
|
||||
f"Inv({type(x._inv).__name__})"
|
||||
if isinstance(x, _InverseTransform)
|
||||
else f"{type(x).__name__}"
|
||||
)
|
||||
return f"{name}(cache_size={x._cache_size})"
|
||||
|
||||
|
||||
def generate_data(transform):
|
||||
@ -119,12 +146,17 @@ def generate_data(transform):
|
||||
if isinstance(transform.inv, ReshapeTransform):
|
||||
return torch.randn(transform.inv.out_shape)
|
||||
domain = transform.domain
|
||||
while (isinstance(domain, constraints.independent) and
|
||||
domain is not constraints.real_vector):
|
||||
while (
|
||||
isinstance(domain, constraints.independent)
|
||||
and domain is not constraints.real_vector
|
||||
):
|
||||
domain = domain.base_constraint
|
||||
codomain = transform.codomain
|
||||
x = torch.empty(4, 5)
|
||||
positive_definite_constraints = [constraints.lower_cholesky, constraints.positive_definite]
|
||||
positive_definite_constraints = [
|
||||
constraints.lower_cholesky,
|
||||
constraints.positive_definite,
|
||||
]
|
||||
if domain in positive_definite_constraints:
|
||||
x = torch.randn(6, 6)
|
||||
x = x.tril(-1) + x.diag().exp().diag_embed()
|
||||
@ -159,21 +191,23 @@ def generate_data(transform):
|
||||
x /= x.norm(dim=-1, keepdim=True)
|
||||
x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs())
|
||||
return x
|
||||
raise ValueError(f'Unsupported domain: {domain}')
|
||||
raise ValueError(f"Unsupported domain: {domain}")
|
||||
|
||||
|
||||
TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1)
|
||||
TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0)
|
||||
ALL_TRANSFORMS = TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform]
|
||||
ALL_TRANSFORMS = (
|
||||
TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
|
||||
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
|
||||
def test_inv_inv(transform, ids=transform_id):
|
||||
assert transform.inv.inv is transform
|
||||
|
||||
|
||||
@pytest.mark.parametrize('x', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
||||
@pytest.mark.parametrize('y', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
||||
@pytest.mark.parametrize("x", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
||||
@pytest.mark.parametrize("y", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
||||
def test_equality(x, y):
|
||||
if x is y:
|
||||
assert x == y
|
||||
@ -182,7 +216,7 @@ def test_equality(x, y):
|
||||
assert identity_transform == identity_transform.inv
|
||||
|
||||
|
||||
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
|
||||
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
|
||||
def test_with_cache(transform):
|
||||
if transform._cache_size == 0:
|
||||
transform = transform.with_cache(1)
|
||||
@ -191,20 +225,20 @@ def test_with_cache(transform):
|
||||
try:
|
||||
y = transform(x)
|
||||
except NotImplementedError:
|
||||
pytest.skip('Not implemented.')
|
||||
pytest.skip("Not implemented.")
|
||||
y2 = transform(x)
|
||||
assert y2 is y
|
||||
|
||||
|
||||
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
|
||||
@pytest.mark.parametrize('test_cached', [True, False])
|
||||
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
|
||||
@pytest.mark.parametrize("test_cached", [True, False])
|
||||
def test_forward_inverse(transform, test_cached):
|
||||
x = generate_data(transform).requires_grad_()
|
||||
assert transform.domain.check(x).all() # verify that the input data are valid
|
||||
try:
|
||||
y = transform(x)
|
||||
except NotImplementedError:
|
||||
pytest.skip('Not implemented.')
|
||||
pytest.skip("Not implemented.")
|
||||
assert y.shape == transform.forward_shape(x.shape)
|
||||
if test_cached:
|
||||
x2 = transform.inv(y) # should be implemented at least by caching
|
||||
@ -212,26 +246,30 @@ def test_forward_inverse(transform, test_cached):
|
||||
try:
|
||||
x2 = transform.inv(y.clone()) # bypass cache
|
||||
except NotImplementedError:
|
||||
pytest.skip('Not implemented.')
|
||||
pytest.skip("Not implemented.")
|
||||
assert x2.shape == transform.inverse_shape(y.shape)
|
||||
y2 = transform(x2)
|
||||
if transform.bijective:
|
||||
# verify function inverse
|
||||
assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), '\n'.join([
|
||||
f'{transform} t.inv(t(-)) error',
|
||||
f'x = {x}',
|
||||
f'y = t(x) = {y}',
|
||||
f'x2 = t.inv(y) = {x2}',
|
||||
])
|
||||
assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), "\n".join(
|
||||
[
|
||||
f"{transform} t.inv(t(-)) error",
|
||||
f"x = {x}",
|
||||
f"y = t(x) = {y}",
|
||||
f"x2 = t.inv(y) = {x2}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
# verify weaker function pseudo-inverse
|
||||
assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), '\n'.join([
|
||||
f'{transform} t(t.inv(t(-))) error',
|
||||
f'x = {x}',
|
||||
f'y = t(x) = {y}',
|
||||
f'x2 = t.inv(y) = {x2}',
|
||||
f'y2 = t(x2) = {y2}',
|
||||
])
|
||||
assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), "\n".join(
|
||||
[
|
||||
f"{transform} t(t.inv(t(-))) error",
|
||||
f"x = {x}",
|
||||
f"y = t(x) = {y}",
|
||||
f"x2 = t.inv(y) = {x2}",
|
||||
f"y2 = t(x2) = {y2}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_compose_transform_shapes():
|
||||
@ -255,33 +293,36 @@ base_dist1 = Dirichlet(torch.ones(4, 4))
|
||||
base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('batch_shape', 'event_shape', 'dist'), [
|
||||
((4, 4), (), base_dist0),
|
||||
((4,), (4,), base_dist1),
|
||||
((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
|
||||
((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
|
||||
((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
|
||||
((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
|
||||
((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
|
||||
((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
|
||||
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
|
||||
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
|
||||
((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
|
||||
((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
|
||||
((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
|
||||
((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
|
||||
((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
|
||||
((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
|
||||
((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
|
||||
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
|
||||
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
|
||||
((3, 4, 4), (), base_dist2),
|
||||
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
|
||||
((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
|
||||
((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
|
||||
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
|
||||
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
("batch_shape", "event_shape", "dist"),
|
||||
[
|
||||
((4, 4), (), base_dist0),
|
||||
((4,), (4,), base_dist1),
|
||||
((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
|
||||
((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
|
||||
((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
|
||||
((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
|
||||
((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
|
||||
((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
|
||||
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
|
||||
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
|
||||
((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
|
||||
((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
|
||||
((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
|
||||
((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
|
||||
((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
|
||||
((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
|
||||
((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
|
||||
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
|
||||
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
|
||||
((3, 4, 4), (), base_dist2),
|
||||
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
|
||||
((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
|
||||
((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
|
||||
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
|
||||
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
|
||||
],
|
||||
)
|
||||
def test_transformed_distribution_shapes(batch_shape, event_shape, dist):
|
||||
assert dist.batch_shape == batch_shape
|
||||
assert dist.event_shape == event_shape
|
||||
@ -289,10 +330,10 @@ def test_transformed_distribution_shapes(batch_shape, event_shape, dist):
|
||||
try:
|
||||
dist.log_prob(x) # this should not crash
|
||||
except NotImplementedError:
|
||||
pytest.skip('Not implemented.')
|
||||
pytest.skip("Not implemented.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
||||
@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
||||
def test_jit_fwd(transform):
|
||||
x = generate_data(transform).requires_grad_()
|
||||
|
||||
@ -302,14 +343,14 @@ def test_jit_fwd(transform):
|
||||
try:
|
||||
traced_f = torch.jit.trace(f, (x,))
|
||||
except NotImplementedError:
|
||||
pytest.skip('Not implemented.')
|
||||
pytest.skip("Not implemented.")
|
||||
|
||||
# check on different inputs
|
||||
x = generate_data(transform).requires_grad_()
|
||||
assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
||||
@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
||||
def test_jit_inv(transform):
|
||||
y = generate_data(transform.inv).requires_grad_()
|
||||
|
||||
@ -319,14 +360,14 @@ def test_jit_inv(transform):
|
||||
try:
|
||||
traced_f = torch.jit.trace(f, (y,))
|
||||
except NotImplementedError:
|
||||
pytest.skip('Not implemented.')
|
||||
pytest.skip("Not implemented.")
|
||||
|
||||
# check on different inputs
|
||||
y = generate_data(transform.inv).requires_grad_()
|
||||
assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
||||
@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
||||
def test_jit_jacobian(transform):
|
||||
x = generate_data(transform).requires_grad_()
|
||||
|
||||
@ -337,23 +378,23 @@ def test_jit_jacobian(transform):
|
||||
try:
|
||||
traced_f = torch.jit.trace(f, (x,))
|
||||
except NotImplementedError:
|
||||
pytest.skip('Not implemented.')
|
||||
pytest.skip("Not implemented.")
|
||||
|
||||
# check on different inputs
|
||||
x = generate_data(transform).requires_grad_()
|
||||
assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
|
||||
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
|
||||
def test_jacobian(transform):
|
||||
x = generate_data(transform)
|
||||
try:
|
||||
y = transform(x)
|
||||
actual = transform.log_abs_det_jacobian(x, y)
|
||||
except NotImplementedError:
|
||||
pytest.skip('Not implemented.')
|
||||
pytest.skip("Not implemented.")
|
||||
# Test shape
|
||||
target_shape = x.shape[:x.dim() - transform.domain.event_dim]
|
||||
target_shape = x.shape[: x.dim() - transform.domain.event_dim]
|
||||
assert actual.shape == target_shape
|
||||
|
||||
# Expand if required
|
||||
@ -366,7 +407,9 @@ def test_jacobian(transform):
|
||||
transform = reshape_transform(transform, x_.shape)
|
||||
|
||||
# 1. Transforms with unit jacobian
|
||||
if isinstance(transform, ReshapeTransform) or isinstance(transform.inv, ReshapeTransform):
|
||||
if isinstance(transform, ReshapeTransform) or isinstance(
|
||||
transform.inv, ReshapeTransform
|
||||
):
|
||||
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
|
||||
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
|
||||
# 2. Transforms with 0 off-diagonal elements
|
||||
@ -380,8 +423,10 @@ def test_jacobian(transform):
|
||||
if isinstance(transform, CorrCholeskyTransform):
|
||||
jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_)
|
||||
elif isinstance(transform.inv, CorrCholeskyTransform):
|
||||
jac = jacobian(lambda x: transform(vec_to_tril_matrix(x, diag=-1)),
|
||||
tril_matrix_to_vec(x_, diag=-1))
|
||||
jac = jacobian(
|
||||
lambda x: transform(vec_to_tril_matrix(x, diag=-1)),
|
||||
tril_matrix_to_vec(x_, diag=-1),
|
||||
)
|
||||
elif isinstance(transform, StickBreakingTransform):
|
||||
jac = jacobian(lambda x: transform(x)[..., :-1], x_)
|
||||
else:
|
||||
@ -393,20 +438,28 @@ def test_jacobian(transform):
|
||||
# can be computed.
|
||||
gather_idx_shape = list(jac.shape)
|
||||
gather_idx_shape[-2] = 1
|
||||
gather_idxs = torch.arange(n).reshape((n,) + (1,) * (len(jac.shape) - 1)).expand(gather_idx_shape)
|
||||
gather_idxs = (
|
||||
torch.arange(n)
|
||||
.reshape((n,) + (1,) * (len(jac.shape) - 1))
|
||||
.expand(gather_idx_shape)
|
||||
)
|
||||
jac = jac.gather(-2, gather_idxs).squeeze(-2)
|
||||
out_ndims = jac.shape[-2]
|
||||
jac = jac[..., :out_ndims] # Remove extra zero-valued dims (for inverse stick-breaking).
|
||||
jac = jac[
|
||||
..., :out_ndims
|
||||
] # Remove extra zero-valued dims (for inverse stick-breaking).
|
||||
expected = torch.slogdet(jac).logabsdet
|
||||
|
||||
assert torch.allclose(actual, expected, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("event_dims",
|
||||
[(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)],
|
||||
ids=str)
|
||||
@pytest.mark.parametrize(
|
||||
"event_dims", [(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)], ids=str
|
||||
)
|
||||
def test_compose_affine(event_dims):
|
||||
transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims]
|
||||
transforms = [
|
||||
AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims
|
||||
]
|
||||
transform = ComposeTransform(transforms)
|
||||
assert transform.codomain.event_dim == max(event_dims)
|
||||
assert transform.domain.event_dim == max(event_dims)
|
||||
@ -426,10 +479,12 @@ def test_compose_affine(event_dims):
|
||||
|
||||
@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
|
||||
def test_compose_reshape(batch_shape):
|
||||
transforms = [ReshapeTransform((), ()),
|
||||
ReshapeTransform((2,), (1, 2)),
|
||||
ReshapeTransform((3, 1, 2), (6,)),
|
||||
ReshapeTransform((6,), (2, 3))]
|
||||
transforms = [
|
||||
ReshapeTransform((), ()),
|
||||
ReshapeTransform((2,), (1, 2)),
|
||||
ReshapeTransform((3, 1, 2), (6,)),
|
||||
ReshapeTransform((6,), (2, 3)),
|
||||
]
|
||||
transform = ComposeTransform(transforms)
|
||||
assert transform.codomain.event_dim == 2
|
||||
assert transform.domain.event_dim == 2
|
||||
@ -447,16 +502,19 @@ def test_compose_reshape(batch_shape):
|
||||
@pytest.mark.parametrize("base_batch_dim", [0, 1, 2])
|
||||
@pytest.mark.parametrize("base_event_dim", [0, 1, 2])
|
||||
@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3])
|
||||
def test_transformed_distribution(base_batch_dim, base_event_dim, transform_dim,
|
||||
num_transforms, sample_shape):
|
||||
def test_transformed_distribution(
|
||||
base_batch_dim, base_event_dim, transform_dim, num_transforms, sample_shape
|
||||
):
|
||||
shape = torch.Size([2, 3, 4, 5])
|
||||
base_dist = Normal(0, 1)
|
||||
base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:])
|
||||
base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim :])
|
||||
if base_event_dim:
|
||||
base_dist = Independent(base_dist, base_event_dim)
|
||||
transforms = [AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1),
|
||||
ReshapeTransform((4, 5), (20,)),
|
||||
ReshapeTransform((3, 20), (6, 10))]
|
||||
transforms = [
|
||||
AffineTransform(torch.zeros(shape[4 - transform_dim :]), 1),
|
||||
ReshapeTransform((4, 5), (20,)),
|
||||
ReshapeTransform((3, 20), (6, 10)),
|
||||
]
|
||||
transforms = transforms[:num_transforms]
|
||||
transform = ComposeTransform(transforms)
|
||||
|
||||
@ -498,17 +556,17 @@ def test_save_load_transform():
|
||||
assert torch.allclose(log_prob, other.log_prob(x))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
|
||||
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
|
||||
def test_transform_sign(transform: Transform):
|
||||
try:
|
||||
sign = transform.sign
|
||||
except NotImplementedError:
|
||||
pytest.skip('Not implemented.')
|
||||
pytest.skip("Not implemented.")
|
||||
|
||||
x = generate_data(transform).requires_grad_()
|
||||
y = transform(x).sum()
|
||||
derivatives, = grad(y, [x])
|
||||
assert torch.less(torch.as_tensor(0.), derivatives * sign).all()
|
||||
(derivatives,) = grad(y, [x])
|
||||
assert torch.less(torch.as_tensor(0.0), derivatives * sign).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -6,12 +6,16 @@ import torch
|
||||
from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
@pytest.mark.parametrize('shape', [
|
||||
(2, 2),
|
||||
(3, 3),
|
||||
(2, 4, 4),
|
||||
(2, 2, 4, 4),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(2, 2),
|
||||
(3, 3),
|
||||
(2, 4, 4),
|
||||
(2, 2, 4, 4),
|
||||
],
|
||||
)
|
||||
def test_tril_matrix_to_vec(shape):
|
||||
mat = torch.randn(shape)
|
||||
n = mat.shape[-1]
|
||||
|
@ -6,66 +6,71 @@ def check_error(desc, fn, *required_substrings):
|
||||
fn()
|
||||
except Exception as e:
|
||||
error_message = e.args[0]
|
||||
print('=' * 80)
|
||||
print("=" * 80)
|
||||
print(desc)
|
||||
print('-' * 80)
|
||||
print("-" * 80)
|
||||
print(error_message)
|
||||
print('')
|
||||
print("")
|
||||
for sub in required_substrings:
|
||||
assert sub in error_message
|
||||
return
|
||||
raise AssertionError(f"given function ({desc}) didn't raise an error")
|
||||
|
||||
|
||||
check_error("Wrong argument types", lambda: torch.FloatStorage(object()), "object")
|
||||
|
||||
check_error(
|
||||
'Wrong argument types',
|
||||
lambda: torch.FloatStorage(object()),
|
||||
'object')
|
||||
"Unknown keyword argument", lambda: torch.FloatStorage(content=1234.0), "keyword"
|
||||
)
|
||||
|
||||
check_error('Unknown keyword argument',
|
||||
lambda: torch.FloatStorage(content=1234.),
|
||||
'keyword')
|
||||
check_error(
|
||||
"Invalid types inside a sequence",
|
||||
lambda: torch.FloatStorage(["a", "b"]),
|
||||
"list",
|
||||
"str",
|
||||
)
|
||||
|
||||
check_error('Invalid types inside a sequence',
|
||||
lambda: torch.FloatStorage(['a', 'b']),
|
||||
'list', 'str')
|
||||
check_error("Invalid size type", lambda: torch.FloatStorage(1.5), "float")
|
||||
|
||||
check_error('Invalid size type',
|
||||
lambda: torch.FloatStorage(1.5),
|
||||
'float')
|
||||
check_error(
|
||||
"Invalid offset", lambda: torch.FloatStorage(torch.FloatStorage(2), 4), "2", "4"
|
||||
)
|
||||
|
||||
check_error('Invalid offset',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(2), 4),
|
||||
'2', '4')
|
||||
check_error(
|
||||
"Negative offset", lambda: torch.FloatStorage(torch.FloatStorage(2), -1), "2", "-1"
|
||||
)
|
||||
|
||||
check_error('Negative offset',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(2), -1),
|
||||
'2', '-1')
|
||||
check_error(
|
||||
"Invalid size",
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
|
||||
"2",
|
||||
"1",
|
||||
"5",
|
||||
)
|
||||
|
||||
check_error('Invalid size',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
|
||||
'2', '1', '5')
|
||||
check_error(
|
||||
"Negative size",
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
|
||||
"2",
|
||||
"1",
|
||||
"-5",
|
||||
)
|
||||
|
||||
check_error('Negative size',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
|
||||
'2', '1', '-5')
|
||||
|
||||
check_error('Invalid index type',
|
||||
lambda: torch.FloatStorage(10)['first item'],
|
||||
'str')
|
||||
check_error("Invalid index type", lambda: torch.FloatStorage(10)["first item"], "str")
|
||||
|
||||
|
||||
def assign():
|
||||
torch.FloatStorage(10)[1:-1] = '1'
|
||||
check_error('Invalid value type',
|
||||
assign,
|
||||
'str')
|
||||
torch.FloatStorage(10)[1:-1] = "1"
|
||||
|
||||
check_error('resize_ with invalid type',
|
||||
lambda: torch.FloatStorage(10).resize_(1.5),
|
||||
'float')
|
||||
|
||||
check_error('fill_ with invalid type',
|
||||
lambda: torch.IntStorage(10).fill_('asdf'),
|
||||
'str')
|
||||
check_error("Invalid value type", assign, "str")
|
||||
|
||||
check_error(
|
||||
"resize_ with invalid type", lambda: torch.FloatStorage(10).resize_(1.5), "float"
|
||||
)
|
||||
|
||||
check_error(
|
||||
"fill_ with invalid type", lambda: torch.IntStorage(10).fill_("asdf"), "str"
|
||||
)
|
||||
|
||||
# TODO: frombuffer
|
||||
|
@ -108,14 +108,12 @@ ALLOW_LIST = [
|
||||
("aten::mps_max_pool2d_backward.out", datetime.date(9999, 1, 1)),
|
||||
# TODO: FIXME: prims shouldn't be checked
|
||||
("prims::.*", datetime.date(9999, 1, 1)),
|
||||
|
||||
("aten::_flash_attention_forward", datetime.date(2023, 12, 30)),
|
||||
("aten::_flash_attention_backward", datetime.date(2023, 12, 30)),
|
||||
("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
|
||||
# BetterTransformer 1.0 internal operators
|
||||
("aten::_transformer_decoder_only_layer_fwd", datetime.date(9999, 1, 1)),
|
||||
("aten::_native_decoder_only_multi_head_attention",
|
||||
datetime.date(9999, 1, 1)),
|
||||
("aten::_native_decoder_only_multi_head_attention", datetime.date(9999, 1, 1)),
|
||||
("c10d::_allgather_base_", datetime.date(2023, 12, 30)),
|
||||
("c10d::_reduce_scatter_base_", datetime.date(2023, 12, 30)),
|
||||
("c10d::broadcast_", datetime.date(2023, 12, 30)),
|
||||
@ -149,9 +147,12 @@ ALLOW_LIST_COMPILED = [
|
||||
re.compile(item[0]),
|
||||
item[1],
|
||||
re.compile(item[2]) if len(item) > 2 else None,
|
||||
) for item in ALLOW_LIST if item[1] >= datetime.date.today()
|
||||
)
|
||||
for item in ALLOW_LIST
|
||||
if item[1] >= datetime.date.today()
|
||||
]
|
||||
|
||||
|
||||
def allow_listed(schema):
|
||||
for item in ALLOW_LIST_COMPILED:
|
||||
if item[0].search(str(schema)):
|
||||
@ -171,6 +172,7 @@ dont_parse_list = [
|
||||
("__backends__.nnc", datetime.date(2099, 9, 17)),
|
||||
]
|
||||
|
||||
|
||||
def has_valid_upgraders(schema, version_map):
|
||||
# we want to parse through the map to find if
|
||||
# the schema has valid upgraders. Since the
|
||||
@ -199,6 +201,7 @@ def has_valid_upgraders(schema, version_map):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def dont_parse(schema_line):
|
||||
for item in dont_parse_list:
|
||||
if item[1] < datetime.date.today():
|
||||
@ -208,6 +211,7 @@ def dont_parse(schema_line):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load_schemas_to_dict():
|
||||
new_schemas = torch._C._jit_get_all_schemas()
|
||||
new_schemas += torch._C._jit_get_custom_class_schemas()
|
||||
@ -216,6 +220,7 @@ def load_schemas_to_dict():
|
||||
new_schema_dict[s.name].append(s)
|
||||
return new_schema_dict
|
||||
|
||||
|
||||
def process_version_map(version_map):
|
||||
# version map maps full schema name to
|
||||
# list of upgraders. Since we only have
|
||||
@ -225,12 +230,13 @@ def process_version_map(version_map):
|
||||
# Dict[schema_name, Dict[overload, List[schema]]]
|
||||
|
||||
output = defaultdict(dict)
|
||||
for (key, entries) in version_map.items():
|
||||
for key, entries in version_map.items():
|
||||
operator_name = key.split(".")[0]
|
||||
schema_entries = [parse_schema(entry.old_schema) for entry in entries]
|
||||
output[operator_name][key] = schema_entries
|
||||
return output
|
||||
|
||||
|
||||
def check_bc(existing_schemas):
|
||||
new_schema_dict = load_schemas_to_dict()
|
||||
version_map = process_version_map(torch._C._get_operator_version_map())
|
||||
@ -272,6 +278,7 @@ def check_bc(existing_schemas):
|
||||
)
|
||||
return is_bc
|
||||
|
||||
|
||||
def check_fc(existing_schemas):
|
||||
new_schema_dict = load_schemas_to_dict()
|
||||
is_fc = True
|
||||
@ -285,7 +292,9 @@ def check_fc(existing_schemas):
|
||||
found = False
|
||||
possible_failure_reasons = []
|
||||
for matching_new_schema in matching_new_schemas:
|
||||
is_compatible, reason = matching_new_schema.check_forward_compatible_with(existing_schema)
|
||||
is_compatible, reason = matching_new_schema.check_forward_compatible_with(
|
||||
existing_schema
|
||||
)
|
||||
if is_compatible:
|
||||
found = True
|
||||
break
|
||||
|
@ -1,24 +1,25 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def dump(filename):
|
||||
schemas = torch._C._jit_get_all_schemas()
|
||||
schemas += torch._C._jit_get_custom_class_schemas()
|
||||
with open(filename, 'w') as f:
|
||||
with open(filename, "w") as f:
|
||||
for s in schemas:
|
||||
f.write(str(s))
|
||||
f.write('\n')
|
||||
f.write("\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Process some integers.")
|
||||
parser.add_argument(
|
||||
'-f',
|
||||
'--filename',
|
||||
help='filename to dump the schemas',
|
||||
"-f",
|
||||
"--filename",
|
||||
help="filename to dump the schemas",
|
||||
type=str,
|
||||
default='schemas.txt')
|
||||
default="schemas.txt",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
dump(args.filename)
|
||||
|
Reference in New Issue
Block a user