Remove deprecated tensor constructors in torch.distributions (#19979)

Summary:
This removes the deprecated `tensor.new_*` constructors (see #16770) from `torch.distributions` module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19979

Differential Revision: D15195618

Pulled By: soumith

fbshipit-source-id: 46b519bfd32017265e90bd5c53f12cfe4a138021
This commit is contained in:
Neeraj Pradhan
2019-05-02 20:31:00 -07:00
committed by Facebook Github Bot
parent 792bc56ec2
commit fb40e58f24
7 changed files with 12 additions and 19 deletions

View File

@ -55,10 +55,7 @@ class Beta(ExponentialFamily):
(total.pow(2) * (total + 1)))
def rsample(self, sample_shape=()):
value = self._dirichlet.rsample(sample_shape).select(-1, 0)
if isinstance(value, Number):
value = self._dirichlet.concentration.new_tensor(value)
return value
return self._dirichlet.rsample(sample_shape).select(-1, 0)
def log_prob(self, value):
if self._validate_args:

View File

@ -93,11 +93,11 @@ class Categorical(Distribution):
@property
def mean(self):
return self.probs.new_tensor(nan).expand(self._extended_shape())
return torch.full(self._extended_shape(), nan, dtype=self.probs.dtype, device=self.probs.device)
@property
def variance(self):
return self.probs.new_tensor(nan).expand(self._extended_shape())
return torch.full(self._extended_shape(), nan, dtype=self.probs.dtype, device=self.probs.device)
def sample(self, sample_shape=torch.Size()):
sample_shape = self._extended_shape(sample_shape)

View File

@ -47,11 +47,11 @@ class Cauchy(Distribution):
@property
def mean(self):
return self.loc.new_tensor(nan).expand(self._extended_shape())
return torch.full(self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device)
@property
def variance(self):
return self.loc.new_tensor(inf).expand(self._extended_shape())
return torch.full(self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device)
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)

View File

@ -116,7 +116,7 @@ def _infinite_like(tensor):
"""
Helper function for obtaining infinite KL Divergence throughout
"""
return tensor.new_tensor(inf).expand_as(tensor)
return torch.full_like(tensor, inf)
def _x_log_x(tensor):

View File

@ -77,7 +77,7 @@ class ExpRelaxedCategorical(Distribution):
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
log_scale = (self.temperature.new_tensor(float(K)).lgamma() -
log_scale = (torch.full_like(self.temperature, float(K)).lgamma() -
self.temperature.log().mul(-(K - 1)))
score = logits - value.mul(self.temperature)
score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)

View File

@ -434,7 +434,7 @@ class AffineTransform(Transform):
shape = x.shape
scale = self.scale
if isinstance(scale, numbers.Number):
result = x.new_empty(shape).fill_(math.log(abs(scale)))
result = torch.full_like(x, math.log(abs(scale)))
else:
result = torch.abs(scale).log()
if self.event_dim:

View File

@ -4,11 +4,6 @@ import torch
import torch.nn.functional as F
# promote numbers to tensors of dtype torch.get_default_dtype()
def _default_promotion(v):
return torch.tensor(v, dtype=torch.get_default_dtype())
def broadcast_all(*values):
r"""
Given a list of values (possibly containing numbers), returns a list where each
@ -28,12 +23,13 @@ def broadcast_all(*values):
if not all(torch.is_tensor(v) or isinstance(v, Number) for v in values):
raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.')
if not all(map(torch.is_tensor, values)):
new_tensor = _default_promotion
options = dict(dtype=torch.get_default_dtype())
for value in values:
if torch.is_tensor(value):
new_tensor = value.new_tensor
options = dict(dtype=value.dtype, device=value.device)
break
values = [v if torch.is_tensor(v) else new_tensor(v) for v in values]
values = [v if torch.is_tensor(v) else torch.tensor(v, **options)
for v in values]
return torch.broadcast_tensors(*values)