Implement Multinomial distribution (#4624)

This commit is contained in:
Alican Bozkurt
2018-01-13 05:26:14 -05:00
committed by Adam Paszke
parent 8eded5aece
commit 9b6441ecbc
5 changed files with 200 additions and 13 deletions

View File

@ -31,8 +31,8 @@ import torch
from common import TestCase, run_tests, set_rng_seed
from torch.autograd import Variable, grad, gradcheck
from torch.distributions import (Bernoulli, Beta, Categorical, Cauchy, Chi2,
Dirichlet, Exponential, Gamma, Gumbel,
Laplace, Normal, OneHotCategorical, Pareto,
Dirichlet, Exponential, Gamma, Gumbel, Laplace,
Normal, OneHotCategorical, Multinomial, Pareto,
StudentT, Uniform, kl_divergence)
from torch.distributions.dirichlet import _Dirichlet_backward
from torch.distributions.constraints import Constraint, is_dependent
@ -69,6 +69,10 @@ EXAMPLES = [
{'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True)},
{'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True)},
]),
Example(Multinomial, [
{'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True), 'total_count': 10},
{'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True), 'total_count': 10},
]),
Example(Cauchy, [
{'loc': 0.0, 'scale': 1.0},
{'loc': Variable(torch.Tensor([0.0])), 'scale': 1.0},
@ -294,6 +298,53 @@ class TestDistributions(TestCase):
(2, 5, 2, 3, 5))
self.assertEqual(Bernoulli(p).sample_n(2).size(), (2, 2, 3, 5))
def test_multinomial_1d(self):
total_count = 10
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
self.assertEqual(Multinomial(total_count, p).sample().size(), (3,))
self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3))
self.assertEqual(Multinomial(total_count, p).sample_n(1).size(), (1, 3))
self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
self.assertRaises(NotImplementedError, Multinomial(10, p).rsample)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_multinomial_1d_log_prob(self):
total_count = 10
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
dist = Multinomial(total_count, probs=p)
x = dist.sample()
log_prob = dist.log_prob(x)
expected = torch.Tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
self.assertEqual(log_prob.data, expected)
dist = Multinomial(total_count, logits=p.log())
x = dist.sample()
log_prob = dist.log_prob(x)
expected = torch.Tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
self.assertEqual(log_prob.data, expected)
def test_multinomial_2d(self):
total_count = 10
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
p = Variable(torch.Tensor(probabilities), requires_grad=True)
s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3))
self.assertEqual(Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
self.assertEqual(Multinomial(total_count, p).sample_n(6).size(), (6, 2, 3))
set_rng_seed(0)
self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
p.grad.zero_()
self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
# sample check for extreme value of probs
self.assertEqual(Multinomial(total_count, s).sample().data,
torch.Tensor([[total_count, 0], [0, total_count]]))
# check entropy computation
self.assertRaises(NotImplementedError, Multinomial(10, p).entropy)
def test_categorical_1d(self):
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
# TODO: this should return a 0-dim tensor once we have Scalar support
@ -1096,13 +1147,16 @@ class TestDistributionShapes(TestCase):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
actual_shape = dist.entropy().size()
expected_shape = dist._batch_shape
if not expected_shape:
expected_shape = torch.Size((1,)) # TODO Remove this once scalars are supported.
message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
Dist.__name__, i, len(params), expected_shape, actual_shape)
self.assertEqual(actual_shape, expected_shape, message=message)
try:
actual_shape = dist.entropy().size()
expected_shape = dist._batch_shape
if not expected_shape:
expected_shape = torch.Size((1,)) # TODO Remove this once scalars are supported.
message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
Dist.__name__, i, len(params), expected_shape, actual_shape)
self.assertEqual(actual_shape, expected_shape, message=message)
except NotImplementedError:
continue
def test_bernoulli_shape_scalar_params(self):
bernoulli = Bernoulli(0.3)
@ -1145,6 +1199,16 @@ class TestDistributionShapes(TestCase):
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
self.assertEqual(dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
def test_multinomial_shape(self):
dist = Multinomial(10, torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size((2,)))
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
self.assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3)))
def test_categorical_shape(self):
dist = Categorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
@ -1375,11 +1439,14 @@ class TestConstraints(TestCase):
for name, value in param.items():
if not (torch.is_tensor(value) or isinstance(value, Variable)):
value = torch.Tensor([value])
if Dist in (Categorical, OneHotCategorical) and name == 'probs':
if Dist in (Categorical, OneHotCategorical, Multinomial) and name == 'probs':
# These distributions accept positive probs, but elsewhere we
# use a stricter constraint to the simplex.
value = value / value.sum(-1, True)
constraint = dist.params[name]
try:
constraint = dist.params[name]
except KeyError:
continue # ignore optional parameters
if is_dependent(constraint):
continue
message = '{} example {}/{} parameter {} = {}'.format(
@ -1499,6 +1566,23 @@ class TestNumericalStability(TestCase):
log_pdf_prob_0 = categorical.log_prob(Variable(tensor_type([1, 0])))
self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True)
def test_multinomial_log_prob(self):
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
p = Variable(tensor_type([0, 1]), requires_grad=True)
s = Variable(tensor_type([0, 10]))
multinomial = Multinomial(10, p)
log_pdf = multinomial.log_prob(s)
self.assertEqual(log_pdf.data[0], 0)
def test_multinomial_log_prob_with_logits(self):
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
p = Variable(tensor_type([-float('inf'), 0]), requires_grad=True)
multinomial = Multinomial(10, logits=p)
log_pdf_prob_1 = multinomial.log_prob(Variable(tensor_type([0, 10])))
self.assertEqual(log_pdf_prob_1.data[0], 0)
log_pdf_prob_0 = multinomial.log_prob(Variable(tensor_type([10, 0])))
self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True)
if __name__ == '__main__':
run_tests()

