mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	[ONNX] Supporting different opset versions for torchlib registry (#149901)
- Allows opset_version to determine which onnx decomposition to choose - Adds a cleanup function to modify the registry after it is built Pull Request resolved: https://github.com/pytorch/pytorch/pull/149901 Approved by: https://github.com/justinchuby, https://github.com/titaiwangms
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							97a5e5c6b3
						
					
				
				
					commit
					1a56609e75
				
			@ -246,6 +246,31 @@ class TestExportAPIDynamo(common_utils.TestCase):
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_upgraded_torchlib_impl(self):
 | 
			
		||||
        class GeluModel(torch.nn.Module):
 | 
			
		||||
            def forward(self, input):
 | 
			
		||||
                # Use GELU activation function
 | 
			
		||||
                return torch.nn.functional.gelu(input, approximate="tanh")
 | 
			
		||||
 | 
			
		||||
        input = torch.randn(1, 3, 4, 4)
 | 
			
		||||
        onnx_program_op18 = torch.onnx.export(
 | 
			
		||||
            GeluModel(),
 | 
			
		||||
            input,
 | 
			
		||||
            dynamo=True,
 | 
			
		||||
        )
 | 
			
		||||
        all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph]
 | 
			
		||||
        self.assertIn("Tanh", all_nodes_op18)
 | 
			
		||||
        self.assertNotIn("Gelu", all_nodes_op18)
 | 
			
		||||
 | 
			
		||||
        onnx_program_op20 = torch.onnx.export(
 | 
			
		||||
            GeluModel(),
 | 
			
		||||
            input,
 | 
			
		||||
            opset_version=20,
 | 
			
		||||
            dynamo=True,
 | 
			
		||||
        )
 | 
			
		||||
        all_nodes_op20 = [n.op_type for n in onnx_program_op20.model.graph]
 | 
			
		||||
        self.assertIn("Gelu", all_nodes_op20)
 | 
			
		||||
 | 
			
		||||
    def test_refine_dynamic_shapes_with_onnx_export(self):
 | 
			
		||||
        # NOTE: From test/export/test_export.py
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,6 +52,7 @@ FLOAT_TYPES = (
 | 
			
		||||
    torch.float64,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TEST_OPSET_VERSION = 18
 | 
			
		||||
IS_MACOS = sys.platform.startswith("darwin")
 | 
			
		||||
IS_WINDOWS = os.name == "nt"
 | 
			
		||||
@ -487,6 +488,7 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -
 | 
			
		||||
def graph_executor(
 | 
			
		||||
    test_name: str,
 | 
			
		||||
    outputs: Sequence[Any],
 | 
			
		||||
    opset_version: int = TEST_OPSET_VERSION,
 | 
			
		||||
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]:
 | 
			
		||||
    """Eagerly executes a function."""
 | 
			
		||||
 | 
			
		||||
@ -500,10 +502,10 @@ def graph_executor(
 | 
			
		||||
            (),
 | 
			
		||||
            (),
 | 
			
		||||
            nodes=(),
 | 
			
		||||
            opset_imports={"": 18, "pkg.torch.onnx": 1},
 | 
			
		||||
            opset_imports={"": opset_version, "pkg.torch.onnx": 1},
 | 
			
		||||
            name="main_graph",
 | 
			
		||||
        )
 | 
			
		||||
        opset = onnxscript.opset18
 | 
			
		||||
        opset = onnxscript.values.Opset("", opset_version)
 | 
			
		||||
        tracer = _building.OpRecorder(opset, {})
 | 
			
		||||
        ort_inputs = {}
 | 
			
		||||
        onnxscript_args: list[Any] = []
 | 
			
		||||
@ -590,7 +592,7 @@ def graph_executor(
 | 
			
		||||
                proto = onnxscript_function.to_function_proto()
 | 
			
		||||
                ir_function = ir.serde.deserialize_function(proto)
 | 
			
		||||
            onnx_model.functions[identifier] = ir_function
 | 
			
		||||
        _ir_passes.add_torchlib_common_imports(onnx_model)
 | 
			
		||||
        _ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version)
 | 
			
		||||
        _ir_passes.add_opset_imports(onnx_model)
 | 
			
		||||
        # Make sure the model is valid
 | 
			
		||||
        model_proto = ir.to_proto(onnx_model)
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ import numpy as np
 | 
			
		||||
import ops_test_common
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops
 | 
			
		||||
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops, nn as nn_ops
 | 
			
		||||
from torch.testing._internal import common_methods_invocations
 | 
			
		||||
from torch.testing._internal.opinfo import definitions as opinfo_definitions
 | 
			
		||||
 | 
			
		||||
@ -78,6 +78,12 @@ class TorchLibOpInfo:
 | 
			
		||||
    compare_shape_only_for_output: tuple[int, ...] = ()
 | 
			
		||||
    # Whether the function is designed for complex inputs
 | 
			
		||||
    complex: bool = False
 | 
			
		||||
    # The ONNX opset version in which the function was introduced.
 | 
			
		||||
    # Its specifies the minimum ONNX opset version required to use the function.
 | 
			
		||||
    # It ensures that the function is only used when the target ONNX opset version
 | 
			
		||||
    # is compatible. For example, if `opset_introduced=20`, the function will only
 | 
			
		||||
    # be used when exporting to ONNX models targeting opset version 20 or higher.
 | 
			
		||||
    opset_introduced: int = 18
 | 
			
		||||
    # The acceptable tolerance of the inference result difference between PyTorch and ORT.
 | 
			
		||||
    # Format: {dtype: (rtol, atol)}.
 | 
			
		||||
    # For example: {torch.float16: (1e-3, 1e-3)}
 | 
			
		||||
@ -447,8 +453,10 @@ TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
 | 
			
		||||
    TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True),
 | 
			
		||||
    TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
 | 
			
		||||
    TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
 | 
			
		||||
    TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
 | 
			
		||||
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
 | 
			
		||||
