mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
2 Commits
v4.56.2
...
keras3_com
Author | SHA1 | Date | |
---|---|---|---|
a8cf3e0371 | |||
39407389e4 |
@ -257,7 +257,6 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
|
||||
safetensors archive created with the safe_open() function."""
|
||||
import tensorflow as tf
|
||||
from keras import backend as K
|
||||
|
||||
if tf_inputs is None:
|
||||
tf_inputs = tf_model.dummy_inputs
|
||||
@ -310,7 +309,8 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
mismatched_keys = []
|
||||
is_safetensor_archive = hasattr(pt_state_dict, "get_tensor")
|
||||
for symbolic_weight in symbolic_weights:
|
||||
sw_name = symbolic_weight.name
|
||||
# Keras 2 stores the full weight path on the "name" attribute, but Keras 3 uses a new "path" attr
|
||||
sw_name = getattr(symbolic_weight, "path", symbolic_weight.name)
|
||||
name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
||||
sw_name,
|
||||
start_prefix_to_remove=start_prefix_to_remove,
|
||||
@ -357,7 +357,7 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
|
||||
tf_loaded_numel += tensor_size(array)
|
||||
|
||||
K.set_value(symbolic_weight, array)
|
||||
symbolic_weight.assign(array)
|
||||
del array # Immediately free memory to keep peak usage as low as possible
|
||||
all_pytorch_weights.discard(name)
|
||||
|
||||
|
@ -28,6 +28,7 @@ import warnings
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
from contextlib import nullcontext
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
@ -625,13 +626,22 @@ def dtype_byte_size(dtype):
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def strip_model_name_and_prefix(name, _prefix=None):
|
||||
def strip_model_name_and_prefix(name, _prefix=None, strip_suffix=False):
|
||||
if _prefix is not None and name.startswith(_prefix):
|
||||
name = name[len(_prefix) :]
|
||||
if name.startswith("/"):
|
||||
name = name[1:]
|
||||
if "model." not in name and len(name.split("/")) > 1:
|
||||
name = "/".join(name.split("/")[1:])
|
||||
if strip_suffix:
|
||||
name = strip_weight_suffix(name)
|
||||
return name
|
||||
|
||||
def strip_weight_suffix(name):
|
||||
# TensorFlow weight names all end in ":0", perhaps because they're surprised.
|
||||
# Keras 3 stops doing this, so this function strips the suffix if present to normalize weight names.
|
||||
if name.endswith(":0"):
|
||||
name = name[:-2]
|
||||
return name
|
||||
|
||||
|
||||
@ -889,6 +899,7 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
|
||||
# Read the H5 file
|
||||
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
|
||||
# Retrieve the name of each layer from the H5 file
|
||||
breakpoint()
|
||||
saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
|
||||
|
||||
# Find the missing layers from the high level list of layers
|
||||
@ -898,7 +909,6 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
|
||||
unexpected_layers = list(saved_h5_model_layers_name - {layer.name for layer in model.layers})
|
||||
saved_weight_names_set = set()
|
||||
symbolic_weights_names = set()
|
||||
weight_value_tuples = []
|
||||
|
||||
# Compute missing and unexpected sub layers
|
||||
# Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
|
||||
@ -927,21 +937,28 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
|
||||
|
||||
# Loop over each weights from the instantiated model and compare with the weights from the H5 file
|
||||
for symbolic_weight in symbolic_weights:
|
||||
if hasattr(symbolic_weight, "path"):
|
||||
# Keras 3 compatibility
|
||||
symbolic_weight_name = symbolic_weight.path
|
||||
else:
|
||||
symbolic_weight_name = symbolic_weight.name
|
||||
# TF names always start with the model name so we ignore it
|
||||
if _prefix is not None:
|
||||
delimeter = len(_prefix.split("/"))
|
||||
delimiter = len(_prefix.split("/"))
|
||||
symbolic_weight_name = "/".join(
|
||||
symbolic_weight.name.split("/")[:delimeter]
|
||||
+ symbolic_weight.name.split("/")[delimeter + 1 :]
|
||||
symbolic_weight_name.split("/")[:delimiter]
|
||||
+ symbolic_weight_name.split("/")[delimiter + 1 :]
|
||||
)
|
||||
else:
|
||||
symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])
|
||||
symbolic_weight_name = "/".join(symbolic_weight_name.split("/")[1:])
|
||||
|
||||
# here we check if the current weight is among the weights from the H5 file
|
||||
# If yes, get the weight_value of the corresponding weight from the H5 file
|
||||
# If not, make the value to None
|
||||
saved_weight_value = saved_weights.get(symbolic_weight_name, None)
|
||||
|
||||
# TODO Matt: Continue from here and figure out what's going on
|
||||
|
||||
# Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's
|
||||
# `model.shared/embeddings:0` are stored as `model.shared/weights:0`)
|
||||
if saved_weight_value is None and symbolic_weight_name.endswith("embeddings:0"):
|
||||
@ -954,15 +971,15 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
|
||||
# If the current weight is found
|
||||
if saved_weight_value is not None:
|
||||
# Check if the shape of the current weight and the one from the H5 file are different
|
||||
if K.int_shape(symbolic_weight) != saved_weight_value.shape:
|
||||
if symbolic_weight.shape != saved_weight_value.shape:
|
||||
# If yes we reshape the weight from the H5 file accordingly to the current weight
|
||||
# If the two shapes are not compatible we raise an issue
|
||||
try:
|
||||
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
|
||||
array = np.reshape(saved_weight_value, symbolic_weight.shape)
|
||||
except ValueError as e:
|
||||
if ignore_mismatched_sizes:
|
||||
mismatched_layers.append(
|
||||
(symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
|
||||
(symbolic_weight_name, saved_weight_value.shape, symbolic_weight.shape)
|
||||
)
|
||||
continue
|
||||
else:
|
||||
@ -970,11 +987,9 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
|
||||
else:
|
||||
array = saved_weight_value
|
||||
|
||||
# We create the tuple that will be loaded and add it to the final list
|
||||
weight_value_tuples.append((symbolic_weight, array))
|
||||
# Load the weight
|
||||
symbolic_weight.assign(array)
|
||||
|
||||
# Load all the weights
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
|
||||
# Compute the missing and unexpected layers
|
||||
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
|
||||
@ -984,34 +999,48 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
|
||||
|
||||
|
||||
def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
||||
keras_3_mode = any([hasattr(w, "path") for w in model.weights]) # First figure out what framework we've got
|
||||
# Read the safetensors file
|
||||
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
|
||||
mismatched_layers = []
|
||||
weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights]
|
||||
loaded_weight_names = list(safetensors_archive.keys())
|
||||
if keras_3_mode:
|
||||
weight_names = [strip_model_name_and_prefix(w.path, _prefix=_prefix) for w in model.weights]
|
||||
else:
|
||||
weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix, strip_suffix=True) for w in model.weights]
|
||||
archive_has_suffix = any([w.endswith(":0") for w in safetensors_archive.keys()])
|
||||
if archive_has_suffix:
|
||||
loaded_weight_names = [strip_weight_suffix(w) for w in safetensors_archive.keys()]
|
||||
else:
|
||||
loaded_weight_names = list(safetensors_archive.keys())
|
||||
# Find the missing layers from the high level list of layers
|
||||
missing_layers = list(set(weight_names) - set(loaded_weight_names))
|
||||
# Find the unexpected layers from the high level list of layers
|
||||
unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
|
||||
|
||||
for weight in model.weights:
|
||||
weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix)
|
||||
if keras_3_mode:
|
||||
weight_name = strip_model_name_and_prefix(weight.path, _prefix=_prefix)
|
||||
else:
|
||||
weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix, strip_suffix=True)
|
||||
if weight_name in loaded_weight_names:
|
||||
weight_value = safetensors_archive.get_tensor(weight_name)
|
||||
# Check if the shape of the current weight and the one from the H5 file are different
|
||||
if K.int_shape(weight) != weight_value.shape:
|
||||
if archive_has_suffix:
|
||||
weight_value = safetensors_archive.get_tensor(weight_name + ":0")
|
||||
else:
|
||||
weight_value = safetensors_archive.get_tensor(weight_name)
|
||||
# Check if the shape of the current weight and the one from the safetensors file are different
|
||||
if weight.shape != weight_value.shape:
|
||||
# If yes we reshape the weight from the H5 file accordingly to the current weight
|
||||
# If the two shapes are not compatible we raise an issue
|
||||
try:
|
||||
weight_value = tf.reshape(weight_value, K.int_shape(weight))
|
||||
weight_value = tf.reshape(weight_value, tuple(weight.shape))
|
||||
except (ValueError, tf.errors.InvalidArgumentError) as e:
|
||||
if ignore_mismatched_sizes:
|
||||
mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
|
||||
mismatched_layers.append((weight_name, tuple(weight_value.shape), tuple(weight.shape)))
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
|
||||
K.set_value(weight, weight_value) # weight.assign() might break if weight is a DTensor
|
||||
weight.assign(weight_value) # TODO Find a DTensor-compatible way to do this
|
||||
return missing_layers, unexpected_layers, mismatched_layers
|
||||
|
||||
|
||||
@ -3469,3 +3498,10 @@ def get_initializer(initializer_range: float = 0.02) -> tf.keras.initializers.Tr
|
||||
`tf.keras.initializers.TruncatedNormal`: The truncated normal initializer.
|
||||
"""
|
||||
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
|
||||
|
||||
|
||||
if hasattr(tf.keras, "name_scope"):
|
||||
# Keras 3 nests name scopes for us, so we don't need to manually enter them
|
||||
name_scope = nullcontext
|
||||
else:
|
||||
name_scope = tf.name_scope
|
@ -51,6 +51,7 @@ from ...modeling_tf_utils import (
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
name_scope,
|
||||
)
|
||||
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
|
||||
from ...utils import (
|
||||
@ -157,32 +158,59 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
|
||||
def build(self, input_shape=None):
|
||||
with tf.name_scope("word_embeddings"):
|
||||
if self.built:
|
||||
return
|
||||
self.built = True
|
||||
keras_3_mode = hasattr(tf.keras, "name_scope")
|
||||
if keras_3_mode:
|
||||
# Matt: self.add_weight() wraps weights in a self.name() scope in Keras 3, so we have to get a little
|
||||
# creative for cases like this where we need specific names. In future we could consider moving
|
||||
# these weights to their own layers, or adding a weight rename function so that we could stop doing
|
||||
# this and bring this layer more in sync with the PyTorch version.
|
||||
original_name = self.name
|
||||
self.name = "word_embeddings"
|
||||
self.weight = self.add_weight(
|
||||
name="weight",
|
||||
shape=[self.config.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
with tf.name_scope("token_type_embeddings"):
|
||||
self.name = "token_type_embeddings"
|
||||
self.token_type_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.config.type_vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
with tf.name_scope("position_embeddings"):
|
||||
self.name = "position_embeddings"
|
||||
self.position_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.max_position_embeddings, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
self.name = original_name
|
||||
else:
|
||||
with name_scope("word_embeddings"):
|
||||
self.weight = self.add_weight(
|
||||
name="weight",
|
||||
shape=[self.config.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
with name_scope("token_type_embeddings"):
|
||||
self.token_type_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.config.type_vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
with name_scope("position_embeddings"):
|
||||
self.position_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.max_position_embeddings, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
if self.built:
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "LayerNorm", None) is not None:
|
||||
with tf.name_scope(self.LayerNorm.name):
|
||||
with name_scope(self.LayerNorm.name):
|
||||
self.LayerNorm.build([None, None, self.config.hidden_size])
|
||||
|
||||
def call(
|
||||
@ -348,13 +376,13 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "query", None) is not None:
|
||||
with tf.name_scope(self.query.name):
|
||||
with name_scope(self.query.name):
|
||||
self.query.build([None, None, self.config.hidden_size])
|
||||
if getattr(self, "key", None) is not None:
|
||||
with tf.name_scope(self.key.name):
|
||||
with name_scope(self.key.name):
|
||||
self.key.build([None, None, self.config.hidden_size])
|
||||
if getattr(self, "value", None) is not None:
|
||||
with tf.name_scope(self.value.name):
|
||||
with name_scope(self.value.name):
|
||||
self.value.build([None, None, self.config.hidden_size])
|
||||
|
||||
|
||||
@ -381,10 +409,10 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "dense", None) is not None:
|
||||
with tf.name_scope(self.dense.name):
|
||||
with name_scope(self.dense.name):
|
||||
self.dense.build([None, None, self.config.hidden_size])
|
||||
if getattr(self, "LayerNorm", None) is not None:
|
||||
with tf.name_scope(self.LayerNorm.name):
|
||||
with name_scope(self.LayerNorm.name):
|
||||
self.LayerNorm.build([None, None, self.config.hidden_size])
|
||||
|
||||
|
||||
@ -432,10 +460,10 @@ class TFBertAttention(tf.keras.layers.Layer):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "self_attention", None) is not None:
|
||||
with tf.name_scope(self.self_attention.name):
|
||||
with name_scope(self.self_attention.name):
|
||||
self.self_attention.build(None)
|
||||
if getattr(self, "dense_output", None) is not None:
|
||||
with tf.name_scope(self.dense_output.name):
|
||||
with name_scope(self.dense_output.name):
|
||||
self.dense_output.build(None)
|
||||
|
||||
|
||||
@ -464,7 +492,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "dense", None) is not None:
|
||||
with tf.name_scope(self.dense.name):
|
||||
with name_scope(self.dense.name):
|
||||
self.dense.build([None, None, self.config.hidden_size])
|
||||
|
||||
|
||||
@ -491,10 +519,10 @@ class TFBertOutput(tf.keras.layers.Layer):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "dense", None) is not None:
|
||||
with tf.name_scope(self.dense.name):
|
||||
with name_scope(self.dense.name):
|
||||
self.dense.build([None, None, self.config.intermediate_size])
|
||||
if getattr(self, "LayerNorm", None) is not None:
|
||||
with tf.name_scope(self.LayerNorm.name):
|
||||
with name_scope(self.LayerNorm.name):
|
||||
self.LayerNorm.build([None, None, self.config.hidden_size])
|
||||
|
||||
|
||||
@ -588,16 +616,16 @@ class TFBertLayer(tf.keras.layers.Layer):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "attention", None) is not None:
|
||||
with tf.name_scope(self.attention.name):
|
||||
with name_scope(self.attention.name):
|
||||
self.attention.build(None)
|
||||
if getattr(self, "intermediate", None) is not None:
|
||||
with tf.name_scope(self.intermediate.name):
|
||||
with name_scope(self.intermediate.name):
|
||||
self.intermediate.build(None)
|
||||
if getattr(self, "bert_output", None) is not None:
|
||||
with tf.name_scope(self.bert_output.name):
|
||||
with name_scope(self.bert_output.name):
|
||||
self.bert_output.build(None)
|
||||
if getattr(self, "crossattention", None) is not None:
|
||||
with tf.name_scope(self.crossattention.name):
|
||||
with name_scope(self.crossattention.name):
|
||||
self.crossattention.build(None)
|
||||
|
||||
|
||||
@ -675,7 +703,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
|
||||
self.built = True
|
||||
if getattr(self, "layer", None) is not None:
|
||||
for layer in self.layer:
|
||||
with tf.name_scope(layer.name):
|
||||
with name_scope(layer.name):
|
||||
layer.build(None)
|
||||
|
||||
|
||||
@ -704,7 +732,7 @@ class TFBertPooler(tf.keras.layers.Layer):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "dense", None) is not None:
|
||||
with tf.name_scope(self.dense.name):
|
||||
with name_scope(self.dense.name):
|
||||
self.dense.build([None, None, self.config.hidden_size])
|
||||
|
||||
|
||||
@ -738,10 +766,10 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "dense", None) is not None:
|
||||
with tf.name_scope(self.dense.name):
|
||||
with name_scope(self.dense.name):
|
||||
self.dense.build([None, None, self.config.hidden_size])
|
||||
if getattr(self, "LayerNorm", None) is not None:
|
||||
with tf.name_scope(self.LayerNorm.name):
|
||||
with name_scope(self.LayerNorm.name):
|
||||
self.LayerNorm.build([None, None, self.config.hidden_size])
|
||||
|
||||
|
||||
@ -765,7 +793,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "transform", None) is not None:
|
||||
with tf.name_scope(self.transform.name):
|
||||
with name_scope(self.transform.name):
|
||||
self.transform.build(None)
|
||||
|
||||
def get_output_embeddings(self) -> tf.keras.layers.Layer:
|
||||
@ -809,7 +837,7 @@ class TFBertMLMHead(tf.keras.layers.Layer):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "predictions", None) is not None:
|
||||
with tf.name_scope(self.predictions.name):
|
||||
with name_scope(self.predictions.name):
|
||||
self.predictions.build(None)
|
||||
|
||||
|
||||
@ -834,7 +862,7 @@ class TFBertNSPHead(tf.keras.layers.Layer):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "seq_relationship", None) is not None:
|
||||
with tf.name_scope(self.seq_relationship.name):
|
||||
with name_scope(self.seq_relationship.name):
|
||||
self.seq_relationship.build([None, None, self.config.hidden_size])
|
||||
|
||||
|
||||
@ -1029,13 +1057,13 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "embeddings", None) is not None:
|
||||
with tf.name_scope(self.embeddings.name):
|
||||
with name_scope(self.embeddings.name):
|
||||
self.embeddings.build(None)
|
||||
if getattr(self, "encoder", None) is not None:
|
||||
with tf.name_scope(self.encoder.name):
|
||||
with name_scope(self.encoder.name):
|
||||
self.encoder.build(None)
|
||||
if getattr(self, "pooler", None) is not None:
|
||||
with tf.name_scope(self.pooler.name):
|
||||
with name_scope(self.pooler.name):
|
||||
self.pooler.build(None)
|
||||
|
||||
|
||||
@ -1255,7 +1283,7 @@ class TFBertModel(TFBertPreTrainedModel):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "bert", None) is not None:
|
||||
with tf.name_scope(self.bert.name):
|
||||
with name_scope(self.bert.name):
|
||||
self.bert.build(None)
|
||||
|
||||
|
||||
@ -1375,13 +1403,13 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "bert", None) is not None:
|
||||
with tf.name_scope(self.bert.name):
|
||||
with name_scope(self.bert.name):
|
||||
self.bert.build(None)
|
||||
if getattr(self, "nsp", None) is not None:
|
||||
with tf.name_scope(self.nsp.name):
|
||||
with name_scope(self.nsp.name):
|
||||
self.nsp.build(None)
|
||||
if getattr(self, "mlm", None) is not None:
|
||||
with tf.name_scope(self.mlm.name):
|
||||
with name_scope(self.mlm.name):
|
||||
self.mlm.build(None)
|
||||
|
||||
|
||||
@ -1475,10 +1503,10 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "bert", None) is not None:
|
||||
with tf.name_scope(self.bert.name):
|
||||
with name_scope(self.bert.name):
|
||||
self.bert.build(None)
|
||||
if getattr(self, "mlm", None) is not None:
|
||||
with tf.name_scope(self.mlm.name):
|
||||
with name_scope(self.mlm.name):
|
||||
self.mlm.build(None)
|
||||
|
||||
|
||||
@ -1611,10 +1639,10 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "bert", None) is not None:
|
||||
with tf.name_scope(self.bert.name):
|
||||
with name_scope(self.bert.name):
|
||||
self.bert.build(None)
|
||||
if getattr(self, "mlm", None) is not None:
|
||||
with tf.name_scope(self.mlm.name):
|
||||
with name_scope(self.mlm.name):
|
||||
self.mlm.build(None)
|
||||
|
||||
|
||||
@ -1704,10 +1732,10 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "bert", None) is not None:
|
||||
with tf.name_scope(self.bert.name):
|
||||
with name_scope(self.bert.name):
|
||||
self.bert.build(None)
|
||||
if getattr(self, "nsp", None) is not None:
|
||||
with tf.name_scope(self.nsp.name):
|
||||
with name_scope(self.nsp.name):
|
||||
self.nsp.build(None)
|
||||
|
||||
|
||||
@ -1802,10 +1830,10 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "bert", None) is not None:
|
||||
with tf.name_scope(self.bert.name):
|
||||
with name_scope(self.bert.name):
|
||||
self.bert.build(None)
|
||||
if getattr(self, "classifier", None) is not None:
|
||||
with tf.name_scope(self.classifier.name):
|
||||
with name_scope(self.classifier.name):
|
||||
self.classifier.build([None, None, self.config.hidden_size])
|
||||
|
||||
|
||||
@ -1913,10 +1941,10 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "bert", None) is not None:
|
||||
with tf.name_scope(self.bert.name):
|
||||
with name_scope(self.bert.name):
|
||||
self.bert.build(None)
|
||||
if getattr(self, "classifier", None) is not None:
|
||||
with tf.name_scope(self.classifier.name):
|
||||
with name_scope(self.classifier.name):
|
||||
self.classifier.build([None, None, self.config.hidden_size])
|
||||
|
||||
|
||||
@ -2015,10 +2043,10 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "bert", None) is not None:
|
||||
with tf.name_scope(self.bert.name):
|
||||
with name_scope(self.bert.name):
|
||||
self.bert.build(None)
|
||||
if getattr(self, "classifier", None) is not None:
|
||||
with tf.name_scope(self.classifier.name):
|
||||
with name_scope(self.classifier.name):
|
||||
self.classifier.build([None, None, self.config.hidden_size])
|
||||
|
||||
|
||||
@ -2129,8 +2157,8 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
||||
return
|
||||
self.built = True
|
||||
if getattr(self, "bert", None) is not None:
|
||||
with tf.name_scope(self.bert.name):
|
||||
with name_scope(self.bert.name):
|
||||
self.bert.build(None)
|
||||
if getattr(self, "qa_outputs", None) is not None:
|
||||
with tf.name_scope(self.qa_outputs.name):
|
||||
with name_scope(self.qa_outputs.name):
|
||||
self.qa_outputs.build([None, None, self.config.hidden_size])
|
||||
|
Reference in New Issue
Block a user