mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Disable incremental_state function in MultiheadAttention module. (#20177)
Summary: To fully support incremental_state function, it requires several additional utils available in fairseq. However, we lack a problem for the unit test. Therefore, the incremental_state function will be disable for now. If it is needed in the future, a feature request could be created. Fixed #20132 Add some unit tests to cover the arguments of MultiheadAttention module, including bias, add_bias_kv, add_zero_attn, key_padding_mask, need_weights, attn_mask. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20177 Differential Revision: D15304575 Pulled By: cpuhrsch fbshipit-source-id: ebd8cc0f11a4da0c0998bf0c7e4e341585e5685a
This commit is contained in:
committed by
Facebook Github Bot
parent
f8aa6a8f44
commit
41673d477c
@ -3237,29 +3237,36 @@ class TestNN(NNTestCase):
|
||||
output = m(sigmoid(input), target)
|
||||
verify_reduction_scalars(input, reduction, output)
|
||||
|
||||
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
|
||||
"Scipy v1.0 and/or numpy not found")
|
||||
def test_multihead_attention(self):
|
||||
def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=False, src_lengths=None):
|
||||
def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, src_lengths=None,
|
||||
attn_mask=None, add_zero_attn=False):
|
||||
""" Numpy-based reference implementation of scaled dot attention
|
||||
for testing"""
|
||||
|
||||
QKT = _batchmatmul(
|
||||
Q,
|
||||
np.transpose(K, axes=[0, 1, 3, 2])
|
||||
/ np.sqrt(dims[3], dtype=np.float32), # divide by sqrt(d_head)
|
||||
)
|
||||
if unseen_mask or src_lengths is not None:
|
||||
b1, b2, s1, s2 = QKT.shape
|
||||
b1, b2, s1, s2 = QKT.shape
|
||||
if unseen_mask is not None or src_lengths is not None:
|
||||
# assert s1 == s2
|
||||
for i in range(b1):
|
||||
for j in range(b2):
|
||||
for m in range(s1):
|
||||
for n in range(s2):
|
||||
if unseen_mask and n > m:
|
||||
if unseen_mask[m][n] == 0:
|
||||
QKT[i, j, m, n] = -np.inf
|
||||
if src_lengths is not None and n >= src_lengths[i]:
|
||||
QKT[i, j, m, n] = -np.inf
|
||||
|
||||
reference = _softmax(QKT)
|
||||
ref_attn_weight = reference
|
||||
ref_attn_weight = np.sum(ref_attn_weight, axis=1) / b2
|
||||
reference = _batchmatmul(reference, V)
|
||||
return reference
|
||||
return reference, ref_attn_weight
|
||||
|
||||
def _batchmatmul(a, b): # batchmatmul over 4 dim matrix
|
||||
""" Numpy-based batch matrix multiply over 4 dim matrix"""
|
||||
@ -3275,7 +3282,8 @@ class TestNN(NNTestCase):
|
||||
|
||||
def _softmax(x): # softmax over 4 dim matrix
|
||||
""" Numpy-based reference softmax over 4 dim matrix"""
|
||||
output = np.zeros(x.shape, dtype=np.float32)
|
||||
np.seterr(invalid='ignore')
|
||||
output = np.zeros(x.shape, dtype=np.float64)
|
||||
for i in range(x.shape[0]):
|
||||
for j in range(x.shape[1]):
|
||||
for k in range(x.shape[2]):
|
||||
@ -3338,7 +3346,7 @@ class TestNN(NNTestCase):
|
||||
# returns [batch_size, max_seq_len]
|
||||
return (src_indices < src_lengths).int().detach()
|
||||
|
||||
def _multihead_attn_test_helper(use_src_lengths):
|
||||
def _multihead_attn_test_helper(add_key_padding_mask, add_bias_kv=False, add_zero_attn=False):
|
||||
for _ in range(100):
|
||||
batch_sz, seq_len = [random.randint(2, 10) for r in range(2)]
|
||||
d_head = random.randint(3, 10)
|
||||
@ -3348,7 +3356,7 @@ class TestNN(NNTestCase):
|
||||
|
||||
src_lengths = None
|
||||
src_lengths_tensor = None
|
||||
if use_src_lengths:
|
||||
if add_key_padding_mask:
|
||||
src_lengths, src_lengths_tensor = _generate_src_lengths(
|
||||
batch_size=batch_sz, seq_len=seq_len
|
||||
)
|
||||
@ -3357,28 +3365,44 @@ class TestNN(NNTestCase):
|
||||
K = np.random.rand(*dims).astype(np.float64)
|
||||
V = K
|
||||
Q = np.expand_dims(decoder_state, 1)
|
||||
attn_mask = np.random.randint(0 , 2, size=(1, seq_len))
|
||||
attn_mask_tensor = torch.from_numpy(attn_mask).float()
|
||||
attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf'))
|
||||
attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0'))
|
||||
attn_mask_tensor = attn_mask_tensor.double()
|
||||
|
||||
decoder_state_tensor = torch.from_numpy(decoder_state).double()
|
||||
source_hid_tensor = torch.from_numpy(K).double().transpose(0, 1)
|
||||
|
||||
multihead_attn_module = MultiheadAttention(d_model, nheads)
|
||||
multihead_attn_module = MultiheadAttention(d_model, nheads,
|
||||
add_bias_kv=add_bias_kv,
|
||||
add_zero_attn=add_zero_attn)
|
||||
|
||||
if add_bias_kv:
|
||||
bias_k = multihead_attn_module.bias_k.detach().numpy()
|
||||
bias_v = multihead_attn_module.bias_v.detach().numpy()
|
||||
else:
|
||||
bias_k = None
|
||||
bias_v = None
|
||||
|
||||
_batch_size = decoder_state_tensor.shape[0]
|
||||
_Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1)
|
||||
_V = source_hid_tensor
|
||||
_K = source_hid_tensor
|
||||
src_len_mask = None
|
||||
if src_lengths is not None and use_src_lengths:
|
||||
if src_lengths is not None and add_key_padding_mask:
|
||||
# [batch_size, 1, seq_len]
|
||||
src_len_mask_int = _create_src_lengths_mask(
|
||||
batch_size=_batch_size, src_lengths=src_lengths_tensor
|
||||
)
|
||||
src_len_mask = src_len_mask_int != 1
|
||||
|
||||
result = multihead_attn_module(
|
||||
result, result_weight = multihead_attn_module(
|
||||
_Q, _K, _V,
|
||||
key_padding_mask=src_len_mask,
|
||||
need_weights=True)[0].squeeze(0).detach().numpy()
|
||||
need_weights=True,
|
||||
attn_mask=attn_mask_tensor)
|
||||
|
||||
result = result.squeeze(0).detach().numpy()
|
||||
|
||||
Q_fc = _fc(Q, "in_proj_", multihead_attn_module, end=d_model)
|
||||
K_fc = _fc(
|
||||
@ -3386,20 +3410,31 @@ class TestNN(NNTestCase):
|
||||
)
|
||||
V_fc = _fc(V, "in_proj_", multihead_attn_module, start=2 * d_model)
|
||||
|
||||
if add_bias_kv:
|
||||
K_fc = np.concatenate((K_fc, np.repeat(bias_k, K_fc.shape[0], axis=0)), axis=1)
|
||||
V_fc = np.concatenate((V_fc, np.repeat(bias_v, V_fc.shape[0], axis=0)), axis=1)
|
||||
attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1)
|
||||
dims[1] += 1
|
||||
Q_split = _split_heads_ref(
|
||||
Q_fc, [batch_sz, 1, d_model], nheads, d_head
|
||||
)
|
||||
K_split = _split_heads_ref(K_fc, dims, nheads, d_head)
|
||||
V_split = _split_heads_ref(V_fc, dims, nheads, d_head)
|
||||
|
||||
attn_heads = _scaled_dot_attn_ref(
|
||||
if add_zero_attn:
|
||||
dims[1] += 1
|
||||
K_split = np.concatenate((K_split, np.zeros([K_split.shape[0], K_split.shape[1], 1, K_split.shape[3]])), axis=2)
|
||||
V_split = np.concatenate((V_split, np.zeros([V_split.shape[0], V_split.shape[1], 1, V_split.shape[3]])), axis=2)
|
||||
attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1)
|
||||
|
||||
attn_heads, ref_attn_weight = _scaled_dot_attn_ref(
|
||||
Q=Q_split,
|
||||
K=K_split,
|
||||
V=V_split,
|
||||
dims=Q_split.shape,
|
||||
src_lengths=src_lengths,
|
||||
unseen_mask=attn_mask,
|
||||
src_lengths=src_lengths
|
||||
)
|
||||
|
||||
combined_attn_heads = _combine_heads_ref(
|
||||
X=attn_heads, dims=[batch_sz, 1], nheads=nheads, d_head=d_head
|
||||
)
|
||||
@ -3413,14 +3448,27 @@ class TestNN(NNTestCase):
|
||||
self.assertEqual(tuple(result.shape), (batch_sz, d_model))
|
||||
np.testing.assert_allclose(result, reference, atol=1e-5)
|
||||
|
||||
# result_weight = ref_attn_weight
|
||||
result_weight = result_weight.detach().numpy()
|
||||
self.assertEqual(tuple(result_weight.shape), tuple(ref_attn_weight.shape))
|
||||
np.testing.assert_allclose(result_weight, ref_attn_weight, atol=1e-5)
|
||||
|
||||
def test_multihead_attn_add_bias_kv():
|
||||
_multihead_attn_test_helper(add_key_padding_mask=None, add_bias_kv=True)
|
||||
|
||||
def test_multihead_attn_add_zero_attn():
|
||||
_multihead_attn_test_helper(add_key_padding_mask=None, add_zero_attn=True)
|
||||
|
||||
def test_multihead_attn_no_masking():
|
||||
_multihead_attn_test_helper(use_src_lengths=None)
|
||||
_multihead_attn_test_helper(add_key_padding_mask=None)
|
||||
|
||||
def test_multihead_attn_with_src_lengths():
|
||||
_multihead_attn_test_helper(use_src_lengths=True)
|
||||
def test_multihead_attn_key_padding_mask():
|
||||
_multihead_attn_test_helper(add_key_padding_mask=True)
|
||||
|
||||
test_multihead_attn_add_zero_attn() # Test MultiheadAttention with add_zero_attn
|
||||
test_multihead_attn_add_bias_kv() # Test MultiheadAttention with add_bias_kv
|
||||
test_multihead_attn_no_masking() # Test MultiheadAttention without masking
|
||||
test_multihead_attn_with_src_lengths() # Test MultiheadAttention with src lengths
|
||||
test_multihead_attn_key_padding_mask() # Test MultiheadAttention with src lengths
|
||||
|
||||
def test_normalize(self):
|
||||
inputs = torch.randn(1, 3, 4, 4, requires_grad=True)
|
||||
|
@ -691,6 +691,8 @@ class MultiheadAttention(Module):
|
||||
Args:
|
||||
embed_dim: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||
bias: add bias as module parameter. Default: True.
|
||||
add_bias_kv: add bias to the key and value sequences at dim=0.
|
||||
add_zero_attn: add a new batch of zeros to the key and
|
||||
value sequences at dim=1.
|
||||
@ -742,23 +744,19 @@ class MultiheadAttention(Module):
|
||||
xavier_normal_(self.bias_v)
|
||||
|
||||
@weak_script_method
|
||||
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
|
||||
need_weights=True, static_kv=False, attn_mask=None):
|
||||
def forward(self, query, key, value, key_padding_mask=None,
|
||||
need_weights=True, attn_mask=None):
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention.
|
||||
incremental_state: if provided, previous time steps are cached.
|
||||
need_weights: output attn_output_weights.
|
||||
static_kv: if true, key and value are static. The key and value in previous
|
||||
states will be used.
|
||||
attn_mask: mask that prevents attention to certain positions.
|
||||
|
||||
Shape:
|
||||
- Inputs:
|
||||
|
||||
Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
@ -766,11 +764,9 @@ class MultiheadAttention(Module):
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
||||
- incremental_state: a dictionary used for storing states.
|
||||
- attn_mask: :math:`(L, L)` where L is the target sequence length.
|
||||
|
||||
- Outputs:
|
||||
|
||||
Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
@ -784,17 +780,6 @@ class MultiheadAttention(Module):
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
assert key.size() == value.size()
|
||||
|
||||
if incremental_state is not None:
|
||||
saved_state = self._get_input_buffer(incremental_state)
|
||||
if 'prev_key' in saved_state:
|
||||
# previous time steps are cached - no need to recompute
|
||||
# key and value if they are static
|
||||
if static_kv:
|
||||
assert kv_same and not qkv_same
|
||||
key = value = None
|
||||
else:
|
||||
saved_state = None
|
||||
|
||||
if qkv_same:
|
||||
# self-attention
|
||||
q, k, v = self._in_proj_qkv(query)
|
||||
@ -828,25 +813,6 @@ class MultiheadAttention(Module):
|
||||
if v is not None:
|
||||
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
|
||||
if saved_state is not None:
|
||||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||
if 'prev_key' in saved_state:
|
||||
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
k = prev_key
|
||||
else:
|
||||
k = torch.cat((prev_key, k), dim=1)
|
||||
if 'prev_value' in saved_state:
|
||||
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
v = prev_value
|
||||
else:
|
||||
v = torch.cat((prev_value, v), dim=1)
|
||||
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
|
||||
self._set_input_buffer(incremental_state, saved_state)
|
||||
|
||||
src_len = k.size(1)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
|
Reference in New Issue
Block a user