View File

@ -42,6 +42,7 @@ from .gamma import Gamma
from .gumbel import Gumbel
from .kl import kl_divergence, register_kl
from .laplace import Laplace
from .multinomial import Multinomial
from .normal import Normal
from .one_hot_categorical import OneHotCategorical
from .pareto import Pareto
@ -60,6 +61,7 @@ __all__ = [
'Gamma',
'Gumbel',
'Laplace',
'Multinomial',
'Normal',
'OneHotCategorical',
'Pareto',

View File

@ -7,10 +7,12 @@ from torch.distributions.utils import probs_to_logits, logits_to_probs, log_sum_
class Categorical(Distribution):
r"""
Creates a categorical distribution parameterized by `probs`.
Creates a categorical distribution parameterized by either `probs` or
`logits` (but not both).
.. note::
It is equivalent to the distribution that ``multinomial()`` samples from.
It is equivalent to the distribution that :func:`torch.multinomial`
samples from.
Samples are integers from `0 ... K-1` where `K` is probs.size(-1).
@ -30,6 +32,7 @@ class Categorical(Distribution):
Args:
probs (Tensor or Variable): event probabilities
logits (Tensor or Variable): event log probabilities
"""
params = {'probs': constraints.simplex}
has_enumerate_support = True

View File

@ -10,6 +10,7 @@ __all__ = [
'integer_interval',
'interval',
'is_dependent',
'less_than',
'lower_triangular',
'nonnegative_integer',
'positive',
@ -112,6 +113,17 @@ class _GreaterThan(Constraint):
return self.lower_bound <= value
class _LessThan(Constraint):
"""
Constrain to a real half line `[inf, upper_bound]`.
"""
def __init__(self, upper_bound):
self.upper_bound = upper_bound
def check(self, value):
return value <= self.upper_bound
class _Interval(Constraint):
"""
Constrain to a real interval `[lower_bound, upper_bound]`.
@ -150,6 +162,7 @@ integer_interval = _IntegerInterval
real = _Real()
positive = _GreaterThan(0)
greater_than = _GreaterThan
less_than = _LessThan
unit_interval = _Interval(0, 1)
interval = _Interval
simplex = _Simplex()

View File

@ -0,0 +1,85 @@
import torch
from torch.distributions.distribution import Distribution
from torch.autograd import Variable
from torch.distributions import Categorical
from numbers import Number
from torch.distributions import constraints
from torch.distributions.utils import log_sum_exp, broadcast_all
class Multinomial(Distribution):
r"""
Creates a Multinomial distribution parameterized by `total_count` and
either `probs` or `logits` (but not both). The innermost dimension of
`probs` indexes over categories. All other dimensions index over batches.
Note that `total_count` need not be specified if only :meth:`log_prob` is
called (see example below)
- :meth:`sample` requires a single shared `total_count` for all
parameters and samples.
- :meth:`log_prob` allows different `total_count` for each parameter and
sample.
Example::
>>> m = Multinomial(100, torch.Tensor([ 1, 1, 1, 1]))
>>> x = m.sample() # equal probability of 0, 1, 2, 3
21
24
30
25
[torch.FloatTensor of size 4]]
>>> Multinomial(probs=torch.Tensor([1, 1, 1, 1])).log_prob(x)
-4.1338
[torch.FloatTensor of size 1]
Args:
total_count (int): number of trials
probs (Tensor or Variable): event probabilities
logits (Tensor or Variable): event log probabilities
"""
params = {'logits': constraints.real} # Let logits be the canonical parameterization.
def __init__(self, total_count=1, probs=None, logits=None):
if not isinstance(total_count, Number):
raise NotImplementedError('inhomogeneous total_count is not supported')
self.total_count = total_count
self._categorical = Categorical(probs=probs, logits=logits)
batch_shape = probs.size()[:-1] if probs is not None else logits.size()[:-1]
event_shape = probs.size()[-1:] if probs is not None else logits.size()[-1:]
super(Multinomial, self).__init__(batch_shape, event_shape)
@constraints.dependent_property
def support(self):
return constraints.integer_interval(0, self.total_count)
@property
def logits(self):
return self._categorical.logits
@property
def probs(self):
return self._categorical.probs
def sample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
samples = self._categorical.sample(torch.Size((self.total_count,)) + sample_shape)
# samples.shape is (total_count, sample_shape, batch_shape), need to change it to
# (sample_shape, batch_shape, total_count)
shifted_idx = list(range(samples.dim()))
shifted_idx.append(shifted_idx.pop(0))
samples = samples.permute(*shifted_idx)
counts = samples.new(self._extended_shape(sample_shape)).zero_()
counts.scatter_add_(-1, samples, torch.ones_like(samples))
return counts.type_as(self.probs)
def log_prob(self, value):
self._validate_log_prob_arg(value)
logits, value = broadcast_all(self.logits.clone(), value)
log_factorial_n = torch.lgamma(value.sum(-1) + 1)
log_factorial_xs = torch.lgamma(value + 1).sum(-1)
logits[(value == 0) & (logits == -float('inf'))] = 0
log_powers = (logits * value).sum(-1)
return log_factorial_n - log_factorial_xs + log_powers