mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
12 Commits
trigger_up
...
v4.9.2
Author | SHA1 | Date | |
---|---|---|---|
41981a25cd | |||
ec784223ea | |||
bfd53549b0 | |||
226763a262 | |||
f595ea33d9 | |||
a12fa50693 | |||
94b7db97bf | |||
2c255a2e0c | |||
ca272fc523 | |||
bff1c71e84 | |||
8ee16d84ce | |||
6cab8b32e3 |
@ -67,4 +67,5 @@ deploy_doc "25dee4a" v4.6.0
|
||||
deploy_doc "7a6c9fa" v4.7.0
|
||||
deploy_doc "9252a51" v4.8.0
|
||||
deploy_doc "1366172" v4.8.1
|
||||
deploy_doc "96d1cfb" # v4.8.2 Latest stable release
|
||||
deploy_doc "96d1cfb" v4.8.2
|
||||
deploy_doc "72aee83" # v4.9.0 Latest stable release
|
@ -1,10 +1,11 @@
|
||||
// These two things need to be updated at each release for the version selector.
|
||||
// Last stable version
|
||||
const stableVersion = "v4.8.2"
|
||||
const stableVersion = "v4.9.0"
|
||||
// Dictionary doc folder to label. The last stable version should have an empty key.
|
||||
const versionMapping = {
|
||||
"master": "master",
|
||||
"": "v4.8.0/v4.8.1/v4.8.2 (stable)",
|
||||
"": "v4.9.0 (stable)",
|
||||
"v4.8.2": "v4.8.0/v4.8.1/v4.8.2",
|
||||
"v4.7.0": "v4.7.0",
|
||||
"v4.6.0": "v4.6.0",
|
||||
"v4.5.1": "v4.5.0/v4.5.1",
|
||||
|
@ -99,6 +99,30 @@ It will be exported under ``onnx/bert-base-cased``. You should see similar logs:
|
||||
-[✓] all values close (atol: 0.0001)
|
||||
All good, model saved at: onnx/bert-base-cased/model.onnx
|
||||
|
||||
This export can now be used in the ONNX inference runtime:
|
||||
|
||||
.. code-block::
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
from transformers import BertTokenizerFast
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||
|
||||
ort_session = ort.InferenceSession("onnx/bert-base-cased/model.onnx")
|
||||
|
||||
inputs = tokenizer("Using BERT in ONNX!", return_tensors="np")
|
||||
outputs = ort_session.run(["last_hidden_state", "pooler_output"], dict(inputs))
|
||||
|
||||
The outputs used (:obj:`["last_hidden_state", "pooler_output"]`) can be obtained by taking a look at the ONNX
|
||||
configuration of each model. For example, for BERT:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers.models.bert import BertOnnxConfig, BertConfig
|
||||
|
||||
config = BertConfig()
|
||||
onnx_config = BertOnnxConfig(config)
|
||||
output_keys = list(onnx_config.outputs.keys())
|
||||
|
||||
Implementing a custom configuration for an unsupported architecture
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@ -142,6 +166,12 @@ An important fact to notice is the use of `OrderedDict` in both inputs and outpu
|
||||
as inputs are matched against their relative position within the `PreTrainedModel.forward()` prototype and outputs are
|
||||
match against there position in the returned `BaseModelOutputX` instance.
|
||||
|
||||
An example of such an addition is visible here, for the MBart model: `Making MBART ONNX-convertible
|
||||
<https://github.com/huggingface/transformers/pull/13049/commits/d097adcebd89a520f04352eb215a85916934204f>`__
|
||||
|
||||
If you would like to contribute your addition to the library, we recommend you implement tests. An example of such
|
||||
tests is visible here: `Adding tests to the MBART ONNX conversion
|
||||
<https://github.com/huggingface/transformers/pull/13049/commits/5d642f65abf45ceeb72bd855ca7bfe2506a58e6a>`__
|
||||
|
||||
Graph conversion
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
@ -403,6 +403,10 @@ def main():
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
||||
if accelerator.distributed_type == DistributedType.TPU:
|
||||
model.tie_weights()
|
||||
|
||||
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
||||
# shorter in multiprocess)
|
||||
|
||||
|
@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
@ -448,6 +448,10 @@ def main():
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
||||
if accelerator.distributed_type == DistributedType.TPU:
|
||||
model.tie_weights()
|
||||
|
||||
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
||||
# shorter in multiprocess)
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -337,7 +337,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.9.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.9.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Suraj Patil, Stas Bekman, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
|
||||
author_email="thomas@huggingface.co",
|
||||
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",
|
||||
|
@ -22,7 +22,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.9.0"
|
||||
__version__ = "4.9.2"
|
||||
|
||||
# Work around to update TensorFlow's absl.logging threshold which alters the
|
||||
# default Python logging output behavior when present.
|
||||
|
@ -274,8 +274,9 @@ PRESET_MIRROR_DICT = {
|
||||
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
||||
}
|
||||
|
||||
# This is the version of torch required to run torch.fx features.
|
||||
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
|
||||
TORCH_FX_REQUIRED_VERSION = version.parse("1.8")
|
||||
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
|
||||
|
||||
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
|
||||
|
||||
@ -297,7 +298,7 @@ def is_torch_cuda_available():
|
||||
return False
|
||||
|
||||
|
||||
_torch_fx_available = False
|
||||
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
|
||||
if _torch_available:
|
||||
torch_version = version.parse(importlib_metadata.version("torch"))
|
||||
_torch_fx_available = (torch_version.major, torch_version.minor) == (
|
||||
@ -305,11 +306,17 @@ if _torch_available:
|
||||
TORCH_FX_REQUIRED_VERSION.minor,
|
||||
)
|
||||
|
||||
_torch_onnx_dict_inputs_support_available = torch_version >= TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION
|
||||
|
||||
|
||||
def is_torch_fx_available():
|
||||
return _torch_fx_available
|
||||
|
||||
|
||||
def is_torch_onnx_dict_inputs_support_available():
|
||||
return _torch_onnx_dict_inputs_support_available
|
||||
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
@ -594,6 +594,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
self = getattr(self, self.base_model_prefix)
|
||||
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
|
||||
|
||||
for module in self.modules():
|
||||
if hasattr(module, "_tie_weights"):
|
||||
module._tie_weights()
|
||||
|
||||
@staticmethod
|
||||
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
|
||||
uninitialized_encoder_weights: List[str] = []
|
||||
|
@ -860,8 +860,6 @@ class AlbertMLMHead(nn.Module):
|
||||
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
||||
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -874,6 +872,10 @@ class AlbertMLMHead(nn.Module):
|
||||
|
||||
return prediction_scores
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
class AlbertSOPHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
@ -430,16 +430,18 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
||||
class BertGenerationOnlyLMHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
logits = self.decoder(hidden_states)
|
||||
return logits
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""BertGeneration Model with a `language modeling` head on top for CLM fine-tuning. """,
|
||||
|
@ -21,7 +21,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
|
||||
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig", "GPTNeoOnnxConfig"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
@ -43,7 +43,7 @@ if is_flax_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
|
||||
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig, GPTNeoOnnxConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_gpt_neo import (
|
||||
|
@ -14,7 +14,12 @@
|
||||
# limitations under the License.
|
||||
""" GPT Neo model configuration """
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Iterable, Mapping, Optional
|
||||
|
||||
from ... import PreTrainedTokenizer, TensorType, is_torch_available
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfigWithPast, PatchingSpec
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -173,3 +178,162 @@ class GPTNeoConfig(PretrainedConfig):
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.num_layers
|
||||
|
||||
|
||||
def custom_unfold(input, dimension, size, step):
|
||||
"""Custom torch.Tensor.unfold implementation to enable the export to ONNX."""
|
||||
import torch
|
||||
|
||||
shape = input.size()
|
||||
rank = len(shape)
|
||||
sizedim = shape[dimension]
|
||||
|
||||
low_indices = torch.arange(0, sizedim, step)
|
||||
min_length = torch.div(sizedim - size, step, rounding_mode="floor") + 1
|
||||
indices = torch.arange(size) + low_indices[:min_length][:, None]
|
||||
|
||||
s = [slice(None)] * rank
|
||||
s[dimension] = indices
|
||||
sliced = input[s]
|
||||
|
||||
perm = list(range(0, rank + 1))
|
||||
perm.append(perm.pop(dimension + 1))
|
||||
|
||||
return sliced.permute(perm)
|
||||
|
||||
|
||||
def custom_get_block_length_and_num_blocks(seq_length, window_size):
|
||||
"""
|
||||
Custom implementation for GPTNeoAttentionMixin._get_block_length_and_num_blocks to enable the export to ONNX as
|
||||
original implmentation uses Python variables and control flow.
|
||||
"""
|
||||
import torch
|
||||
|
||||
candidates = torch.arange(1, window_size)
|
||||
remainders = torch.remainder(seq_length, candidates)
|
||||
divisor_indices = remainders == 0
|
||||
divisors = candidates[divisor_indices]
|
||||
largest_divisor = torch.max(divisors)
|
||||
return largest_divisor, torch.div(seq_length, largest_divisor, rounding_mode="floor")
|
||||
|
||||
|
||||
class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||
def __init__(self, config: PretrainedConfig, task: str = "default", use_past: bool = False):
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from .modeling_gpt_neo import GPTNeoAttentionMixin
|
||||
|
||||
patching_specs = [
|
||||
PatchingSpec(torch.Tensor, name="unfold", custom_op=custom_unfold),
|
||||
PatchingSpec(
|
||||
GPTNeoAttentionMixin,
|
||||
name="_get_block_length_and_num_blocks",
|
||||
custom_op=custom_get_block_length_and_num_blocks,
|
||||
op_wrapper=staticmethod,
|
||||
),
|
||||
]
|
||||
|
||||
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
|
||||
|
||||
self._num_local_attention = len([type_ for type_ in self._config.attention_layers if type_ == "local"])
|
||||
self._key_values_dynamic_axis = []
|
||||
for i in range(self._config.num_layers):
|
||||
if self._config.attention_layers[i] == "local":
|
||||
self._key_values_dynamic_axis.append({0: "batch", 1: "sequence"})
|
||||
else:
|
||||
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
|
||||
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
|
||||
|
||||
@property
|
||||
def _number_key_values(self):
|
||||
return (self._config.num_layers * 2) - self._num_local_attention
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||
if self.use_past:
|
||||
for i in range(self._config.num_layers):
|
||||
if self._config.attention_layers[i] == "local":
|
||||
common_inputs[f"past_key_values.{i}.key_value"] = {0: "batch", 1: "sequence"}
|
||||
else:
|
||||
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "sequence"}
|
||||
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "sequence"}
|
||||
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_outputs = super().outputs
|
||||
if self.use_past:
|
||||
for i in range(self._config.num_layers):
|
||||
if self._config.attention_layers[i] == "local":
|
||||
common_outputs[f"present.{i}.key_value"] = {0: "batch", 1: "sequence"}
|
||||
else:
|
||||
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "sequence"}
|
||||
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "sequence"}
|
||||
return common_outputs
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||
|
||||
# We need to order the input in the way they appears in the forward()
|
||||
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
||||
|
||||
batch = common_inputs["input_ids"].shape[0]
|
||||
past_shapes = {
|
||||
"global": (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_attention_heads),
|
||||
"local": (batch, 1, self._config.hidden_size),
|
||||
}
|
||||
|
||||
# Need to add the past_keys
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
|
||||
ordered_inputs["past_key_values"] = []
|
||||
for i in range(self._config.num_layers):
|
||||
attention_type = self._config.attention_layers[i]
|
||||
if attention_type == "global":
|
||||
ordered_inputs["past_key_values"].append(
|
||||
(
|
||||
torch.zeros(past_shapes[attention_type]),
|
||||
torch.zeros(past_shapes[attention_type]),
|
||||
)
|
||||
)
|
||||
else:
|
||||
ordered_inputs["past_key_values"].append((torch.zeros(past_shapes[attention_type]),))
|
||||
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||
if self.use_past:
|
||||
ordered_inputs["attention_mask"] = torch.cat(
|
||||
[ordered_inputs["attention_mask"], torch.zeros(batch, 1)], dim=1
|
||||
)
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
@staticmethod
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
if name in ["present", "past_key_values"]:
|
||||
flatten_output = {}
|
||||
for idx, t in enumerate(field):
|
||||
if len(t) == 1:
|
||||
flatten_output[f"{name}.{idx}.key_value"] = t[0]
|
||||
else:
|
||||
flatten_output[f"{name}.{idx}.key"] = t[0]
|
||||
flatten_output[f"{name}.{idx}.value"] = t[1]
|
||||
|
||||
return flatten_output
|
||||
|
||||
return super().flatten_output_collection_property(name, field)
|
||||
|
@ -1121,7 +1121,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
||||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[range(batch_size), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -948,10 +948,8 @@ class IBertLMHead(nn.Module):
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
@ -964,6 +962,10 @@ class IBertLMHead(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
@ -1336,10 +1336,8 @@ class LongformerLMHead(nn.Module):
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
@ -1352,6 +1350,10 @@ class LongformerLMHead(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
class LongformerPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
@ -28,7 +28,7 @@ from ...file_utils import (
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig"],
|
||||
"configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig", "MBartOnnxConfig"],
|
||||
}
|
||||
|
||||
if is_sentencepiece_available():
|
||||
@ -68,7 +68,7 @@ if is_flax_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
|
||||
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig, MBartOnnxConfig
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .tokenization_mbart import MBartTokenizer
|
||||
|
@ -13,6 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" MBART model configuration """
|
||||
from collections import OrderedDict
|
||||
from typing import Mapping
|
||||
|
||||
from transformers.onnx import OnnxConfigWithPast
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
@ -171,3 +175,32 @@ class MBartConfig(PretrainedConfig):
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
|
||||
class MBartOnnxConfig(OnnxConfigWithPast):
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
if self.use_past:
|
||||
return OrderedDict(
|
||||
[
|
||||
("last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||
("past_keys", {0: "batch", 2: "sequence"}),
|
||||
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
else:
|
||||
return OrderedDict(
|
||||
[
|
||||
("last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
@ -1747,8 +1747,6 @@ class ReformerOnlyLMHead(nn.Module):
|
||||
self.chunk_size_lm_head = config.chunk_size_lm_head
|
||||
self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -1758,6 +1756,10 @@ class ReformerOnlyLMHead(nn.Module):
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
class ReformerPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
@ -1124,10 +1124,8 @@ class RobertaLMHead(nn.Module):
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
@ -1140,6 +1138,10 @@ class RobertaLMHead(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
@ -14,10 +14,11 @@
|
||||
# limitations under the License.
|
||||
""" T5 model configuration """
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Mapping, Optional
|
||||
from typing import Any, Dict, Iterable, Mapping, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer, TensorType
|
||||
|
||||
from ... import is_torch_available
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfigWithPast
|
||||
from ...utils import logging
|
||||
@ -140,9 +141,6 @@ class T5Config(PretrainedConfig):
|
||||
|
||||
|
||||
class T5OnnxConfig(OnnxConfigWithPast):
|
||||
def __init__(self, config: PretrainedConfig, use_past: bool = False):
|
||||
super().__init__(config, use_past)
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = OrderedDict(
|
||||
@ -155,29 +153,30 @@ class T5OnnxConfig(OnnxConfigWithPast):
|
||||
)
|
||||
|
||||
if self.use_past:
|
||||
for i in range(self._config.num_layers):
|
||||
common_inputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "past_sequence"},)
|
||||
common_inputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "past_sequence"},)
|
||||
common_inputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "past_sequence"},)
|
||||
common_inputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "past_sequence"},)
|
||||
for i in range(0, self._config.num_layers):
|
||||
common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"}
|
||||
common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"}
|
||||
common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"}
|
||||
common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"}
|
||||
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_outputs = OrderedDict(
|
||||
[
|
||||
("last_hidden_state", {0: "batch", 1: "decoder_sequence"}),
|
||||
("encoder_last_hidden_state", {0: "batch", 2: "encoder_sequence"}),
|
||||
]
|
||||
)
|
||||
common_outputs = super().outputs
|
||||
|
||||
if "last_hidden_state" in common_outputs:
|
||||
common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"}
|
||||
|
||||
if self.use_past:
|
||||
for i in range(self._config.num_layers):
|
||||
common_outputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "decoder_sequence"},)
|
||||
common_outputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "decoder_sequence"},)
|
||||
common_outputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "encoder_sequence"},)
|
||||
common_outputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "encoder_sequence"},)
|
||||
common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"}
|
||||
common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"}
|
||||
common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"}
|
||||
common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"}
|
||||
|
||||
if self.task == "default":
|
||||
common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"}
|
||||
|
||||
return common_outputs
|
||||
|
||||
@ -189,8 +188,6 @@ class T5OnnxConfig(OnnxConfigWithPast):
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
if self.use_past:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Generate encoder inputs
|
||||
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||
@ -199,4 +196,45 @@ class T5OnnxConfig(OnnxConfigWithPast):
|
||||
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
|
||||
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||
|
||||
return dict(**encoder_inputs, **decoder_inputs)
|
||||
ordered_inputs = dict(**encoder_inputs, **decoder_inputs)
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
batch = encoder_inputs["input_ids"].shape[0]
|
||||
encoder_seq_length = encoder_inputs["input_ids"].shape[1]
|
||||
encoder_shape = (
|
||||
batch,
|
||||
self._config.num_heads,
|
||||
encoder_seq_length,
|
||||
self._config.hidden_size // self._config.num_heads,
|
||||
)
|
||||
decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)
|
||||
|
||||
ordered_inputs["past_key_values"] = []
|
||||
for _ in range(self._config.num_layers):
|
||||
ordered_inputs["past_key_values"].append(
|
||||
(
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
)
|
||||
)
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
@staticmethod
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
if name in ["present", "past_key_values"]:
|
||||
flatten_output = {}
|
||||
for idx, t in enumerate(field):
|
||||
flatten_output[f"{name}.{idx}.decoder.key"] = t[0]
|
||||
flatten_output[f"{name}.{idx}.decoder.value"] = t[1]
|
||||
flatten_output[f"{name}.{idx}.encoder.key"] = t[2]
|
||||
flatten_output[f"{name}.{idx}.encoder.value"] = t[3]
|
||||
|
||||
return flatten_output
|
||||
|
||||
return super().flatten_output_collection_property(name, field)
|
||||
|
@ -426,8 +426,6 @@ class T5Attention(nn.Module):
|
||||
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
int_seq_length = int(seq_length)
|
||||
|
||||
real_seq_length = seq_length
|
||||
|
||||
if past_key_value is not None:
|
||||
@ -496,7 +494,7 @@ class T5Attention(nn.Module):
|
||||
# if key and values are already calculated
|
||||
# we want only the last query position bias
|
||||
if past_key_value is not None:
|
||||
position_bias = position_bias[:, :, -int_seq_length:, :]
|
||||
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
|
||||
|
||||
if mask is not None:
|
||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
||||
@ -626,7 +624,7 @@ class T5Block(nn.Module):
|
||||
if len(past_key_value) != expected_num_past_key_values:
|
||||
raise ValueError(
|
||||
f"There should be {expected_num_past_key_values} past states. "
|
||||
f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}."
|
||||
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
|
||||
f"Got {len(past_key_value)} past key / value states"
|
||||
)
|
||||
|
||||
|
@ -13,6 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast
|
||||
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast, PatchingSpec
|
||||
from .convert import export, validate_model_outputs
|
||||
from .utils import ParameterFormat, compute_serialized_parameters_size
|
||||
|
@ -14,101 +14,22 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from transformers.models.albert import AlbertOnnxConfig
|
||||
from transformers.models.auto import AutoTokenizer
|
||||
from transformers.models.bart import BartOnnxConfig
|
||||
from transformers.models.bert import BertOnnxConfig
|
||||
from transformers.models.distilbert import DistilBertOnnxConfig
|
||||
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||
from transformers.models.longformer import LongformerOnnxConfig
|
||||
from transformers.models.roberta import RobertaOnnxConfig
|
||||
from transformers.models.t5 import T5OnnxConfig
|
||||
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
||||
|
||||
from .. import is_torch_available
|
||||
from ..utils import logging
|
||||
from .convert import export, validate_model_outputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import AutoModel, PreTrainedModel
|
||||
|
||||
FEATURES_TO_AUTOMODELS = {
|
||||
"default": AutoModel,
|
||||
}
|
||||
|
||||
|
||||
# Set of model topologies we support associated to the features supported by each topology and the factory
|
||||
SUPPORTED_MODEL_KIND = {
|
||||
"albert": {"default": AlbertOnnxConfig.default},
|
||||
"bart": {"default": BartOnnxConfig.default},
|
||||
"bert": {"default": BertOnnxConfig.default},
|
||||
"distilbert": {"default": DistilBertOnnxConfig.default},
|
||||
"gpt2": {"default": GPT2OnnxConfig.default},
|
||||
"longformer": {"default": LongformerOnnxConfig.default},
|
||||
"roberta": {"default": RobertaOnnxConfig},
|
||||
"t5": {"default": T5OnnxConfig.default},
|
||||
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default},
|
||||
}
|
||||
|
||||
|
||||
def get_model_from_features(features: str, model: str):
|
||||
"""
|
||||
Attempt to retrieve a model from a model's name and the features to be enabled.
|
||||
|
||||
Args:
|
||||
features: The features required
|
||||
model: The name of the model to export
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if features not in FEATURES_TO_AUTOMODELS:
|
||||
raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}")
|
||||
|
||||
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
|
||||
|
||||
|
||||
def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]:
|
||||
"""
|
||||
Check whether or not the model has the requested features
|
||||
|
||||
Args:
|
||||
model: The model to export
|
||||
features: The name of the features to check if they are avaiable
|
||||
|
||||
Returns:
|
||||
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
|
||||
|
||||
"""
|
||||
if model.config.model_type not in SUPPORTED_MODEL_KIND:
|
||||
raise KeyError(
|
||||
f"{model.config.model_type} ({model.name}) is not supported yet. "
|
||||
f"Only {SUPPORTED_MODEL_KIND} are supported. "
|
||||
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
|
||||
)
|
||||
|
||||
# Look for the features
|
||||
model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
|
||||
if features not in model_features:
|
||||
raise ValueError(
|
||||
f"{model.config.model_type} doesn't support features {features}. "
|
||||
f"Supported values are: {list(model_features.keys())}"
|
||||
)
|
||||
|
||||
return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]
|
||||
from .features import FeaturesManager
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
|
||||
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
|
||||
parser.add_argument(
|
||||
"--features",
|
||||
choices=["default"],
|
||||
"--feature",
|
||||
choices=list(FeaturesManager.AVAILABLE_FEATURES),
|
||||
default="default",
|
||||
help="Export the model with some additional features.",
|
||||
help="Export the model with some additional feature.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
|
||||
@ -127,8 +48,8 @@ def main():
|
||||
|
||||
# Allocate the model
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
model = get_model_from_features(args.features, args.model)
|
||||
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features)
|
||||
model = FeaturesManager.get_model_from_feature(args.feature, args.model)
|
||||
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
|
||||
onnx_config = model_onnx_config(model.config)
|
||||
|
||||
# Ensure the requested opset is sufficient
|
||||
|
@ -11,9 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Mapping, Optional
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
|
||||
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
|
||||
|
||||
@ -26,6 +27,27 @@ DEFAULT_ONNX_OPSET = 11
|
||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PatchingSpec:
|
||||
"""
|
||||
Data class that holds patching specifications.
|
||||
|
||||
Args:
|
||||
o: Module / object where the op to patch is located
|
||||
name: Name of the op to monkey patch
|
||||
custom_op: Custom op that patches the original op
|
||||
orig_op: Original op that is being patched
|
||||
op_wrapper: Wrapper (optional) that wraps both the original and custom ops.
|
||||
It is useful for ops that are class or static methods for instance.
|
||||
"""
|
||||
|
||||
o: Any
|
||||
name: str
|
||||
custom_op: Callable
|
||||
orig_op: Optional[Callable] = None
|
||||
op_wrapper: Optional[Callable] = None
|
||||
|
||||
|
||||
class OnnxConfig(ABC):
|
||||
"""
|
||||
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
|
||||
@ -34,11 +56,39 @@ class OnnxConfig(ABC):
|
||||
DEFAULT_FIXED_BATCH = 2
|
||||
DEFAULT_FIXED_SEQUENCE = 8
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
_TASKS_TO_COMMON_OUTPUTS = {
|
||||
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
|
||||
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
|
||||
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
|
||||
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
|
||||
"question-answering": OrderedDict(
|
||||
{
|
||||
"start_logits": {0: "batch", 1: "sequence"},
|
||||
"end_logits": {0: "batch", 1: "sequence"},
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, config: PretrainedConfig, task: str = "default", patching_specs: List[PatchingSpec] = None):
|
||||
self._config = config
|
||||
|
||||
if task not in self._TASKS_TO_COMMON_OUTPUTS:
|
||||
raise ValueError(
|
||||
f"{task} is not a supported task, supported tasks: {self._TASKS_TO_COMMON_OUTPUTS.keys()}"
|
||||
)
|
||||
self.task = task
|
||||
|
||||
self._patching_specs = []
|
||||
for spec in patching_specs if patching_specs is not None else []:
|
||||
final_spec = spec
|
||||
if spec.orig_op is None:
|
||||
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
|
||||
self._patching_specs.append(final_spec)
|
||||
|
||||
@classmethod
|
||||
def default(cls, config: PretrainedConfig) -> "OnnxConfig":
|
||||
def from_model_config(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfig":
|
||||
"""
|
||||
Instantiate a OnnxConfig for a specific model
|
||||
|
||||
@ -48,7 +98,7 @@ class OnnxConfig(ABC):
|
||||
Returns:
|
||||
OnnxConfig for this model
|
||||
"""
|
||||
return cls(config)
|
||||
return cls(config, task=task)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -62,7 +112,6 @@ class OnnxConfig(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
"""
|
||||
Mapping containing the axis definition of the output tensors to provide to the model
|
||||
@ -70,7 +119,7 @@ class OnnxConfig(ABC):
|
||||
Returns:
|
||||
For each output: its name associated to the axes symbolic name and the axis position within the tensor
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
return self._TASKS_TO_COMMON_OUTPUTS[self.task]
|
||||
|
||||
@property
|
||||
def values_override(self) -> Optional[Mapping[str, Any]]:
|
||||
@ -170,14 +219,48 @@ class OnnxConfig(ABC):
|
||||
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||
return dict(tokenizer(dummy_input, return_tensors=framework))
|
||||
|
||||
def patch_ops(self):
|
||||
for spec in self._patching_specs:
|
||||
custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op)
|
||||
setattr(spec.o, spec.name, custom_op)
|
||||
|
||||
def restore_ops(self):
|
||||
for spec in self._patching_specs:
|
||||
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
|
||||
setattr(spec.o, spec.name, orig_op)
|
||||
|
||||
@staticmethod
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Flatten any potential nested structure expanding the name of the field with the index of the element within the
|
||||
structure.
|
||||
|
||||
Args:
|
||||
name: The name of the nested structure
|
||||
field: The structure to, potentially, be flattened
|
||||
|
||||
Returns:
|
||||
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
|
||||
|
||||
"""
|
||||
from itertools import chain
|
||||
|
||||
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
|
||||
|
||||
|
||||
class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
def __init__(self, config: PretrainedConfig, use_past: bool = False):
|
||||
super().__init__(config)
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
task: str = "default",
|
||||
patching_specs: List[PatchingSpec] = None,
|
||||
use_past: bool = False,
|
||||
):
|
||||
super().__init__(config, task=task, patching_specs=patching_specs)
|
||||
self.use_past = use_past
|
||||
|
||||
@classmethod
|
||||
def with_past(cls, config: PretrainedConfig) -> "OnnxConfigWithPast":
|
||||
def with_past(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfigWithPast":
|
||||
"""
|
||||
Instantiate a OnnxConfig with `use_past` attribute set to True
|
||||
|
||||
@ -187,7 +270,7 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
Returns:
|
||||
OnnxConfig with `.use_past = True`
|
||||
"""
|
||||
return cls(config, use_past=True)
|
||||
return cls(config, task=task, use_past=True)
|
||||
|
||||
@property
|
||||
def values_override(self) -> Optional[Mapping[str, Any]]:
|
||||
@ -221,3 +304,15 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
# Generate dummy inputs according to compute batch and sequence
|
||||
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||
return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))
|
||||
|
||||
@staticmethod
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
if name in ["present", "past_key_values"]:
|
||||
flatten_output = {}
|
||||
for idx, t in enumerate(field):
|
||||
flatten_output[f"{name}.{idx}.key"] = t[0]
|
||||
flatten_output[f"{name}.{idx}.value"] = t[1]
|
||||
|
||||
return flatten_output
|
||||
|
||||
return super().flatten_output_collection_property(name, field)
|
||||
|
@ -21,9 +21,9 @@ import numpy as np
|
||||
from packaging.version import Version, parse
|
||||
|
||||
from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
|
||||
from ..file_utils import is_torch_onnx_dict_inputs_support_available
|
||||
from ..utils import logging
|
||||
from .config import OnnxConfig
|
||||
from .utils import flatten_output_collection_property
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@ -79,11 +79,16 @@ def export(
|
||||
|
||||
"""
|
||||
if not is_torch_available():
|
||||
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
||||
raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")
|
||||
|
||||
import torch
|
||||
from torch.onnx import export
|
||||
|
||||
from ..file_utils import torch_version
|
||||
|
||||
if not is_torch_onnx_dict_inputs_support_available():
|
||||
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
|
||||
|
||||
logger.info(f"Using framework PyTorch: {torch.__version__}")
|
||||
torch.set_grad_enabled(False)
|
||||
model.config.return_dict = True
|
||||
@ -105,6 +110,8 @@ def export(
|
||||
if not inputs_match:
|
||||
raise ValueError("Model and config inputs doesn't match")
|
||||
|
||||
config.patch_ops()
|
||||
|
||||
# export can works with named args but the dict containing named args as to be last element of the args tuple
|
||||
export(
|
||||
model,
|
||||
@ -119,6 +126,8 @@ def export(
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
config.restore_ops()
|
||||
|
||||
return matched_inputs, onnx_outputs
|
||||
|
||||
|
||||
@ -134,6 +143,8 @@ def validate_model_outputs(
|
||||
|
||||
logger.info("Validating ONNX model...")
|
||||
|
||||
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
|
||||
# dynamic input shapes.
|
||||
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
|
||||
|
||||
# Create ONNX Runtime session
|
||||
@ -146,8 +157,12 @@ def validate_model_outputs(
|
||||
|
||||
# We flatten potential collection of outputs (i.e. past_keys) to a flat structure
|
||||
for name, value in ref_outputs.items():
|
||||
# Overwriting the output name as "present" since it is the name used for the ONNX ouputs
|
||||
# ("past_key_values" being taken for the ONNX inputs)
|
||||
if name == "past_key_values":
|
||||
name = "present"
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = flatten_output_collection_property(name, value)
|
||||
value = config.flatten_output_collection_property(name, value)
|
||||
ref_outputs_dict.update(value)
|
||||
else:
|
||||
ref_outputs_dict[name] = value
|
||||
@ -156,7 +171,7 @@ def validate_model_outputs(
|
||||
onnx_inputs = {}
|
||||
for name, value in reference_model_inputs.items():
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = flatten_output_collection_property(name, value)
|
||||
value = config.flatten_output_collection_property(name, value)
|
||||
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
|
||||
else:
|
||||
onnx_inputs[name] = value.numpy()
|
||||
@ -180,7 +195,7 @@ def validate_model_outputs(
|
||||
|
||||
# Check the shape and values match
|
||||
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
|
||||
ref_value = ref_outputs_dict[name].numpy()
|
||||
ref_value = ref_outputs_dict[name].detach().numpy()
|
||||
logger.info(f'\t- Validating ONNX Model output "{name}":')
|
||||
|
||||
# Shape
|
||||
@ -191,7 +206,7 @@ def validate_model_outputs(
|
||||
f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
|
||||
)
|
||||
else:
|
||||
logger.info(f"\t\t-[✓] {ort_value.shape} matchs {ref_value.shape}")
|
||||
logger.info(f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}")
|
||||
|
||||
# Values
|
||||
if not np.allclose(ref_value, ort_value, atol=atol):
|
||||
|
141
src/transformers/onnx/features.py
Normal file
141
src/transformers/onnx/features.py
Normal file
@ -0,0 +1,141 @@
|
||||
from functools import partial, reduce
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from .. import is_torch_available
|
||||
from ..models.albert import AlbertOnnxConfig
|
||||
from ..models.bart import BartOnnxConfig
|
||||
from ..models.bert import BertOnnxConfig
|
||||
from ..models.distilbert import DistilBertOnnxConfig
|
||||
from ..models.gpt2 import GPT2OnnxConfig
|
||||
from ..models.gpt_neo import GPTNeoOnnxConfig
|
||||
from ..models.longformer import LongformerOnnxConfig
|
||||
from ..models.mbart import MBartOnnxConfig
|
||||
from ..models.roberta import RobertaOnnxConfig
|
||||
from ..models.t5 import T5OnnxConfig
|
||||
from ..models.xlm_roberta import XLMRobertaOnnxConfig
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.models.auto import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
)
|
||||
|
||||
|
||||
def supported_features_mapping(*supported_features, onnx_config_cls=None):
|
||||
"""Generates the mapping between supported features and their corresponding OnnxConfig."""
|
||||
if onnx_config_cls is None:
|
||||
raise ValueError("A OnnxConfig class must be provided")
|
||||
|
||||
mapping = {}
|
||||
for feature in supported_features:
|
||||
if "-with-past" in feature:
|
||||
task = feature.replace("-with-past", "")
|
||||
mapping[feature] = partial(onnx_config_cls.with_past, task=task)
|
||||
else:
|
||||
mapping[feature] = partial(onnx_config_cls.from_model_config, task=feature)
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
class FeaturesManager:
|
||||
_TASKS_TO_AUTOMODELS = {
|
||||
"default": AutoModel,
|
||||
"causal-lm": AutoModelForCausalLM,
|
||||
"seq2seq-lm": AutoModelForSeq2SeqLM,
|
||||
"sequence-classification": AutoModelForSequenceClassification,
|
||||
"token-classification": AutoModelForTokenClassification,
|
||||
"multiple-choice": AutoModelForMultipleChoice,
|
||||
"question-answering": AutoModelForQuestionAnswering,
|
||||
}
|
||||
|
||||
# Set of model topologies we support associated to the features supported by each topology and the factory
|
||||
_SUPPORTED_MODEL_KIND = {
|
||||
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
|
||||
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
|
||||
"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
|
||||
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
|
||||
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
|
||||
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
|
||||
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
|
||||
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
|
||||
"t5": supported_features_mapping(
|
||||
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
|
||||
),
|
||||
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
|
||||
"gpt-neo": supported_features_mapping(
|
||||
"default",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"default-with-past",
|
||||
"causal-lm-with-past",
|
||||
"sequence-classification-with-past",
|
||||
onnx_config_cls=GPTNeoOnnxConfig,
|
||||
),
|
||||
}
|
||||
|
||||
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
|
||||
|
||||
@staticmethod
|
||||
def feature_to_task(feature: str) -> str:
|
||||
return feature.replace("-with-past", "")
|
||||
|
||||
@staticmethod
|
||||
def get_model_from_feature(feature: str, model: str):
|
||||
"""
|
||||
Attempt to retrieve a model from a model's name and the feature to be enabled.
|
||||
|
||||
Args:
|
||||
feature: The feature required
|
||||
model: The name of the model to export
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
task = FeaturesManager.feature_to_task(feature)
|
||||
if task not in FeaturesManager._TASKS_TO_AUTOMODELS:
|
||||
raise KeyError(
|
||||
f"Unknown task: {feature}."
|
||||
f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
|
||||
)
|
||||
|
||||
return FeaturesManager._TASKS_TO_AUTOMODELS[task].from_pretrained(model)
|
||||
|
||||
@staticmethod
|
||||
def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]:
|
||||
"""
|
||||
Check whether or not the model has the requested features
|
||||
|
||||
Args:
|
||||
model: The model to export
|
||||
feature: The name of the feature to check if it is avaiable
|
||||
|
||||
Returns:
|
||||
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
|
||||
|
||||
"""
|
||||
model_type = model.config.model_type.replace("_", "-")
|
||||
model_name = getattr(model, "name", "")
|
||||
model_name = f"({model_name})" if model_name else ""
|
||||
if model_type not in FeaturesManager._SUPPORTED_MODEL_KIND:
|
||||
raise KeyError(
|
||||
f"{model.config.model_type} ({model_name}) is not supported yet. "
|
||||
f"Only {FeaturesManager._SUPPORTED_MODEL_KIND} are supported. "
|
||||
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
|
||||
)
|
||||
|
||||
# Look for the features
|
||||
model_features = FeaturesManager._SUPPORTED_MODEL_KIND[model_type]
|
||||
if feature not in model_features:
|
||||
raise ValueError(
|
||||
f"{model.config.model_type} doesn't support feature {feature}. "
|
||||
f"Supported values are: {list(model_features.keys())}"
|
||||
)
|
||||
|
||||
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_KIND[model_type][feature]
|
@ -14,7 +14,6 @@
|
||||
|
||||
from ctypes import c_float, sizeof
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
|
||||
class ParameterFormat(Enum):
|
||||
@ -62,21 +61,3 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
|
||||
Size (in byte) taken to save all the parameters
|
||||
"""
|
||||
return num_parameters * dtype.size
|
||||
|
||||
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Flatten any potential nested structure expanding the name of the field with the index of the element within the
|
||||
structure.
|
||||
|
||||
Args:
|
||||
name: The name of the nested structure
|
||||
field: The structure to, potentially, be flattened
|
||||
|
||||
Returns:
|
||||
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
|
||||
|
||||
"""
|
||||
from itertools import chain
|
||||
|
||||
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
|
||||
|
@ -25,6 +25,7 @@ from distutils.util import strtobool
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Union
|
||||
from unittest import mock
|
||||
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
@ -1007,7 +1008,7 @@ def mockenv(**kwargs):
|
||||
use_tf = os.getenv("USE_TF", False)
|
||||
|
||||
"""
|
||||
return unittest.mock.patch.dict(os.environ, kwargs)
|
||||
return mock.patch.dict(os.environ, kwargs)
|
||||
|
||||
|
||||
# from https://stackoverflow.com/a/34333710/9201239
|
||||
|
@ -364,7 +364,7 @@ class Trainer:
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if self.place_model_on_device:
|
||||
model = model.to(args.device)
|
||||
self._move_model_to_device(model, args.device)
|
||||
|
||||
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
||||
if self.is_model_parallel:
|
||||
@ -505,6 +505,12 @@ class Trainer:
|
||||
"""
|
||||
self.callback_handler.remove_callback(callback)
|
||||
|
||||
def _move_model_to_device(self, model, device):
|
||||
model = model.to(device)
|
||||
# Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
|
||||
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
|
||||
model.tie_weights()
|
||||
|
||||
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
||||
if not self.args.remove_unused_columns:
|
||||
return dataset
|
||||
@ -1016,7 +1022,7 @@ class Trainer:
|
||||
# do_train is not a reliable argument, as it might not be set and .train() still called, so
|
||||
# the following is a workaround:
|
||||
if args.fp16_full_eval and not args.do_train:
|
||||
self.model = self.model.to(args.device)
|
||||
self._move_model_to_device(self.model, args.device)
|
||||
|
||||
if "model_path" in kwargs:
|
||||
resume_from_checkpoint = kwargs.pop("model_path")
|
||||
@ -1077,7 +1083,7 @@ class Trainer:
|
||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||
if model_reloaded:
|
||||
if self.place_model_on_device:
|
||||
self.model = self.model.to(args.device)
|
||||
self._move_model_to_device(self.model, args.device)
|
||||
self.model_wrapped = self.model
|
||||
|
||||
# Keeping track whether we can can len() on the dataset or not
|
||||
@ -2515,10 +2521,11 @@ class Trainer:
|
||||
Returns:
|
||||
The url of the commit of your model in the given repository.
|
||||
"""
|
||||
if not self.args.should_save:
|
||||
return
|
||||
|
||||
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
|
||||
if self.args.should_save:
|
||||
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
|
||||
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
|
||||
# self.args.should_save.
|
||||
self.save_model()
|
||||
|
||||
# Only push from one node.
|
||||
|
@ -1052,6 +1052,8 @@ class TrainingArguments:
|
||||
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous(desc)
|
||||
elif is_sagemaker_dp_enabled():
|
||||
sm_dist.Barrier()
|
||||
else:
|
||||
torch.distributed.barrier()
|
||||
yield
|
||||
@ -1061,6 +1063,8 @@ class TrainingArguments:
|
||||
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous(desc)
|
||||
elif is_sagemaker_dp_enabled():
|
||||
sm_dist.Barrier()
|
||||
else:
|
||||
torch.distributed.barrier()
|
||||
else:
|
||||
|
@ -9,6 +9,8 @@ from transformers import ( # LongformerConfig,; T5Config,
|
||||
BartConfig,
|
||||
DistilBertConfig,
|
||||
GPT2Config,
|
||||
GPTNeoConfig,
|
||||
MBartConfig,
|
||||
RobertaConfig,
|
||||
XLMRobertaConfig,
|
||||
is_torch_available,
|
||||
@ -20,17 +22,21 @@ from transformers.models.distilbert import DistilBertOnnxConfig
|
||||
|
||||
# from transformers.models.longformer import LongformerOnnxConfig
|
||||
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||
from transformers.models.gpt_neo import GPTNeoOnnxConfig
|
||||
from transformers.models.mbart import MBartOnnxConfig
|
||||
from transformers.models.roberta import RobertaOnnxConfig
|
||||
|
||||
# from transformers.models.t5 import T5OnnxConfig
|
||||
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
||||
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
|
||||
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
||||
from transformers.onnx.utils import (
|
||||
compute_effective_axis_dimension,
|
||||
compute_serialized_parameters_size,
|
||||
flatten_output_collection_property,
|
||||
from transformers.onnx import (
|
||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||
OnnxConfig,
|
||||
ParameterFormat,
|
||||
export,
|
||||
validate_model_outputs,
|
||||
)
|
||||
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
||||
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||
from transformers.testing_utils import require_onnx, require_torch, slow
|
||||
|
||||
|
||||
@ -40,6 +46,15 @@ class OnnxUtilsTestCaseV2(TestCase):
|
||||
Cover all the utilities involved to export ONNX models
|
||||
"""
|
||||
|
||||
@require_torch
|
||||
@patch("transformers.onnx.convert.is_torch_onnx_dict_inputs_support_available", return_value=False)
|
||||
def test_ensure_pytorch_version_ge_1_8_0(self, mock_is_torch_onnx_dict_inputs_support_available):
|
||||
"""
|
||||
Ensure we raise an Exception if the pytorch version is unsupported (< 1.8.0)
|
||||
"""
|
||||
self.assertRaises(AssertionError, export, None, None, None, None, None)
|
||||
mock_is_torch_onnx_dict_inputs_support_available.assert_called()
|
||||
|
||||
def test_compute_effective_axis_dimension(self):
|
||||
"""
|
||||
When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
|
||||
@ -78,7 +93,7 @@ class OnnxUtilsTestCaseV2(TestCase):
|
||||
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
|
||||
"""
|
||||
self.assertEqual(
|
||||
flatten_output_collection_property("past_key", [[0], [1], [2]]),
|
||||
OnnxConfig.flatten_output_collection_property("past_key", [[0], [1], [2]]),
|
||||
{
|
||||
"past_key.0": 0,
|
||||
"past_key.1": 1,
|
||||
@ -136,11 +151,13 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
|
||||
with self.subTest(name):
|
||||
self.assertFalse(
|
||||
OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past"
|
||||
OnnxConfigWithPast.from_model_config(config()).use_past,
|
||||
"OnnxConfigWithPast.from_model_config() should not use_past",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past"
|
||||
OnnxConfigWithPast.with_past(config()).use_past,
|
||||
"OnnxConfigWithPast.from_model_config() should use_past",
|
||||
)
|
||||
|
||||
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
|
||||
@ -152,7 +169,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
with self.subTest(name):
|
||||
|
||||
# without past
|
||||
onnx_config_default = OnnxConfigWithPast.default(config())
|
||||
onnx_config_default = OnnxConfigWithPast.from_model_config(config())
|
||||
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
|
||||
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
|
||||
self.assertFalse(
|
||||
@ -175,19 +192,23 @@ if is_torch_available():
|
||||
BertModel,
|
||||
DistilBertModel,
|
||||
GPT2Model,
|
||||
GPTNeoModel,
|
||||
MBartModel,
|
||||
RobertaModel,
|
||||
XLMRobertaModel,
|
||||
)
|
||||
|
||||
PYTORCH_EXPORT_DEFAULT_MODELS = {
|
||||
("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig),
|
||||
("ALBERT", "hf-internal-testing/tiny-albert", AlbertModel, AlbertConfig, AlbertOnnxConfig),
|
||||
("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
|
||||
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
|
||||
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
|
||||
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
|
||||
("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig),
|
||||
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
||||
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
||||
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
||||
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
|
||||
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
|
||||
}
|
||||
|
||||
@ -210,11 +231,11 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
|
||||
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
|
||||
with self.subTest(name):
|
||||
self.assertTrue(hasattr(onnx_config_class, "default"))
|
||||
self.assertTrue(hasattr(onnx_config_class, "from_model_config"))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
model = model_class(config_class())
|
||||
onnx_config = onnx_config_class.default(model.config)
|
||||
model = model_class(config_class.from_pretrained(model))
|
||||
onnx_config = onnx_config_class.from_model_config(model.config)
|
||||
|
||||
with NamedTemporaryFile("w") as output:
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
|
Reference in New Issue
Block a user