mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Follows #164104 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164206 Approved by: https://github.com/albanD
468 lines
20 KiB
Python
468 lines
20 KiB
Python
"""ONNX operators as native torch.fx operators.
|
|
|
|
This module provides a set of functions to create ONNX operators in the FX graph
|
|
which are exportable to ONNX.
|
|
"""
|
|
|
|
# flake8: noqa: B950
|
|
from __future__ import annotations
|
|
|
|
|
|
__all__ = [
|
|
"aten_decompositions",
|
|
"symbolic",
|
|
"symbolic_multi_out",
|
|
"rotary_embedding",
|
|
"attention",
|
|
]
|
|
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch.onnx.ops import _impl, _symbolic_impl
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Callable, Sequence
|
|
|
|
|
|
# https://github.com/onnx/onnx/blob/f542e1f06699ea7e1db5f62af53355b64338c723/onnx/onnx.proto#L597
|
|
_TORCH_DTYPE_TO_ONNX_DTYPE = {
|
|
torch.float32: 1, # FLOAT
|
|
torch.uint8: 2, # UINT8
|
|
torch.int8: 3, # INT8
|
|
torch.uint16: 4, # UINT16
|
|
torch.int16: 5, # INT16
|
|
torch.int32: 6, # INT32
|
|
torch.int64: 7, # INT64
|
|
str: 8, # STRING
|
|
torch.bool: 9, # BOOL
|
|
torch.float16: 10, # FLOAT16
|
|
torch.double: 11, # DOUBLE
|
|
torch.uint32: 12, # UINT32
|
|
torch.uint64: 13, # UINT64
|
|
torch.complex64: 14, # COMPLEX64
|
|
torch.complex128: 15, # COMPLEX128
|
|
torch.bfloat16: 16, # BFLOAT16
|
|
torch.float8_e4m3fn: 17, # FLOAT8E4M3FN
|
|
torch.float8_e4m3fnuz: 18, # FLOAT8E4M3FNUZ
|
|
torch.float8_e5m2: 19, # FLOAT8E5M2
|
|
torch.float8_e5m2fnuz: 20, # FLOAT8E5M2FNUZ
|
|
# 21 = UINT4
|
|
# 22 = INT4
|
|
torch.float4_e2m1fn_x2: 23, # FLOAT4E2M1
|
|
}
|
|
|
|
|
|
def aten_decompositions() -> dict[torch._ops.OpOverload, Callable]:
|
|
"""Return the ONNX to ATen decomp table."""
|
|
return _impl.ONNX_ATEN_DECOMP_TABLE
|
|
|
|
|
|
def _parse_domain_op_type(domain_op: str) -> tuple[str, str]:
|
|
split = domain_op.split("::", 1)
|
|
if len(split) == 1:
|
|
domain = ""
|
|
op_type = split[0]
|
|
else:
|
|
domain = split[0]
|
|
op_type = split[1]
|
|
return domain, op_type
|
|
|
|
|
|
def symbolic(
|
|
domain_op: str,
|
|
/,
|
|
inputs: Sequence[torch.Tensor | None],
|
|
attrs: dict[
|
|
str,
|
|
int
|
|
| float
|
|
| str
|
|
| bool
|
|
| Sequence[int]
|
|
| Sequence[float]
|
|
| Sequence[str]
|
|
| Sequence[bool],
|
|
]
|
|
| None = None,
|
|
*,
|
|
dtype: torch.dtype | int,
|
|
shape: Sequence[int | torch.SymInt],
|
|
version: int | None = None,
|
|
metadata_props: dict[str, str] | None = None,
|
|
) -> torch.Tensor:
|
|
"""Create a symbolic FX operator to represent an arbitrary ONNX operator.
|
|
|
|
This function is used to create a symbolic operator with a single output.
|
|
To create an operator with multiple outputs, use :func:`symbolic_multi_out`.
|
|
|
|
You may use ``if torch.onnx.is_in_onnx_export()`` to conditionally enable the
|
|
symbolic logic only during ``torch.onnx.export()``.
|
|
|
|
Example::
|
|
|
|
class CustomOp(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Normal torch operators can interleave with the symbolic ops during ONNX export
|
|
x = x + 1
|
|
|
|
# Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain.
|
|
# The output tensor will have the specified dtype and shape
|
|
val = torch.onnx.ops.symbolic(
|
|
"custom_domain::CustomOp",
|
|
(x,),
|
|
dict(attr_key="attr_value"),
|
|
dtype=x.dtype,
|
|
shape=x.shape,
|
|
version=1,
|
|
)
|
|
|
|
# The result of the symbolic op can be used in normal torch operations during ONNX export
|
|
return torch.nn.functional.relu(val)
|
|
|
|
|
|
# You may then export this model to ONNX using torch.onnx.export(..., dynamo=True).
|
|
|
|
Args:
|
|
domain_op: The domain and operator name, separated by "::". For example,
|
|
"custom_domain::CustomOp".
|
|
inputs: The input tensors to the operator.
|
|
attrs: The attributes of the operator. The keys are attribute names and
|
|
the values are attribute values. Valid attribute types are int, float,
|
|
str, bool, and lists of int, float, str, and bool. Tensor attributes
|
|
are unsupported.
|
|
dtype: The data type of the output tensor.This can be either a torch.dtype
|
|
or an integer representing the ONNX data type.
|
|
shape: The shape of the output tensor. This can be a list of integers or
|
|
SymInt values.
|
|
version: The version of the opset used for the operator.
|
|
metadata_props: Metadata properties for the ONNX node.
|
|
This is a dictionary of str-str pairs.
|
|
|
|
Returns:
|
|
The output tensor of the operator.
|
|
"""
|
|
if not isinstance(dtype, int):
|
|
torch._check(
|
|
dtype in _TORCH_DTYPE_TO_ONNX_DTYPE, lambda: f"Unsupported dtype: {dtype}"
|
|
)
|
|
dtype = _TORCH_DTYPE_TO_ONNX_DTYPE[dtype]
|
|
domain, op_type = _parse_domain_op_type(domain_op)
|
|
if attrs is None:
|
|
attrs = {}
|
|
encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs)
|
|
# TODO: Parse domain
|
|
return _symbolic_impl._symbolic(
|
|
inputs,
|
|
op_type,
|
|
dtype,
|
|
shape=shape,
|
|
attr_keys=encoded_attrs.attr_keys,
|
|
attr_types=encoded_attrs.attr_types,
|
|
attr_pos=encoded_attrs.attr_pos,
|
|
attr_ints=encoded_attrs.attr_ints,
|
|
attr_floats=encoded_attrs.attr_floats,
|
|
attr_strs=encoded_attrs.attr_strs,
|
|
metadata_props_keys=metadata_props.keys() if metadata_props else [],
|
|
metadata_props_values=metadata_props.values() if metadata_props else [],
|
|
domain=domain,
|
|
version=version,
|
|
)
|
|
|
|
|
|
def symbolic_multi_out(
|
|
domain_op: str,
|
|
/,
|
|
inputs: Sequence[torch.Tensor | None],
|
|
attrs: dict[
|
|
str,
|
|
int
|
|
| float
|
|
| str
|
|
| bool
|
|
| Sequence[int]
|
|
| Sequence[float]
|
|
| Sequence[str]
|
|
| Sequence[bool],
|
|
]
|
|
| None = None,
|
|
*,
|
|
dtypes: Sequence[torch.dtype | int],
|
|
shapes: Sequence[Sequence[int | torch.SymInt]],
|
|
version: int | None = None,
|
|
metadata_props: dict[str, str] | None = None,
|
|
) -> Sequence[torch.Tensor]:
|
|
"""Create a symbolic FX operator to represent an arbitrary ONNX operator with multiple outputs.
|
|
|
|
You may use ``if torch.onnx.is_in_onnx_export()`` to conditionally enable the
|
|
symbolic logic only during ``torch.onnx.export()``.
|
|
|
|
Example::
|
|
|
|
class CustomOp(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Normal torch operators can interleave with the symbolic ops during ONNX export
|
|
x = x + 1
|
|
|
|
# Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain.
|
|
# The output tensors will have the specified dtypes and shapes
|
|
(out1, out2) = torch.onnx.ops.symbolic_multi_out(
|
|
"custom_domain::CustomOp",
|
|
(x,),
|
|
dict(attr_key="attr_value"),
|
|
dtypes=(x.dtype, torch.float32),
|
|
shapes=(x.shape, [1, 2, 3]),
|
|
version=1,
|
|
)
|
|
|
|
# The result of the symbolic op can be used in normal torch operations during ONNX export
|
|
return torch.nn.functional.relu(out1 + out2)
|
|
|
|
|
|
# You may then export this model to ONNX using torch.onnx.export(..., dynamo=True).
|
|
|
|
Args:
|
|
domain_op: The domain and operator name, separated by "::". For example,
|
|
"custom_domain::CustomOp".
|
|
inputs: The input tensors to the operator.
|
|
attrs: The attributes of the operator. The keys are attribute names and
|
|
the values are attribute values. Valid attribute types are int, float,
|
|
str, bool, and lists of int, float, str, and bool. Tensor attributes
|
|
are unsupported.
|
|
dtypes: The data types of the output tensors. This can be a list of
|
|
torch.dtype or integers representing the ONNX data types. The length
|
|
of this list must be the number of outputs.
|
|
shapes: The shapes of the output tensors. This can be a list of lists of
|
|
integers or SymInt values. The length of this list must be the number of outputs.
|
|
version: The version of the opset used for the operator.
|
|
metadata_props: Metadata properties for the ONNX node.
|
|
This is a dictionary of str-str pairs.
|
|
|
|
Returns:
|
|
A list of output tensors of the operator.
|
|
"""
|
|
torch._check(
|
|
len(shapes) == len(dtypes),
|
|
lambda: f"Number of shapes ({len(shapes)}) must match number of dtypes ({len(dtypes)})",
|
|
)
|
|
onnx_dtypes = []
|
|
for dtype in dtypes:
|
|
if not isinstance(dtype, int):
|
|
torch._check(
|
|
dtype in _TORCH_DTYPE_TO_ONNX_DTYPE,
|
|
lambda: f"Unsupported dtype: {dtype}",
|
|
)
|
|
onnx_dtypes.append(_TORCH_DTYPE_TO_ONNX_DTYPE[dtype])
|
|
else:
|
|
onnx_dtypes.append(dtype)
|
|
domain, op_type = _parse_domain_op_type(domain_op)
|
|
if attrs is None:
|
|
attrs = {}
|
|
encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs)
|
|
# Use the size of dtypes to determine the number of outputs
|
|
return _symbolic_impl._symbolic_multi_out(
|
|
inputs,
|
|
op_type,
|
|
onnx_dtypes,
|
|
shapes=shapes,
|
|
attr_keys=encoded_attrs.attr_keys,
|
|
attr_types=encoded_attrs.attr_types,
|
|
attr_pos=encoded_attrs.attr_pos,
|
|
attr_ints=encoded_attrs.attr_ints,
|
|
attr_floats=encoded_attrs.attr_floats,
|
|
attr_strs=encoded_attrs.attr_strs,
|
|
metadata_props_keys=metadata_props.keys() if metadata_props else [],
|
|
metadata_props_values=metadata_props.values() if metadata_props else [],
|
|
domain=domain,
|
|
version=version,
|
|
)
|
|
|
|
|
|
def rotary_embedding(
|
|
X: torch.Tensor,
|
|
cos_cache: torch.Tensor,
|
|
sin_cache: torch.Tensor,
|
|
position_ids: torch.Tensor | None = None,
|
|
*,
|
|
interleaved: bool = False,
|
|
num_heads: int = 0,
|
|
rotary_embedding_dim: int = 0,
|
|
) -> torch.Tensor:
|
|
"""RotaryEmbedding op in ONNX.
|
|
|
|
https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html
|
|
|
|
RotaryEmbedding is the implementation of rotary positional embeddings (RoPE) based on the paper https://arxiv.org/pdf/2104.09864.
|
|
The key advantage of RoPE is that it allows the model to understand both the absolute position of a token and the relative distances
|
|
between tokens. This is achieved through a rotational mechanism where the extent of rotation is computed based on the token's absolute position (position_ids).
|
|
|
|
The rotational mechanism is defined by sine and cosine functions that are used to represent the rotation angles.
|
|
For each token in the sequence, its positional embedding is computed by rotating its embedding vector. This is done by splitting the
|
|
embedding vector either into two halves or interleaving every alternate token and applying the rotation matrix to each half of the embedding vector.
|
|
The rotation matrix is parameterized by the token's position in the sequence. The rotated halves of the embedding vector are concatenated
|
|
to form the final positional embedding for each token. The rotated positional embeddings are used in the self-attention mechanism.
|
|
The rotation ensures that the model captures both absolute and relative positional information.
|
|
|
|
Args:
|
|
X: The input tensor representing the token embeddings. 4D tensor with
|
|
shape `(batch_size, num_heads, sequence_length, head_size)` or 3D tensor
|
|
with shape `(batch_size, sequence_length, hidden_size)`. For cases with
|
|
a 4D input tensor, `head_size` has to be even. For cases with a 3D input
|
|
tensor, `num_heads` attribute must be provided and `hidden_size` must
|
|
be an even multiple of `num_heads` where `hidden_size = num_heads * head_size`
|
|
cos_cache: The cosine values for the rotation. 2D tensor with shape `(max_position_id_plus_1, head_size / 2)`
|
|
for full rotation or `(max_position_id_plus_1, rotary_embedding_dim / 2)`
|
|
for partial rotation when `position_ids` are provided. 3D tensor with shape
|
|
`(batch_size, sequence_length, head_size / 2)` for full rotation or
|
|
`(batch_size, sequence_length, rotary_embedding_dim / 2)` for partial
|
|
rotation when `position_ids` are not provided. `max_position_id_plus_1`
|
|
is a parameter to the model.
|
|
sin_cache: The sine values for the rotation. 2D tensor with shape `(max_position_id_plus_1, head_size / 2)`
|
|
for full rotation or `(max_position_id_plus_1, rotary_embedding_dim / 2)`
|
|
for partial rotation when `position_ids` are provided. 3D tensor with shape
|
|
`(batch_size, sequence_length, head_size / 2)` for full rotation or
|
|
`(batch_size, sequence_length, rotary_embedding_dim / 2)` for partial rotation
|
|
when `position_ids` are not provided. `max_position_id_plus_1` is a parameter
|
|
to the model.
|
|
position_ids: The position indices for the tokens. 2D tensor with shape
|
|
`(batch_size, sequence_length)`.
|
|
interleaved: Rotate using interleaved pattern. Default value is 0 (False).
|
|
num_heads: Number of attention heads. Must be provided when input is a 3D tensor.
|
|
rotary_embedding_dim: Rotary embedding dimension used to apply partial rotary embeddings.
|
|
|
|
Returns:
|
|
Tensor with same shape as input.
|
|
"""
|
|
return _impl.rotary_embedding_23(
|
|
X,
|
|
cos_cache,
|
|
sin_cache,
|
|
position_ids=position_ids,
|
|
interleaved=interleaved,
|
|
num_heads=num_heads,
|
|
rotary_embedding_dim=rotary_embedding_dim,
|
|
)
|
|
|
|
|
|
def attention(
|
|
Q: torch.Tensor,
|
|
K: torch.Tensor,
|
|
V: torch.Tensor,
|
|
attn_mask: torch.Tensor | None = None,
|
|
past_key: torch.Tensor | None = None,
|
|
past_value: torch.Tensor | None = None,
|
|
*,
|
|
is_causal: bool = False,
|
|
kv_num_heads: int = 0,
|
|
q_num_heads: int = 0,
|
|
qk_matmul_output_mode: int = 0,
|
|
scale: float | None = None,
|
|
softcap: float = 0.0,
|
|
softmax_precision: int | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Attention op in ONNX.
|
|
|
|
https://onnx.ai/onnx/operators/onnx__Attention.html
|
|
|
|
Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed.
|
|
|
|
This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V.
|
|
|
|
For self attention, ``kv_sequence_length`` equals to ``q_sequence_length``.
|
|
|
|
For cross attention, query and key might have different lengths.
|
|
|
|
This operator also covers the 3 following variants based on the number of heads:
|
|
|
|
1. Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, `q_num_heads = kv_num_heads`.
|
|
2. Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
|
|
3. Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, `q_num_heads > kv_num_heads`, `kv_num_heads=1`.
|
|
|
|
Attention bias to be added is calculated based on ``attn_mask`` input and ``is_causal` `attribute``, only one of which can be provided.
|
|
|
|
1. If ``is_causal`` is set to `1`, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment.
|
|
2. `attn_mask`: A boolean mask where a value of `True` indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score.
|
|
|
|
Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them.
|
|
The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided::
|
|
|
|
The following pattern is applied by this operator:
|
|
Q K V
|
|
| | |
|
|
Q*sqrt(scale) K*sqrt(scale) |
|
|
| | |
|
|
| Transpose |
|
|
| | |
|
|
---MatMul--- |
|
|
| |
|
|
at_mask---Add |
|
|
| |
|
|
softcap (if provided) |
|
|
| |
|
|
Softmax |
|
|
| |
|
|
-----MatMul------
|
|
|
|
|
Y
|
|
|
|
Args:
|
|
Q: Query tensor. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, head_size)` or 3D tensor
|
|
with shape `(batch_size, q_sequence_length, q_hidden_size)`. For cases with a 3D input tensor,
|
|
`q_hidden_size = q_num_heads * head_size`
|
|
K: Key tensor. 4D tensor with shape `(batch_size, kv_num_heads, kv_sequence_length, head_size)` or 3D tensor
|
|
with shape `(batch_size, kv_sequence_length, k_hidden_size)`. For cases with a 3D input tensor,
|
|
`k_hidden_size = kv_num_heads * head_size`
|
|
V: Value tensor. 4D tensor with shape `(batch_size, kv_num_heads, kv_sequence_length, v_head_size)` or 3D tensor
|
|
with shape `(batch_size, kv_sequence_length, v_hidden_size)`. For cases with a 3D input tensor,
|
|
`v_hidden_size = kv_num_heads * v_head_size`
|
|
attn_mask: Attention mask. Shape must be broadcastable to 4D tensor with shape
|
|
`(batch_size, q_num_heads, q_sequence_length, total_sequence_length)` where
|
|
`total_sequence_length = past_sequence_length + kv_sequence_length`. Two types of masks are supported.
|
|
A boolean mask where a value of True indicates that the element should take part in attention.
|
|
Also supports a float mask of the same type as query, key, value that is added to the attention score.
|
|
past_key: Past state cache for key with shape `(batch_size, kv_num_heads, past_sequence_length, head_size)`
|
|
past_value: Past state cache for value with shape `(batch_size, kv_num_heads, past_sequence_length, v_head_size)`
|
|
is_causal: If set to True, the attention masking is a lower triangular matrix when the mask is a square matrix.
|
|
The attention masking has the form of the upper left causal bias due to the alignment.
|
|
kv_num_heads: Number of heads of key and value. Must be used with 3D inputs of Q, K and V.
|
|
q_num_heads: Number of heads of query. Must be used with 3D inputs of Q, K and V.
|
|
qk_matmul_output_mode: If set to 0, qk_matmul_output is the output of qk matmul. If set to 1,
|
|
qk_matmul_output includes the addition of the attention mask to the output of qk matmul.
|
|
If set to 2, qk_matmul_output is the output after the softcap operation. If set to 3,
|
|
qk_matmul_output is the output after the softmax operation. Default value is 0.
|
|
scale: Scaling factor applied to Q*K^T. Default value is 1/sqrt(head_size). To prevent numerical overflow,
|
|
scale Q, K by sqrt(scale) before matmul.
|
|
softcap: Softcap value for attention weights. Default value is 0.
|
|
softmax_precision: The floating-point precision used in softmax computation. If softmax precision is not provided,
|
|
the same precision as the input of softmax (Q and K) is used.
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- The output tensor. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, v_head_size)` or 3D tensor
|
|
with shape `(batch_size, q_sequence_length, hidden_size)`. For cases with a 3D input tensor,
|
|
`hidden_size = q_num_heads * v_head_size`
|
|
- Updated key cache with shape `(batch_size, kv_num_heads, total_sequence_length, head_size)` where
|
|
`total_sequence_length = past_sequence_length + kv_sequence_length`.
|
|
- Updated value cache with shape `(batch_size, kv_num_heads, total_sequence_length, v_head_size)` where
|
|
`total_sequence_length = past_sequence_length + kv_sequence_length`.
|
|
- The output of QK matmul. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, total_sequence_length)`
|
|
where `total_sequence_length = past_sequence_length + kv_sequence_length`.
|
|
"""
|
|
return _impl.attention_23(
|
|
Q,
|
|
K,
|
|
V,
|
|
attn_mask=attn_mask,
|
|
past_key=past_key,
|
|
past_value=past_value,
|
|
is_causal=is_causal,
|
|
kv_num_heads=kv_num_heads,
|
|
q_num_heads=q_num_heads,
|
|
qk_matmul_output_mode=qk_matmul_output_mode,
|
|
scale=scale,
|
|
softcap=softcap,
|
|
softmax_precision=softmax_precision,
|
|
)
|