mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix support of MixtureSameFamily [bugfix]. (#151317)
Fixes https://github.com/pyro-ppl/pyro/issues/3419 which is actually a `torch` bug that can be replicated by the below code:
```
from torch import rand
from torch.distributions import MixtureSameFamily, Categorical, Binomial
max_count = 20
probs = rand(10, 5)
binom_probs = rand(10, 5)
d = MixtureSameFamily(Categorical(probs=probs), Binomial(max_count, binom_probs))
d.log_prob(d.sample())
```
which results in:
```
Traceback (most recent call last):
File "test.py", line 11, in <module>
d.log_prob(d.sample())
File "pytorch\torch\distributions\mixture_same_family.py", line 168, in log_prob
self._validate_sample(x)
File "pytorch\torch\distributions\distribution.py", line 315, in _validate_sample
valid = support.check(value)
^^^^^^^^^^^^^^^^^^^^
File "pytorch\torch\distributions\constraints.py", line 307, in check
(value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 1
```
### Fix explanation (only for cases when the component distribution contains parameters with batch dimenisons)
- The failure is due to sample validation taking place before padding in `MixtureSameFamily.log_prob`, and hence the fix is to pad before doing sample validation.
- The fix itself does not alter the calculations at all. It only affects the sample validation process.
- The failure does not occur with the component distribution set to the `Normal` distribution, as its validation is not defined elementwise (the validation itself is elementwise).
- I've split the `test_mixture_same_family_log_prob` test into two tests based on the `Normal` and `Binomial` distributions.
- Initially, the `Binomial` version of the test did not fail, but this was due to the component distribution having equal batch dimensions of (5, 5) so I changed it to (10, 5).
### Updated fix explanation (for all cases)
- The previous fix caused a bug in sample shape validation (which is done correctly) due to the padding taking place before the sample validation.
- The updated fix corrects the support to reflect the fact that the support of `MixtureSameFamily` is equal to the support of its components distribution with the first event dimension removed.
- This issue was already anticipated in the [code](331423e5c2/torch/distributions/mixture_same_family.py (L127)
).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151317
Approved by: https://github.com/albanD, https://github.com/fritzo
This commit is contained in:
committed by
PyTorch MergeBot
parent
534b66fe30
commit
a54bf43baa
@ -2558,10 +2558,10 @@ class TestDistributions(DistributionsTestCase):
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_mixture_same_family_log_prob(self):
|
||||
probs = torch.rand(5, 5).softmax(dim=-1)
|
||||
loc = torch.randn(5, 5)
|
||||
scale = torch.rand(5, 5)
|
||||
def test_mixture_same_family_normal_log_prob(self):
|
||||
probs = torch.rand(10, 5).softmax(dim=-1)
|
||||
loc = torch.randn(10, 5)
|
||||
scale = torch.rand(10, 5)
|
||||
|
||||
def ref_log_prob(idx, x, log_prob):
|
||||
p = probs[idx].numpy()
|
||||
@ -2577,6 +2577,27 @@ class TestDistributions(DistributionsTestCase):
|
||||
ref_log_prob,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_mixture_same_family_binomial_log_prob(self):
|
||||
max_count = 20
|
||||
probs = torch.rand(10, 5).softmax(dim=-1)
|
||||
binom_probs = torch.rand(10, 5)
|
||||
|
||||
def ref_log_prob(idx, x, log_prob):
|
||||
p = probs[idx].numpy()
|
||||
binom_p = binom_probs[idx].numpy()
|
||||
mix = scipy.stats.multinomial(1, p)
|
||||
comp = scipy.stats.binom(max_count, binom_p)
|
||||
expected = scipy.special.logsumexp(comp.logpmf(x) + np.log(mix.p))
|
||||
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
|
||||
|
||||
self._check_log_prob(
|
||||
MixtureSameFamily(
|
||||
Categorical(probs=probs), Binomial(max_count, binom_probs)
|
||||
),
|
||||
ref_log_prob,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_mixture_same_family_sample(self):
|
||||
probs = torch.rand(5).softmax(dim=-1)
|
||||
|
@ -18,6 +18,7 @@ The following constraints are implemented:
|
||||
- ``constraints.less_than(upper_bound)``
|
||||
- ``constraints.lower_cholesky``
|
||||
- ``constraints.lower_triangular``
|
||||
- ``constraints.MixtureSameFamilyConstraint(base_constraint)``
|
||||
- ``constraints.multinomial``
|
||||
- ``constraints.nonnegative``
|
||||
- ``constraints.nonnegative_integer``
|
||||
@ -56,6 +57,7 @@ __all__ = [
|
||||
"less_than",
|
||||
"lower_cholesky",
|
||||
"lower_triangular",
|
||||
"MixtureSameFamilyConstraint",
|
||||
"multinomial",
|
||||
"nonnegative",
|
||||
"nonnegative_integer",
|
||||
@ -265,6 +267,52 @@ class _IndependentConstraint(Constraint):
|
||||
return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})"
|
||||
|
||||
|
||||
class MixtureSameFamilyConstraint(Constraint):
|
||||
"""
|
||||
Constraint for the :class:`~torch.distribution.MixtureSameFamily`
|
||||
distribution that adds back the rightmost batch dimension before
|
||||
performing the validity check with the component distribution
|
||||
constraint.
|
||||
|
||||
Args:
|
||||
base_constraint: The ``Constraint`` object of
|
||||
the component distribution of
|
||||
the :class:`~torch.distribution.MixtureSameFamily` distribution.
|
||||
"""
|
||||
|
||||
def __init__(self, base_constraint):
|
||||
assert isinstance(base_constraint, Constraint)
|
||||
self.base_constraint = base_constraint
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def is_discrete(self) -> bool: # type: ignore[override]
|
||||
return self.base_constraint.is_discrete
|
||||
|
||||
@property
|
||||
def event_dim(self) -> int: # type: ignore[override]
|
||||
return self.base_constraint.event_dim
|
||||
|
||||
def check(self, value):
|
||||
"""
|
||||
Check validity of ``value`` as a possible outcome of sampling
|
||||
the :class:`~torch.distribution.MixtureSameFamily` distribution.
|
||||
"""
|
||||
unsqueezed_value = value.unsqueeze(-1 - self.event_dim)
|
||||
result = self.base_constraint.check(unsqueezed_value)
|
||||
if value.dim() < self.event_dim:
|
||||
raise ValueError(
|
||||
f"Expected value.dim() >= {self.event_dim} but got {value.dim()}"
|
||||
)
|
||||
num_dim_to_keep = value.dim() - self.event_dim
|
||||
result = result.reshape(result.shape[:num_dim_to_keep] + (-1,))
|
||||
result = result.all(-1)
|
||||
return result
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({repr(self.base_constraint)})"
|
||||
|
||||
|
||||
class _Boolean(Constraint):
|
||||
"""
|
||||
Constrain to the two values `{0, 1}`.
|
||||
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import Categorical, constraints
|
||||
from torch.distributions.constraints import MixtureSameFamilyConstraint
|
||||
from torch.distributions.distribution import Distribution
|
||||
|
||||
|
||||
@ -124,9 +125,7 @@ class MixtureSameFamily(Distribution):
|
||||
|
||||
@constraints.dependent_property
|
||||
def support(self):
|
||||
# FIXME this may have the wrong shape when support contains batched
|
||||
# parameters
|
||||
return self._component_distribution.support
|
||||
return MixtureSameFamilyConstraint(self._component_distribution.support)
|
||||
|
||||
@property
|
||||
def mixture_distribution(self) -> Categorical:
|
||||
|
Reference in New Issue
Block a user