[ONNX] Remove common imports from torchlib (#165156)

The Rank and IsScalar functions are no longer used in the torchlib. Requires onnxscript v0.5.4

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165156
Approved by: https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
Justin Chu
2025-10-17 03:25:31 +00:00
committed by PyTorch MergeBot
parent 861cdb887b
commit fcbde24c1c
5 changed files with 1 additions and 66 deletions

View File

@ -20,7 +20,7 @@ pip_install \
pip_install coloredlogs packaging
pip_install onnxruntime==1.23.0
pip_install onnxscript==0.5.3
pip_install onnxscript==0.5.4
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/

View File

@ -592,7 +592,6 @@ def graph_executor(
proto = onnxscript_function.to_function_proto()
ir_function = ir.serde.deserialize_function(proto)
onnx_model.functions[identifier] = ir_function
_ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version)
_ir_passes.add_opset_imports(onnx_model)
# Make sure the model is valid
model_proto = ir.to_proto(onnx_model)

View File

@ -646,45 +646,6 @@ class OpRecorder(evaluator.Evaluator):
kwargs: Mapping[str, AllowedArgType],
) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor] | bool | int:
try:
# TODO(justinchuby): Remove this once IsScalar and Rank are removed
# Special cases for handling IsScalar and Rank
if function.name == "IsScalar":
if len(args) != 1:
raise TypeError(
f"Expected 1 positional argument for function '{function}', got {len(args)}."
)
if isinstance(args[0], _tensors.SymbolicTensor):
if args[0].rank is not None:
return args[0].rank == 0
else:
# Fall to call add_function_call
pass
elif isinstance(args[0], Sequence):
return False
else:
# Python constants are scalars
return True
if function.name == "Rank":
if len(args) != 1:
raise TypeError(
f"Expected 1 positional argument for function '{function}', got {len(args)}."
)
if isinstance(args[0], _tensors.SymbolicTensor):
if args[0].rank is not None:
return args[0].rank
else:
# Fall to call add_function_call
pass
elif isinstance(args[0], Sequence):
if all(isinstance(arg, (int, float)) for arg in args[0]):
return 1
else:
# Fall to call add_function_call
pass
else:
# Python constants are scalars
return 0
# NOTE: signature should be written to function in the registration process
if hasattr(function, "_pt_onnx_signature"):
op_signature = function._pt_onnx_signature # type: ignore[attr-defined]

View File

@ -1249,9 +1249,6 @@ def _exported_program_to_onnx_program(
# TODO: Decide if we should keep mutated buffers as inputs/outputs
# TODO(justinchuby): Remove the hack
_ir_passes.add_torchlib_common_imports(model)
# Collect and add opset imports to the model
_ir_passes.add_opset_imports(model)

View File

@ -90,28 +90,6 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None:
value.shape = ir.Shape(new_shape)
def add_torchlib_common_imports(
model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET
) -> None:
"""Hack to add torchlib common imports to the model."""
try:
# TODO(justinchuby): Remove this hack and improved onnxscript
from onnxscript.function_libs.torch_lib.ops import common as common_ops
model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
rank_func.opset_imports[""] = opset_version
is_scalar_func = ir.serde.deserialize_function(
common_ops.IsScalar.to_function_proto()
)
is_scalar_func.opset_imports[""] = opset_version
model.functions[rank_func.identifier()] = rank_func
model.functions[is_scalar_func.identifier()] = is_scalar_func
except Exception:
logger.exception("Failed to add torchlib common imports to the model.")
def _maybe_set_opset_version(
opset_imports: dict[str, int], domain: str, version: int | None
) -> None: