mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove deprecated tensor constructors in torch.distributions (#19979)
Summary: This removes the deprecated `tensor.new_*` constructors (see #16770) from `torch.distributions` module. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19979 Differential Revision: D15195618 Pulled By: soumith fbshipit-source-id: 46b519bfd32017265e90bd5c53f12cfe4a138021
This commit is contained in:
committed by
Facebook Github Bot
parent
792bc56ec2
commit
fb40e58f24
@ -55,10 +55,7 @@ class Beta(ExponentialFamily):
|
||||
(total.pow(2) * (total + 1)))
|
||||
|
||||
def rsample(self, sample_shape=()):
|
||||
value = self._dirichlet.rsample(sample_shape).select(-1, 0)
|
||||
if isinstance(value, Number):
|
||||
value = self._dirichlet.concentration.new_tensor(value)
|
||||
return value
|
||||
return self._dirichlet.rsample(sample_shape).select(-1, 0)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
|
@ -93,11 +93,11 @@ class Categorical(Distribution):
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self.probs.new_tensor(nan).expand(self._extended_shape())
|
||||
return torch.full(self._extended_shape(), nan, dtype=self.probs.dtype, device=self.probs.device)
|
||||
|
||||
@property
|
||||
def variance(self):
|
||||
return self.probs.new_tensor(nan).expand(self._extended_shape())
|
||||
return torch.full(self._extended_shape(), nan, dtype=self.probs.dtype, device=self.probs.device)
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
sample_shape = self._extended_shape(sample_shape)
|
||||
|
@ -47,11 +47,11 @@ class Cauchy(Distribution):
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self.loc.new_tensor(nan).expand(self._extended_shape())
|
||||
return torch.full(self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device)
|
||||
|
||||
@property
|
||||
def variance(self):
|
||||
return self.loc.new_tensor(inf).expand(self._extended_shape())
|
||||
return torch.full(self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device)
|
||||
|
||||
def rsample(self, sample_shape=torch.Size()):
|
||||
shape = self._extended_shape(sample_shape)
|
||||
|
@ -116,7 +116,7 @@ def _infinite_like(tensor):
|
||||
"""
|
||||
Helper function for obtaining infinite KL Divergence throughout
|
||||
"""
|
||||
return tensor.new_tensor(inf).expand_as(tensor)
|
||||
return torch.full_like(tensor, inf)
|
||||
|
||||
|
||||
def _x_log_x(tensor):
|
||||
|
@ -77,7 +77,7 @@ class ExpRelaxedCategorical(Distribution):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
logits, value = broadcast_all(self.logits, value)
|
||||
log_scale = (self.temperature.new_tensor(float(K)).lgamma() -
|
||||
log_scale = (torch.full_like(self.temperature, float(K)).lgamma() -
|
||||
self.temperature.log().mul(-(K - 1)))
|
||||
score = logits - value.mul(self.temperature)
|
||||
score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
|
||||
|
@ -434,7 +434,7 @@ class AffineTransform(Transform):
|
||||
shape = x.shape
|
||||
scale = self.scale
|
||||
if isinstance(scale, numbers.Number):
|
||||
result = x.new_empty(shape).fill_(math.log(abs(scale)))
|
||||
result = torch.full_like(x, math.log(abs(scale)))
|
||||
else:
|
||||
result = torch.abs(scale).log()
|
||||
if self.event_dim:
|
||||
|
@ -4,11 +4,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# promote numbers to tensors of dtype torch.get_default_dtype()
|
||||
def _default_promotion(v):
|
||||
return torch.tensor(v, dtype=torch.get_default_dtype())
|
||||
|
||||
|
||||
def broadcast_all(*values):
|
||||
r"""
|
||||
Given a list of values (possibly containing numbers), returns a list where each
|
||||
@ -28,12 +23,13 @@ def broadcast_all(*values):
|
||||
if not all(torch.is_tensor(v) or isinstance(v, Number) for v in values):
|
||||
raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.')
|
||||
if not all(map(torch.is_tensor, values)):
|
||||
new_tensor = _default_promotion
|
||||
options = dict(dtype=torch.get_default_dtype())
|
||||
for value in values:
|
||||
if torch.is_tensor(value):
|
||||
new_tensor = value.new_tensor
|
||||
options = dict(dtype=value.dtype, device=value.device)
|
||||
break
|
||||
values = [v if torch.is_tensor(v) else new_tensor(v) for v in values]
|
||||
values = [v if torch.is_tensor(v) else torch.tensor(v, **options)
|
||||
for v in values]
|
||||
return torch.broadcast_tensors(*values)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user