mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Add Onnx Config for ImageGPT (#19868)
* add Onnx Config for ImageGPT * add generate_dummy_inputs for onnx config * add TYPE_CHECKING clause * Update doc for generate_dummy_inputs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@ -74,6 +74,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- GPT-J
|
- GPT-J
|
||||||
- GroupViT
|
- GroupViT
|
||||||
- I-BERT
|
- I-BERT
|
||||||
|
- ImageGPT
|
||||||
- LayoutLM
|
- LayoutLM
|
||||||
- LayoutLMv3
|
- LayoutLMv3
|
||||||
- LeViT
|
- LeViT
|
||||||
|
@ -21,7 +21,9 @@ from typing import TYPE_CHECKING
|
|||||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"]}
|
_import_structure = {
|
||||||
|
"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig", "ImageGPTOnnxConfig"]
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
@ -48,7 +50,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig
|
from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig, ImageGPTOnnxConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
|
@ -14,10 +14,17 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" OpenAI ImageGPT configuration"""
|
""" OpenAI ImageGPT configuration"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ... import FeatureExtractionMixin, TensorType
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
@ -140,3 +147,56 @@ class ImageGPTConfig(PretrainedConfig):
|
|||||||
self.tie_word_embeddings = tie_word_embeddings
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
|
||||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGPTOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
preprocessor: "FeatureExtractionMixin",
|
||||||
|
batch_size: int = 1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional["TensorType"] = None,
|
||||||
|
num_channels: int = 3,
|
||||||
|
image_width: int = 32,
|
||||||
|
image_height: int = 32,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate inputs to provide to the ONNX exporter for the specific framework
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preprocessor ([`PreTrainedTokenizerBase`] or [`FeatureExtractionMixin`]):
|
||||||
|
The preprocessor associated with this model configuration.
|
||||||
|
batch_size (`int`, *optional*, defaults to -1):
|
||||||
|
The batch size to export the model for (-1 means dynamic axis).
|
||||||
|
num_choices (`int`, *optional*, defaults to -1):
|
||||||
|
The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
|
||||||
|
seq_length (`int`, *optional*, defaults to -1):
|
||||||
|
The sequence length to export the model for (-1 means dynamic axis).
|
||||||
|
is_pair (`bool`, *optional*, defaults to `False`):
|
||||||
|
Indicate if the input is a pair (sentence 1, sentence 2)
|
||||||
|
framework (`TensorType`, *optional*, defaults to `None`):
|
||||||
|
The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for.
|
||||||
|
num_channels (`int`, *optional*, defaults to 3):
|
||||||
|
The number of channels of the generated images.
|
||||||
|
image_width (`int`, *optional*, defaults to 40):
|
||||||
|
The width of the generated images.
|
||||||
|
image_height (`int`, *optional*, defaults to 40):
|
||||||
|
The height of the generated images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_image = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
|
||||||
|
inputs = dict(preprocessor(input_image, framework))
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
@ -341,6 +341,9 @@ class FeaturesManager:
|
|||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls="models.ibert.IBertOnnxConfig",
|
onnx_config_cls="models.ibert.IBertOnnxConfig",
|
||||||
),
|
),
|
||||||
|
"imagegpt": supported_features_mapping(
|
||||||
|
"default", "image-classification", onnx_config_cls="models.imagegpt.ImageGPTOnnxConfig"
|
||||||
|
),
|
||||||
"layoutlm": supported_features_mapping(
|
"layoutlm": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
|
@ -193,6 +193,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("detr", "facebook/detr-resnet-50"),
|
("detr", "facebook/detr-resnet-50"),
|
||||||
("distilbert", "distilbert-base-cased"),
|
("distilbert", "distilbert-base-cased"),
|
||||||
("electra", "google/electra-base-generator"),
|
("electra", "google/electra-base-generator"),
|
||||||
|
("imagegpt", "openai/imagegpt-small"),
|
||||||
("resnet", "microsoft/resnet-50"),
|
("resnet", "microsoft/resnet-50"),
|
||||||
("roberta", "roberta-base"),
|
("roberta", "roberta-base"),
|
||||||
("roformer", "junnyu/roformer_chinese_base"),
|
("roformer", "junnyu/roformer_chinese_base"),
|
||||||
|
Reference in New Issue
Block a user