Disable incremental_state function in MultiheadAttention module. (#20177)

Summary:
To fully support incremental_state function, it requires several additional utils available in fairseq. However, we lack a problem for the unit test. Therefore, the incremental_state function will be disable for now. If it is needed in the future, a feature request could be created. Fixed #20132

Add some unit tests to cover the arguments of MultiheadAttention module, including bias, add_bias_kv, add_zero_attn, key_padding_mask, need_weights, attn_mask.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20177

Differential Revision: D15304575

Pulled By: cpuhrsch

fbshipit-source-id: ebd8cc0f11a4da0c0998bf0c7e4e341585e5685a
This commit is contained in:
Guanheng Zhang
2019-05-13 08:16:04 -07:00
committed by Facebook Github Bot
parent f8aa6a8f44
commit 41673d477c
2 changed files with 74 additions and 60 deletions

View File

@ -3237,29 +3237,36 @@ class TestNN(NNTestCase):
output = m(sigmoid(input), target)
verify_reduction_scalars(input, reduction, output)
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
"Scipy v1.0 and/or numpy not found")
def test_multihead_attention(self):
def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=False, src_lengths=None):
def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, src_lengths=None,
attn_mask=None, add_zero_attn=False):
""" Numpy-based reference implementation of scaled dot attention
for testing"""
QKT = _batchmatmul(
Q,
np.transpose(K, axes=[0, 1, 3, 2])
/ np.sqrt(dims[3], dtype=np.float32), # divide by sqrt(d_head)
)
if unseen_mask or src_lengths is not None:
b1, b2, s1, s2 = QKT.shape
b1, b2, s1, s2 = QKT.shape
if unseen_mask is not None or src_lengths is not None:
# assert s1 == s2
for i in range(b1):
for j in range(b2):
for m in range(s1):
for n in range(s2):
if unseen_mask and n > m:
if unseen_mask[m][n] == 0:
QKT[i, j, m, n] = -np.inf
if src_lengths is not None and n >= src_lengths[i]:
QKT[i, j, m, n] = -np.inf
reference = _softmax(QKT)
ref_attn_weight = reference
ref_attn_weight = np.sum(ref_attn_weight, axis=1) / b2
reference = _batchmatmul(reference, V)
return reference
return reference, ref_attn_weight
def _batchmatmul(a, b): # batchmatmul over 4 dim matrix
""" Numpy-based batch matrix multiply over 4 dim matrix"""
@ -3275,7 +3282,8 @@ class TestNN(NNTestCase):
def _softmax(x): # softmax over 4 dim matrix
""" Numpy-based reference softmax over 4 dim matrix"""
output = np.zeros(x.shape, dtype=np.float32)
np.seterr(invalid='ignore')
output = np.zeros(x.shape, dtype=np.float64)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
for k in range(x.shape[2]):
@ -3338,7 +3346,7 @@ class TestNN(NNTestCase):
# returns [batch_size, max_seq_len]
return (src_indices < src_lengths).int().detach()
def _multihead_attn_test_helper(use_src_lengths):
def _multihead_attn_test_helper(add_key_padding_mask, add_bias_kv=False, add_zero_attn=False):
for _ in range(100):
batch_sz, seq_len = [random.randint(2, 10) for r in range(2)]
d_head = random.randint(3, 10)
@ -3348,7 +3356,7 @@ class TestNN(NNTestCase):
src_lengths = None
src_lengths_tensor = None
if use_src_lengths:
if add_key_padding_mask:
src_lengths, src_lengths_tensor = _generate_src_lengths(
batch_size=batch_sz, seq_len=seq_len
)
@ -3357,28 +3365,44 @@ class TestNN(NNTestCase):
K = np.random.rand(*dims).astype(np.float64)
V = K
Q = np.expand_dims(decoder_state, 1)
attn_mask = np.random.randint(0 , 2, size=(1, seq_len))
attn_mask_tensor = torch.from_numpy(attn_mask).float()
attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf'))
attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0'))
attn_mask_tensor = attn_mask_tensor.double()
decoder_state_tensor = torch.from_numpy(decoder_state).double()
source_hid_tensor = torch.from_numpy(K).double().transpose(0, 1)
multihead_attn_module = MultiheadAttention(d_model, nheads)
multihead_attn_module = MultiheadAttention(d_model, nheads,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn)
if add_bias_kv:
bias_k = multihead_attn_module.bias_k.detach().numpy()
bias_v = multihead_attn_module.bias_v.detach().numpy()
else:
bias_k = None
bias_v = None
_batch_size = decoder_state_tensor.shape[0]
_Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1)
_V = source_hid_tensor
_K = source_hid_tensor
src_len_mask = None
if src_lengths is not None and use_src_lengths:
if src_lengths is not None and add_key_padding_mask:
# [batch_size, 1, seq_len]
src_len_mask_int = _create_src_lengths_mask(
batch_size=_batch_size, src_lengths=src_lengths_tensor
)
src_len_mask = src_len_mask_int != 1
result = multihead_attn_module(
result, result_weight = multihead_attn_module(
_Q, _K, _V,
key_padding_mask=src_len_mask,
need_weights=True)[0].squeeze(0).detach().numpy()
need_weights=True,
attn_mask=attn_mask_tensor)
result = result.squeeze(0).detach().numpy()
Q_fc = _fc(Q, "in_proj_", multihead_attn_module, end=d_model)
K_fc = _fc(
@ -3386,20 +3410,31 @@ class TestNN(NNTestCase):
)
V_fc = _fc(V, "in_proj_", multihead_attn_module, start=2 * d_model)
if add_bias_kv:
K_fc = np.concatenate((K_fc, np.repeat(bias_k, K_fc.shape[0], axis=0)), axis=1)
V_fc = np.concatenate((V_fc, np.repeat(bias_v, V_fc.shape[0], axis=0)), axis=1)
attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1)
dims[1] += 1
Q_split = _split_heads_ref(
Q_fc, [batch_sz, 1, d_model], nheads, d_head
)
K_split = _split_heads_ref(K_fc, dims, nheads, d_head)
V_split = _split_heads_ref(V_fc, dims, nheads, d_head)
attn_heads = _scaled_dot_attn_ref(
if add_zero_attn:
dims[1] += 1
K_split = np.concatenate((K_split, np.zeros([K_split.shape[0], K_split.shape[1], 1, K_split.shape[3]])), axis=2)
V_split = np.concatenate((V_split, np.zeros([V_split.shape[0], V_split.shape[1], 1, V_split.shape[3]])), axis=2)
attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1)
attn_heads, ref_attn_weight = _scaled_dot_attn_ref(
Q=Q_split,
K=K_split,
V=V_split,
dims=Q_split.shape,
src_lengths=src_lengths,
unseen_mask=attn_mask,
src_lengths=src_lengths
)
combined_attn_heads = _combine_heads_ref(
X=attn_heads, dims=[batch_sz, 1], nheads=nheads, d_head=d_head
)
@ -3413,14 +3448,27 @@ class TestNN(NNTestCase):
self.assertEqual(tuple(result.shape), (batch_sz, d_model))
np.testing.assert_allclose(result, reference, atol=1e-5)
# result_weight = ref_attn_weight
result_weight = result_weight.detach().numpy()
self.assertEqual(tuple(result_weight.shape), tuple(ref_attn_weight.shape))
np.testing.assert_allclose(result_weight, ref_attn_weight, atol=1e-5)
def test_multihead_attn_add_bias_kv():
_multihead_attn_test_helper(add_key_padding_mask=None, add_bias_kv=True)
def test_multihead_attn_add_zero_attn():
_multihead_attn_test_helper(add_key_padding_mask=None, add_zero_attn=True)
def test_multihead_attn_no_masking():
_multihead_attn_test_helper(use_src_lengths=None)
_multihead_attn_test_helper(add_key_padding_mask=None)
def test_multihead_attn_with_src_lengths():
_multihead_attn_test_helper(use_src_lengths=True)
def test_multihead_attn_key_padding_mask():
_multihead_attn_test_helper(add_key_padding_mask=True)
test_multihead_attn_add_zero_attn() # Test MultiheadAttention with add_zero_attn
test_multihead_attn_add_bias_kv() # Test MultiheadAttention with add_bias_kv
test_multihead_attn_no_masking() # Test MultiheadAttention without masking
test_multihead_attn_with_src_lengths() # Test MultiheadAttention with src lengths
test_multihead_attn_key_padding_mask() # Test MultiheadAttention with src lengths
def test_normalize(self):
inputs = torch.randn(1, 3, 4, 4, requires_grad=True)

