From 87040af49863d5fd065ebf3feada58d9c2a84c59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Josef=20Lindman=20H=C3=B6rnlund?= Date: Fri, 24 May 2019 09:07:09 -0700 Subject: [PATCH] Fix documentation for attention mask shape (#20850) Summary: Attention mask should be of shape `(L, S)` since it is added to `attn_output_weights`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20850 Differential Revision: D15495587 Pulled By: ezyang fbshipit-source-id: 61d6801da5291df960daab273e874df28aedbf6e --- torch/nn/functional.py | 2 +- torch/nn/modules/activation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 703e8b962a43..2909f9b36ef3 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -3126,7 +3126,7 @@ def multi_head_attention_forward(query, # type: Tensor - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. - - attn_mask: :math:`(L, L)` where L is the target sequence length. + - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 37f0fbaf6307..3c33ca6a2b61 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -764,7 +764,7 @@ class MultiheadAttention(Module): - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. - - attn_mask: :math:`(L, L)` where L is the target sequence length. + - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,