Compare commits

...

12 Commits

Author SHA1 Message Date
41981a25cd Patch release: v4.9.2 2021-08-09 16:01:36 +02:00
ec784223ea Tpu tie weights (#13030)
* Fix tied weights on TPU

* Manually tie weights in no trainer examples

* Fix for test

* One last missing

* Gettning owned by my scripts

* Address review comments

* Fix test

* Fix tests

* Fix reformer tests
2021-08-09 15:53:05 +02:00
bfd53549b0 Add to ONNX docs (#13048)
* Add to ONNX docs

* Add MBART example

* Update docs/source/serialization.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2021-08-09 15:52:16 +02:00
226763a262 Add MBART to models exportable with ONNX (#13049)
* Add MBART to models exportable with ONNX

* unittest mock

* Add tests

* Misc fixes
2021-08-09 15:52:07 +02:00
f595ea33d9 Put smaller ALBERT model (#13028) 2021-08-09 15:51:04 +02:00
a12fa50693 T5 with past ONNX export (#13014)
T5 with past ONNX export, and more explicit past_key_values inputs and outputs names for ONNX model

Authored-by: Michael Benayoun <michael@huggingface.co>
2021-08-09 15:50:58 +02:00
94b7db97bf GPT-Neo ONNX export (#12911)
GPT-Neo ONNX export and task / feature refactoring

Authored-by: Michael Benayoun <michael@huggingface.co>
2021-08-09 15:50:44 +02:00
2c255a2e0c Fix push_to_hub for TPUs (#12895) 2021-08-09 15:47:29 +02:00
ca272fc523 ONNX v2 raises an Exception when using PyTorch < 1.8.0 (#12933)
* Raise an issue if the pytorch version is < 1.8.0

* Attempt to add a test to ensure it correctly raises.

* Missing docstring.

* Second attempt, patch with string absolute import.

* Let's do the call before checking it was called ...

* use the correct function ... 🤦

* Raise ImportError and AssertionError respectively when unable to find torch and torch version is not sufficient.

* Correct path mock patching

* relax constraint for torch_onnx_dict_inputs to ge instead of eq.

* Style.

* Split each version requirements for torch.

* Let's compare version directly.

* Import torch_version after checking pytorch is installed.

* @require_torch
2021-08-09 15:43:31 +02:00
bff1c71e84 Release: v4.9.1 2021-07-26 10:21:55 -04:00
8ee16d84ce Fix barrier for SM distributed (#12853) 2021-07-26 10:21:07 -04:00
6cab8b32e3 Add doc for v4.9.0 2021-07-26 10:20:51 -04:00
32 changed files with 680 additions and 198 deletions

View File

@ -67,4 +67,5 @@ deploy_doc "25dee4a" v4.6.0
deploy_doc "7a6c9fa" v4.7.0
deploy_doc "9252a51" v4.8.0
deploy_doc "1366172" v4.8.1
deploy_doc "96d1cfb" # v4.8.2 Latest stable release
deploy_doc "96d1cfb" v4.8.2
deploy_doc "72aee83" # v4.9.0 Latest stable release

View File

@ -1,10 +1,11 @@
// These two things need to be updated at each release for the version selector.
// Last stable version
const stableVersion = "v4.8.2"
const stableVersion = "v4.9.0"
// Dictionary doc folder to label. The last stable version should have an empty key.
const versionMapping = {
"master": "master",
"": "v4.8.0/v4.8.1/v4.8.2 (stable)",
"": "v4.9.0 (stable)",
"v4.8.2": "v4.8.0/v4.8.1/v4.8.2",
"v4.7.0": "v4.7.0",
"v4.6.0": "v4.6.0",
"v4.5.1": "v4.5.0/v4.5.1",

View File

@ -99,6 +99,30 @@ It will be exported under ``onnx/bert-base-cased``. You should see similar logs:
-[✓] all values close (atol: 0.0001)
All good, model saved at: onnx/bert-base-cased/model.onnx
This export can now be used in the ONNX inference runtime:
.. code-block::
import onnxruntime as ort
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
ort_session = ort.InferenceSession("onnx/bert-base-cased/model.onnx")
inputs = tokenizer("Using BERT in ONNX!", return_tensors="np")
outputs = ort_session.run(["last_hidden_state", "pooler_output"], dict(inputs))
The outputs used (:obj:`["last_hidden_state", "pooler_output"]`) can be obtained by taking a look at the ONNX
configuration of each model. For example, for BERT:
.. code-block::
from transformers.models.bert import BertOnnxConfig, BertConfig
config = BertConfig()
onnx_config = BertOnnxConfig(config)
output_keys = list(onnx_config.outputs.keys())
Implementing a custom configuration for an unsupported architecture
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -142,6 +166,12 @@ An important fact to notice is the use of `OrderedDict` in both inputs and outpu
as inputs are matched against their relative position within the `PreTrainedModel.forward()` prototype and outputs are
match against there position in the returned `BaseModelOutputX` instance.
An example of such an addition is visible here, for the MBart model: `Making MBART ONNX-convertible
<https://github.com/huggingface/transformers/pull/13049/commits/d097adcebd89a520f04352eb215a85916934204f>`__
If you would like to contribute your addition to the library, we recommend you implement tests. An example of such
tests is visible here: `Adding tests to the MBART ONNX conversion
<https://github.com/huggingface/transformers/pull/13049/commits/5d642f65abf45ceeb72bd855ca7bfe2506a58e6a>`__
Graph conversion
-----------------------------------------------------------------------------------------------------------------------

View File

@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
@ -403,6 +403,10 @@ def main():
model, optimizer, train_dataloader, eval_dataloader
)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)

View File

@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
@ -448,6 +448,10 @@ def main():
model, optimizer, train_dataloader, eval_dataloader
)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)

View File

@ -337,7 +337,7 @@ install_requires = [
setup(
name="transformers",
version="4.9.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.9.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Suraj Patil, Stas Bekman, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
author_email="thomas@huggingface.co",
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",

View File

@ -22,7 +22,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.9.0"
__version__ = "4.9.2"
# Work around to update TensorFlow's absl.logging threshold which alters the
# default Python logging output behavior when present.

View File

@ -274,8 +274,9 @@ PRESET_MIRROR_DICT = {
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
}
# This is the version of torch required to run torch.fx features.
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.8")
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
@ -297,7 +298,7 @@ def is_torch_cuda_available():
return False
_torch_fx_available = False
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
if _torch_available:
torch_version = version.parse(importlib_metadata.version("torch"))
_torch_fx_available = (torch_version.major, torch_version.minor) == (
@ -305,11 +306,17 @@ if _torch_available:
TORCH_FX_REQUIRED_VERSION.minor,
)
_torch_onnx_dict_inputs_support_available = torch_version >= TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION
def is_torch_fx_available():
return _torch_fx_available
def is_torch_onnx_dict_inputs_support_available():
return _torch_onnx_dict_inputs_support_available
def is_tf_available():
return _tf_available

View File

@ -594,6 +594,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self = getattr(self, self.base_model_prefix)
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
for module in self.modules():
if hasattr(module, "_tie_weights"):
module._tie_weights()
@staticmethod
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
uninitialized_encoder_weights: List[str] = []

View File

@ -860,8 +860,6 @@ class AlbertMLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
self.activation = ACT2FN[config.hidden_act]
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
@ -874,6 +872,10 @@ class AlbertMLMHead(nn.Module):
return prediction_scores
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class AlbertSOPHead(nn.Module):
def __init__(self, config):

View File

@ -430,16 +430,18 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
class BertGenerationOnlyLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
logits = self.decoder(hidden_states)
return logits
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
@add_start_docstrings(
"""BertGeneration Model with a `language modeling` head on top for CLM fine-tuning. """,

View File

@ -21,7 +21,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = {
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig", "GPTNeoOnnxConfig"],
}
if is_torch_available():
@ -43,7 +43,7 @@ if is_flax_available():
if TYPE_CHECKING:
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig, GPTNeoOnnxConfig
if is_torch_available():
from .modeling_gpt_neo import (

View File

@ -14,7 +14,12 @@
# limitations under the License.
""" GPT Neo model configuration """
from collections import OrderedDict
from typing import Any, Dict, Iterable, Mapping, Optional
from ... import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...utils import logging
@ -173,3 +178,162 @@ class GPTNeoConfig(PretrainedConfig):
@property
def num_hidden_layers(self):
return self.num_layers
def custom_unfold(input, dimension, size, step):
"""Custom torch.Tensor.unfold implementation to enable the export to ONNX."""
import torch
shape = input.size()
rank = len(shape)
sizedim = shape[dimension]
low_indices = torch.arange(0, sizedim, step)
min_length = torch.div(sizedim - size, step, rounding_mode="floor") + 1
indices = torch.arange(size) + low_indices[:min_length][:, None]
s = [slice(None)] * rank
s[dimension] = indices
sliced = input[s]
perm = list(range(0, rank + 1))
perm.append(perm.pop(dimension + 1))
return sliced.permute(perm)
def custom_get_block_length_and_num_blocks(seq_length, window_size):
"""
Custom implementation for GPTNeoAttentionMixin._get_block_length_and_num_blocks to enable the export to ONNX as
original implmentation uses Python variables and control flow.
"""
import torch
candidates = torch.arange(1, window_size)
remainders = torch.remainder(seq_length, candidates)
divisor_indices = remainders == 0
divisors = candidates[divisor_indices]
largest_divisor = torch.max(divisors)
return largest_divisor, torch.div(seq_length, largest_divisor, rounding_mode="floor")
class GPTNeoOnnxConfig(OnnxConfigWithPast):
def __init__(self, config: PretrainedConfig, task: str = "default", use_past: bool = False):
if is_torch_available():
import torch
from .modeling_gpt_neo import GPTNeoAttentionMixin
patching_specs = [
PatchingSpec(torch.Tensor, name="unfold", custom_op=custom_unfold),
PatchingSpec(
GPTNeoAttentionMixin,
name="_get_block_length_and_num_blocks",
custom_op=custom_get_block_length_and_num_blocks,
op_wrapper=staticmethod,
),
]
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
self._num_local_attention = len([type_ for type_ in self._config.attention_layers if type_ == "local"])
self._key_values_dynamic_axis = []
for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local":
self._key_values_dynamic_axis.append({0: "batch", 1: "sequence"})
else:
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
@property
def _number_key_values(self):
return (self._config.num_layers * 2) - self._num_local_attention
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local":
common_inputs[f"past_key_values.{i}.key_value"] = {0: "batch", 1: "sequence"}
else:
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "sequence"}
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super().outputs
if self.use_past:
for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local":
common_outputs[f"present.{i}.key_value"] = {0: "batch", 1: "sequence"}
else:
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "sequence"}
return common_outputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
batch = common_inputs["input_ids"].shape[0]
past_shapes = {
"global": (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_attention_heads),
"local": (batch, 1, self._config.hidden_size),
}
# Need to add the past_keys
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
ordered_inputs["past_key_values"] = []
for i in range(self._config.num_layers):
attention_type = self._config.attention_layers[i]
if attention_type == "global":
ordered_inputs["past_key_values"].append(
(
torch.zeros(past_shapes[attention_type]),
torch.zeros(past_shapes[attention_type]),
)
)
else:
ordered_inputs["past_key_values"].append((torch.zeros(past_shapes[attention_type]),))
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.zeros(batch, 1)], dim=1
)
return ordered_inputs
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
if len(t) == 1:
flatten_output[f"{name}.{idx}.key_value"] = t[0]
else:
flatten_output[f"{name}.{idx}.key"] = t[0]
flatten_output[f"{name}.{idx}.value"] = t[1]
return flatten_output
return super().flatten_output_collection_property(name, field)

View File

@ -1121,7 +1121,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
loss = None
if labels is not None:

View File

@ -948,10 +948,8 @@ class IBertLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
@ -964,6 +962,10 @@ class IBertLMHead(nn.Module):
return x
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
@add_start_docstrings(
"""

View File

@ -1336,10 +1336,8 @@ class LongformerLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
@ -1352,6 +1350,10 @@ class LongformerLMHead(nn.Module):
return x
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class LongformerPreTrainedModel(PreTrainedModel):
"""

View File

@ -28,7 +28,7 @@ from ...file_utils import (
_import_structure = {
"configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig"],
"configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig", "MBartOnnxConfig"],
}
if is_sentencepiece_available():
@ -68,7 +68,7 @@ if is_flax_available():
if TYPE_CHECKING:
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig, MBartOnnxConfig
if is_sentencepiece_available():
from .tokenization_mbart import MBartTokenizer

View File

@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" MBART model configuration """
from collections import OrderedDict
from typing import Mapping
from transformers.onnx import OnnxConfigWithPast
from ...configuration_utils import PretrainedConfig
from ...utils import logging
@ -171,3 +175,32 @@ class MBartConfig(PretrainedConfig):
@property
def hidden_size(self) -> int:
return self.d_model
class MBartOnnxConfig(OnnxConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.use_past:
return OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("past_keys", {0: "batch", 2: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)
else:
return OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)

View File

@ -1747,8 +1747,6 @@ class ReformerOnlyLMHead(nn.Module):
self.chunk_size_lm_head = config.chunk_size_lm_head
self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
@ -1758,6 +1756,10 @@ class ReformerOnlyLMHead(nn.Module):
hidden_states = self.decoder(hidden_states)
return hidden_states
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class ReformerPreTrainedModel(PreTrainedModel):
"""

View File

@ -1124,10 +1124,8 @@ class RobertaLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
@ -1140,6 +1138,10 @@ class RobertaLMHead(nn.Module):
return x
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
@add_start_docstrings(
"""

View File

@ -14,10 +14,11 @@
# limitations under the License.
""" T5 model configuration """
from collections import OrderedDict
from typing import Any, Mapping, Optional
from typing import Any, Dict, Iterable, Mapping, Optional
from transformers import PreTrainedTokenizer, TensorType
from ... import is_torch_available
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast
from ...utils import logging
@ -140,9 +141,6 @@ class T5Config(PretrainedConfig):
class T5OnnxConfig(OnnxConfigWithPast):
def __init__(self, config: PretrainedConfig, use_past: bool = False):
super().__init__(config, use_past)
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict(
@ -155,29 +153,30 @@ class T5OnnxConfig(OnnxConfigWithPast):
)
if self.use_past:
for i in range(self._config.num_layers):
common_inputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "past_sequence"},)
common_inputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "past_sequence"},)
common_inputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "past_sequence"},)
common_inputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "past_sequence"},)
for i in range(0, self._config.num_layers):
common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"}
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "decoder_sequence"}),
("encoder_last_hidden_state", {0: "batch", 2: "encoder_sequence"}),
]
)
common_outputs = super().outputs
if "last_hidden_state" in common_outputs:
common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
for i in range(self._config.num_layers):
common_outputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "decoder_sequence"},)
common_outputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "decoder_sequence"},)
common_outputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "encoder_sequence"},)
common_outputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "encoder_sequence"},)
common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"}
common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"}
common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"}
common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"}
if self.task == "default":
common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"}
return common_outputs
@ -189,8 +188,6 @@ class T5OnnxConfig(OnnxConfigWithPast):
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
if self.use_past:
raise NotImplementedError()
# Generate encoder inputs
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
@ -199,4 +196,45 @@ class T5OnnxConfig(OnnxConfigWithPast):
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
return dict(**encoder_inputs, **decoder_inputs)
ordered_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch = encoder_inputs["input_ids"].shape[0]
encoder_seq_length = encoder_inputs["input_ids"].shape[1]
encoder_shape = (
batch,
self._config.num_heads,
encoder_seq_length,
self._config.hidden_size // self._config.num_heads,
)
decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)
ordered_inputs["past_key_values"] = []
for _ in range(self._config.num_layers):
ordered_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
return ordered_inputs
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
flatten_output[f"{name}.{idx}.decoder.key"] = t[0]
flatten_output[f"{name}.{idx}.decoder.value"] = t[1]
flatten_output[f"{name}.{idx}.encoder.key"] = t[2]
flatten_output[f"{name}.{idx}.encoder.value"] = t[3]
return flatten_output
return super().flatten_output_collection_property(name, field)

