Fix support of MixtureSameFamily [bugfix]. (#151317)

Fixes https://github.com/pyro-ppl/pyro/issues/3419 which is actually a `torch` bug that can be replicated by the below code:

```
from torch import rand
from torch.distributions import MixtureSameFamily, Categorical, Binomial

max_count = 20
probs = rand(10, 5)
binom_probs = rand(10, 5)

d = MixtureSameFamily(Categorical(probs=probs), Binomial(max_count, binom_probs))
d.log_prob(d.sample())
```

which results in:

```
Traceback (most recent call last):
  File "test.py", line 11, in <module>
    d.log_prob(d.sample())
  File "pytorch\torch\distributions\mixture_same_family.py", line 168, in log_prob
    self._validate_sample(x)
  File "pytorch\torch\distributions\distribution.py", line 315, in _validate_sample
    valid = support.check(value)
            ^^^^^^^^^^^^^^^^^^^^
  File "pytorch\torch\distributions\constraints.py", line 307, in check
    (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
                                                      ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 1
```

### Fix explanation (only for cases when the component distribution contains parameters with batch dimenisons)

- The failure is due to sample validation taking place before padding in `MixtureSameFamily.log_prob`, and hence the fix is to pad before doing sample validation.
- The fix itself does not alter the calculations at all. It only affects the sample validation process.
- The failure does not occur with the component distribution set to the `Normal` distribution, as its validation is not defined elementwise (the validation itself is elementwise).
- I've split the `test_mixture_same_family_log_prob` test into two tests based on the `Normal` and `Binomial` distributions.
- Initially, the `Binomial` version of the test did not fail, but this was due to the component distribution having equal batch dimensions of (5, 5) so I changed it to (10, 5).

### Updated fix explanation (for all cases)

- The previous fix caused a bug in sample shape validation (which is done correctly) due to the padding taking place before the sample validation.
- The updated fix corrects the support to reflect the fact that the support of `MixtureSameFamily` is equal to the support of its components distribution with the first event dimension removed.
- This issue was already anticipated in the [code](331423e5c2/torch/distributions/mixture_same_family.py (L127)).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151317
Approved by: https://github.com/albanD, https://github.com/fritzo
This commit is contained in:
Ben Zickel
2025-05-14 19:24:33 +00:00
committed by PyTorch MergeBot
parent 534b66fe30
commit a54bf43baa
3 changed files with 75 additions and 7 deletions

View File

@ -2558,10 +2558,10 @@ class TestDistributions(DistributionsTestCase):
)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_mixture_same_family_log_prob(self):
probs = torch.rand(5, 5).softmax(dim=-1)
loc = torch.randn(5, 5)
scale = torch.rand(5, 5)
def test_mixture_same_family_normal_log_prob(self):
probs = torch.rand(10, 5).softmax(dim=-1)
loc = torch.randn(10, 5)
scale = torch.rand(10, 5)
def ref_log_prob(idx, x, log_prob):
p = probs[idx].numpy()
@ -2577,6 +2577,27 @@ class TestDistributions(DistributionsTestCase):
ref_log_prob,
)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_mixture_same_family_binomial_log_prob(self):
max_count = 20
probs = torch.rand(10, 5).softmax(dim=-1)
binom_probs = torch.rand(10, 5)
def ref_log_prob(idx, x, log_prob):
p = probs[idx].numpy()
binom_p = binom_probs[idx].numpy()
mix = scipy.stats.multinomial(1, p)
comp = scipy.stats.binom(max_count, binom_p)
expected = scipy.special.logsumexp(comp.logpmf(x) + np.log(mix.p))
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(
MixtureSameFamily(
Categorical(probs=probs), Binomial(max_count, binom_probs)
),
ref_log_prob,
)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_mixture_same_family_sample(self):
probs = torch.rand(5).softmax(dim=-1)

View File

@ -18,6 +18,7 @@ The following constraints are implemented:
- ``constraints.less_than(upper_bound)``
- ``constraints.lower_cholesky``
- ``constraints.lower_triangular``
- ``constraints.MixtureSameFamilyConstraint(base_constraint)``
- ``constraints.multinomial``
- ``constraints.nonnegative``
- ``constraints.nonnegative_integer``
@ -56,6 +57,7 @@ __all__ = [
"less_than",
"lower_cholesky",
"lower_triangular",
"MixtureSameFamilyConstraint",
"multinomial",
"nonnegative",
"nonnegative_integer",
@ -265,6 +267,52 @@ class _IndependentConstraint(Constraint):
return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})"
class MixtureSameFamilyConstraint(Constraint):
"""
Constraint for the :class:`~torch.distribution.MixtureSameFamily`
distribution that adds back the rightmost batch dimension before
performing the validity check with the component distribution
constraint.
Args:
base_constraint: The ``Constraint`` object of
the component distribution of
the :class:`~torch.distribution.MixtureSameFamily` distribution.
"""
def __init__(self, base_constraint):
assert isinstance(base_constraint, Constraint)
self.base_constraint = base_constraint
super().__init__()
@property
def is_discrete(self) -> bool: # type: ignore[override]
return self.base_constraint.is_discrete
@property
def event_dim(self) -> int: # type: ignore[override]
return self.base_constraint.event_dim
def check(self, value):
"""
Check validity of ``value`` as a possible outcome of sampling
the :class:`~torch.distribution.MixtureSameFamily` distribution.
"""
unsqueezed_value = value.unsqueeze(-1 - self.event_dim)
result = self.base_constraint.check(unsqueezed_value)
if value.dim() < self.event_dim:
raise ValueError(
f"Expected value.dim() >= {self.event_dim} but got {value.dim()}"
)
num_dim_to_keep = value.dim() - self.event_dim
result = result.reshape(result.shape[:num_dim_to_keep] + (-1,))
result = result.all(-1)
return result
def __repr__(self):
return f"{self.__class__.__name__}({repr(self.base_constraint)})"
class _Boolean(Constraint):
"""
Constrain to the two values `{0, 1}`.

View File

@ -4,6 +4,7 @@ from typing import Optional
import torch
from torch import Tensor
from torch.distributions import Categorical, constraints
from torch.distributions.constraints import MixtureSameFamilyConstraint
from torch.distributions.distribution import Distribution
@ -124,9 +125,7 @@ class MixtureSameFamily(Distribution):
@constraints.dependent_property
def support(self):
# FIXME this may have the wrong shape when support contains batched
# parameters
return self._component_distribution.support
return MixtureSameFamilyConstraint(self._component_distribution.support)
@property
def mixture_distribution(self) -> Categorical: