[PyTorch] Exercise MHA fast path in JIT

Tests previously did not exercise this; now they do.

Differential Revision: [D35945821](https://our.internmc.facebook.com/intern/diff/D35945821/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76416

Approved by: https://github.com/ezyang
This commit is contained in:
Scott Wolchok
2022-04-26 16:12:02 -07:00
committed by PyTorch MergeBot
parent 1ea49c68d0
commit b182c22e15

View File

@ -15084,6 +15084,22 @@ dedent """
# print(jit_out / py_out - 1)
self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
def test_torchscript_multi_head_attn_fast_path(self):
src_l = 3
bsz = 5
embed_size = 8
nhead = 2
multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True)
multi_head_attn = multi_head_attn.eval()
query = key = value = torch.rand((bsz, src_l, embed_size))
with torch.no_grad():
py_out = multi_head_attn(query, key, value)
mha = torch.jit.script(multi_head_attn)
jit_out = mha(query, key, value)
torch.testing.assert_close(jit_out, py_out)
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_multi_head_attn_cuda(self):