ops_test_common.duplicate_opinfo(
 | 
			
		||||
@ -500,6 +508,7 @@ ops_test_common.duplicate_opinfo(
 | 
			
		||||
        "nn.functional.replication_pad3d",
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
ops_test_common.duplicate_opinfo(OPS_DB, "nn.functional.gelu", ("gelu_op20",))
 | 
			
		||||
ops_test_common.duplicate_opinfo(
 | 
			
		||||
    OPS_DB,
 | 
			
		||||
    "nn.functional.scaled_dot_product_attention",
 | 
			
		||||
 | 
			
		||||
@ -220,7 +220,9 @@ def run_test_output_match(
 | 
			
		||||
 | 
			
		||||
                test_name = test_suite.id()
 | 
			
		||||
                function_output, model_proto = function_executor(
 | 
			
		||||
                    test_name, reference_torch_outputs
 | 
			
		||||
                    test_name,
 | 
			
		||||
                    reference_torch_outputs,
 | 
			
		||||
                    opset_version=torchlib_op_info.opset_introduced,
 | 
			
		||||
                )(onnx_function, input_onnx, kwargs_onnx)
 | 
			
		||||
                # Finally we re-flatten everything
 | 
			
		||||
                # TODO: add pytree structure comparison.
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ def export_compat(
 | 
			
		||||
    verbose: bool | None = None,
 | 
			
		||||
    input_names: Sequence[str] | None = None,
 | 
			
		||||
    output_names: Sequence[str] | None = None,
 | 
			
		||||
    opset_version: int | None = None,
 | 
			
		||||
    opset_version: int | None = _constants.TORCHLIB_OPSET,
 | 
			
		||||
    custom_translation_table: dict[Callable, Callable | Sequence[Callable]]
 | 
			
		||||
    | None = None,
 | 
			
		||||
    dynamic_axes: Mapping[str, Mapping[int, str]]
 | 
			
		||||
@ -105,8 +105,7 @@ def export_compat(
 | 
			
		||||
    dynamic_shapes_with_export_dim, need_axis_mapping = (
 | 
			
		||||
        _dynamic_shapes.convert_str_to_export_dim(dynamic_shapes)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    registry = _registration.ONNXRegistry.from_torchlib()
 | 
			
		||||
    registry = _registration.ONNXRegistry().from_torchlib(opset_version=opset_version)
 | 
			
		||||
    if custom_translation_table is not None:
 | 
			
		||||
        for torch_op, onnx_ops in custom_translation_table.items():
 | 
			
		||||
            # TODO(justinchuby): Support complex inputs with annotations
 | 
			
		||||
 | 
			
		||||
@ -90,7 +90,9 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None:
 | 
			
		||||
            value.shape = ir.Shape(new_shape)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def add_torchlib_common_imports(model: ir.Model) -> None:
 | 
			
		||||
def add_torchlib_common_imports(
 | 
			
		||||
    model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET
 | 
			
		||||
) -> None:
 | 
			
		||||
    """Hack to add torchlib common imports to the model."""
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
@ -99,9 +101,11 @@ def add_torchlib_common_imports(model: ir.Model) -> None:
 | 
			
		||||
 | 
			
		||||
        model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
 | 
			
		||||
        rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
 | 
			
		||||
        rank_func.opset_imports[""] = opset_version
 | 
			
		||||
        is_scalar_func = ir.serde.deserialize_function(
 | 
			
		||||
            common_ops.IsScalar.to_function_proto()
 | 
			
		||||
        )
 | 
			
		||||
        is_scalar_func.opset_imports[""] = opset_version
 | 
			
		||||
        model.functions[rank_func.identifier()] = rank_func
 | 
			
		||||
        model.functions[is_scalar_func.identifier()] = is_scalar_func
 | 
			
		||||
    except Exception:
 | 
			
		||||
 | 
			
		||||
@ -42,6 +42,9 @@ class OnnxDecompMeta:
 | 
			
		||||
    signature: The ONNX signature of the function. When None, the signature is inferred.
 | 
			
		||||
    is_custom: Whether the function is a custom function.
 | 
			
		||||
    is_complex: Whether the function is a function that handles complex valued inputs.
 | 
			
		||||
    opset_introduced:
 | 
			
		||||
        The ONNX opset version in which the function was introduced.
 | 
			
		||||
        Its specifies the minimum ONNX opset version required to use the function.
 | 
			
		||||
    device: The device the function is registered to. If None, it is registered to all devices.
 | 
			
		||||
    skip_signature_inference: Whether to skip signature inference for the function.
 | 
			
		||||
    """
 | 
			
		||||
@ -51,6 +54,7 @@ class OnnxDecompMeta:
 | 
			
		||||
    signature: _schemas.OpSignature | None
 | 
			
		||||
    is_custom: bool = False
 | 
			
		||||
    is_complex: bool = False
 | 
			
		||||
    opset_introduced: int = 18
 | 
			
		||||
    device: Literal["cuda", "cpu"] | str | None = None  # noqa: PYI051
 | 
			
		||||
    skip_signature_inference: bool = False
 | 
			
		||||
 | 
			
		||||
@ -150,13 +154,14 @@ class ONNXRegistry:
 | 
			
		||||
        return self._opset_version
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_torchlib(cls) -> ONNXRegistry:
 | 
			
		||||
    def from_torchlib(cls, opset_version=_constants.TORCHLIB_OPSET) -> ONNXRegistry:
 | 
			
		||||
        """Populates the registry with ATen functions from torchlib.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            torchlib_registry: The torchlib registry to use for populating the registry.
 | 
			
		||||
        """
 | 
			
		||||
        registry = cls()
 | 
			
		||||
        registry._opset_version = opset_version
 | 
			
		||||
        for meta in _torchlib_registry.get_torchlib_ops():
 | 
			
		||||
            registry._register(meta.fx_target, meta)
 | 
			
		||||
 | 
			
		||||
@ -185,6 +190,7 @@ class ONNXRegistry:
 | 
			
		||||
                logger.exception("Failed to register '%s'. Skipped", qualified_name)
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
        registry._cleanup_registry_based_on_opset_version()
 | 
			
		||||
        return registry
 | 
			
		||||
 | 
			
		||||
    def _register(
 | 
			
		||||
@ -274,5 +280,24 @@ class ONNXRegistry:
 | 
			
		||||
        """
 | 
			
		||||
        return bool(self.get_decomps(target))
 | 
			
		||||
 | 
			
		||||
    def _cleanup_registry_based_on_opset_version(self) -> None:
 | 
			
		||||
        """Pick the implementation with the highest opset version valid until the current opset version."""
 | 
			
		||||
        cleaned_functions = {}
 | 
			
		||||
        for target_or_name, decomps in self.functions.items():
 | 
			
		||||
            # Filter decompositions to only include those with opset_introduced <= opset_version
 | 
			
		||||
            decomps = [d for d in decomps if d.opset_introduced <= self.opset_version]
 | 
			
		||||
 | 
			
		||||
            # Keep only the decomposition with the highest opset_introduced
 | 
			
		||||
            if decomps:
 | 
			
		||||
                # Find the maximum opset_introduced
 | 
			
		||||
                max_opset = max(d.opset_introduced for d in decomps)
 | 
			
		||||
 | 
			
		||||
                # Keep all decompositions with the maximum opset_introduced
 | 
			
		||||
                cleaned_functions[target_or_name] = [
 | 
			
		||||
                    d for d in decomps if d.opset_introduced == max_opset
 | 
			
		||||
                ]
 | 
			
		||||
 | 
			
		||||
        self.functions = cleaned_functions
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return f"{self.__class__.__name__}(functions={self.functions})"
 | 
			
		||||
 | 
			
		||||
@ -30,6 +30,7 @@ def onnx_impl(
 | 
			
		||||
    *,
 | 
			
		||||
    trace_only: bool = False,
 | 
			
		||||
    complex: bool = False,
 | 
			
		||||
    opset_introduced: int = 18,
 | 
			
		||||
    no_compile: bool = False,
 | 
			
		||||
    private: bool = False,
 | 
			
		||||
) -> Callable[[_T], _T]:
 | 
			
		||||
@ -74,6 +75,7 @@ def onnx_impl(
 | 
			
		||||
                        fx_target=t,
 | 
			
		||||
                        signature=None,
 | 
			
		||||
                        is_complex=complex,
 | 
			
		||||
                        opset_introduced=opset_introduced,
 | 
			
		||||
                        skip_signature_inference=no_compile,
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ["core", "hop", "symbolic"]
 | 
			
		||||
__all__ = ["core", "hop", "nn", "symbolic"]
 | 
			
		||||
 | 
			
		||||
from torch.onnx._internal.exporter._torchlib.ops import core, hop, symbolic
 | 
			
		||||
from torch.onnx._internal.exporter._torchlib.ops import core, hop, nn, symbolic
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										26
									
								
								torch/onnx/_internal/exporter/_torchlib/ops/nn.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								torch/onnx/_internal/exporter/_torchlib/ops/nn.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,26 @@
 | 
			
		||||
"""torch.ops.aten operators under the `core` module."""
 | 
			
		||||
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
 | 
			
		||||
# ruff: noqa: TCH001,TCH002
 | 
			
		||||
# flake8: noqa
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
from onnxscript.onnx_opset import opset20 as op20
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal
 | 
			
		||||
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
aten = torch.ops.aten
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20)
 | 
			
		||||
def aten_gelu_opset20(
 | 
			
		||||
    self: TReal,
 | 
			
		||||
    approximate: str = "none",
 | 
			
		||||
) -> TReal:
 | 
			
		||||
    """gelu(Tensor self, *, bool approximate=False) -> Tensor"""
 | 
			
		||||
    return op20.Gelu(self, approximate=approximate)
 | 
			
		||||
		Reference in New Issue
	
	Block a user