mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Check for internal memory overlap in some indexing-type functions (#43423)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43423 Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D23298652 Pulled By: zou3519 fbshipit-source-id: c13c59aec0c6967ef0d6365d782c1f4c98c04227
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5807bb92d3
commit
c88ac25679
@ -101,7 +101,8 @@ class Multinomial(Distribution):
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
logits, value = broadcast_all(self.logits.clone(memory_format=torch.contiguous_format), value)
|
||||
logits, value = broadcast_all(self.logits, value)
|
||||
logits = logits.clone(memory_format=torch.contiguous_format)
|
||||
log_factorial_n = torch.lgamma(value.sum(-1) + 1)
|
||||
log_factorial_xs = torch.lgamma(value + 1).sum(-1)
|
||||
logits[(value == 0) & (logits == -inf)] = 0
|
||||
|
||||
Reference in New Issue
Block a user