mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch] Exercise MHA fast path in JIT
Tests previously did not exercise this; now they do. Differential Revision: [D35945821](https://our.internmc.facebook.com/intern/diff/D35945821/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/76416 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ea49c68d0
commit
b182c22e15
@ -15084,6 +15084,22 @@ dedent """
|
||||
# print(jit_out / py_out - 1)
|
||||
self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
|
||||
|
||||
def test_torchscript_multi_head_attn_fast_path(self):
|
||||
src_l = 3
|
||||
bsz = 5
|
||||
embed_size = 8
|
||||
nhead = 2
|
||||
multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True)
|
||||
multi_head_attn = multi_head_attn.eval()
|
||||
|
||||
query = key = value = torch.rand((bsz, src_l, embed_size))
|
||||
|
||||
with torch.no_grad():
|
||||
py_out = multi_head_attn(query, key, value)
|
||||
mha = torch.jit.script(multi_head_attn)
|
||||
jit_out = mha(query, key, value)
|
||||
torch.testing.assert_close(jit_out, py_out)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||
def test_scriptmodule_multi_head_attn_cuda(self):
|
||||
|
||||
|
Reference in New Issue
Block a user