mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add fake_impl for _native_multi_head_attention (#163700)
Test Plan: See added test in test_export.py Differential Revision: D83099187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163700 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
7bad9c5a64
commit
21a41edd4f
@ -318,7 +318,7 @@ timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,3
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
|
@ -1087,6 +1087,93 @@ graph():
|
||||
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
|
||||
self.assertEqual(gm(*args), m(*args))
|
||||
|
||||
# stride() is called for an undefined tensor
|
||||
@testing.expectedFailureCppRuntimeNonStrict
|
||||
def test_native_multi_attention_head(self):
|
||||
embed_dim = 64
|
||||
num_heads = 4
|
||||
bs = 16
|
||||
sl = 8
|
||||
device = "cpu"
|
||||
|
||||
q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
|
||||
k = q
|
||||
v = q
|
||||
|
||||
qkv = torch.nn.Linear(
|
||||
embed_dim, 3 * embed_dim, device=device, dtype=torch.float32
|
||||
)
|
||||
proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32)
|
||||
|
||||
class NativeMHA(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
qkv,
|
||||
proj,
|
||||
need_weights,
|
||||
average_attn_weights,
|
||||
mask_type,
|
||||
):
|
||||
super().__init__()
|
||||
self.qkv = qkv
|
||||
self.proj = proj
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.need_weights = need_weights
|
||||
self.average_attn_weights = average_attn_weights
|
||||
self.mask_type = mask_type
|
||||
|
||||
def forward(self, q, k, v, key_padding_mask):
|
||||
return torch._native_multi_head_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.qkv.weight,
|
||||
self.qkv.bias,
|
||||
self.proj.weight,
|
||||
self.proj.bias,
|
||||
key_padding_mask,
|
||||
need_weights=False,
|
||||
average_attn_weights=False,
|
||||
mask_type=1, # mask_type = 1 => src_key_padding_mask, mask_type = 0 => src_mask
|
||||
)
|
||||
|
||||
for mask_type in (0, 1):
|
||||
for need_weights in (True, False):
|
||||
for average_attn_weights in (True, False):
|
||||
npt = NativeMHA(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
qkv=qkv,
|
||||
proj=proj,
|
||||
need_weights=need_weights,
|
||||
average_attn_weights=average_attn_weights,
|
||||
mask_type=mask_type,
|
||||
)
|
||||
sample_input = (q, k, v, None)
|
||||
|
||||
ep = export(
|
||||
npt,
|
||||
args=sample_input,
|
||||
dynamic_shapes={
|
||||
"q": {
|
||||
0: Dim("dim0_q", max=1024),
|
||||
},
|
||||
"k": {
|
||||
0: Dim("dim0_k", max=1024),
|
||||
},
|
||||
"v": {
|
||||
0: Dim("dim0_v", max=1024),
|
||||
},
|
||||
"key_padding_mask": None,
|
||||
},
|
||||
)
|
||||
self.assertEqual(ep.module()(*sample_input), npt(*sample_input))
|
||||
|
||||
def test_unused_constant(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -7790,6 +7790,56 @@ def _create_unary_float_meta_func(func):
|
||||
return _f
|
||||
|
||||
|
||||
# Implementation follows cuda implementation native_multi_head_attention_cuda
|
||||
@register_meta(aten._native_multi_head_attention.default)
|
||||
def native_multi_head_attention_fake(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
embed_dim,
|
||||
num_head,
|
||||
qkv_weight,
|
||||
qkv_bias,
|
||||
proj_weight,
|
||||
proj_bias,
|
||||
mask=None,
|
||||
need_weights=True,
|
||||
average_attn_weights=True,
|
||||
mask_type=None,
|
||||
):
|
||||
if query.is_nested or key.is_nested or value.is_nested:
|
||||
raise NotImplementedError(
|
||||
"_native_multi_head_attention fake implementation does not support nested tensors"
|
||||
)
|
||||
|
||||
if query.numel() == 0:
|
||||
return (query.new_empty(query.shape), query.new_empty(0))
|
||||
|
||||
B = query.size(0) # B: batch size
|
||||
T = query.size(1) # T: target sequence length
|
||||
|
||||
# In native_multi_head_attention_cuda,
|
||||
# we have proj = transform0213_gemm_nt_bias(attn_ctx, proj_weight, proj_bias, query)
|
||||
# , which does attn_ctx @ proj_weight.T + proj_bias
|
||||
# so the last dim of output shape is proj_weight.size(0)
|
||||
output_dim = proj_weight.size(0)
|
||||
output = query.new_empty(B, T, output_dim)
|
||||
|
||||
if need_weights:
|
||||
if average_attn_weights:
|
||||
# When averaging attention weights, shape is [B, T, T] (averaged over heads)
|
||||
# T = query seq len, S = key/value seq len
|
||||
attn_weights = query.new_empty(B, T, T)
|
||||
else:
|
||||
# When not averaging, shape is [B, num_head, T, T]
|
||||
# T = query seq len, S = key/value seq len
|
||||
attn_weights = query.new_empty(B, num_head, T, T)
|
||||
else:
|
||||
attn_weights = query.new_empty(0)
|
||||
|
||||
return (output, attn_weights)
|
||||
|
||||
|
||||
def _create_binary_float_meta_func(func):
|
||||
@register_meta(func)
|
||||
@out_wrapper()
|
||||
|
Reference in New Issue
Block a user