mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Add Onnx Config for ImageGPT (#19868)
* add Onnx Config for ImageGPT * add generate_dummy_inputs for onnx config * add TYPE_CHECKING clause * Update doc for generate_dummy_inputs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@ -74,6 +74,7 @@ Ready-made configurations include the following architectures:
|
||||
- GPT-J
|
||||
- GroupViT
|
||||
- I-BERT
|
||||
- ImageGPT
|
||||
- LayoutLM
|
||||
- LayoutLMv3
|
||||
- LeViT
|
||||
|
@ -21,7 +21,9 @@ from typing import TYPE_CHECKING
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"]}
|
||||
_import_structure = {
|
||||
"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig", "ImageGPTOnnxConfig"]
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
@ -48,7 +50,7 @@ else:
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig
|
||||
from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig, ImageGPTOnnxConfig
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
|
@ -14,10 +14,17 @@
|
||||
# limitations under the License.
|
||||
""" OpenAI ImageGPT configuration"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ... import FeatureExtractionMixin, TensorType
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
@ -140,3 +147,56 @@ class ImageGPTConfig(PretrainedConfig):
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
class ImageGPTOnnxConfig(OnnxConfig):
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
preprocessor: "FeatureExtractionMixin",
|
||||
batch_size: int = 1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional["TensorType"] = None,
|
||||
num_channels: int = 3,
|
||||
image_width: int = 32,
|
||||
image_height: int = 32,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate inputs to provide to the ONNX exporter for the specific framework
|
||||
|
||||
Args:
|
||||
preprocessor ([`PreTrainedTokenizerBase`] or [`FeatureExtractionMixin`]):
|
||||
The preprocessor associated with this model configuration.
|
||||
batch_size (`int`, *optional*, defaults to -1):
|
||||
The batch size to export the model for (-1 means dynamic axis).
|
||||
num_choices (`int`, *optional*, defaults to -1):
|
||||
The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
|
||||
seq_length (`int`, *optional*, defaults to -1):
|
||||
The sequence length to export the model for (-1 means dynamic axis).
|
||||
is_pair (`bool`, *optional*, defaults to `False`):
|
||||
Indicate if the input is a pair (sentence 1, sentence 2)
|
||||
framework (`TensorType`, *optional*, defaults to `None`):
|
||||
The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of channels of the generated images.
|
||||
image_width (`int`, *optional*, defaults to 40):
|
||||
The width of the generated images.
|
||||
image_height (`int`, *optional*, defaults to 40):
|
||||
The height of the generated images.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
|
||||
"""
|
||||
|
||||
input_image = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
|
||||
inputs = dict(preprocessor(input_image, framework))
|
||||
|
||||
return inputs
|
||||
|
@ -341,6 +341,9 @@ class FeaturesManager:
|
||||
"question-answering",
|
||||
onnx_config_cls="models.ibert.IBertOnnxConfig",
|
||||
),
|
||||
"imagegpt": supported_features_mapping(
|
||||
"default", "image-classification", onnx_config_cls="models.imagegpt.ImageGPTOnnxConfig"
|
||||
),
|
||||
"layoutlm": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
|
@ -193,6 +193,7 @@ PYTORCH_EXPORT_MODELS = {
|
||||
("detr", "facebook/detr-resnet-50"),
|
||||
("distilbert", "distilbert-base-cased"),
|
||||
("electra", "google/electra-base-generator"),
|
||||
("imagegpt", "openai/imagegpt-small"),
|
||||
("resnet", "microsoft/resnet-50"),
|
||||
("roberta", "roberta-base"),
|
||||
("roformer", "junnyu/roformer_chinese_base"),
|
||||
|
Reference in New Issue
Block a user