[ONNX] Supporting different opset versions for torchlib registry (#149901)

- Allows opset_version to determine which onnx decomposition to choose
- Adds a cleanup function to modify the registry after it is built

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149901
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms
This commit is contained in:
shubhambhokare1
2025-04-09 16:03:46 +00:00
committed by PyTorch MergeBot
parent 97a5e5c6b3
commit 1a56609e75
10 changed files with 106 additions and 12 deletions

View File

@ -246,6 +246,31 @@ class TestExportAPIDynamo(common_utils.TestCase):
)
)
def test_upgraded_torchlib_impl(self):
class GeluModel(torch.nn.Module):
def forward(self, input):
# Use GELU activation function
return torch.nn.functional.gelu(input, approximate="tanh")
input = torch.randn(1, 3, 4, 4)
onnx_program_op18 = torch.onnx.export(
GeluModel(),
input,
dynamo=True,
)
all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph]
self.assertIn("Tanh", all_nodes_op18)
self.assertNotIn("Gelu", all_nodes_op18)
onnx_program_op20 = torch.onnx.export(
GeluModel(),
input,
opset_version=20,
dynamo=True,
)
all_nodes_op20 = [n.op_type for n in onnx_program_op20.model.graph]
self.assertIn("Gelu", all_nodes_op20)
def test_refine_dynamic_shapes_with_onnx_export(self):
# NOTE: From test/export/test_export.py

View File

@ -52,6 +52,7 @@ FLOAT_TYPES = (
torch.float64,
)
TEST_OPSET_VERSION = 18
IS_MACOS = sys.platform.startswith("darwin")
IS_WINDOWS = os.name == "nt"
@ -487,6 +488,7 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -
def graph_executor(
test_name: str,
outputs: Sequence[Any],
opset_version: int = TEST_OPSET_VERSION,
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]:
"""Eagerly executes a function."""
@ -500,10 +502,10 @@ def graph_executor(
(),
(),
nodes=(),
opset_imports={"": 18, "pkg.torch.onnx": 1},
opset_imports={"": opset_version, "pkg.torch.onnx": 1},
name="main_graph",
)
opset = onnxscript.opset18
opset = onnxscript.values.Opset("", opset_version)
tracer = _building.OpRecorder(opset, {})
ort_inputs = {}
onnxscript_args: list[Any] = []
@ -590,7 +592,7 @@ def graph_executor(
proto = onnxscript_function.to_function_proto()
ir_function = ir.serde.deserialize_function(proto)
onnx_model.functions[identifier] = ir_function
_ir_passes.add_torchlib_common_imports(onnx_model)
_ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version)
_ir_passes.add_opset_imports(onnx_model)
# Make sure the model is valid
model_proto = ir.to_proto(onnx_model)

View File

