Fix DDP incompatibility issue with nn.MultiheadAttention. (#26826)

Summary:
Fix issue https://github.com/pytorch/pytorch/issues/26698.

With different query/keys/value dimensions, `nn.MultiheadAttention` has DDP incompatibility issue because in that case `in_proj_weight` attribute is created but not used. Fix it and add a distributed unit test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26826

Differential Revision: D17583807

Pulled By: zhangguanheng66

fbshipit-source-id: c393584c331ed4f57ebaf2d4015ef04589c973f6
This commit is contained in:
Guanheng Zhang
2019-10-08 11:57:34 -07:00
committed by Facebook Github Bot
parent f522bde121
commit eb93200321
3 changed files with 9 additions and 9 deletions

View File

@ -2627,7 +2627,7 @@ class TestNN(NNTestCase):
result, result_weight = torch.nn.functional.multi_head_attention_forward(
_Q, _K, _V,
d_model, nheads,
multihead_attn_module.in_proj_weight, multihead_attn_module.in_proj_bias,
None, multihead_attn_module.in_proj_bias,
multihead_attn_module.bias_k, multihead_attn_module.bias_v,
multihead_attn_module.add_zero_attn, multihead_attn_module.dropout,
multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias,
@ -2638,10 +2638,11 @@ class TestNN(NNTestCase):
result = result.squeeze(0).detach().numpy()
q_proj_weight = multihead_attn_module.in_proj_weight[:d_model]
k_proj_weight = multihead_attn_module.in_proj_weight[d_model:(d_model * 2)]
v_proj_weight = multihead_attn_module.in_proj_weight[(d_model * 2):]
if not multihead_attn_module._qkv_same_embed_dim:
if multihead_attn_module._qkv_same_embed_dim:
q_proj_weight = multihead_attn_module.in_proj_weight[:d_model]
k_proj_weight = multihead_attn_module.in_proj_weight[d_model:(d_model * 2)]
v_proj_weight = multihead_attn_module.in_proj_weight[(d_model * 2):]
else:
q_proj_weight = multihead_attn_module.q_proj_weight
k_proj_weight = multihead_attn_module.k_proj_weight
v_proj_weight = multihead_attn_module.v_proj_weight

View File

@ -3201,7 +3201,6 @@ def multi_head_attention_forward(query, # type: Tensor
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
head_dim = embed_dim // num_heads

View File

@ -687,12 +687,12 @@ class MultiheadAttention(Module):
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
if self._qkv_same_embed_dim is False:
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
else:
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
if bias:
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
@ -759,7 +759,7 @@ class MultiheadAttention(Module):
if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False:
return F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
None, self.in_proj_bias, # set self.in_proj_weight = None
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,