mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-18 00:40:50 +08:00
781 lines
34 KiB
Python
781 lines
34 KiB
Python
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from examples/modular-transformers/modular_dummy_bert.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_dummy_bert.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
from collections.abc import Callable
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
from ...masking_utils import create_causal_mask
|
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
|
from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available
|
|
from ...utils.generic import check_model_inputs
|
|
from .configuration_dummy_bert import DummyBertConfig
|
|
|
|
|
|
if is_torch_flex_attn_available():
|
|
from ...integrations.flex_attention import make_flex_block_causal_mask
|
|
|
|
|
|
class DummyBertEmbeddings(nn.Module):
|
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
|
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
self.register_buffer(
|
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
|
)
|
|
self.register_buffer(
|
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
past_key_values_length: int = 0,
|
|
) -> torch.Tensor:
|
|
if input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
else:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
batch_size, seq_length = input_shape
|
|
|
|
if position_ids is None:
|
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
|
|
|
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
|
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
|
# issue #5664
|
|
if token_type_ids is None:
|
|
if hasattr(self, "token_type_ids"):
|
|
# NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
|
|
buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
|
|
buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
|
|
token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
|
|
else:
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
embeddings = inputs_embeds + token_type_embeddings
|
|
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
embeddings = embeddings + position_embeddings
|
|
|
|
embeddings = self.LayerNorm(embeddings)
|
|
embeddings = self.dropout(embeddings)
|
|
return embeddings
|
|
|
|
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor],
|
|
scaling: Optional[float] = None,
|
|
dropout: float = 0.0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
):
|
|
if scaling is None:
|
|
scaling = query.size(-1) ** -0.5
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
|
|
if attention_mask is not None and attention_mask.ndim == 4:
|
|
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class DummyBertSelfAttention(nn.Module):
|
|
def __init__(self, config, is_causal=False, layer_idx=None):
|
|
super().__init__()
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
|
raise ValueError(
|
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
|
f"heads ({config.num_attention_heads})"
|
|
)
|
|
self.config = config
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
self.scaling = self.attention_head_size**-0.5
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
|
|
|
self.is_decoder = config.is_decoder
|
|
self.is_causal = is_causal
|
|
self.layer_idx = layer_idx
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
cache_position: Optional[torch.Tensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor]:
|
|
input_shape = hidden_states.shape[:-1]
|
|
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
|
|
|
# get all proj
|
|
query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
|
|
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
|
|
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
|
|
|
|
if past_key_value is not None:
|
|
# decoder-only dummy_bert can have a simple dynamic cache for example
|
|
current_past_key_value = past_key_value
|
|
if isinstance(past_key_value, EncoderDecoderCache):
|
|
current_past_key_value = past_key_value.self_attention_cache
|
|
|
|
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
|
|
key_layer, value_layer = current_past_key_value.update(
|
|
key_layer,
|
|
value_layer,
|
|
self.layer_idx,
|
|
{"cache_position": cache_position},
|
|
)
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
if self.config._attn_implementation != "eager":
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_layer,
|
|
key_layer,
|
|
value_layer,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.dropout.p,
|
|
scaling=self.scaling,
|
|
**kwargs,
|
|
)
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class DummyBertCrossAttention(nn.Module):
|
|
def __init__(self, config, is_causal=False, layer_idx=None):
|
|
super().__init__()
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
|
raise ValueError(
|
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
|
f"heads ({config.num_attention_heads})"
|
|
)
|
|
self.config = config
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
self.scaling = self.attention_head_size**-0.5
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
|
|
|
self.is_causal = is_causal
|
|
self.layer_idx = layer_idx
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
past_key_value: Optional[EncoderDecoderCache] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor]:
|
|
# determine input shapes
|
|
bsz, tgt_len = hidden_states.shape[:-1]
|
|
src_len = encoder_hidden_states.shape[1]
|
|
|
|
q_input_shape = (bsz, tgt_len, -1, self.attention_head_size)
|
|
kv_input_shape = (bsz, src_len, -1, self.attention_head_size)
|
|
|
|
# get query proj
|
|
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)
|
|
|
|
is_updated = past_key_value.is_updated.get(self.layer_idx) if past_key_value is not None else False
|
|
if past_key_value is not None and is_updated:
|
|
# reuse k,v, cross_attentions
|
|
key_layer = past_key_value.cross_attention_cache.layers[self.layer_idx].keys
|
|
value_layer = past_key_value.cross_attention_cache.layers[self.layer_idx].values
|
|
else:
|
|
key_layer = self.key(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
|
|
value_layer = self.value(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
|
|
|
|
if past_key_value is not None:
|
|
# save all states to the cache
|
|
key_layer, value_layer = past_key_value.cross_attention_cache.update(
|
|
key_layer, value_layer, self.layer_idx
|
|
)
|
|
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
|
past_key_value.is_updated[self.layer_idx] = True
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
if self.config._attn_implementation != "eager":
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_layer,
|
|
key_layer,
|
|
value_layer,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.dropout.p,
|
|
scaling=self.scaling,
|
|
**kwargs,
|
|
)
|
|
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class DummyBertSelfOutput(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
class DummyBertAttention(nn.Module):
|
|
def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
|
|
super().__init__()
|
|
self.is_cross_attention = is_cross_attention
|
|
attention_class = DummyBertCrossAttention if is_cross_attention else DummyBertSelfAttention
|
|
self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
|
|
self.output = DummyBertSelfOutput(config)
|
|
self.pruned_heads = set()
|
|
|
|
def prune_heads(self, heads):
|
|
if len(heads) == 0:
|
|
return
|
|
heads, index = find_pruneable_heads_and_indices(
|
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
|
)
|
|
|
|
# Prune linear layers
|
|
self.self.query = prune_linear_layer(self.self.query, index)
|
|
self.self.key = prune_linear_layer(self.self.key, index)
|
|
self.self.value = prune_linear_layer(self.self.value, index)
|
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
|
|
|
# Update hyper params and store pruned heads
|
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
|
self.pruned_heads = self.pruned_heads.union(heads)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
cache_position: Optional[torch.Tensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor]:
|
|
attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
|
|
attention_output, attn_weights = self.self(
|
|
hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
attention_mask=attention_mask,
|
|
past_key_value=past_key_value,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
attention_output = self.output(attention_output, hidden_states)
|
|
return attention_output, attn_weights
|
|
|
|
|
|
class DummyBertIntermediate(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
if isinstance(config.hidden_act, str):
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.intermediate_act_fn = config.hidden_act
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class DummyBertOutput(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
class DummyBertLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config, layer_idx=None):
|
|
super().__init__()
|
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
|
self.seq_len_dim = 1
|
|
self.attention = DummyBertAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
|
|
self.is_decoder = config.is_decoder
|
|
self.add_cross_attention = config.add_cross_attention
|
|
if self.add_cross_attention:
|
|
if not self.is_decoder:
|
|
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
|
self.crossattention = DummyBertAttention(
|
|
config,
|
|
is_causal=False,
|
|
layer_idx=layer_idx,
|
|
is_cross_attention=True,
|
|
)
|
|
self.intermediate = DummyBertIntermediate(config)
|
|
self.output = DummyBertOutput(config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
cache_position: Optional[torch.Tensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor]:
|
|
self_attention_output, _ = self.attention(
|
|
hidden_states,
|
|
attention_mask,
|
|
past_key_value=past_key_value,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
attention_output = self_attention_output
|
|
|
|
if self.is_decoder and encoder_hidden_states is not None:
|
|
if not hasattr(self, "crossattention"):
|
|
raise ValueError(
|
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
|
" by setting `config.add_cross_attention=True`"
|
|
)
|
|
|
|
cross_attention_output, _ = self.crossattention(
|
|
self_attention_output,
|
|
None, # attention_mask
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
past_key_value=past_key_value,
|
|
**kwargs,
|
|
)
|
|
attention_output = cross_attention_output
|
|
|
|
layer_output = apply_chunking_to_forward(
|
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
|
)
|
|
return layer_output
|
|
|
|
def feed_forward_chunk(self, attention_output):
|
|
intermediate_output = self.intermediate(attention_output)
|
|
layer_output = self.output(intermediate_output, attention_output)
|
|
return layer_output
|
|
|
|
|
|
class DummyBertEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer = nn.ModuleList([DummyBertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
past_key_values: Optional[Cache] = None,
|
|
use_cache: Optional[bool] = None,
|
|
cache_position: Optional[torch.Tensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
|
for i, layer_module in enumerate(self.layer):
|
|
hidden_states = layer_module(
|
|
hidden_states,
|
|
attention_mask,
|
|
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
past_key_value=past_key_values,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values if use_cache else None,
|
|
)
|
|
|
|
|
|
class DummyBertPooler(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.activation = nn.Tanh()
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
# We "pool" the model by simply taking the hidden state corresponding
|
|
# to the first token.
|
|
first_token_tensor = hidden_states[:, 0]
|
|
pooled_output = self.dense(first_token_tensor)
|
|
pooled_output = self.activation(pooled_output)
|
|
return pooled_output
|
|
|
|
|
|
class DummyBertPredictionHeadTransform(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
if isinstance(config.hidden_act, str):
|
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.transform_act_fn = config.hidden_act
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.transform_act_fn(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class DummyBertLMPredictionHead(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.transform = DummyBertPredictionHeadTransform(config)
|
|
|
|
# The output weights are the same as the input embeddings, but there is
|
|
# an output-only bias for each token.
|
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
|
|
|
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
|
self.decoder.bias = self.bias
|
|
|
|
def _tie_weights(self):
|
|
self.decoder.bias = self.bias
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.transform(hidden_states)
|
|
hidden_states = self.decoder(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
@auto_docstring
|
|
class DummyBertPreTrainedModel(PreTrainedModel):
|
|
config_class = DummyBertConfig
|
|
base_model_prefix = "dummy_bert"
|
|
supports_gradient_checkpointing = True
|
|
_supports_flash_attn = True
|
|
_supports_sdpa = True
|
|
_supports_flex_attn = True
|
|
_supports_attention_backend = True
|
|
_can_record_outputs = {
|
|
"hidden_states": DummyBertLayer,
|
|
"attentions": DummyBertSelfAttention,
|
|
"cross_attentions": DummyBertCrossAttention,
|
|
}
|
|
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights"""
|
|
if isinstance(module, nn.Linear):
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
if module.padding_idx is not None:
|
|
module.weight.data[module.padding_idx].zero_()
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
elif isinstance(module, DummyBertLMPredictionHead):
|
|
module.bias.data.zero_()
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
|
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
|
all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
|
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
|
|
|
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
|
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
|
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
|
"""
|
|
)
|
|
class DummyBertModel(DummyBertPreTrainedModel):
|
|
_no_split_modules = ["DummyBertEmbeddings", "DummyBertLayer"]
|
|
|
|
def __init__(self, config, add_pooling_layer=True):
|
|
r"""
|
|
add_pooling_layer (bool, *optional*, defaults to `True`):
|
|
Whether to add a pooling layer
|
|
"""
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.gradient_checkpointing = False
|
|
|
|
self.embeddings = DummyBertEmbeddings(config)
|
|
self.encoder = DummyBertEncoder(config)
|
|
|
|
self.pooler = DummyBertPooler(config) if add_pooling_layer else None
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.word_embeddings = value
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
"""
|
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
|
class PreTrainedModel
|
|
"""
|
|
for layer, heads in heads_to_prune.items():
|
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
|
|
|
@check_model_inputs()
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.Tensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
|
if self.config.is_decoder:
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
else:
|
|
use_cache = False
|
|
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if input_ids is not None:
|
|
device = input_ids.device
|
|
input_shape = input_ids.shape
|
|
else:
|
|
device = inputs_embeds.device
|
|
input_shape = inputs_embeds.shape[:-1]
|
|
|
|
seq_length = input_shape[1]
|
|
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
if cache_position is None:
|
|
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
|
|
|
|
embedding_output = self.embeddings(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
token_type_ids=token_type_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
past_key_values_length=past_key_values_length,
|
|
)
|
|
|
|
attention_mask, encoder_attention_mask = self._create_attention_masks(
|
|
input_shape=input_shape,
|
|
attention_mask=attention_mask,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
embedding_output=embedding_output,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
)
|
|
|
|
encoder_outputs = self.encoder(
|
|
embedding_output,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_ids=position_ids,
|
|
**kwargs,
|
|
)
|
|
sequence_output = encoder_outputs.last_hidden_state
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
|
|
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
past_key_values=encoder_outputs.past_key_values,
|
|
)
|
|
|
|
def _create_attention_masks(
|
|
self,
|
|
input_shape,
|
|
attention_mask,
|
|
encoder_attention_mask,
|
|
embedding_output,
|
|
encoder_hidden_states,
|
|
cache_position,
|
|
past_key_values,
|
|
):
|
|
if attention_mask is not None and attention_mask.dim() == 2:
|
|
if self.config.is_decoder:
|
|
attention_mask = create_causal_mask(
|
|
config=self.config,
|
|
input_embeds=embedding_output,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
)
|
|
else:
|
|
attention_mask = self._update_full_mask(
|
|
attention_mask,
|
|
embedding_output,
|
|
)
|
|
elif attention_mask is not None and attention_mask.dim() == 3:
|
|
if "flash" in self.config._attn_implementation or self.config._attn_implementation == "flex_attention":
|
|
raise ValueError(
|
|
"Passing attention mask with a 3D/4D shape does not work with type "
|
|
f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead."
|
|
)
|
|
attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
|
|
|
if encoder_attention_mask is not None:
|
|
if encoder_attention_mask.dim() == 2:
|
|
encoder_attention_mask = self._update_cross_attn_mask(
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
embedding_output.shape[:2],
|
|
embedding_output,
|
|
)
|
|
else:
|
|
if "flash" in self.config._attn_implementation or self.config._attn_implementation == "flex_attention":
|
|
raise ValueError(
|
|
"Passing attention mask with a 3D/4D shape does not work with type "
|
|
f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead."
|
|
)
|
|
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
|
|
return attention_mask, encoder_attention_mask
|
|
|
|
def _update_full_mask(
|
|
self,
|
|
attention_mask: Union[torch.Tensor, None],
|
|
inputs_embeds: torch.Tensor,
|
|
):
|
|
if attention_mask is not None:
|
|
if "flash" in self.config._attn_implementation:
|
|
attention_mask = attention_mask if 0 in attention_mask else None
|
|
elif self.config._attn_implementation == "sdpa":
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
|
elif self.config._attn_implementation == "flex_attention":
|
|
if isinstance(attention_mask, torch.Tensor):
|
|
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
|
|
else:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
|
|
|
return attention_mask
|
|
|
|
def _update_cross_attn_mask(
|
|
self,
|
|
encoder_hidden_states: Union[torch.Tensor, None],
|
|
encoder_attention_mask: Union[torch.Tensor, None],
|
|
input_shape: torch.Size,
|
|
inputs_embeds: torch.Tensor,
|
|
):
|
|
# expand encoder attention mask
|
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
|
if "flash" in self.config._attn_implementation:
|
|
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
|
elif self.config._attn_implementation == "sdpa":
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
|
encoder_attention_mask,
|
|
inputs_embeds.dtype,
|
|
tgt_len=input_shape[-1],
|
|
)
|
|
elif self.config._attn_implementation == "flex_attention":
|
|
if isinstance(encoder_attention_mask, torch.Tensor):
|
|
encoder_attention_mask = make_flex_block_causal_mask(
|
|
encoder_attention_mask,
|
|
query_length=input_shape[-1],
|
|
is_causal=False,
|
|
)
|
|
else:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
encoder_attention_mask = _prepare_4d_attention_mask(
|
|
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
)
|
|
|
|
return encoder_attention_mask
|