View File

@ -691,6 +691,8 @@ class MultiheadAttention(Module):
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
bias: add bias as module parameter. Default: True.
add_bias_kv: add bias to the key and value sequences at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
@ -742,23 +744,19 @@ class MultiheadAttention(Module):
xavier_normal_(self.bias_v)
@weak_script_method
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False, attn_mask=None):
def forward(self, query, key, value, key_padding_mask=None,
need_weights=True, attn_mask=None):
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention.
incremental_state: if provided, previous time steps are cached.
need_weights: output attn_output_weights.
static_kv: if true, key and value are static. The key and value in previous
states will be used.
attn_mask: mask that prevents attention to certain positions.
Shape:
- Inputs:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
@ -766,11 +764,9 @@ class MultiheadAttention(Module):
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
- incremental_state: a dictionary used for storing states.
- attn_mask: :math:`(L, L)` where L is the target sequence length.
- Outputs:
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
@ -784,17 +780,6 @@ class MultiheadAttention(Module):
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert kv_same and not qkv_same
key = value = None
else:
saved_state = None
if qkv_same:
# self-attention
q, k, v = self._in_proj_qkv(query)
@ -828,25 +813,6 @@ class MultiheadAttention(Module):
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if 'prev_key' in saved_state:
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
k = torch.cat((prev_key, k), dim=1)
if 'prev_value' in saved_state:
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
v = torch.cat((prev_value, v), dim=1)
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
self._set_input_buffer(incremental_state, saved_state)
src_len = k.size(1)
if key_padding_mask is not None: