Add Onnx Config for ImageGPT (#19868)

* add Onnx Config for ImageGPT

* add generate_dummy_inputs for onnx config

* add TYPE_CHECKING clause

* Update doc for generate_dummy_inputs

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Raghav Prabhakar
2022-10-28 19:09:53 +05:30
committed by GitHub
parent 9b1dcba94a
commit 0d4c45c585
5 changed files with 69 additions and 2 deletions

View File

@ -74,6 +74,7 @@ Ready-made configurations include the following architectures:
- GPT-J
- GroupViT
- I-BERT
- ImageGPT
- LayoutLM
- LayoutLMv3
- LeViT

View File

@ -21,7 +21,9 @@ from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"]}
_import_structure = {
"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig", "ImageGPTOnnxConfig"]
}
try:
if not is_vision_available():
@ -48,7 +50,7 @@ else:
if TYPE_CHECKING:
from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig
from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig, ImageGPTOnnxConfig
try:
if not is_vision_available():

View File

@ -14,10 +14,17 @@
# limitations under the License.
""" OpenAI ImageGPT configuration"""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
if TYPE_CHECKING:
from ... import FeatureExtractionMixin, TensorType
logger = logging.get_logger(__name__)
IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
@ -140,3 +147,56 @@ class ImageGPTConfig(PretrainedConfig):
self.tie_word_embeddings = tie_word_embeddings
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
class ImageGPTOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
]
)
def generate_dummy_inputs(
self,
preprocessor: "FeatureExtractionMixin",
batch_size: int = 1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
num_channels: int = 3,
image_width: int = 32,
image_height: int = 32,
) -> Mapping[str, Any]:
"""
Generate inputs to provide to the ONNX exporter for the specific framework
Args:
preprocessor ([`PreTrainedTokenizerBase`] or [`FeatureExtractionMixin`]):
The preprocessor associated with this model configuration.
batch_size (`int`, *optional*, defaults to -1):
The batch size to export the model for (-1 means dynamic axis).
num_choices (`int`, *optional*, defaults to -1):
The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
seq_length (`int`, *optional*, defaults to -1):
The sequence length to export the model for (-1 means dynamic axis).
is_pair (`bool`, *optional*, defaults to `False`):
Indicate if the input is a pair (sentence 1, sentence 2)
framework (`TensorType`, *optional*, defaults to `None`):
The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for.
num_channels (`int`, *optional*, defaults to 3):
The number of channels of the generated images.
image_width (`int`, *optional*, defaults to 40):
The width of the generated images.
image_height (`int`, *optional*, defaults to 40):
The height of the generated images.
Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
"""
input_image = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
inputs = dict(preprocessor(input_image, framework))
return inputs

View File

@ -341,6 +341,9 @@ class FeaturesManager:
"question-answering",
onnx_config_cls="models.ibert.IBertOnnxConfig",
),
"imagegpt": supported_features_mapping(
"default", "image-classification", onnx_config_cls="models.imagegpt.ImageGPTOnnxConfig"
),
"layoutlm": supported_features_mapping(
"default",
"masked-lm",

View File

@ -193,6 +193,7 @@ PYTORCH_EXPORT_MODELS = {
("detr", "facebook/detr-resnet-50"),
("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
("imagegpt", "openai/imagegpt-small"),
("resnet", "microsoft/resnet-50"),
("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"),