mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: This adds tests in tests/test_distributions.py to ensure that all methods of `Distribution` objects are jittable. I've replaced a few samplers with jittable versions: - `.uniform_()` -> `torch.rand()` - `.exponential_()` -> `-(-torch.rand()).log1p()` - `.normal_()` -> `torch.normal(torch.zeros(...), torch.ones(...), ...)` Some jit failures remain, and are marked in test_distributions.py - `Cauchy` and `HalfCauchy` do not support sampling due to missing `.cauchy_()` - `Binomial` does not support `.enumerate_support()` due to `arange` ignoring its first arg. - `MultivariateNormal`, `LowRankMultivariateNormal` do not support `.mean`, `.entropy` - [x] Currently some tests fail (I've skipped those) due to unavailability of `aten::uniform` and `aten::cauchy` in the jit. Can someone suggest how to add these? I tried to add declarations to `torch/csrc/ir.cpp` and `torch/csrc/passes/shape_analysis.cpp`, but that resulted in "Couldn't find operator" errors. - [x] There are still lots of `TracerWarning`s that something doesn't match something. I'm not sure whether these are real. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11560 Differential Revision: D9816327 Pulled By: apaszke fbshipit-source-id: 72ec998ea13fc4c76d1ed003d9502e0fbaf728b8
150 lines
5.2 KiB
Python
150 lines
5.2 KiB
Python
from collections import namedtuple
|
|
from functools import update_wrapper
|
|
from numbers import Number
|
|
import math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
# This follows semantics of numpy.finfo.
|
|
_Finfo = namedtuple('_Finfo', ['eps', 'tiny'])
|
|
_FINFO = {
|
|
torch.HalfStorage: _Finfo(eps=0.00097656, tiny=6.1035e-05),
|
|
torch.FloatStorage: _Finfo(eps=1.19209e-07, tiny=1.17549e-38),
|
|
torch.DoubleStorage: _Finfo(eps=2.22044604925e-16, tiny=2.22507385851e-308),
|
|
torch.cuda.HalfStorage: _Finfo(eps=0.00097656, tiny=6.1035e-05),
|
|
torch.cuda.FloatStorage: _Finfo(eps=1.19209e-07, tiny=1.17549e-38),
|
|
torch.cuda.DoubleStorage: _Finfo(eps=2.22044604925e-16, tiny=2.22507385851e-308),
|
|
}
|
|
|
|
|
|
def _finfo(tensor):
|
|
r"""
|
|
Return floating point info about a `Tensor`:
|
|
- `.eps` is the smallest number that can be added to 1 without being lost.
|
|
- `.tiny` is the smallest positive number greater than zero
|
|
(much smaller than `.eps`).
|
|
|
|
Args:
|
|
tensor (Tensor): tensor of floating point data.
|
|
Returns:
|
|
_Finfo: a `namedtuple` with fields `.eps` and `.tiny`.
|
|
"""
|
|
return _FINFO[tensor.storage_type()]
|
|
|
|
|
|
# promote numbers to tensors of dtype torch.get_default_dtype()
|
|
def _default_promotion(v):
|
|
return torch.tensor(v, dtype=torch.get_default_dtype())
|
|
|
|
|
|
def broadcast_all(*values):
|
|
r"""
|
|
Given a list of values (possibly containing numbers), returns a list where each
|
|
value is broadcasted based on the following rules:
|
|
- `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`.
|
|
- numbers.Number instances (scalars) are upcast to tensors having
|
|
the same size and type as the first tensor passed to `values`. If all the
|
|
values are scalars, then they are upcasted to scalar Tensors.
|
|
|
|
Args:
|
|
values (list of `numbers.Number` or `torch.*Tensor`)
|
|
|
|
Raises:
|
|
ValueError: if any of the values is not a `numbers.Number` or
|
|
`torch.*Tensor` instance
|
|
"""
|
|
if not all(torch.is_tensor(v) or isinstance(v, Number) for v in values):
|
|
raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.')
|
|
if not all(map(torch.is_tensor, values)):
|
|
new_tensor = _default_promotion
|
|
for value in values:
|
|
if torch.is_tensor(value):
|
|
new_tensor = value.new_tensor
|
|
break
|
|
values = [v if torch.is_tensor(v) else new_tensor(v) for v in values]
|
|
return torch.broadcast_tensors(*values)
|
|
|
|
|
|
def _standard_normal(shape, dtype, device):
|
|
if torch._C._get_tracing_state():
|
|
# [JIT WORKAROUND] lack of support for .normal_()
|
|
return torch.normal(torch.zeros(shape, dtype=dtype, device=device),
|
|
torch.ones(shape, dtype=dtype, device=device))
|
|
return torch.empty(shape, dtype=dtype, device=device).normal_()
|
|
|
|
|
|
def _sum_rightmost(value, dim):
|
|
r"""
|
|
Sum out ``dim`` many rightmost dimensions of a given tensor.
|
|
|
|
Args:
|
|
value (Tensor): A tensor of ``.dim()`` at least ``dim``.
|
|
dim (int): The number of rightmost dims to sum out.
|
|
"""
|
|
if dim == 0:
|
|
return value
|
|
required_shape = value.shape[:-dim] + (-1,)
|
|
return value.reshape(required_shape).sum(-1)
|
|
|
|
|
|
def logits_to_probs(logits, is_binary=False):
|
|
r"""
|
|
Converts a tensor of logits into probabilities. Note that for the
|
|
binary case, each value denotes log odds, whereas for the
|
|
multi-dimensional case, the values along the last dimension denote
|
|
the log probabilities (possibly unnormalized) of the events.
|
|
"""
|
|
if is_binary:
|
|
return torch.sigmoid(logits)
|
|
return F.softmax(logits, dim=-1)
|
|
|
|
|
|
def clamp_probs(probs):
|
|
eps = _finfo(probs).eps
|
|
return probs.clamp(min=eps, max=1 - eps)
|
|
|
|
|
|
def probs_to_logits(probs, is_binary=False):
|
|
r"""
|
|
Converts a tensor of probabilities into logits. For the binary case,
|
|
this denotes the probability of occurrence of the event indexed by `1`.
|
|
For the multi-dimensional case, the values along the last dimension
|
|
denote the probabilities of occurrence of each of the events.
|
|
"""
|
|
ps_clamped = clamp_probs(probs)
|
|
if is_binary:
|
|
return torch.log(ps_clamped) - torch.log1p(-ps_clamped)
|
|
return torch.log(ps_clamped)
|
|
|
|
|
|
def batch_tril(bmat, diagonal=0):
|
|
"""
|
|
Given a batch of matrices, returns the lower triangular part of each matrix, with
|
|
the other entries set to 0. The argument `diagonal` has the same meaning as in
|
|
`torch.tril`.
|
|
"""
|
|
if bmat.dim() == 2:
|
|
return bmat.tril(diagonal=diagonal)
|
|
else:
|
|
return bmat * torch.tril(bmat.new(*bmat.shape[-2:]).fill_(1.0), diagonal=diagonal)
|
|
|
|
|
|
class lazy_property(object):
|
|
r"""
|
|
Used as a decorator for lazy loading of class attributes. This uses a
|
|
non-data descriptor that calls the wrapped method to compute the property on
|
|
first call; thereafter replacing the wrapped method into an instance
|
|
attribute.
|
|
"""
|
|
def __init__(self, wrapped):
|
|
self.wrapped = wrapped
|
|
update_wrapper(self, wrapped)
|
|
|
|
def __get__(self, instance, obj_type=None):
|
|
if instance is None:
|
|
return self
|
|
with torch.enable_grad():
|
|
value = self.wrapped(instance)
|
|
setattr(instance, self.wrapped.__name__, value)
|
|
return value
|