mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
5 Commits
v4.51.0
...
less-const
Author | SHA1 | Date | |
---|---|---|---|
10f56e53bc | |||
8db50972a9 | |||
ec73df8bdd | |||
2f3fa1ad18 | |||
b581b70a23 |
@ -1953,12 +1953,13 @@ class GenerationMixin:
|
||||
if model_kwargs.get("past_key_values") is None:
|
||||
# prepare inputs
|
||||
model_kwargs["use_cache"] = True
|
||||
self.config.output_attentions = output_attentions
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
|
||||
# the `encoder_outputs`
|
||||
outputs = self(
|
||||
**model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
|
||||
**model_inputs, return_dict=True, output_hidden_states=True
|
||||
)
|
||||
|
||||
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
|
||||
@ -2040,17 +2041,18 @@ class GenerationMixin:
|
||||
new_key_values.append(tuple(items))
|
||||
model_kwargs["past_key_values"] = tuple(new_key_values)
|
||||
|
||||
self.config.output_attentions = output_attentions
|
||||
if sequential:
|
||||
all_outputs = []
|
||||
for i in range(top_k):
|
||||
# compute the candidate tokens by the language model and collect their hidden_states
|
||||
|
||||
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
|
||||
|
||||
outputs = self(
|
||||
**next_model_inputs,
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
all_outputs.append(outputs)
|
||||
outputs = stack_model_outputs(all_outputs)
|
||||
@ -2064,7 +2066,6 @@ class GenerationMixin:
|
||||
**next_model_inputs,
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
# name is different for encoder-decoder and decoder-only models
|
||||
if self.config.is_encoder_decoder:
|
||||
@ -2098,6 +2099,7 @@ class GenerationMixin:
|
||||
next_decoder_hidden_states += (layer,)
|
||||
|
||||
# generate past_key_values cache of only the selected token
|
||||
self.config.output_attentions = False
|
||||
if sequential:
|
||||
next_model_input = self.prepare_inputs_for_generation(
|
||||
top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs
|
||||
@ -2107,7 +2109,6 @@ class GenerationMixin:
|
||||
**next_model_input,
|
||||
return_dict=True,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
)
|
||||
next_past_key_values = selected_outputs["past_key_values"]
|
||||
|
||||
@ -2397,17 +2398,14 @@ class GenerationMixin:
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
self.config.output_attentions = output_attentions
|
||||
self.config.output_hidden_states = output_hidden_states
|
||||
# prepare model inputs. Non attention compatible models can pop whatever they need
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
# forward pass to get next token
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
outputs = self(**model_inputs)
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
@ -2691,13 +2689,13 @@ class GenerationMixin:
|
||||
break
|
||||
|
||||
# prepare model inputs
|
||||
self.config.output_attentions = output_attentions
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
# forward pass to get next token
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
@ -3028,6 +3026,7 @@ class GenerationMixin:
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
self.config.output_attentions = output_attentions
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
# if sequential is True, split the input to batches of batch_size and run sequentially
|
||||
@ -3057,7 +3056,6 @@ class GenerationMixin:
|
||||
self(
|
||||
**inputs_per_sub_batch,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
for inputs_per_sub_batch in inputs_per_sub_batches
|
||||
@ -3069,7 +3067,6 @@ class GenerationMixin:
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
@ -3406,13 +3403,13 @@ class GenerationMixin:
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
|
||||
self.config.output_attentions = output_attentions
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
@ -3767,11 +3764,11 @@ class GenerationMixin:
|
||||
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
||||
|
||||
# do one decoder step on all beams of all sentences in batch
|
||||
self.config.output_attentions = output_attentions
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
@ -4159,12 +4156,12 @@ class GenerationMixin:
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
self.config.output_attentions = output_attentions
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
@ -4526,12 +4523,12 @@ class GenerationMixin:
|
||||
)
|
||||
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
|
||||
|
||||
self.config.output_attentions = output_attentions
|
||||
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
|
||||
|
||||
# 2.2. Run a forward pass on the candidate sequence
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user