mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added an implementation of a multivariate normal distribution (#4950)
This commit is contained in:
@ -163,6 +163,18 @@ def probs_to_logits(probs, is_binary=False):
|
||||
return torch.log(ps_clamped)
|
||||
|
||||
|
||||
def batch_tril(bmat, diagonal=0):
|
||||
"""
|
||||
Given a batch of matrices, returns the lower triangular part of each matrix, with
|
||||
the other entries set to 0. The argument `diagonal` has the same meaning as in
|
||||
`torch.tril`.
|
||||
"""
|
||||
if bmat.dim() == 2:
|
||||
return bmat.tril(diagonal=diagonal)
|
||||
else:
|
||||
return bmat * torch.tril(bmat.new(*bmat.shape[-2:]).fill_(1.0), diagonal=diagonal)
|
||||
|
||||
|
||||
class lazy_property(object):
|
||||
r"""
|
||||
Used as a decorator for lazy loading of class attributes. This uses a
|
||||
|
Reference in New Issue
Block a user