Check for internal memory overlap in some indexing-type functions (#43423)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43423

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D23298652

Pulled By: zou3519

fbshipit-source-id: c13c59aec0c6967ef0d6365d782c1f4c98c04227
This commit is contained in:
Peter Bell
2020-09-02 08:44:13 -07:00
committed by Facebook GitHub Bot
parent 5807bb92d3
commit c88ac25679
9 changed files with 109 additions and 4 deletions

View File

@ -101,7 +101,8 @@ class Multinomial(Distribution):
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits.clone(memory_format=torch.contiguous_format), value)
logits, value = broadcast_all(self.logits, value)
logits = logits.clone(memory_format=torch.contiguous_format)
log_factorial_n = torch.lgamma(value.sum(-1) + 1)
log_factorial_xs = torch.lgamma(value + 1).sum(-1)
logits[(value == 0) & (logits == -inf)] = 0