Fix log_prob() in torch.distributions.Uniform, HalfCauchy and Gamma (#23017)

Summary:
This fixes https://github.com/pytorch/pytorch/issues/22970. Specifically, `torch.distributions.uniform.Uniform.log_prob()` now works even if `value` is passed as a python float.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23017

Differential Revision: D16383258

Pulled By: vincentqb

fbshipit-source-id: 26943c33431d6da6f47e0897d6eda1c5f5541d28
This commit is contained in:
Vaibhav Sinha
2019-08-22 08:15:01 -07:00
committed by Facebook Github Bot
parent b9a5188178
commit 632aeb034d
4 changed files with 10 additions and 5 deletions

View File

@ -70,8 +70,8 @@ class Uniform(Distribution):
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
lb = value.ge(self.low).type_as(self.low)
ub = value.lt(self.high).type_as(self.low)
lb = self.low.le(value).type_as(self.low)
ub = self.high.gt(value).type_as(self.low)
return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)
def cdf(self, value):