Compare commits

...

1 Commits

Author SHA1 Message Date
c9b864ba54 add logs 2024-02-20 20:22:36 +01:00
3 changed files with 5 additions and 0 deletions

View File

@ -395,6 +395,7 @@ class StaticCache(Cache):
return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
print("self.seen_tokens in get_seq_length", self.seen_tokens)
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
return self.seen_tokens

View File

@ -2386,7 +2386,10 @@ class GenerationMixin:
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
count = 0
while True:
print("------- forward in generate", count)
count += 1
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence

View File

@ -1263,6 +1263,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
# generation with static cache
past_length = past_key_value.get_seq_length()
print("past_length here", past_length)
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]