mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
17 Commits
v4.56.0
...
repro-bug-
Author | SHA1 | Date | |
---|---|---|---|
a87debaf9e | |||
a8c4e1036a | |||
5dbcef4347 | |||
293546296f | |||
a660486ee6 | |||
80b9072c4e | |||
5019e81b80 | |||
190e0cf2be | |||
28cdee0fa4 | |||
0c03b7d45d | |||
b9b627c6f0 | |||
b214766730 | |||
7472549870 | |||
0a00d6bba7 | |||
724c694612 | |||
4891050949 | |||
6d669eea44 |
@ -357,7 +357,6 @@ class StaticCache(Cache):
|
||||
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
self.seen_tokens = 0
|
||||
|
||||
def update(
|
||||
self,
|
||||
@ -391,15 +390,20 @@ class StaticCache(Cache):
|
||||
k_out[:, :, new_cache_positions] = key_states
|
||||
v_out[:, :, new_cache_positions] = value_states
|
||||
|
||||
self.seen_tokens += key_states.shape[2]
|
||||
return k_out, v_out
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
|
||||
return self.seen_tokens
|
||||
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
|
||||
raise ValueError(
|
||||
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
|
||||
)
|
||||
|
||||
def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
|
||||
return self.seen_tokens
|
||||
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
|
||||
raise ValueError(
|
||||
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
|
||||
)
|
||||
|
||||
def get_max_length(self) -> Optional[int]:
|
||||
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
||||
|
@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
import copy
|
||||
import inspect
|
||||
import warnings
|
||||
@ -648,6 +648,7 @@ class GenerationMixin:
|
||||
model_kwargs: Dict[str, Any],
|
||||
is_encoder_decoder: bool = False,
|
||||
standardize_cache_format: bool = False,
|
||||
model_inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
# update past_key_values
|
||||
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
||||
@ -677,6 +678,8 @@ class GenerationMixin:
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
model_kwargs["cache_position"] = model_inputs.get("cache_position", None)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
@ -1451,17 +1454,19 @@ class GenerationMixin:
|
||||
):
|
||||
generation_config.max_length -= inputs_tensor.shape[1]
|
||||
|
||||
# if we don't pass `past_key_values` and a cache_implementation is specified
|
||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get(
|
||||
"past_key_values", False
|
||||
):
|
||||
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[generation_config.cache_implementation]
|
||||
if not callable(getattr(self, "_setup_cache", None)):
|
||||
raise ValueError(
|
||||
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
|
||||
" Make sure it has a `_setup_cache` function."
|
||||
)
|
||||
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
|
||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
||||
if generation_config.cache_implementation == "static":
|
||||
if model_kwargs.get("past_key_values", False) is not False:
|
||||
raise ValueError(
|
||||
"Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository."
|
||||
)
|
||||
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"]
|
||||
if not callable(getattr(self, "_setup_cache", None)):
|
||||
raise ValueError(
|
||||
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
|
||||
" Make sure it has a `_setup_cache` function."
|
||||
)
|
||||
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
|
||||
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
|
||||
@ -1523,7 +1528,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# 12. run assisted generate
|
||||
return self.assisted_decoding(
|
||||
result = self.assisted_decoding(
|
||||
input_ids,
|
||||
candidate_generator=candidate_generator,
|
||||
do_sample=generation_config.do_sample,
|
||||
@ -1541,7 +1546,7 @@ class GenerationMixin:
|
||||
)
|
||||
if generation_mode == GenerationMode.GREEDY_SEARCH:
|
||||
# 11. run greedy search
|
||||
return self.greedy_search(
|
||||
result = self.greedy_search(
|
||||
input_ids,
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
@ -1559,7 +1564,7 @@ class GenerationMixin:
|
||||
if not model_kwargs["use_cache"]:
|
||||
raise ValueError("Contrastive search requires `use_cache=True`")
|
||||
|
||||
return self.contrastive_search(
|
||||
result = self.contrastive_search(
|
||||
input_ids,
|
||||
top_k=generation_config.top_k,
|
||||
penalty_alpha=generation_config.penalty_alpha,
|
||||
@ -1589,7 +1594,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# 13. run sample
|
||||
return self.sample(
|
||||
result = self.sample(
|
||||
input_ids,
|
||||
logits_processor=prepared_logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
@ -1623,7 +1628,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
# 13. run beam search
|
||||
return self.beam_search(
|
||||
result = self.beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=prepared_logits_processor,
|
||||
@ -1662,7 +1667,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# 14. run beam sample
|
||||
return self.beam_sample(
|
||||
result = self.beam_sample(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=prepared_logits_processor,
|
||||
@ -1697,7 +1702,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
# 13. run beam search
|
||||
return self.group_beam_search(
|
||||
result = self.group_beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=prepared_logits_processor,
|
||||
@ -1771,7 +1776,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
# 13. run beam search
|
||||
return self.constrained_beam_search(
|
||||
result = self.constrained_beam_search(
|
||||
input_ids,
|
||||
constrained_beam_scorer=constrained_beam_scorer,
|
||||
logits_processor=prepared_logits_processor,
|
||||
@ -1785,6 +1790,16 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
||||
if not callable(getattr(self, "_reset_cache", None)):
|
||||
raise ValueError(
|
||||
"A `static_cache` was used to generate but there was a failure when trying to release the cache. "
|
||||
" Make sure this model implements a `_reset_cache` function."
|
||||
)
|
||||
self._reset_cache()
|
||||
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def contrastive_search(
|
||||
self,
|
||||
@ -1975,6 +1990,7 @@ class GenerationMixin:
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
standardize_cache_format=True,
|
||||
model_inputs=model_inputs,
|
||||
)
|
||||
if not sequential:
|
||||
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
|
||||
@ -2169,7 +2185,7 @@ class GenerationMixin:
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
|
||||
)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
@ -2386,7 +2402,10 @@ class GenerationMixin:
|
||||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
count = 0
|
||||
while True:
|
||||
print("----- in forward", count)
|
||||
count += 1
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
@ -2400,6 +2419,17 @@ class GenerationMixin:
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
for name, inp in model_inputs.items():
|
||||
if isinstance(inp, torch.Tensor):
|
||||
print(f"name={name}, shape={inp.shape}, stride={inp.stride()}, dtype={inp.dtype}, device={inp.device}")
|
||||
elif name == "past_key_values" and inp is not None:
|
||||
print("past_key_values not None")
|
||||
else:
|
||||
print(f"name={name}, value={inp}")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = time.time_ns()
|
||||
# forward pass to get next token
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
@ -2408,6 +2438,11 @@ class GenerationMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
end = time.time_ns()
|
||||
|
||||
print(f"forward call latency: {(end - start) * 1e-6:.3f} ms")
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
@ -2450,7 +2485,10 @@ class GenerationMixin:
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
model_inputs=model_inputs,
|
||||
)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
@ -2744,7 +2782,7 @@ class GenerationMixin:
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
|
||||
)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
@ -3137,7 +3175,7 @@ class GenerationMixin:
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
|
||||
)
|
||||
if model_kwargs["past_key_values"] is not None:
|
||||
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
|
||||
@ -3484,7 +3522,7 @@ class GenerationMixin:
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
|
||||
)
|
||||
if model_kwargs["past_key_values"] is not None:
|
||||
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
|
||||
@ -3883,7 +3921,7 @@ class GenerationMixin:
|
||||
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
|
||||
)
|
||||
if model_kwargs["past_key_values"] is not None:
|
||||
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
|
||||
@ -4235,7 +4273,7 @@ class GenerationMixin:
|
||||
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
|
||||
)
|
||||
if model_kwargs["past_key_values"] is not None:
|
||||
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
|
||||
@ -4642,7 +4680,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
|
||||
)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
|
@ -648,6 +648,7 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
||||
|
||||
# In case static cache is used, it is an instance attribute.
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
|
||||
if past_key_value is not None:
|
||||
@ -976,9 +977,11 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
@ -1050,6 +1053,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
def _update_causal_mask(self, attention_mask, input_tensor):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
@ -1065,16 +1072,8 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows
|
||||
causal_mask = (
|
||||
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
|
||||
)
|
||||
else:
|
||||
mask = torch.full(
|
||||
(self.config.max_position_embeddings, self.config.max_position_embeddings),
|
||||
fill_value=torch.finfo(dtype).min,
|
||||
)
|
||||
causal_mask = torch.triu(mask, diagonal=1)
|
||||
# We use the current dtype to avoid any overflows
|
||||
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
|
||||
|
||||
causal_mask = causal_mask.to(dtype=dtype, device=device)
|
||||
if attention_mask is not None and attention_mask.dim() == 2:
|
||||
@ -1260,29 +1259,32 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
||||
if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
|
||||
# generation with static cache
|
||||
past_length = past_key_value.get_seq_length()
|
||||
cache_position = kwargs.get("cache_position", None)
|
||||
if cache_position is None:
|
||||
past_length = 0
|
||||
else:
|
||||
past_length = cache_position[-1] + 1
|
||||
input_ids = input_ids[:, past_length:]
|
||||
position_ids = position_ids[:, past_length:]
|
||||
|
||||
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
|
||||
# same goes for position ids. Could also help with continued generation.
|
||||
cache_position = kwargs.get("cache_position", None)
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_length, past_length + position_ids.shape[-1], device=position_ids.device
|
||||
)
|
||||
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"position_ids": position_ids.contiguous(),
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
|
Reference in New Issue
Block a user