mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
test test_save_load_transform in [test_transforms.py](https://github.com/pytorch/pytorch/blob/main/test/distributions/test_transforms.py) _pytest test_transforms.py -k test_save_load_transform_ error message: ``` . . . File "/workspace/pytorch/test/distributions/test_transforms.py", line 555, in test_save_load_transform other = torch.load(stream) ^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/serialization.py", line 1444, in load raise pickle.UnpicklingError(_get_wo_message(str(e))) from None _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint. (1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message. WeightsUnpickler error: Unsupported global: GLOBAL torch.distributions.transformed_distribution.TransformedDistribution was not an allowed global by default. Please use `torch.serialization.add_safe_globals([TransformedDistribution])` or the `torch.serialization.safe_globals([TransformedDistribution])` context manager to allowlist this global if you trust this class/function. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/140494 Approved by: https://github.com/mikaylagawarecki
577 lines
21 KiB
Python
577 lines
21 KiB
Python
# Owner(s): ["module: distributions"]
|
|
|
|
import io
|
|
from numbers import Number
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
from torch.autograd import grad
|
|
from torch.autograd.functional import jacobian
|
|
from torch.distributions import (
|
|
constraints,
|
|
Dirichlet,
|
|
Independent,
|
|
Normal,
|
|
TransformedDistribution,
|
|
)
|
|
from torch.distributions.transforms import (
|
|
_InverseTransform,
|
|
AbsTransform,
|
|
AffineTransform,
|
|
ComposeTransform,
|
|
CorrCholeskyTransform,
|
|
CumulativeDistributionTransform,
|
|
ExpTransform,
|
|
identity_transform,
|
|
IndependentTransform,
|
|
LowerCholeskyTransform,
|
|
PositiveDefiniteTransform,
|
|
PowerTransform,
|
|
ReshapeTransform,
|
|
SigmoidTransform,
|
|
SoftmaxTransform,
|
|
SoftplusTransform,
|
|
StickBreakingTransform,
|
|
TanhTransform,
|
|
Transform,
|
|
)
|
|
from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
|
|
def get_transforms(cache_size):
|
|
transforms = [
|
|
AbsTransform(cache_size=cache_size),
|
|
ExpTransform(cache_size=cache_size),
|
|
PowerTransform(exponent=2, cache_size=cache_size),
|
|
PowerTransform(exponent=-2, cache_size=cache_size),
|
|
PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size),
|
|
PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size),
|
|
SigmoidTransform(cache_size=cache_size),
|
|
TanhTransform(cache_size=cache_size),
|
|
AffineTransform(0, 1, cache_size=cache_size),
|
|
AffineTransform(1, -2, cache_size=cache_size),
|
|
AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size),
|
|
AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size),
|
|
SoftmaxTransform(cache_size=cache_size),
|
|
SoftplusTransform(cache_size=cache_size),
|
|
StickBreakingTransform(cache_size=cache_size),
|
|
LowerCholeskyTransform(cache_size=cache_size),
|
|
CorrCholeskyTransform(cache_size=cache_size),
|
|
PositiveDefiniteTransform(cache_size=cache_size),
|
|
ComposeTransform(
|
|
[
|
|
AffineTransform(
|
|
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
|
|
),
|
|
]
|
|
),
|
|
ComposeTransform(
|
|
[
|
|
AffineTransform(
|
|
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
|
|
),
|
|
ExpTransform(cache_size=cache_size),
|
|
]
|
|
),
|
|
ComposeTransform(
|
|
[
|
|
AffineTransform(0, 1, cache_size=cache_size),
|
|
AffineTransform(
|
|
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
|
|
),
|
|
AffineTransform(1, -2, cache_size=cache_size),
|
|
AffineTransform(
|
|
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
|
|
),
|
|
]
|
|
),
|
|
ReshapeTransform((4, 5), (2, 5, 2)),
|
|
IndependentTransform(
|
|
AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), 1
|
|
),
|
|
CumulativeDistributionTransform(Normal(0, 1)),
|
|
]
|
|
transforms += [t.inv for t in transforms]
|
|
return transforms
|
|
|
|
|
|
def reshape_transform(transform, shape):
|
|
# Needed to squash batch dims for testing jacobian
|
|
if isinstance(transform, AffineTransform):
|
|
if isinstance(transform.loc, Number):
|
|
return transform
|
|
try:
|
|
return AffineTransform(
|
|
transform.loc.expand(shape),
|
|
transform.scale.expand(shape),
|
|
cache_size=transform._cache_size,
|
|
)
|
|
except RuntimeError:
|
|
return AffineTransform(
|
|
transform.loc.reshape(shape),
|
|
transform.scale.reshape(shape),
|
|
cache_size=transform._cache_size,
|
|
)
|
|
if isinstance(transform, ComposeTransform):
|
|
reshaped_parts = []
|
|
for p in transform.parts:
|
|
reshaped_parts.append(reshape_transform(p, shape))
|
|
return ComposeTransform(reshaped_parts, cache_size=transform._cache_size)
|
|
if isinstance(transform.inv, AffineTransform):
|
|
return reshape_transform(transform.inv, shape).inv
|
|
if isinstance(transform.inv, ComposeTransform):
|
|
return reshape_transform(transform.inv, shape).inv
|
|
return transform
|
|
|
|
|
|
# Generate pytest ids
|
|
def transform_id(x):
|
|
assert isinstance(x, Transform)
|
|
name = (
|
|
f"Inv({type(x._inv).__name__})"
|
|
if isinstance(x, _InverseTransform)
|
|
else f"{type(x).__name__}"
|
|
)
|
|
return f"{name}(cache_size={x._cache_size})"
|
|
|
|
|
|
def generate_data(transform):
|
|
torch.manual_seed(1)
|
|
while isinstance(transform, IndependentTransform):
|
|
transform = transform.base_transform
|
|
if isinstance(transform, ReshapeTransform):
|
|
return torch.randn(transform.in_shape)
|
|
if isinstance(transform.inv, ReshapeTransform):
|
|
return torch.randn(transform.inv.out_shape)
|
|
domain = transform.domain
|
|
while (
|
|
isinstance(domain, constraints.independent)
|
|
and domain is not constraints.real_vector
|
|
):
|
|
domain = domain.base_constraint
|
|
codomain = transform.codomain
|
|
x = torch.empty(4, 5)
|
|
positive_definite_constraints = [
|
|
constraints.lower_cholesky,
|
|
constraints.positive_definite,
|
|
]
|
|
if domain in positive_definite_constraints:
|
|
x = torch.randn(6, 6)
|
|
x = x.tril(-1) + x.diag().exp().diag_embed()
|
|
if domain is constraints.positive_definite:
|
|
return x @ x.T
|
|
return x
|
|
elif codomain in positive_definite_constraints:
|
|
return torch.randn(6, 6)
|
|
elif domain is constraints.real:
|
|
return x.normal_()
|
|
elif domain is constraints.real_vector:
|
|
# For corr_cholesky the last dim in the vector
|
|
# must be of size (dim * dim) // 2
|
|
x = torch.empty(3, 6)
|
|
x = x.normal_()
|
|
return x
|
|
elif domain is constraints.positive:
|
|
return x.normal_().exp()
|
|
elif domain is constraints.unit_interval:
|
|
return x.uniform_()
|
|
elif isinstance(domain, constraints.interval):
|
|
x = x.uniform_()
|
|
x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound)
|
|
return x
|
|
elif domain is constraints.simplex:
|
|
x = x.normal_().exp()
|
|
x /= x.sum(-1, True)
|
|
return x
|
|
elif domain is constraints.corr_cholesky:
|
|
x = torch.empty(4, 5, 5)
|
|
x = x.normal_().tril()
|
|
x /= x.norm(dim=-1, keepdim=True)
|
|
x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs())
|
|
return x
|
|
raise ValueError(f"Unsupported domain: {domain}")
|
|
|
|
|
|
TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1)
|
|
TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0)
|
|
ALL_TRANSFORMS = (
|
|
TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform]
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
|
|
def test_inv_inv(transform, ids=transform_id):
|
|
assert transform.inv.inv is transform
|
|
|
|
|
|
@pytest.mark.parametrize("x", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
|
@pytest.mark.parametrize("y", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
|
def test_equality(x, y):
|
|
if x is y:
|
|
assert x == y
|
|
else:
|
|
assert x != y
|
|
assert identity_transform == identity_transform.inv
|
|
|
|
|
|
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
|
|
def test_with_cache(transform):
|
|
if transform._cache_size == 0:
|
|
transform = transform.with_cache(1)
|
|
assert transform._cache_size == 1
|
|
x = generate_data(transform).requires_grad_()
|
|
try:
|
|
y = transform(x)
|
|
except NotImplementedError:
|
|
pytest.skip("Not implemented.")
|
|
y2 = transform(x)
|
|
assert y2 is y
|
|
|
|
|
|
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
|
|
@pytest.mark.parametrize("test_cached", [True, False])
|
|
def test_forward_inverse(transform, test_cached):
|
|
x = generate_data(transform).requires_grad_()
|
|
assert transform.domain.check(x).all() # verify that the input data are valid
|
|
try:
|
|
y = transform(x)
|
|
except NotImplementedError:
|
|
pytest.skip("Not implemented.")
|
|
assert y.shape == transform.forward_shape(x.shape)
|
|
if test_cached:
|
|
x2 = transform.inv(y) # should be implemented at least by caching
|
|
else:
|
|
try:
|
|
x2 = transform.inv(y.clone()) # bypass cache
|
|
except NotImplementedError:
|
|
pytest.skip("Not implemented.")
|
|
assert x2.shape == transform.inverse_shape(y.shape)
|
|
y2 = transform(x2)
|
|
if transform.bijective:
|
|
# verify function inverse
|
|
assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), "\n".join(
|
|
[
|
|
f"{transform} t.inv(t(-)) error",
|
|
f"x = {x}",
|
|
f"y = t(x) = {y}",
|
|
f"x2 = t.inv(y) = {x2}",
|
|
]
|
|
)
|
|
else:
|
|
# verify weaker function pseudo-inverse
|
|
assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), "\n".join(
|
|
[
|
|
f"{transform} t(t.inv(t(-))) error",
|
|
f"x = {x}",
|
|
f"y = t(x) = {y}",
|
|
f"x2 = t.inv(y) = {x2}",
|
|
f"y2 = t(x2) = {y2}",
|
|
]
|
|
)
|
|
|
|
|
|
def test_compose_transform_shapes():
|
|
transform0 = ExpTransform()
|
|
transform1 = SoftmaxTransform()
|
|
transform2 = LowerCholeskyTransform()
|
|
|
|
assert transform0.event_dim == 0
|
|
assert transform1.event_dim == 1
|
|
assert transform2.event_dim == 2
|
|
assert ComposeTransform([transform0, transform1]).event_dim == 1
|
|
assert ComposeTransform([transform0, transform2]).event_dim == 2
|
|
assert ComposeTransform([transform1, transform2]).event_dim == 2
|
|
|
|
|
|
transform0 = ExpTransform()
|
|
transform1 = SoftmaxTransform()
|
|
transform2 = LowerCholeskyTransform()
|
|
base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4))
|
|
base_dist1 = Dirichlet(torch.ones(4, 4))
|
|
base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("batch_shape", "event_shape", "dist"),
|
|
[
|
|
((4, 4), (), base_dist0),
|
|
((4,), (4,), base_dist1),
|
|
((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
|
|
((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
|
|
((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
|
|
((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
|
|
((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
|
|
((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
|
|
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
|
|
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
|
|
((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
|
|
((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
|
|
((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
|
|
((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
|
|
((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
|
|
((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
|
|
((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
|
|
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
|
|
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
|
|
((3, 4, 4), (), base_dist2),
|
|
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
|
|
((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
|
|
((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
|
|
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
|
|
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
|
|
],
|
|
)
|
|
def test_transformed_distribution_shapes(batch_shape, event_shape, dist):
|
|
assert dist.batch_shape == batch_shape
|
|
assert dist.event_shape == event_shape
|
|
x = dist.rsample()
|
|
try:
|
|
dist.log_prob(x) # this should not crash
|
|
except NotImplementedError:
|
|
pytest.skip("Not implemented.")
|
|
|
|
|
|
@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
|
def test_jit_fwd(transform):
|
|
x = generate_data(transform).requires_grad_()
|
|
|
|
def f(x):
|
|
return transform(x)
|
|
|
|
try:
|
|
traced_f = torch.jit.trace(f, (x,))
|
|
except NotImplementedError:
|
|
pytest.skip("Not implemented.")
|
|
|
|
# check on different inputs
|
|
x = generate_data(transform).requires_grad_()
|
|
assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
|
|
|
|
|
|
@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
|
def test_jit_inv(transform):
|
|
y = generate_data(transform.inv).requires_grad_()
|
|
|
|
def f(y):
|
|
return transform.inv(y)
|
|
|
|
try:
|
|
traced_f = torch.jit.trace(f, (y,))
|
|
except NotImplementedError:
|
|
pytest.skip("Not implemented.")
|
|
|
|
# check on different inputs
|
|
y = generate_data(transform.inv).requires_grad_()
|
|
assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True)
|
|
|
|
|
|
@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
|
|
def test_jit_jacobian(transform):
|
|
x = generate_data(transform).requires_grad_()
|
|
|
|
def f(x):
|
|
y = transform(x)
|
|
return transform.log_abs_det_jacobian(x, y)
|
|
|
|
try:
|
|
traced_f = torch.jit.trace(f, (x,))
|
|
except NotImplementedError:
|
|
pytest.skip("Not implemented.")
|
|
|
|
# check on different inputs
|
|
x = generate_data(transform).requires_grad_()
|
|
assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
|
|
|
|
|
|
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
|
|
def test_jacobian(transform):
|
|
x = generate_data(transform)
|
|
try:
|
|
y = transform(x)
|
|
actual = transform.log_abs_det_jacobian(x, y)
|
|
except NotImplementedError:
|
|
pytest.skip("Not implemented.")
|
|
# Test shape
|
|
target_shape = x.shape[: x.dim() - transform.domain.event_dim]
|
|
assert actual.shape == target_shape
|
|
|
|
# Expand if required
|
|
transform = reshape_transform(transform, x.shape)
|
|
ndims = len(x.shape)
|
|
event_dim = ndims - transform.domain.event_dim
|
|
x_ = x.view((-1,) + x.shape[event_dim:])
|
|
n = x_.shape[0]
|
|
# Reshape to squash batch dims to a single batch dim
|
|
transform = reshape_transform(transform, x_.shape)
|
|
|
|
# 1. Transforms with unit jacobian
|
|
if isinstance(transform, ReshapeTransform) or isinstance(
|
|
transform.inv, ReshapeTransform
|
|
):
|
|
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
|
|
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
|
|
# 2. Transforms with 0 off-diagonal elements
|
|
elif transform.domain.event_dim == 0:
|
|
jac = jacobian(transform, x_)
|
|
# assert off-diagonal elements are zero
|
|
assert torch.allclose(jac, jac.diagonal().diag_embed())
|
|
expected = jac.diagonal().abs().log().reshape(x.shape)
|
|
# 3. Transforms with non-0 off-diagonal elements
|
|
else:
|
|
if isinstance(transform, CorrCholeskyTransform):
|
|
jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_)
|
|
elif isinstance(transform.inv, CorrCholeskyTransform):
|
|
jac = jacobian(
|
|
lambda x: transform(vec_to_tril_matrix(x, diag=-1)),
|
|
tril_matrix_to_vec(x_, diag=-1),
|
|
)
|
|
elif isinstance(transform, StickBreakingTransform):
|
|
jac = jacobian(lambda x: transform(x)[..., :-1], x_)
|
|
else:
|
|
jac = jacobian(transform, x_)
|
|
|
|
# Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims)
|
|
# However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims)
|
|
# after reshaping the event dims (see above) to give a batched square matrix whose determinant
|
|
# can be computed.
|
|
gather_idx_shape = list(jac.shape)
|
|
gather_idx_shape[-2] = 1
|
|
gather_idxs = (
|
|
torch.arange(n)
|
|
.reshape((n,) + (1,) * (len(jac.shape) - 1))
|
|
.expand(gather_idx_shape)
|
|
)
|
|
jac = jac.gather(-2, gather_idxs).squeeze(-2)
|
|
out_ndims = jac.shape[-2]
|
|
jac = jac[
|
|
..., :out_ndims
|
|
] # Remove extra zero-valued dims (for inverse stick-breaking).
|
|
expected = torch.slogdet(jac).logabsdet
|
|
|
|
assert torch.allclose(actual, expected, atol=1e-5)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"event_dims", [(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)], ids=str
|
|
)
|
|
def test_compose_affine(event_dims):
|
|
transforms = [
|
|
AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims
|
|
]
|
|
transform = ComposeTransform(transforms)
|
|
assert transform.codomain.event_dim == max(event_dims)
|
|
assert transform.domain.event_dim == max(event_dims)
|
|
|
|
base_dist = Normal(0, 1)
|
|
if transform.domain.event_dim:
|
|
base_dist = base_dist.expand((1,) * transform.domain.event_dim)
|
|
dist = TransformedDistribution(base_dist, transform.parts)
|
|
assert dist.support.event_dim == max(event_dims)
|
|
|
|
base_dist = Dirichlet(torch.ones(5))
|
|
if transform.domain.event_dim > 1:
|
|
base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
|
|
dist = TransformedDistribution(base_dist, transforms)
|
|
assert dist.support.event_dim == max(1, *event_dims)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
|
|
def test_compose_reshape(batch_shape):
|
|
transforms = [
|
|
ReshapeTransform((), ()),
|
|
ReshapeTransform((2,), (1, 2)),
|
|
ReshapeTransform((3, 1, 2), (6,)),
|
|
ReshapeTransform((6,), (2, 3)),
|
|
]
|
|
transform = ComposeTransform(transforms)
|
|
assert transform.codomain.event_dim == 2
|
|
assert transform.domain.event_dim == 2
|
|
data = torch.randn(batch_shape + (3, 2))
|
|
assert transform(data).shape == batch_shape + (2, 3)
|
|
|
|
dist = TransformedDistribution(Normal(data, 1), transforms)
|
|
assert dist.batch_shape == batch_shape
|
|
assert dist.event_shape == (2, 3)
|
|
assert dist.support.event_dim == 2
|
|
|
|
|
|
@pytest.mark.parametrize("sample_shape", [(), (7,)], ids=str)
|
|
@pytest.mark.parametrize("transform_dim", [0, 1, 2])
|
|
@pytest.mark.parametrize("base_batch_dim", [0, 1, 2])
|
|
@pytest.mark.parametrize("base_event_dim", [0, 1, 2])
|
|
@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3])
|
|
def test_transformed_distribution(
|
|
base_batch_dim, base_event_dim, transform_dim, num_transforms, sample_shape
|
|
):
|
|
shape = torch.Size([2, 3, 4, 5])
|
|
base_dist = Normal(0, 1)
|
|
base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim :])
|
|
if base_event_dim:
|
|
base_dist = Independent(base_dist, base_event_dim)
|
|
transforms = [
|
|
AffineTransform(torch.zeros(shape[4 - transform_dim :]), 1),
|
|
ReshapeTransform((4, 5), (20,)),
|
|
ReshapeTransform((3, 20), (6, 10)),
|
|
]
|
|
transforms = transforms[:num_transforms]
|
|
transform = ComposeTransform(transforms)
|
|
|
|
# Check validation in .__init__().
|
|
if base_batch_dim + base_event_dim < transform.domain.event_dim:
|
|
with pytest.raises(ValueError):
|
|
TransformedDistribution(base_dist, transforms)
|
|
return
|
|
d = TransformedDistribution(base_dist, transforms)
|
|
|
|
# Check sampling is sufficiently expanded.
|
|
x = d.sample(sample_shape)
|
|
assert x.shape == sample_shape + d.batch_shape + d.event_shape
|
|
num_unique = len(set(x.reshape(-1).tolist()))
|
|
assert num_unique >= 0.9 * x.numel()
|
|
|
|
# Check log_prob shape on full samples.
|
|
log_prob = d.log_prob(x)
|
|
assert log_prob.shape == sample_shape + d.batch_shape
|
|
|
|
# Check log_prob shape on partial samples.
|
|
y = x
|
|
while y.dim() > len(d.event_shape):
|
|
y = y[0]
|
|
log_prob = d.log_prob(y)
|
|
assert log_prob.shape == d.batch_shape
|
|
|
|
|
|
def test_save_load_transform():
|
|
# Evaluating `log_prob` will create a weakref `_inv` which cannot be pickled. Here, we check
|
|
# that `__getstate__` correctly handles the weakref, and that we can evaluate the density after.
|
|
dist = TransformedDistribution(Normal(0, 1), [AffineTransform(2, 3)])
|
|
x = torch.linspace(0, 1, 10)
|
|
log_prob = dist.log_prob(x)
|
|
stream = io.BytesIO()
|
|
torch.save(dist, stream)
|
|
stream.seek(0)
|
|
with torch.serialization.safe_globals(
|
|
[TransformedDistribution, AffineTransform, Normal]
|
|
):
|
|
other = torch.load(stream)
|
|
assert torch.allclose(log_prob, other.log_prob(x))
|
|
|
|
|
|
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
|
|
def test_transform_sign(transform: Transform):
|
|
try:
|
|
sign = transform.sign
|
|
except NotImplementedError:
|
|
pytest.skip("Not implemented.")
|
|
|
|
x = generate_data(transform).requires_grad_()
|
|
y = transform(x).sum()
|
|
(derivatives,) = grad(y, [x])
|
|
assert torch.less(torch.as_tensor(0.0), derivatives * sign).all()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|