mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
5 Commits
v4.44.0
...
onnx_gpt2_
Author | SHA1 | Date | |
---|---|---|---|
db4a4b75f0 | |||
f7d63691bb | |||
d909501555 | |||
e0063d30ca | |||
b93d5d4531 |
@ -249,25 +249,27 @@ class GPT2OnnxConfig(OnnxConfigWithPast):
|
||||
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
past_key_values_length = 1
|
||||
past_shape = (
|
||||
2, # Key AND Values
|
||||
batch,
|
||||
self.num_attention_heads,
|
||||
past_key_values_length,
|
||||
self._config.hidden_size // self.num_attention_heads,
|
||||
)
|
||||
ordered_inputs["past_key_values"] = [
|
||||
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
|
||||
torch.zeros(past_shape, dtype=torch.float32) for _ in range(self.num_layers)
|
||||
]
|
||||
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||
if self.use_past:
|
||||
ordered_inputs["attention_mask"] = torch.cat(
|
||||
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||
)
|
||||
if self.use_past:
|
||||
ordered_inputs["attention_mask"] = torch.ones(
|
||||
batch, seqlen + past_key_values_length, dtype=torch.int64
|
||||
)
|
||||
else:
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"].long()
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
@property
|
||||
def default_onnx_opset(self) -> int:
|
||||
return 13
|
||||
return 11
|
||||
|
@ -231,12 +231,10 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
|
||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||
)
|
||||
|
||||
# We need to order the input in the way they appears in the forward()
|
||||
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
||||
|
||||
# Need to add the past_keys
|
||||
@ -248,22 +246,24 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
past_key_values_length = 1
|
||||
past_shape = (
|
||||
2, # Key AND Values
|
||||
batch,
|
||||
self.num_attention_heads,
|
||||
past_key_values_length,
|
||||
self._config.hidden_size // self.num_attention_heads,
|
||||
)
|
||||
ordered_inputs["past_key_values"] = [
|
||||
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
|
||||
torch.zeros(past_shape, dtype=torch.float32) for _ in range(self.num_layers)
|
||||
]
|
||||
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||
if self.use_past:
|
||||
ordered_inputs["attention_mask"] = torch.cat(
|
||||
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||
)
|
||||
if self.use_past:
|
||||
ordered_inputs["attention_mask"] = torch.ones(
|
||||
batch, seqlen + past_key_values_length, dtype=torch.int64
|
||||
)
|
||||
else:
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"].long()
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
|
@ -198,24 +198,24 @@ class GPTJOnnxConfig(OnnxConfigWithPast):
|
||||
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
past_key_values_length = 1
|
||||
past_shape = (
|
||||
2, # Key AND Values
|
||||
batch,
|
||||
self.num_attention_heads,
|
||||
past_key_values_length,
|
||||
self._config.hidden_size // self.num_attention_heads,
|
||||
)
|
||||
ordered_inputs["past_key_values"] = [
|
||||
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
|
||||
torch.zeros(past_shape, dtype=torch.float32) for _ in range(self.num_layers)
|
||||
]
|
||||
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||
if self.use_past:
|
||||
ordered_inputs["attention_mask"] = torch.cat(
|
||||
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||
)
|
||||
|
||||
return ordered_inputs
|
||||
if self.use_past:
|
||||
ordered_inputs["attention_mask"] = torch.ones(
|
||||
batch, seqlen + past_key_values_length, dtype=torch.int64
|
||||
)
|
||||
else:
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"].long()
|
||||
|
||||
@property
|
||||
def default_onnx_opset(self) -> int:
|
||||
|
@ -101,10 +101,7 @@ class OnnxConfig(ABC):
|
||||
|
||||
self._patching_specs = []
|
||||
for spec in patching_specs if patching_specs is not None else []:
|
||||
final_spec = spec
|
||||
if spec.orig_op is None:
|
||||
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
|
||||
self._patching_specs.append(final_spec)
|
||||
self.install_patching_spec(spec)
|
||||
|
||||
@classmethod
|
||||
def from_model_config(cls, config: "PretrainedConfig", task: str = "default") -> "OnnxConfig":
|
||||
@ -246,6 +243,17 @@ class OnnxConfig(ABC):
|
||||
images.append(Image.fromarray(data.astype("uint8")).convert("RGB"))
|
||||
return images
|
||||
|
||||
def install_patching_spec(self, spec: PatchingSpec):
|
||||
"""
|
||||
Append a patching spec after initialization of the configuration
|
||||
:param spec:
|
||||
:return:
|
||||
"""
|
||||
if spec.orig_op is None:
|
||||
spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
|
||||
|
||||
self._patching_specs.append(spec)
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
|
||||
@ -447,8 +455,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
past_key_values_length = 1
|
||||
shape = (
|
||||
2,
|
||||
batch,
|
||||
self.num_attention_heads,
|
||||
past_key_values_length,
|
||||
@ -457,12 +466,16 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
|
||||
if "attention_mask" in common_inputs:
|
||||
common_inputs["attention_mask"] = torch.cat(
|
||||
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||
[
|
||||
common_inputs["attention_mask"],
|
||||
torch.ones(batch, seqlen + past_key_values_length, dtype=torch.int64),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
common_inputs["past_key_values"] = []
|
||||
for _ in range(self.num_layers):
|
||||
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
|
||||
common_inputs["past_key_values"].append(torch.zeros(shape, dtype=torch.float32))
|
||||
|
||||
return common_inputs
|
||||
|
||||
@ -481,12 +494,10 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
|
||||
name = "past_key_values" if direction == "inputs" else "present"
|
||||
for i in range(self.num_layers):
|
||||
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
inputs_or_outputs[f"{name}.{i}"] = {1: "batch", 3: "past_sequence + sequence"}
|
||||
|
||||
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
|
||||
flattened_output[f"{name}.{idx}.key"] = t[0]
|
||||
flattened_output[f"{name}.{idx}.value"] = t[1]
|
||||
flattened_output[f"{name}.{idx}"] = t
|
||||
|
||||
def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
flattened_output = {}
|
||||
|
@ -29,7 +29,7 @@ from ..utils import (
|
||||
is_torch_onnx_dict_inputs_support_available,
|
||||
logging,
|
||||
)
|
||||
from .config import OnnxConfig
|
||||
from .config import OnnxConfig, OnnxConfigWithPast
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -126,6 +126,11 @@ def export_pytorch(
|
||||
model.config.return_dict = True
|
||||
model.eval()
|
||||
|
||||
if isinstance(config, OnnxConfigWithPast):
|
||||
from transformers.onnx.utils import ort_compatible_forward_with_past_key_values_output
|
||||
|
||||
model.forward = ort_compatible_forward_with_past_key_values_output(model.forward, config.num_layers)
|
||||
|
||||
# Check if we need to override certain configuration item
|
||||
if config.values_override is not None:
|
||||
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
|
||||
@ -370,7 +375,14 @@ def validate_model_outputs(
|
||||
for name, value in reference_model_inputs.items():
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = config.flatten_output_collection_property(name, value)
|
||||
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
|
||||
onnx_inputs.update(
|
||||
{
|
||||
tensor_name: np.stack([t.numpy() for t in pt_tensor])
|
||||
if isinstance(pt_tensor, tuple)
|
||||
else pt_tensor.numpy()
|
||||
for tensor_name, pt_tensor in value.items()
|
||||
}
|
||||
)
|
||||
else:
|
||||
onnx_inputs[name] = value.numpy()
|
||||
|
||||
|
@ -14,6 +14,9 @@
|
||||
|
||||
from ctypes import c_float, sizeof
|
||||
from enum import Enum
|
||||
from typing import Iterable
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
|
||||
class ParameterFormat(Enum):
|
||||
@ -61,3 +64,39 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
|
||||
Size (in byte) taken to save all the parameters
|
||||
"""
|
||||
return num_parameters * dtype.size
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
def ort_compatible_forward_with_past_key_values_output(forward, num_layers):
|
||||
import functools
|
||||
|
||||
if isinstance(num_layers, Iterable):
|
||||
num_layers = sum(num_layers)
|
||||
|
||||
@functools.wraps(forward)
|
||||
def compatible_forward(*args, **kwargs):
|
||||
result = forward(*args, **kwargs)
|
||||
|
||||
if "past_key_values" in result:
|
||||
if isinstance(result["past_key_values"][0], tuple) or isinstance(result["past_key_values"][0], list):
|
||||
assert len(result["past_key_values"]) == num_layers and len(result["past_key_values"][0]) == 2
|
||||
present = []
|
||||
for i in range(num_layers):
|
||||
# Since transformers v4.*, past key and values are separated outputs.
|
||||
# Here we concatenate them into one tensor to be compatible with Attention operator.
|
||||
present.append(
|
||||
torch.cat(
|
||||
(
|
||||
result["past_key_values"][i][0].unsqueeze(0),
|
||||
result["past_key_values"][i][1].unsqueeze(0),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
return {"logits": result["logits"], "past_key_values": tuple(present)}
|
||||
else:
|
||||
return result
|
||||
|
||||
return compatible_forward
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from unittest import TestCase
|
||||
@ -267,7 +268,11 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported model input name: {model.main_input_name}")
|
||||
|
||||
with NamedTemporaryFile("w") as output:
|
||||
# "w" mode on Windows opens the file, so torch.onnx.export will crash because the file is already open while
|
||||
# attempting to open the file, resulting in a PermissionDenied
|
||||
# This workaround makes sure the file is not deleted on close and handle the file's lifetime by hand.
|
||||
with NamedTemporaryFile("w", delete=False) as output:
|
||||
output.close()
|
||||
try:
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
|
||||
@ -280,6 +285,7 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
onnx_outputs,
|
||||
onnx_config.atol_for_validation,
|
||||
)
|
||||
os.unlink(output.name)
|
||||
except (RuntimeError, ValueError) as e:
|
||||
self.fail(f"{name}, {feature} -> {e}")
|
||||
|
||||
|
Reference in New Issue
Block a user