🚨 [DistilBert] Refactor Attention (#41163)

* refactor

* allow pos ids for flattened sequences
This commit is contained in:
Anton Vlasjuk
2025-10-02 17:50:48 +02:00
committed by GitHub
parent e54defcfc2
commit da3c7d1d36
2 changed files with 175 additions and 441 deletions

View File

@ -18,8 +18,7 @@ PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://g
part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
"""
import math
from typing import Optional, Union
from typing import Callable, Optional, Union
import numpy as np
import torch
@ -29,8 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import get_activation
from ...configuration_utils import PretrainedConfig
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
@ -40,21 +38,25 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...utils import (
TransformersKwargs,
auto_docstring,
is_torch_flex_attn_available,
logging,
)
from ...utils.generic import can_return_tuple, check_model_inputs
from .configuration_distilbert import DistilBertConfig
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@ -94,31 +96,26 @@ class Embeddings(nn.Module):
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Parameters:
input_ids (torch.Tensor):
torch.tensor(bs, max_seq_length) The token ids to embed.
input_embeds (*optional*, torch.Tensor):
The pre-computed word embeddings. Can only be passed if the input ids are `None`.
Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
embeddings)
"""
def forward(
self,
input_ids: torch.Tensor,
input_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if input_ids is not None:
input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
seq_length = input_embeds.size(1)
# Setting the position-ids to the registered buffer in constructor, it helps
# when tracing the model without passing position-ids, solves
# issues similar to issue #5664
if hasattr(self, "position_ids"):
position_ids = self.position_ids[:, :seq_length]
else:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
if position_ids is None:
# Setting the position-ids to the registered buffer in constructor, it helps
# when tracing the model without passing position-ids, solves
# issues similar to issue #5664
if hasattr(self, "position_ids"):
position_ids = self.position_ids[:, :seq_length]
else:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
@ -128,15 +125,42 @@ class Embeddings(nn.Module):
return embeddings
class MultiHeadSelfAttention(nn.Module):
# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
**kwargs,
):
if scaling is None:
scaling = query.size(-1) ** -0.5
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class DistilBertSelfAttention(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.n_heads = config.n_heads
self.dim = config.dim
self.dropout = nn.Dropout(p=config.attention_dropout)
self.is_causal = False
self.attention_head_size = self.dim // self.n_heads
self.scaling = self.attention_head_size**-0.5
# Have an even number of multi heads that divide the dimensions
if self.dim % self.n_heads != 0:
@ -148,8 +172,10 @@ class MultiHeadSelfAttention(nn.Module):
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.dropout = nn.Dropout(p=config.attention_dropout)
self.is_causal = False
self.pruned_heads: set[int] = set()
self.attention_head_size = self.dim // self.n_heads
def prune_heads(self, heads: list[int]):
if len(heads) == 0:
@ -169,231 +195,35 @@ class MultiHeadSelfAttention(nn.Module):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor,
output_attentions: bool = False,
) -> tuple[torch.Tensor, ...]:
"""
Parameters:
query: torch.tensor(bs, seq_length, dim)
key: torch.tensor(bs, seq_length, dim)
value: torch.tensor(bs, seq_length, dim)
mask: torch.tensor(bs, seq_length)
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
Returns:
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
"""
bs, q_length, dim = query.size()
k_length = key.size(1)
# assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
# assert key.size() == value.size()
# get all proj
query_layer = self.q_lin(hidden_states).view(*hidden_shape).transpose(1, 2)
key_layer = self.k_lin(hidden_states).view(*hidden_shape).transpose(1, 2)
value_layer = self.v_lin(hidden_states).view(*hidden_shape).transpose(1, 2)
dim_per_head = self.dim // self.n_heads
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
mask_reshp = (bs, 1, 1, k_length)
def shape(x: torch.Tensor) -> torch.Tensor:
"""separate heads"""
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x: torch.Tensor) -> torch.Tensor:
"""group heads"""
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(
mask, torch.tensor(torch.finfo(scores.dtype).min)
) # (bs, n_heads, q_length, k_length)
weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim)
if output_attentions:
return (context, weights)
else:
return (context,)
class DistilBertFlashAttention2(MultiHeadSelfAttention):
"""
DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module
stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
API of flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor,
output_attentions: bool = False,
) -> tuple[torch.Tensor, ...]:
"""
Parameters:
query: torch.tensor(bs, seq_length, dim)
key: torch.tensor(bs, seq_length, dim)
value: torch.tensor(bs, seq_length, dim)
mask: torch.tensor(bs, seq_length)
Returns:
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
"""
batch_size, q_length, dim = query.size()
dim_per_head = self.dim // self.n_heads
def reshape(x: torch.Tensor) -> torch.Tensor:
"""separate heads"""
return x.view(batch_size, -1, self.n_heads, dim_per_head)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
query_states = reshape(self.q_lin(query))
key_states = reshape(self.k_lin(key))
value_states = reshape(self.v_lin(value))
attn_dropout = self.config.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if query_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_lin.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_weights = _flash_attention_forward(
query_states,
key_states,
value_states,
mask,
q_length,
dropout=attn_dropout,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
attn_output, attn_weights = attention_interface(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
dropout=0.0 if not self.training else self.dropout.p,
scaling=self.scaling,
**kwargs,
)
attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head)
attn_output = self.out_lin(attn_weights_reshaped)
if output_attentions:
return (attn_output, attn_weights)
else:
return (attn_output,)
class DistilBertSdpaAttention(MultiHeadSelfAttention):
def __init__(self, config: PretrainedConfig):
super().__init__(config=config)
self.dropout_prob = config.attention_dropout
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor,
output_attentions: bool = False,
) -> tuple[torch.Tensor, ...]:
"""
Parameters:
query: torch.tensor(bs, seq_length, dim)
key: torch.tensor(bs, seq_length, dim)
value: torch.tensor(bs, seq_length, dim)
mask: torch.tensor(bs, seq_length)
Returns:
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
"""
if output_attentions:
logger.warning_once(
"DistilBertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
" `output_attentions=True`. Falling back to the manual attention implementation, but specifying"
" the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
' removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
query,
key,
value,
mask,
output_attentions,
)
batch_size, _, _ = query.size()
dim_per_head = self.dim // self.n_heads
def shape(x: torch.Tensor) -> torch.Tensor:
"""separate heads"""
return x.view(batch_size, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x: torch.Tensor) -> torch.Tensor:
"""group heads"""
return x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=False,
)
attn_output = unshape(attn_output)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_lin(attn_output)
return (attn_output,)
return attn_output, attn_weights
class FFN(nn.Module):
@ -417,13 +247,6 @@ class FFN(nn.Module):
return x
DISTILBERT_ATTENTION_CLASSES = {
"eager": MultiHeadSelfAttention,
"flash_attention_2": DistilBertFlashAttention2,
"sdpa": DistilBertSdpaAttention,
}
class TransformerBlock(GradientCheckpointingLayer):
def __init__(self, config: PretrainedConfig):
super().__init__()
@ -432,7 +255,7 @@ class TransformerBlock(GradientCheckpointingLayer):
if config.dim % config.n_heads != 0:
raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly")
self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = DistilBertSelfAttention(config)
self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
self.ffn = FFN(config)
@ -440,44 +263,23 @@ class TransformerBlock(GradientCheckpointingLayer):
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, ...]:
"""
Parameters:
x: torch.tensor(bs, seq_length, dim)
attn_mask: torch.tensor(bs, seq_length)
Returns:
sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
"""
# Self-Attention
sa_output = self.attention(
query=x,
key=x,
value=x,
mask=attn_mask,
output_attentions=output_attentions,
attention_output, _ = self.attention(
hidden_states,
attention_mask=attention_mask,
**kwargs,
)
if output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
if type(sa_output) is not tuple:
raise TypeError(f"sa_output must be a tuple but it is {type(sa_output)} type")
sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
attention_output = self.sa_layer_norm(attention_output + hidden_states)
# Feed Forward Network
ffn_output = self.ffn(sa_output) # (bs, seq_length, dim)
ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
ffn_output = self.ffn(attention_output)
ffn_output = self.output_layer_norm(ffn_output + attention_output)
output = (ffn_output,)
if output_attentions:
output = (sa_weights,) + output
return output
return ffn_output
class Transformer(nn.Module):
@ -489,61 +291,18 @@ class Transformer(nn.Module):
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: Optional[bool] = None,
) -> Union[BaseModelOutput, tuple[torch.Tensor, ...]]: # docstyle-ignore
"""
Parameters:
x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.
Returns:
hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)
layer all_hidden_states: tuple[torch.tensor(bs, seq_length, dim)]
Tuple of length n_layers with the hidden states from each layer.
Optional: only if output_hidden_states=True
all_attentions: tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if output_attentions=True
"""
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_state = x
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
layer_outputs = layer_module(
hidden_state,
attn_mask,
output_attentions,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[BaseModelOutput]:
for layer_module in self.layer:
hidden_states = layer_module(
hidden_states,
attention_mask,
**kwargs,
)
hidden_state = layer_outputs[-1]
if output_attentions:
if len(layer_outputs) != 2:
raise ValueError(f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}")
attentions = layer_outputs[0]
all_attentions = all_attentions + (attentions,)
else:
if len(layer_outputs) != 1:
raise ValueError(f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}")
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
if not return_dict:
return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
)
return BaseModelOutput(last_hidden_state=hidden_states)
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
@ -554,6 +313,12 @@ class DistilBertPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": TransformerBlock,
"attentions": DistilBertSelfAttention,
}
def _init_weights(self, module: nn.Module):
"""Initialize the weights."""
@ -647,15 +412,15 @@ class DistilBertModel(DistilBertPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.transformer.layer[layer].attention.prune_heads(heads)
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[BaseModelOutput, tuple[torch.Tensor, ...]]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`):
@ -670,45 +435,43 @@ class DistilBertModel(DistilBertPreTrainedModel):
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
embeddings = self.embeddings(input_ids, inputs_embeds, position_ids)
attention_mask = self._update_full_mask(
attention_mask,
embeddings,
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
if self.config._attn_implementation == "sdpa" and not output_attentions:
attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embeddings.dtype, tgt_len=input_shape[1]
)
return self.transformer(
x=embeddings,
attn_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
hidden_states=embeddings,
attention_mask=attention_mask,
**kwargs,
)
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
def _update_full_mask(
self,
attention_mask: Union[torch.Tensor, None],
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if "flash" in self.config._attn_implementation:
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
elif self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
return attention_mask
@auto_docstring(
custom_intro="""
@ -759,6 +522,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def set_output_embeddings(self, new_embeddings: nn.Module):
self.vocab_projector = new_embeddings
@can_return_tuple
@auto_docstring
def forward(
self,
@ -766,9 +530,8 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[MaskedLMOutput, tuple[torch.Tensor, ...]]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`):
@ -787,15 +550,13 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
dlbrt_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
position_ids=position_ids,
return_dict=True,
**kwargs,
)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
@ -807,10 +568,6 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
if labels is not None:
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
if not return_dict:
output = (prediction_logits,) + dlbrt_output[1:]
return ((mlm_loss,) + output) if mlm_loss is not None else output
return MaskedLMOutput(
loss=mlm_loss,
logits=prediction_logits,
@ -859,6 +616,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
@can_return_tuple
@auto_docstring
def forward(
self,
@ -866,9 +624,8 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[SequenceClassifierOutput, tuple[torch.Tensor, ...]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -876,15 +633,13 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
position_ids=position_ids,
return_dict=True,
**kwargs,
)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
@ -916,10 +671,6 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + distilbert_output[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
@ -963,6 +714,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
@can_return_tuple
@auto_docstring
def forward(
self,
@ -971,9 +723,8 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
inputs_embeds: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[QuestionAnsweringModelOutput, tuple[torch.Tensor, ...]]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`):
@ -988,15 +739,13 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
position_ids=position_ids,
return_dict=True,
**kwargs,
)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
@ -1023,10 +772,6 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + distilbert_output[1:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
@ -1069,6 +814,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1076,23 +822,20 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[TokenClassifierOutput, tuple[torch.Tensor, ...]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.distilbert(
input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
position_ids=position_ids,
return_dict=True,
**kwargs,
)
sequence_output = outputs[0]
@ -1105,10 +848,6 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
@ -1150,6 +889,7 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1157,9 +897,8 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[MultipleChoiceModelOutput, tuple[torch.Tensor, ...]]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@ -1199,7 +938,6 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
>>> loss = outputs.loss
>>> logits = outputs.logits
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@ -1214,9 +952,9 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
position_ids=position_ids,
return_dict=True,
**kwargs,
)
hidden_state = outputs[0] # (bs * num_choices, seq_len, dim)
@ -1233,10 +971,6 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
if not return_dict:
output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,

View File

@ -223,7 +223,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
if is_torch_available()
else {}
)
fx_compatible = True
fx_compatible = False # won't be maintained
test_pruning = True
test_resize_embeddings = True
test_resize_position_embeddings = True