mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
24 Commits
Author | SHA1 | Date | |
---|---|---|---|
bd4c28b78c | |||
1133b90a8d | |||
69f6683b65 | |||
8d0dd2b7c5 | |||
4280e160f0 | |||
fdaa41cd4f | |||
789c6ebd2d | |||
97b64efbf5 | |||
c9fe623818 | |||
aeb40d18f9 | |||
d67910ca9e | |||
69fee99a0d | |||
6ab2da6d80 | |||
3e7b5c7868 | |||
687a57f3ba | |||
ba933ecf61 | |||
74b1b3da5c | |||
ccdf7c557c | |||
13e7cac02e | |||
c814030772 | |||
00b61e1694 | |||
3b2e997b61 | |||
5a69ffedb2 | |||
80e94d09ed |
@ -2543,8 +2543,11 @@ class GenerationMixin:
|
||||
if output_logits:
|
||||
raw_logits += (next_token_logits,)
|
||||
if output_attentions:
|
||||
attentions = tuple(x.clone() if isinstance(x, torch.Tensor) else x for x in outputs.attentions)
|
||||
# print(attentions)
|
||||
# print("=" * 80)
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
@ -104,15 +104,15 @@ class GemmaRotaryEmbedding(nn.Module):
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.register_buffer("inv_freq", None, persistent=False)
|
||||
# self.register_buffer("inv_freq", None, persistent=False)
|
||||
self.inv_freq = 1.0 / (
|
||||
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cuda").float() / self.dim)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if self.inv_freq is None:
|
||||
self.inv_freq = 1.0 / (
|
||||
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
||||
)
|
||||
self.inv_freq.to(x.device)
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
@ -508,6 +508,9 @@ class GemmaSdpaAttention(GemmaAttention):
|
||||
`GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._seen_tokens = 0
|
||||
|
||||
# Ignore copy
|
||||
def forward(
|
||||
@ -519,22 +522,23 @@ class GemmaSdpaAttention(GemmaAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
_length: int = 0,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
# if output_attentions:
|
||||
# # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
# logger.warning_once(
|
||||
# "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
# 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
# )
|
||||
# return super().forward(
|
||||
# hidden_states=hidden_states,
|
||||
# attention_mask=attention_mask,
|
||||
# position_ids=position_ids,
|
||||
# past_key_value=past_key_value,
|
||||
# output_attentions=output_attentions,
|
||||
# use_cache=use_cache,
|
||||
# cache_position=cache_position,
|
||||
# )
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -551,9 +555,16 @@ class GemmaSdpaAttention(GemmaAttention):
|
||||
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
|
||||
if q_len > 1:
|
||||
self._seen_tokens = 0
|
||||
# self._seen_tokens = (64 + 6) - 6 - 1 # compile ok but should fail
|
||||
# self._seen_tokens = (64 + 6) - 6 # failed with index error 64
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
# full length
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
@ -574,11 +585,36 @@ class GemmaSdpaAttention(GemmaAttention):
|
||||
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
# obviously wrong as size is not what should be looked
|
||||
# length = cache_position.size()[0]
|
||||
|
||||
# can't compile (TODO: add error message
|
||||
# length = int(cache_position[-1] + 1)
|
||||
# length = cache_position[-1] + 1
|
||||
|
||||
# incorrect results with torch.compile (index stays at the value obtained in the 2nd forward call)
|
||||
length = self._seen_tokens
|
||||
# incorrect results without torch.compile (index stays at the value obtained in the 2nd forward call)
|
||||
# length = 1
|
||||
|
||||
# The correct length
|
||||
# length = _length
|
||||
|
||||
# to use the full length of the static cache
|
||||
# _key_states = key_states
|
||||
# _value_states = value_states
|
||||
# _attn_mask = causal_mask if causal_mask is not None else causal_mask
|
||||
|
||||
# to use the correct length or the very short length
|
||||
_key_states = key_states[:, :, :length, :]
|
||||
_value_states = value_states[:, :, :length, :]
|
||||
_attn_mask = causal_mask[:, :, :, :length] if causal_mask is not None else causal_mask
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
_key_states,
|
||||
_value_states,
|
||||
attn_mask=_attn_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
@ -588,7 +624,10 @@ class GemmaSdpaAttention(GemmaAttention):
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
# verify = self._seen_tokens
|
||||
verify = _length
|
||||
|
||||
return attn_output, verify, past_key_value
|
||||
|
||||
|
||||
GEMMA_ATTENTION_CLASSES = {
|
||||
@ -619,6 +658,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
_length: int = 0,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@ -653,6 +693,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
_length=_length,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -857,6 +898,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
_length: int = 0,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -933,6 +975,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
_length=_length,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -1087,6 +1130,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
_length: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1131,6 +1175,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
_length=_length,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1169,6 +1214,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
use_cache=True,
|
||||
_length=None,
|
||||
**kwargs,
|
||||
):
|
||||
# With static cache, the `past_key_values` is None
|
||||
@ -1246,6 +1292,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"_length": int(cache_position[-1]),
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -604,6 +604,10 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._seen_tokens = 0
|
||||
|
||||
# Adapted from LlamaAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
@ -630,7 +634,6 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
@ -647,10 +650,19 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
# In case static cache is used, it is an instance attribute.
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
|
||||
if q_len > 1:
|
||||
self._seen_tokens = 0
|
||||
# self._seen_tokens = (64 + 7) - 7 - 1 # compile ok but should fail
|
||||
# self._seen_tokens = (64 + 7) - 7 # failed with index error 71
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
# full length
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
# only necessary length (but we still need to update)
|
||||
# _, _ = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
@ -670,11 +682,27 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
# length = int(cache_position[-1] + 1)
|
||||
#length = cache_position.size()[0] # can't compile, failed at `scaled_dot_product_attention` (`(*bias): last dimension must be contiguous`). Also wrong value!
|
||||
length = self._seen_tokens # incorrect results (index stay at very small values)
|
||||
|
||||
# _key_states = key_states
|
||||
# _value_states = value_states
|
||||
# _attn_mask = causal_mask if causal_mask is not None else causal_mask
|
||||
|
||||
_key_states = key_states[:, :, :length, :]
|
||||
_value_states = value_states[:, :, :length, :]
|
||||
_attn_mask = causal_mask[:, :, :, :length] if causal_mask is not None else causal_mask
|
||||
|
||||
# _key_states = _key_states.contiguous()
|
||||
# _value_states = _value_states.contiguous()
|
||||
# _attn_mask = _attn_mask.contiguous() if causal_mask is not None else causal_mask
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
_key_states,
|
||||
_value_states,
|
||||
attn_mask=_attn_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
@ -684,7 +712,8 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
verify = None # key_states[:, :, length - 1, :]
|
||||
return attn_output, verify, past_key_value
|
||||
|
||||
|
||||
LLAMA_ATTENTION_CLASSES = {
|
||||
|
Reference in New Issue
Block a user