mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Implement Student's t-distribution (#4510)
This commit is contained in:
committed by
Adam Paszke
parent
5c641cc14f
commit
c9bc6c2bc3
@ -85,6 +85,12 @@ Probability distributions - torch.distributions
|
||||
.. autoclass:: Pareto
|
||||
:members:
|
||||
|
||||
:hidden:`StudentT`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: StudentT
|
||||
:members:
|
||||
|
||||
:hidden:`Uniform`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -32,11 +32,10 @@ from common import TestCase, run_tests, set_rng_seed
|
||||
from torch.autograd import Variable, gradcheck
|
||||
from torch.distributions import (Bernoulli, Beta, Categorical, Cauchy, Chi2,
|
||||
Dirichlet, Exponential, Gamma, Laplace,
|
||||
Normal, OneHotCategorical, Pareto, Uniform)
|
||||
Normal, OneHotCategorical, Pareto, StudentT, Uniform)
|
||||
from torch.distributions.constraints import Constraint, is_dependent
|
||||
from torch.distributions.utils import _get_clamping_buffer
|
||||
|
||||
|
||||
TEST_NUMPY = True
|
||||
try:
|
||||
import numpy as np
|
||||
@ -75,12 +74,12 @@ EXAMPLES = [
|
||||
'scale': Variable(torch.Tensor([[1.0], [1.0]]))}
|
||||
]),
|
||||
Example(Chi2, [
|
||||
{
|
||||
'df': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
|
||||
},
|
||||
{
|
||||
'df': Variable(torch.exp(torch.randn(1)), requires_grad=True),
|
||||
},
|
||||
{'df': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
|
||||
{'df': Variable(torch.exp(torch.randn(1)), requires_grad=True)},
|
||||
]),
|
||||
Example(StudentT, [
|
||||
{'df': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
|
||||
{'df': Variable(torch.exp(torch.randn(1)), requires_grad=True)},
|
||||
]),
|
||||
Example(Dirichlet, [
|
||||
{'alpha': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
|
||||
@ -673,6 +672,46 @@ class TestDistributions(TestCase):
|
||||
'rel error {}'.format(rel_error),
|
||||
'max error {}'.format(rel_error.max())]))
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_studentT_shape(self):
|
||||
df = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
|
||||
df_1d = Variable(torch.exp(torch.randn(1)), requires_grad=True)
|
||||
self.assertEqual(StudentT(df).sample().size(), (2, 3))
|
||||
self.assertEqual(StudentT(df).sample_n(5).size(), (5, 2, 3))
|
||||
self.assertEqual(StudentT(df_1d).sample_n(1).size(), (1, 1))
|
||||
self.assertEqual(StudentT(df_1d).sample().size(), (1,))
|
||||
self.assertEqual(StudentT(0.5).sample().size(), (1,))
|
||||
self.assertEqual(StudentT(0.5).sample_n(1).size(), (1,))
|
||||
|
||||
def ref_log_prob(idx, x, log_prob):
|
||||
d = df.data.view(-1)[idx]
|
||||
expected = scipy.stats.t.logpdf(x, d)
|
||||
self.assertAlmostEqual(log_prob, expected, places=3)
|
||||
|
||||
self._check_log_prob(StudentT(df), ref_log_prob)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_studentT_sample(self):
|
||||
set_rng_seed(11) # see Note [Randomized statistical tests]
|
||||
for df, loc, scale in product([0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_sampler(StudentT(df=df, loc=loc, scale=scale),
|
||||
scipy.stats.t(df=df, loc=loc, scale=scale),
|
||||
'StudentT(df={}, loc={}, scale={})'.format(df, loc, scale))
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_studentT_log_prob(self):
|
||||
set_rng_seed(0) # see Note [Randomized statistical tests]
|
||||
num_samples = 10
|
||||
for df, loc, scale in product([0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
|
||||
dist = StudentT(df=df, loc=loc, scale=scale)
|
||||
x = dist.sample((num_samples,))
|
||||
actual_log_prob = dist.log_prob(x)
|
||||
for i in range(num_samples):
|
||||
expected_log_prob = scipy.stats.t.logpdf(x[i], df=df, loc=loc, scale=scale)
|
||||
self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)
|
||||
|
||||
# TODO: add test_studentT_sample_grad once standard_t_grad() is implemented
|
||||
|
||||
def test_dirichlet_shape(self):
|
||||
alpha = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
|
||||
alpha_1d = Variable(torch.exp(torch.randn(4)), requires_grad=True)
|
||||
@ -806,6 +845,18 @@ class TestDistributions(TestCase):
|
||||
(1, 2)),
|
||||
(Laplace(loc=torch.Tensor([0]), scale=torch.Tensor([[1]])),
|
||||
(1, 1)),
|
||||
(StudentT(df=torch.Tensor([1, 1]), loc=1),
|
||||
(2,)),
|
||||
(StudentT(df=1, scale=torch.Tensor([1, 1])),
|
||||
(2,)),
|
||||
(StudentT(df=torch.Tensor([1, 1]), loc=torch.Tensor([1])),
|
||||
(2,)),
|
||||
(StudentT(df=torch.Tensor([1, 1]), scale=torch.Tensor([[1], [1]])),
|
||||
(2, 2)),
|
||||
(StudentT(df=torch.Tensor([1, 1]), loc=torch.Tensor([[1]])),
|
||||
(1, 2)),
|
||||
(StudentT(df=torch.Tensor([1]), scale=torch.Tensor([[1]])),
|
||||
(1, 1)),
|
||||
]
|
||||
|
||||
for dist, expected_size in valid_examples:
|
||||
@ -832,6 +883,14 @@ class TestDistributions(TestCase):
|
||||
(Laplace, {
|
||||
'loc': torch.Tensor([0, 0]),
|
||||
'scale': torch.Tensor([1, 1, 1])
|
||||
}),
|
||||
(StudentT, {
|
||||
'df': torch.Tensor([1, 1]),
|
||||
'scale': torch.Tensor([1, 1, 1])
|
||||
}),
|
||||
(StudentT, {
|
||||
'df': torch.Tensor([1, 1]),
|
||||
'loc': torch.Tensor([1, 1, 1])
|
||||
})
|
||||
]
|
||||
|
||||
@ -982,6 +1041,25 @@ class TestDistributionShapes(TestCase):
|
||||
self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
|
||||
self.assertRaises(ValueError, chi2.log_prob, self.tensor_sample_2)
|
||||
|
||||
def test_studentT_shape_scalar_params(self):
|
||||
st = StudentT(1)
|
||||
self.assertEqual(st._batch_shape, torch.Size())
|
||||
self.assertEqual(st._event_shape, torch.Size())
|
||||
self.assertEqual(st.sample().size(), torch.Size((1,)))
|
||||
self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2)))
|
||||
self.assertRaises(ValueError, st.log_prob, self.scalar_sample)
|
||||
self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
|
||||
self.assertEqual(st.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
|
||||
|
||||
def test_studentT_shape_tensor_params(self):
|
||||
st = StudentT(torch.Tensor([1, 1]))
|
||||
self.assertEqual(st._batch_shape, torch.Size((2,)))
|
||||
self.assertEqual(st._event_shape, torch.Size(()))
|
||||
self.assertEqual(st.sample().size(), torch.Size((2,)))
|
||||
self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2, 2)))
|
||||
self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
|
||||
self.assertRaises(ValueError, st.log_prob, self.tensor_sample_2)
|
||||
|
||||
def test_pareto_shape_scalar_params(self):
|
||||
pareto = Pareto(1, 1)
|
||||
self.assertEqual(pareto._batch_shape, torch.Size())
|
||||
|
@ -43,6 +43,7 @@ from .laplace import Laplace
|
||||
from .normal import Normal
|
||||
from .one_hot_categorical import OneHotCategorical
|
||||
from .pareto import Pareto
|
||||
from .studentT import StudentT
|
||||
from .uniform import Uniform
|
||||
|
||||
__all__ = [
|
||||
@ -59,5 +60,6 @@ __all__ = [
|
||||
'Normal',
|
||||
'OneHotCategorical',
|
||||
'Pareto',
|
||||
'StudentT',
|
||||
'Uniform',
|
||||
]
|
||||
|
63
torch/distributions/studentT.py
Normal file
63
torch/distributions/studentT.py
Normal file
@ -0,0 +1,63 @@
|
||||
from numbers import Number
|
||||
import torch
|
||||
import math
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions import Chi2
|
||||
from torch.distributions.utils import broadcast_all
|
||||
|
||||
|
||||
class StudentT(Distribution):
|
||||
r"""
|
||||
Creates a Student's t-distribution parameterized by `df`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> m = StudentT(torch.Tensor([2.0]))
|
||||
>>> m.sample() # Student's t-distributed with degrees of freedom=2
|
||||
0.1046
|
||||
[torch.FloatTensor of size 1]
|
||||
|
||||
Args:
|
||||
df (float or Tensor or Variable): degrees of freedom
|
||||
"""
|
||||
params = {'df': constraints.positive, 'loc': constraints.real, 'scale': constraints.positive}
|
||||
support = constraints.real
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, df, loc=0., scale=1.):
|
||||
self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
|
||||
self._chi2 = Chi2(df)
|
||||
batch_shape = torch.Size() if isinstance(df, Number) else self.df.size()
|
||||
super(StudentT, self).__init__(batch_shape)
|
||||
|
||||
def rsample(self, sample_shape=torch.Size()):
|
||||
# NOTE: This does not agree with scipy implementation as much as other distributions.
|
||||
# (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor
|
||||
# parameters seems to help.
|
||||
|
||||
# X ~ Normal(0, 1)
|
||||
# Z ~ Chi2(df)
|
||||
# Y = X / sqrt(Z / df) ~ StudentT(df)
|
||||
shape = self._extended_shape(sample_shape)
|
||||
X = self.df.new(*shape).normal_()
|
||||
Z = self._chi2.rsample(sample_shape)
|
||||
Y = X * torch.rsqrt(Z / self.df)
|
||||
return self.loc + self.scale * Y
|
||||
|
||||
def log_prob(self, value):
|
||||
self._validate_log_prob_arg(value)
|
||||
y = (value - self.loc) / self.scale
|
||||
Z = (self.scale.log() +
|
||||
0.5 * self.df.log() +
|
||||
0.5 * math.log(math.pi) +
|
||||
torch.lgamma(0.5 * self.df) -
|
||||
torch.lgamma(0.5 * (self.df + 1.)))
|
||||
return -0.5 * (self.df + 1.) * torch.log1p(y**2. / self.df) - Z
|
||||
|
||||
def entropy(self):
|
||||
lbeta = torch.lgamma(0.5 * self.df) + math.lgamma(0.5) - torch.lgamma(0.5 * (self.df + 1))
|
||||
return (self.scale.log() +
|
||||
0.5 * (self.df + 1) *
|
||||
(torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) +
|
||||
0.5 * self.df.log() + lbeta)
|
Reference in New Issue
Block a user