Compare commits

...

4 Commits

Author SHA1 Message Date
58c831bd12 Merge commit '10bc8b4e' into dynamic_length_on_b6eb708b 2024-05-20 12:51:23 +02:00
2fc9e12ef5 Merge commit 'cbf98115' into dynamic_length_on_b6eb708b
# Conflicts:
#	src/transformers/models/gemma/modeling_gemma.py
2024-05-20 12:50:04 +02:00
cbf98115fa 0ae789e0 + dynamic length in static cache 2024-05-20 12:41:20 +02:00
10bc8b4e02 fix EosTokenCriteria 2024-05-20 10:43:14 +02:00

View File

@ -110,7 +110,7 @@ class GemmaRotaryEmbedding(nn.Module):
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
self.inv_freq = self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
@ -517,6 +517,7 @@ class GemmaSdpaAttention(GemmaAttention):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@ -570,6 +571,11 @@ class GemmaSdpaAttention(GemmaAttention):
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if causal_mask is None and q_len > 1 else False
if _length > 0:
key_states = key_states[:, :, :_length, :]
value_states = value_states[:, :, :_length, :]
causal_mask = causal_mask[:, :, :, :_length] if causal_mask is not None else causal_mask
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
@ -615,6 +621,7 @@ class GemmaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@ -643,6 +650,7 @@ class GemmaDecoderLayer(nn.Module):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
_length=_length,
)
hidden_states = residual + hidden_states
@ -829,6 +837,7 @@ class GemmaModel(GemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -907,6 +916,7 @@ class GemmaModel(GemmaPreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
_length=_length,
)
hidden_states = layer_outputs[0]
@ -1062,6 +1072,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1106,6 +1117,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
_length=_length,
)
hidden_states = outputs[0]
@ -1144,6 +1156,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
inputs_embeds=None,
cache_position=None,
use_cache=True,
_length=None,
**kwargs,
):
past_length = 0
@ -1210,6 +1223,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"_length": int(cache_position[-1]) + 1,
}
)
return model_inputs