Fix documentation for attention mask shape (#20850)

Summary:
Attention mask should be of shape `(L, S)` since it is added to `attn_output_weights`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20850

Differential Revision: D15495587

Pulled By: ezyang

fbshipit-source-id: 61d6801da5291df960daab273e874df28aedbf6e
This commit is contained in:
Josef Lindman Hörnlund
2019-05-24 09:07:09 -07:00
committed by Facebook Github Bot
parent a5c90aaf47
commit 87040af498
2 changed files with 2 additions and 2 deletions

View File

@ -3126,7 +3126,7 @@ def multi_head_attention_forward(query, # type: Tensor
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
- attn_mask: :math:`(L, L)` where L is the target sequence length.
- attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,

View File

@ -764,7 +764,7 @@ class MultiheadAttention(Module):
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
- attn_mask: :math:`(L, L)` where L is the target sequence length.
- attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,