mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This adds a `.expand` method for distributions that is akin to the `torch.Tensor.expand` method for tensors. It returns a new distribution instance with batch dimensions expanded to the desired `batch_shape`. Since this calls `torch.Tensor.expand` on the distribution's parameters, it does not allocate new memory for the expanded distribution instance's parameters. e.g. ```python >>> d = dist.Normal(torch.zeros(100, 1), torch.ones(100, 1)) >>> d.sample().shape torch.Size([100, 1]) >>> d.expand([100, 10]).sample().shape torch.Size([100, 10]) ``` We have already been using the `.expand` method in Pyro in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py#L10) of `torch.distributions`. We use this in our models to enable dynamic broadcasting. This has also been requested by a few users on the distributions slack, and we believe will be useful to the larger community. Note that currently, there is no convenient and efficient way to expand distribution instances: - Many distributions use `TransformedDistribution` (or wrap over another distribution instance. e.g. `OneHotCategorical` uses a `Categorical` instance) under the hood, or have lazy parameters. This makes it difficult to collect all the relevant parameters, broadcast them and construct new instances. - In the few cases where this is even possible, the resulting implementation would be inefficient since we will go through a lot of broadcasting and args validation logic in `__init__.py` that can be avoided. The `.expand` method allows for a safe and efficient way to expand distribution instances. Additionally, this bypasses `__init__.py` (using `__new__` and populating relevant attributes) since we do not need to do any broadcasting or args validation (which was already done when the instance was first created). This can result in significant savings as compared to constructing new instances via `__init__` (that said, the `sample` and `log_prob` methods will probably be the rate determining steps in many applications). e.g. ```python >>> a = dist.Bernoulli(torch.ones([10000, 1]), validate_args=True) >>> %timeit a.expand([10000, 100]) 15.2 µs ± 224 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) >>> %timeit dist.Bernoulli(torch.ones([10000, 100]), validate_args=True) 11.8 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` cc. fritzo, apaszke, vishwakftw, alicanb Pull Request resolved: https://github.com/pytorch/pytorch/pull/11341 Differential Revision: D9728485 Pulled By: soumith fbshipit-source-id: 3b94c23bc6a43ee704389e6287aa83d1e278d52f
99 lines
3.9 KiB
Python
99 lines
3.9 KiB
Python
import torch
|
|
from torch.distributions import constraints
|
|
from torch.distributions.distribution import Distribution
|
|
from torch.distributions.utils import _sum_rightmost
|
|
|
|
|
|
class Independent(Distribution):
|
|
r"""
|
|
Reinterprets some of the batch dims of a distribution as event dims.
|
|
|
|
This is mainly useful for changing the shape of the result of
|
|
:meth:`log_prob`. For example to create a diagonal Normal distribution with
|
|
the same shape as a Multivariate Normal distribution (so they are
|
|
interchangeable), you can::
|
|
|
|
>>> loc = torch.zeros(3)
|
|
>>> scale = torch.ones(3)
|
|
>>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
|
|
>>> [mvn.batch_shape, mvn.event_shape]
|
|
[torch.Size(()), torch.Size((3,))]
|
|
>>> normal = Normal(loc, scale)
|
|
>>> [normal.batch_shape, normal.event_shape]
|
|
[torch.Size((3,)), torch.Size(())]
|
|
>>> diagn = Independent(normal, 1)
|
|
>>> [diagn.batch_shape, diagn.event_shape]
|
|
[torch.Size(()), torch.Size((3,))]
|
|
|
|
Args:
|
|
base_distribution (torch.distributions.distribution.Distribution): a
|
|
base distribution
|
|
reinterpreted_batch_ndims (int): the number of batch dims to
|
|
reinterpret as event dims
|
|
"""
|
|
arg_constraints = {}
|
|
|
|
def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None):
|
|
if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
|
|
raise ValueError("Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
|
|
"actual {} vs {}".format(reinterpreted_batch_ndims,
|
|
len(base_distribution.batch_shape)))
|
|
shape = base_distribution.batch_shape + base_distribution.event_shape
|
|
event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)
|
|
batch_shape = shape[:len(shape) - event_dim]
|
|
event_shape = shape[len(shape) - event_dim:]
|
|
self.base_dist = base_distribution
|
|
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
|
|
super(Independent, self).__init__(batch_shape, event_shape, validate_args=validate_args)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(Independent, _instance)
|
|
batch_shape = torch.Size(batch_shape)
|
|
new.base_dist = self.base_dist.expand(batch_shape +
|
|
self.event_shape[:self.reinterpreted_batch_ndims])
|
|
new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
|
|
super(Independent, new).__init__(batch_shape, self.event_shape, validate_args=False)
|
|
new._validate_args = self._validate_args
|
|
return new
|
|
|
|
@property
|
|
def has_rsample(self):
|
|
return self.base_dist.has_rsample
|
|
|
|
@property
|
|
def has_enumerate_support(self):
|
|
if self.reinterpreted_batch_ndims > 0:
|
|
return False
|
|
return self.base_dist.has_enumerate_support
|
|
|
|
@constraints.dependent_property
|
|
def support(self):
|
|
return self.base_dist.support
|
|
|
|
@property
|
|
def mean(self):
|
|
return self.base_dist.mean
|
|
|
|
@property
|
|
def variance(self):
|
|
return self.base_dist.variance
|
|
|
|
def sample(self, sample_shape=torch.Size()):
|
|
return self.base_dist.sample(sample_shape)
|
|
|
|
def rsample(self, sample_shape=torch.Size()):
|
|
return self.base_dist.rsample(sample_shape)
|
|
|
|
def log_prob(self, value):
|
|
log_prob = self.base_dist.log_prob(value)
|
|
return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
|
|
|
|
def entropy(self):
|
|
entropy = self.base_dist.entropy()
|
|
return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
|
|
|
|
def enumerate_support(self, expand=True):
|
|
if self.reinterpreted_batch_ndims > 0:
|
|
raise NotImplementedError("Enumeration over cartesian product is not implemented")
|
|
return self.base_dist.enumerate_support(expand=expand)
|