mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix precision issue with expansion that prefers 'probs' over 'logits' (#18614)
Summary: I have experienced that sometimes both were in `__dict__`, but it chose to copy `probs` which loses precision over `logits`. This is especially important when training (bayesian) neural networks or doing other type of optimization, since the loss is heavily affected. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18614 Differential Revision: D14793486 Pulled By: ezyang fbshipit-source-id: d4ff5e34fbb4021ea9de9f58af09a7de00d80a63
This commit is contained in:
committed by
Facebook Github Bot
parent
b90cbb841d
commit
8e1e29124d
@ -53,7 +53,7 @@ class Bernoulli(ExponentialFamily):
|
||||
if 'probs' in self.__dict__:
|
||||
new.probs = self.probs.expand(batch_shape)
|
||||
new._param = new.probs
|
||||
else:
|
||||
if 'logits' in self.__dict__:
|
||||
new.logits = self.logits.expand(batch_shape)
|
||||
new._param = new.logits
|
||||
super(Bernoulli, new).__init__(batch_shape, validate_args=False)
|
||||
|
@ -58,7 +58,7 @@ class Binomial(Distribution):
|
||||
if 'probs' in self.__dict__:
|
||||
new.probs = self.probs.expand(batch_shape)
|
||||
new._param = new.probs
|
||||
else:
|
||||
if 'logits' in self.__dict__:
|
||||
new.logits = self.logits.expand(batch_shape)
|
||||
new._param = new.logits
|
||||
super(Binomial, new).__init__(batch_shape, validate_args=False)
|
||||
|
@ -64,7 +64,7 @@ class Categorical(Distribution):
|
||||
if 'probs' in self.__dict__:
|
||||
new.probs = self.probs.expand(param_shape)
|
||||
new._param = new.probs
|
||||
else:
|
||||
if 'logits' in self.__dict__:
|
||||
new.logits = self.logits.expand(param_shape)
|
||||
new._param = new.logits
|
||||
new._num_events = self._num_events
|
||||
|
@ -51,7 +51,7 @@ class Geometric(Distribution):
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
if 'probs' in self.__dict__:
|
||||
new.probs = self.probs.expand(batch_shape)
|
||||
else:
|
||||
if 'logits' in self.__dict__:
|
||||
new.logits = self.logits.expand(batch_shape)
|
||||
super(Geometric, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
|
@ -45,7 +45,7 @@ class NegativeBinomial(Distribution):
|
||||
if 'probs' in self.__dict__:
|
||||
new.probs = self.probs.expand(batch_shape)
|
||||
new._param = new.probs
|
||||
else:
|
||||
if 'logits' in self.__dict__:
|
||||
new.logits = self.logits.expand(batch_shape)
|
||||
new._param = new.logits
|
||||
super(NegativeBinomial, new).__init__(batch_shape, validate_args=False)
|
||||
|
@ -54,7 +54,7 @@ class LogitRelaxedBernoulli(Distribution):
|
||||
if 'probs' in self.__dict__:
|
||||
new.probs = self.probs.expand(batch_shape)
|
||||
new._param = new.probs
|
||||
else:
|
||||
if 'logits' in self.__dict__:
|
||||
new.logits = self.logits.expand(batch_shape)
|
||||
new._param = new.logits
|
||||
super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False)
|
||||
|
Reference in New Issue
Block a user