Apply UFMT to low traffic torch modules (#106249)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
This commit is contained in:
Edward Z. Yang
2023-07-29 10:51:26 -04:00
committed by PyTorch MergeBot
parent a4ebc61f15
commit 3bf922a6ce
163 changed files with 8472 additions and 4412 deletions

View File

@ -1,12 +1,14 @@
from typing import Dict
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.independent import Independent
from torch.distributions.transforms import ComposeTransform, Transform
from torch.distributions.utils import _sum_rightmost
from typing import Dict
__all__ = ['TransformedDistribution']
__all__ = ["TransformedDistribution"]
class TransformedDistribution(Distribution):
r"""
@ -45,36 +47,51 @@ class TransformedDistribution(Distribution):
def __init__(self, base_distribution, transforms, validate_args=None):
if isinstance(transforms, Transform):
self.transforms = [transforms, ]
self.transforms = [
transforms,
]
elif isinstance(transforms, list):
if not all(isinstance(t, Transform) for t in transforms):
raise ValueError("transforms must be a Transform or a list of Transforms")
raise ValueError(
"transforms must be a Transform or a list of Transforms"
)
self.transforms = transforms
else:
raise ValueError(f"transforms must be a Transform or list, but was {transforms}")
raise ValueError(
f"transforms must be a Transform or list, but was {transforms}"
)
# Reshape base_distribution according to transforms.
base_shape = base_distribution.batch_shape + base_distribution.event_shape
base_event_dim = len(base_distribution.event_shape)
transform = ComposeTransform(self.transforms)
if len(base_shape) < transform.domain.event_dim:
raise ValueError("base_distribution needs to have shape with size at least {}, but got {}."
.format(transform.domain.event_dim, base_shape))
raise ValueError(
"base_distribution needs to have shape with size at least {}, but got {}.".format(
transform.domain.event_dim, base_shape
)
)
forward_shape = transform.forward_shape(base_shape)
expanded_base_shape = transform.inverse_shape(forward_shape)
if base_shape != expanded_base_shape:
base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim]
base_batch_shape = expanded_base_shape[
: len(expanded_base_shape) - base_event_dim
]
base_distribution = base_distribution.expand(base_batch_shape)
reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
if reinterpreted_batch_ndims > 0:
base_distribution = Independent(base_distribution, reinterpreted_batch_ndims)
base_distribution = Independent(
base_distribution, reinterpreted_batch_ndims
)
self.base_dist = base_distribution
# Compute shapes.
transform_change_in_event_dim = transform.codomain.event_dim - transform.domain.event_dim
transform_change_in_event_dim = (
transform.codomain.event_dim - transform.domain.event_dim
)
event_dim = max(
transform.codomain.event_dim, # the transform is coupled
base_event_dim + transform_change_in_event_dim # the base dist is coupled
base_event_dim + transform_change_in_event_dim, # the base dist is coupled
)
assert len(forward_shape) >= event_dim
cut = len(forward_shape) - event_dim
@ -88,10 +105,12 @@ class TransformedDistribution(Distribution):
shape = batch_shape + self.event_shape
for t in reversed(self.transforms):
shape = t.inverse_shape(shape)
base_batch_shape = shape[:len(shape) - len(self.base_dist.event_shape)]
base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)]
new.base_dist = self.base_dist.expand(base_batch_shape)
new.transforms = self.transforms
super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False)
super(TransformedDistribution, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
@ -101,7 +120,9 @@ class TransformedDistribution(Distribution):
return self.base_dist.support
support = self.transforms[-1].codomain
if len(self.event_shape) > support.event_dim:
support = constraints.independent(support, len(self.event_shape) - support.event_dim)
support = constraints.independent(
support, len(self.event_shape) - support.event_dim
)
return support
@property
@ -146,12 +167,15 @@ class TransformedDistribution(Distribution):
for transform in reversed(self.transforms):
x = transform.inv(y)
event_dim += transform.domain.event_dim - transform.codomain.event_dim
log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim)
log_prob = log_prob - _sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
y = x
log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
event_dim - len(self.base_dist.event_shape))
log_prob = log_prob + _sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
)
return log_prob
def _monotonize_cdf(self, value):