mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Support aten::scaled_dot_product_attention in torchscript exporter (#99658)
Fixes #97262 <!-- copilot:all --> ### <samp>🤖 Generated by Copilot at d06d195</samp> ### Summary 🆕🚀📝 <!-- 1. 🆕 for adding tests and annotations for a new operator. 2. 🚀 for adding support for exporting a new operator to ONNX. 3. 📝 for fixing a minor formatting issue. --> This pull request adds ONNX opset 14 support for the `nn.functional.scaled_dot_product_attention` operator, which is used for self-attention in transformer models. It does so by adding tests and annotations in `test/onnx/test_op_consistency.py`, and by adding a symbolic function in `torch/onnx/symbolic_opset14.py` that reuses an existing implementation. > _To export `scaled_dot_product_attention`_ > _To ONNX opset 14, we need some extension_ > _We import some modules and types_ > _And add a symbolic that pipes_ > _The existing code with some annotation_ ### Walkthrough * Implement the `nn.functional.scaled_dot_product_attention` operator for ONNX opset 14 ([link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-244955d820ec138d5ddffb20ee6f517cc4c5d281f19ccb53d8db47043b5ac46fR122-R292)) * Add imports for modules and types needed for the operator implementation ([link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-244955d820ec138d5ddffb20ee6f517cc4c5d281f19ccb53d8db47043b5ac46fL17-R23)) * Add a command to run the pytest module for testing the operator consistency ([link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753R13)) * Add the operator to the list of operators tested for consistency ([link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753R311)) * Add annotations to indicate the operator's limitations and issues ([link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L333-R339), [link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753R354-R358)) * Remove an empty line at the end of `test/onnx/test_op_consistency.py` ([link](https://github.com/pytorch/pytorch/pull/99658/files?diff=unified&w=0#diff-e968c9cb6fc6631cab526cb3a9fe66358c4c6e757e2a223a224b976471bcb753L441)) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99658 Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
6585d76f0f
commit
e5664c652a
@ -10,6 +10,7 @@ Usage:
|
||||
To run tests on a specific operator (e.g. torch.ceil):
|
||||
|
||||
pytest test/onnx/test_op_consistency.py -k ceil
|
||||
pytest test/onnx/test_op_consistency.py -k nn_functional_scaled_dot_product_attention
|
||||
|
||||
Read more on Running and writing tests:
|
||||
https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
|
||||
@ -307,6 +308,7 @@ TESTED_OPS: frozenset[str] = frozenset(
|
||||
"ceil",
|
||||
"flatten",
|
||||
"logical_not",
|
||||
"nn.functional.scaled_dot_product_attention",
|
||||
"sqrt",
|
||||
"stft",
|
||||
"t",
|
||||
@ -330,9 +332,11 @@ EXPECTED_SKIPS_OR_FAILS: Tuple[DecorateMeta, ...] = (
|
||||
reason=reason_onnx_does_not_support("Ceil")
|
||||
),
|
||||
fixme("ceil", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
|
||||
dont_care("nn.functional.scaled_dot_product_attention", opsets=[opsets_before(14)], reason="Need Trilu."),
|
||||
fixme("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"),
|
||||
dont_care("sqrt", dtypes=BOOL_TYPES, reason=reason_onnx_does_not_support("Sqrt")),
|
||||
dont_care("stft", opsets=[opsets_before(17)], reason=reason_onnx_does_not_support("STFT")),
|
||||
fixme("unflatten", opsets=[opsets_before(13)], reason="helper function is needed to support legacy ops."),
|
||||
fixme("unflatten", opsets=[opsets_before(13)], reason="Helper function is needed to support legacy ops."),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@ -347,6 +351,11 @@ SKIP_SUBTESTS: tuple[DecorateMeta, ...] = (
|
||||
reason="Logic not implemented for size 0 inputs in op.Reshape",
|
||||
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
|
||||
),
|
||||
dont_care(
|
||||
"nn.functional.scaled_dot_product_attention",
|
||||
matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
|
||||
reason="dropout is random so the results do not match",
|
||||
),
|
||||
)
|
||||
|
||||
# END OF SECTION TO MODIFY #####################################################
|
||||
@ -438,7 +447,6 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
|
||||
# Cannot use self.skip because pytest would skip the entire test
|
||||
warnings.warn(f"skipped sample {i}. Reason: {skip_reason}")
|
||||
continue
|
||||
|
||||
model = SingleOpModel(op, cpu_sample.kwargs)
|
||||
model.eval()
|
||||
|
||||
|
||||
@ -463,7 +463,8 @@ def _reduce_with_dtype(onnx_op, name):
|
||||
return reduce
|
||||
|
||||
|
||||
# Ported from https://github.com/microsoft/onnx-script/blob/main/onnxscript/function_libs/torch_aten/ops/core.py aten_unflatten
|
||||
# Ported from
|
||||
# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/core.py#L6097
|
||||
# NOTE: Supporting aten::unflatten before opset13 needs helper function to adjust ONNX op changes in Concat, Slice, ...
|
||||
@_onnx_symbolic("aten::unflatten")
|
||||
@_beartype.beartype
|
||||
|
||||
@ -14,14 +14,26 @@ Updated operators:
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.onnx import symbolic_helper
|
||||
from torch.onnx import _constants, _type_utils, symbolic_helper
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
__all__ = [
|
||||
"hardswish",
|
||||
"tril",
|
||||
"triu",
|
||||
"reshape",
|
||||
"batch_norm",
|
||||
"quantized_hardswish",
|
||||
"scaled_dot_product_attention",
|
||||
]
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14)
|
||||
|
||||
|
||||
@ -117,3 +129,153 @@ def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||||
output = hardswish(g, x)
|
||||
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
|
||||
# Ported from
|
||||
# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504
|
||||
# aten_scaled_dot_product_attention
|
||||
# NOTE: Need op.Trilu
|
||||
@_onnx_symbolic("aten::scaled_dot_product_attention")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v")
|
||||
@_beartype.beartype
|
||||
def scaled_dot_product_attention(
|
||||
g: jit_utils.GraphContext,
|
||||
query: torch._C.Value,
|
||||
key: torch._C.Value,
|
||||
value: torch._C.Value,
|
||||
attn_mask: Optional[torch._C.Value] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[torch._C.Value] = None,
|
||||
):
|
||||
assert (not is_causal) or (
|
||||
is_causal and symbolic_helper._is_none(attn_mask)
|
||||
), "is_causal and attn_mask cannot be set at the same time"
|
||||
|
||||
scale = symbolic_helper._maybe_get_const(scale, "f")
|
||||
if symbolic_helper._is_none(scale):
|
||||
scale = _attention_scale(g, query)
|
||||
|
||||
if is_causal:
|
||||
attn_mask = _causal_attention_mask(g, query, key)
|
||||
|
||||
# Swap the last two axes of key
|
||||
# NOTE: onnx-script has different logic here, because the attribute perms in
|
||||
# transpose needs list of ints
|
||||
key_shape_builtin = symbolic_helper._get_tensor_rank(key)
|
||||
key_transposed_axes = list(range(key_shape_builtin))
|
||||
key_transposed_axes[-1], key_transposed_axes[-2] = (
|
||||
key_transposed_axes[-2],
|
||||
key_transposed_axes[-1],
|
||||
)
|
||||
key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
|
||||
# Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
|
||||
query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
|
||||
key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale))
|
||||
mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled)
|
||||
|
||||
if symbolic_helper._is_none(attn_mask):
|
||||
mul_qk_add = mul_qk
|
||||
elif (
|
||||
_type_utils.JitScalarType.from_value(attn_mask)
|
||||
== _type_utils.JitScalarType.BOOL
|
||||
):
|
||||
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
|
||||
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
|
||||
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
|
||||
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
|
||||
mul_qk_add = g.op("Add", mul_qk, attn_mask)
|
||||
elif (
|
||||
_type_utils.JitScalarType.from_value(attn_mask)
|
||||
== _type_utils.JitScalarType.FLOAT
|
||||
):
|
||||
mul_qk_add = g.op("Add", mul_qk, attn_mask)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}"
|
||||
)
|
||||
|
||||
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
|
||||
|
||||
if dropout_p != 0:
|
||||
attn_weight = g.op(
|
||||
"Dropout",
|
||||
attn_weight,
|
||||
g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)),
|
||||
)
|
||||
|
||||
return g.op("MatMul", attn_weight, value)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _attention_scale(
|
||||
g: jit_utils.GraphContext, query: torch._C.Value
|
||||
) -> torch._C.Value:
|
||||
"""Calculate the scale factor for the attention result.
|
||||
|
||||
Args:
|
||||
query: Tensor of shape [..., L, E]
|
||||
|
||||
Returns:
|
||||
Scalar scale factor := 1 / math.sqrt(query.size(-1))
|
||||
"""
|
||||
query_shape = g.op("Shape", query)
|
||||
query_shape_last = g.op(
|
||||
"Slice",
|
||||
query_shape,
|
||||
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)),
|
||||
g.op(
|
||||
"Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64)
|
||||
),
|
||||
)
|
||||
embedding_size = g.op(
|
||||
"Cast",
|
||||
query_shape_last,
|
||||
to_i=_type_utils.JitScalarType.from_value(query).onnx_type(),
|
||||
)
|
||||
const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float))
|
||||
scale = g.op("Div", const_one, g.op("Sqrt", embedding_size))
|
||||
return scale
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _causal_attention_mask(
|
||||
g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value
|
||||
) -> torch._C.Value:
|
||||
"""Create a causal mask for the given query and key tensors.
|
||||
|
||||
Equivalent to::
|
||||
mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
||||
attn_mask = torch.zeros(L, S, dtype=torch.float)
|
||||
attn_mask = attn_mask.masked_fill(not mask, -float('inf'))
|
||||
|
||||
Args:
|
||||
query: Tensor of shape [..., L, E]
|
||||
key: Tensor of shape [..., S, E]
|
||||
|
||||
Returns:
|
||||
Tensor of shape [L, S]
|
||||
"""
|
||||
|
||||
query_shape = g.op("Shape", query)
|
||||
key_shape = g.op("Shape", key)
|
||||
|
||||
last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||||
second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64))
|
||||
target_length = g.op("Slice", query_shape, second_last_idx, last_idx)
|
||||
source_length = g.op("Slice", key_shape, second_last_idx, last_idx)
|
||||
# attn_mask = torch.ones(L, S) := {
|
||||
size = g.op("Concat", target_length, source_length, axis_i=0)
|
||||
const_one = g.op("Constant", value_t=torch.tensor([1.0]))
|
||||
attn_mask = g.op("Expand", const_one, size)
|
||||
# }
|
||||
attn_mask = g.op("Trilu", attn_mask, upper_i=0)
|
||||
# The causal mask has 0s in the lower triangle and -inf in the upper triangle.
|
||||
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
|
||||
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
|
||||
attn_mask = g.op(
|
||||
"Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero
|
||||
)
|
||||
return attn_mask
|
||||
|
||||
Reference in New Issue
Block a user