mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix documentation for attention mask shape (#20850)
Summary: Attention mask should be of shape `(L, S)` since it is added to `attn_output_weights`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20850 Differential Revision: D15495587 Pulled By: ezyang fbshipit-source-id: 61d6801da5291df960daab273e874df28aedbf6e
This commit is contained in:
committed by
Facebook Github Bot
parent
a5c90aaf47
commit
87040af498
@ -3126,7 +3126,7 @@ def multi_head_attention_forward(query, # type: Tensor
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
||||
- attn_mask: :math:`(L, L)` where L is the target sequence length.
|
||||
- attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
|
||||
Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
|
@ -764,7 +764,7 @@ class MultiheadAttention(Module):
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
||||
- attn_mask: :math:`(L, L)` where L is the target sequence length.
|
||||
- attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
|
||||
- Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
|
Reference in New Issue
Block a user