mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ONNX] Remove common imports from torchlib (#165156)
The Rank and IsScalar functions are no longer used in the torchlib. Requires onnxscript v0.5.4 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165156 Approved by: https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
committed by
PyTorch MergeBot
parent
861cdb887b
commit
fcbde24c1c
@ -20,7 +20,7 @@ pip_install \
|
||||
|
||||
pip_install coloredlogs packaging
|
||||
pip_install onnxruntime==1.23.0
|
||||
pip_install onnxscript==0.5.3
|
||||
pip_install onnxscript==0.5.4
|
||||
|
||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
||||
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
||||
|
@ -592,7 +592,6 @@ def graph_executor(
|
||||
proto = onnxscript_function.to_function_proto()
|
||||
ir_function = ir.serde.deserialize_function(proto)
|
||||
onnx_model.functions[identifier] = ir_function
|
||||
_ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version)
|
||||
_ir_passes.add_opset_imports(onnx_model)
|
||||
# Make sure the model is valid
|
||||
model_proto = ir.to_proto(onnx_model)
|
||||
|
@ -646,45 +646,6 @@ class OpRecorder(evaluator.Evaluator):
|
||||
kwargs: Mapping[str, AllowedArgType],
|
||||
) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor] | bool | int:
|
||||
try:
|
||||
# TODO(justinchuby): Remove this once IsScalar and Rank are removed
|
||||
# Special cases for handling IsScalar and Rank
|
||||
if function.name == "IsScalar":
|
||||
if len(args) != 1:
|
||||
raise TypeError(
|
||||
f"Expected 1 positional argument for function '{function}', got {len(args)}."
|
||||
)
|
||||
if isinstance(args[0], _tensors.SymbolicTensor):
|
||||
if args[0].rank is not None:
|
||||
return args[0].rank == 0
|
||||
else:
|
||||
# Fall to call add_function_call
|
||||
pass
|
||||
elif isinstance(args[0], Sequence):
|
||||
return False
|
||||
else:
|
||||
# Python constants are scalars
|
||||
return True
|
||||
if function.name == "Rank":
|
||||
if len(args) != 1:
|
||||
raise TypeError(
|
||||
f"Expected 1 positional argument for function '{function}', got {len(args)}."
|
||||
)
|
||||
if isinstance(args[0], _tensors.SymbolicTensor):
|
||||
if args[0].rank is not None:
|
||||
return args[0].rank
|
||||
else:
|
||||
# Fall to call add_function_call
|
||||
pass
|
||||
elif isinstance(args[0], Sequence):
|
||||
if all(isinstance(arg, (int, float)) for arg in args[0]):
|
||||
return 1
|
||||
else:
|
||||
# Fall to call add_function_call
|
||||
pass
|
||||
else:
|
||||
# Python constants are scalars
|
||||
return 0
|
||||
|
||||
# NOTE: signature should be written to function in the registration process
|
||||
if hasattr(function, "_pt_onnx_signature"):
|
||||
op_signature = function._pt_onnx_signature # type: ignore[attr-defined]
|
||||
|
@ -1249,9 +1249,6 @@ def _exported_program_to_onnx_program(
|
||||
|
||||
# TODO: Decide if we should keep mutated buffers as inputs/outputs
|
||||
|
||||
# TODO(justinchuby): Remove the hack
|
||||
_ir_passes.add_torchlib_common_imports(model)
|
||||
|
||||
# Collect and add opset imports to the model
|
||||
_ir_passes.add_opset_imports(model)
|
||||
|
||||
|
@ -90,28 +90,6 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None:
|
||||
value.shape = ir.Shape(new_shape)
|
||||
|
||||
|
||||
def add_torchlib_common_imports(
|
||||
model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET
|
||||
) -> None:
|
||||
"""Hack to add torchlib common imports to the model."""
|
||||
|
||||
try:
|
||||
# TODO(justinchuby): Remove this hack and improved onnxscript
|
||||
from onnxscript.function_libs.torch_lib.ops import common as common_ops
|
||||
|
||||
model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
|
||||
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
|
||||
rank_func.opset_imports[""] = opset_version
|
||||
is_scalar_func = ir.serde.deserialize_function(
|
||||
common_ops.IsScalar.to_function_proto()
|
||||
)
|
||||
is_scalar_func.opset_imports[""] = opset_version
|
||||
model.functions[rank_func.identifier()] = rank_func
|
||||
model.functions[is_scalar_func.identifier()] = is_scalar_func
|
||||
except Exception:
|
||||
logger.exception("Failed to add torchlib common imports to the model.")
|
||||
|
||||
|
||||
def _maybe_set_opset_version(
|
||||
opset_imports: dict[str, int], domain: str, version: int | None
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user