Compare commits

...

5 Commits

7 changed files with 110 additions and 40 deletions

View File

@ -249,25 +249,27 @@ class GPT2OnnxConfig(OnnxConfigWithPast):
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_key_values_length = 1
past_shape = (
2, # Key AND Values
batch,
self.num_attention_heads,
past_key_values_length,
self._config.hidden_size // self.num_attention_heads,
)
ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
torch.zeros(past_shape, dtype=torch.float32) for _ in range(self.num_layers)
]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
if self.use_past:
ordered_inputs["attention_mask"] = torch.ones(
batch, seqlen + past_key_values_length, dtype=torch.int64
)
else:
ordered_inputs["attention_mask"] = common_inputs["attention_mask"].long()
return ordered_inputs
@property
def default_onnx_opset(self) -> int:
return 13
return 11

View File

@ -231,12 +231,10 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
# Need to add the past_keys
@ -248,22 +246,24 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_key_values_length = 1
past_shape = (
2, # Key AND Values
batch,
self.num_attention_heads,
past_key_values_length,
self._config.hidden_size // self.num_attention_heads,
)
ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
torch.zeros(past_shape, dtype=torch.float32) for _ in range(self.num_layers)
]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
if self.use_past:
ordered_inputs["attention_mask"] = torch.ones(
batch, seqlen + past_key_values_length, dtype=torch.int64
)
else:
ordered_inputs["attention_mask"] = common_inputs["attention_mask"].long()
return ordered_inputs

View File

@ -198,24 +198,24 @@ class GPTJOnnxConfig(OnnxConfigWithPast):
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_key_values_length = 1
past_shape = (
2, # Key AND Values
batch,
self.num_attention_heads,
past_key_values_length,
self._config.hidden_size // self.num_attention_heads,
)
ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
torch.zeros(past_shape, dtype=torch.float32) for _ in range(self.num_layers)
]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
return ordered_inputs
if self.use_past:
ordered_inputs["attention_mask"] = torch.ones(
batch, seqlen + past_key_values_length, dtype=torch.int64
)
else:
ordered_inputs["attention_mask"] = common_inputs["attention_mask"].long()
@property
def default_onnx_opset(self) -> int:

View File

@ -101,10 +101,7 @@ class OnnxConfig(ABC):
self._patching_specs = []
for spec in patching_specs if patching_specs is not None else []:
final_spec = spec
if spec.orig_op is None:
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
self._patching_specs.append(final_spec)
self.install_patching_spec(spec)
@classmethod
def from_model_config(cls, config: "PretrainedConfig", task: str = "default") -> "OnnxConfig":
@ -246,6 +243,17 @@ class OnnxConfig(ABC):
images.append(Image.fromarray(data.astype("uint8")).convert("RGB"))
return images
def install_patching_spec(self, spec: PatchingSpec):
"""
Append a patching spec after initialization of the configuration
:param spec:
:return:
"""
if spec.orig_op is None:
spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
self._patching_specs.append(spec)
def generate_dummy_inputs(
self,
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
@ -447,8 +455,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_key_values_length = 1
shape = (
2,
batch,
self.num_attention_heads,
past_key_values_length,
@ -457,12 +466,16 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
if "attention_mask" in common_inputs:
common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
[
common_inputs["attention_mask"],
torch.ones(batch, seqlen + past_key_values_length, dtype=torch.int64),
],
dim=1,
)
common_inputs["past_key_values"] = []
for _ in range(self.num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
common_inputs["past_key_values"].append(torch.zeros(shape, dtype=torch.float32))
return common_inputs
@ -481,12 +494,10 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
name = "past_key_values" if direction == "inputs" else "present"
for i in range(self.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
inputs_or_outputs[f"{name}.{i}"] = {1: "batch", 3: "past_sequence + sequence"}
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key"] = t[0]
flattened_output[f"{name}.{idx}.value"] = t[1]
flattened_output[f"{name}.{idx}"] = t
def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]:
flattened_output = {}

View File

@ -29,7 +29,7 @@ from ..utils import (
is_torch_onnx_dict_inputs_support_available,
logging,
)
from .config import OnnxConfig
from .config import OnnxConfig, OnnxConfigWithPast
if is_torch_available():
@ -126,6 +126,11 @@ def export_pytorch(
model.config.return_dict = True
model.eval()
if isinstance(config, OnnxConfigWithPast):
from transformers.onnx.utils import ort_compatible_forward_with_past_key_values_output
model.forward = ort_compatible_forward_with_past_key_values_output(model.forward, config.num_layers)
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
@ -370,7 +375,14 @@ def validate_model_outputs(
for name, value in reference_model_inputs.items():
if isinstance(value, (list, tuple)):
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
onnx_inputs.update(
{
tensor_name: np.stack([t.numpy() for t in pt_tensor])
if isinstance(pt_tensor, tuple)
else pt_tensor.numpy()
for tensor_name, pt_tensor in value.items()
}
)
else:
onnx_inputs[name] = value.numpy()

View File

@ -14,6 +14,9 @@
from ctypes import c_float, sizeof
from enum import Enum
from typing import Iterable
from transformers import is_torch_available
class ParameterFormat(Enum):
@ -61,3 +64,39 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
Size (in byte) taken to save all the parameters
"""
return num_parameters * dtype.size
if is_torch_available():
import torch
def ort_compatible_forward_with_past_key_values_output(forward, num_layers):
import functools
if isinstance(num_layers, Iterable):
num_layers = sum(num_layers)
@functools.wraps(forward)
def compatible_forward(*args, **kwargs):
result = forward(*args, **kwargs)
if "past_key_values" in result:
if isinstance(result["past_key_values"][0], tuple) or isinstance(result["past_key_values"][0], list):
assert len(result["past_key_values"]) == num_layers and len(result["past_key_values"][0]) == 2
present = []
for i in range(num_layers):
# Since transformers v4.*, past key and values are separated outputs.
# Here we concatenate them into one tensor to be compatible with Attention operator.
present.append(
torch.cat(
(
result["past_key_values"][i][0].unsqueeze(0),
result["past_key_values"][i][1].unsqueeze(0),
),
dim=0,
)
)
return {"logits": result["logits"], "past_key_values": tuple(present)}
else:
return result
return compatible_forward

View File

@ -1,3 +1,4 @@
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from unittest import TestCase
@ -267,7 +268,11 @@ class OnnxExportTestCaseV2(TestCase):
else:
raise ValueError(f"Unsupported model input name: {model.main_input_name}")
with NamedTemporaryFile("w") as output:
# "w" mode on Windows opens the file, so torch.onnx.export will crash because the file is already open while
# attempting to open the file, resulting in a PermissionDenied
# This workaround makes sure the file is not deleted on close and handle the file's lifetime by hand.
with NamedTemporaryFile("w", delete=False) as output:
output.close()
try:
onnx_inputs, onnx_outputs = export(
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
@ -280,6 +285,7 @@ class OnnxExportTestCaseV2(TestCase):
onnx_outputs,
onnx_config.atol_for_validation,
)
os.unlink(output.name)
except (RuntimeError, ValueError) as e:
self.fail(f"{name}, {feature} -> {e}")