Compare commits

...

5 Commits

Author SHA1 Message Date
10f56e53bc Update src/transformers/generation/utils.py 2024-02-20 11:24:16 +01:00
8db50972a9 simpler? 2024-02-20 11:06:36 +01:00
ec73df8bdd nit 2024-02-20 11:45:29 +09:00
2f3fa1ad18 Merge branch 'main' of github.com:huggingface/transformers into less-constraints 2024-02-20 11:44:56 +09:00
b581b70a23 update utils 2024-02-20 11:43:24 +09:00

View File

@ -1953,12 +1953,13 @@ class GenerationMixin:
if model_kwargs.get("past_key_values") is None:
# prepare inputs
model_kwargs["use_cache"] = True
self.config.output_attentions = output_attentions
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# the `encoder_outputs`
outputs = self(
**model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
**model_inputs, return_dict=True, output_hidden_states=True
)
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
@ -2040,17 +2041,18 @@ class GenerationMixin:
new_key_values.append(tuple(items))
model_kwargs["past_key_values"] = tuple(new_key_values)
self.config.output_attentions = output_attentions
if sequential:
all_outputs = []
for i in range(top_k):
# compute the candidate tokens by the language model and collect their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
outputs = self(
**next_model_inputs,
return_dict=True,
output_hidden_states=True,
output_attentions=output_attentions,
)
all_outputs.append(outputs)
outputs = stack_model_outputs(all_outputs)
@ -2064,7 +2066,6 @@ class GenerationMixin:
**next_model_inputs,
return_dict=True,
output_hidden_states=True,
output_attentions=output_attentions,
)
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
@ -2098,6 +2099,7 @@ class GenerationMixin:
next_decoder_hidden_states += (layer,)
# generate past_key_values cache of only the selected token
self.config.output_attentions = False
if sequential:
next_model_input = self.prepare_inputs_for_generation(
top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs
@ -2107,7 +2109,6 @@ class GenerationMixin:
**next_model_input,
return_dict=True,
output_hidden_states=False,
output_attentions=False,
)
next_past_key_values = selected_outputs["past_key_values"]
@ -2397,17 +2398,14 @@ class GenerationMixin:
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
self.config.output_attentions = output_attentions
self.config.output_hidden_states = output_hidden_states
# prepare model inputs. Non attention compatible models can pop whatever they need
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
outputs = self(**model_inputs)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
@ -2691,13 +2689,13 @@ class GenerationMixin:
break
# prepare model inputs
self.config.output_attentions = output_attentions
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
@ -3028,6 +3026,7 @@ class GenerationMixin:
if this_peer_finished_flag.item() == 0.0:
break
self.config.output_attentions = output_attentions
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# if sequential is True, split the input to batches of batch_size and run sequentially
@ -3057,7 +3056,6 @@ class GenerationMixin:
self(
**inputs_per_sub_batch,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
for inputs_per_sub_batch in inputs_per_sub_batches
@ -3069,7 +3067,6 @@ class GenerationMixin:
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
@ -3406,13 +3403,13 @@ class GenerationMixin:
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
self.config.output_attentions = output_attentions
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
@ -3767,11 +3764,11 @@ class GenerationMixin:
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
# do one decoder step on all beams of all sentences in batch
self.config.output_attentions = output_attentions
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
@ -4159,12 +4156,12 @@ class GenerationMixin:
if this_peer_finished_flag.item() == 0.0:
break
self.config.output_attentions = output_attentions
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
@ -4526,12 +4523,12 @@ class GenerationMixin:
)
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
self.config.output_attentions = output_attentions
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
# 2.2. Run a forward pass on the candidate sequence
outputs = self(
**model_inputs,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)