Fix precision issue with expansion that prefers 'probs' over 'logits' (#18614)

Summary:
I have experienced that sometimes both were in `__dict__`, but it chose to copy `probs` which loses precision over `logits`. This is especially important when training (bayesian) neural networks or doing other type of optimization, since the loss is heavily affected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18614

Differential Revision: D14793486

Pulled By: ezyang

fbshipit-source-id: d4ff5e34fbb4021ea9de9f58af09a7de00d80a63
This commit is contained in:
Ahmad Salim Al-Sibahi
2019-04-05 12:45:37 -07:00
committed by Facebook Github Bot
parent b90cbb841d
commit 8e1e29124d
6 changed files with 6 additions and 6 deletions

View File

@ -45,7 +45,7 @@ class NegativeBinomial(Distribution):
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
else:
if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(NegativeBinomial, new).__init__(batch_shape, validate_args=False)