Add fake_impl for _native_multi_head_attention (#163700)

Test Plan: See added test in test_export.py

Differential Revision: D83099187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163700
Approved by: https://github.com/angelayi
This commit is contained in:
Yidi Wu
2025-09-25 19:01:23 +00:00
committed by PyTorch MergeBot
parent 7bad9c5a64
commit 21a41edd4f
3 changed files with 138 additions and 1 deletions

View File

@ -318,7 +318,7 @@ timm_vovnet,pass,0
torch_multimodal_clip,pass,3
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
318
319
320
321
322
323
324

View File

@ -1087,6 +1087,93 @@ graph():
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
self.assertEqual(gm(*args), m(*args))
# stride() is called for an undefined tensor
@testing.expectedFailureCppRuntimeNonStrict
def test_native_multi_attention_head(self):
embed_dim = 64
num_heads = 4
bs = 16
sl = 8
device = "cpu"
q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
k = q
v = q
qkv = torch.nn.Linear(
embed_dim, 3 * embed_dim, device=device, dtype=torch.float32
)
proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32)
class NativeMHA(torch.nn.Module):
def __init__(
self,
embed_dim,
num_heads,
qkv,
proj,
need_weights,
average_attn_weights,
mask_type,
):
super().__init__()
self.qkv = qkv
self.proj = proj
self.embed_dim = embed_dim
self.num_heads = num_heads
self.need_weights = need_weights
self.average_attn_weights = average_attn_weights
self.mask_type = mask_type
def forward(self, q, k, v, key_padding_mask):
return torch._native_multi_head_attention(
q,
k,
v,
self.embed_dim,
self.num_heads,
self.qkv.weight,
self.qkv.bias,
self.proj.weight,
self.proj.bias,
key_padding_mask,
need_weights=False,
average_attn_weights=False,
mask_type=1, # mask_type = 1 => src_key_padding_mask, mask_type = 0 => src_mask
)
for mask_type in (0, 1):
for need_weights in (True, False):
for average_attn_weights in (True, False):
npt = NativeMHA(
embed_dim=embed_dim,
num_heads=num_heads,
qkv=qkv,
proj=proj,
need_weights=need_weights,
average_attn_weights=average_attn_weights,
mask_type=mask_type,
)
sample_input = (q, k, v, None)
ep = export(
npt,
args=sample_input,
dynamic_shapes={
"q": {
0: Dim("dim0_q", max=1024),
},
"k": {
0: Dim("dim0_k", max=1024),
},
"v": {
0: Dim("dim0_v", max=1024),
},
"key_padding_mask": None,
},
)
self.assertEqual(ep.module()(*sample_input), npt(*sample_input))
def test_unused_constant(self):
class M(torch.nn.Module):
def forward(self, x):

View File

@ -7790,6 +7790,56 @@ def _create_unary_float_meta_func(func):
return _f
# Implementation follows cuda implementation native_multi_head_attention_cuda
@register_meta(aten._native_multi_head_attention.default)
def native_multi_head_attention_fake(
query,
key,
value,
embed_dim,
num_head,
qkv_weight,
qkv_bias,
proj_weight,
proj_bias,
mask=None,
need_weights=True,
average_attn_weights=True,
mask_type=None,
):
if query.is_nested or key.is_nested or value.is_nested:
raise NotImplementedError(
"_native_multi_head_attention fake implementation does not support nested tensors"
)
if query.numel() == 0:
return (query.new_empty(query.shape), query.new_empty(0))
B = query.size(0) # B: batch size
T = query.size(1) # T: target sequence length
# In native_multi_head_attention_cuda,
# we have proj = transform0213_gemm_nt_bias(attn_ctx, proj_weight, proj_bias, query)
# , which does attn_ctx @ proj_weight.T + proj_bias
# so the last dim of output shape is proj_weight.size(0)
output_dim = proj_weight.size(0)
output = query.new_empty(B, T, output_dim)
if need_weights:
if average_attn_weights:
# When averaging attention weights, shape is [B, T, T] (averaged over heads)
# T = query seq len, S = key/value seq len
attn_weights = query.new_empty(B, T, T)
else:
# When not averaging, shape is [B, num_head, T, T]
# T = query seq len, S = key/value seq len
attn_weights = query.new_empty(B, num_head, T, T)
else:
attn_weights = query.new_empty(0)
return (output, attn_weights)
def _create_binary_float_meta_func(func):
@register_meta(func)
@out_wrapper()