Compare commits

...

2 Commits

Author SHA1 Message Date
58df5841ec Update modeling_gemma2.py 2024-08-01 16:52:26 +02:00
15fcb7ea58 Fix attention_mask shape indexing
Here, the attention mask is a 4D tensor. `.shape[-1]` selects incorrect index.
2024-08-01 16:40:13 +02:00

View File

@ -579,8 +579,8 @@ class Gemma2DecoderLayer(nn.Module):
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
if attention_mask.shape[1] <= 1: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
residual = hidden_states
@ -897,7 +897,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
if isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_length()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
target_length = attention_mask.shape[1] if attention_mask is not None else input_tensor.shape[1]
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(