mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Fix conversion of attention - 4D (#157130)
Fixes a wrong conversion to onnx while investigation #149662. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157130 Approved by: https://github.com/gramalingam, https://github.com/justinchuby, https://github.com/titaiwangms Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
d5d14ee823
commit
0105cd89ab
@ -5,9 +5,13 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import onnx.reference as onnx_ref
|
||||
|
||||
import onnxruntime
|
||||
import pytest
|
||||
import transformers
|
||||
from onnxscript import ir
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter import _testing as onnx_testing
|
||||
@ -15,6 +19,10 @@ from torch.testing._internal import common_utils
|
||||
from torch.utils import _pytree as torch_pytree
|
||||
|
||||
|
||||
def has_onnxruntime_opset_23() -> bool:
|
||||
return version.parse(onnxruntime.__version__) >= version.parse("1.22")
|
||||
|
||||
|
||||
class _WithExport:
|
||||
def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
|
||||
onnx_program = torch.onnx.export(
|
||||
@ -736,11 +744,17 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
|
||||
query = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||
expected = Model()(query, key, value)
|
||||
|
||||
onnx_program = self.export(Model(), (query, key, value), opset_version=23)
|
||||
self.assertIn("Attention", [node.op_type for node in onnx_program.model.graph])
|
||||
|
||||
@pytest.mark.xfail(reason="Expected to fail until opset 23 is supported by ORT.")
|
||||
ref = onnx_ref.ReferenceEvaluator(onnx_program.model_proto)
|
||||
got = ref.run(
|
||||
None, dict(query=query.numpy(), key=key.numpy(), value=value.numpy())
|
||||
)[0]
|
||||
torch.testing.assert_close(torch.from_numpy(got), expected, atol=1e-2, rtol=1)
|
||||
|
||||
def test_graph_accuracy_attention_opset_23(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, query, key, value):
|
||||
@ -752,8 +766,13 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
|
||||
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||
|
||||
onnx_program = self.export(Model(), (query, key, value), opset_version=23)
|
||||
onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1)
|
||||
onnx_program = self.export(
|
||||
Model(), (query, key, value), opset_version=23, optimize=True
|
||||
)
|
||||
self.assertEqual(["Attention"], [n.op_type for n in onnx_program.model.graph])
|
||||
# onnxruntime inlines any op defined as a function and without any implemented kernel
|
||||
if has_onnxruntime_opset_23():
|
||||
onnx_testing.assert_onnx_program(onnx_program, atol=1e-2, rtol=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -119,19 +119,19 @@ def aten_scaled_dot_product_attention_23(
|
||||
"SDPA (MHA) requires q_num_heads = kv_num_heads"
|
||||
)
|
||||
|
||||
# NOTE: There was extended discussion on whether the num_heads attributes (q_num_heads/kv_num_heads)
|
||||
# should be set as ONNX attributes or inferred from the tensor shape. In ONNX, num_heads is needed
|
||||
# for 3D attention inputs (shape: [B, S, N*H]), but not for 4D ([B, N, S, H]), which is the only
|
||||
# input accepted by this exporter. Thus, the attribute is not strictly necessary here, but adding it
|
||||
# may ease future optimization or conversion to 3D formats (e.g., GQA ops)
|
||||
# NOTE: num_heads attributes (q_num_heads/kv_num_heads) should not be specified for 4D.
|
||||
# They are not populated with 4D inputs because this information directy comes from input shapes:
|
||||
# `q_num_heads=query.shape[1]` and `kv_num_heads=key.shape[1]`.
|
||||
# This dimension is usually static but it could not be dynamic if also given as an attribute.
|
||||
# num_heads attributes are needed for 3D attention inputs:
|
||||
# (shape: [B, S, N*H]), 4D shape is ([B, N, S, H]).
|
||||
|
||||
Y, _, _, _ = op23.Attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask,
|
||||
scale=scale,
|
||||
q_num_heads=query.shape[-3],
|
||||
kv_num_heads=key.shape[-3],
|
||||
is_causal=is_causal,
|
||||
)
|
||||
return Y
|
||||
|
Reference in New Issue
Block a user