mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement Multinomial distribution (#4624)
This commit is contained in:
committed by
Adam Paszke
parent
8eded5aece
commit
9b6441ecbc
@ -31,8 +31,8 @@ import torch
|
||||
from common import TestCase, run_tests, set_rng_seed
|
||||
from torch.autograd import Variable, grad, gradcheck
|
||||
from torch.distributions import (Bernoulli, Beta, Categorical, Cauchy, Chi2,
|
||||
Dirichlet, Exponential, Gamma, Gumbel,
|
||||
Laplace, Normal, OneHotCategorical, Pareto,
|
||||
Dirichlet, Exponential, Gamma, Gumbel, Laplace,
|
||||
Normal, OneHotCategorical, Multinomial, Pareto,
|
||||
StudentT, Uniform, kl_divergence)
|
||||
from torch.distributions.dirichlet import _Dirichlet_backward
|
||||
from torch.distributions.constraints import Constraint, is_dependent
|
||||
@ -69,6 +69,10 @@ EXAMPLES = [
|
||||
{'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True)},
|
||||
{'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True)},
|
||||
]),
|
||||
Example(Multinomial, [
|
||||
{'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True), 'total_count': 10},
|
||||
{'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True), 'total_count': 10},
|
||||
]),
|
||||
Example(Cauchy, [
|
||||
{'loc': 0.0, 'scale': 1.0},
|
||||
{'loc': Variable(torch.Tensor([0.0])), 'scale': 1.0},
|
||||
@ -294,6 +298,53 @@ class TestDistributions(TestCase):
|
||||
(2, 5, 2, 3, 5))
|
||||
self.assertEqual(Bernoulli(p).sample_n(2).size(), (2, 2, 3, 5))
|
||||
|
||||
def test_multinomial_1d(self):
|
||||
total_count = 10
|
||||
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
|
||||
self.assertEqual(Multinomial(total_count, p).sample().size(), (3,))
|
||||
self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3))
|
||||
self.assertEqual(Multinomial(total_count, p).sample_n(1).size(), (1, 3))
|
||||
self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
|
||||
self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
|
||||
self.assertRaises(NotImplementedError, Multinomial(10, p).rsample)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_multinomial_1d_log_prob(self):
|
||||
total_count = 10
|
||||
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
|
||||
dist = Multinomial(total_count, probs=p)
|
||||
x = dist.sample()
|
||||
log_prob = dist.log_prob(x)
|
||||
expected = torch.Tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
|
||||
self.assertEqual(log_prob.data, expected)
|
||||
|
||||
dist = Multinomial(total_count, logits=p.log())
|
||||
x = dist.sample()
|
||||
log_prob = dist.log_prob(x)
|
||||
expected = torch.Tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
|
||||
self.assertEqual(log_prob.data, expected)
|
||||
|
||||
def test_multinomial_2d(self):
|
||||
total_count = 10
|
||||
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
|
||||
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
|
||||
p = Variable(torch.Tensor(probabilities), requires_grad=True)
|
||||
s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
|
||||
self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3))
|
||||
self.assertEqual(Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
|
||||
self.assertEqual(Multinomial(total_count, p).sample_n(6).size(), (6, 2, 3))
|
||||
set_rng_seed(0)
|
||||
self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
|
||||
p.grad.zero_()
|
||||
self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
|
||||
|
||||
# sample check for extreme value of probs
|
||||
self.assertEqual(Multinomial(total_count, s).sample().data,
|
||||
torch.Tensor([[total_count, 0], [0, total_count]]))
|
||||
|
||||
# check entropy computation
|
||||
self.assertRaises(NotImplementedError, Multinomial(10, p).entropy)
|
||||
|
||||
def test_categorical_1d(self):
|
||||
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
|
||||
# TODO: this should return a 0-dim tensor once we have Scalar support
|
||||
@ -1096,13 +1147,16 @@ class TestDistributionShapes(TestCase):
|
||||
for Dist, params in EXAMPLES:
|
||||
for i, param in enumerate(params):
|
||||
dist = Dist(**param)
|
||||
actual_shape = dist.entropy().size()
|
||||
expected_shape = dist._batch_shape
|
||||
if not expected_shape:
|
||||
expected_shape = torch.Size((1,)) # TODO Remove this once scalars are supported.
|
||||
message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
|
||||
Dist.__name__, i, len(params), expected_shape, actual_shape)
|
||||
self.assertEqual(actual_shape, expected_shape, message=message)
|
||||
try:
|
||||
actual_shape = dist.entropy().size()
|
||||
expected_shape = dist._batch_shape
|
||||
if not expected_shape:
|
||||
expected_shape = torch.Size((1,)) # TODO Remove this once scalars are supported.
|
||||
message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
|
||||
Dist.__name__, i, len(params), expected_shape, actual_shape)
|
||||
self.assertEqual(actual_shape, expected_shape, message=message)
|
||||
except NotImplementedError:
|
||||
continue
|
||||
|
||||
def test_bernoulli_shape_scalar_params(self):
|
||||
bernoulli = Bernoulli(0.3)
|
||||
@ -1145,6 +1199,16 @@ class TestDistributionShapes(TestCase):
|
||||
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
|
||||
self.assertEqual(dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
|
||||
|
||||
def test_multinomial_shape(self):
|
||||
dist = Multinomial(10, torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
|
||||
self.assertEqual(dist._batch_shape, torch.Size((3,)))
|
||||
self.assertEqual(dist._event_shape, torch.Size((2,)))
|
||||
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
|
||||
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
|
||||
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
|
||||
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
|
||||
self.assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3)))
|
||||
|
||||
def test_categorical_shape(self):
|
||||
dist = Categorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
|
||||
self.assertEqual(dist._batch_shape, torch.Size((3,)))
|
||||
@ -1375,11 +1439,14 @@ class TestConstraints(TestCase):
|
||||
for name, value in param.items():
|
||||
if not (torch.is_tensor(value) or isinstance(value, Variable)):
|
||||
value = torch.Tensor([value])
|
||||
if Dist in (Categorical, OneHotCategorical) and name == 'probs':
|
||||
if Dist in (Categorical, OneHotCategorical, Multinomial) and name == 'probs':
|
||||
# These distributions accept positive probs, but elsewhere we
|
||||
# use a stricter constraint to the simplex.
|
||||
value = value / value.sum(-1, True)
|
||||
constraint = dist.params[name]
|
||||
try:
|
||||
constraint = dist.params[name]
|
||||
except KeyError:
|
||||
continue # ignore optional parameters
|
||||
if is_dependent(constraint):
|
||||
continue
|
||||
message = '{} example {}/{} parameter {} = {}'.format(
|
||||
@ -1499,6 +1566,23 @@ class TestNumericalStability(TestCase):
|
||||
log_pdf_prob_0 = categorical.log_prob(Variable(tensor_type([1, 0])))
|
||||
self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True)
|
||||
|
||||
def test_multinomial_log_prob(self):
|
||||
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
|
||||
p = Variable(tensor_type([0, 1]), requires_grad=True)
|
||||
s = Variable(tensor_type([0, 10]))
|
||||
multinomial = Multinomial(10, p)
|
||||
log_pdf = multinomial.log_prob(s)
|
||||
self.assertEqual(log_pdf.data[0], 0)
|
||||
|
||||
def test_multinomial_log_prob_with_logits(self):
|
||||
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
|
||||
p = Variable(tensor_type([-float('inf'), 0]), requires_grad=True)
|
||||
multinomial = Multinomial(10, logits=p)
|
||||
log_pdf_prob_1 = multinomial.log_prob(Variable(tensor_type([0, 10])))
|
||||
self.assertEqual(log_pdf_prob_1.data[0], 0)
|
||||
log_pdf_prob_0 = multinomial.log_prob(Variable(tensor_type([10, 0])))
|
||||
self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
@ -42,6 +42,7 @@ from .gamma import Gamma
|
||||
from .gumbel import Gumbel
|
||||
from .kl import kl_divergence, register_kl
|
||||
from .laplace import Laplace
|
||||
from .multinomial import Multinomial
|
||||
from .normal import Normal
|
||||
from .one_hot_categorical import OneHotCategorical
|
||||
from .pareto import Pareto
|
||||
@ -60,6 +61,7 @@ __all__ = [
|
||||
'Gamma',
|
||||
'Gumbel',
|
||||
'Laplace',
|
||||
'Multinomial',
|
||||
'Normal',
|
||||
'OneHotCategorical',
|
||||
'Pareto',
|
||||
|
||||
@ -7,10 +7,12 @@ from torch.distributions.utils import probs_to_logits, logits_to_probs, log_sum_
|
||||
|
||||
class Categorical(Distribution):
|
||||
r"""
|
||||
Creates a categorical distribution parameterized by `probs`.
|
||||
Creates a categorical distribution parameterized by either `probs` or
|
||||
`logits` (but not both).
|
||||
|
||||
.. note::
|
||||
It is equivalent to the distribution that ``multinomial()`` samples from.
|
||||
It is equivalent to the distribution that :func:`torch.multinomial`
|
||||
samples from.
|
||||
|
||||
Samples are integers from `0 ... K-1` where `K` is probs.size(-1).
|
||||
|
||||
@ -30,6 +32,7 @@ class Categorical(Distribution):
|
||||
|
||||
Args:
|
||||
probs (Tensor or Variable): event probabilities
|
||||
logits (Tensor or Variable): event log probabilities
|
||||
"""
|
||||
params = {'probs': constraints.simplex}
|
||||
has_enumerate_support = True
|
||||
|
||||
@ -10,6 +10,7 @@ __all__ = [
|
||||
'integer_interval',
|
||||
'interval',
|
||||
'is_dependent',
|
||||
'less_than',
|
||||
'lower_triangular',
|
||||
'nonnegative_integer',
|
||||
'positive',
|
||||
@ -112,6 +113,17 @@ class _GreaterThan(Constraint):
|
||||
return self.lower_bound <= value
|
||||
|
||||
|
||||
class _LessThan(Constraint):
|
||||
"""
|
||||
Constrain to a real half line `[inf, upper_bound]`.
|
||||
"""
|
||||
def __init__(self, upper_bound):
|
||||
self.upper_bound = upper_bound
|
||||
|
||||
def check(self, value):
|
||||
return value <= self.upper_bound
|
||||
|
||||
|
||||
class _Interval(Constraint):
|
||||
"""
|
||||
Constrain to a real interval `[lower_bound, upper_bound]`.
|
||||
@ -150,6 +162,7 @@ integer_interval = _IntegerInterval
|
||||
real = _Real()
|
||||
positive = _GreaterThan(0)
|
||||
greater_than = _GreaterThan
|
||||
less_than = _LessThan
|
||||
unit_interval = _Interval(0, 1)
|
||||
interval = _Interval
|
||||
simplex = _Simplex()
|
||||
|
||||
85
torch/distributions/multinomial.py
Normal file
85
torch/distributions/multinomial.py
Normal file
@ -0,0 +1,85 @@
|
||||
import torch
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.autograd import Variable
|
||||
from torch.distributions import Categorical
|
||||
from numbers import Number
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.utils import log_sum_exp, broadcast_all
|
||||
|
||||
|
||||
class Multinomial(Distribution):
|
||||
r"""
|
||||
Creates a Multinomial distribution parameterized by `total_count` and
|
||||
either `probs` or `logits` (but not both). The innermost dimension of
|
||||
`probs` indexes over categories. All other dimensions index over batches.
|
||||
|
||||
Note that `total_count` need not be specified if only :meth:`log_prob` is
|
||||
called (see example below)
|
||||
|
||||
- :meth:`sample` requires a single shared `total_count` for all
|
||||
parameters and samples.
|
||||
- :meth:`log_prob` allows different `total_count` for each parameter and
|
||||
sample.
|
||||
|
||||
Example::
|
||||
|
||||
>>> m = Multinomial(100, torch.Tensor([ 1, 1, 1, 1]))
|
||||
>>> x = m.sample() # equal probability of 0, 1, 2, 3
|
||||
21
|
||||
24
|
||||
30
|
||||
25
|
||||
[torch.FloatTensor of size 4]]
|
||||
|
||||
>>> Multinomial(probs=torch.Tensor([1, 1, 1, 1])).log_prob(x)
|
||||
-4.1338
|
||||
[torch.FloatTensor of size 1]
|
||||
|
||||
Args:
|
||||
total_count (int): number of trials
|
||||
probs (Tensor or Variable): event probabilities
|
||||
logits (Tensor or Variable): event log probabilities
|
||||
"""
|
||||
params = {'logits': constraints.real} # Let logits be the canonical parameterization.
|
||||
|
||||
def __init__(self, total_count=1, probs=None, logits=None):
|
||||
if not isinstance(total_count, Number):
|
||||
raise NotImplementedError('inhomogeneous total_count is not supported')
|
||||
self.total_count = total_count
|
||||
self._categorical = Categorical(probs=probs, logits=logits)
|
||||
batch_shape = probs.size()[:-1] if probs is not None else logits.size()[:-1]
|
||||
event_shape = probs.size()[-1:] if probs is not None else logits.size()[-1:]
|
||||
super(Multinomial, self).__init__(batch_shape, event_shape)
|
||||
|
||||
@constraints.dependent_property
|
||||
def support(self):
|
||||
return constraints.integer_interval(0, self.total_count)
|
||||
|
||||
@property
|
||||
def logits(self):
|
||||
return self._categorical.logits
|
||||
|
||||
@property
|
||||
def probs(self):
|
||||
return self._categorical.probs
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
sample_shape = torch.Size(sample_shape)
|
||||
samples = self._categorical.sample(torch.Size((self.total_count,)) + sample_shape)
|
||||
# samples.shape is (total_count, sample_shape, batch_shape), need to change it to
|
||||
# (sample_shape, batch_shape, total_count)
|
||||
shifted_idx = list(range(samples.dim()))
|
||||
shifted_idx.append(shifted_idx.pop(0))
|
||||
samples = samples.permute(*shifted_idx)
|
||||
counts = samples.new(self._extended_shape(sample_shape)).zero_()
|
||||
counts.scatter_add_(-1, samples, torch.ones_like(samples))
|
||||
return counts.type_as(self.probs)
|
||||
|
||||
def log_prob(self, value):
|
||||
self._validate_log_prob_arg(value)
|
||||
logits, value = broadcast_all(self.logits.clone(), value)
|
||||
log_factorial_n = torch.lgamma(value.sum(-1) + 1)
|
||||
log_factorial_xs = torch.lgamma(value + 1).sum(-1)
|
||||
logits[(value == 0) & (logits == -float('inf'))] = 0
|
||||
log_powers = (logits * value).sum(-1)
|
||||
return log_factorial_n - log_factorial_xs + log_powers
|
||||
Reference in New Issue
Block a user