Split nn.MultiHeadAttention into Module + functional (#20415)

Summary:
Moving functions from torch/nn/modules/activation.py to torch/nn/functional.py. For functions not implemented (_get_input_buffer and _set_input_buffer), a TODO is added.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20415

Differential Revision: D15318078

Pulled By: jamarshon

fbshipit-source-id: 5ca698e2913821442cf8609cc61ac8190496a3c6
This commit is contained in:
Jason Lian
2019-05-14 08:37:48 -07:00
committed by Facebook Github Bot
parent b46a630836
commit 6e82b1c77d
2 changed files with 205 additions and 124 deletions

View File

@ -3077,3 +3077,194 @@ def _pad_circular(input, padding):
input = torch.cat([input[:, :, :, :, -(padding[-5] + padding[-6]):-padding[-5]], input], dim=4)
return input
@weak_script
def multi_head_attention_forward(query, # type: Tensor
key, # type: Tensor
value, # type: Tensor
embed_dim_to_check, # type: int
num_heads, # type: int
in_proj_weight, # type: Tensor
in_proj_bias, # type: Tensor
bias_k, # type: Tensor
bias_v, # type: Tensor
add_zero_attn, # type: bool
dropout_p, # type: float
out_proj, # type: Tensor
training=True, # type: bool
key_padding_mask=None, # type: Optional[Tensor]
need_weights=True, # type: bool
attn_mask=None # type: Optional[Tensor]
):
# type: (...) -> Tuple[Tensor, Tensor]
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
dropout_p: probability of an element to be zeroed.
out_proj: the output projection.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention.
need_weights: output attn_output_weights.
attn_mask: mask that prevents attention to certain positions.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
- attn_mask: :math:`(L, L)` where L is the target sequence length.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
@weak_script
def _in_proj(input, weight, bias, start=0, end=None):
# type: (Tensor, Tensor, Optional[Tensor], int, Optional[int]) -> Tensor
weight = weight[start:end, :]
if bias is not None:
bias = bias[start:end]
return linear(input, weight, bias)
@weak_script
def _in_proj_qkv(weight, bias, query):
# type: (Tensor, Tensor, Tensor) -> Tensor
return _in_proj(query, weight, bias).chunk(3, dim=-1)
@weak_script
def _in_proj_kv(weight, bias, embed_dim, key):
# type: (Tensor, Tensor, int, Tensor) -> Tensor
return _in_proj(key, weight, bias, start=embed_dim).chunk(2, dim=-1)
@weak_script
def _in_proj_q(weight, bias, embed_dim, query):
# type: (Tensor, Tensor, int, Tensor) -> Tensor
return _in_proj(query, weight, bias, end=embed_dim)
@weak_script
def _in_proj_k(weight, bias, embed_dim, key):
# type: (Tensor, Tensor, int, Tensor) -> Tensor
return _in_proj(key, weight, bias, start=embed_dim, end=2 * embed_dim)
@weak_script
def _in_proj_v(weight, bias, embed_dim, value):
# type: (Tensor, Tensor, int, Tensor) -> Tensor
return _in_proj(value, weight, bias, start=2 * embed_dim)
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
kv_same = key.data_ptr() == value.data_ptr()
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
scaling = head_dim ** -0.5
if qkv_same:
# self-attention
q, k, v = _in_proj_qkv(in_proj_weight, in_proj_bias, query)
elif kv_same:
# encoder-decoder attention
q = _in_proj_q(in_proj_weight, in_proj_bias, embed_dim, query)
if key is None:
assert value is None
k = v = None
else:
k, v = _in_proj_kv(in_proj_weight, in_proj_bias, embed_dim, key)
else:
q = _in_proj_q(in_proj_weight, in_proj_bias, embed_dim, query)
k = _in_proj_k(in_proj_weight, in_proj_bias, embed_dim, key)
v = _in_proj_v(in_proj_weight, in_proj_bias, embed_dim, value)
q *= scaling
if bias_k is not None:
assert bias_v is not None
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if add_zero_attn:
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
attn_output_weights = softmax(
attn_output_weights.float(), dim=-1,
dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype)
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = out_proj(attn_output)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.sum(dim=1) / num_heads
else:
attn_output_weights = None
return attn_output, attn_output_weights

View File

@ -694,7 +694,7 @@ class MultiheadAttention(Module):
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
bias: add bias as module parameter. Default: True.
add_bias_kv: add bias to the key and value sequences at dim=0.
add_zero_attn: add a new batch of zeros to the key and
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
Examples::
@ -708,9 +708,6 @@ class MultiheadAttention(Module):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
if bias:
@ -748,143 +745,36 @@ class MultiheadAttention(Module):
need_weights=True, attn_mask=None):
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask: if provided, specified padding elements in the key will
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention.
need_weights: output attn_output_weights.
attn_mask: mask that prevents attention to certain positions.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
- attn_mask: :math:`(L, L)` where L is the target sequence length.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
kv_same = key.data_ptr() == value.data_ptr()
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
if qkv_same:
# self-attention
q, k, v = self._in_proj_qkv(query)
elif kv_same:
# encoder-decoder attention
q = self._in_proj_q(query)
if key is None:
assert value is None
k = v = None
else:
k, v = self._in_proj_kv(key)
else:
q = self._in_proj_q(query)
k = self._in_proj_k(key)
v = self._in_proj_v(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_output_weights = F.softmax(
attn_output_weights.float(), dim=-1,
dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype)
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = self.out_proj(attn_output)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads
else:
attn_output_weights = None
return attn_output, attn_output_weights
def _in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
def _in_proj_kv(self, key):
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
def _in_proj_q(self, query):
return self._in_proj(query, end=self.embed_dim)
def _in_proj_k(self, key):
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
def _in_proj_v(self, value):
return self._in_proj(value, start=2 * self.embed_dim)
def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight
bias = self.in_proj_bias
weight = weight[start:end, :]
if bias is not None:
bias = bias[start:end]
return F.linear(input, weight, bias)
return F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj, training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask)
@weak_module