mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
🚨 [DistilBert
] Refactor Attention (#41163)
* refactor * allow pos ids for flattened sequences
This commit is contained in:
@ -18,8 +18,7 @@ PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://g
|
||||
part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -29,8 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import get_activation
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
@ -40,21 +38,25 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import (
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...utils import (
|
||||
TransformersKwargs,
|
||||
auto_docstring,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
)
|
||||
from ...utils.generic import can_return_tuple, check_model_inputs
|
||||
from .configuration_distilbert import DistilBertConfig
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
if is_torch_flex_attn_available():
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -94,31 +96,26 @@ class Embeddings(nn.Module):
|
||||
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||||
)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Parameters:
|
||||
input_ids (torch.Tensor):
|
||||
torch.tensor(bs, max_seq_length) The token ids to embed.
|
||||
input_embeds (*optional*, torch.Tensor):
|
||||
The pre-computed word embeddings. Can only be passed if the input ids are `None`.
|
||||
|
||||
|
||||
Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
|
||||
embeddings)
|
||||
"""
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
input_embeds: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
|
||||
|
||||
seq_length = input_embeds.size(1)
|
||||
|
||||
# Setting the position-ids to the registered buffer in constructor, it helps
|
||||
# when tracing the model without passing position-ids, solves
|
||||
# issues similar to issue #5664
|
||||
if hasattr(self, "position_ids"):
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
else:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
|
||||
if position_ids is None:
|
||||
# Setting the position-ids to the registered buffer in constructor, it helps
|
||||
# when tracing the model without passing position-ids, solves
|
||||
# issues similar to issue #5664
|
||||
if hasattr(self, "position_ids"):
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
else:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
|
||||
|
||||
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
||||
|
||||
@ -128,15 +125,42 @@ class Embeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class MultiHeadSelfAttention(nn.Module):
|
||||
# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DistilBertSelfAttention(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.dropout = nn.Dropout(p=config.attention_dropout)
|
||||
self.is_causal = False
|
||||
self.attention_head_size = self.dim // self.n_heads
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
# Have an even number of multi heads that divide the dimensions
|
||||
if self.dim % self.n_heads != 0:
|
||||
@ -148,8 +172,10 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
||||
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
||||
|
||||
self.dropout = nn.Dropout(p=config.attention_dropout)
|
||||
self.is_causal = False
|
||||
|
||||
self.pruned_heads: set[int] = set()
|
||||
self.attention_head_size = self.dim // self.n_heads
|
||||
|
||||
def prune_heads(self, heads: list[int]):
|
||||
if len(heads) == 0:
|
||||
@ -169,231 +195,35 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Parameters:
|
||||
query: torch.tensor(bs, seq_length, dim)
|
||||
key: torch.tensor(bs, seq_length, dim)
|
||||
value: torch.tensor(bs, seq_length, dim)
|
||||
mask: torch.tensor(bs, seq_length)
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
Returns:
|
||||
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
|
||||
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
|
||||
"""
|
||||
bs, q_length, dim = query.size()
|
||||
k_length = key.size(1)
|
||||
# assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
|
||||
# assert key.size() == value.size()
|
||||
# get all proj
|
||||
query_layer = self.q_lin(hidden_states).view(*hidden_shape).transpose(1, 2)
|
||||
key_layer = self.k_lin(hidden_states).view(*hidden_shape).transpose(1, 2)
|
||||
value_layer = self.v_lin(hidden_states).view(*hidden_shape).transpose(1, 2)
|
||||
|
||||
dim_per_head = self.dim // self.n_heads
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
mask_reshp = (bs, 1, 1, k_length)
|
||||
|
||||
def shape(x: torch.Tensor) -> torch.Tensor:
|
||||
"""separate heads"""
|
||||
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
|
||||
|
||||
def unshape(x: torch.Tensor) -> torch.Tensor:
|
||||
"""group heads"""
|
||||
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
|
||||
|
||||
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
|
||||
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
|
||||
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
|
||||
|
||||
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
|
||||
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
|
||||
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
|
||||
scores = scores.masked_fill(
|
||||
mask, torch.tensor(torch.finfo(scores.dtype).min)
|
||||
) # (bs, n_heads, q_length, k_length)
|
||||
|
||||
weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
|
||||
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
|
||||
|
||||
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
|
||||
context = unshape(context) # (bs, q_length, dim)
|
||||
context = self.out_lin(context) # (bs, q_length, dim)
|
||||
|
||||
if output_attentions:
|
||||
return (context, weights)
|
||||
else:
|
||||
return (context,)
|
||||
|
||||
|
||||
class DistilBertFlashAttention2(MultiHeadSelfAttention):
|
||||
"""
|
||||
DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module
|
||||
stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
|
||||
API of flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Parameters:
|
||||
query: torch.tensor(bs, seq_length, dim)
|
||||
key: torch.tensor(bs, seq_length, dim)
|
||||
value: torch.tensor(bs, seq_length, dim)
|
||||
mask: torch.tensor(bs, seq_length)
|
||||
|
||||
Returns:
|
||||
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
|
||||
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
|
||||
"""
|
||||
batch_size, q_length, dim = query.size()
|
||||
|
||||
dim_per_head = self.dim // self.n_heads
|
||||
|
||||
def reshape(x: torch.Tensor) -> torch.Tensor:
|
||||
"""separate heads"""
|
||||
return x.view(batch_size, -1, self.n_heads, dim_per_head)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
query_states = reshape(self.q_lin(query))
|
||||
key_states = reshape(self.k_lin(key))
|
||||
value_states = reshape(self.v_lin(value))
|
||||
|
||||
attn_dropout = self.config.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if query_states.dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_lin.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_weights = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
mask,
|
||||
q_length,
|
||||
dropout=attn_dropout,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout.p,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head)
|
||||
attn_output = self.out_lin(attn_weights_reshaped)
|
||||
|
||||
if output_attentions:
|
||||
return (attn_output, attn_weights)
|
||||
else:
|
||||
return (attn_output,)
|
||||
|
||||
|
||||
class DistilBertSdpaAttention(MultiHeadSelfAttention):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__(config=config)
|
||||
self.dropout_prob = config.attention_dropout
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Parameters:
|
||||
query: torch.tensor(bs, seq_length, dim)
|
||||
key: torch.tensor(bs, seq_length, dim)
|
||||
value: torch.tensor(bs, seq_length, dim)
|
||||
mask: torch.tensor(bs, seq_length)
|
||||
|
||||
Returns:
|
||||
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
|
||||
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
|
||||
"""
|
||||
if output_attentions:
|
||||
logger.warning_once(
|
||||
"DistilBertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
|
||||
" `output_attentions=True`. Falling back to the manual attention implementation, but specifying"
|
||||
" the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
|
||||
' removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
mask,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
batch_size, _, _ = query.size()
|
||||
dim_per_head = self.dim // self.n_heads
|
||||
|
||||
def shape(x: torch.Tensor) -> torch.Tensor:
|
||||
"""separate heads"""
|
||||
return x.view(batch_size, -1, self.n_heads, dim_per_head).transpose(1, 2)
|
||||
|
||||
def unshape(x: torch.Tensor) -> torch.Tensor:
|
||||
"""group heads"""
|
||||
return x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * dim_per_head)
|
||||
|
||||
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
|
||||
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
|
||||
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=mask,
|
||||
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
attn_output = unshape(attn_output)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.out_lin(attn_output)
|
||||
|
||||
return (attn_output,)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
@ -417,13 +247,6 @@ class FFN(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
DISTILBERT_ATTENTION_CLASSES = {
|
||||
"eager": MultiHeadSelfAttention,
|
||||
"flash_attention_2": DistilBertFlashAttention2,
|
||||
"sdpa": DistilBertSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class TransformerBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
@ -432,7 +255,7 @@ class TransformerBlock(GradientCheckpointingLayer):
|
||||
if config.dim % config.n_heads != 0:
|
||||
raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly")
|
||||
|
||||
self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = DistilBertSelfAttention(config)
|
||||
self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
|
||||
|
||||
self.ffn = FFN(config)
|
||||
@ -440,44 +263,23 @@ class TransformerBlock(GradientCheckpointingLayer):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Parameters:
|
||||
x: torch.tensor(bs, seq_length, dim)
|
||||
attn_mask: torch.tensor(bs, seq_length)
|
||||
|
||||
Returns:
|
||||
sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
|
||||
torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
|
||||
"""
|
||||
# Self-Attention
|
||||
sa_output = self.attention(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
mask=attn_mask,
|
||||
output_attentions=output_attentions,
|
||||
attention_output, _ = self.attention(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
if output_attentions:
|
||||
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
|
||||
else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
|
||||
if type(sa_output) is not tuple:
|
||||
raise TypeError(f"sa_output must be a tuple but it is {type(sa_output)} type")
|
||||
|
||||
sa_output = sa_output[0]
|
||||
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
|
||||
attention_output = self.sa_layer_norm(attention_output + hidden_states)
|
||||
|
||||
# Feed Forward Network
|
||||
ffn_output = self.ffn(sa_output) # (bs, seq_length, dim)
|
||||
ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
|
||||
ffn_output = self.ffn(attention_output)
|
||||
ffn_output = self.output_layer_norm(ffn_output + attention_output)
|
||||
|
||||
output = (ffn_output,)
|
||||
if output_attentions:
|
||||
output = (sa_weights,) + output
|
||||
return output
|
||||
return ffn_output
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
@ -489,61 +291,18 @@ class Transformer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[BaseModelOutput, tuple[torch.Tensor, ...]]: # docstyle-ignore
|
||||
"""
|
||||
Parameters:
|
||||
x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
|
||||
attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.
|
||||
|
||||
Returns:
|
||||
hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)
|
||||
layer all_hidden_states: tuple[torch.tensor(bs, seq_length, dim)]
|
||||
Tuple of length n_layers with the hidden states from each layer.
|
||||
Optional: only if output_hidden_states=True
|
||||
all_attentions: tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
|
||||
Tuple of length n_layers with the attention weights from each layer
|
||||
Optional: only if output_attentions=True
|
||||
"""
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_state = x
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_state,
|
||||
attn_mask,
|
||||
output_attentions,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[BaseModelOutput]:
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_state = layer_outputs[-1]
|
||||
|
||||
if output_attentions:
|
||||
if len(layer_outputs) != 2:
|
||||
raise ValueError(f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}")
|
||||
|
||||
attentions = layer_outputs[0]
|
||||
all_attentions = all_attentions + (attentions,)
|
||||
else:
|
||||
if len(layer_outputs) != 1:
|
||||
raise ValueError(f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}")
|
||||
|
||||
# Add last layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
)
|
||||
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
|
||||
@ -554,6 +313,12 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"hidden_states": TransformerBlock,
|
||||
"attentions": DistilBertSelfAttention,
|
||||
}
|
||||
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights."""
|
||||
@ -647,15 +412,15 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.transformer.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[BaseModelOutput, tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`):
|
||||
@ -670,45 +435,43 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
embeddings = self.embeddings(input_ids, inputs_embeds, position_ids)
|
||||
|
||||
attention_mask = self._update_full_mask(
|
||||
attention_mask,
|
||||
embeddings,
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
|
||||
|
||||
if self.config._attn_implementation == "sdpa" and not output_attentions:
|
||||
attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
attention_mask, embeddings.dtype, tgt_len=input_shape[1]
|
||||
)
|
||||
|
||||
return self.transformer(
|
||||
x=embeddings,
|
||||
attn_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
hidden_states=embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
|
||||
def _update_full_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, None],
|
||||
inputs_embeds: torch.Tensor,
|
||||
):
|
||||
if attention_mask is not None:
|
||||
if "flash" in self.config._attn_implementation:
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
elif self.config._attn_implementation == "sdpa":
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
||||
elif self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
@ -759,6 +522,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
def set_output_embeddings(self, new_embeddings: nn.Module):
|
||||
self.vocab_projector = new_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -766,9 +530,8 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[MaskedLMOutput, tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`):
|
||||
@ -787,15 +550,13 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||||
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
dlbrt_output = self.distilbert(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
position_ids=position_ids,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
|
||||
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
|
||||
@ -807,10 +568,6 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
if labels is not None:
|
||||
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_logits,) + dlbrt_output[1:]
|
||||
return ((mlm_loss,) + output) if mlm_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=mlm_loss,
|
||||
logits=prediction_logits,
|
||||
@ -859,6 +616,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
||||
"""
|
||||
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -866,9 +624,8 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[SequenceClassifierOutput, tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -876,15 +633,13 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
distilbert_output = self.distilbert(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
position_ids=position_ids,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
||||
pooled_output = hidden_state[:, 0] # (bs, dim)
|
||||
@ -916,10 +671,6 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + distilbert_output[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -963,6 +714,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
||||
"""
|
||||
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -971,9 +723,8 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
start_positions: Optional[torch.Tensor] = None,
|
||||
end_positions: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[QuestionAnsweringModelOutput, tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`):
|
||||
@ -988,15 +739,13 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
distilbert_output = self.distilbert(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
position_ids=position_ids,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
|
||||
|
||||
@ -1023,10 +772,6 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + distilbert_output[1:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
@ -1069,6 +814,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
|
||||
"""
|
||||
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1076,23 +822,20 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[TokenClassifierOutput, tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.distilbert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
position_ids=position_ids,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1105,10 +848,6 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -1150,6 +889,7 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
|
||||
"""
|
||||
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1157,9 +897,8 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[MultipleChoiceModelOutput, tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
||||
@ -1199,7 +938,6 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
|
||||
>>> loss = outputs.loss
|
||||
>>> logits = outputs.logits
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||
@ -1214,9 +952,9 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
position_ids=position_ids,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_state = outputs[0] # (bs * num_choices, seq_len, dim)
|
||||
@ -1233,10 +971,6 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return MultipleChoiceModelOutput(
|
||||
loss=loss,
|
||||
logits=reshaped_logits,
|
||||
|
@ -223,7 +223,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
fx_compatible = True
|
||||
fx_compatible = False # won't be maintained
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
test_resize_position_embeddings = True
|
||||
|
Reference in New Issue
Block a user