Compare commits

...

1 Commits

Author SHA1 Message Date
488c084e86 use config attr 2025-04-28 18:26:34 +02:00

View File

@ -819,7 +819,6 @@ class GPT2Model(GPT2PreTrainedModel):
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False
self._attn_implementation = config._attn_implementation
# Initialize weights and apply final processing
self.post_init()
@ -978,7 +977,7 @@ class GPT2Model(GPT2PreTrainedModel):
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
_use_sdpa = self.config._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
@ -988,7 +987,7 @@ class GPT2Model(GPT2PreTrainedModel):
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
)
elif not self._attn_implementation == "flash_attention_2":
elif not self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None