@ -46,7 +46,7 @@ import numpy as np
import ops_test_common
import torch
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops, nn as nn_ops
from torch.testing._internal import common_methods_invocations
from torch.testing._internal.opinfo import definitions as opinfo_definitions
@ -78,6 +78,12 @@ class TorchLibOpInfo:
compare_shape_only_for_output: tuple[int, ...] = ()
# Whether the function is designed for complex inputs
complex: bool = False
# The ONNX opset version in which the function was introduced.
# Its specifies the minimum ONNX opset version required to use the function.
# It ensures that the function is only used when the target ONNX opset version
# is compatible. For example, if `opset_introduced=20`, the function will only
# be used when exporting to ONNX models targeting opset version 20 or higher.
opset_introduced: int = 18
# The acceptable tolerance of the inference result difference between PyTorch and ORT.
# Format: {dtype: (rtol, atol)}.
# For example: {torch.float16: (1e-3, 1e-3)}
@ -447,8 +453,10 @@ TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True),
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20),
)
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
ops_test_common.duplicate_opinfo(
@ -500,6 +508,7 @@ ops_test_common.duplicate_opinfo(
"nn.functional.replication_pad3d",
),
)
ops_test_common.duplicate_opinfo(OPS_DB, "nn.functional.gelu", ("gelu_op20",))
ops_test_common.duplicate_opinfo(
OPS_DB,
"nn.functional.scaled_dot_product_attention",

View File

@ -220,7 +220,9 @@ def run_test_output_match(
test_name = test_suite.id()
function_output, model_proto = function_executor(
test_name, reference_torch_outputs
test_name,
reference_torch_outputs,
opset_version=torchlib_op_info.opset_introduced,
)(onnx_function, input_onnx, kwargs_onnx)
# Finally we re-flatten everything
# TODO: add pytree structure comparison.

View File

@ -50,7 +50,7 @@ def export_compat(
verbose: bool | None = None,
input_names: Sequence[str] | None = None,
output_names: Sequence[str] | None = None,
opset_version: int | None = None,
opset_version: int | None = _constants.TORCHLIB_OPSET,
custom_translation_table: dict[Callable, Callable | Sequence[Callable]]
| None = None,
dynamic_axes: Mapping[str, Mapping[int, str]]
@ -105,8 +105,7 @@ def export_compat(
dynamic_shapes_with_export_dim, need_axis_mapping = (
_dynamic_shapes.convert_str_to_export_dim(dynamic_shapes)
)
registry = _registration.ONNXRegistry.from_torchlib()
registry = _registration.ONNXRegistry().from_torchlib(opset_version=opset_version)
if custom_translation_table is not None:
for torch_op, onnx_ops in custom_translation_table.items():
# TODO(justinchuby): Support complex inputs with annotations

View File

@ -90,7 +90,9 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None:
value.shape = ir.Shape(new_shape)
def add_torchlib_common_imports(model: ir.Model) -> None:
def add_torchlib_common_imports(
model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET
) -> None:
"""Hack to add torchlib common imports to the model."""
try:
@ -99,9 +101,11 @@ def add_torchlib_common_imports(model: ir.Model) -> None:
model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
rank_func.opset_imports[""] = opset_version
is_scalar_func = ir.serde.deserialize_function(
common_ops.IsScalar.to_function_proto()
)
is_scalar_func.opset_imports[""] = opset_version
model.functions[rank_func.identifier()] = rank_func
model.functions[is_scalar_func.identifier()] = is_scalar_func
except Exception:

View File

@ -42,6 +42,9 @@ class OnnxDecompMeta:
signature: The ONNX signature of the function. When None, the signature is inferred.
is_custom: Whether the function is a custom function.
is_complex: Whether the function is a function that handles complex valued inputs.
opset_introduced:
The ONNX opset version in which the function was introduced.
Its specifies the minimum ONNX opset version required to use the function.
device: The device the function is registered to. If None, it is registered to all devices.
skip_signature_inference: Whether to skip signature inference for the function.
"""
@ -51,6 +54,7 @@ class OnnxDecompMeta:
signature: _schemas.OpSignature | None
is_custom: bool = False
is_complex: bool = False
opset_introduced: int = 18
device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
skip_signature_inference: bool = False
@ -150,13 +154,14 @@ class ONNXRegistry:
return self._opset_version
@classmethod
def from_torchlib(cls) -> ONNXRegistry:
def from_torchlib(cls, opset_version=_constants.TORCHLIB_OPSET) -> ONNXRegistry:
"""Populates the registry with ATen functions from torchlib.
Args:
torchlib_registry: The torchlib registry to use for populating the registry.
"""
registry = cls()
registry._opset_version = opset_version
for meta in _torchlib_registry.get_torchlib_ops():
registry._register(meta.fx_target, meta)
@ -185,6 +190,7 @@ class ONNXRegistry:
logger.exception("Failed to register '%s'. Skipped", qualified_name)
continue
registry._cleanup_registry_based_on_opset_version()
return registry
def _register(
@ -274,5 +280,24 @@ class ONNXRegistry:
"""
return bool(self.get_decomps(target))
def _cleanup_registry_based_on_opset_version(self) -> None:
"""Pick the implementation with the highest opset version valid until the current opset version."""
cleaned_functions = {}
for target_or_name, decomps in self.functions.items():
# Filter decompositions to only include those with opset_introduced <= opset_version
decomps = [d for d in decomps if d.opset_introduced <= self.opset_version]
# Keep only the decomposition with the highest opset_introduced
if decomps:
# Find the maximum opset_introduced
max_opset = max(d.opset_introduced for d in decomps)
# Keep all decompositions with the maximum opset_introduced
cleaned_functions[target_or_name] = [
d for d in decomps if d.opset_introduced == max_opset
]
self.functions = cleaned_functions
def __repr__(self) -> str:
return f"{self.__class__.__name__}(functions={self.functions})"

View File

@ -30,6 +30,7 @@ def onnx_impl(
*,
trace_only: bool = False,
complex: bool = False,
opset_introduced: int = 18,
no_compile: bool = False,
private: bool = False,
) -> Callable[[_T], _T]:
@ -74,6 +75,7 @@ def onnx_impl(
fx_target=t,
signature=None,
is_complex=complex,
opset_introduced=opset_introduced,
skip_signature_inference=no_compile,
)
)

View File

@ -1,6 +1,6 @@
from __future__ import annotations
__all__ = ["core", "hop", "symbolic"]
__all__ = ["core", "hop", "nn", "symbolic"]
from torch.onnx._internal.exporter._torchlib.ops import core, hop, symbolic
from torch.onnx._internal.exporter._torchlib.ops import core, hop, nn, symbolic

View File

@ -0,0 +1,26 @@
"""torch.ops.aten operators under the `core` module."""
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
# ruff: noqa: TCH001,TCH002
# flake8: noqa
from __future__ import annotations
import math
from onnxscript.onnx_opset import opset20 as op20
import torch
from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
aten = torch.ops.aten
@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20)
def aten_gelu_opset20(
self: TReal,
approximate: str = "none",
) -> TReal:
"""gelu(Tensor self, *, bool approximate=False) -> Tensor"""
return op20.Gelu(self, approximate=approximate)