Compare commits

...

3 Commits

Author SHA1 Message Date
ed2f65a9a4 Merge commit '10bc8b4e' into dynamic_length_on_0ae789e0 2024-05-20 12:42:32 +02:00
cbf98115fa 0ae789e0 + dynamic length in static cache 2024-05-20 12:41:20 +02:00
10bc8b4e02 fix EosTokenCriteria 2024-05-20 10:43:14 +02:00
2 changed files with 19 additions and 6 deletions

View File

@ -481,6 +481,7 @@ class EosTokenCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
self.eos_token_id = self.eos_token_id.to(input_ids.device)
if input_ids.device.type == "mps":
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done = (
@ -492,7 +493,7 @@ class EosTokenCriteria(StoppingCriteria):
.squeeze()
)
else:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
return is_done

View File

@ -104,15 +104,13 @@ class GemmaRotaryEmbedding(nn.Module):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
self.inv_freq = self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
@ -519,6 +517,7 @@ class GemmaSdpaAttention(GemmaAttention):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@ -574,6 +573,11 @@ class GemmaSdpaAttention(GemmaAttention):
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if causal_mask is None and q_len > 1 else False
if _length > 0:
key_states = key_states[:, :, :_length, :]
value_states = value_states[:, :, :_length, :]
causal_mask = causal_mask[:, :, :, :_length] if causal_mask is not None else causal_mask
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
@ -619,6 +623,7 @@ class GemmaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@ -653,6 +658,7 @@ class GemmaDecoderLayer(nn.Module):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
_length=_length,
**kwargs,
)
hidden_states = residual + hidden_states
@ -857,6 +863,7 @@ class GemmaModel(GemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -933,6 +940,7 @@ class GemmaModel(GemmaPreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
_length=_length,
)
hidden_states = layer_outputs[0]
@ -1087,6 +1095,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1131,6 +1140,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
_length=_length,
)
hidden_states = outputs[0]
@ -1169,6 +1179,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
inputs_embeds=None,
cache_position=None,
use_cache=True,
_length=None,
**kwargs,
):
# With static cache, the `past_key_values` is None
@ -1246,6 +1257,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"_length": int(cache_position[-1]) + 1,
}
)
return model_inputs