Files
pytorch/torch/distributions/utils.py
gchanan e37f02469d Favor Variables over Tensors for scalar constructors in torch.distrib… (#4791)
* Favor Variables over Tensors for scalar constructors in torch.distributions.

Current behvior:
1) distribution constructors containing only python number elements will have their python numbers upcasted to Tensors.
2) Python number arguments of distribution constructors that also contain tensors and variables will be upcasted
to the first tensor/variable type.

This PR changes the above to favor Variables as follows:
1) The python numbers will now be upcasted to Variables
2) An error will be raised if the first tensor/variable type is not a Variable.

This is done in preparation for the introduction of Scalars (0-dimensional tensors), which are only available on the Variable API.
Note that we are (separately) merging Variable and Tensor, so this PR should have no real long-term effect.

Also note that the above means we don't change the behavior of constructors without python number arguments.

* Fix tests that require numpy.
2018-01-23 11:49:15 -05:00

173 lines
6.1 KiB
Python

from collections import namedtuple
from functools import update_wrapper
from numbers import Number
import math
import torch
import torch.nn.functional as F
from torch.autograd import Variable, variable
# This follows semantics of numpy.finfo.
_Finfo = namedtuple('_Finfo', ['eps', 'tiny'])
_FINFO = {
torch.HalfStorage: _Finfo(eps=0.00097656, tiny=6.1035e-05),
torch.FloatStorage: _Finfo(eps=1.19209e-07, tiny=1.17549e-38),
torch.DoubleStorage: _Finfo(eps=2.22044604925e-16, tiny=2.22507385851e-308),
torch.cuda.HalfStorage: _Finfo(eps=0.00097656, tiny=6.1035e-05),
torch.cuda.FloatStorage: _Finfo(eps=1.19209e-07, tiny=1.17549e-38),
torch.cuda.DoubleStorage: _Finfo(eps=2.22044604925e-16, tiny=2.22507385851e-308),
}
def _finfo(tensor):
"""
Return floating point info about a `Tensor` or `Variable`:
- `.eps` is the smallest number that can be added to 1 without being lost.
- `.tiny` is the smallest positive number greater than zero
(much smaller than `.eps`).
Args:
tensor (Tensor or Variable): tensor or variable of floating point data.
Returns:
_Finfo: a `namedtuple` with fields `.eps` and `.tiny`.
"""
return _FINFO[tensor.storage_type()]
def expand_n(v, n):
r"""
Cleanly expand float or Tensor or Variable parameters.
"""
if isinstance(v, Number):
return torch.Tensor([v]).expand(n, 1)
else:
return v.expand(n, *v.size())
def _broadcast_shape(shapes):
"""
Given a list of tensor sizes, returns the size of the resulting broadcasted
tensor.
Args:
shapes (list of torch.Size): list of tensor sizes
"""
shape = torch.Size([1])
for s in shapes:
shape = torch._C._infer_size(s, shape)
return shape
def broadcast_all(*values):
"""
Given a list of values (possibly containing numbers), returns a list where each
value is broadcasted based on the following rules:
- `torch.Tensor` and `torch.autograd.Variable` instances are broadcasted as
per the `broadcasting rules
<http://pytorch.org/docs/master/notes/broadcasting.html>`_
- numbers.Number instances (scalars) are upcast to Variables having
the same size and type as the first tensor passed to `values`. If all the
values are scalars, then they are upcasted to Variables having size
`(1,)`.
Args:
values (list of `numbers.Number`, `torch.autograd.Variable` or
`torch.Tensor`)
Raises:
ValueError: if any of the values is not a `numbers.Number`, `torch.Tensor`
or `torch.autograd.Variable` instance
"""
values = list(values)
scalar_idxs = [i for i in range(len(values)) if isinstance(values[i], Number)]
tensor_idxs = [i for i in range(len(values)) if
torch.is_tensor(values[i]) or isinstance(values[i], Variable)]
if len(scalar_idxs) + len(tensor_idxs) != len(values):
raise ValueError('Input arguments must all be instances of numbers.Number, torch.Tensor or ' +
'torch.autograd.Variable.')
if tensor_idxs:
broadcast_shape = _broadcast_shape([values[i].size() for i in tensor_idxs])
for idx in tensor_idxs:
values[idx] = values[idx].expand(broadcast_shape)
template = values[tensor_idxs[0]]
if len(scalar_idxs) > 0 and not isinstance(template, torch.autograd.Variable):
raise ValueError(('Input arguments containing instances of numbers.Number and torch.Tensor '
'are not currently supported. Use torch.autograd.Variable instead of torch.Tensor'))
for idx in scalar_idxs:
values[idx] = template.new(template.size()).fill_(values[idx])
else:
for idx in scalar_idxs:
values[idx] = variable(values[idx])
return values
def softmax(tensor):
"""
Wrapper around softmax to make it work with both Tensors and Variables.
TODO: Remove once https://github.com/pytorch/pytorch/issues/2633 is resolved.
"""
if not isinstance(tensor, Variable):
return F.softmax(Variable(tensor), -1).data
return F.softmax(tensor, -1)
def log_sum_exp(tensor, keepdim=True):
"""
Numerically stable implementation for the `LogSumExp` operation. The
summing is done along the last dimension.
Args:
tensor (torch.Tensor or torch.autograd.Variable)
keepdim (Boolean): Whether to retain the last dimension on summing.
"""
max_val = tensor.max(dim=-1, keepdim=True)[0]
return max_val + (tensor - max_val).exp().sum(dim=-1, keepdim=keepdim).log()
def logits_to_probs(logits, is_binary=False):
"""
Converts a tensor of logits into probabilities. Note that for the
binary case, each value denotes log odds, whereas for the
multi-dimensional case, the values along the last dimension denote
the log probabilities (possibly unnormalized) of the events.
"""
if is_binary:
return F.sigmoid(logits)
return softmax(logits)
def clamp_probs(probs):
eps = _finfo(probs).eps
return probs.clamp(min=eps, max=1 - eps)
def probs_to_logits(probs, is_binary=False):
"""
Converts a tensor of probabilities into logits. For the binary case,
this denotes the probability of occurrence of the event indexed by `1`.
For the multi-dimensional case, the values along the last dimension
denote the probabilities of occurrence of each of the events.
"""
ps_clamped = clamp_probs(probs)
if is_binary:
return torch.log(ps_clamped) - torch.log1p(-ps_clamped)
return torch.log(ps_clamped)
class lazy_property(object):
"""
Used as a decorator for lazy loading of class attributes. This uses a
non-data descriptor that calls the wrapped method to compute the property on
first call; thereafter replacing the wrapped method into an instance
attribute.
"""
def __init__(self, wrapped):
self.wrapped = wrapped
update_wrapper(self, wrapped)
def __get__(self, instance, obj_type=None):
if instance is None:
return self
value = self.wrapped(instance)
setattr(instance, self.wrapped.__name__, value)
return value