mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
a4ebc61f15
commit
3bf922a6ce
@ -1,12 +1,14 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.independent import Independent
|
||||
from torch.distributions.transforms import ComposeTransform, Transform
|
||||
from torch.distributions.utils import _sum_rightmost
|
||||
from typing import Dict
|
||||
|
||||
__all__ = ['TransformedDistribution']
|
||||
__all__ = ["TransformedDistribution"]
|
||||
|
||||
|
||||
class TransformedDistribution(Distribution):
|
||||
r"""
|
||||
@ -45,36 +47,51 @@ class TransformedDistribution(Distribution):
|
||||
|
||||
def __init__(self, base_distribution, transforms, validate_args=None):
|
||||
if isinstance(transforms, Transform):
|
||||
self.transforms = [transforms, ]
|
||||
self.transforms = [
|
||||
transforms,
|
||||
]
|
||||
elif isinstance(transforms, list):
|
||||
if not all(isinstance(t, Transform) for t in transforms):
|
||||
raise ValueError("transforms must be a Transform or a list of Transforms")
|
||||
raise ValueError(
|
||||
"transforms must be a Transform or a list of Transforms"
|
||||
)
|
||||
self.transforms = transforms
|
||||
else:
|
||||
raise ValueError(f"transforms must be a Transform or list, but was {transforms}")
|
||||
raise ValueError(
|
||||
f"transforms must be a Transform or list, but was {transforms}"
|
||||
)
|
||||
|
||||
# Reshape base_distribution according to transforms.
|
||||
base_shape = base_distribution.batch_shape + base_distribution.event_shape
|
||||
base_event_dim = len(base_distribution.event_shape)
|
||||
transform = ComposeTransform(self.transforms)
|
||||
if len(base_shape) < transform.domain.event_dim:
|
||||
raise ValueError("base_distribution needs to have shape with size at least {}, but got {}."
|
||||
.format(transform.domain.event_dim, base_shape))
|
||||
raise ValueError(
|
||||
"base_distribution needs to have shape with size at least {}, but got {}.".format(
|
||||
transform.domain.event_dim, base_shape
|
||||
)
|
||||
)
|
||||
forward_shape = transform.forward_shape(base_shape)
|
||||
expanded_base_shape = transform.inverse_shape(forward_shape)
|
||||
if base_shape != expanded_base_shape:
|
||||
base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim]
|
||||
base_batch_shape = expanded_base_shape[
|
||||
: len(expanded_base_shape) - base_event_dim
|
||||
]
|
||||
base_distribution = base_distribution.expand(base_batch_shape)
|
||||
reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
|
||||
if reinterpreted_batch_ndims > 0:
|
||||
base_distribution = Independent(base_distribution, reinterpreted_batch_ndims)
|
||||
base_distribution = Independent(
|
||||
base_distribution, reinterpreted_batch_ndims
|
||||
)
|
||||
self.base_dist = base_distribution
|
||||
|
||||
# Compute shapes.
|
||||
transform_change_in_event_dim = transform.codomain.event_dim - transform.domain.event_dim
|
||||
transform_change_in_event_dim = (
|
||||
transform.codomain.event_dim - transform.domain.event_dim
|
||||
)
|
||||
event_dim = max(
|
||||
transform.codomain.event_dim, # the transform is coupled
|
||||
base_event_dim + transform_change_in_event_dim # the base dist is coupled
|
||||
base_event_dim + transform_change_in_event_dim, # the base dist is coupled
|
||||
)
|
||||
assert len(forward_shape) >= event_dim
|
||||
cut = len(forward_shape) - event_dim
|
||||
@ -88,10 +105,12 @@ class TransformedDistribution(Distribution):
|
||||
shape = batch_shape + self.event_shape
|
||||
for t in reversed(self.transforms):
|
||||
shape = t.inverse_shape(shape)
|
||||
base_batch_shape = shape[:len(shape) - len(self.base_dist.event_shape)]
|
||||
base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)]
|
||||
new.base_dist = self.base_dist.expand(base_batch_shape)
|
||||
new.transforms = self.transforms
|
||||
super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False)
|
||||
super(TransformedDistribution, new).__init__(
|
||||
batch_shape, self.event_shape, validate_args=False
|
||||
)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@ -101,7 +120,9 @@ class TransformedDistribution(Distribution):
|
||||
return self.base_dist.support
|
||||
support = self.transforms[-1].codomain
|
||||
if len(self.event_shape) > support.event_dim:
|
||||
support = constraints.independent(support, len(self.event_shape) - support.event_dim)
|
||||
support = constraints.independent(
|
||||
support, len(self.event_shape) - support.event_dim
|
||||
)
|
||||
return support
|
||||
|
||||
@property
|
||||
@ -146,12 +167,15 @@ class TransformedDistribution(Distribution):
|
||||
for transform in reversed(self.transforms):
|
||||
x = transform.inv(y)
|
||||
event_dim += transform.domain.event_dim - transform.codomain.event_dim
|
||||
log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
|
||||
event_dim - transform.domain.event_dim)
|
||||
log_prob = log_prob - _sum_rightmost(
|
||||
transform.log_abs_det_jacobian(x, y),
|
||||
event_dim - transform.domain.event_dim,
|
||||
)
|
||||
y = x
|
||||
|
||||
log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
|
||||
event_dim - len(self.base_dist.event_shape))
|
||||
log_prob = log_prob + _sum_rightmost(
|
||||
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
|
||||
)
|
||||
return log_prob
|
||||
|
||||
def _monotonize_cdf(self, value):
|
||||
|
||||
Reference in New Issue
Block a user