mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement Transforms (#4771)
This commit is contained in:
committed by
Adam Paszke
parent
3ecd25b065
commit
967bceb16b
@ -91,6 +91,12 @@ Probability distributions - torch.distributions
|
||||
.. autoclass:: Laplace
|
||||
:members:
|
||||
|
||||
:hidden:`LogNormal`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: LogNormal
|
||||
:members:
|
||||
|
||||
:hidden:`Normal`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -121,6 +127,12 @@ Probability distributions - torch.distributions
|
||||
.. autoclass:: StudentT
|
||||
:members:
|
||||
|
||||
:hidden:`TransformedDistribution`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: TransformedDistribution
|
||||
:members:
|
||||
|
||||
:hidden:`Uniform`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -135,3 +147,17 @@ Probability distributions - torch.distributions
|
||||
|
||||
.. autofunction:: kl_divergence
|
||||
.. autofunction:: register_kl
|
||||
|
||||
`Transforms`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: torch.distributions.transforms
|
||||
:members:
|
||||
:member-order: bysource
|
||||
|
||||
`Constraints`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: torch.distributions.constraints
|
||||
:members:
|
||||
:member-order: bysource
|
||||
|
||||
@ -31,13 +31,21 @@ from random import shuffle
|
||||
import torch
|
||||
from common import TestCase, run_tests, set_rng_seed
|
||||
from torch.autograd import Variable, grad, gradcheck, variable
|
||||
from torch.distributions import Distribution
|
||||
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical, Cauchy, Chi2,
|
||||
Dirichlet, Exponential, FisherSnedecor, Gamma, Geometric,
|
||||
Gumbel, Laplace, Normal, OneHotCategorical, Multinomial,
|
||||
Pareto, Poisson, StudentT, Uniform, kl_divergence)
|
||||
from torch.distributions.dirichlet import _Dirichlet_backward
|
||||
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
|
||||
Cauchy, Chi2, Dirichlet, Distribution,
|
||||
Exponential, FisherSnedecor, Gamma, Geometric,
|
||||
Gumbel, Laplace, LogNormal, Multinomial,
|
||||
Normal, OneHotCategorical, Pareto, Poisson,
|
||||
StudentT, Uniform, constraints, kl_divergence)
|
||||
from torch.distributions.constraints import Constraint, is_dependent
|
||||
from torch.distributions.dirichlet import _Dirichlet_backward
|
||||
from torch.distributions.transforms import (AbsTransform, AffineTransform,
|
||||
BoltzmannTransform,
|
||||
ComposeTransform, ExpTransform,
|
||||
LowerCholeskyTransform,
|
||||
SigmoidTransform,
|
||||
StickBreakingTransform,
|
||||
identity_transform)
|
||||
from torch.distributions.utils import _finfo, probs_to_logits
|
||||
|
||||
TEST_NUMPY = True
|
||||
@ -167,6 +175,20 @@ EXAMPLES = [
|
||||
'scale': torch.Tensor([1e-5, 1e-5]),
|
||||
},
|
||||
]),
|
||||
Example(LogNormal, [
|
||||
{
|
||||
'loc': Variable(torch.randn(5, 5), requires_grad=True),
|
||||
'scale': Variable(torch.randn(5, 5).abs(), requires_grad=True),
|
||||
},
|
||||
{
|
||||
'loc': Variable(torch.randn(1), requires_grad=True),
|
||||
'scale': Variable(torch.randn(1).abs(), requires_grad=True),
|
||||
},
|
||||
{
|
||||
'loc': torch.Tensor([1.0, 0.0]),
|
||||
'scale': torch.Tensor([1e-5, 1e-5]),
|
||||
},
|
||||
]),
|
||||
Example(Normal, [
|
||||
{
|
||||
'loc': Variable(torch.randn(5, 5), requires_grad=True),
|
||||
@ -665,6 +687,51 @@ class TestDistributions(TestCase):
|
||||
loc.grad.zero_()
|
||||
scale.grad.zero_()
|
||||
|
||||
def test_lognormal(self):
|
||||
mean = Variable(torch.randn(5, 5), requires_grad=True)
|
||||
std = Variable(torch.randn(5, 5).abs(), requires_grad=True)
|
||||
mean_1d = Variable(torch.randn(1), requires_grad=True)
|
||||
std_1d = Variable(torch.randn(1), requires_grad=True)
|
||||
mean_delta = torch.Tensor([1.0, 0.0])
|
||||
std_delta = torch.Tensor([1e-5, 1e-5])
|
||||
self.assertEqual(LogNormal(mean, std).sample().size(), (5, 5))
|
||||
self.assertEqual(LogNormal(mean, std).sample_n(7).size(), (7, 5, 5))
|
||||
self.assertEqual(LogNormal(mean_1d, std_1d).sample_n(1).size(), (1, 1))
|
||||
self.assertEqual(LogNormal(mean_1d, std_1d).sample().size(), (1,))
|
||||
self.assertEqual(LogNormal(0.2, .6).sample_n(1).size(), (1,))
|
||||
self.assertEqual(LogNormal(-0.7, 50.0).sample_n(1).size(), (1,))
|
||||
|
||||
# sample check for extreme value of mean, std
|
||||
set_rng_seed(1)
|
||||
self.assertEqual(LogNormal(mean_delta, std_delta).sample(sample_shape=(1, 2)),
|
||||
torch.Tensor([[[math.exp(1), 1.0], [math.exp(1), 1.0]]]),
|
||||
prec=1e-4)
|
||||
|
||||
self._gradcheck_log_prob(LogNormal, (mean, std))
|
||||
self._gradcheck_log_prob(LogNormal, (mean, 1.0))
|
||||
self._gradcheck_log_prob(LogNormal, (0.0, std))
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_lognormal_logprob(self):
|
||||
mean = Variable(torch.randn(5, 1), requires_grad=True)
|
||||
std = Variable(torch.randn(5, 1).abs(), requires_grad=True)
|
||||
|
||||
def ref_log_prob(idx, x, log_prob):
|
||||
m = mean.data.view(-1)[idx]
|
||||
s = std.data.view(-1)[idx]
|
||||
expected = scipy.stats.lognorm(s=s, scale=math.exp(m)).logpdf(x)
|
||||
self.assertAlmostEqual(log_prob, expected, places=3)
|
||||
|
||||
self._check_log_prob(LogNormal(mean, std), ref_log_prob)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_lognormal_sample(self):
|
||||
set_rng_seed(0) # see Note [Randomized statistical tests]
|
||||
for mean, std in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_sampler(LogNormal(mean, std),
|
||||
scipy.stats.lognorm(scale=math.exp(mean), s=std),
|
||||
'LogNormal(loc={}, scale={})'.format(mean, std))
|
||||
|
||||
def test_normal(self):
|
||||
loc = Variable(torch.randn(5, 5), requires_grad=True)
|
||||
scale = Variable(torch.randn(5, 5).abs(), requires_grad=True)
|
||||
@ -1420,7 +1487,7 @@ class TestDistributionShapes(TestCase):
|
||||
dist = Dist(**param)
|
||||
try:
|
||||
actual_shape = dist.entropy().size()
|
||||
expected_shape = dist._batch_shape if dist._batch_shape else torch.Size(SCALAR_SHAPE)
|
||||
expected_shape = dist.batch_shape if dist.batch_shape else torch.Size(SCALAR_SHAPE)
|
||||
message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
|
||||
Dist.__name__, i + 1, len(params), expected_shape, actual_shape)
|
||||
self.assertEqual(actual_shape, expected_shape, message=message)
|
||||
@ -1717,7 +1784,7 @@ class TestKL(TestCase):
|
||||
def __init__(self, probs):
|
||||
super(Binomial30, self).__init__(30, probs)
|
||||
|
||||
# These are pairs of distributions with 4 x 4 paramters as specified.
|
||||
# These are pairs of distributions with 4 x 4 parameters as specified.
|
||||
# The first of the pair e.g. bernoulli[0] varies column-wise and the second
|
||||
# e.g. bernoulli[1] varies row-wise; that way we test all param pairs.
|
||||
bernoulli = pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9])
|
||||
@ -1728,6 +1795,7 @@ class TestKL(TestCase):
|
||||
gamma = pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
|
||||
gumbel = pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
|
||||
laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
|
||||
lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
|
||||
normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
|
||||
pareto = pairwise(Pareto, [2.5, 4.0, 2.5, 4.0], [2.25, 3.75, 2.25, 3.75])
|
||||
poisson = pairwise(Poisson, [0.3, 1.0, 5.0, 10.0])
|
||||
@ -1776,6 +1844,7 @@ class TestKL(TestCase):
|
||||
(gumbel, gumbel),
|
||||
(gumbel, normal),
|
||||
(laplace, laplace),
|
||||
(lognormal, lognormal),
|
||||
(laplace, normal),
|
||||
(normal, gumbel),
|
||||
(normal, normal),
|
||||
@ -2100,5 +2169,164 @@ class TestLazyLogitsInitialization(TestCase):
|
||||
self.assertFalse('logits' in vars(dist), msg=message)
|
||||
|
||||
|
||||
class TestTransforms(TestCase):
|
||||
def setUp(self):
|
||||
self.transforms = []
|
||||
transforms_by_cache_size = {}
|
||||
for cache_size in [0, 1]:
|
||||
transforms = [
|
||||
AbsTransform(cache_size=cache_size),
|
||||
ExpTransform(cache_size=cache_size),
|
||||
SigmoidTransform(cache_size=cache_size),
|
||||
AffineTransform(Variable(torch.Tensor(5).normal_()),
|
||||
Variable(torch.Tensor(5).normal_()),
|
||||
cache_size=cache_size),
|
||||
AffineTransform(Variable(torch.Tensor(4, 5).normal_()),
|
||||
Variable(torch.Tensor(4, 5).normal_()),
|
||||
cache_size=cache_size),
|
||||
BoltzmannTransform(cache_size=cache_size),
|
||||
StickBreakingTransform(cache_size=cache_size),
|
||||
LowerCholeskyTransform(cache_size=cache_size),
|
||||
ComposeTransform([
|
||||
AffineTransform(Variable(torch.Tensor(4, 5).normal_()),
|
||||
Variable(torch.Tensor(4, 5).normal_()),
|
||||
cache_size=cache_size),
|
||||
]),
|
||||
ComposeTransform([
|
||||
AffineTransform(Variable(torch.Tensor(4, 5).normal_()),
|
||||
Variable(torch.Tensor(4, 5).normal_()),
|
||||
cache_size=cache_size),
|
||||
ExpTransform(cache_size=cache_size),
|
||||
]),
|
||||
]
|
||||
for t in transforms[:]:
|
||||
transforms.append(t.inv)
|
||||
transforms.append(identity_transform)
|
||||
self.transforms += transforms
|
||||
if cache_size == 0:
|
||||
self.unique_transforms = transforms[:]
|
||||
|
||||
def _generate_data(self, transform):
|
||||
domain = transform.domain
|
||||
codomain = transform.codomain
|
||||
x = torch.Tensor(4, 5)
|
||||
if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky:
|
||||
x = torch.Tensor(6, 6)
|
||||
x = x.normal_()
|
||||
return x
|
||||
elif domain is constraints.real:
|
||||
return x.normal_()
|
||||
elif domain is constraints.positive:
|
||||
return x.normal_().exp()
|
||||
elif domain is constraints.unit_interval:
|
||||
return x.uniform_()
|
||||
elif domain is constraints.simplex:
|
||||
x = x.normal_().exp()
|
||||
x /= x.sum(-1, True)
|
||||
return x
|
||||
raise ValueError('Unsupported domain: {}'.format(domain))
|
||||
|
||||
def test_inv_inv(self):
|
||||
for t in self.transforms:
|
||||
self.assertTrue(t.inv.inv is t)
|
||||
|
||||
def test_equality(self):
|
||||
transforms = self.unique_transforms
|
||||
for x, y in product(transforms, transforms):
|
||||
if x is y:
|
||||
self.assertTrue(x == y)
|
||||
self.assertFalse(x != y)
|
||||
else:
|
||||
self.assertFalse(x == y)
|
||||
self.assertTrue(x != y)
|
||||
|
||||
self.assertTrue(identity_transform == identity_transform.inv)
|
||||
self.assertFalse(identity_transform != identity_transform.inv)
|
||||
|
||||
def test_forward_inverse_cache(self):
|
||||
for transform in self.transforms:
|
||||
x = Variable(self._generate_data(transform), requires_grad=True)
|
||||
try:
|
||||
y = transform(x)
|
||||
except NotImplementedError:
|
||||
continue
|
||||
x2 = transform.inv(y) # should be implemented at least by caching
|
||||
y2 = transform(x2) # should be implemented at least by caching
|
||||
if transform.bijective:
|
||||
# verify function inverse
|
||||
self.assertEqual(x2, x, message='\n'.join([
|
||||
'{} t.inv(t(-)) error'.format(transform),
|
||||
'x = {}'.format(x),
|
||||
'y = t(x) = {}'.format(y),
|
||||
'x2 = t.inv(y) = {}'.format(x2),
|
||||
]))
|
||||
else:
|
||||
# verify weaker function pseudo-inverse
|
||||
self.assertEqual(y2, y, message='\n'.join([
|
||||
'{} t(t.inv(t(-))) error'.format(transform),
|
||||
'x = {}'.format(x),
|
||||
'y = t(x) = {}'.format(y),
|
||||
'x2 = t.inv(y) = {}'.format(x2),
|
||||
'y2 = t(x2) = {}'.format(y2),
|
||||
]))
|
||||
|
||||
def test_forward_inverse_no_cache(self):
|
||||
for transform in self.transforms:
|
||||
x = Variable(self._generate_data(transform), requires_grad=True)
|
||||
try:
|
||||
y = transform(x)
|
||||
x2 = transform.inv(y.clone()) # bypass cache
|
||||
y2 = transform(x2)
|
||||
except NotImplementedError:
|
||||
continue
|
||||
if transform.bijective:
|
||||
# verify function inverse
|
||||
self.assertEqual(x2, x, message='\n'.join([
|
||||
'{} t.inv(t(-)) error'.format(transform),
|
||||
'x = {}'.format(x),
|
||||
'y = t(x) = {}'.format(y),
|
||||
'x2 = t.inv(y) = {}'.format(x2),
|
||||
]))
|
||||
else:
|
||||
# verify weaker function pseudo-inverse
|
||||
self.assertEqual(y2, y, message='\n'.join([
|
||||
'{} t(t.inv(t(-))) error'.format(transform),
|
||||
'x = {}'.format(x),
|
||||
'y = t(x) = {}'.format(y),
|
||||
'x2 = t.inv(y) = {}'.format(x2),
|
||||
'y2 = t(x2) = {}'.format(y2),
|
||||
]))
|
||||
|
||||
def test_univariate_forward_jacobian(self):
|
||||
for transform in self.transforms:
|
||||
x = Variable(self._generate_data(transform), requires_grad=True)
|
||||
try:
|
||||
y = transform(x)
|
||||
actual = transform.log_abs_det_jacobian(x, y)
|
||||
except NotImplementedError:
|
||||
continue
|
||||
expected = torch.abs(grad([y.sum()], [x])[0]).log()
|
||||
self.assertEqual(actual, expected, message='\n'.join([
|
||||
'Bad {}.log_abs_det_jacobian() disagrees with ()'.format(transform),
|
||||
'Expected: {}'.format(expected),
|
||||
'Actual: {}'.format(actual),
|
||||
]))
|
||||
|
||||
def test_univariate_inverse_jacobian(self):
|
||||
for transform in self.transforms:
|
||||
y = Variable(self._generate_data(transform.inv), requires_grad=True)
|
||||
try:
|
||||
x = transform.inv(y)
|
||||
actual = transform.log_abs_det_jacobian(x, y)
|
||||
except NotImplementedError:
|
||||
continue
|
||||
expected = -torch.abs(grad([x.sum()], [y])[0]).log()
|
||||
self.assertEqual(actual, expected, message='\n'.join([
|
||||
'{}.log_abs_det_jacobian() disagrees with .inv()'.format(transform),
|
||||
'Expected: {}'.format(expected),
|
||||
'Actual: {}'.format(actual),
|
||||
]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
@ -32,6 +32,7 @@ policy, the code for implementing REINFORCE would be as follows::
|
||||
|
||||
from .bernoulli import Bernoulli
|
||||
from .beta import Beta
|
||||
from .transforms import *
|
||||
from .binomial import Binomial
|
||||
from .categorical import Categorical
|
||||
from .cauchy import Cauchy
|
||||
@ -45,12 +46,14 @@ from .geometric import Geometric
|
||||
from .gumbel import Gumbel
|
||||
from .kl import kl_divergence, register_kl
|
||||
from .laplace import Laplace
|
||||
from .log_normal import LogNormal
|
||||
from .multinomial import Multinomial
|
||||
from .normal import Normal
|
||||
from .one_hot_categorical import OneHotCategorical
|
||||
from .pareto import Pareto
|
||||
from .studentT import StudentT
|
||||
from .poisson import Poisson
|
||||
from .studentT import StudentT
|
||||
from .transformed_distribution import TransformedDistribution
|
||||
from .uniform import Uniform
|
||||
|
||||
__all__ = [
|
||||
@ -68,6 +71,7 @@ __all__ = [
|
||||
'Geometric',
|
||||
'Gumbel',
|
||||
'Laplace',
|
||||
'LogNormal',
|
||||
'Multinomial',
|
||||
'Normal',
|
||||
'OneHotCategorical',
|
||||
@ -75,6 +79,8 @@ __all__ = [
|
||||
'StudentT',
|
||||
'Poisson',
|
||||
'Uniform',
|
||||
'TransformedDistribution',
|
||||
'kl_divergence',
|
||||
'register_kl',
|
||||
]
|
||||
__all__.extend(transforms.__all__)
|
||||
|
||||
@ -11,6 +11,7 @@ __all__ = [
|
||||
'interval',
|
||||
'is_dependent',
|
||||
'less_than',
|
||||
'lower_cholesky',
|
||||
'lower_triangular',
|
||||
'nonnegative_integer',
|
||||
'positive',
|
||||
@ -165,7 +166,19 @@ class _LowerTriangular(Constraint):
|
||||
Constrain to lower-triangular square matrices.
|
||||
"""
|
||||
def check(self, value):
|
||||
return (torch.tril(value) == value).min(-1).min(-1)
|
||||
return (torch.tril(value) == value).min(-1)[0].min(-1)[0]
|
||||
|
||||
|
||||
class _LowerCholesky(Constraint):
|
||||
"""
|
||||
Constrain to lower-triangular square matrices with positive diagonals.
|
||||
"""
|
||||
def check(self, value):
|
||||
n = value.size(-1)
|
||||
diag_mask = torch.eye(n, n, out=value.new(n, n))
|
||||
lower_triangular = (torch.tril(value) == value).min(-1)[0].min(-1)[0]
|
||||
positive_diagonal = (value * diag_mask > (diag_mask - 1)).min(-1)[0].min(-1)[0]
|
||||
return lower_triangular & positive_diagonal
|
||||
|
||||
|
||||
# Public interface.
|
||||
@ -183,3 +196,4 @@ unit_interval = _Interval(0, 1)
|
||||
interval = _Interval
|
||||
simplex = _Simplex()
|
||||
lower_triangular = _LowerTriangular()
|
||||
lower_cholesky = _LowerCholesky()
|
||||
|
||||
@ -1,22 +1,24 @@
|
||||
import math
|
||||
import warnings
|
||||
from functools import total_ordering
|
||||
|
||||
import torch
|
||||
import math
|
||||
|
||||
from .distribution import Distribution
|
||||
from .bernoulli import Bernoulli
|
||||
from .binomial import Binomial
|
||||
from .beta import Beta
|
||||
from .binomial import Binomial
|
||||
from .dirichlet import Dirichlet
|
||||
from .distribution import Distribution
|
||||
from .exponential import Exponential
|
||||
from .gamma import Gamma
|
||||
from .geometric import Geometric
|
||||
from .gumbel import Gumbel
|
||||
from .laplace import Laplace
|
||||
from .log_normal import LogNormal
|
||||
from .normal import Normal
|
||||
from .pareto import Pareto
|
||||
from .poisson import Poisson
|
||||
from .transformed_distribution import TransformedDistribution
|
||||
from .uniform import Uniform
|
||||
from torch.autograd import Variable, variable
|
||||
|
||||
@ -268,6 +270,13 @@ def _kl_poisson_poisson(p, q):
|
||||
return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate)
|
||||
|
||||
|
||||
@register_kl(TransformedDistribution, TransformedDistribution)
|
||||
def _kl_transformed_transformed(p, q):
|
||||
if p.transforms != q.transforms:
|
||||
raise NotImplementedError
|
||||
return kl_divergence(p.base_dist, q.base_dist)
|
||||
|
||||
|
||||
@register_kl(Uniform, Uniform)
|
||||
def _kl_uniform_uniform(p, q):
|
||||
result = ((q.high - q.low) / (p.high - p.low)).log()
|
||||
|
||||
42
torch/distributions/log_normal.py
Normal file
42
torch/distributions/log_normal.py
Normal file
@ -0,0 +1,42 @@
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.transforms import ExpTransform
|
||||
from torch.distributions.normal import Normal
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
|
||||
|
||||
class LogNormal(TransformedDistribution):
|
||||
r"""
|
||||
Creates a log-normal distribution parameterized by
|
||||
`mean` and `std` where::
|
||||
|
||||
X ~ Normal(loc, scale)
|
||||
Y = exp(X) ~ LogNormal(loc, scale)
|
||||
|
||||
Example::
|
||||
|
||||
>>> m = LogNormal(torch.Tensor([0.0]), torch.Tensor([1.0]))
|
||||
>>> m.sample() # log-normal distributed with mean=0 and stddev=1
|
||||
0.1046
|
||||
[torch.FloatTensor of size 1]
|
||||
|
||||
Args:
|
||||
loc (float or Tensor or Variable): mean of log of distribution
|
||||
scale (float or Tensor or Variable): standard deviation of log ofthe distribution
|
||||
"""
|
||||
params = {'loc': constraints.real, 'scale': constraints.positive}
|
||||
support = constraints.positive
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, loc, scale):
|
||||
super(LogNormal, self).__init__(Normal(loc, scale), ExpTransform())
|
||||
|
||||
@property
|
||||
def loc(self):
|
||||
return self.base_dist.loc
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
return self.base_dist.scale
|
||||
|
||||
def entropy(self):
|
||||
return self.base_dist.entropy() + self.loc
|
||||
80
torch/distributions/transformed_distribution.py
Normal file
80
torch/distributions/transformed_distribution.py
Normal file
@ -0,0 +1,80 @@
|
||||
import torch
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.transforms import Transform
|
||||
from torch.distributions.distribution import Distribution
|
||||
|
||||
|
||||
class TransformedDistribution(Distribution):
|
||||
r"""
|
||||
Extension of the Distribution class, which applies a sequence of Transforms to a base distribution.
|
||||
Let f be the composition of transforms applied,
|
||||
X ~ BaseDistribution
|
||||
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
|
||||
log p(Y) = log p(X) + log det (dX/dY)
|
||||
"""
|
||||
def __init__(self, base_distribution, transforms=[], *args, **kwargs):
|
||||
super(TransformedDistribution, self).__init__(*args, **kwargs)
|
||||
self.base_dist = base_distribution
|
||||
if isinstance(transforms, Transform):
|
||||
self.transforms = [transforms, ]
|
||||
elif isinstance(transforms, list):
|
||||
if not all(isinstance(t, Transform) for t in transforms):
|
||||
raise ValueError("transforms must be a Transform or a list of Transforms")
|
||||
self.transforms = transforms
|
||||
|
||||
@constraints.dependent_property
|
||||
def params(self):
|
||||
return self.base_dist.params # TODO add params of transforms?
|
||||
|
||||
@constraints.dependent_property
|
||||
def support(self):
|
||||
return self.transforms[-1].codomain if self.transforms else self.base_dist.support
|
||||
|
||||
@property
|
||||
def has_rsample(self):
|
||||
return self.base_dist.has_rsample
|
||||
|
||||
@property
|
||||
def batch_shape(self):
|
||||
return self.base_dist.batch_shape
|
||||
|
||||
@property
|
||||
def event_shape(self):
|
||||
return self.base_dist.event_shape
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
"""
|
||||
Generates a sample_shape shaped sample or sample_shape shaped batch of
|
||||
samples if the distribution parameters are batched. Samples first from base distribution
|
||||
and applies `transform()` for every transform in the list.
|
||||
"""
|
||||
x = self.base_dist.sample(sample_shape)
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
return x
|
||||
|
||||
def rsample(self, sample_shape=torch.Size()):
|
||||
"""
|
||||
Generates a sample_shape shaped reparameterized sample or sample_shape
|
||||
shaped batch of reparameterized samples if the distribution parameters
|
||||
are batched. Samples first from base distribution and applies `transform()`
|
||||
for every transform in the list.
|
||||
"""
|
||||
x = self.base_dist.rsample(sample_shape)
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
return x
|
||||
|
||||
def log_prob(self, value):
|
||||
"""
|
||||
Scores the sample by inverting the transform(s) and computing the score using the score
|
||||
of the base distribution and the log abs det jacobian
|
||||
"""
|
||||
log_prob = 0.0
|
||||
y = value
|
||||
for transform in reversed(self.transforms):
|
||||
x = transform.inv(y)
|
||||
log_prob -= transform.log_abs_det_jacobian(x, y)
|
||||
y = x
|
||||
log_prob += self.base_dist.log_prob(y)
|
||||
return log_prob
|
||||
431
torch/distributions/transforms.py
Normal file
431
torch/distributions/transforms.py
Normal file
@ -0,0 +1,431 @@
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.utils import broadcast_all, lazy_property
|
||||
from torch.nn.functional import sigmoid
|
||||
|
||||
__all__ = [
|
||||
'AbsTransform',
|
||||
'AffineTransform',
|
||||
'BoltzmannTransform',
|
||||
'ComposeTransform',
|
||||
'ExpTransform',
|
||||
'LowerCholeskyTransform',
|
||||
'SigmoidTransform',
|
||||
'StickBreakingTransform',
|
||||
'Transform',
|
||||
'identity_transform',
|
||||
]
|
||||
|
||||
|
||||
class Transform(object):
|
||||
"""
|
||||
Abstract class for invertable transformations with computable log
|
||||
det jacobians. They are primarily used in
|
||||
:class:`torch.distributions.TransformedDistribution`.
|
||||
|
||||
Caching is useful for tranforms whose inverses are either expensive or
|
||||
numerically unstable. Note that care must be taken with memoized values
|
||||
since the autograd graph may be reversed. For example while the following
|
||||
works with or without caching::
|
||||
|
||||
y = t(x)
|
||||
t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.
|
||||
|
||||
However the following will error when caching due to dependency reversal::
|
||||
|
||||
y = t(x)
|
||||
z = t.inv(y)
|
||||
grad(z.sum(), [y]) # error because z is x
|
||||
|
||||
Derived classes should implement one or both of :meth:`_call` or
|
||||
:meth:`_inverse`. Derived classes that set `bijective=True` should also
|
||||
implement :meth:`log_abs_det_jacobian`.
|
||||
|
||||
Args:
|
||||
cache_size (int): Size of cache. If zero, no caching is done. If one,
|
||||
the latest single value is cached. Only 0 and 1 are supported.
|
||||
|
||||
Attributes:
|
||||
domain (:class:`~torch.distributions.constraints.Constraint`):
|
||||
The constraint representing valid inputs to this transform.
|
||||
codomain (:class:`~torch.distributions.constraints.Constraint`):
|
||||
The constraint representing valid outputs to this transform
|
||||
which are inputs to the inverse transform.
|
||||
bijective (bool): Whether this transform is bijective. A transform
|
||||
``t`` is bijective iff ``t.inv(t(x)) == x`` and
|
||||
``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in
|
||||
the codomain. Transforms that are not bijective should at least
|
||||
maintain the weaker pseudoinverse properties
|
||||
``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
|
||||
"""
|
||||
bijective = False
|
||||
|
||||
def __init__(self, cache_size=0):
|
||||
self._cache_size = cache_size
|
||||
self._inv = None
|
||||
if cache_size == 0:
|
||||
pass # default behavior
|
||||
elif cache_size == 1:
|
||||
self._cached_x_y = None, None
|
||||
else:
|
||||
raise ValueError('cache_size must be 0 or 1')
|
||||
|
||||
@property
|
||||
def inv(self):
|
||||
"""
|
||||
Returns the inverse :class:`Transform` of this transform.
|
||||
This should satisfy ``t.inv.inv is t``.
|
||||
"""
|
||||
inv = None
|
||||
if self._inv is not None:
|
||||
inv = self._inv()
|
||||
if inv is None:
|
||||
inv = _InverseTransform(self)
|
||||
self._inv = weakref.ref(inv)
|
||||
return inv
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
def __ne__(self, other):
|
||||
# Necessary for Python2
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Computes the transform `x => y`.
|
||||
"""
|
||||
if self._cache_size == 0:
|
||||
return self._call(x)
|
||||
x_old, y_old = self._cached_x_y
|
||||
if x is x_old:
|
||||
return y_old
|
||||
y = self._call(x)
|
||||
self._cached_x_y = x, y
|
||||
return y
|
||||
|
||||
def _inv_call(self, y):
|
||||
"""
|
||||
Inverts the transform `y => x`.
|
||||
"""
|
||||
if self._cache_size == 0:
|
||||
return self._inverse(y)
|
||||
x_old, y_old = self._cached_x_y
|
||||
if y is y_old:
|
||||
return x_old
|
||||
x = self._inverse(y)
|
||||
self._cached_x_y = x, y
|
||||
return x
|
||||
|
||||
def _call(self, x):
|
||||
"""
|
||||
Abstract method to compute forward transformation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _inverse(self, y):
|
||||
"""
|
||||
Abstract method to compute inverse transformation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
"""
|
||||
Computes the log det jacobian `log |dy/dx|` given input and output.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _InverseTransform(Transform):
|
||||
"""
|
||||
Inverts a single :class:`Transform`.
|
||||
This class is private; please instead use the ``Transform.inv`` property.
|
||||
"""
|
||||
def __init__(self, transform):
|
||||
super(_InverseTransform, self).__init__()
|
||||
self._inv = transform
|
||||
|
||||
@constraints.dependent_property
|
||||
def domain(self):
|
||||
return self._inv.codomain
|
||||
|
||||
@constraints.dependent_property
|
||||
def codomain(self):
|
||||
return self._inv.domain
|
||||
|
||||
@property
|
||||
def bijective(self):
|
||||
return self._inv.bijective
|
||||
|
||||
@property
|
||||
def inv(self):
|
||||
return self._inv
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, _InverseTransform):
|
||||
return False
|
||||
return self._inv == other._inv
|
||||
|
||||
def __call__(self, x):
|
||||
return self._inv._inv_call(x)
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
return -self._inv.log_abs_det_jacobian(y, x)
|
||||
|
||||
|
||||
class ComposeTransform(Transform):
|
||||
"""
|
||||
Composes multiple transforms in a chain.
|
||||
The transforms being composed are responsible for caching.
|
||||
|
||||
Args:
|
||||
parts (list of :class:`Transform`): A list of transforms to compose.
|
||||
"""
|
||||
def __init__(self, parts):
|
||||
super(ComposeTransform, self).__init__()
|
||||
self.parts = parts
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, ComposeTransform):
|
||||
return False
|
||||
return self.parts == other.parts
|
||||
|
||||
@constraints.dependent_property
|
||||
def domain(self):
|
||||
if not self.parts:
|
||||
return constraints.real
|
||||
return self.parts[0].domain
|
||||
|
||||
@constraints.dependent_property
|
||||
def codomain(self):
|
||||
if not self.parts:
|
||||
return constraints.real
|
||||
return self.parts[-1].codomain
|
||||
|
||||
@lazy_property
|
||||
def bijective(self):
|
||||
return all(p.bijective for p in self.parts)
|
||||
|
||||
@property
|
||||
def inv(self):
|
||||
inv = None
|
||||
if self._inv is not None:
|
||||
inv = self._inv()
|
||||
if inv is None:
|
||||
inv = ComposeTransform([p.inv for p in reversed(self.parts)])
|
||||
self._inv = weakref.ref(inv)
|
||||
inv._inv = weakref.ref(self)
|
||||
return inv
|
||||
|
||||
def __call__(self, x):
|
||||
for part in self.parts:
|
||||
x = part(x)
|
||||
return x
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
if not self.parts:
|
||||
return x.new([0]).expand_as(x)
|
||||
result = 0
|
||||
for part in self.parts:
|
||||
y = part(x)
|
||||
result += part.log_abs_det_jacobian(x, y)
|
||||
x = y
|
||||
return result
|
||||
|
||||
|
||||
identity_transform = ComposeTransform([])
|
||||
|
||||
|
||||
class ExpTransform(Transform):
|
||||
"""
|
||||
Transform via the mapping `y = exp(x)`.
|
||||
"""
|
||||
domain = constraints.real
|
||||
codomain = constraints.positive
|
||||
bijective = True
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ExpTransform)
|
||||
|
||||
def _call(self, x):
|
||||
return x.exp()
|
||||
|
||||
def _inverse(self, y):
|
||||
return y.log()
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
return x
|
||||
|
||||
|
||||
class SigmoidTransform(Transform):
|
||||
"""
|
||||
Transform via the mapping `y = sigmoid(x)` and `x = logit(y)`.
|
||||
"""
|
||||
domain = constraints.real
|
||||
codomain = constraints.unit_interval
|
||||
bijective = True
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, SigmoidTransform)
|
||||
|
||||
def _call(self, x):
|
||||
return sigmoid(x)
|
||||
|
||||
def _inverse(self, y):
|
||||
return y.log() - (-y).log1p()
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
return -(y.reciprocal() + (1 - y).reciprocal()).log()
|
||||
|
||||
|
||||
class AbsTransform(Transform):
|
||||
"""
|
||||
Transform via the mapping `y = abs(x)`.
|
||||
"""
|
||||
domain = constraints.real
|
||||
codomain = constraints.positive
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, AbsTransform)
|
||||
|
||||
def _call(self, x):
|
||||
return x.abs()
|
||||
|
||||
def _inverse(self, y):
|
||||
return y
|
||||
|
||||
|
||||
class AffineTransform(Transform):
|
||||
"""
|
||||
Transform via the pointwise affine mapping `y = loc + scale * x`.
|
||||
|
||||
Args:
|
||||
loc (Tensor or Variable): Location parameter.
|
||||
scale (Tensor or Variable): Scale parameter.
|
||||
event_dim (int): Optional size of `event_shape`. This should be zero
|
||||
for univariate random variables, 1 for distributions over vectors,
|
||||
2 for distributions over matrices, etc.
|
||||
"""
|
||||
domain = constraints.real
|
||||
codomain = constraints.real
|
||||
bijective = True
|
||||
|
||||
def __init__(self, loc, scale, event_dim=0, cache_size=0):
|
||||
super(AffineTransform, self).__init__(cache_size=cache_size)
|
||||
self.loc, self.scale = broadcast_all(loc, scale)
|
||||
self.event_dim = event_dim
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, AffineTransform):
|
||||
return False
|
||||
result = self.loc.eq(other.loc).all() and self.scale.eq(other.scale).all()
|
||||
if isinstance(result, Variable):
|
||||
result = result.data.view(-1)[0]
|
||||
return result
|
||||
|
||||
def _call(self, x):
|
||||
return self.loc + self.scale * x
|
||||
|
||||
def _inverse(self, y):
|
||||
return (y - self.loc) / self.scale
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
result = torch.abs(self.scale).log()
|
||||
shape = x.shape
|
||||
if self.event_dim:
|
||||
result_size = result.size()[:-self.event_dim] + (-1,)
|
||||
result = result.view(result_size).sum(-1)
|
||||
shape = shape[:-self.event_dim]
|
||||
return result.expand(shape)
|
||||
|
||||
|
||||
class BoltzmannTransform(Transform):
|
||||
"""
|
||||
Transform from unconstrained space to the simplex via `y = exp(x)` then
|
||||
normalizing.
|
||||
|
||||
This is not bijective and cannot be used for HMC. However this acts mostly
|
||||
coordinate-wise (except for the final normalization), and thus is
|
||||
appropriate for coordinate-wise optimization algorithms.
|
||||
"""
|
||||
domain = constraints.real
|
||||
codomain = constraints.simplex
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, BoltzmannTransform)
|
||||
|
||||
def _call(self, x):
|
||||
logprobs = x
|
||||
probs = (logprobs - logprobs.max(-1, True)[0]).exp()
|
||||
probs /= probs.sum(-1, True)
|
||||
return probs
|
||||
|
||||
def _inverse(self, y):
|
||||
probs = y
|
||||
return probs.log()
|
||||
|
||||
|
||||
class StickBreakingTransform(Transform):
|
||||
"""
|
||||
Transform from unconstrained space to the simplex of one additional
|
||||
dimension via a stick-breaking process.
|
||||
|
||||
This transform arises as an iterated sigmoid transform in a stick-breaking
|
||||
construction of the `Dirichlet` distribution: the first logit is
|
||||
transformed via sigmoid to the first probability and the probability of
|
||||
everything else, and then the process recurses.
|
||||
|
||||
This is bijective and appropriate for use in HMC; however it mixes
|
||||
coordinates together and is less appropriate for optimization.
|
||||
"""
|
||||
domain = constraints.real
|
||||
codomain = constraints.simplex
|
||||
bijective = True
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, StickBreakingTransform)
|
||||
|
||||
def _call(self, x):
|
||||
shape = x.shape[:-1] + (1 + x.shape[-1],)
|
||||
one = x.new([1]).expand(x.shape[:-1] + (1,))
|
||||
numer = sigmoid(x)
|
||||
denom = (1 - numer).cumprod(-1)
|
||||
probs = torch.cat([numer, one], -1) * torch.cat([one, denom], -1)
|
||||
return probs
|
||||
|
||||
def _inverse(self, y):
|
||||
pmf = y
|
||||
cmf = pmf.cumsum(-1)
|
||||
sf = 1 - cmf
|
||||
units = y[..., :-1] / sf[..., :-1]
|
||||
return units.log()
|
||||
|
||||
# TODO implement .log_abs_det_jacobian()
|
||||
|
||||
|
||||
class LowerCholeskyTransform(Transform):
|
||||
"""
|
||||
Transform from unconstrained matrices to lower-triangular matrices with
|
||||
nonnegative diagonal entries.
|
||||
|
||||
This is useful for parameterizing positive definite matrices in terms of
|
||||
their Cholesky factorization.
|
||||
"""
|
||||
domain = constraints.real
|
||||
codomain = constraints.lower_cholesky
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, LowerCholeskyTransform)
|
||||
|
||||
def _call(self, x):
|
||||
if x.dim() != 2:
|
||||
raise NotImplementedError
|
||||
return x.tril(-1) + x.diag().exp().diag()
|
||||
|
||||
def _inverse(self, y):
|
||||
if y.dim() != 2:
|
||||
raise NotImplementedError
|
||||
return y.tril(-1) + y.diag().log().diag()
|
||||
Reference in New Issue
Block a user