mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Remove common imports from torchlib (#165156)
The Rank and IsScalar functions are no longer used in the torchlib. Requires onnxscript v0.5.4 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165156 Approved by: https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
committed by
PyTorch MergeBot
parent
861cdb887b
commit
fcbde24c1c
@ -20,7 +20,7 @@ pip_install \
|
|||||||
|
|
||||||
pip_install coloredlogs packaging
|
pip_install coloredlogs packaging
|
||||||
pip_install onnxruntime==1.23.0
|
pip_install onnxruntime==1.23.0
|
||||||
pip_install onnxscript==0.5.3
|
pip_install onnxscript==0.5.4
|
||||||
|
|
||||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
||||||
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
||||||
|
@ -592,7 +592,6 @@ def graph_executor(
|
|||||||
proto = onnxscript_function.to_function_proto()
|
proto = onnxscript_function.to_function_proto()
|
||||||
ir_function = ir.serde.deserialize_function(proto)
|
ir_function = ir.serde.deserialize_function(proto)
|
||||||
onnx_model.functions[identifier] = ir_function
|
onnx_model.functions[identifier] = ir_function
|
||||||
_ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version)
|
|
||||||
_ir_passes.add_opset_imports(onnx_model)
|
_ir_passes.add_opset_imports(onnx_model)
|
||||||
# Make sure the model is valid
|
# Make sure the model is valid
|
||||||
model_proto = ir.to_proto(onnx_model)
|
model_proto = ir.to_proto(onnx_model)
|
||||||
|
@ -646,45 +646,6 @@ class OpRecorder(evaluator.Evaluator):
|
|||||||
kwargs: Mapping[str, AllowedArgType],
|
kwargs: Mapping[str, AllowedArgType],
|
||||||
) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor] | bool | int:
|
) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor] | bool | int:
|
||||||
try:
|
try:
|
||||||
# TODO(justinchuby): Remove this once IsScalar and Rank are removed
|
|
||||||
# Special cases for handling IsScalar and Rank
|
|
||||||
if function.name == "IsScalar":
|
|
||||||
if len(args) != 1:
|
|
||||||
raise TypeError(
|
|
||||||
f"Expected 1 positional argument for function '{function}', got {len(args)}."
|
|
||||||
)
|
|
||||||
if isinstance(args[0], _tensors.SymbolicTensor):
|
|
||||||
if args[0].rank is not None:
|
|
||||||
return args[0].rank == 0
|
|
||||||
else:
|
|
||||||
# Fall to call add_function_call
|
|
||||||
pass
|
|
||||||
elif isinstance(args[0], Sequence):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
# Python constants are scalars
|
|
||||||
return True
|
|
||||||
if function.name == "Rank":
|
|
||||||
if len(args) != 1:
|
|
||||||
raise TypeError(
|
|
||||||
f"Expected 1 positional argument for function '{function}', got {len(args)}."
|
|
||||||
)
|
|
||||||
if isinstance(args[0], _tensors.SymbolicTensor):
|
|
||||||
if args[0].rank is not None:
|
|
||||||
return args[0].rank
|
|
||||||
else:
|
|
||||||
# Fall to call add_function_call
|
|
||||||
pass
|
|
||||||
elif isinstance(args[0], Sequence):
|
|
||||||
if all(isinstance(arg, (int, float)) for arg in args[0]):
|
|
||||||
return 1
|
|
||||||
else:
|
|
||||||
# Fall to call add_function_call
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Python constants are scalars
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# NOTE: signature should be written to function in the registration process
|
# NOTE: signature should be written to function in the registration process
|
||||||
if hasattr(function, "_pt_onnx_signature"):
|
if hasattr(function, "_pt_onnx_signature"):
|
||||||
op_signature = function._pt_onnx_signature # type: ignore[attr-defined]
|
op_signature = function._pt_onnx_signature # type: ignore[attr-defined]
|
||||||
|
@ -1249,9 +1249,6 @@ def _exported_program_to_onnx_program(
|
|||||||
|
|
||||||
# TODO: Decide if we should keep mutated buffers as inputs/outputs
|
# TODO: Decide if we should keep mutated buffers as inputs/outputs
|
||||||
|
|
||||||
# TODO(justinchuby): Remove the hack
|
|
||||||
_ir_passes.add_torchlib_common_imports(model)
|
|
||||||
|
|
||||||
# Collect and add opset imports to the model
|
# Collect and add opset imports to the model
|
||||||
_ir_passes.add_opset_imports(model)
|
_ir_passes.add_opset_imports(model)
|
||||||
|
|
||||||
|
@ -90,28 +90,6 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None:
|
|||||||
value.shape = ir.Shape(new_shape)
|
value.shape = ir.Shape(new_shape)
|
||||||
|
|
||||||
|
|
||||||
def add_torchlib_common_imports(
|
|
||||||
model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET
|
|
||||||
) -> None:
|
|
||||||
"""Hack to add torchlib common imports to the model."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# TODO(justinchuby): Remove this hack and improved onnxscript
|
|
||||||
from onnxscript.function_libs.torch_lib.ops import common as common_ops
|
|
||||||
|
|
||||||
model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
|
|
||||||
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
|
|
||||||
rank_func.opset_imports[""] = opset_version
|
|
||||||
is_scalar_func = ir.serde.deserialize_function(
|
|
||||||
common_ops.IsScalar.to_function_proto()
|
|
||||||
)
|
|
||||||
is_scalar_func.opset_imports[""] = opset_version
|
|
||||||
model.functions[rank_func.identifier()] = rank_func
|
|
||||||
model.functions[is_scalar_func.identifier()] = is_scalar_func
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to add torchlib common imports to the model.")
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_set_opset_version(
|
def _maybe_set_opset_version(
|
||||||
opset_imports: dict[str, int], domain: str, version: int | None
|
opset_imports: dict[str, int], domain: str, version: int | None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Reference in New Issue
Block a user