Compare commits

...

24 Commits

Author SHA1 Message Date
bd4c28b78c incorrectly very short cache length 2024-05-09 23:28:06 +02:00
1133b90a8d full cache length 2024-05-09 23:26:43 +02:00
69f6683b65 optimized cache length 2024-05-09 23:25:24 +02:00
8d0dd2b7c5 update 2024-05-09 15:20:39 +02:00
4280e160f0 update 2024-05-09 10:05:18 +02:00
fdaa41cd4f update 2024-05-09 09:58:11 +02:00
789c6ebd2d update 2024-05-09 09:57:20 +02:00
97b64efbf5 update 2024-05-09 09:19:34 +02:00
c9fe623818 update 2024-05-09 09:12:57 +02:00
aeb40d18f9 update 2024-05-09 09:05:13 +02:00
d67910ca9e update 2024-05-08 14:33:32 +02:00
69fee99a0d update 2024-05-08 13:28:00 +02:00
6ab2da6d80 update 2024-05-08 12:24:56 +02:00
3e7b5c7868 update 2024-05-08 12:06:18 +02:00
687a57f3ba remark 2024-05-08 10:08:52 +02:00
ba933ecf61 exp 2024-05-01 21:05:31 +02:00
74b1b3da5c exp 2024-05-01 19:05:33 +02:00
ccdf7c557c exp 2024-05-01 17:52:34 +02:00
13e7cac02e exp 2024-05-01 13:37:30 +02:00
c814030772 exp 2024-05-01 13:35:50 +02:00
00b61e1694 exp 2024-05-01 13:33:21 +02:00
3b2e997b61 exp 2024-05-01 13:28:05 +02:00
5a69ffedb2 exp 2024-05-01 13:05:22 +02:00
80e94d09ed exp 2024-05-01 09:54:40 +02:00
3 changed files with 109 additions and 30 deletions

View File

@ -2543,8 +2543,11 @@ class GenerationMixin:
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
attentions = tuple(x.clone() if isinstance(x, torch.Tensor) else x for x in outputs.attentions)
# print(attentions)
# print("=" * 80)
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)

View File

@ -104,15 +104,15 @@ class GemmaRotaryEmbedding(nn.Module):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
# self.register_buffer("inv_freq", None, persistent=False)
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cuda").float() / self.dim)
)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
@ -508,6 +508,9 @@ class GemmaSdpaAttention(GemmaAttention):
`GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._seen_tokens = 0
# Ignore copy
def forward(
@ -519,22 +522,23 @@ class GemmaSdpaAttention(GemmaAttention):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
# if output_attentions:
# # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
# logger.warning_once(
# "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
# 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
# )
# return super().forward(
# hidden_states=hidden_states,
# attention_mask=attention_mask,
# position_ids=position_ids,
# past_key_value=past_key_value,
# output_attentions=output_attentions,
# use_cache=use_cache,
# cache_position=cache_position,
# )
bsz, q_len, _ = hidden_states.size()
@ -551,9 +555,16 @@ class GemmaSdpaAttention(GemmaAttention):
past_key_value = getattr(self, "past_key_value", past_key_value)
if q_len > 1:
self._seen_tokens = 0
# self._seen_tokens = (64 + 6) - 6 - 1 # compile ok but should fail
# self._seen_tokens = (64 + 6) - 6 # failed with index error 64
self._seen_tokens += key_states.shape[-2]
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# full length
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
@ -574,11 +585,36 @@ class GemmaSdpaAttention(GemmaAttention):
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if causal_mask is None and q_len > 1 else False
# obviously wrong as size is not what should be looked
# length = cache_position.size()[0]
# can't compile (TODO: add error message
# length = int(cache_position[-1] + 1)
# length = cache_position[-1] + 1
# incorrect results with torch.compile (index stays at the value obtained in the 2nd forward call)
length = self._seen_tokens
# incorrect results without torch.compile (index stays at the value obtained in the 2nd forward call)
# length = 1
# The correct length
# length = _length
# to use the full length of the static cache
# _key_states = key_states
# _value_states = value_states
# _attn_mask = causal_mask if causal_mask is not None else causal_mask
# to use the correct length or the very short length
_key_states = key_states[:, :, :length, :]
_value_states = value_states[:, :, :length, :]
_attn_mask = causal_mask[:, :, :, :length] if causal_mask is not None else causal_mask
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
_key_states,
_value_states,
attn_mask=_attn_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
@ -588,7 +624,10 @@ class GemmaSdpaAttention(GemmaAttention):
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# verify = self._seen_tokens
verify = _length
return attn_output, verify, past_key_value
GEMMA_ATTENTION_CLASSES = {
@ -619,6 +658,7 @@ class GemmaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@ -653,6 +693,7 @@ class GemmaDecoderLayer(nn.Module):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
_length=_length,
**kwargs,
)
hidden_states = residual + hidden_states
@ -857,6 +898,7 @@ class GemmaModel(GemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -933,6 +975,7 @@ class GemmaModel(GemmaPreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
_length=_length,
)
hidden_states = layer_outputs[0]
@ -1087,6 +1130,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1131,6 +1175,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
_length=_length,
)
hidden_states = outputs[0]
@ -1169,6 +1214,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
inputs_embeds=None,
cache_position=None,
use_cache=True,
_length=None,
**kwargs,
):
# With static cache, the `past_key_values` is None
@ -1246,6 +1292,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"_length": int(cache_position[-1]),
}
)
return model_inputs

View File

@ -604,6 +604,10 @@ class LlamaSdpaAttention(LlamaAttention):
SDPA API.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._seen_tokens = 0
# Adapted from LlamaAttention.forward
def forward(
self,
@ -630,7 +634,6 @@ class LlamaSdpaAttention(LlamaAttention):
use_cache=use_cache,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
@ -647,10 +650,19 @@ class LlamaSdpaAttention(LlamaAttention):
# In case static cache is used, it is an instance attribute.
past_key_value = getattr(self, "past_key_value", past_key_value)
if q_len > 1:
self._seen_tokens = 0
# self._seen_tokens = (64 + 7) - 7 - 1 # compile ok but should fail
# self._seen_tokens = (64 + 7) - 7 # failed with index error 71
self._seen_tokens += key_states.shape[-2]
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# full length
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# only necessary length (but we still need to update)
# _, _ = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
@ -670,11 +682,27 @@ class LlamaSdpaAttention(LlamaAttention):
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if causal_mask is None and q_len > 1 else False
# length = int(cache_position[-1] + 1)
#length = cache_position.size()[0] # can't compile, failed at `scaled_dot_product_attention` (`(*bias): last dimension must be contiguous`). Also wrong value!
length = self._seen_tokens # incorrect results (index stay at very small values)
# _key_states = key_states
# _value_states = value_states
# _attn_mask = causal_mask if causal_mask is not None else causal_mask
_key_states = key_states[:, :, :length, :]
_value_states = value_states[:, :, :length, :]
_attn_mask = causal_mask[:, :, :, :length] if causal_mask is not None else causal_mask
# _key_states = _key_states.contiguous()
# _value_states = _value_states.contiguous()
# _attn_mask = _attn_mask.contiguous() if causal_mask is not None else causal_mask
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
_key_states,
_value_states,
attn_mask=_attn_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
@ -684,7 +712,8 @@ class LlamaSdpaAttention(LlamaAttention):
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
verify = None # key_states[:, :, length - 1, :]
return attn_output, verify, past_key_value
LLAMA_ATTENTION_CLASSES = {