mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Split nn.MultiHeadAttention into Module + functional (#20415)
Summary: Moving functions from torch/nn/modules/activation.py to torch/nn/functional.py. For functions not implemented (_get_input_buffer and _set_input_buffer), a TODO is added. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20415 Differential Revision: D15318078 Pulled By: jamarshon fbshipit-source-id: 5ca698e2913821442cf8609cc61ac8190496a3c6
This commit is contained in:
committed by
Facebook Github Bot
parent
b46a630836
commit
6e82b1c77d
@ -3077,3 +3077,194 @@ def _pad_circular(input, padding):
|
||||
input = torch.cat([input[:, :, :, :, -(padding[-5] + padding[-6]):-padding[-5]], input], dim=4)
|
||||
|
||||
return input
|
||||
|
||||
|
||||
@weak_script
|
||||
def multi_head_attention_forward(query, # type: Tensor
|
||||
key, # type: Tensor
|
||||
value, # type: Tensor
|
||||
embed_dim_to_check, # type: int
|
||||
num_heads, # type: int
|
||||
in_proj_weight, # type: Tensor
|
||||
in_proj_bias, # type: Tensor
|
||||
bias_k, # type: Tensor
|
||||
bias_v, # type: Tensor
|
||||
add_zero_attn, # type: bool
|
||||
dropout_p, # type: float
|
||||
out_proj, # type: Tensor
|
||||
training=True, # type: bool
|
||||
key_padding_mask=None, # type: Optional[Tensor]
|
||||
need_weights=True, # type: bool
|
||||
attn_mask=None # type: Optional[Tensor]
|
||||
):
|
||||
# type: (...) -> Tuple[Tensor, Tensor]
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
embed_dim_to_check: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
||||
add_zero_attn: add a new batch of zeros to the key and
|
||||
value sequences at dim=1.
|
||||
dropout_p: probability of an element to be zeroed.
|
||||
out_proj: the output projection.
|
||||
training: apply dropout if is ``True``.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention.
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: mask that prevents attention to certain positions.
|
||||
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
||||
- attn_mask: :math:`(L, L)` where L is the target sequence length.
|
||||
|
||||
Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
|
||||
@weak_script
|
||||
def _in_proj(input, weight, bias, start=0, end=None):
|
||||
# type: (Tensor, Tensor, Optional[Tensor], int, Optional[int]) -> Tensor
|
||||
weight = weight[start:end, :]
|
||||
if bias is not None:
|
||||
bias = bias[start:end]
|
||||
return linear(input, weight, bias)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _in_proj_qkv(weight, bias, query):
|
||||
# type: (Tensor, Tensor, Tensor) -> Tensor
|
||||
return _in_proj(query, weight, bias).chunk(3, dim=-1)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _in_proj_kv(weight, bias, embed_dim, key):
|
||||
# type: (Tensor, Tensor, int, Tensor) -> Tensor
|
||||
return _in_proj(key, weight, bias, start=embed_dim).chunk(2, dim=-1)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _in_proj_q(weight, bias, embed_dim, query):
|
||||
# type: (Tensor, Tensor, int, Tensor) -> Tensor
|
||||
return _in_proj(query, weight, bias, end=embed_dim)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _in_proj_k(weight, bias, embed_dim, key):
|
||||
# type: (Tensor, Tensor, int, Tensor) -> Tensor
|
||||
return _in_proj(key, weight, bias, start=embed_dim, end=2 * embed_dim)
|
||||
|
||||
|
||||
@weak_script
|
||||
def _in_proj_v(weight, bias, embed_dim, value):
|
||||
# type: (Tensor, Tensor, int, Tensor) -> Tensor
|
||||
return _in_proj(value, weight, bias, start=2 * embed_dim)
|
||||
|
||||
|
||||
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
|
||||
kv_same = key.data_ptr() == value.data_ptr()
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
assert key.size() == value.size()
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
||||
scaling = head_dim ** -0.5
|
||||
|
||||
if qkv_same:
|
||||
# self-attention
|
||||
q, k, v = _in_proj_qkv(in_proj_weight, in_proj_bias, query)
|
||||
elif kv_same:
|
||||
# encoder-decoder attention
|
||||
q = _in_proj_q(in_proj_weight, in_proj_bias, embed_dim, query)
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = v = None
|
||||
else:
|
||||
k, v = _in_proj_kv(in_proj_weight, in_proj_bias, embed_dim, key)
|
||||
else:
|
||||
q = _in_proj_q(in_proj_weight, in_proj_bias, embed_dim, query)
|
||||
k = _in_proj_k(in_proj_weight, in_proj_bias, embed_dim, key)
|
||||
v = _in_proj_v(in_proj_weight, in_proj_bias, embed_dim, value)
|
||||
q *= scaling
|
||||
|
||||
if bias_k is not None:
|
||||
assert bias_v is not None
|
||||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
||||
|
||||
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
if k is not None:
|
||||
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
if v is not None:
|
||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
src_len = k.size(1)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if add_zero_attn:
|
||||
src_len += 1
|
||||
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
||||
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
|
||||
|
||||
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float('-inf'),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
||||
|
||||
attn_output_weights = softmax(
|
||||
attn_output_weights.float(), dim=-1,
|
||||
dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype)
|
||||
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
attn_output = out_proj(attn_output)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
||||
attn_output_weights = attn_output_weights.sum(dim=1) / num_heads
|
||||
else:
|
||||
attn_output_weights = None
|
||||
|
||||
return attn_output, attn_output_weights
|
||||
|
@ -694,7 +694,7 @@ class MultiheadAttention(Module):
|
||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||
bias: add bias as module parameter. Default: True.
|
||||
add_bias_kv: add bias to the key and value sequences at dim=0.
|
||||
add_zero_attn: add a new batch of zeros to the key and
|
||||
add_zero_attn: add a new batch of zeros to the key and
|
||||
value sequences at dim=1.
|
||||
|
||||
Examples::
|
||||
@ -708,9 +708,6 @@ class MultiheadAttention(Module):
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
||||
if bias:
|
||||
@ -748,143 +745,36 @@ class MultiheadAttention(Module):
|
||||
need_weights=True, attn_mask=None):
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention.
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: mask that prevents attention to certain positions.
|
||||
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
||||
- attn_mask: :math:`(L, L)` where L is the target sequence length.
|
||||
|
||||
Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
|
||||
kv_same = key.data_ptr() == value.data_ptr()
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
assert key.size() == value.size()
|
||||
|
||||
if qkv_same:
|
||||
# self-attention
|
||||
q, k, v = self._in_proj_qkv(query)
|
||||
elif kv_same:
|
||||
# encoder-decoder attention
|
||||
q = self._in_proj_q(query)
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = v = None
|
||||
else:
|
||||
k, v = self._in_proj_kv(key)
|
||||
else:
|
||||
q = self._in_proj_q(query)
|
||||
k = self._in_proj_k(key)
|
||||
v = self._in_proj_v(value)
|
||||
q *= self.scaling
|
||||
|
||||
if self.bias_k is not None:
|
||||
assert self.bias_v is not None
|
||||
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
||||
|
||||
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
if k is not None:
|
||||
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
if v is not None:
|
||||
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
|
||||
src_len = k.size(1)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if self.add_zero_attn:
|
||||
src_len += 1
|
||||
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
||||
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
|
||||
|
||||
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float('-inf'),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_output_weights = F.softmax(
|
||||
attn_output_weights.float(), dim=-1,
|
||||
dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype)
|
||||
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads
|
||||
else:
|
||||
attn_output_weights = None
|
||||
|
||||
return attn_output, attn_output_weights
|
||||
|
||||
def _in_proj_qkv(self, query):
|
||||
return self._in_proj(query).chunk(3, dim=-1)
|
||||
|
||||
def _in_proj_kv(self, key):
|
||||
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
|
||||
|
||||
def _in_proj_q(self, query):
|
||||
return self._in_proj(query, end=self.embed_dim)
|
||||
|
||||
def _in_proj_k(self, key):
|
||||
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
|
||||
|
||||
def _in_proj_v(self, value):
|
||||
return self._in_proj(value, start=2 * self.embed_dim)
|
||||
|
||||
def _in_proj(self, input, start=0, end=None):
|
||||
weight = self.in_proj_weight
|
||||
bias = self.in_proj_bias
|
||||
weight = weight[start:end, :]
|
||||
if bias is not None:
|
||||
bias = bias[start:end]
|
||||
return F.linear(input, weight, bias)
|
||||
return F.multi_head_attention_forward(
|
||||
query, key, value, self.embed_dim, self.num_heads,
|
||||
self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn,
|
||||
self.dropout, self.out_proj, training=self.training,
|
||||
key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask)
|
||||
|
||||
|
||||
@weak_module
|
||||
|
Reference in New Issue
Block a user