mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249 Approved by: https://github.com/Skylion007
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							a4ebc61f15
						
					
				
				
					commit
					3bf922a6ce
				
			| @ -1,12 +1,14 @@ | ||||
| from typing import Dict | ||||
|  | ||||
| import torch | ||||
| from torch.distributions import constraints | ||||
| from torch.distributions.distribution import Distribution | ||||
| from torch.distributions.independent import Independent | ||||
| from torch.distributions.transforms import ComposeTransform, Transform | ||||
| from torch.distributions.utils import _sum_rightmost | ||||
| from typing import Dict | ||||
|  | ||||
| __all__ = ['TransformedDistribution'] | ||||
| __all__ = ["TransformedDistribution"] | ||||
|  | ||||
|  | ||||
| class TransformedDistribution(Distribution): | ||||
|     r""" | ||||
| @ -45,36 +47,51 @@ class TransformedDistribution(Distribution): | ||||
|  | ||||
|     def __init__(self, base_distribution, transforms, validate_args=None): | ||||
|         if isinstance(transforms, Transform): | ||||
|             self.transforms = [transforms, ] | ||||
|             self.transforms = [ | ||||
|                 transforms, | ||||
|             ] | ||||
|         elif isinstance(transforms, list): | ||||
|             if not all(isinstance(t, Transform) for t in transforms): | ||||
|                 raise ValueError("transforms must be a Transform or a list of Transforms") | ||||
|                 raise ValueError( | ||||
|                     "transforms must be a Transform or a list of Transforms" | ||||
|                 ) | ||||
|             self.transforms = transforms | ||||
|         else: | ||||
|             raise ValueError(f"transforms must be a Transform or list, but was {transforms}") | ||||
|             raise ValueError( | ||||
|                 f"transforms must be a Transform or list, but was {transforms}" | ||||
|             ) | ||||
|  | ||||
|         # Reshape base_distribution according to transforms. | ||||
|         base_shape = base_distribution.batch_shape + base_distribution.event_shape | ||||
|         base_event_dim = len(base_distribution.event_shape) | ||||
|         transform = ComposeTransform(self.transforms) | ||||
|         if len(base_shape) < transform.domain.event_dim: | ||||
|             raise ValueError("base_distribution needs to have shape with size at least {}, but got {}." | ||||
|                              .format(transform.domain.event_dim, base_shape)) | ||||
|             raise ValueError( | ||||
|                 "base_distribution needs to have shape with size at least {}, but got {}.".format( | ||||
|                     transform.domain.event_dim, base_shape | ||||
|                 ) | ||||
|             ) | ||||
|         forward_shape = transform.forward_shape(base_shape) | ||||
|         expanded_base_shape = transform.inverse_shape(forward_shape) | ||||
|         if base_shape != expanded_base_shape: | ||||
|             base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim] | ||||
|             base_batch_shape = expanded_base_shape[ | ||||
|                 : len(expanded_base_shape) - base_event_dim | ||||
|             ] | ||||
|             base_distribution = base_distribution.expand(base_batch_shape) | ||||
|         reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim | ||||
|         if reinterpreted_batch_ndims > 0: | ||||
|             base_distribution = Independent(base_distribution, reinterpreted_batch_ndims) | ||||
|             base_distribution = Independent( | ||||
|                 base_distribution, reinterpreted_batch_ndims | ||||
|             ) | ||||
|         self.base_dist = base_distribution | ||||
|  | ||||
|         # Compute shapes. | ||||
|         transform_change_in_event_dim = transform.codomain.event_dim - transform.domain.event_dim | ||||
|         transform_change_in_event_dim = ( | ||||
|             transform.codomain.event_dim - transform.domain.event_dim | ||||
|         ) | ||||
|         event_dim = max( | ||||
|             transform.codomain.event_dim,  # the transform is coupled | ||||
|             base_event_dim + transform_change_in_event_dim  # the base dist is coupled | ||||
|             base_event_dim + transform_change_in_event_dim,  # the base dist is coupled | ||||
|         ) | ||||
|         assert len(forward_shape) >= event_dim | ||||
|         cut = len(forward_shape) - event_dim | ||||
| @ -88,10 +105,12 @@ class TransformedDistribution(Distribution): | ||||
|         shape = batch_shape + self.event_shape | ||||
|         for t in reversed(self.transforms): | ||||
|             shape = t.inverse_shape(shape) | ||||
|         base_batch_shape = shape[:len(shape) - len(self.base_dist.event_shape)] | ||||
|         base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)] | ||||
|         new.base_dist = self.base_dist.expand(base_batch_shape) | ||||
|         new.transforms = self.transforms | ||||
|         super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False) | ||||
|         super(TransformedDistribution, new).__init__( | ||||
|             batch_shape, self.event_shape, validate_args=False | ||||
|         ) | ||||
|         new._validate_args = self._validate_args | ||||
|         return new | ||||
|  | ||||
| @ -101,7 +120,9 @@ class TransformedDistribution(Distribution): | ||||
|             return self.base_dist.support | ||||
|         support = self.transforms[-1].codomain | ||||
|         if len(self.event_shape) > support.event_dim: | ||||
|             support = constraints.independent(support, len(self.event_shape) - support.event_dim) | ||||
|             support = constraints.independent( | ||||
|                 support, len(self.event_shape) - support.event_dim | ||||
|             ) | ||||
|         return support | ||||
|  | ||||
|     @property | ||||
| @ -146,12 +167,15 @@ class TransformedDistribution(Distribution): | ||||
|         for transform in reversed(self.transforms): | ||||
|             x = transform.inv(y) | ||||
|             event_dim += transform.domain.event_dim - transform.codomain.event_dim | ||||
|             log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y), | ||||
|                                                  event_dim - transform.domain.event_dim) | ||||
|             log_prob = log_prob - _sum_rightmost( | ||||
|                 transform.log_abs_det_jacobian(x, y), | ||||
|                 event_dim - transform.domain.event_dim, | ||||
|             ) | ||||
|             y = x | ||||
|  | ||||
|         log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y), | ||||
|                                              event_dim - len(self.base_dist.event_shape)) | ||||
|         log_prob = log_prob + _sum_rightmost( | ||||
|             self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape) | ||||
|         ) | ||||
|         return log_prob | ||||
|  | ||||
|     def _monotonize_cdf(self, value): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user