mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
19 Commits
a4a46e62a5
...
gptneo_gpt
Author | SHA1 | Date | |
---|---|---|---|
188c8cba91 | |||
1c085e68fd | |||
1af22f9cdc | |||
c126cb0734 | |||
fc6463deaf | |||
6b8975d233 | |||
1f45e5478b | |||
ba78a24cea | |||
a59d28448a | |||
cf4fbad413 | |||
659ac7f3af | |||
846c9d14c1 | |||
300f9c1fb5 | |||
c99b81163a | |||
90cf3c9a11 | |||
311dbc86e2 | |||
f56d8d97a6 | |||
f6cfcfcfde | |||
5d2263e788 |
@ -325,7 +325,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| GIT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GLPN | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
| GPT Neo | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
| GPT NeoX | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| GPT NeoX Japanese | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GPT-J | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
|
@ -3181,6 +3181,17 @@ else:
|
||||
"TFGPT2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.gpt_neo"].extend(
|
||||
[
|
||||
"TF_GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFGPTNeoForCausalLM",
|
||||
"TFGPTNeoForQuestionAnswering",
|
||||
"TFGPTNeoForSequenceClassification",
|
||||
"TFGPTNeoForTokenClassification",
|
||||
"TFGPTNeoModel",
|
||||
"TFGPTNeoPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.gptj"].extend(
|
||||
[
|
||||
"TFGPTJForCausalLM",
|
||||
@ -6466,6 +6477,15 @@ if TYPE_CHECKING:
|
||||
TFGPT2Model,
|
||||
TFGPT2PreTrainedModel,
|
||||
)
|
||||
from .models.gpt_neo import (
|
||||
TF_GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFGPTNeoForCausalLM,
|
||||
TFGPTNeoForQuestionAnswering,
|
||||
TFGPTNeoForSequenceClassification,
|
||||
TFGPTNeoForTokenClassification,
|
||||
TFGPTNeoModel,
|
||||
TFGPTNeoPreTrainedModel,
|
||||
)
|
||||
from .models.gptj import (
|
||||
TFGPTJForCausalLM,
|
||||
TFGPTJForQuestionAnswering,
|
||||
|
@ -41,6 +41,7 @@ from .generation import GenerationConfig, TFGenerationMixin
|
||||
from .tf_utils import shape_list
|
||||
from .utils import (
|
||||
DUMMY_INPUTS,
|
||||
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
@ -1117,9 +1118,29 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
Returns:
|
||||
`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||
"""
|
||||
return {
|
||||
"input_ids": tf.constant(DUMMY_INPUTS, dtype=tf.int32),
|
||||
}
|
||||
dummy_inputs = {}
|
||||
|
||||
serving_sig = self.get_serving_input_signature()
|
||||
if self.main_input_name == "input_ids" and serving_sig[0]["input_ids"].shape.rank == 2:
|
||||
dummy_inputs["input_ids"] = tf.constant(DUMMY_INPUTS, dtype=tf.int32)
|
||||
elif self.main_input_name == "input_ids" and serving_sig[0]["input_ids"].shape.rank == 3:
|
||||
dummy_inputs["input_ids"] = tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32)
|
||||
elif self.main_input_name == "pixel_values":
|
||||
image_shape = serving_sig[0]["pixel_values"].shape.as_list()
|
||||
if image_shape[0] is None:
|
||||
image_shape[0] = 3 # matches DUMMY_INPUTS
|
||||
if None in image_shape[1:]:
|
||||
raise NotImplementedError(
|
||||
f"Could not fully infer input tensor shape; dummy inputs or serving sig must be defined manually for {self.__class__.__name__}"
|
||||
)
|
||||
rng = np.random.default_rng(42)
|
||||
VISION_DUMMY_INPUTS = rng.random(image_shape).astype(np.float32)
|
||||
dummy_inputs["pixel_values"] = tf.constant(VISION_DUMMY_INPUTS, dtype=tf.float32)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Could not fully infer input shapes, dummy inputs must be defined manually for {self.__class__.__name__}"
|
||||
)
|
||||
return dummy_inputs
|
||||
|
||||
@property
|
||||
def framework(self) -> str:
|
||||
@ -1140,6 +1161,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
self.config = config
|
||||
self.name_or_path = config.name_or_path
|
||||
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
||||
if not hasattr(self, "serving"): # Don't overwrite existing serving signatures
|
||||
self.serving = tf.function(self.eager_serving, input_signature=self.get_serving_input_signature())
|
||||
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
|
||||
self._set_save_spec(self.serving.input_signature[0])
|
||||
|
||||
@ -1182,12 +1205,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
|
||||
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
|
||||
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
|
||||
if head_mask.shape.rank == 1:
|
||||
if len(head_mask.shape) == 1:
|
||||
head_mask = head_mask[None, None, :, None, None]
|
||||
head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0)
|
||||
elif head_mask.shape.rank == 2:
|
||||
elif len(head_mask.shape) == 2:
|
||||
head_mask = head_mask[:, None, :, None, None]
|
||||
assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
|
||||
assert len(head_mask.shape) == 5, f"head_mask.dim != 5, instead {len(head_mask.shape)}"
|
||||
head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
|
||||
return head_mask
|
||||
|
||||
@ -1204,36 +1227,50 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
|
||||
return self.serving_output(output)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
|
||||
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
|
||||
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
|
||||
}
|
||||
]
|
||||
)
|
||||
def serving(self, inputs):
|
||||
"""
|
||||
Method used for serving the model.
|
||||
|
||||
Args:
|
||||
inputs (`Dict[str, tf.Tensor]`):
|
||||
The input of the saved model as a dictionary of tensors.
|
||||
"""
|
||||
output = self.call(inputs)
|
||||
|
||||
return self.serving_output(output)
|
||||
def get_serving_input_signature(self) -> List[Dict[str, tf.TensorSpec]]:
|
||||
model_inputs = list(dict(inspect.signature(self.call).parameters).keys())
|
||||
sig = {}
|
||||
if self.__class__.__name__.endswith("ForMultipleChoice"):
|
||||
text_dims = 3
|
||||
else:
|
||||
text_dims = 2
|
||||
if "input_ids" in model_inputs:
|
||||
for input_name in ("input_ids", "attention_mask", "token_type_ids"):
|
||||
if input_name in model_inputs:
|
||||
sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name)
|
||||
if "pixel_values" in model_inputs:
|
||||
pixel_values_shape = [None, None, None, None]
|
||||
if hasattr(self.config, "vision_config"):
|
||||
vision_config = self.config.vision_config
|
||||
else:
|
||||
vision_config = self.config
|
||||
if hasattr(vision_config, "num_channels"):
|
||||
pixel_values_shape[1] = vision_config.num_channels
|
||||
if hasattr(vision_config, "image_size"):
|
||||
pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size
|
||||
sig["pixel_values"] = tf.TensorSpec(pixel_values_shape, tf.float32, name="pixel_values")
|
||||
return [sig]
|
||||
|
||||
def serving_output(self, output):
|
||||
"""
|
||||
Prepare the output of the saved model. Each model must implement this function.
|
||||
|
||||
Args:
|
||||
output ([`TFBaseModelOutput`]):
|
||||
The output returned by the model.
|
||||
Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
config_variables = {
|
||||
"hidden_states": "output_hidden_states",
|
||||
"attentions": "output_attentions",
|
||||
"past_key_values": "use_cache",
|
||||
}
|
||||
if isinstance(output, ModelOutput):
|
||||
for key, config_var in config_variables.items():
|
||||
if key in output:
|
||||
if not getattr(self.config, config_var, False):
|
||||
output[key] = None
|
||||
elif output[key] is not None:
|
||||
try:
|
||||
output[key] = tf.convert_to_tensor(output[key])
|
||||
except ValueError:
|
||||
pass # Layers may not have the same dimensions
|
||||
return output
|
||||
|
||||
def can_generate(self) -> bool:
|
||||
"""
|
||||
|
@ -54,6 +54,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("gpt-sw3", "TFGPT2Model"),
|
||||
("gpt2", "TFGPT2Model"),
|
||||
("gptj", "TFGPTJModel"),
|
||||
("gpt_neo", "TFGPTNeoModel"),
|
||||
("groupvit", "TFGroupViTModel"),
|
||||
("hubert", "TFHubertModel"),
|
||||
("layoutlm", "TFLayoutLMModel"),
|
||||
@ -173,6 +174,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("gpt-sw3", "TFGPT2LMHeadModel"),
|
||||
("gpt2", "TFGPT2LMHeadModel"),
|
||||
("gptj", "TFGPTJForCausalLM"),
|
||||
("gpt_neo", "TFGPTNeoForCausalLM"),
|
||||
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
||||
("opt", "TFOPTForCausalLM"),
|
||||
("rembert", "TFRemBertForCausalLM"),
|
||||
@ -306,6 +308,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("gpt-sw3", "TFGPT2ForSequenceClassification"),
|
||||
("gpt2", "TFGPT2ForSequenceClassification"),
|
||||
("gptj", "TFGPTJForSequenceClassification"),
|
||||
("gpt_neo", "TFGPTNeoForSequenceClassification"),
|
||||
("layoutlm", "TFLayoutLMForSequenceClassification"),
|
||||
("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"),
|
||||
("longformer", "TFLongformerForSequenceClassification"),
|
||||
@ -338,6 +341,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
|
||||
("funnel", "TFFunnelForQuestionAnswering"),
|
||||
("gptj", "TFGPTJForQuestionAnswering"),
|
||||
("gpt_neo", "TFGPTNeoForQuestionAnswering"),
|
||||
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
|
||||
("longformer", "TFLongformerForQuestionAnswering"),
|
||||
("mobilebert", "TFMobileBertForQuestionAnswering"),
|
||||
@ -381,6 +385,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("esm", "TFEsmForTokenClassification"),
|
||||
("flaubert", "TFFlaubertForTokenClassification"),
|
||||
("funnel", "TFFunnelForTokenClassification"),
|
||||
("gpt_neo", "TFGPTNeoForTokenClassification"),
|
||||
("layoutlm", "TFLayoutLMForTokenClassification"),
|
||||
("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
|
||||
("longformer", "TFLongformerForTokenClassification"),
|
||||
|
@ -13,7 +13,13 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@ -37,6 +43,23 @@ else:
|
||||
"load_tf_weights_in_gpt_neo",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_gpt_neo"] = [
|
||||
"TF_GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFGPTNeoForCausalLM",
|
||||
"TFGPTNeoForQuestionAnswering",
|
||||
"TFGPTNeoForSequenceClassification",
|
||||
"TFGPTNeoForTokenClassification",
|
||||
"TFGPTNeoModel",
|
||||
"TFGPTNeoPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@ -70,6 +93,22 @@ if TYPE_CHECKING:
|
||||
load_tf_weights_in_gpt_neo,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_gpt_neo import (
|
||||
TF_GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFGPTNeoForCausalLM,
|
||||
TFGPTNeoForQuestionAnswering,
|
||||
TFGPTNeoForSequenceClassification,
|
||||
TFGPTNeoForTokenClassification,
|
||||
TFGPTNeoModel,
|
||||
TFGPTNeoPreTrainedModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
@ -892,7 +892,6 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
|
959
src/transformers/models/gpt_neo/modeling_tf_gpt_neo.py
Normal file
959
src/transformers/models/gpt_neo/modeling_tf_gpt_neo.py
Normal file
@ -0,0 +1,959 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch GPT Neo model."""
|
||||
|
||||
# TODO Implement weight tying, and see if there's a generic solution (because this will come up in user PRs too)
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import ACT2FN
|
||||
from ...modeling_tf_outputs import (
|
||||
TFBaseModelOutputWithPast,
|
||||
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||
TFCausalLMOutputWithCrossAttentions,
|
||||
TFCausalLMOutputWithPast,
|
||||
TFQuestionAnsweringModelOutput,
|
||||
TFSequenceClassifierOutputWithPast,
|
||||
TFTokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, unpack_inputs, TFCausalLanguageModelingLoss, TFQuestionAnsweringLoss, TFTokenClassificationLoss, TFSequenceClassificationLoss
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from ...tf_utils import shape_list
|
||||
from .configuration_gpt_neo import GPTNeoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "GPTNeoConfig"
|
||||
|
||||
TF_GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"EleutherAI/gpt-neo-1.3B",
|
||||
# See all GPTNeo models at https://huggingface.co/models?filter=gpt_neo
|
||||
]
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B"
|
||||
|
||||
class TFGPTNeoSelfAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, attention_type, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
max_positions = config.max_position_embeddings
|
||||
bias = tf.linalg.band_part(tf.ones((max_positions, max_positions), dtype=tf.bool), -1, 0)
|
||||
bias = tf.reshape(bias, (1, 1, max_positions, max_positions))
|
||||
|
||||
if attention_type == "local":
|
||||
bias = tf.math.logical_xor(bias, tf.linalg.band_part(bias, -config.window_size, 0))
|
||||
|
||||
self.bias = tf.constant(bias)
|
||||
self.masked_bias = tf.constant(-1e9)
|
||||
|
||||
self.attn_dropout = tf.keras.layers.Dropout(float(config.attention_dropout))
|
||||
self.resid_dropout = tf.keras.layers.Dropout(float(config.resid_dropout))
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
|
||||
self.k_proj = tf.keras.layers.Dense(self.embed_dim, use_bias=False, name="k_proj")
|
||||
self.v_proj = tf.keras.layers.Dense(self.embed_dim, use_bias=False, name="v_proj")
|
||||
self.q_proj = tf.keras.layers.Dense(self.embed_dim, use_bias=False, name="q_proj")
|
||||
self.out_proj = tf.keras.layers.Dense(self.embed_dim, use_bias=True, name="out_proj")
|
||||
|
||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||
new_shape = tensor.shape[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tf.reshape(tensor, new_shape)
|
||||
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
||||
|
||||
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
||||
tensor = tf.transpose(tensor, perm=[0, 2, 1, 3])
|
||||
new_shape = tensor.shape[:-2] + (num_heads * attn_head_size,)
|
||||
return tf.reshape(tensor, new_shape)
|
||||
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
query = tf.cast(query, tf.float32)
|
||||
key = tf.cast(key, tf.float32)
|
||||
|
||||
attn_weights = tf.matmul(query, key, transpose_b=True)
|
||||
|
||||
query_length, key_length = query.shape[-2], key.shape[-2]
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
mask_value = tf.float32.min
|
||||
mask_value = tf.constant(mask_value, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
attn_weights = tf.cast(attn_weights, value.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_output = tf.matmul(attn_weights, value)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
layer_past=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
training=None,
|
||||
):
|
||||
query = self.q_proj(hidden_states)
|
||||
key = self.k_proj(hidden_states)
|
||||
value = self.v_proj(hidden_states)
|
||||
|
||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key = layer_past[0]
|
||||
past_value = layer_past[1]
|
||||
key = tf.concat((past_key, key), axis=-2)
|
||||
value = tf.concat((past_value, value), axis=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key, value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output, training=training)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs # a, present, (attentions)
|
||||
|
||||
class TFGPTNeoAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, layer_id=0, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.layer_id = layer_id
|
||||
self.attention_layers = config.attention_layers
|
||||
self.attention_type = self.attention_layers[layer_id]
|
||||
|
||||
if self.attention_type in ["global", "local"]:
|
||||
self.attention = TFGPTNeoSelfAttention(config, self.attention_type, name="attention")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
|
||||
f"{config.attention_layers}. Select attn layer types from ['global', 'local'] only."
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
return self.attention(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_past=layer_past,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
||||
class TFGPTNeoMLP(tf.keras.layers.Layer):
|
||||
def __init__(self, intermediate_size, config, **kwargs): # in MLP: intermediate_size= 4 * hidden_size
|
||||
super().__init__(**kwargs)
|
||||
embed_dim = config.hidden_size
|
||||
self.c_fc = tf.keras.layers.Dense(intermediate_size, name="c_fc")
|
||||
self.c_proj = tf.keras.layers.Dense(embed_dim, name="c_proj")
|
||||
self.act = ACT2FN[config.activation_function]
|
||||
self.dropout = tf.keras.layers.Dropout(float(config.resid_dropout))
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.c_proj(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFGPTNeoBlock(tf.keras.layers.Layer):
|
||||
def __init__(self, config, layer_id, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
|
||||
self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
|
||||
self.attn = TFGPTNeoAttention(config, layer_id, name="attn")
|
||||
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
|
||||
self.mlp = TFGPTNeoMLP(inner_dim, config, name="mlp")
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
training=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_outputs = self.attn(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
)
|
||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
||||
outputs = attn_outputs[1:]
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states, training=training)
|
||||
# residual connection
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
|
||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||
|
||||
|
||||
class TFGPTNeoPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = GPTNeoConfig
|
||||
base_model_prefix = "transformer"
|
||||
_no_split_modules = ["TFGPTNeoBlock"]
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
|
||||
GPT_NEO_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a TensorFlow [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
|
||||
subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`GPTNeoConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
GPT_NEO_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`tf.Tensor` of shape `(batch_size, input_ids_length)`):
|
||||
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
||||
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
||||
sequence tokens in the vocabulary.
|
||||
|
||||
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
||||
`input_ids`.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.num_layers`):
|
||||
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
||||
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
||||
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
||||
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
token_type_ids (`tf.Tensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
||||
1]`:
|
||||
|
||||
- 0 corresponds to a *sentence A* token,
|
||||
- 1 corresponds to a *sentence B* token.
|
||||
|
||||
[What are token type IDs?](../glossary#token-type-ids)
|
||||
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
||||
`past_key_values`).
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare GPT Neo Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
GPT_NEO_START_DOCSTRING,
|
||||
)
|
||||
class TFGPTNeoModel(TFGPTNeoPreTrainedModel):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.wte = tf.keras.layers.Embedding(config.vocab_size, self.embed_dim, name="wte")
|
||||
self.wpe = tf.keras.layers.Embedding(config.max_position_embeddings, self.embed_dim, name="wpe")
|
||||
self.drop = tf.keras.layers.Dropout(float(config.embed_dropout))
|
||||
self.h = [TFGPTNeoBlock(config, layer_id=i, name=f"h_._{i}") for i in range(config.num_layers)]
|
||||
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.wte = new_embeddings
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFBaseModelOutputWithPastAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[tf.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[tf.Tensor]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
token_type_ids: Optional[tf.Tensor] = None,
|
||||
position_ids: Optional[tf.Tensor] = None,
|
||||
head_mask: Optional[tf.Tensor] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: Optional[bool] = None,
|
||||
) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPastAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
|
||||
batch_size = input_shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
batch_size = input_shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.reshape(token_type_ids, (-1, input_shape[-1]))
|
||||
if position_ids is not None:
|
||||
position_ids = tf.reshape(position_ids, (-1, input_shape[-1]))
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = [None] * len(self.h)
|
||||
else:
|
||||
past_length = past_key_values[0][0].shape[-2]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)
|
||||
position_ids = tf.expand_dims(position_ids, 0)
|
||||
position_ids = tf.reshape(position_ids, (-1, input_shape[-1]))
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = tf.reshape(attention_mask, (batch_size, -1))
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = tf.cast(attention_mask, dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * tf.float32.min
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x num_heads x N x N
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
||||
|
||||
if inputs_embeds is None:
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.config.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.config.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states, training=training)
|
||||
|
||||
output_shape = input_shape + [hidden_states.shape[-1]]
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = tf.reshape(hidden_states, output_shape)
|
||||
# Addlast hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
return TFBaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The GPT Neo Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
||||
embeddings).
|
||||
""",
|
||||
GPT_NEO_START_DOCSTRING,
|
||||
)
|
||||
class TFGPTNeoForCausalLM(TFGPTNeoPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"h\.\d+\.attn\.masked_bias",
|
||||
r"lm_head.weight",
|
||||
r"h\.\d+\.attn\.attention\.bias",
|
||||
]
|
||||
_keys_to_ignore_on_save = [r"lm_head.weight"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
|
||||
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFGPTNeoModel(config, name="transformer")
|
||||
self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1:]
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = tf.math.cumsum(tf.cast(attention_mask, tf.int32), axis=-1) - 1
|
||||
position_ids = tf.where(attention_mask == 0, 1, position_ids)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1:]
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFCausalLMOutputWithCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[tf.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[tf.Tensor]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
token_type_ids: Optional[tf.Tensor] = None,
|
||||
position_ids: Optional[tf.Tensor] = None,
|
||||
head_mask: Optional[tf.Tensor] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
labels: Optional[tf.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[Tuple[tf.Tensor], TFCausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :]
|
||||
shift_labels = labels[..., 1:]
|
||||
loss = self.hf_compute_loss(shift_labels, shift_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(
|
||||
past_key_values: Tuple[Tuple[tf.Tensor]], beam_idx: tf.Tensor
|
||||
) -> Tuple[Tuple[tf.Tensor]]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
|
||||
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||
beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(tf.gather(past_state, beam_idx) for past_state in layer_past)
|
||||
for layer_past in past_key_values
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The GPTNeo Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
[`TFGPTNeoForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
||||
(e.g. GPT-1) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||
each row of the batch).
|
||||
""",
|
||||
GPT_NEO_START_DOCSTRING,
|
||||
)
|
||||
class TFGPTNeoForSequenceClassification(TFGPTNeoPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
|
||||
|
||||
def __init__(self, config, *args, **kwargs):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
self.transformer = TFGPTNeoModel(config, name="transformer")
|
||||
self.score = tf.keras.layers.Dense(self.num_labels, use_bias=False, name="score")
|
||||
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFSequenceClassifierOutputWithPast,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[tf.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[tf.Tensor]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
token_type_ids: Optional[tf.Tensor] = None,
|
||||
position_ids: Optional[tf.Tensor] = None,
|
||||
head_mask: Optional[tf.Tensor] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
labels: Optional[tf.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple[tf.Tensor], TFSequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = shape_list(input_ids)[:2]
|
||||
else:
|
||||
batch_size, sequence_length = shape_list(inputs_embeds)[:2]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = tf.fill(dims=(batch_size,), value=-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = tf.math.count_nonzero(input_ids != self.config.pad_token_id, axis=1, dtype=tf.int32) - 1
|
||||
else:
|
||||
sequence_lengths = tf.fill(dims=(batch_size,), value=-1)
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = tf.gather_nd(logits, tf.stack([tf.range(batch_size), sequence_lengths], axis=1))
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == tf.int64 or labels.dtype == tf.int32):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = tf.keras.losses.MeanSquaredError()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
loss = loss_fct(labels, pooled_logits)
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = tf.keras.losses.BinaryCrossentropy(from_logits=True)
|
||||
loss = loss_fct(labels, pooled_logits)
|
||||
if loss.shape.rank == 0:
|
||||
loss = tf.expand_dims(loss, 0)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFSequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
GPT Neo model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
||||
Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
GPT_NEO_START_DOCSTRING,
|
||||
)
|
||||
class TFGPTNeoForTokenClassification(TFGPTNeoPreTrainedModel, TFTokenClassificationLoss):
|
||||
def __init__(self, config, *args, **kwargs):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFGPTNeoModel(config, name="transformer")
|
||||
self.dropout = tf.keras.layers.Dropout(config.classifier_dropout, name="dropout")
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier")
|
||||
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint="EleutherAI/gpt-neo-125m",
|
||||
output_type=TFTokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_loss=0.25,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[tf.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
token_type_ids: Optional[tf.Tensor] = None,
|
||||
position_ids: Optional[tf.Tensor] = None,
|
||||
head_mask: Optional[tf.Tensor] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
labels: Optional[tf.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: Optional[bool] = None,
|
||||
) -> Union[Tuple, TFTokenClassifierOutput]:
|
||||
r"""
|
||||
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.hf_compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFTokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The GPT-Neo Model transformer with a span classification head on top for extractive question-answering tasks like
|
||||
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
GPT_NEO_START_DOCSTRING,
|
||||
)
|
||||
class TFGPTNeoForQuestionAnswering(TFGPTNeoPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
|
||||
|
||||
def __init__(self, config, *args, **kwargs):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
self.transformer = TFGPTNeoModel(config, name="transformer")
|
||||
self.qa_outputs = tf.keras.layers.Dense(2, name="qa_outputs")
|
||||
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFQuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
real_checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[tf.Tensor] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
token_type_ids: Optional[tf.Tensor] = None,
|
||||
position_ids: Optional[tf.Tensor] = None,
|
||||
head_mask: Optional[tf.Tensor] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
start_positions: Optional[tf.Tensor] = None,
|
||||
end_positions: Optional[tf.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, TFQuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
loss = self.hf_compute_loss({"start_position": start_positions, "end_position": end_positions}, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFQuestionAnsweringModelOutput(
|
||||
loss=loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
@ -1366,6 +1366,51 @@ class TFGPT2PreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TFGPTNeoForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFGPTNeoForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFGPTNeoForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFGPTNeoForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFGPTNeoModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFGPTNeoPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFGPTJForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
|
602
tests/models/gpt_neo/test_modeling_tf_gpt_neo.py
Normal file
602
tests/models/gpt_neo/test_modeling_tf_gpt_neo.py
Normal file
@ -0,0 +1,602 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch GPT Neo model. """
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import GPTNeoConfig, is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
TF_GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPT2Tokenizer,
|
||||
TFGPTNeoForCausalLM,
|
||||
TFGPTNeoForQuestionAnswering,
|
||||
TFGPTNeoForSequenceClassification,
|
||||
TFGPTNeoForTokenClassification,
|
||||
TFGPTNeoModel,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GPTNeoModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=14,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_token_type_ids=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
use_mc_token_ids=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=4,
|
||||
attention_types=[[["global", "local"], 2]],
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
window_size=7,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.use_mc_token_ids = use_mc_token_ids
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.window_size = window_size
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.bos_token_id = vocab_size - 1
|
||||
self.eos_token_id = vocab_size - 1
|
||||
self.pad_token_id = vocab_size - 1
|
||||
self.attention_types = attention_types
|
||||
|
||||
def get_large_model_config(self):
|
||||
return GPTNeoConfig.from_pretrained("gpt-neo-125M")
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
mc_token_ids = None
|
||||
if self.use_mc_token_ids:
|
||||
mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
return GPTNeoConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_hidden_layers,
|
||||
num_heads=self.num_attention_heads,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
use_cache=True,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
window_size=self.window_size,
|
||||
attention_types=self.attention_types,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = self.prepare_config_and_inputs()
|
||||
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_gpt_neo_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFGPTNeoModel(config=config)
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
# past_key_values is not implemented
|
||||
# self.parent.assertEqual(len(result.past_key_values), config.n_layer)
|
||||
|
||||
def create_and_check_gpt_neo_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFGPTNeoModel(config=config)
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
|
||||
outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
|
||||
outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past = outputs.to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
next_token_types = ids_tensor([self.batch_size, 1], self.type_vocab_size)
|
||||
|
||||
# append to next input_ids and token_type_ids
|
||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(np.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_gpt_neo_model_attention_mask_past(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
):
|
||||
model = TFGPTNeoModel(config=config)
|
||||
|
||||
# create attention mask
|
||||
attn_mask = np.ones(input_ids.shape, dtype=np.int32)
|
||||
half_seq_length = self.seq_length // 2
|
||||
attn_mask[:, half_seq_length:] = 0
|
||||
attn_mask = tf.convert_to_tensor(attn_mask)
|
||||
|
||||
# first forward pass
|
||||
output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
# change a random masked slice from input_ids
|
||||
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy().item() + 1
|
||||
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).numpy().squeeze(-1)
|
||||
input_ids = input_ids.numpy()
|
||||
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||
input_ids = tf.convert_to_tensor(input_ids)
|
||||
|
||||
# append to next input_ids and attn_mask
|
||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
attn_mask = tf.concat(
|
||||
[attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(np.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_gpt_neo_model_past_large_inputs(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
):
|
||||
model = TFGPTNeoModel(config=config)
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=input_mask, use_cache=True)
|
||||
|
||||
output, past = outputs.to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_token_types = ids_tensor([self.batch_size, 3], self.type_vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||
|
||||
# append to next input_ids and token_type_ids
|
||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1)
|
||||
next_attention_mask = tf.concat([input_mask, next_mask], axis=-1)
|
||||
|
||||
output_from_no_past = model(
|
||||
next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask
|
||||
)["last_hidden_state"]
|
||||
output_from_past = model(
|
||||
next_tokens, token_type_ids=next_token_types, attention_mask=next_attention_mask, past_key_values=past
|
||||
)["last_hidden_state"]
|
||||
self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(np.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFGPTNeoForCausalLM(config)
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_gpt_neo_for_question_answering(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFGPTNeoForQuestionAnswering(config)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
self.parent.assertEqual(result.start_logits.shape, [self.batch_size, self.seq_length])
|
||||
self.parent.assertEqual(result.end_logits.shape, [self.batch_size, self.seq_length])
|
||||
|
||||
def create_and_check_gpt_neo_for_sequence_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFGPTNeoForSequenceClassification(config)
|
||||
model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_gpt_neo_for_token_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFGPTNeoForTokenClassification(config)
|
||||
model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
self.parent.assertEqual(result.logits.shape, [self.batch_size, self.seq_length, self.num_labels])
|
||||
|
||||
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFGPTNeoForCausalLM(config)
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFGPTNeoModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
TFGPTNeoModel,
|
||||
TFGPTNeoForCausalLM,
|
||||
TFGPTNeoForQuestionAnswering,
|
||||
TFGPTNeoForSequenceClassification,
|
||||
TFGPTNeoForTokenClassification,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (TFGPTNeoForCausalLM,) if is_tf_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": TFGPTNeoModel,
|
||||
"question-answering": TFGPTNeoForQuestionAnswering,
|
||||
"text-classification": TFGPTNeoForSequenceClassification,
|
||||
"token-classification": TFGPTNeoForTokenClassification,
|
||||
"text-generation": TFGPTNeoForCausalLM,
|
||||
"zero-shot": TFGPTNeoForSequenceClassification,
|
||||
}
|
||||
if is_tf_available()
|
||||
else {}
|
||||
)
|
||||
test_onnx = False
|
||||
test_missing_keys = False
|
||||
test_pruning = False
|
||||
test_model_parallel = False
|
||||
test_head_masking = False
|
||||
test_resize_token_embeddings = False
|
||||
|
||||
# special case for DoubleHeads model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = GPTNeoModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=GPTNeoConfig, n_embd=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_gpt_neo_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_model(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_model_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_model_past(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_model_att_mask_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_model_attention_mask_past(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_model_past_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_lm_head_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_question_answering_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_sequence_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_token_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
|
||||
if model_class in self.all_generative_model_classes:
|
||||
x = model.get_output_embeddings()
|
||||
assert isinstance(x, tf.keras.layers.Layer)
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
else:
|
||||
x = model.get_output_embeddings()
|
||||
assert x is None
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def _get_hidden_states(self):
|
||||
return tf.constant(
|
||||
[
|
||||
[
|
||||
[0.4983, -0.7584, -1.6944, 0.5440],
|
||||
[2.6918, 0.4206, 0.4176, 0.2055],
|
||||
[-0.0071, -0.0405, -1.4920, -0.3630],
|
||||
[1.0492, 0.1599, -1.7648, 0.2419],
|
||||
[-1.8348, 2.0514, -0.1946, 0.3203],
|
||||
[0.7672, -1.1600, -1.7118, -0.9056],
|
||||
[0.2986, 0.5372, 0.7729, -0.1927],
|
||||
[0.0285, 0.2629, -1.1156, -1.1992],
|
||||
]
|
||||
],
|
||||
dtype=tf.float32,
|
||||
)
|
||||
|
||||
def test_local_attn_probs(self):
|
||||
model = TFGPTNeoModel.from_pretrained("valhalla/gpt-neo-random-tiny")
|
||||
layer = model.h[1].attn.attention
|
||||
hidden_states = self._get_hidden_states()
|
||||
hidden_states = tf.concat([hidden_states, hidden_states - 0.5], axis=2)
|
||||
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
mask_tokens = 2
|
||||
attention_mask = tf.ones((batch_size, seq_length), dtype=tf.int32)
|
||||
attention_mask[:, -mask_tokens:] = 0 # dont attend last mask_tokens
|
||||
|
||||
attention_mask = tf.reshape(attention_mask, (batch_size, -1))
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
attention_mask = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
|
||||
|
||||
attn_probs = layer(hidden_states, attention_mask=attention_mask, output_attentions=True)[-1]
|
||||
|
||||
# the last 2 tokens are masked, and should have 0 attn_probs
|
||||
self.assertTrue(tf.reduce_all(attn_probs[:, :, -mask_tokens:, -mask_tokens:] == 0))
|
||||
|
||||
# in local attention each token can only attend to the previous window_size tokens (including itself)
|
||||
# here window_size is 4, so a token at index 5 can only attend to indices [2, 3, 4, 5]
|
||||
# and the attn_probs should be 0 for token [0, 1]
|
||||
self.assertTrue(tf.reduce_all(attn_probs[:, :, 5, 2:6] != 0))
|
||||
self.assertTrue(tf.reduce_all(attn_probs[:, :, 5, :2] == 0))
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFGPTNeoModelLanguageGenerationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def model(self):
|
||||
return TFGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self):
|
||||
return GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
||||
|
||||
@slow
|
||||
def test_lm_generate_gpt_neo(self):
|
||||
model = self.model
|
||||
input_ids = tf.constant([[464, 3290]], dtype=tf.int32) # The dog
|
||||
# fmt: off
|
||||
# The dog-eared copy of the book, which is a collection of essays by the late author,
|
||||
expected_output_ids = [464, 3290, 12, 3380, 4866, 286, 262, 1492, 11, 543, 318, 257, 4947, 286, 27126, 416, 262, 2739, 1772, 11]
|
||||
# fmt: on
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_gpt_neo_sample(self):
|
||||
model = self.model
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
tf.random.set_seed(0)
|
||||
tokenized = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True)
|
||||
input_ids = tokenized.input_ids
|
||||
output_ids = model.generate(input_ids, do_sample=True)
|
||||
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUT_STR = "Today is a nice day and if you don’t get the memo here is what you can"
|
||||
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
||||
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
model = self.model
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Define PAD Token = EOS Token = 50256
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
# use different length sentences to test batching
|
||||
sentences = [
|
||||
"Hello, my dog is a little",
|
||||
"Today, I am",
|
||||
]
|
||||
|
||||
inputs = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
input_ids = inputs["input_ids"]
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="tf").input_ids
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
|
||||
num_paddings = inputs_non_padded.shape[-1] - tf.reduce_sum(inputs["attention_mask"][-1]).numpy().item()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="tf").input_ids
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
||||
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
||||
|
||||
expected_output_sentence = [
|
||||
"Hello, my dog is a little bit of a kitty. She is a very sweet and loving",
|
||||
"Today, I am going to talk about the best way to get a job in the",
|
||||
]
|
||||
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||||
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFGPTNeoModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
Reference in New Issue
Block a user