Files
pytorch/torch/distributions/independent.py
Neeraj Pradhan 80fa8e1007 Add .expand() method to distribution classes (#11341)
Summary:
This adds a `.expand` method for distributions that is akin to the `torch.Tensor.expand` method for tensors. It returns a new distribution instance with batch dimensions expanded to the desired `batch_shape`. Since this calls `torch.Tensor.expand` on the distribution's parameters, it does not allocate new memory for the expanded distribution instance's parameters.

e.g.
```python
>>> d = dist.Normal(torch.zeros(100, 1), torch.ones(100, 1))
>>> d.sample().shape
  torch.Size([100, 1])
>>> d.expand([100, 10]).sample().shape
  torch.Size([100, 10])
```

We have already been using the `.expand` method in Pyro in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py#L10) of `torch.distributions`. We use this in our models to enable dynamic broadcasting. This has also been requested by a few users on the distributions slack, and we believe will be useful to the larger community.

Note that currently, there is no convenient and efficient way to expand distribution instances:
 - Many distributions use `TransformedDistribution` (or wrap over another distribution instance. e.g. `OneHotCategorical` uses a `Categorical` instance) under the hood, or have lazy parameters. This makes it difficult to collect all the relevant parameters, broadcast them and construct new instances.
 - In the few cases where this is even possible, the resulting implementation would be inefficient since we will go through a lot of broadcasting and args validation logic in `__init__.py` that can be avoided.

The `.expand` method allows for a safe and efficient way to expand distribution instances. Additionally, this bypasses `__init__.py` (using `__new__` and populating relevant attributes) since we do not need to do any broadcasting or args validation (which was already done when the instance was first created). This can result in significant savings as compared to constructing new instances via `__init__` (that said, the `sample` and `log_prob` methods will probably be the rate determining steps in many applications).

e.g.
```python
>>> a = dist.Bernoulli(torch.ones([10000, 1]), validate_args=True)

>>> %timeit a.expand([10000, 100])
15.2 µs ± 224 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

>>> %timeit dist.Bernoulli(torch.ones([10000, 100]), validate_args=True)
11.8 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

cc. fritzo, apaszke, vishwakftw, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11341

Differential Revision: D9728485

Pulled By: soumith

fbshipit-source-id: 3b94c23bc6a43ee704389e6287aa83d1e278d52f
2018-09-11 06:56:18 -07:00

99 lines
3.9 KiB
Python

import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import _sum_rightmost
class Independent(Distribution):
r"""
Reinterprets some of the batch dims of a distribution as event dims.
This is mainly useful for changing the shape of the result of
:meth:`log_prob`. For example to create a diagonal Normal distribution with
the same shape as a Multivariate Normal distribution (so they are
interchangeable), you can::
>>> loc = torch.zeros(3)
>>> scale = torch.ones(3)
>>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
>>> [mvn.batch_shape, mvn.event_shape]
[torch.Size(()), torch.Size((3,))]
>>> normal = Normal(loc, scale)
>>> [normal.batch_shape, normal.event_shape]
[torch.Size((3,)), torch.Size(())]
>>> diagn = Independent(normal, 1)
>>> [diagn.batch_shape, diagn.event_shape]
[torch.Size(()), torch.Size((3,))]
Args:
base_distribution (torch.distributions.distribution.Distribution): a
base distribution
reinterpreted_batch_ndims (int): the number of batch dims to
reinterpret as event dims
"""
arg_constraints = {}
def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None):
if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
raise ValueError("Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
"actual {} vs {}".format(reinterpreted_batch_ndims,
len(base_distribution.batch_shape)))
shape = base_distribution.batch_shape + base_distribution.event_shape
event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)
batch_shape = shape[:len(shape) - event_dim]
event_shape = shape[len(shape) - event_dim:]
self.base_dist = base_distribution
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
super(Independent, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Independent, _instance)
batch_shape = torch.Size(batch_shape)
new.base_dist = self.base_dist.expand(batch_shape +
self.event_shape[:self.reinterpreted_batch_ndims])
new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
super(Independent, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
@property
def has_rsample(self):
return self.base_dist.has_rsample
@property
def has_enumerate_support(self):
if self.reinterpreted_batch_ndims > 0:
return False
return self.base_dist.has_enumerate_support
@constraints.dependent_property
def support(self):
return self.base_dist.support
@property
def mean(self):
return self.base_dist.mean
@property
def variance(self):
return self.base_dist.variance
def sample(self, sample_shape=torch.Size()):
return self.base_dist.sample(sample_shape)
def rsample(self, sample_shape=torch.Size()):
return self.base_dist.rsample(sample_shape)
def log_prob(self, value):
log_prob = self.base_dist.log_prob(value)
return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
def entropy(self):
entropy = self.base_dist.entropy()
return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
def enumerate_support(self, expand=True):
if self.reinterpreted_batch_ndims > 0:
raise NotImplementedError("Enumeration over cartesian product is not implemented")
return self.base_dist.enumerate_support(expand=expand)