mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[V5] Remove deprecated transformers.onnx (#41214)
* Remove deprecated transformers.onnx Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> * Remove onnx docs Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> --------- Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@ -342,8 +342,6 @@
|
||||
title: Models
|
||||
- local: main_classes/text_generation
|
||||
title: Text Generation
|
||||
- local: main_classes/onnx
|
||||
title: ONNX
|
||||
- local: main_classes/optimizer_schedules
|
||||
title: Optimization
|
||||
- local: main_classes/output
|
||||
|
@ -1,53 +0,0 @@
|
||||
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Exporting 🤗 Transformers models to ONNX
|
||||
|
||||
🤗 Transformers provides a `transformers.onnx` package that enables you to
|
||||
convert model checkpoints to an ONNX graph by leveraging configuration objects.
|
||||
|
||||
See the [guide](../serialization) on exporting 🤗 Transformers models for more
|
||||
details.
|
||||
|
||||
## ONNX Configurations
|
||||
|
||||
We provide three abstract classes that you should inherit from, depending on the
|
||||
type of model architecture you wish to export:
|
||||
|
||||
* Encoder-based models inherit from [`~onnx.config.OnnxConfig`]
|
||||
* Decoder-based models inherit from [`~onnx.config.OnnxConfigWithPast`]
|
||||
* Encoder-decoder models inherit from [`~onnx.config.OnnxSeq2SeqConfigWithPast`]
|
||||
|
||||
### OnnxConfig
|
||||
|
||||
[[autodoc]] onnx.config.OnnxConfig
|
||||
|
||||
### OnnxConfigWithPast
|
||||
|
||||
[[autodoc]] onnx.config.OnnxConfigWithPast
|
||||
|
||||
### OnnxSeq2SeqConfigWithPast
|
||||
|
||||
[[autodoc]] onnx.config.OnnxSeq2SeqConfigWithPast
|
||||
|
||||
## ONNX Features
|
||||
|
||||
Each ONNX configuration is associated with a set of _features_ that enable you
|
||||
to export models for different types of topologies or tasks.
|
||||
|
||||
### FeaturesManager
|
||||
|
||||
[[autodoc]] onnx.features.FeaturesManager
|
7
setup.py
7
setup.py
@ -125,9 +125,6 @@ _deps = [
|
||||
"nltk<=3.8.1",
|
||||
"num2words",
|
||||
"numpy>=1.17",
|
||||
"onnxconverter-common",
|
||||
"onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime>=1.4.0",
|
||||
"openai>=1.98.0",
|
||||
"opencv-python",
|
||||
"optimum-benchmark>=0.3.0",
|
||||
@ -271,8 +268,6 @@ else:
|
||||
|
||||
extras["tokenizers"] = deps_list("tokenizers")
|
||||
extras["ftfy"] = deps_list("ftfy")
|
||||
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
|
||||
extras["onnx"] = deps_list("onnxconverter-common") + extras["onnxruntime"]
|
||||
extras["modelcreation"] = deps_list("cookiecutter")
|
||||
|
||||
extras["sagemaker"] = deps_list("sagemaker")
|
||||
@ -376,7 +371,6 @@ extras["dev-torch"] = (
|
||||
+ extras["ja"]
|
||||
+ extras["sklearn"]
|
||||
+ extras["modelcreation"]
|
||||
+ extras["onnxruntime"]
|
||||
+ extras["num2words"]
|
||||
)
|
||||
|
||||
@ -463,7 +457,6 @@ setup(
|
||||
extras["tests_torch"] = deps_list()
|
||||
extras["tests_hub"] = deps_list()
|
||||
extras["tests_pipelines_torch"] = deps_list()
|
||||
extras["tests_onnx"] = deps_list()
|
||||
extras["tests_examples_torch"] = deps_list()
|
||||
extras["tests_custom_tokenizers"] = deps_list()
|
||||
extras["tests_exotic_models"] = deps_list()
|
||||
|
@ -34,9 +34,6 @@ deps = {
|
||||
"nltk": "nltk<=3.8.1",
|
||||
"num2words": "num2words",
|
||||
"numpy": "numpy>=1.17",
|
||||
"onnxconverter-common": "onnxconverter-common",
|
||||
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime": "onnxruntime>=1.4.0",
|
||||
"openai": "openai>=1.98.0",
|
||||
"opencv-python": "opencv-python",
|
||||
"optimum-benchmark": "optimum-benchmark>=0.3.0",
|
||||
|
@ -25,8 +25,6 @@ _import_structure = {
|
||||
"OnnxSeq2SeqConfigWithPast",
|
||||
"PatchingSpec",
|
||||
],
|
||||
"convert": ["export", "validate_model_outputs"],
|
||||
"features": ["FeaturesManager"],
|
||||
"utils": ["ParameterFormat", "compute_serialized_parameters_size"],
|
||||
}
|
||||
|
||||
@ -39,8 +37,6 @@ if TYPE_CHECKING:
|
||||
OnnxSeq2SeqConfigWithPast,
|
||||
PatchingSpec,
|
||||
)
|
||||
from .convert import export, validate_model_outputs
|
||||
from .features import FeaturesManager
|
||||
from .utils import ParameterFormat, compute_serialized_parameters_size
|
||||
|
||||
else:
|
||||
|
@ -1,228 +0,0 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
from packaging import version
|
||||
|
||||
from .. import AutoFeatureExtractor, AutoImageProcessor, AutoProcessor, AutoTokenizer
|
||||
from ..utils import logging
|
||||
from ..utils.import_utils import is_optimum_available
|
||||
from .convert import export, validate_model_outputs
|
||||
from .features import FeaturesManager
|
||||
from .utils import get_preprocessor
|
||||
|
||||
|
||||
MIN_OPTIMUM_VERSION = "1.5.0"
|
||||
|
||||
ENCODER_DECODER_MODELS = ["vision-encoder-decoder"]
|
||||
|
||||
|
||||
def export_with_optimum(args):
|
||||
if is_optimum_available():
|
||||
from optimum.version import __version__ as optimum_version
|
||||
|
||||
parsed_optimum_version = version.parse(optimum_version)
|
||||
if parsed_optimum_version < version.parse(MIN_OPTIMUM_VERSION):
|
||||
raise RuntimeError(
|
||||
f"transformers.onnx requires optimum >= {MIN_OPTIMUM_VERSION} but {optimum_version} is installed. You "
|
||||
"can upgrade optimum by running: pip install -U optimum[exporters]"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"transformers.onnx requires optimum to run, you can install the library by running: pip install "
|
||||
"optimum[exporters]"
|
||||
)
|
||||
cmd_line = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"optimum.exporters.onnx",
|
||||
f"--model {args.model}",
|
||||
f"--task {args.feature}",
|
||||
f"{args.output}",
|
||||
]
|
||||
proc = subprocess.Popen(cmd_line, stdout=subprocess.PIPE)
|
||||
proc.wait()
|
||||
|
||||
logger.info(
|
||||
"The export was done by optimum.exporters.onnx. We recommend using to use this package directly in future, as "
|
||||
"transformers.onnx is deprecated, and will be removed in v5. You can find more information here: "
|
||||
"https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model."
|
||||
)
|
||||
|
||||
|
||||
def export_with_transformers(args):
|
||||
args.output = args.output if args.output.is_file() else args.output.joinpath("model.onnx")
|
||||
if not args.output.parent.exists():
|
||||
args.output.parent.mkdir(parents=True)
|
||||
|
||||
# Allocate the model
|
||||
model = FeaturesManager.get_model_from_feature(args.feature, args.model, cache_dir=args.cache_dir)
|
||||
|
||||
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
|
||||
onnx_config = model_onnx_config(model.config)
|
||||
|
||||
if model_kind in ENCODER_DECODER_MODELS:
|
||||
encoder_model = model.get_encoder()
|
||||
decoder_model = model.get_decoder()
|
||||
|
||||
encoder_onnx_config = onnx_config.get_encoder_config(encoder_model.config)
|
||||
decoder_onnx_config = onnx_config.get_decoder_config(
|
||||
encoder_model.config, decoder_model.config, feature=args.feature
|
||||
)
|
||||
|
||||
if args.opset is None:
|
||||
args.opset = max(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)
|
||||
|
||||
if args.opset < min(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset):
|
||||
raise ValueError(
|
||||
f"Opset {args.opset} is not sufficient to export {model_kind}. At least "
|
||||
f" {min(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)} is required."
|
||||
)
|
||||
|
||||
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
|
||||
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor,
|
||||
encoder_model,
|
||||
encoder_onnx_config,
|
||||
args.opset,
|
||||
args.output.parent.joinpath("encoder_model.onnx"),
|
||||
)
|
||||
|
||||
validate_model_outputs(
|
||||
encoder_onnx_config,
|
||||
preprocessor,
|
||||
encoder_model,
|
||||
args.output.parent.joinpath("encoder_model.onnx"),
|
||||
onnx_outputs,
|
||||
args.atol if args.atol else encoder_onnx_config.atol_for_validation,
|
||||
)
|
||||
|
||||
preprocessor = AutoTokenizer.from_pretrained(args.model)
|
||||
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor,
|
||||
decoder_model,
|
||||
decoder_onnx_config,
|
||||
args.opset,
|
||||
args.output.parent.joinpath("decoder_model.onnx"),
|
||||
)
|
||||
|
||||
validate_model_outputs(
|
||||
decoder_onnx_config,
|
||||
preprocessor,
|
||||
decoder_model,
|
||||
args.output.parent.joinpath("decoder_model.onnx"),
|
||||
onnx_outputs,
|
||||
args.atol if args.atol else decoder_onnx_config.atol_for_validation,
|
||||
)
|
||||
logger.info(
|
||||
f"All good, model saved at: {args.output.parent.joinpath('encoder_model.onnx').as_posix()},"
|
||||
f" {args.output.parent.joinpath('decoder_model.onnx').as_posix()}"
|
||||
)
|
||||
|
||||
else:
|
||||
# Instantiate the appropriate preprocessor
|
||||
if args.preprocessor == "auto":
|
||||
preprocessor = get_preprocessor(args.model)
|
||||
elif args.preprocessor == "tokenizer":
|
||||
preprocessor = AutoTokenizer.from_pretrained(args.model)
|
||||
elif args.preprocessor == "image_processor":
|
||||
preprocessor = AutoImageProcessor.from_pretrained(args.model)
|
||||
elif args.preprocessor == "feature_extractor":
|
||||
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
|
||||
elif args.preprocessor == "processor":
|
||||
preprocessor = AutoProcessor.from_pretrained(args.model)
|
||||
else:
|
||||
raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'")
|
||||
|
||||
# Ensure the requested opset is sufficient
|
||||
if args.opset is None:
|
||||
args.opset = onnx_config.default_onnx_opset
|
||||
|
||||
if args.opset < onnx_config.default_onnx_opset:
|
||||
raise ValueError(
|
||||
f"Opset {args.opset} is not sufficient to export {model_kind}. "
|
||||
f"At least {onnx_config.default_onnx_opset} is required."
|
||||
)
|
||||
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor,
|
||||
model,
|
||||
onnx_config,
|
||||
args.opset,
|
||||
args.output,
|
||||
)
|
||||
|
||||
if args.atol is None:
|
||||
args.atol = onnx_config.atol_for_validation
|
||||
|
||||
validate_model_outputs(onnx_config, preprocessor, model, args.output, onnx_outputs, args.atol)
|
||||
logger.info(f"All good, model saved at: {args.output.as_posix()}")
|
||||
warnings.warn(
|
||||
"The export was done by transformers.onnx which is deprecated and will be removed in v5. We recommend"
|
||||
" using optimum.exporters.onnx in future. You can find more information here:"
|
||||
" https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser("Hugging Face Transformers ONNX exporter")
|
||||
parser.add_argument(
|
||||
"-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--feature",
|
||||
default="default",
|
||||
help="The type of features to export the model with.",
|
||||
)
|
||||
parser.add_argument("--opset", type=int, default=None, help="ONNX opset version to export the model with.")
|
||||
parser.add_argument(
|
||||
"--atol", type=float, default=None, help="Absolute difference tolerance when validating the model."
|
||||
)
|
||||
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
|
||||
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
|
||||
parser.add_argument(
|
||||
"--preprocessor",
|
||||
type=str,
|
||||
choices=["auto", "tokenizer", "feature_extractor", "image_processor", "processor"],
|
||||
default="auto",
|
||||
help="Which type of preprocessor to use. 'auto' tries to automatically detect it.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export_with_transformers",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to use transformers.onnx instead of optimum.exporters.onnx to perform the ONNX export. It can be "
|
||||
"useful when exporting a model supported in transformers but not in optimum, otherwise it is not "
|
||||
"recommended."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.export_with_transformers or not is_optimum_available():
|
||||
export_with_transformers(args)
|
||||
else:
|
||||
export_with_optimum(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = logging.get_logger("transformers.onnx") # pylint: disable=invalid-name
|
||||
logger.setLevel(logging.INFO)
|
||||
main()
|
@ -1,368 +0,0 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from inspect import signature
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from packaging.version import Version, parse
|
||||
|
||||
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from ..utils import (
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from .config import OnnxConfig
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..feature_extraction_utils import FeatureExtractionMixin
|
||||
from ..processing_utils import ProcessorMixin
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# This is the minimal required version to support some ONNX Runtime features
|
||||
ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")
|
||||
|
||||
|
||||
def check_onnxruntime_requirements(minimum_version: Version):
|
||||
"""
|
||||
Check onnxruntime is installed and if the installed version match is recent enough
|
||||
|
||||
Raises:
|
||||
ImportError: If onnxruntime is not installed or too old version is found
|
||||
"""
|
||||
try:
|
||||
import onnxruntime
|
||||
|
||||
# Parse the version of the installed onnxruntime
|
||||
ort_version = parse(onnxruntime.__version__)
|
||||
|
||||
# We require 1.4.0 minimum
|
||||
if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
|
||||
raise ImportError(
|
||||
f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
|
||||
f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
|
||||
"Please update onnxruntime by running `pip install --upgrade onnxruntime`"
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"onnxruntime doesn't seem to be currently installed. "
|
||||
"Please install the onnxruntime by running `pip install onnxruntime`"
|
||||
" and relaunch the conversion."
|
||||
)
|
||||
|
||||
|
||||
def export_pytorch(
|
||||
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"],
|
||||
model: "PreTrainedModel",
|
||||
config: OnnxConfig,
|
||||
opset: int,
|
||||
output: Path,
|
||||
tokenizer: Optional["PreTrainedTokenizer"] = None,
|
||||
device: str = "cpu",
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Export a PyTorch model to an ONNX Intermediate Representation (IR)
|
||||
|
||||
Args:
|
||||
preprocessor: ([`PreTrainedTokenizer`], [`FeatureExtractionMixin`] or [`ProcessorMixin`]):
|
||||
The preprocessor used for encoding the data.
|
||||
model ([`PreTrainedModel`]):
|
||||
The model to export.
|
||||
config ([`~onnx.config.OnnxConfig`]):
|
||||
The ONNX configuration associated with the exported model.
|
||||
opset (`int`):
|
||||
The version of the ONNX operator set to use.
|
||||
output (`Path`):
|
||||
Directory to store the exported ONNX model.
|
||||
device (`str`, *optional*, defaults to `cpu`):
|
||||
The device on which the ONNX model will be exported. Either `cpu` or `cuda`.
|
||||
|
||||
Returns:
|
||||
`tuple[list[str], list[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
|
||||
the ONNX configuration.
|
||||
"""
|
||||
|
||||
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
|
||||
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
|
||||
if tokenizer is not None:
|
||||
warnings.warn(
|
||||
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
|
||||
" `preprocessor` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummy inputs.")
|
||||
preprocessor = tokenizer
|
||||
|
||||
if issubclass(type(model), PreTrainedModel):
|
||||
import torch
|
||||
from torch.onnx import export as onnx_export
|
||||
|
||||
with torch.no_grad():
|
||||
model.config.return_dict = True
|
||||
model.eval()
|
||||
|
||||
# Check if we need to override certain configuration item
|
||||
if config.values_override is not None:
|
||||
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
|
||||
for override_config_key, override_config_value in config.values_override.items():
|
||||
logger.info(f"\t- {override_config_key} -> {override_config_value}")
|
||||
setattr(model.config, override_config_key, override_config_value)
|
||||
|
||||
# Ensure inputs match
|
||||
# TODO: Check when exporting QA we provide "is_pair=True"
|
||||
model_inputs = config.generate_dummy_inputs(preprocessor)
|
||||
device = torch.device(device)
|
||||
if device.type == "cuda" and torch.cuda.is_available():
|
||||
model.to(device)
|
||||
model_inputs_device = {}
|
||||
for k, v in model_inputs.items():
|
||||
if isinstance(v, tuple):
|
||||
model_inputs_device[k] = tuple(
|
||||
x.to(device) if isinstance(x, torch.Tensor) else None for x in v
|
||||
)
|
||||
elif isinstance(v, list):
|
||||
model_inputs_device[k] = [
|
||||
tuple(x.to(device) if isinstance(x, torch.Tensor) else None for x in t) for t in v
|
||||
]
|
||||
else:
|
||||
model_inputs_device[k] = v.to(device)
|
||||
|
||||
model_inputs = model_inputs_device
|
||||
|
||||
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
|
||||
onnx_outputs = list(config.outputs.keys())
|
||||
|
||||
if not inputs_match:
|
||||
raise ValueError("Model and config inputs doesn't match")
|
||||
|
||||
config.patch_ops()
|
||||
|
||||
onnx_export(
|
||||
model,
|
||||
(model_inputs,),
|
||||
f=output.as_posix(),
|
||||
input_names=list(config.inputs.keys()),
|
||||
output_names=onnx_outputs,
|
||||
dynamic_axes=dict(chain(config.inputs.items(), config.outputs.items())),
|
||||
do_constant_folding=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
config.restore_ops()
|
||||
|
||||
return matched_inputs, onnx_outputs
|
||||
|
||||
|
||||
def export(
|
||||
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"],
|
||||
model: "PreTrainedModel",
|
||||
config: OnnxConfig,
|
||||
opset: int,
|
||||
output: Path,
|
||||
tokenizer: Optional["PreTrainedTokenizer"] = None,
|
||||
device: str = "cpu",
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Export a Pytorch model to an ONNX Intermediate Representation (IR)
|
||||
|
||||
Args:
|
||||
preprocessor: ([`PreTrainedTokenizer`], [`FeatureExtractionMixin`] or [`ProcessorMixin`]):
|
||||
The preprocessor used for encoding the data.
|
||||
model ([`PreTrainedModel`):
|
||||
The model to export.
|
||||
config ([`~onnx.config.OnnxConfig`]):
|
||||
The ONNX configuration associated with the exported model.
|
||||
opset (`int`):
|
||||
The version of the ONNX operator set to use.
|
||||
output (`Path`):
|
||||
Directory to store the exported ONNX model.
|
||||
device (`str`, *optional*, defaults to `cpu`):
|
||||
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
|
||||
export on CUDA devices.
|
||||
|
||||
Returns:
|
||||
`tuple[list[str], list[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
|
||||
the ONNX configuration.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
raise ImportError("Cannot convert because PyTorchis not installed. Please install it first.")
|
||||
|
||||
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
|
||||
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
|
||||
if tokenizer is not None:
|
||||
warnings.warn(
|
||||
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
|
||||
" `preprocessor` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummy inputs.")
|
||||
preprocessor = tokenizer
|
||||
|
||||
from ..utils import get_torch_version
|
||||
|
||||
if not config.is_torch_support_available:
|
||||
logger.warning(
|
||||
f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version},"
|
||||
f" got: {get_torch_version()}"
|
||||
)
|
||||
|
||||
if issubclass(type(model), PreTrainedModel):
|
||||
return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer, device=device)
|
||||
|
||||
|
||||
def validate_model_outputs(
|
||||
config: OnnxConfig,
|
||||
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"],
|
||||
reference_model: "PreTrainedModel",
|
||||
onnx_model: Path,
|
||||
onnx_named_outputs: list[str],
|
||||
atol: float,
|
||||
tokenizer: Optional["PreTrainedTokenizer"] = None,
|
||||
):
|
||||
from onnxruntime import InferenceSession, SessionOptions
|
||||
|
||||
logger.info("Validating ONNX model...")
|
||||
|
||||
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
|
||||
raise ValueError("You cannot provide both a tokenizer and a preprocessor to validate the model outputs.")
|
||||
if tokenizer is not None:
|
||||
warnings.warn(
|
||||
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
|
||||
" `preprocessor` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummy inputs.")
|
||||
preprocessor = tokenizer
|
||||
|
||||
# generate inputs with a different batch_size and seq_len that was used for conversion to properly test
|
||||
# dynamic input shapes.
|
||||
if issubclass(type(reference_model), PreTrainedModel):
|
||||
reference_model_inputs = config.generate_dummy_inputs(
|
||||
preprocessor,
|
||||
batch_size=config.default_fixed_batch + 1,
|
||||
seq_length=config.default_fixed_sequence + 1,
|
||||
)
|
||||
|
||||
# Create ONNX Runtime session
|
||||
options = SessionOptions()
|
||||
session = InferenceSession(onnx_model.as_posix(), options, providers=["CPUExecutionProvider"])
|
||||
|
||||
# Compute outputs from the reference model
|
||||
if issubclass(type(reference_model), PreTrainedModel):
|
||||
reference_model.to("cpu")
|
||||
ref_outputs = reference_model(**reference_model_inputs)
|
||||
ref_outputs_dict = {}
|
||||
|
||||
# We flatten potential collection of outputs (i.e. past_keys) to a flat structure
|
||||
for name, value in ref_outputs.items():
|
||||
# Overwriting the output name as "present" since it is the name used for the ONNX outputs
|
||||
# ("past_key_values" being taken for the ONNX inputs)
|
||||
if name == "past_key_values":
|
||||
name = "present"
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = config.flatten_output_collection_property(name, value)
|
||||
ref_outputs_dict.update(value)
|
||||
else:
|
||||
ref_outputs_dict[name] = value
|
||||
|
||||
# Create onnxruntime inputs from the reference model inputs
|
||||
reference_model_inputs_onnxruntime = config.generate_dummy_inputs_onnxruntime(reference_model_inputs)
|
||||
|
||||
# We flatten potential collection of inputs (i.e. past_keys)
|
||||
onnx_inputs = {}
|
||||
for name, value in reference_model_inputs_onnxruntime.items():
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = config.flatten_output_collection_property(name, value)
|
||||
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
|
||||
else:
|
||||
onnx_inputs[name] = value.numpy()
|
||||
|
||||
# Compute outputs from the ONNX model
|
||||
onnx_outputs = session.run(onnx_named_outputs, onnx_inputs)
|
||||
|
||||
# Check we have a subset of the keys into onnx_outputs against ref_outputs
|
||||
ref_outputs_set, onnx_outputs_set = set(ref_outputs_dict.keys()), set(onnx_named_outputs)
|
||||
if not onnx_outputs_set.issubset(ref_outputs_set):
|
||||
logger.info(
|
||||
f"\t-[x] ONNX model output names {onnx_outputs_set} do not match reference model {ref_outputs_set}"
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Outputs doesn't match between reference model and ONNX exported model: "
|
||||
f"{onnx_outputs_set.difference(ref_outputs_set)}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"\t-[✓] ONNX model output names match reference model ({onnx_outputs_set})")
|
||||
|
||||
# Check the shape and values match
|
||||
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
|
||||
if is_torch_available() and issubclass(type(reference_model), PreTrainedModel):
|
||||
ref_value = ref_outputs_dict[name].detach().numpy()
|
||||
else:
|
||||
ref_value = ref_outputs_dict[name].numpy()
|
||||
logger.info(f'\t- Validating ONNX Model output "{name}":')
|
||||
|
||||
# Shape
|
||||
if ort_value.shape != ref_value.shape:
|
||||
logger.info(f"\t\t-[x] shape {ort_value.shape} doesn't match {ref_value.shape}")
|
||||
raise ValueError(
|
||||
"Outputs shape doesn't match between reference model and ONNX exported model: "
|
||||
f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
|
||||
)
|
||||
else:
|
||||
logger.info(f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}")
|
||||
|
||||
# Values
|
||||
if not np.allclose(ref_value, ort_value, atol=atol):
|
||||
bad_indices = np.logical_not(np.isclose(ref_value, ort_value, atol=atol))
|
||||
logger.info(f"\t\t-[x] values not close enough (atol: {atol})")
|
||||
raise ValueError(
|
||||
"Outputs values doesn't match between reference model and ONNX exported model: "
|
||||
f"Got max absolute difference of: {np.amax(np.abs(ref_value - ort_value))} for "
|
||||
f"{ref_value[bad_indices]} vs {ort_value[bad_indices]}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"\t\t-[✓] all values close (atol: {atol})")
|
||||
|
||||
|
||||
def ensure_model_and_config_inputs_match(
|
||||
model: "PreTrainedModel", model_inputs: Iterable[str]
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
:param model_inputs: :param config_inputs: :return:
|
||||
"""
|
||||
forward_parameters = signature(model.forward).parameters
|
||||
model_inputs_set = set(model_inputs)
|
||||
|
||||
# We are fine if config_inputs has more keys than model_inputs
|
||||
forward_inputs_set = set(forward_parameters.keys())
|
||||
is_ok = model_inputs_set.issubset(forward_inputs_set)
|
||||
|
||||
# Make sure the input order match (VERY IMPORTANT !!!!)
|
||||
matching_inputs = forward_inputs_set.intersection(model_inputs_set)
|
||||
ordered_inputs = [parameter for parameter in forward_parameters if parameter in matching_inputs]
|
||||
return is_ok, ordered_inputs
|
@ -1,635 +0,0 @@
|
||||
from functools import partial, reduce
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
import transformers
|
||||
|
||||
from .. import PretrainedConfig, is_torch_available
|
||||
from ..utils import logging
|
||||
from .config import OnnxConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.models.auto import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForImageSegmentation,
|
||||
AutoModelForMaskedImageModeling,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSemanticSegmentation,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForVision2Seq,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"The ONNX export features is only supported for PyTorch. You will not be able to export models without it installed."
|
||||
)
|
||||
|
||||
|
||||
def supported_features_mapping(
|
||||
*supported_features: str, onnx_config_cls: Optional[str] = None
|
||||
) -> dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
|
||||
"""
|
||||
Generate the mapping between supported the features and their corresponding OnnxConfig for a given model.
|
||||
|
||||
Args:
|
||||
*supported_features: The names of the supported features.
|
||||
onnx_config_cls: The OnnxConfig full name corresponding to the model.
|
||||
|
||||
Returns:
|
||||
The dictionary mapping a feature to an OnnxConfig constructor.
|
||||
"""
|
||||
if onnx_config_cls is None:
|
||||
raise ValueError("A OnnxConfig class must be provided")
|
||||
|
||||
config_cls = transformers
|
||||
for attr_name in onnx_config_cls.split("."):
|
||||
config_cls = getattr(config_cls, attr_name)
|
||||
mapping = {}
|
||||
for feature in supported_features:
|
||||
if "-with-past" in feature:
|
||||
task = feature.replace("-with-past", "")
|
||||
mapping[feature] = partial(config_cls.with_past, task=task)
|
||||
else:
|
||||
mapping[feature] = partial(config_cls.from_model_config, task=feature)
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
class FeaturesManager:
|
||||
_TASKS_TO_AUTOMODELS = {}
|
||||
if is_torch_available():
|
||||
_TASKS_TO_AUTOMODELS = {
|
||||
"default": AutoModel,
|
||||
"masked-lm": AutoModelForMaskedLM,
|
||||
"causal-lm": AutoModelForCausalLM,
|
||||
"seq2seq-lm": AutoModelForSeq2SeqLM,
|
||||
"sequence-classification": AutoModelForSequenceClassification,
|
||||
"token-classification": AutoModelForTokenClassification,
|
||||
"multiple-choice": AutoModelForMultipleChoice,
|
||||
"object-detection": AutoModelForObjectDetection,
|
||||
"question-answering": AutoModelForQuestionAnswering,
|
||||
"image-classification": AutoModelForImageClassification,
|
||||
"image-segmentation": AutoModelForImageSegmentation,
|
||||
"masked-im": AutoModelForMaskedImageModeling,
|
||||
"semantic-segmentation": AutoModelForSemanticSegmentation,
|
||||
"vision2seq-lm": AutoModelForVision2Seq,
|
||||
"speech2seq-lm": AutoModelForSpeechSeq2Seq,
|
||||
}
|
||||
|
||||
# Set of model topologies we support associated to the features supported by each topology and the factory
|
||||
_SUPPORTED_MODEL_TYPE = {
|
||||
"albert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.albert.AlbertOnnxConfig",
|
||||
),
|
||||
"bart": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
"sequence-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.bart.BartOnnxConfig",
|
||||
),
|
||||
# BEiT cannot be used with the masked image modeling autoclass, so this feature is excluded here
|
||||
"beit": supported_features_mapping(
|
||||
"default", "image-classification", onnx_config_cls="models.beit.BeitOnnxConfig"
|
||||
),
|
||||
"bert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.bert.BertOnnxConfig",
|
||||
),
|
||||
"big-bird": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.big_bird.BigBirdOnnxConfig",
|
||||
),
|
||||
"bigbird-pegasus": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
"sequence-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.bigbird_pegasus.BigBirdPegasusOnnxConfig",
|
||||
),
|
||||
"blenderbot": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
onnx_config_cls="models.blenderbot.BlenderbotOnnxConfig",
|
||||
),
|
||||
"blenderbot-small": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
|
||||
),
|
||||
"bloom": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"sequence-classification",
|
||||
"token-classification",
|
||||
onnx_config_cls="models.bloom.BloomOnnxConfig",
|
||||
),
|
||||
"camembert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.camembert.CamembertOnnxConfig",
|
||||
),
|
||||
"clip": supported_features_mapping(
|
||||
"default",
|
||||
onnx_config_cls="models.clip.CLIPOnnxConfig",
|
||||
),
|
||||
"codegen": supported_features_mapping(
|
||||
"default",
|
||||
"causal-lm",
|
||||
onnx_config_cls="models.codegen.CodeGenOnnxConfig",
|
||||
),
|
||||
"convbert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.convbert.ConvBertOnnxConfig",
|
||||
),
|
||||
"convnext": supported_features_mapping(
|
||||
"default",
|
||||
"image-classification",
|
||||
onnx_config_cls="models.convnext.ConvNextOnnxConfig",
|
||||
),
|
||||
"data2vec-text": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.data2vec.Data2VecTextOnnxConfig",
|
||||
),
|
||||
"data2vec-vision": supported_features_mapping(
|
||||
"default",
|
||||
"image-classification",
|
||||
# ONNX doesn't support `adaptive_avg_pool2d` yet
|
||||
# "semantic-segmentation",
|
||||
onnx_config_cls="models.data2vec.Data2VecVisionOnnxConfig",
|
||||
),
|
||||
"deberta": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.deberta.DebertaOnnxConfig",
|
||||
),
|
||||
"deberta-v2": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.deberta_v2.DebertaV2OnnxConfig",
|
||||
),
|
||||
"deit": supported_features_mapping(
|
||||
"default", "image-classification", onnx_config_cls="models.deit.DeiTOnnxConfig"
|
||||
),
|
||||
"detr": supported_features_mapping(
|
||||
"default",
|
||||
"object-detection",
|
||||
"image-segmentation",
|
||||
onnx_config_cls="models.detr.DetrOnnxConfig",
|
||||
),
|
||||
"distilbert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.distilbert.DistilBertOnnxConfig",
|
||||
),
|
||||
"electra": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.electra.ElectraOnnxConfig",
|
||||
),
|
||||
"flaubert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.flaubert.FlaubertOnnxConfig",
|
||||
),
|
||||
"gpt2": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"sequence-classification",
|
||||
"token-classification",
|
||||
onnx_config_cls="models.gpt2.GPT2OnnxConfig",
|
||||
),
|
||||
"gptj": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"question-answering",
|
||||
"sequence-classification",
|
||||
onnx_config_cls="models.gptj.GPTJOnnxConfig",
|
||||
),
|
||||
"gpt-neo": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"sequence-classification",
|
||||
onnx_config_cls="models.gpt_neo.GPTNeoOnnxConfig",
|
||||
),
|
||||
"groupvit": supported_features_mapping(
|
||||
"default",
|
||||
onnx_config_cls="models.groupvit.GroupViTOnnxConfig",
|
||||
),
|
||||
"ibert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.ibert.IBertOnnxConfig",
|
||||
),
|
||||
"imagegpt": supported_features_mapping(
|
||||
"default", "image-classification", onnx_config_cls="models.imagegpt.ImageGPTOnnxConfig"
|
||||
),
|
||||
"layoutlm": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
"token-classification",
|
||||
onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig",
|
||||
),
|
||||
"layoutlmv3": supported_features_mapping(
|
||||
"default",
|
||||
"question-answering",
|
||||
"sequence-classification",
|
||||
"token-classification",
|
||||
onnx_config_cls="models.layoutlmv3.LayoutLMv3OnnxConfig",
|
||||
),
|
||||
"levit": supported_features_mapping(
|
||||
"default", "image-classification", onnx_config_cls="models.levit.LevitOnnxConfig"
|
||||
),
|
||||
"longt5": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
onnx_config_cls="models.longt5.LongT5OnnxConfig",
|
||||
),
|
||||
"longformer": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"multiple-choice",
|
||||
"question-answering",
|
||||
"sequence-classification",
|
||||
"token-classification",
|
||||
onnx_config_cls="models.longformer.LongformerOnnxConfig",
|
||||
),
|
||||
"marian": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
onnx_config_cls="models.marian.MarianOnnxConfig",
|
||||
),
|
||||
"mbart": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
"sequence-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.mbart.MBartOnnxConfig",
|
||||
),
|
||||
"mobilebert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.mobilebert.MobileBertOnnxConfig",
|
||||
),
|
||||
"mobilenet-v1": supported_features_mapping(
|
||||
"default",
|
||||
"image-classification",
|
||||
onnx_config_cls="models.mobilenet_v1.MobileNetV1OnnxConfig",
|
||||
),
|
||||
"mobilenet-v2": supported_features_mapping(
|
||||
"default",
|
||||
"image-classification",
|
||||
onnx_config_cls="models.mobilenet_v2.MobileNetV2OnnxConfig",
|
||||
),
|
||||
"mobilevit": supported_features_mapping(
|
||||
"default",
|
||||
"image-classification",
|
||||
onnx_config_cls="models.mobilevit.MobileViTOnnxConfig",
|
||||
),
|
||||
"mt5": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
onnx_config_cls="models.mt5.MT5OnnxConfig",
|
||||
),
|
||||
"m2m-100": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
onnx_config_cls="models.m2m_100.M2M100OnnxConfig",
|
||||
),
|
||||
"owlvit": supported_features_mapping(
|
||||
"default",
|
||||
onnx_config_cls="models.owlvit.OwlViTOnnxConfig",
|
||||
),
|
||||
"perceiver": supported_features_mapping(
|
||||
"image-classification",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
|
||||
),
|
||||
"poolformer": supported_features_mapping(
|
||||
"default", "image-classification", onnx_config_cls="models.poolformer.PoolFormerOnnxConfig"
|
||||
),
|
||||
"rembert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.rembert.RemBertOnnxConfig",
|
||||
),
|
||||
"resnet": supported_features_mapping(
|
||||
"default",
|
||||
"image-classification",
|
||||
onnx_config_cls="models.resnet.ResNetOnnxConfig",
|
||||
),
|
||||
"roberta": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.roberta.RobertaOnnxConfig",
|
||||
),
|
||||
"roformer": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"token-classification",
|
||||
"multiple-choice",
|
||||
"question-answering",
|
||||
"token-classification",
|
||||
onnx_config_cls="models.roformer.RoFormerOnnxConfig",
|
||||
),
|
||||
"segformer": supported_features_mapping(
|
||||
"default",
|
||||
"image-classification",
|
||||
"semantic-segmentation",
|
||||
onnx_config_cls="models.segformer.SegformerOnnxConfig",
|
||||
),
|
||||
"squeezebert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.squeezebert.SqueezeBertOnnxConfig",
|
||||
),
|
||||
"swin": supported_features_mapping(
|
||||
"default", "image-classification", onnx_config_cls="models.swin.SwinOnnxConfig"
|
||||
),
|
||||
"t5": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
onnx_config_cls="models.t5.T5OnnxConfig",
|
||||
),
|
||||
"vision-encoder-decoder": supported_features_mapping(
|
||||
"vision2seq-lm", onnx_config_cls="models.vision_encoder_decoder.VisionEncoderDecoderOnnxConfig"
|
||||
),
|
||||
"vit": supported_features_mapping(
|
||||
"default", "image-classification", onnx_config_cls="models.vit.ViTOnnxConfig"
|
||||
),
|
||||
"whisper": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"speech2seq-lm",
|
||||
"speech2seq-lm-with-past",
|
||||
onnx_config_cls="models.whisper.WhisperOnnxConfig",
|
||||
),
|
||||
"xlm": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.xlm.XLMOnnxConfig",
|
||||
),
|
||||
"xlm-roberta": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls="models.xlm_roberta.XLMRobertaOnnxConfig",
|
||||
),
|
||||
"yolos": supported_features_mapping(
|
||||
"default",
|
||||
"object-detection",
|
||||
onnx_config_cls="models.yolos.YolosOnnxConfig",
|
||||
),
|
||||
}
|
||||
|
||||
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
|
||||
|
||||
@staticmethod
|
||||
def get_supported_features_for_model_type(
|
||||
model_type: str, model_name: Optional[str] = None
|
||||
) -> dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
|
||||
"""
|
||||
Tries to retrieve the feature -> OnnxConfig constructor map from the model type.
|
||||
|
||||
Args:
|
||||
model_type (`str`):
|
||||
The model type to retrieve the supported features for.
|
||||
model_name (`str`, *optional*):
|
||||
The name attribute of the model object, only used for the exception message.
|
||||
|
||||
Returns:
|
||||
The dictionary mapping each feature to a corresponding OnnxConfig constructor.
|
||||
"""
|
||||
model_type = model_type.lower()
|
||||
if model_type not in FeaturesManager._SUPPORTED_MODEL_TYPE:
|
||||
model_type_and_model_name = f"{model_type} ({model_name})" if model_name else model_type
|
||||
raise KeyError(
|
||||
f"{model_type_and_model_name} is not supported yet. "
|
||||
f"Only {list(FeaturesManager._SUPPORTED_MODEL_TYPE.keys())} are supported. "
|
||||
f"If you want to support {model_type} please propose a PR or open up an issue."
|
||||
)
|
||||
return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type]
|
||||
|
||||
@staticmethod
|
||||
def feature_to_task(feature: str) -> str:
|
||||
return feature.replace("-with-past", "")
|
||||
|
||||
@staticmethod
|
||||
def get_model_class_for_feature(feature: str) -> type:
|
||||
"""
|
||||
Attempts to retrieve an AutoModel class from a feature name.
|
||||
|
||||
Args:
|
||||
feature (`str`):
|
||||
The feature required.
|
||||
|
||||
Returns:
|
||||
The AutoModel class corresponding to the feature.
|
||||
"""
|
||||
task = FeaturesManager.feature_to_task(feature)
|
||||
task_to_automodel = FeaturesManager._TASKS_TO_AUTOMODELS
|
||||
if task not in task_to_automodel:
|
||||
raise KeyError(
|
||||
f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
|
||||
)
|
||||
|
||||
return task_to_automodel[task]
|
||||
|
||||
@staticmethod
|
||||
def get_model_from_feature(feature: str, model: str, cache_dir: Optional[str] = None) -> "PreTrainedModel":
|
||||
"""
|
||||
Attempts to retrieve a model from a model's name and the feature to be enabled.
|
||||
|
||||
Args:
|
||||
feature (`str`):
|
||||
The feature required.
|
||||
model (`str`):
|
||||
The name of the model to export.
|
||||
|
||||
Returns:
|
||||
The instance of the model.
|
||||
|
||||
"""
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature)
|
||||
model = model_class.from_pretrained(model, cache_dir=cache_dir)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def check_supported_model_or_raise(model: "PreTrainedModel", feature: str = "default") -> tuple[str, Callable]:
|
||||
"""
|
||||
Check whether or not the model has the requested features.
|
||||
|
||||
Args:
|
||||
model: The model to export.
|
||||
feature: The name of the feature to check if it is available.
|
||||
|
||||
Returns:
|
||||
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties.
|
||||
|
||||
"""
|
||||
model_type = model.config.model_type.replace("_", "-")
|
||||
model_name = getattr(model, "name", "")
|
||||
model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name)
|
||||
if feature not in model_features:
|
||||
raise ValueError(
|
||||
f"{model.config.model_type} doesn't support feature {feature}. Supported values are: {model_features}"
|
||||
)
|
||||
|
||||
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
|
||||
|
||||
def get_config(model_type: str, feature: str) -> OnnxConfig:
|
||||
"""
|
||||
Gets the OnnxConfig for a model_type and feature combination.
|
||||
|
||||
Args:
|
||||
model_type (`str`):
|
||||
The model type to retrieve the config for.
|
||||
feature (`str`):
|
||||
The feature to retrieve the config for.
|
||||
|
||||
Returns:
|
||||
`OnnxConfig`: config for the combination
|
||||
"""
|
||||
return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
|
@ -28,7 +28,6 @@ docs/source/en/main_classes/feature_extractor.md
|
||||
docs/source/en/main_classes/image_processor.md
|
||||
docs/source/en/main_classes/logging.md
|
||||
docs/source/en/main_classes/model.md
|
||||
docs/source/en/main_classes/onnx.md
|
||||
docs/source/en/main_classes/optimizer_schedules.md
|
||||
docs/source/en/main_classes/output.md
|
||||
docs/source/en/main_classes/pipelines.md
|
||||
@ -738,11 +737,7 @@ src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py
|
||||
src/transformers/models/yoso/modeling_yoso.py
|
||||
src/transformers/models/zamba/configuration_zamba.py
|
||||
src/transformers/models/zamba/modeling_zamba.py
|
||||
src/transformers/onnx/__main__.py
|
||||
src/transformers/onnx/config.py
|
||||
src/transformers/onnx/convert.py
|
||||
src/transformers/onnx/features.py
|
||||
src/transformers/onnx/utils.py
|
||||
src/transformers/optimization.py
|
||||
src/transformers/pipelines/audio_classification.py
|
||||
src/transformers/pipelines/audio_utils.py
|
||||
@ -815,4 +810,4 @@ src/transformers/utils/peft_utils.py
|
||||
src/transformers/utils/quantization_config.py
|
||||
src/transformers/utils/sentencepiece_model_pb2.py
|
||||
src/transformers/utils/sentencepiece_model_pb2_new.py
|
||||
src/transformers/utils/versions.py
|
||||
src/transformers/utils/versions.py
|
||||
|
Reference in New Issue
Block a user