mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix log_prob() in torch.distributions.Uniform, HalfCauchy and Gamma (#23017)
Summary: This fixes https://github.com/pytorch/pytorch/issues/22970. Specifically, `torch.distributions.uniform.Uniform.log_prob()` now works even if `value` is passed as a python float. Pull Request resolved: https://github.com/pytorch/pytorch/pull/23017 Differential Revision: D16383258 Pulled By: vincentqb fbshipit-source-id: 26943c33431d6da6f47e0897d6eda1c5f5541d28
This commit is contained in:
committed by
Facebook Github Bot
parent
b9a5188178
commit
632aeb034d
@ -70,8 +70,8 @@ class Uniform(Distribution):
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
lb = value.ge(self.low).type_as(self.low)
|
||||
ub = value.lt(self.high).type_as(self.low)
|
||||
lb = self.low.le(value).type_as(self.low)
|
||||
ub = self.high.gt(value).type_as(self.low)
|
||||
return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)
|
||||
|
||||
def cdf(self, value):
|
||||
|
||||
Reference in New Issue
Block a user