From fcbde24c1cb54f3e0417e123bdb9ae09da134c8d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 17 Oct 2025 03:25:31 +0000 Subject: [PATCH] [ONNX] Remove common imports from torchlib (#165156) The Rank and IsScalar functions are no longer used in the torchlib. Requires onnxscript v0.5.4 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165156 Approved by: https://github.com/Skylion007, https://github.com/cyyever --- .ci/docker/common/install_onnx.sh | 2 +- test/onnx/torchlib/ops_test_common.py | 1 - torch/onnx/_internal/exporter/_building.py | 39 --------------------- torch/onnx/_internal/exporter/_core.py | 3 -- torch/onnx/_internal/exporter/_ir_passes.py | 22 ------------ 5 files changed, 1 insertion(+), 66 deletions(-) diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index 183b5b65c90a..b0615b8a84c1 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -20,7 +20,7 @@ pip_install \ pip_install coloredlogs packaging pip_install onnxruntime==1.23.0 -pip_install onnxscript==0.5.3 +pip_install onnxscript==0.5.4 # Cache the transformers model to be used later by ONNX tests. We need to run the transformers # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ diff --git a/test/onnx/torchlib/ops_test_common.py b/test/onnx/torchlib/ops_test_common.py index 72243faf3b50..d1206da0e07d 100644 --- a/test/onnx/torchlib/ops_test_common.py +++ b/test/onnx/torchlib/ops_test_common.py @@ -592,7 +592,6 @@ def graph_executor( proto = onnxscript_function.to_function_proto() ir_function = ir.serde.deserialize_function(proto) onnx_model.functions[identifier] = ir_function - _ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version) _ir_passes.add_opset_imports(onnx_model) # Make sure the model is valid model_proto = ir.to_proto(onnx_model) diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py index dbe38f81680c..608591ca04c2 100644 --- a/torch/onnx/_internal/exporter/_building.py +++ b/torch/onnx/_internal/exporter/_building.py @@ -646,45 +646,6 @@ class OpRecorder(evaluator.Evaluator): kwargs: Mapping[str, AllowedArgType], ) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor] | bool | int: try: - # TODO(justinchuby): Remove this once IsScalar and Rank are removed - # Special cases for handling IsScalar and Rank - if function.name == "IsScalar": - if len(args) != 1: - raise TypeError( - f"Expected 1 positional argument for function '{function}', got {len(args)}." - ) - if isinstance(args[0], _tensors.SymbolicTensor): - if args[0].rank is not None: - return args[0].rank == 0 - else: - # Fall to call add_function_call - pass - elif isinstance(args[0], Sequence): - return False - else: - # Python constants are scalars - return True - if function.name == "Rank": - if len(args) != 1: - raise TypeError( - f"Expected 1 positional argument for function '{function}', got {len(args)}." - ) - if isinstance(args[0], _tensors.SymbolicTensor): - if args[0].rank is not None: - return args[0].rank - else: - # Fall to call add_function_call - pass - elif isinstance(args[0], Sequence): - if all(isinstance(arg, (int, float)) for arg in args[0]): - return 1 - else: - # Fall to call add_function_call - pass - else: - # Python constants are scalars - return 0 - # NOTE: signature should be written to function in the registration process if hasattr(function, "_pt_onnx_signature"): op_signature = function._pt_onnx_signature # type: ignore[attr-defined] diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index 06b12d8b1931..5696273f7b66 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -1249,9 +1249,6 @@ def _exported_program_to_onnx_program( # TODO: Decide if we should keep mutated buffers as inputs/outputs - # TODO(justinchuby): Remove the hack - _ir_passes.add_torchlib_common_imports(model) - # Collect and add opset imports to the model _ir_passes.add_opset_imports(model) diff --git a/torch/onnx/_internal/exporter/_ir_passes.py b/torch/onnx/_internal/exporter/_ir_passes.py index 8a715e245597..9391b642b009 100644 --- a/torch/onnx/_internal/exporter/_ir_passes.py +++ b/torch/onnx/_internal/exporter/_ir_passes.py @@ -90,28 +90,6 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None: value.shape = ir.Shape(new_shape) -def add_torchlib_common_imports( - model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET -) -> None: - """Hack to add torchlib common imports to the model.""" - - try: - # TODO(justinchuby): Remove this hack and improved onnxscript - from onnxscript.function_libs.torch_lib.ops import common as common_ops - - model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1 - rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto()) - rank_func.opset_imports[""] = opset_version - is_scalar_func = ir.serde.deserialize_function( - common_ops.IsScalar.to_function_proto() - ) - is_scalar_func.opset_imports[""] = opset_version - model.functions[rank_func.identifier()] = rank_func - model.functions[is_scalar_func.identifier()] = is_scalar_func - except Exception: - logger.exception("Failed to add torchlib common imports to the model.") - - def _maybe_set_opset_version( opset_imports: dict[str, int], domain: str, version: int | None ) -> None: