Implement Transforms (#4771)

This commit is contained in:
Alican Bozkurt
2018-01-28 15:17:16 -05:00
committed by Adam Paszke
parent 3ecd25b065
commit 967bceb16b
8 changed files with 849 additions and 13 deletions

View File

@ -91,6 +91,12 @@ Probability distributions - torch.distributions
.. autoclass:: Laplace
:members:
:hidden:`LogNormal`
~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LogNormal
:members:
:hidden:`Normal`
~~~~~~~~~~~~~~~~~~~~~~~
@ -121,6 +127,12 @@ Probability distributions - torch.distributions
.. autoclass:: StudentT
:members:
:hidden:`TransformedDistribution`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: TransformedDistribution
:members:
:hidden:`Uniform`
~~~~~~~~~~~~~~~~~~~~~~~
@ -135,3 +147,17 @@ Probability distributions - torch.distributions
.. autofunction:: kl_divergence
.. autofunction:: register_kl
`Transforms`
~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: torch.distributions.transforms
:members:
:member-order: bysource
`Constraints`
~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: torch.distributions.constraints
:members:
:member-order: bysource

View File

@ -31,13 +31,21 @@ from random import shuffle
import torch
from common import TestCase, run_tests, set_rng_seed
from torch.autograd import Variable, grad, gradcheck, variable
from torch.distributions import Distribution
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical, Cauchy, Chi2,
Dirichlet, Exponential, FisherSnedecor, Gamma, Geometric,
Gumbel, Laplace, Normal, OneHotCategorical, Multinomial,
Pareto, Poisson, StudentT, Uniform, kl_divergence)
from torch.distributions.dirichlet import _Dirichlet_backward
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
Cauchy, Chi2, Dirichlet, Distribution,
Exponential, FisherSnedecor, Gamma, Geometric,
Gumbel, Laplace, LogNormal, Multinomial,
Normal, OneHotCategorical, Pareto, Poisson,
StudentT, Uniform, constraints, kl_divergence)
from torch.distributions.constraints import Constraint, is_dependent
from torch.distributions.dirichlet import _Dirichlet_backward
from torch.distributions.transforms import (AbsTransform, AffineTransform,
BoltzmannTransform,
ComposeTransform, ExpTransform,
LowerCholeskyTransform,
SigmoidTransform,
StickBreakingTransform,
identity_transform)
from torch.distributions.utils import _finfo, probs_to_logits
TEST_NUMPY = True
@ -167,6 +175,20 @@ EXAMPLES = [
'scale': torch.Tensor([1e-5, 1e-5]),
},
]),
Example(LogNormal, [
{
'loc': Variable(torch.randn(5, 5), requires_grad=True),
'scale': Variable(torch.randn(5, 5).abs(), requires_grad=True),
},
{
'loc': Variable(torch.randn(1), requires_grad=True),
'scale': Variable(torch.randn(1).abs(), requires_grad=True),
},
{
'loc': torch.Tensor([1.0, 0.0]),
'scale': torch.Tensor([1e-5, 1e-5]),
},
]),
Example(Normal, [
{
'loc': Variable(torch.randn(5, 5), requires_grad=True),
@ -665,6 +687,51 @@ class TestDistributions(TestCase):
loc.grad.zero_()
scale.grad.zero_()
def test_lognormal(self):
mean = Variable(torch.randn(5, 5), requires_grad=True)
std = Variable(torch.randn(5, 5).abs(), requires_grad=True)
mean_1d = Variable(torch.randn(1), requires_grad=True)
std_1d = Variable(torch.randn(1), requires_grad=True)
mean_delta = torch.Tensor([1.0, 0.0])
std_delta = torch.Tensor([1e-5, 1e-5])
self.assertEqual(LogNormal(mean, std).sample().size(), (5, 5))
self.assertEqual(LogNormal(mean, std).sample_n(7).size(), (7, 5, 5))
self.assertEqual(LogNormal(mean_1d, std_1d).sample_n(1).size(), (1, 1))
self.assertEqual(LogNormal(mean_1d, std_1d).sample().size(), (1,))
self.assertEqual(LogNormal(0.2, .6).sample_n(1).size(), (1,))
self.assertEqual(LogNormal(-0.7, 50.0).sample_n(1).size(), (1,))
# sample check for extreme value of mean, std
set_rng_seed(1)
self.assertEqual(LogNormal(mean_delta, std_delta).sample(sample_shape=(1, 2)),
torch.Tensor([[[math.exp(1), 1.0], [math.exp(1), 1.0]]]),
prec=1e-4)
self._gradcheck_log_prob(LogNormal, (mean, std))
self._gradcheck_log_prob(LogNormal, (mean, 1.0))
self._gradcheck_log_prob(LogNormal, (0.0, std))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_lognormal_logprob(self):
mean = Variable(torch.randn(5, 1), requires_grad=True)
std = Variable(torch.randn(5, 1).abs(), requires_grad=True)
def ref_log_prob(idx, x, log_prob):
m = mean.data.view(-1)[idx]
s = std.data.view(-1)[idx]
expected = scipy.stats.lognorm(s=s, scale=math.exp(m)).logpdf(x)
self.assertAlmostEqual(log_prob, expected, places=3)
self._check_log_prob(LogNormal(mean, std), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_lognormal_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for mean, std in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(LogNormal(mean, std),
scipy.stats.lognorm(scale=math.exp(mean), s=std),
'LogNormal(loc={}, scale={})'.format(mean, std))
def test_normal(self):
loc = Variable(torch.randn(5, 5), requires_grad=True)
scale = Variable(torch.randn(5, 5).abs(), requires_grad=True)
@ -1420,7 +1487,7 @@ class TestDistributionShapes(TestCase):
dist = Dist(**param)
try:
actual_shape = dist.entropy().size()
expected_shape = dist._batch_shape if dist._batch_shape else torch.Size(SCALAR_SHAPE)
expected_shape = dist.batch_shape if dist.batch_shape else torch.Size(SCALAR_SHAPE)
message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
Dist.__name__, i + 1, len(params), expected_shape, actual_shape)
self.assertEqual(actual_shape, expected_shape, message=message)
@ -1717,7 +1784,7 @@ class TestKL(TestCase):
def __init__(self, probs):
super(Binomial30, self).__init__(30, probs)
# These are pairs of distributions with 4 x 4 paramters as specified.
# These are pairs of distributions with 4 x 4 parameters as specified.
# The first of the pair e.g. bernoulli[0] varies column-wise and the second
# e.g. bernoulli[1] varies row-wise; that way we test all param pairs.
bernoulli = pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9])
@ -1728,6 +1795,7 @@ class TestKL(TestCase):
gamma = pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
gumbel = pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
pareto = pairwise(Pareto, [2.5, 4.0, 2.5, 4.0], [2.25, 3.75, 2.25, 3.75])
poisson = pairwise(Poisson, [0.3, 1.0, 5.0, 10.0])
@ -1776,6 +1844,7 @@ class TestKL(TestCase):
(gumbel, gumbel),
(gumbel, normal),
(laplace, laplace),
(lognormal, lognormal),
(laplace, normal),
(normal, gumbel),
(normal, normal),
@ -2100,5 +2169,164 @@ class TestLazyLogitsInitialization(TestCase):
self.assertFalse('logits' in vars(dist), msg=message)
class TestTransforms(TestCase):
def setUp(self):
self.transforms = []
transforms_by_cache_size = {}
for cache_size in [0, 1]:
transforms = [
AbsTransform(cache_size=cache_size),
ExpTransform(cache_size=cache_size),
SigmoidTransform(cache_size=cache_size),
AffineTransform(Variable(torch.Tensor(5).normal_()),
Variable(torch.Tensor(5).normal_()),
cache_size=cache_size),
AffineTransform(Variable(torch.Tensor(4, 5).normal_()),
Variable(torch.Tensor(4, 5).normal_()),
cache_size=cache_size),
BoltzmannTransform(cache_size=cache_size),
StickBreakingTransform(cache_size=cache_size),
LowerCholeskyTransform(cache_size=cache_size),
ComposeTransform([
AffineTransform(Variable(torch.Tensor(4, 5).normal_()),
Variable(torch.Tensor(4, 5).normal_()),
cache_size=cache_size),
]),
ComposeTransform([
AffineTransform(Variable(torch.Tensor(4, 5).normal_()),
Variable(torch.Tensor(4, 5).normal_()),
cache_size=cache_size),
ExpTransform(cache_size=cache_size),
]),
]
for t in transforms[:]:
transforms.append(t.inv)
transforms.append(identity_transform)
self.transforms += transforms
if cache_size == 0:
self.unique_transforms = transforms[:]
def _generate_data(self, transform):
domain = transform.domain
codomain = transform.codomain
x = torch.Tensor(4, 5)
if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky:
x = torch.Tensor(6, 6)
x = x.normal_()
return x
elif domain is constraints.real:
return x.normal_()
elif domain is constraints.positive:
return x.normal_().exp()
elif domain is constraints.unit_interval:
return x.uniform_()
elif domain is constraints.simplex:
x = x.normal_().exp()
x /= x.sum(-1, True)
return x
raise ValueError('Unsupported domain: {}'.format(domain))
def test_inv_inv(self):
for t in self.transforms:
self.assertTrue(t.inv.inv is t)
def test_equality(self):
transforms = self.unique_transforms
for x, y in product(transforms, transforms):
if x is y:
self.assertTrue(x == y)
self.assertFalse(x != y)
else:
self.assertFalse(x == y)
self.assertTrue(x != y)
self.assertTrue(identity_transform == identity_transform.inv)
self.assertFalse(identity_transform != identity_transform.inv)
def test_forward_inverse_cache(self):
for transform in self.transforms:
x = Variable(self._generate_data(transform), requires_grad=True)
try:
y = transform(x)
except NotImplementedError:
continue
x2 = transform.inv(y) # should be implemented at least by caching
y2 = transform(x2) # should be implemented at least by caching
if transform.bijective:
# verify function inverse
self.assertEqual(x2, x, message='\n'.join([
'{} t.inv(t(-)) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
]))
else:
# verify weaker function pseudo-inverse
self.assertEqual(y2, y, message='\n'.join([
'{} t(t.inv(t(-))) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
'y2 = t(x2) = {}'.format(y2),
]))
def test_forward_inverse_no_cache(self):
for transform in self.transforms:
x = Variable(self._generate_data(transform), requires_grad=True)
try:
y = transform(x)
x2 = transform.inv(y.clone()) # bypass cache
y2 = transform(x2)
except NotImplementedError:
continue
if transform.bijective:
# verify function inverse
self.assertEqual(x2, x, message='\n'.join([
'{} t.inv(t(-)) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
]))
else:
# verify weaker function pseudo-inverse
self.assertEqual(y2, y, message='\n'.join([
'{} t(t.inv(t(-))) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
'y2 = t(x2) = {}'.format(y2),
]))
def test_univariate_forward_jacobian(self):
for transform in self.transforms:
x = Variable(self._generate_data(transform), requires_grad=True)
try:
y = transform(x)
actual = transform.log_abs_det_jacobian(x, y)
except NotImplementedError:
continue
expected = torch.abs(grad([y.sum()], [x])[0]).log()
self.assertEqual(actual, expected, message='\n'.join([
'Bad {}.log_abs_det_jacobian() disagrees with ()'.format(transform),
'Expected: {}'.format(expected),
'Actual: {}'.format(actual),
]))
def test_univariate_inverse_jacobian(self):
for transform in self.transforms:
y = Variable(self._generate_data(transform.inv), requires_grad=True)
try:
x = transform.inv(y)
actual = transform.log_abs_det_jacobian(x, y)
except NotImplementedError:
continue
expected = -torch.abs(grad([x.sum()], [y])[0]).log()
self.assertEqual(actual, expected, message='\n'.join([
'{}.log_abs_det_jacobian() disagrees with .inv()'.format(transform),
'Expected: {}'.format(expected),
'Actual: {}'.format(actual),
]))
if __name__ == '__main__':
run_tests()

View File

@ -32,6 +32,7 @@ policy, the code for implementing REINFORCE would be as follows::
from .bernoulli import Bernoulli
from .beta import Beta
from .transforms import *
from .binomial import Binomial
from .categorical import Categorical
from .cauchy import Cauchy
@ -45,12 +46,14 @@ from .geometric import Geometric
from .gumbel import Gumbel
from .kl import kl_divergence, register_kl
from .laplace import Laplace
from .log_normal import LogNormal
from .multinomial import Multinomial
from .normal import Normal
from .one_hot_categorical import OneHotCategorical
from .pareto import Pareto
from .studentT import StudentT
from .poisson import Poisson
from .studentT import StudentT
from .transformed_distribution import TransformedDistribution
from .uniform import Uniform
__all__ = [
@ -68,6 +71,7 @@ __all__ = [
'Geometric',
'Gumbel',
'Laplace',
'LogNormal',
'Multinomial',
'Normal',
'OneHotCategorical',
@ -75,6 +79,8 @@ __all__ = [
'StudentT',
'Poisson',
'Uniform',
'TransformedDistribution',
'kl_divergence',
'register_kl',
]
__all__.extend(transforms.__all__)

View File

@ -11,6 +11,7 @@ __all__ = [
'interval',
'is_dependent',
'less_than',
'lower_cholesky',
'lower_triangular',
'nonnegative_integer',
'positive',
@ -165,7 +166,19 @@ class _LowerTriangular(Constraint):
Constrain to lower-triangular square matrices.
"""
def check(self, value):
return (torch.tril(value) == value).min(-1).min(-1)
return (torch.tril(value) == value).min(-1)[0].min(-1)[0]
class _LowerCholesky(Constraint):
"""
Constrain to lower-triangular square matrices with positive diagonals.
"""
def check(self, value):
n = value.size(-1)
diag_mask = torch.eye(n, n, out=value.new(n, n))
lower_triangular = (torch.tril(value) == value).min(-1)[0].min(-1)[0]
positive_diagonal = (value * diag_mask > (diag_mask - 1)).min(-1)[0].min(-1)[0]
return lower_triangular & positive_diagonal
# Public interface.
@ -183,3 +196,4 @@ unit_interval = _Interval(0, 1)
interval = _Interval
simplex = _Simplex()
lower_triangular = _LowerTriangular()
lower_cholesky = _LowerCholesky()

View File

@ -1,22 +1,24 @@
import math
import warnings
from functools import total_ordering
import torch
import math
from .distribution import Distribution
from .bernoulli import Bernoulli
from .binomial import Binomial
from .beta import Beta
from .binomial import Binomial
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exponential import Exponential
from .gamma import Gamma
from .geometric import Geometric
from .gumbel import Gumbel
from .laplace import Laplace
from .log_normal import LogNormal
from .normal import Normal
from .pareto import Pareto
from .poisson import Poisson
from .transformed_distribution import TransformedDistribution
from .uniform import Uniform
from torch.autograd import Variable, variable
@ -268,6 +270,13 @@ def _kl_poisson_poisson(p, q):
return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate)
@register_kl(TransformedDistribution, TransformedDistribution)
def _kl_transformed_transformed(p, q):
if p.transforms != q.transforms:
raise NotImplementedError
return kl_divergence(p.base_dist, q.base_dist)
@register_kl(Uniform, Uniform)
def _kl_uniform_uniform(p, q):
result = ((q.high - q.low) / (p.high - p.low)).log()

View File

@ -0,0 +1,42 @@
from torch.distributions import constraints
from torch.distributions.transforms import ExpTransform
from torch.distributions.normal import Normal
from torch.distributions.transformed_distribution import TransformedDistribution
class LogNormal(TransformedDistribution):
r"""
Creates a log-normal distribution parameterized by
`mean` and `std` where::
X ~ Normal(loc, scale)
Y = exp(X) ~ LogNormal(loc, scale)
Example::
>>> m = LogNormal(torch.Tensor([0.0]), torch.Tensor([1.0]))
>>> m.sample() # log-normal distributed with mean=0 and stddev=1
0.1046
[torch.FloatTensor of size 1]
Args:
loc (float or Tensor or Variable): mean of log of distribution
scale (float or Tensor or Variable): standard deviation of log ofthe distribution
"""
params = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.positive
has_rsample = True
def __init__(self, loc, scale):
super(LogNormal, self).__init__(Normal(loc, scale), ExpTransform())
@property
def loc(self):
return self.base_dist.loc
@property
def scale(self):
return self.base_dist.scale
def entropy(self):
return self.base_dist.entropy() + self.loc

View File

@ -0,0 +1,80 @@
import torch
from torch.distributions import constraints
from torch.distributions.transforms import Transform
from torch.distributions.distribution import Distribution
class TransformedDistribution(Distribution):
r"""
Extension of the Distribution class, which applies a sequence of Transforms to a base distribution.
Let f be the composition of transforms applied,
X ~ BaseDistribution
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
log p(Y) = log p(X) + log det (dX/dY)
"""
def __init__(self, base_distribution, transforms=[], *args, **kwargs):
super(TransformedDistribution, self).__init__(*args, **kwargs)
self.base_dist = base_distribution
if isinstance(transforms, Transform):
self.transforms = [transforms, ]
elif isinstance(transforms, list):
if not all(isinstance(t, Transform) for t in transforms):
raise ValueError("transforms must be a Transform or a list of Transforms")
self.transforms = transforms
@constraints.dependent_property
def params(self):
return self.base_dist.params # TODO add params of transforms?
@constraints.dependent_property
def support(self):
return self.transforms[-1].codomain if self.transforms else self.base_dist.support
@property
def has_rsample(self):
return self.base_dist.has_rsample
@property
def batch_shape(self):
return self.base_dist.batch_shape
@property
def event_shape(self):
return self.base_dist.event_shape
def sample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped sample or sample_shape shaped batch of
samples if the distribution parameters are batched. Samples first from base distribution
and applies `transform()` for every transform in the list.
"""
x = self.base_dist.sample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x
def rsample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped reparameterized sample or sample_shape
shaped batch of reparameterized samples if the distribution parameters
are batched. Samples first from base distribution and applies `transform()`
for every transform in the list.
"""
x = self.base_dist.rsample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x
def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score using the score
of the base distribution and the log abs det jacobian
"""
log_prob = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
log_prob -= transform.log_abs_det_jacobian(x, y)
y = x
log_prob += self.base_dist.log_prob(y)
return log_prob

View File

@ -0,0 +1,431 @@
import weakref
import torch
from torch.autograd import Variable
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all, lazy_property
from torch.nn.functional import sigmoid
__all__ = [
'AbsTransform',
'AffineTransform',
'BoltzmannTransform',
'ComposeTransform',
'ExpTransform',
'LowerCholeskyTransform',
'SigmoidTransform',
'StickBreakingTransform',
'Transform',
'identity_transform',
]
class Transform(object):
"""
Abstract class for invertable transformations with computable log
det jacobians. They are primarily used in
:class:`torch.distributions.TransformedDistribution`.
Caching is useful for tranforms whose inverses are either expensive or
numerically unstable. Note that care must be taken with memoized values
since the autograd graph may be reversed. For example while the following
works with or without caching::
y = t(x)
t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.
However the following will error when caching due to dependency reversal::
y = t(x)
z = t.inv(y)
grad(z.sum(), [y]) # error because z is x
Derived classes should implement one or both of :meth:`_call` or
:meth:`_inverse`. Derived classes that set `bijective=True` should also
implement :meth:`log_abs_det_jacobian`.
Args:
cache_size (int): Size of cache. If zero, no caching is done. If one,
the latest single value is cached. Only 0 and 1 are supported.
Attributes:
domain (:class:`~torch.distributions.constraints.Constraint`):
The constraint representing valid inputs to this transform.
codomain (:class:`~torch.distributions.constraints.Constraint`):
The constraint representing valid outputs to this transform
which are inputs to the inverse transform.
bijective (bool): Whether this transform is bijective. A transform
``t`` is bijective iff ``t.inv(t(x)) == x`` and
``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in
the codomain. Transforms that are not bijective should at least
maintain the weaker pseudoinverse properties
``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
"""
bijective = False
def __init__(self, cache_size=0):
self._cache_size = cache_size
self._inv = None
if cache_size == 0:
pass # default behavior
elif cache_size == 1:
self._cached_x_y = None, None
else:
raise ValueError('cache_size must be 0 or 1')
@property
def inv(self):
"""
Returns the inverse :class:`Transform` of this transform.
This should satisfy ``t.inv.inv is t``.
"""
inv = None
if self._inv is not None:
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
self._inv = weakref.ref(inv)
return inv
def __eq__(self, other):
return self is other
def __ne__(self, other):
# Necessary for Python2
return not self.__eq__(other)
def __call__(self, x):
"""
Computes the transform `x => y`.
"""
if self._cache_size == 0:
return self._call(x)
x_old, y_old = self._cached_x_y
if x is x_old:
return y_old
y = self._call(x)
self._cached_x_y = x, y
return y
def _inv_call(self, y):
"""
Inverts the transform `y => x`.
"""
if self._cache_size == 0:
return self._inverse(y)
x_old, y_old = self._cached_x_y
if y is y_old:
return x_old
x = self._inverse(y)
self._cached_x_y = x, y
return x
def _call(self, x):
"""
Abstract method to compute forward transformation.
"""
raise NotImplementedError
def _inverse(self, y):
"""
Abstract method to compute inverse transformation.
"""
raise NotImplementedError
def log_abs_det_jacobian(self, x, y):
"""
Computes the log det jacobian `log |dy/dx|` given input and output.
"""
raise NotImplementedError
class _InverseTransform(Transform):
"""
Inverts a single :class:`Transform`.
This class is private; please instead use the ``Transform.inv`` property.
"""
def __init__(self, transform):
super(_InverseTransform, self).__init__()
self._inv = transform
@constraints.dependent_property
def domain(self):
return self._inv.codomain
@constraints.dependent_property
def codomain(self):
return self._inv.domain
@property
def bijective(self):
return self._inv.bijective
@property
def inv(self):
return self._inv
def __eq__(self, other):
if not isinstance(other, _InverseTransform):
return False
return self._inv == other._inv
def __call__(self, x):
return self._inv._inv_call(x)
def log_abs_det_jacobian(self, x, y):
return -self._inv.log_abs_det_jacobian(y, x)
class ComposeTransform(Transform):
"""
Composes multiple transforms in a chain.
The transforms being composed are responsible for caching.
Args:
parts (list of :class:`Transform`): A list of transforms to compose.
"""
def __init__(self, parts):
super(ComposeTransform, self).__init__()
self.parts = parts
def __eq__(self, other):
if not isinstance(other, ComposeTransform):
return False
return self.parts == other.parts
@constraints.dependent_property
def domain(self):
if not self.parts:
return constraints.real
return self.parts[0].domain
@constraints.dependent_property
def codomain(self):
if not self.parts:
return constraints.real
return self.parts[-1].codomain
@lazy_property
def bijective(self):
return all(p.bijective for p in self.parts)
@property
def inv(self):
inv = None
if self._inv is not None:
inv = self._inv()
if inv is None:
inv = ComposeTransform([p.inv for p in reversed(self.parts)])
self._inv = weakref.ref(inv)
inv._inv = weakref.ref(self)
return inv
def __call__(self, x):
for part in self.parts:
x = part(x)
return x
def log_abs_det_jacobian(self, x, y):
if not self.parts:
return x.new([0]).expand_as(x)
result = 0
for part in self.parts:
y = part(x)
result += part.log_abs_det_jacobian(x, y)
x = y
return result
identity_transform = ComposeTransform([])
class ExpTransform(Transform):
"""
Transform via the mapping `y = exp(x)`.
"""
domain = constraints.real
codomain = constraints.positive
bijective = True
def __eq__(self, other):
return isinstance(other, ExpTransform)
def _call(self, x):
return x.exp()
def _inverse(self, y):
return y.log()
def log_abs_det_jacobian(self, x, y):
return x
class SigmoidTransform(Transform):
"""
Transform via the mapping `y = sigmoid(x)` and `x = logit(y)`.
"""
domain = constraints.real
codomain = constraints.unit_interval
bijective = True
def __eq__(self, other):
return isinstance(other, SigmoidTransform)
def _call(self, x):
return sigmoid(x)
def _inverse(self, y):
return y.log() - (-y).log1p()
def log_abs_det_jacobian(self, x, y):
return -(y.reciprocal() + (1 - y).reciprocal()).log()
class AbsTransform(Transform):
"""
Transform via the mapping `y = abs(x)`.
"""
domain = constraints.real
codomain = constraints.positive
def __eq__(self, other):
return isinstance(other, AbsTransform)
def _call(self, x):
return x.abs()
def _inverse(self, y):
return y
class AffineTransform(Transform):
"""
Transform via the pointwise affine mapping `y = loc + scale * x`.
Args:
loc (Tensor or Variable): Location parameter.
scale (Tensor or Variable): Scale parameter.
event_dim (int): Optional size of `event_shape`. This should be zero
for univariate random variables, 1 for distributions over vectors,
2 for distributions over matrices, etc.
"""
domain = constraints.real
codomain = constraints.real
bijective = True
def __init__(self, loc, scale, event_dim=0, cache_size=0):
super(AffineTransform, self).__init__(cache_size=cache_size)
self.loc, self.scale = broadcast_all(loc, scale)
self.event_dim = event_dim
def __eq__(self, other):
if not isinstance(other, AffineTransform):
return False
result = self.loc.eq(other.loc).all() and self.scale.eq(other.scale).all()
if isinstance(result, Variable):
result = result.data.view(-1)[0]
return result
def _call(self, x):
return self.loc + self.scale * x
def _inverse(self, y):
return (y - self.loc) / self.scale
def log_abs_det_jacobian(self, x, y):
result = torch.abs(self.scale).log()
shape = x.shape
if self.event_dim:
result_size = result.size()[:-self.event_dim] + (-1,)
result = result.view(result_size).sum(-1)
shape = shape[:-self.event_dim]
return result.expand(shape)
class BoltzmannTransform(Transform):
"""
Transform from unconstrained space to the simplex via `y = exp(x)` then
normalizing.
This is not bijective and cannot be used for HMC. However this acts mostly
coordinate-wise (except for the final normalization), and thus is
appropriate for coordinate-wise optimization algorithms.
"""
domain = constraints.real
codomain = constraints.simplex
def __eq__(self, other):
return isinstance(other, BoltzmannTransform)
def _call(self, x):
logprobs = x
probs = (logprobs - logprobs.max(-1, True)[0]).exp()
probs /= probs.sum(-1, True)
return probs
def _inverse(self, y):
probs = y
return probs.log()
class StickBreakingTransform(Transform):
"""
Transform from unconstrained space to the simplex of one additional
dimension via a stick-breaking process.
This transform arises as an iterated sigmoid transform in a stick-breaking
construction of the `Dirichlet` distribution: the first logit is
transformed via sigmoid to the first probability and the probability of
everything else, and then the process recurses.
This is bijective and appropriate for use in HMC; however it mixes
coordinates together and is less appropriate for optimization.
"""
domain = constraints.real
codomain = constraints.simplex
bijective = True
def __eq__(self, other):
return isinstance(other, StickBreakingTransform)
def _call(self, x):
shape = x.shape[:-1] + (1 + x.shape[-1],)
one = x.new([1]).expand(x.shape[:-1] + (1,))
numer = sigmoid(x)
denom = (1 - numer).cumprod(-1)
probs = torch.cat([numer, one], -1) * torch.cat([one, denom], -1)
return probs
def _inverse(self, y):
pmf = y
cmf = pmf.cumsum(-1)
sf = 1 - cmf
units = y[..., :-1] / sf[..., :-1]
return units.log()
# TODO implement .log_abs_det_jacobian()
class LowerCholeskyTransform(Transform):
"""
Transform from unconstrained matrices to lower-triangular matrices with
nonnegative diagonal entries.
This is useful for parameterizing positive definite matrices in terms of
their Cholesky factorization.
"""
domain = constraints.real
codomain = constraints.lower_cholesky
def __eq__(self, other):
return isinstance(other, LowerCholeskyTransform)
def _call(self, x):
if x.dim() != 2:
raise NotImplementedError
return x.tril(-1) + x.diag().exp().diag()
def _inverse(self, y):
if y.dim() != 2:
raise NotImplementedError
return y.tril(-1) + y.diag().log().diag()