Added an implementation of a multivariate normal distribution (#4950)

This commit is contained in:
Brooks
2018-03-19 22:22:46 +00:00
committed by Adam Paszke
parent 7e13138eb6
commit 1936753708
6 changed files with 316 additions and 5 deletions

View File

@ -163,6 +163,18 @@ def probs_to_logits(probs, is_binary=False):
return torch.log(ps_clamped)
def batch_tril(bmat, diagonal=0):
"""
Given a batch of matrices, returns the lower triangular part of each matrix, with
the other entries set to 0. The argument `diagonal` has the same meaning as in
`torch.tril`.
"""
if bmat.dim() == 2:
return bmat.tril(diagonal=diagonal)
else:
return bmat * torch.tril(bmat.new(*bmat.shape[-2:]).fill_(1.0), diagonal=diagonal)
class lazy_property(object):
r"""
Used as a decorator for lazy loading of class attributes. This uses a