View File

@ -426,8 +426,6 @@ class T5Attention(nn.Module):
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
int_seq_length = int(seq_length)
real_seq_length = seq_length
if past_key_value is not None:
@ -496,7 +494,7 @@ class T5Attention(nn.Module):
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -int_seq_length:, :]
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
@ -626,7 +624,7 @@ class T5Block(nn.Module):
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}."
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)

View File

@ -13,6 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast, PatchingSpec
from .convert import export, validate_model_outputs
from .utils import ParameterFormat, compute_serialized_parameters_size

View File

@ -14,101 +14,22 @@
from argparse import ArgumentParser
from pathlib import Path
from typing import Callable, Tuple
from transformers.models.albert import AlbertOnnxConfig
from transformers.models.auto import AutoTokenizer
from transformers.models.bart import BartOnnxConfig
from transformers.models.bert import BertOnnxConfig
from transformers.models.distilbert import DistilBertOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from .. import is_torch_available
from ..utils import logging
from .convert import export, validate_model_outputs
if is_torch_available():
from transformers import AutoModel, PreTrainedModel
FEATURES_TO_AUTOMODELS = {
"default": AutoModel,
}
# Set of model topologies we support associated to the features supported by each topology and the factory
SUPPORTED_MODEL_KIND = {
"albert": {"default": AlbertOnnxConfig.default},
"bart": {"default": BartOnnxConfig.default},
"bert": {"default": BertOnnxConfig.default},
"distilbert": {"default": DistilBertOnnxConfig.default},
"gpt2": {"default": GPT2OnnxConfig.default},
"longformer": {"default": LongformerOnnxConfig.default},
"roberta": {"default": RobertaOnnxConfig},
"t5": {"default": T5OnnxConfig.default},
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default},
}
def get_model_from_features(features: str, model: str):
"""
Attempt to retrieve a model from a model's name and the features to be enabled.
Args:
features: The features required
model: The name of the model to export
Returns:
"""
if features not in FEATURES_TO_AUTOMODELS:
raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}")
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features
Args:
model: The model to export
features: The name of the features to check if they are avaiable
Returns:
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
"""
if model.config.model_type not in SUPPORTED_MODEL_KIND:
raise KeyError(
f"{model.config.model_type} ({model.name}) is not supported yet. "
f"Only {SUPPORTED_MODEL_KIND} are supported. "
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
)
# Look for the features
model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
if features not in model_features:
raise ValueError(
f"{model.config.model_type} doesn't support features {features}. "
f"Supported values are: {list(model_features.keys())}"
)
return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]
from .features import FeaturesManager
def main():
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
parser.add_argument(
"--features",
choices=["default"],
"--feature",
choices=list(FeaturesManager.AVAILABLE_FEATURES),
default="default",
help="Export the model with some additional features.",
help="Export the model with some additional feature.",
)
parser.add_argument(
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
@ -127,8 +48,8 @@ def main():
# Allocate the model
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = get_model_from_features(args.features, args.model)
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features)
model = FeaturesManager.get_model_from_feature(args.feature, args.model)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config)
# Ensure the requested opset is sufficient

View File

@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Mapping, Optional
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
@ -26,6 +27,27 @@ DEFAULT_ONNX_OPSET = 11
EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024
@dataclasses.dataclass
class PatchingSpec:
"""
Data class that holds patching specifications.
Args:
o: Module / object where the op to patch is located
name: Name of the op to monkey patch
custom_op: Custom op that patches the original op
orig_op: Original op that is being patched
op_wrapper: Wrapper (optional) that wraps both the original and custom ops.
It is useful for ops that are class or static methods for instance.
"""
o: Any
name: str
custom_op: Callable
orig_op: Optional[Callable] = None
op_wrapper: Optional[Callable] = None
class OnnxConfig(ABC):
"""
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
@ -34,11 +56,39 @@ class OnnxConfig(ABC):
DEFAULT_FIXED_BATCH = 2
DEFAULT_FIXED_SEQUENCE = 8
def __init__(self, config: PretrainedConfig):
_TASKS_TO_COMMON_OUTPUTS = {
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
"question-answering": OrderedDict(
{
"start_logits": {0: "batch", 1: "sequence"},
"end_logits": {0: "batch", 1: "sequence"},
}
),
}
def __init__(self, config: PretrainedConfig, task: str = "default", patching_specs: List[PatchingSpec] = None):
self._config = config
if task not in self._TASKS_TO_COMMON_OUTPUTS:
raise ValueError(
f"{task} is not a supported task, supported tasks: {self._TASKS_TO_COMMON_OUTPUTS.keys()}"
)
self.task = task
self._patching_specs = []
for spec in patching_specs if patching_specs is not None else []:
final_spec = spec
if spec.orig_op is None:
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
self._patching_specs.append(final_spec)
@classmethod
def default(cls, config: PretrainedConfig) -> "OnnxConfig":
def from_model_config(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfig":
"""
Instantiate a OnnxConfig for a specific model
@ -48,7 +98,7 @@ class OnnxConfig(ABC):
Returns:
OnnxConfig for this model
"""
return cls(config)
return cls(config, task=task)
@property
@abstractmethod
@ -62,7 +112,6 @@ class OnnxConfig(ABC):
raise NotImplementedError()
@property
@abstractmethod
def outputs(self) -> Mapping[str, Mapping[int, str]]:
"""
Mapping containing the axis definition of the output tensors to provide to the model
@ -70,7 +119,7 @@ class OnnxConfig(ABC):
Returns:
For each output: its name associated to the axes symbolic name and the axis position within the tensor
"""
raise NotImplementedError()
return self._TASKS_TO_COMMON_OUTPUTS[self.task]
@property
def values_override(self) -> Optional[Mapping[str, Any]]:
@ -170,14 +219,48 @@ class OnnxConfig(ABC):
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
return dict(tokenizer(dummy_input, return_tensors=framework))
def patch_ops(self):
for spec in self._patching_specs:
custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op)
setattr(spec.o, spec.name, custom_op)
def restore_ops(self):
for spec in self._patching_specs:
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
setattr(spec.o, spec.name, orig_op)
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
"""
Flatten any potential nested structure expanding the name of the field with the index of the element within the
structure.
Args:
name: The name of the nested structure
field: The structure to, potentially, be flattened
Returns:
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
"""
from itertools import chain
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
class OnnxConfigWithPast(OnnxConfig, ABC):
def __init__(self, config: PretrainedConfig, use_past: bool = False):
super().__init__(config)
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs)
self.use_past = use_past
@classmethod
def with_past(cls, config: PretrainedConfig) -> "OnnxConfigWithPast":
def with_past(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfigWithPast":
"""
Instantiate a OnnxConfig with `use_past` attribute set to True
@ -187,7 +270,7 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
Returns:
OnnxConfig with `.use_past = True`
"""
return cls(config, use_past=True)
return cls(config, task=task, use_past=True)
@property
def values_override(self) -> Optional[Mapping[str, Any]]:
@ -221,3 +304,15 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
flatten_output[f"{name}.{idx}.key"] = t[0]
flatten_output[f"{name}.{idx}.value"] = t[1]
return flatten_output
return super().flatten_output_collection_property(name, field)

View File

@ -21,9 +21,9 @@ import numpy as np
from packaging.version import Version, parse
from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
from ..file_utils import is_torch_onnx_dict_inputs_support_available
from ..utils import logging
from .config import OnnxConfig
from .utils import flatten_output_collection_property
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -79,11 +79,16 @@ def export(
"""
if not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")
import torch
from torch.onnx import export
from ..file_utils import torch_version
if not is_torch_onnx_dict_inputs_support_available():
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
logger.info(f"Using framework PyTorch: {torch.__version__}")
torch.set_grad_enabled(False)
model.config.return_dict = True
@ -105,6 +110,8 @@ def export(
if not inputs_match:
raise ValueError("Model and config inputs doesn't match")
config.patch_ops()
# export can works with named args but the dict containing named args as to be last element of the args tuple
export(
model,
@ -119,6 +126,8 @@ def export(
opset_version=opset,
)
config.restore_ops()
return matched_inputs, onnx_outputs
@ -134,6 +143,8 @@ def validate_model_outputs(
logger.info("Validating ONNX model...")
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
# dynamic input shapes.
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
# Create ONNX Runtime session
@ -146,8 +157,12 @@ def validate_model_outputs(
# We flatten potential collection of outputs (i.e. past_keys) to a flat structure
for name, value in ref_outputs.items():
# Overwriting the output name as "present" since it is the name used for the ONNX ouputs
# ("past_key_values" being taken for the ONNX inputs)
if name == "past_key_values":
name = "present"
if isinstance(value, (list, tuple)):
value = flatten_output_collection_property(name, value)
value = config.flatten_output_collection_property(name, value)
ref_outputs_dict.update(value)
else:
ref_outputs_dict[name] = value
@ -156,7 +171,7 @@ def validate_model_outputs(
onnx_inputs = {}
for name, value in reference_model_inputs.items():
if isinstance(value, (list, tuple)):
value = flatten_output_collection_property(name, value)
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
else:
onnx_inputs[name] = value.numpy()
@ -180,7 +195,7 @@ def validate_model_outputs(
# Check the shape and values match
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
ref_value = ref_outputs_dict[name].numpy()
ref_value = ref_outputs_dict[name].detach().numpy()
logger.info(f'\t- Validating ONNX Model output "{name}":')
# Shape
@ -191,7 +206,7 @@ def validate_model_outputs(
f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
)
else:
logger.info(f"\t\t-[✓] {ort_value.shape} matchs {ref_value.shape}")
logger.info(f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}")
# Values
if not np.allclose(ref_value, ort_value, atol=atol):

View File

@ -0,0 +1,141 @@
from functools import partial, reduce
from typing import Callable, Tuple
from .. import is_torch_available
from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.longformer import LongformerOnnxConfig
from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig
from ..models.t5 import T5OnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig
if is_torch_available():
from transformers import PreTrainedModel
from transformers.models.auto import (
AutoModel,
AutoModelForCausalLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
def supported_features_mapping(*supported_features, onnx_config_cls=None):
"""Generates the mapping between supported features and their corresponding OnnxConfig."""
if onnx_config_cls is None:
raise ValueError("A OnnxConfig class must be provided")
mapping = {}
for feature in supported_features:
if "-with-past" in feature:
task = feature.replace("-with-past", "")
mapping[feature] = partial(onnx_config_cls.with_past, task=task)
else:
mapping[feature] = partial(onnx_config_cls.from_model_config, task=feature)
return mapping
class FeaturesManager:
_TASKS_TO_AUTOMODELS = {
"default": AutoModel,
"causal-lm": AutoModelForCausalLM,
"seq2seq-lm": AutoModelForSeq2SeqLM,
"sequence-classification": AutoModelForSequenceClassification,
"token-classification": AutoModelForTokenClassification,
"multiple-choice": AutoModelForMultipleChoice,
"question-answering": AutoModelForQuestionAnswering,
}
# Set of model topologies we support associated to the features supported by each topology and the factory
_SUPPORTED_MODEL_KIND = {
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
"t5": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
),
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
"gpt-neo": supported_features_mapping(
"default",
"causal-lm",
"sequence-classification",
"default-with-past",
"causal-lm-with-past",
"sequence-classification-with-past",
onnx_config_cls=GPTNeoOnnxConfig,
),
}
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
@staticmethod
def feature_to_task(feature: str) -> str:
return feature.replace("-with-past", "")
@staticmethod
def get_model_from_feature(feature: str, model: str):
"""
Attempt to retrieve a model from a model's name and the feature to be enabled.
Args:
feature: The feature required
model: The name of the model to export
Returns:
"""
task = FeaturesManager.feature_to_task(feature)
if task not in FeaturesManager._TASKS_TO_AUTOMODELS:
raise KeyError(
f"Unknown task: {feature}."
f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
)
return FeaturesManager._TASKS_TO_AUTOMODELS[task].from_pretrained(model)
@staticmethod
def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features
Args:
model: The model to export
feature: The name of the feature to check if it is avaiable
Returns:
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
"""
model_type = model.config.model_type.replace("_", "-")
model_name = getattr(model, "name", "")
model_name = f"({model_name})" if model_name else ""
if model_type not in FeaturesManager._SUPPORTED_MODEL_KIND:
raise KeyError(
f"{model.config.model_type} ({model_name}) is not supported yet. "
f"Only {FeaturesManager._SUPPORTED_MODEL_KIND} are supported. "
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
)
# Look for the features
model_features = FeaturesManager._SUPPORTED_MODEL_KIND[model_type]
if feature not in model_features:
raise ValueError(
f"{model.config.model_type} doesn't support feature {feature}. "
f"Supported values are: {list(model_features.keys())}"
)
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_KIND[model_type][feature]

View File

@ -14,7 +14,6 @@
from ctypes import c_float, sizeof
from enum import Enum
from typing import Any, Dict, Iterable
class ParameterFormat(Enum):
@ -62,21 +61,3 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
Size (in byte) taken to save all the parameters
"""
return num_parameters * dtype.size
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
"""
Flatten any potential nested structure expanding the name of the field with the index of the element within the
structure.
Args:
name: The name of the nested structure
field: The structure to, potentially, be flattened
Returns:
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
"""
from itertools import chain
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}

View File

@ -25,6 +25,7 @@ from distutils.util import strtobool
from io import StringIO
from pathlib import Path
from typing import Iterator, Union
from unittest import mock
from transformers import logging as transformers_logging
@ -1007,7 +1008,7 @@ def mockenv(**kwargs):
use_tf = os.getenv("USE_TF", False)
"""
return unittest.mock.patch.dict(os.environ, kwargs)
return mock.patch.dict(os.environ, kwargs)
# from https://stackoverflow.com/a/34333710/9201239

View File

@ -364,7 +364,7 @@ class Trainer:
self.tokenizer = tokenizer
if self.place_model_on_device:
model = model.to(args.device)
self._move_model_to_device(model, args.device)
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
if self.is_model_parallel:
@ -505,6 +505,12 @@ class Trainer:
"""
self.callback_handler.remove_callback(callback)
def _move_model_to_device(self, model, device):
model = model.to(device)
# Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
model.tie_weights()
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
return dataset
@ -1016,7 +1022,7 @@ class Trainer:
# do_train is not a reliable argument, as it might not be set and .train() still called, so
# the following is a workaround:
if args.fp16_full_eval and not args.do_train:
self.model = self.model.to(args.device)
self._move_model_to_device(self.model, args.device)
if "model_path" in kwargs:
resume_from_checkpoint = kwargs.pop("model_path")
@ -1077,7 +1083,7 @@ class Trainer:
# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
if self.place_model_on_device:
self.model = self.model.to(args.device)
self._move_model_to_device(self.model, args.device)
self.model_wrapped = self.model
# Keeping track whether we can can len() on the dataset or not
@ -2515,10 +2521,11 @@ class Trainer:
Returns:
The url of the commit of your model in the given repository.
"""
if not self.args.should_save:
return
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
if self.args.should_save:
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# self.args.should_save.
self.save_model()
# Only push from one node.

View File

@ -1052,6 +1052,8 @@ class TrainingArguments:
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
if is_torch_tpu_available():
xm.rendezvous(desc)
elif is_sagemaker_dp_enabled():
sm_dist.Barrier()
else:
torch.distributed.barrier()
yield
@ -1061,6 +1063,8 @@ class TrainingArguments:
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
if is_torch_tpu_available():
xm.rendezvous(desc)
elif is_sagemaker_dp_enabled():
sm_dist.Barrier()
else:
torch.distributed.barrier()
else:

View File

@ -9,6 +9,8 @@ from transformers import ( # LongformerConfig,; T5Config,
BartConfig,
DistilBertConfig,
GPT2Config,
GPTNeoConfig,
MBartConfig,
RobertaConfig,
XLMRobertaConfig,
is_torch_available,
@ -20,17 +22,21 @@ from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.models.mbart import MBartOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
# from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
from transformers.onnx.utils import (
compute_effective_axis_dimension,
compute_serialized_parameters_size,
flatten_output_collection_property,
from transformers.onnx import (
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
OnnxConfig,
ParameterFormat,
export,
validate_model_outputs,
)
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_torch, slow
@ -40,6 +46,15 @@ class OnnxUtilsTestCaseV2(TestCase):
Cover all the utilities involved to export ONNX models
"""
@require_torch
@patch("transformers.onnx.convert.is_torch_onnx_dict_inputs_support_available", return_value=False)
def test_ensure_pytorch_version_ge_1_8_0(self, mock_is_torch_onnx_dict_inputs_support_available):
"""
Ensure we raise an Exception if the pytorch version is unsupported (< 1.8.0)
"""
self.assertRaises(AssertionError, export, None, None, None, None, None)
mock_is_torch_onnx_dict_inputs_support_available.assert_called()
def test_compute_effective_axis_dimension(self):
"""
When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
@ -78,7 +93,7 @@ class OnnxUtilsTestCaseV2(TestCase):
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
"""
self.assertEqual(
flatten_output_collection_property("past_key", [[0], [1], [2]]),
OnnxConfig.flatten_output_collection_property("past_key", [[0], [1], [2]]),
{
"past_key.0": 0,
"past_key.1": 1,
@ -136,11 +151,13 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
with self.subTest(name):
self.assertFalse(
OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past"
OnnxConfigWithPast.from_model_config(config()).use_past,
"OnnxConfigWithPast.from_model_config() should not use_past",
)
self.assertTrue(
OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past"
OnnxConfigWithPast.with_past(config()).use_past,
"OnnxConfigWithPast.from_model_config() should use_past",
)
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
@ -152,7 +169,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
with self.subTest(name):
# without past
onnx_config_default = OnnxConfigWithPast.default(config())
onnx_config_default = OnnxConfigWithPast.from_model_config(config())
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
self.assertFalse(
@ -175,19 +192,23 @@ if is_torch_available():
BertModel,
DistilBertModel,
GPT2Model,
GPTNeoModel,
MBartModel,
RobertaModel,
XLMRobertaModel,
)
PYTORCH_EXPORT_DEFAULT_MODELS = {
("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig),
("ALBERT", "hf-internal-testing/tiny-albert", AlbertModel, AlbertConfig, AlbertOnnxConfig),
("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig),
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
}
@ -210,11 +231,11 @@ class OnnxExportTestCaseV2(TestCase):
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "default"))
self.assertTrue(hasattr(onnx_config_class, "from_model_config"))
tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class())
onnx_config = onnx_config_class.default(model.config)
model = model_class(config_class.from_pretrained(model))
onnx_config = onnx_config_class.from_model_config(model.config)
with NamedTemporaryFile("w") as output:
onnx_inputs, onnx_outputs = export(