Implement Student's t-distribution (#4510)

This commit is contained in:
Alican Bozkurt
2018-01-08 04:23:48 -05:00
committed by Adam Paszke
parent 5c641cc14f
commit c9bc6c2bc3
4 changed files with 157 additions and 8 deletions

View File

@ -85,6 +85,12 @@ Probability distributions - torch.distributions
.. autoclass:: Pareto
:members:
:hidden:`StudentT`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: StudentT
:members:
:hidden:`Uniform`
~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -32,11 +32,10 @@ from common import TestCase, run_tests, set_rng_seed
from torch.autograd import Variable, gradcheck
from torch.distributions import (Bernoulli, Beta, Categorical, Cauchy, Chi2,
Dirichlet, Exponential, Gamma, Laplace,
Normal, OneHotCategorical, Pareto, Uniform)
Normal, OneHotCategorical, Pareto, StudentT, Uniform)
from torch.distributions.constraints import Constraint, is_dependent
from torch.distributions.utils import _get_clamping_buffer
TEST_NUMPY = True
try:
import numpy as np
@ -75,12 +74,12 @@ EXAMPLES = [
'scale': Variable(torch.Tensor([[1.0], [1.0]]))}
]),
Example(Chi2, [
{
'df': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
},
{
'df': Variable(torch.exp(torch.randn(1)), requires_grad=True),
},
{'df': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
{'df': Variable(torch.exp(torch.randn(1)), requires_grad=True)},
]),
Example(StudentT, [
{'df': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
{'df': Variable(torch.exp(torch.randn(1)), requires_grad=True)},
]),
Example(Dirichlet, [
{'alpha': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
@ -673,6 +672,46 @@ class TestDistributions(TestCase):
'rel error {}'.format(rel_error),
'max error {}'.format(rel_error.max())]))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_studentT_shape(self):
df = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
df_1d = Variable(torch.exp(torch.randn(1)), requires_grad=True)
self.assertEqual(StudentT(df).sample().size(), (2, 3))
self.assertEqual(StudentT(df).sample_n(5).size(), (5, 2, 3))
self.assertEqual(StudentT(df_1d).sample_n(1).size(), (1, 1))
self.assertEqual(StudentT(df_1d).sample().size(), (1,))
self.assertEqual(StudentT(0.5).sample().size(), (1,))
self.assertEqual(StudentT(0.5).sample_n(1).size(), (1,))
def ref_log_prob(idx, x, log_prob):
d = df.data.view(-1)[idx]
expected = scipy.stats.t.logpdf(x, d)
self.assertAlmostEqual(log_prob, expected, places=3)
self._check_log_prob(StudentT(df), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_studentT_sample(self):
set_rng_seed(11) # see Note [Randomized statistical tests]
for df, loc, scale in product([0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(StudentT(df=df, loc=loc, scale=scale),
scipy.stats.t(df=df, loc=loc, scale=scale),
'StudentT(df={}, loc={}, scale={})'.format(df, loc, scale))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_studentT_log_prob(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
num_samples = 10
for df, loc, scale in product([0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
dist = StudentT(df=df, loc=loc, scale=scale)
x = dist.sample((num_samples,))
actual_log_prob = dist.log_prob(x)
for i in range(num_samples):
expected_log_prob = scipy.stats.t.logpdf(x[i], df=df, loc=loc, scale=scale)
self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)
# TODO: add test_studentT_sample_grad once standard_t_grad() is implemented
def test_dirichlet_shape(self):
alpha = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
alpha_1d = Variable(torch.exp(torch.randn(4)), requires_grad=True)
@ -806,6 +845,18 @@ class TestDistributions(TestCase):
(1, 2)),
(Laplace(loc=torch.Tensor([0]), scale=torch.Tensor([[1]])),
(1, 1)),
(StudentT(df=torch.Tensor([1, 1]), loc=1),
(2,)),
(StudentT(df=1, scale=torch.Tensor([1, 1])),
(2,)),
(StudentT(df=torch.Tensor([1, 1]), loc=torch.Tensor([1])),
(2,)),
(StudentT(df=torch.Tensor([1, 1]), scale=torch.Tensor([[1], [1]])),
(2, 2)),
(StudentT(df=torch.Tensor([1, 1]), loc=torch.Tensor([[1]])),
(1, 2)),
(StudentT(df=torch.Tensor([1]), scale=torch.Tensor([[1]])),
(1, 1)),
]
for dist, expected_size in valid_examples:
@ -832,6 +883,14 @@ class TestDistributions(TestCase):
(Laplace, {
'loc': torch.Tensor([0, 0]),
'scale': torch.Tensor([1, 1, 1])
}),
(StudentT, {
'df': torch.Tensor([1, 1]),
'scale': torch.Tensor([1, 1, 1])
}),
(StudentT, {
'df': torch.Tensor([1, 1]),
'loc': torch.Tensor([1, 1, 1])
})
]
@ -982,6 +1041,25 @@ class TestDistributionShapes(TestCase):
self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, chi2.log_prob, self.tensor_sample_2)
def test_studentT_shape_scalar_params(self):
st = StudentT(1)
self.assertEqual(st._batch_shape, torch.Size())
self.assertEqual(st._event_shape, torch.Size())
self.assertEqual(st.sample().size(), torch.Size((1,)))
self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, st.log_prob, self.scalar_sample)
self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(st.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_studentT_shape_tensor_params(self):
st = StudentT(torch.Tensor([1, 1]))
self.assertEqual(st._batch_shape, torch.Size((2,)))
self.assertEqual(st._event_shape, torch.Size(()))
self.assertEqual(st.sample().size(), torch.Size((2,)))
self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, st.log_prob, self.tensor_sample_2)
def test_pareto_shape_scalar_params(self):
pareto = Pareto(1, 1)
self.assertEqual(pareto._batch_shape, torch.Size())

View File

@ -43,6 +43,7 @@ from .laplace import Laplace
from .normal import Normal
from .one_hot_categorical import OneHotCategorical
from .pareto import Pareto
from .studentT import StudentT
from .uniform import Uniform
__all__ = [
@ -59,5 +60,6 @@ __all__ = [
'Normal',
'OneHotCategorical',
'Pareto',
'StudentT',
'Uniform',
]

View File

@ -0,0 +1,63 @@
from numbers import Number
import torch
import math
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions import Chi2
from torch.distributions.utils import broadcast_all
class StudentT(Distribution):
r"""
Creates a Student's t-distribution parameterized by `df`.
Example::
>>> m = StudentT(torch.Tensor([2.0]))
>>> m.sample() # Student's t-distributed with degrees of freedom=2
0.1046
[torch.FloatTensor of size 1]
Args:
df (float or Tensor or Variable): degrees of freedom
"""
params = {'df': constraints.positive, 'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real
has_rsample = True
def __init__(self, df, loc=0., scale=1.):
self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
self._chi2 = Chi2(df)
batch_shape = torch.Size() if isinstance(df, Number) else self.df.size()
super(StudentT, self).__init__(batch_shape)
def rsample(self, sample_shape=torch.Size()):
# NOTE: This does not agree with scipy implementation as much as other distributions.
# (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor
# parameters seems to help.
# X ~ Normal(0, 1)
# Z ~ Chi2(df)
# Y = X / sqrt(Z / df) ~ StudentT(df)
shape = self._extended_shape(sample_shape)
X = self.df.new(*shape).normal_()
Z = self._chi2.rsample(sample_shape)
Y = X * torch.rsqrt(Z / self.df)
return self.loc + self.scale * Y
def log_prob(self, value):
self._validate_log_prob_arg(value)
y = (value - self.loc) / self.scale
Z = (self.scale.log() +
0.5 * self.df.log() +
0.5 * math.log(math.pi) +
torch.lgamma(0.5 * self.df) -
torch.lgamma(0.5 * (self.df + 1.)))
return -0.5 * (self.df + 1.) * torch.log1p(y**2. / self.df) - Z
def entropy(self):
lbeta = torch.lgamma(0.5 * self.df) + math.lgamma(0.5) - torch.lgamma(0.5 * (self.df + 1))
return (self.scale.log() +
0.5 * (self.df + 1) *
(torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) +
0.5 * self.df.log() + lbeta)