update doc for multinomial (#17269)

Summary:
Update documentation to raise awareness of the fix in #12490. Thanks matteorr for pointing this out!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17269

Reviewed By: ezyang

Differential Revision: D14138421

Pulled By: ailzhang

fbshipit-source-id: 6433f9807a6ba1d871eba8e9d37aa6b78fa1e1fd
This commit is contained in:
Ailing Zhang
2019-02-19 15:23:27 -08:00
committed by Facebook Github Bot
parent d73e6cb59d
commit f827f9f77a

View File

@ -3110,8 +3110,10 @@ If replacement is ``True``, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a If not, they are drawn without replacement, which means that when a
sample index is drawn for a row, it cannot be drawn again for that row. sample index is drawn for a row, it cannot be drawn again for that row.
This implies the constraint that :attr:`num_samples` must be lower than .. note::
:attr:`input` length (or number of columns of :attr:`input` if it is a matrix). When drawn without replacement, :attr:`num_samples` must be lower than
number of non-zero elements in :attr:`input` (or the min number of non-zero
elements in each row of :attr:`input` if it is a matrix).
Args: Args:
input (Tensor): the input tensor containing probabilities input (Tensor): the input tensor containing probabilities
@ -3122,8 +3124,11 @@ Args:
Example:: Example::
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights >>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 4) >>> torch.multinomial(weights, 2)
tensor([ 1, 2, 0, 0]) tensor([1, 2])
>>> torch.multinomial(weights, 4) # ERROR!
RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False,
not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320
>>> torch.multinomial(weights, 4, replacement=True) >>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2, 1, 1, 1]) tensor([ 2, 1, 1, 1])
""") """)