mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix DDP incompatibility issue with nn.MultiheadAttention. (#26826)
Summary: Fix issue https://github.com/pytorch/pytorch/issues/26698. With different query/keys/value dimensions, `nn.MultiheadAttention` has DDP incompatibility issue because in that case `in_proj_weight` attribute is created but not used. Fix it and add a distributed unit test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/26826 Differential Revision: D17583807 Pulled By: zhangguanheng66 fbshipit-source-id: c393584c331ed4f57ebaf2d4015ef04589c973f6
This commit is contained in:
committed by
Facebook Github Bot
parent
f522bde121
commit
eb93200321
@ -2627,7 +2627,7 @@ class TestNN(NNTestCase):
|
||||
result, result_weight = torch.nn.functional.multi_head_attention_forward(
|
||||
_Q, _K, _V,
|
||||
d_model, nheads,
|
||||
multihead_attn_module.in_proj_weight, multihead_attn_module.in_proj_bias,
|
||||
None, multihead_attn_module.in_proj_bias,
|
||||
multihead_attn_module.bias_k, multihead_attn_module.bias_v,
|
||||
multihead_attn_module.add_zero_attn, multihead_attn_module.dropout,
|
||||
multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias,
|
||||
@ -2638,10 +2638,11 @@ class TestNN(NNTestCase):
|
||||
|
||||
result = result.squeeze(0).detach().numpy()
|
||||
|
||||
q_proj_weight = multihead_attn_module.in_proj_weight[:d_model]
|
||||
k_proj_weight = multihead_attn_module.in_proj_weight[d_model:(d_model * 2)]
|
||||
v_proj_weight = multihead_attn_module.in_proj_weight[(d_model * 2):]
|
||||
if not multihead_attn_module._qkv_same_embed_dim:
|
||||
if multihead_attn_module._qkv_same_embed_dim:
|
||||
q_proj_weight = multihead_attn_module.in_proj_weight[:d_model]
|
||||
k_proj_weight = multihead_attn_module.in_proj_weight[d_model:(d_model * 2)]
|
||||
v_proj_weight = multihead_attn_module.in_proj_weight[(d_model * 2):]
|
||||
else:
|
||||
q_proj_weight = multihead_attn_module.q_proj_weight
|
||||
k_proj_weight = multihead_attn_module.k_proj_weight
|
||||
v_proj_weight = multihead_attn_module.v_proj_weight
|
||||
|
@ -3201,7 +3201,6 @@ def multi_head_attention_forward(query, # type: Tensor
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
assert key.size() == value.size()
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
|
@ -687,12 +687,12 @@ class MultiheadAttention(Module):
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
||||
|
||||
if self._qkv_same_embed_dim is False:
|
||||
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
||||
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
||||
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
||||
else:
|
||||
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
||||
@ -759,7 +759,7 @@ class MultiheadAttention(Module):
|
||||
if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False:
|
||||
return F.multi_head_attention_forward(
|
||||
query, key, value, self.embed_dim, self.num_heads,
|
||||
self.in_proj_weight, self.in_proj_bias,
|
||||
None, self.in_proj_bias, # set self.in_proj_weight = None
|
||||
self.bias_k, self.bias_v, self.add_zero_attn,
|
||||
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
||||
training=self.training,
|
||||
|
Reference in New Issue
Block a user