diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 703e8b962a43..2909f9b36ef3 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -3126,7 +3126,7 @@ def multi_head_attention_forward(query, # type: Tensor - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. - - attn_mask: :math:`(L, L)` where L is the target sequence length. + - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 37f0fbaf6307..3c33ca6a2b61 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -764,7 +764,7 @@ class MultiheadAttention(Module): - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. - - attn_mask: :math:`(L, L)` where L is the target sequence length. + - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,