mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement Attention-23 using sdpa and flexattention. - I used copilot for this. - Also updated the conversion logic to remove trailing None inputs. @gramalingam @kunal-vaishnavi @titaiwangms Pull Request resolved: https://github.com/pytorch/pytorch/pull/156431 Approved by: https://github.com/titaiwangms Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
320 lines
12 KiB
Python
320 lines
12 KiB
Python
"""Implementation of symbolic FX ops to represent arbitrary ONNX ops.
|
|
|
|
This module provides a way to create symbolic FX operators that can represent
|
|
arbitrary ONNX operators.
|
|
|
|
The operators are called "symbolic" because they don't do any actual computation
|
|
but instead serve as placeholders in the computation graph.
|
|
|
|
Each implementation contains two parts: A "real" implementation that produce all
|
|
zeros based on the input shape and dtype, and a "fake" implementation that does more
|
|
or less the same thing but is required by the `torch.library.custom_op` interface.
|
|
"""
|
|
|
|
# flake8: noqa: B950
|
|
import dataclasses
|
|
from collections.abc import Sequence
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from torch.onnx.ops import _dtype_mappings
|
|
|
|
|
|
_INT_TYPE = "i"
|
|
_FLOAT_TYPE = "f"
|
|
_STRING_TYPE = "s"
|
|
_INT_SEQ_TYPE = "is"
|
|
_FLOAT_SEQ_TYPE = "fs"
|
|
_STRING_SEQ_TYPE = "ss"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class EncodedAttrs:
|
|
"""Class to encode attributes from dictionary into lists of FX compatible attributes.
|
|
|
|
Since FX does not support dictionaries, we need to encode the attributes into
|
|
lists. This class provides a way to encode and decode the attributes.
|
|
|
|
Attributes:
|
|
attr_keys: List of attribute keys.
|
|
attr_types: List of attribute types. Values can be "i" (int), "f" (float),
|
|
"s" (string), "is" (int sequence), "fs" (float sequence), or "ss" (string sequence).
|
|
attr_pos: List of tuples representing the start and end positions of each
|
|
attribute in the corresponding list.
|
|
attr_ints: List of integer attributes.
|
|
attr_floats: List of float attributes.
|
|
attr_strs: List of string attributes.
|
|
"""
|
|
|
|
attr_keys: list[str]
|
|
attr_types: list[str]
|
|
attr_pos: list[tuple[int, int]]
|
|
attr_ints: list[int]
|
|
attr_floats: list[float]
|
|
attr_strs: list[str]
|
|
|
|
@classmethod
|
|
def from_dict(
|
|
cls,
|
|
attrs: dict[
|
|
str,
|
|
Union[
|
|
int,
|
|
float,
|
|
str,
|
|
bool,
|
|
Sequence[int],
|
|
Sequence[float],
|
|
Sequence[str],
|
|
Sequence[bool],
|
|
],
|
|
],
|
|
) -> "EncodedAttrs":
|
|
encoded = cls(
|
|
attr_keys=[],
|
|
attr_types=[],
|
|
attr_pos=[],
|
|
attr_ints=[],
|
|
attr_floats=[],
|
|
attr_strs=[],
|
|
)
|
|
for i, (k, v) in enumerate(attrs.items()):
|
|
encoded.attr_keys.append(k)
|
|
if isinstance(v, int):
|
|
start_pos = len(encoded.attr_ints)
|
|
encoded.attr_ints.append(v)
|
|
encoded.attr_pos.append((start_pos, start_pos + 1))
|
|
encoded.attr_types.append(_INT_TYPE)
|
|
elif isinstance(v, float):
|
|
start_pos = len(encoded.attr_floats)
|
|
encoded.attr_floats.append(v)
|
|
encoded.attr_pos.append((start_pos, start_pos + 1))
|
|
encoded.attr_types.append(_FLOAT_TYPE)
|
|
elif isinstance(v, str):
|
|
start_pos = len(encoded.attr_strs)
|
|
encoded.attr_strs.append(v)
|
|
encoded.attr_pos.append((start_pos, start_pos + 1))
|
|
encoded.attr_types.append(_STRING_TYPE)
|
|
elif isinstance(v, Sequence):
|
|
if len(v) == 0:
|
|
raise ValueError(f"Empty sequence for attribute {k}")
|
|
if any(isinstance(elem, float) for elem in v):
|
|
start_pos = len(encoded.attr_floats)
|
|
encoded.attr_floats.extend([float(elem) for elem in v])
|
|
encoded.attr_pos.append((start_pos, start_pos + len(v)))
|
|
encoded.attr_types.append(_FLOAT_SEQ_TYPE)
|
|
elif isinstance(v[0], int):
|
|
start_pos = len(encoded.attr_ints)
|
|
encoded.attr_ints.extend([int(elem) for elem in v])
|
|
encoded.attr_pos.append((start_pos, start_pos + len(v)))
|
|
encoded.attr_types.append(_INT_SEQ_TYPE)
|
|
elif isinstance(v[0], str):
|
|
start_pos = len(encoded.attr_strs)
|
|
encoded.attr_strs.extend([str(elem) for elem in v])
|
|
encoded.attr_pos.append((start_pos, start_pos + len(v)))
|
|
encoded.attr_types.append(_STRING_SEQ_TYPE)
|
|
else:
|
|
raise ValueError(f"Unsupported sequence type for attribute {k}")
|
|
else:
|
|
raise ValueError(f"Unsupported attribute type for {k}: {type(v)}")
|
|
assert len(encoded.attr_keys) == len(encoded.attr_types), (
|
|
f"Mismatch between number of attribute keys and types: {len(encoded.attr_keys)} != {len(encoded.attr_types)}"
|
|
)
|
|
assert len(encoded.attr_keys) == len(encoded.attr_pos), (
|
|
f"Mismatch between number of attribute keys and positions: {len(encoded.attr_keys)} != {len(encoded.attr_pos)}"
|
|
)
|
|
return encoded
|
|
|
|
def to_dict(
|
|
self,
|
|
) -> dict[
|
|
str,
|
|
Union[
|
|
int,
|
|
float,
|
|
str,
|
|
list[int],
|
|
list[float],
|
|
list[str],
|
|
],
|
|
]:
|
|
"""Convert the encoded attributes back to a dictionary for creating an ONNX node."""
|
|
attrs: dict[
|
|
str,
|
|
Union[
|
|
int,
|
|
float,
|
|
str,
|
|
list[int],
|
|
list[float],
|
|
list[str],
|
|
],
|
|
] = {}
|
|
for i, key in enumerate(self.attr_keys):
|
|
attr_type = self.attr_types[i]
|
|
if attr_type == _INT_TYPE:
|
|
attrs[key] = self.attr_ints[self.attr_pos[i][0]]
|
|
elif attr_type == _FLOAT_TYPE:
|
|
attrs[key] = self.attr_floats[self.attr_pos[i][0]]
|
|
elif attr_type == _STRING_TYPE:
|
|
attrs[key] = self.attr_strs[self.attr_pos[i][0]]
|
|
elif attr_type == _FLOAT_SEQ_TYPE:
|
|
attrs[key] = self.attr_floats[self.attr_pos[i][0] : self.attr_pos[i][1]]
|
|
elif attr_type == _INT_SEQ_TYPE:
|
|
attrs[key] = self.attr_ints[self.attr_pos[i][0] : self.attr_pos[i][1]]
|
|
elif attr_type == _STRING_SEQ_TYPE:
|
|
attrs[key] = self.attr_strs[self.attr_pos[i][0] : self.attr_pos[i][1]]
|
|
else:
|
|
raise ValueError(f"Unsupported attribute type: {attr_type}")
|
|
return attrs
|
|
|
|
|
|
@torch.library.custom_op(
|
|
"onnx_symbolic::_symbolic",
|
|
mutates_args=(),
|
|
schema=(
|
|
"(Tensor?[] inputs, str op_type, int onnx_dtype, *,"
|
|
" SymInt[] shape, str[] attr_keys, str[] attr_types, int[][] attr_pos,"
|
|
" int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys,"
|
|
" str[] metadata_props_values, str domain='', int? version=None"
|
|
") -> Tensor"
|
|
),
|
|
)
|
|
def _symbolic(
|
|
inputs: Sequence[Optional[torch.Tensor]],
|
|
op_type: str,
|
|
onnx_dtype: int,
|
|
*,
|
|
shape: Sequence[Union[int, torch.SymInt]],
|
|
attr_keys: Sequence[str],
|
|
attr_types: Sequence[str],
|
|
attr_pos: Sequence[tuple[int, int]],
|
|
attr_ints: Sequence[int],
|
|
attr_floats: Sequence[float],
|
|
attr_strs: Sequence[str],
|
|
metadata_props_keys: Sequence[str] = (),
|
|
metadata_props_values: Sequence[str] = (),
|
|
domain: str = "",
|
|
version: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
torch._check(
|
|
onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE,
|
|
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
|
)
|
|
return torch.zeros(
|
|
shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]
|
|
)
|
|
|
|
|
|
@_symbolic.register_fake
|
|
def _(
|
|
inputs: Sequence[torch.Tensor],
|
|
op_type: str,
|
|
onnx_dtype: int,
|
|
*,
|
|
shape: Sequence[Union[int, torch.SymInt]],
|
|
attr_keys: Sequence[str],
|
|
attr_types: Sequence[str],
|
|
attr_pos: Sequence[tuple[int, int]],
|
|
attr_ints: Sequence[int],
|
|
attr_floats: Sequence[float],
|
|
attr_strs: Sequence[str],
|
|
metadata_props_keys: Sequence[str] = (),
|
|
metadata_props_values: Sequence[str] = (),
|
|
domain: str = "",
|
|
version: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
torch._check(
|
|
onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE,
|
|
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
|
)
|
|
# NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured
|
|
# out how it can handle empty shapes
|
|
return torch.zeros(
|
|
shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]
|
|
)
|
|
|
|
|
|
@torch.library.custom_op(
|
|
"onnx_symbolic::_symbolic_multi_out",
|
|
mutates_args=(),
|
|
schema=(
|
|
"(Tensor?[] inputs, str op_type, int[] onnx_dtypes, *,"
|
|
" SymInt[][] shapes, str[] attr_keys, str[] attr_types, int[][] attr_pos,"
|
|
" int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys,"
|
|
" str[] metadata_props_values, str domain='', int? version=None"
|
|
") -> Tensor[]"
|
|
),
|
|
)
|
|
def _symbolic_multi_out(
|
|
inputs: Sequence[Optional[torch.Tensor]],
|
|
op_type: str,
|
|
onnx_dtypes: Sequence[int],
|
|
*,
|
|
shapes: Sequence[Sequence[Union[int, torch.SymInt]]],
|
|
attr_keys: Sequence[str],
|
|
attr_types: Sequence[str],
|
|
attr_pos: Sequence[tuple[int, int]],
|
|
attr_ints: Sequence[int],
|
|
attr_floats: Sequence[float],
|
|
attr_strs: Sequence[str],
|
|
metadata_props_keys: Sequence[str] = (),
|
|
metadata_props_values: Sequence[str] = (),
|
|
domain: str = "",
|
|
version: Optional[int] = None,
|
|
) -> list[torch.Tensor]:
|
|
outputs = []
|
|
torch._check(
|
|
len(shapes) == len(onnx_dtypes),
|
|
lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})",
|
|
)
|
|
for shape, onnx_dtype in zip(shapes, onnx_dtypes):
|
|
torch._check(
|
|
onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE,
|
|
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
|
)
|
|
outputs.append(
|
|
torch.zeros(
|
|
shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]
|
|
)
|
|
)
|
|
return outputs
|
|
|
|
|
|
@_symbolic_multi_out.register_fake
|
|
def _(
|
|
inputs: Sequence[torch.Tensor],
|
|
op_type: str,
|
|
onnx_dtypes: Sequence[int],
|
|
*,
|
|
shapes: Sequence[Sequence[Union[int, torch.SymInt]]],
|
|
attr_keys: Sequence[str],
|
|
attr_types: Sequence[str],
|
|
attr_pos: Sequence[tuple[int, int]],
|
|
attr_ints: Sequence[int],
|
|
attr_floats: Sequence[float],
|
|
attr_strs: Sequence[str],
|
|
metadata_props_keys: Sequence[str] = (),
|
|
metadata_props_values: Sequence[str] = (),
|
|
domain: str = "",
|
|
version: Optional[int] = None,
|
|
) -> list[torch.Tensor]:
|
|
outputs = []
|
|
torch._check(
|
|
len(shapes) == len(onnx_dtypes),
|
|
lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})",
|
|
)
|
|
for shape, onnx_dtype in zip(shapes, onnx_dtypes):
|
|
torch._check(
|
|
onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE,
|
|
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
|
)
|
|
# NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured
|
|
# out how it can handle empty shapes
|
|
outputs.append(
|
|
torch.zeros(
|
|
shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]
|
|
)
|
|
)
|
|
return outputs
|