mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable distribution validation if __debug__ (#48743)
Summary: Fixes https://github.com/pytorch/pytorch/issues/47123 Follows https://github.com/pyro-ppl/pyro/pull/2701 This turns on `Distribution` validation by default. The motivation is to favor beginners by providing helpful error messages. Advanced users focused on speed can disable validation by calling ```py torch.distributions.Distribution.set_default_validate_args(False) ``` or by disabling individual distribution validation via `MyDistribution(..., validate_args=False)`. In practice I have found many beginners forget or do not know about validation. Therefore I have [enabled it by default](https://github.com/pyro-ppl/pyro/pull/2701) in Pyro. I believe PyTorch could also benefit from this change. Indeed validation caught a number of bugs in `.icdf()` methods, in tests, and in PPL benchmarks, all of which have been fixed in this PR. ## Release concerns - This may slightly slow down some models. Concerned users may disable validation. - This may cause new `ValueErrors` in models that rely on unsupported behavior, e.g. `Categorical.log_prob()` applied to continuous-valued tensors (only {0,1}-valued tensors are supported). We should clearly note this change in release notes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/48743 Reviewed By: heitorschueroff Differential Revision: D25304247 Pulled By: neerajprad fbshipit-source-id: 8d50f28441321ae691f848c55f71aa80cb356b41
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e3c56ddde6
commit
093aca082e
@ -68,8 +68,6 @@ class Exponential(ExponentialFamily):
|
||||
return 1 - torch.exp(-self.rate * value)
|
||||
|
||||
def icdf(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return -torch.log(1 - value) / self.rate
|
||||
|
||||
def entropy(self):
|
||||
|
||||
Reference in New Issue
Block a user