mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Supporting different opset versions for torchlib registry (#149901)
- Allows opset_version to determine which onnx decomposition to choose - Adds a cleanup function to modify the registry after it is built Pull Request resolved: https://github.com/pytorch/pytorch/pull/149901 Approved by: https://github.com/justinchuby, https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
97a5e5c6b3
commit
1a56609e75
@ -246,6 +246,31 @@ class TestExportAPIDynamo(common_utils.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_upgraded_torchlib_impl(self):
|
||||
class GeluModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
# Use GELU activation function
|
||||
return torch.nn.functional.gelu(input, approximate="tanh")
|
||||
|
||||
input = torch.randn(1, 3, 4, 4)
|
||||
onnx_program_op18 = torch.onnx.export(
|
||||
GeluModel(),
|
||||
input,
|
||||
dynamo=True,
|
||||
)
|
||||
all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph]
|
||||
self.assertIn("Tanh", all_nodes_op18)
|
||||
self.assertNotIn("Gelu", all_nodes_op18)
|
||||
|
||||
onnx_program_op20 = torch.onnx.export(
|
||||
GeluModel(),
|
||||
input,
|
||||
opset_version=20,
|
||||
dynamo=True,
|
||||
)
|
||||
all_nodes_op20 = [n.op_type for n in onnx_program_op20.model.graph]
|
||||
self.assertIn("Gelu", all_nodes_op20)
|
||||
|
||||
def test_refine_dynamic_shapes_with_onnx_export(self):
|
||||
# NOTE: From test/export/test_export.py
|
||||
|
||||
|
||||
@ -52,6 +52,7 @@ FLOAT_TYPES = (
|
||||
torch.float64,
|
||||
)
|
||||
|
||||
|
||||
TEST_OPSET_VERSION = 18
|
||||
IS_MACOS = sys.platform.startswith("darwin")
|
||||
IS_WINDOWS = os.name == "nt"
|
||||
@ -487,6 +488,7 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -
|
||||
def graph_executor(
|
||||
test_name: str,
|
||||
outputs: Sequence[Any],
|
||||
opset_version: int = TEST_OPSET_VERSION,
|
||||
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]:
|
||||
"""Eagerly executes a function."""
|
||||
|
||||
@ -500,10 +502,10 @@ def graph_executor(
|
||||
(),
|
||||
(),
|
||||
nodes=(),
|
||||
opset_imports={"": 18, "pkg.torch.onnx": 1},
|
||||
opset_imports={"": opset_version, "pkg.torch.onnx": 1},
|
||||
name="main_graph",
|
||||
)
|
||||
opset = onnxscript.opset18
|
||||
opset = onnxscript.values.Opset("", opset_version)
|
||||
tracer = _building.OpRecorder(opset, {})
|
||||
ort_inputs = {}
|
||||
onnxscript_args: list[Any] = []
|
||||
@ -590,7 +592,7 @@ def graph_executor(
|
||||
proto = onnxscript_function.to_function_proto()
|
||||
ir_function = ir.serde.deserialize_function(proto)
|
||||
onnx_model.functions[identifier] = ir_function
|
||||
_ir_passes.add_torchlib_common_imports(onnx_model)
|
||||
_ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version)
|
||||
_ir_passes.add_opset_imports(onnx_model)
|
||||
# Make sure the model is valid
|
||||
model_proto = ir.to_proto(onnx_model)
|
||||
|
||||
@ -46,7 +46,7 @@ import numpy as np
|
||||
import ops_test_common
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops
|
||||
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops, nn as nn_ops
|
||||
from torch.testing._internal import common_methods_invocations
|
||||
from torch.testing._internal.opinfo import definitions as opinfo_definitions
|
||||
|
||||
@ -78,6 +78,12 @@ class TorchLibOpInfo:
|
||||
compare_shape_only_for_output: tuple[int, ...] = ()
|
||||
# Whether the function is designed for complex inputs
|
||||
complex: bool = False
|
||||
# The ONNX opset version in which the function was introduced.
|
||||
# Its specifies the minimum ONNX opset version required to use the function.
|
||||
# It ensures that the function is only used when the target ONNX opset version
|
||||
# is compatible. For example, if `opset_introduced=20`, the function will only
|
||||
# be used when exporting to ONNX models targeting opset version 20 or higher.
|
||||
opset_introduced: int = 18
|
||||
# The acceptable tolerance of the inference result difference between PyTorch and ORT.
|
||||
# Format: {dtype: (rtol, atol)}.
|
||||
# For example: {torch.float16: (1e-3, 1e-3)}
|
||||
@ -447,8 +453,10 @@ TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
|
||||
TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True),
|
||||
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
|
||||
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
|
||||
TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20),
|
||||
)
|
||||
|
||||
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
|
||||
ops_test_common.duplicate_opinfo(
|
||||
@ -500,6 +508,7 @@ ops_test_common.duplicate_opinfo(
|
||||
"nn.functional.replication_pad3d",
|
||||
),
|
||||
)
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "nn.functional.gelu", ("gelu_op20",))
|
||||
ops_test_common.duplicate_opinfo(
|
||||
OPS_DB,
|
||||
"nn.functional.scaled_dot_product_attention",
|
||||
|
||||
@ -220,7 +220,9 @@ def run_test_output_match(
|
||||
|
||||
test_name = test_suite.id()
|
||||
function_output, model_proto = function_executor(
|
||||
test_name, reference_torch_outputs
|
||||
test_name,
|
||||
reference_torch_outputs,
|
||||
opset_version=torchlib_op_info.opset_introduced,
|
||||
)(onnx_function, input_onnx, kwargs_onnx)
|
||||
# Finally we re-flatten everything
|
||||
# TODO: add pytree structure comparison.
|
||||
|
||||
@ -50,7 +50,7 @@ def export_compat(
|
||||
verbose: bool | None = None,
|
||||
input_names: Sequence[str] | None = None,
|
||||
output_names: Sequence[str] | None = None,
|
||||
opset_version: int | None = None,
|
||||
opset_version: int | None = _constants.TORCHLIB_OPSET,
|
||||
custom_translation_table: dict[Callable, Callable | Sequence[Callable]]
|
||||
| None = None,
|
||||
dynamic_axes: Mapping[str, Mapping[int, str]]
|
||||
@ -105,8 +105,7 @@ def export_compat(
|
||||
dynamic_shapes_with_export_dim, need_axis_mapping = (
|
||||
_dynamic_shapes.convert_str_to_export_dim(dynamic_shapes)
|
||||
)
|
||||
|
||||
registry = _registration.ONNXRegistry.from_torchlib()
|
||||
registry = _registration.ONNXRegistry().from_torchlib(opset_version=opset_version)
|
||||
if custom_translation_table is not None:
|
||||
for torch_op, onnx_ops in custom_translation_table.items():
|
||||
# TODO(justinchuby): Support complex inputs with annotations
|
||||
|
||||
@ -90,7 +90,9 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None:
|
||||
value.shape = ir.Shape(new_shape)
|
||||
|
||||
|
||||
def add_torchlib_common_imports(model: ir.Model) -> None:
|
||||
def add_torchlib_common_imports(
|
||||
model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET
|
||||
) -> None:
|
||||
"""Hack to add torchlib common imports to the model."""
|
||||
|
||||
try:
|
||||
@ -99,9 +101,11 @@ def add_torchlib_common_imports(model: ir.Model) -> None:
|
||||
|
||||
model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
|
||||
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
|
||||
rank_func.opset_imports[""] = opset_version
|
||||
is_scalar_func = ir.serde.deserialize_function(
|
||||
common_ops.IsScalar.to_function_proto()
|
||||
)
|
||||
is_scalar_func.opset_imports[""] = opset_version
|
||||
model.functions[rank_func.identifier()] = rank_func
|
||||
model.functions[is_scalar_func.identifier()] = is_scalar_func
|
||||
except Exception:
|
||||
|
||||
@ -42,6 +42,9 @@ class OnnxDecompMeta:
|
||||
signature: The ONNX signature of the function. When None, the signature is inferred.
|
||||
is_custom: Whether the function is a custom function.
|
||||
is_complex: Whether the function is a function that handles complex valued inputs.
|
||||
opset_introduced:
|
||||
The ONNX opset version in which the function was introduced.
|
||||
Its specifies the minimum ONNX opset version required to use the function.
|
||||
device: The device the function is registered to. If None, it is registered to all devices.
|
||||
skip_signature_inference: Whether to skip signature inference for the function.
|
||||
"""
|
||||
@ -51,6 +54,7 @@ class OnnxDecompMeta:
|
||||
signature: _schemas.OpSignature | None
|
||||
is_custom: bool = False
|
||||
is_complex: bool = False
|
||||
opset_introduced: int = 18
|
||||
device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
|
||||
skip_signature_inference: bool = False
|
||||
|
||||
@ -150,13 +154,14 @@ class ONNXRegistry:
|
||||
return self._opset_version
|
||||
|
||||
@classmethod
|
||||
def from_torchlib(cls) -> ONNXRegistry:
|
||||
def from_torchlib(cls, opset_version=_constants.TORCHLIB_OPSET) -> ONNXRegistry:
|
||||
"""Populates the registry with ATen functions from torchlib.
|
||||
|
||||
Args:
|
||||
torchlib_registry: The torchlib registry to use for populating the registry.
|
||||
"""
|
||||
registry = cls()
|
||||
registry._opset_version = opset_version
|
||||
for meta in _torchlib_registry.get_torchlib_ops():
|
||||
registry._register(meta.fx_target, meta)
|
||||
|
||||
@ -185,6 +190,7 @@ class ONNXRegistry:
|
||||
logger.exception("Failed to register '%s'. Skipped", qualified_name)
|
||||
continue
|
||||
|
||||
registry._cleanup_registry_based_on_opset_version()
|
||||
return registry
|
||||
|
||||
def _register(
|
||||
@ -274,5 +280,24 @@ class ONNXRegistry:
|
||||
"""
|
||||
return bool(self.get_decomps(target))
|
||||
|
||||
def _cleanup_registry_based_on_opset_version(self) -> None:
|
||||
"""Pick the implementation with the highest opset version valid until the current opset version."""
|
||||
cleaned_functions = {}
|
||||
for target_or_name, decomps in self.functions.items():
|
||||
# Filter decompositions to only include those with opset_introduced <= opset_version
|
||||
decomps = [d for d in decomps if d.opset_introduced <= self.opset_version]
|
||||
|
||||
# Keep only the decomposition with the highest opset_introduced
|
||||
if decomps:
|
||||
# Find the maximum opset_introduced
|
||||
max_opset = max(d.opset_introduced for d in decomps)
|
||||
|
||||
# Keep all decompositions with the maximum opset_introduced
|
||||
cleaned_functions[target_or_name] = [
|
||||
d for d in decomps if d.opset_introduced == max_opset
|
||||
]
|
||||
|
||||
self.functions = cleaned_functions
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(functions={self.functions})"
|
||||
|
||||
@ -30,6 +30,7 @@ def onnx_impl(
|
||||
*,
|
||||
trace_only: bool = False,
|
||||
complex: bool = False,
|
||||
opset_introduced: int = 18,
|
||||
no_compile: bool = False,
|
||||
private: bool = False,
|
||||
) -> Callable[[_T], _T]:
|
||||
@ -74,6 +75,7 @@ def onnx_impl(
|
||||
fx_target=t,
|
||||
signature=None,
|
||||
is_complex=complex,
|
||||
opset_introduced=opset_introduced,
|
||||
skip_signature_inference=no_compile,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
__all__ = ["core", "hop", "symbolic"]
|
||||
__all__ = ["core", "hop", "nn", "symbolic"]
|
||||
|
||||
from torch.onnx._internal.exporter._torchlib.ops import core, hop, symbolic
|
||||
from torch.onnx._internal.exporter._torchlib.ops import core, hop, nn, symbolic
|
||||
|
||||
26
torch/onnx/_internal/exporter/_torchlib/ops/nn.py
Normal file
26
torch/onnx/_internal/exporter/_torchlib/ops/nn.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""torch.ops.aten operators under the `core` module."""
|
||||
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
||||
# ruff: noqa: TCH001,TCH002
|
||||
# flake8: noqa
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from onnxscript.onnx_opset import opset20 as op20
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal
|
||||
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20)
|
||||
def aten_gelu_opset20(
|
||||
self: TReal,
|
||||
approximate: str = "none",
|
||||
) -> TReal:
|
||||
"""gelu(Tensor self, *, bool approximate=False) -> Tensor"""
|
||||
return op20.Gelu(self, approximate=approximate)
|
||||
Reference in New Issue
Block a user