mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
9 Commits
v4.52.2
...
assistant_
Author | SHA1 | Date | |
---|---|---|---|
996d948e51 | |||
b6ad3e925c | |||
51f0e4f1a7 | |||
daa7cc9575 | |||
7a071d7952 | |||
04a6581f9b | |||
60b5aab973 | |||
202d74bb02 | |||
3f26e69034 |
@ -1621,8 +1621,6 @@ class GenerationMixin:
|
||||
"num_return_sequences has to be 1 when doing assisted generate, "
|
||||
f"but is {generation_config.num_return_sequences}."
|
||||
)
|
||||
if batch_size > 1:
|
||||
raise ValueError("assisted generate is only supported for batch_size = 1")
|
||||
if not model_kwargs["use_cache"]:
|
||||
raise ValueError("assisted generate requires `use_cache=True`")
|
||||
|
||||
@ -4427,6 +4425,7 @@ class GenerationMixin:
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
# other auxiliary variables
|
||||
max_len = stopping_criteria[0].max_length
|
||||
@ -4441,6 +4440,13 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
|
||||
# 2 * max_len to give us room to potentially left cut
|
||||
position_ids = torch.arange(2 * max_len, device=input_ids.device, dtype=torch.long)[None, :].broadcast_to(batch_size, 2 * max_len) if batch_size > 1 else None
|
||||
attention_mask = torch.ones_like(position_ids) if position_ids is not None else None
|
||||
n_matches = None
|
||||
eos_tokens_mask = None
|
||||
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
@ -4452,6 +4458,31 @@ class GenerationMixin:
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
# Rotate everything for bsz > 1
|
||||
if n_matches is not None and position_ids is not None:
|
||||
# compute by how much everything can be rotated
|
||||
shift = unfinished_sequences * (n_matches.max() - n_matches) + (1 - unfinished_sequences) * (eos_tokens_mask.sum(-1) - 1)
|
||||
|
||||
for i in range(batch_size):
|
||||
if shift[i] > 0:
|
||||
input_ids[i][shift[i]:] = input_ids[i][:-shift[i]].clone()
|
||||
input_ids[i][:shift[i]] = self.config.pad_token_id
|
||||
|
||||
position_ids = position_ids.add(-shift[:, None]).clamp(min=0)
|
||||
attention_mask[:, :-1] = position_ids[:, 1:] > 0
|
||||
|
||||
left_cut = (1 - attention_mask).sum(-1).min()
|
||||
|
||||
if left_cut > 0:
|
||||
position_ids = position_ids[:, left_cut:]
|
||||
attention_mask = attention_mask[:, left_cut:]
|
||||
input_ids = input_ids[:, left_cut:]
|
||||
|
||||
model_kwargs["past_key_values"] = _crop_past_key_values(self, model_kwargs["past_key_values"], left_cut=left_cut)
|
||||
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
|
||||
assistant_model, model_kwargs["assistant_past_key_values"], left_cut=left_cut
|
||||
) # the assistant does not have the token after the last match, hence the -1
|
||||
|
||||
# Assistant: main logic start
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
@ -4459,17 +4490,24 @@ class GenerationMixin:
|
||||
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
||||
# need access to the assistant cache to secure strong speedups.
|
||||
candidate_input_ids = input_ids
|
||||
for _ in range(int(num_assistant_tokens)):
|
||||
|
||||
for assist_idx in range(int(num_assistant_tokens)):
|
||||
# 1.1. use the assistant model to obtain the next candidate logits
|
||||
if "assistant_past_key_values" in model_kwargs:
|
||||
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
|
||||
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
||||
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
||||
assist_inputs = candidate_input_ids[:, -new_token_len:]
|
||||
|
||||
assist_position_ids = position_ids[:, cur_len - new_token_len + assist_idx:cur_len + assist_idx] if position_ids is not None else None
|
||||
assist_attention_mask = attention_mask[:, :cur_len + assist_idx] if attention_mask is not None else None
|
||||
|
||||
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
|
||||
if assistant_model.config.is_encoder_decoder:
|
||||
assistant_model_outputs = assistant_model(
|
||||
decoder_input_ids=assist_inputs,
|
||||
decoder_position_ids=assist_position_ids,
|
||||
decoder_attention_mask=assist_attention_mask,
|
||||
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||
)
|
||||
@ -4505,16 +4543,19 @@ class GenerationMixin:
|
||||
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
|
||||
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
|
||||
|
||||
if stopping_criteria(candidate_input_ids, assistant_model_outputs.logits[:, -1, :]):
|
||||
break
|
||||
|
||||
# 1.3. stop assistant generation on EOS
|
||||
if eos_token_id_tensor is not None:
|
||||
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
|
||||
last_assistant_token_is_eos = (
|
||||
~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
|
||||
)
|
||||
if last_assistant_token_is_eos:
|
||||
if torch.logical_or(~unfinished_sequences.bool(), last_assistant_token_is_eos).all():
|
||||
break
|
||||
else:
|
||||
last_assistant_token_is_eos = False
|
||||
last_assistant_token_is_eos = torch.zeros((1, 1), device=candidate_input_ids.device, dtype=torch.int8)
|
||||
|
||||
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
||||
|
||||
@ -4527,6 +4568,12 @@ class GenerationMixin:
|
||||
candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
|
||||
candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
candidate_kwargs["decoder_position_ids"] = position_ids[:, :cur_len + candidate_length] if position_ids is not None else None
|
||||
candidate_kwargs["decoder_attention_mask"] = attention_mask[:, :cur_len + candidate_length] if attention_mask is not None else None
|
||||
else:
|
||||
candidate_kwargs["position_ids"] = position_ids[:, :cur_len + candidate_length]
|
||||
|
||||
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
|
||||
|
||||
# 2.2. Run a forward pass on the candidate sequence
|
||||
@ -4555,7 +4602,7 @@ class GenerationMixin:
|
||||
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
|
||||
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
|
||||
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
|
||||
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
|
||||
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum(-1)
|
||||
|
||||
# 5. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
||||
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
|
||||
@ -4563,32 +4610,50 @@ class GenerationMixin:
|
||||
# is no match.
|
||||
|
||||
# 5.1. Ensure we don't generate beyond max_len or an EOS token
|
||||
if last_assistant_token_is_eos and n_matches == candidate_length:
|
||||
n_matches -= 1
|
||||
n_matches = min(n_matches, max_len - cur_len - 1)
|
||||
n_matches -= last_assistant_token_is_eos.int() * (n_matches == candidate_length).int()
|
||||
# make sure than already finished sequences always match until longest "still active" sequence
|
||||
n_matches = torch.clamp(n_matches, max=max_len - cur_len - 1)
|
||||
# make sure that finished sentences cannot slow down valid tokens
|
||||
n_matches = unfinished_sequences * n_matches + (1 - unfinished_sequences) * (unfinished_sequences * n_matches).max()
|
||||
|
||||
# 5.2. Get the valid continuation, after the matching tokens
|
||||
valid_tokens = selected_tokens[:, : n_matches + 1]
|
||||
valid_tokens = selected_tokens[:, : n_matches.max() + 1]
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
eos_tokens = valid_tokens.eq(eos_token_id_tensor)
|
||||
finished_seq_mask = ~(unfinished_sequences.bool()[:, None].broadcast_to(valid_tokens.shape))
|
||||
eos_tokens_mask = torch.logical_or(eos_tokens.cumsum(-1).bool(), finished_seq_mask)
|
||||
valid_tokens = torch.where(eos_tokens_mask, eos_token_id_tensor, valid_tokens)
|
||||
|
||||
# check which sentence has finished
|
||||
unfinished_sequences = (1 - eos_tokens_mask.gather(1, n_matches[:, None]).squeeze(-1).int())
|
||||
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
|
||||
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
|
||||
|
||||
if streamer is not None:
|
||||
streamer.put(valid_tokens.cpu())
|
||||
new_cur_len = input_ids.shape[-1]
|
||||
|
||||
# 5.3. Discard past key values relative to unused assistant tokens
|
||||
new_cache_size = new_cur_len - 1
|
||||
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
|
||||
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size, n_matches)
|
||||
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
|
||||
assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1
|
||||
assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1, n_matches
|
||||
) # the assistant does not have the token after the last match, hence the -1
|
||||
|
||||
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
||||
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
||||
# cost of forecasting incorrect assistant tokens.
|
||||
if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic":
|
||||
if n_matches == int(num_assistant_tokens):
|
||||
num_assistant_tokens += 2.0
|
||||
if n_matches.min() == int(num_assistant_tokens):
|
||||
num_assistant_tokens += 2
|
||||
else:
|
||||
num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0)
|
||||
num_assistant_tokens = max(1, num_assistant_tokens - 1)
|
||||
|
||||
# Assistant: main logic end
|
||||
if synced_gpus and this_peer_finished:
|
||||
@ -4598,7 +4663,7 @@ class GenerationMixin:
|
||||
# Assistant: modified to append one tuple element per token, as in the other generation methods.
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
|
||||
scores += tuple(new_logits[:, i, :] for i in range(n_matches.max() + 1))
|
||||
|
||||
if "past_key_values" not in model_kwargs:
|
||||
added_len = new_cur_len
|
||||
@ -4639,19 +4704,6 @@ class GenerationMixin:
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
input_ids[:, -1]
|
||||
.tile(eos_token_id_tensor.shape[0], 1)
|
||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
|
||||
# stop if we exceed the maximum length
|
||||
if stopping_criteria(input_ids, scores):
|
||||
this_peer_finished = True
|
||||
@ -4684,15 +4736,32 @@ class GenerationMixin:
|
||||
return input_ids
|
||||
|
||||
|
||||
def _crop_past_key_values(model, past_key_values, maximum_length):
|
||||
def _crop_past_key_values(model, past_key_values, maximum_length=None, n_matches=None, left_cut=None):
|
||||
"""Crops the past key values up to a certain maximum length."""
|
||||
new_past = []
|
||||
if model.config.is_encoder_decoder:
|
||||
for idx in range(len(past_key_values)):
|
||||
if left_cut is None:
|
||||
k_cache = past_key_values[idx][0][:, :, :maximum_length, :]
|
||||
v_cache = past_key_values[idx][1][:, :, :maximum_length, :]
|
||||
else:
|
||||
k_cache = past_key_values[idx][0][:, :, left_cut:, :]
|
||||
v_cache = past_key_values[idx][1][:, :, left_cut:, :]
|
||||
|
||||
if n_matches is not None:
|
||||
for batch_idx in range(len(n_matches)):
|
||||
num_roll_left = n_matches.max() - n_matches[batch_idx]
|
||||
if num_roll_left > 0:
|
||||
# TODO(PVP) - check mem usage
|
||||
# k_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=k_cache.device), k_cache[batch_idx][:, :-num_roll_left].clone())
|
||||
# v_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=v_cache.device), v_cache[batch_idx][:, :-num_roll_left].clone())
|
||||
k_cache[batch_idx][:, num_roll_left:] = k_cache[batch_idx][:, :-num_roll_left].clone()
|
||||
v_cache[batch_idx][:, num_roll_left:] = v_cache[batch_idx][:, :-num_roll_left].clone()
|
||||
|
||||
new_past.append(
|
||||
(
|
||||
past_key_values[idx][0][:, :, :maximum_length, :],
|
||||
past_key_values[idx][1][:, :, :maximum_length, :],
|
||||
k_cache,
|
||||
v_cache,
|
||||
past_key_values[idx][2],
|
||||
past_key_values[idx][3],
|
||||
)
|
||||
@ -4749,6 +4818,9 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at
|
||||
cur_len += 1
|
||||
added_len -= cur_len
|
||||
|
||||
if torch.is_tensor(added_len):
|
||||
added_len = added_len.max().item()
|
||||
|
||||
for i in range(added_len):
|
||||
new_tuple = ()
|
||||
for layer in new_outputs:
|
||||
|
@ -1078,7 +1078,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, position_ids=None, **kwargs
|
||||
):
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
@ -1092,7 +1092,9 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
|
||||
input_ids = input_ids[:, remove_prefix_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if position_ids is not None and position_ids.shape[1] > input_ids.shape[1]:
|
||||
position_ids = position_ids[:, remove_prefix_length:]
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
|
@ -305,8 +305,11 @@ class WhisperPositionalEmbedding(nn.Embedding):
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
|
||||
def forward(self, input_ids, past_key_values_length=0):
|
||||
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
|
||||
def forward(self, input_ids, past_key_values_length=0, position_ids=None):
|
||||
if position_ids is None:
|
||||
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
|
||||
else:
|
||||
return self.weight[position_ids]
|
||||
|
||||
|
||||
class WhisperAttention(nn.Module):
|
||||
@ -1224,6 +1227,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
position_ids=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
@ -1322,9 +1326,9 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
|
||||
# embed positions
|
||||
if input_ids is not None:
|
||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids)
|
||||
else:
|
||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
@ -1506,6 +1510,7 @@ class WhisperModel(WhisperPreTrainedModel):
|
||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
decoder_position_ids: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
@ -1564,6 +1569,7 @@ class WhisperModel(WhisperPreTrainedModel):
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
position_ids=decoder_position_ids,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@ -1637,6 +1643,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -1691,6 +1698,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
decoder_position_ids=decoder_position_ids,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@ -1988,6 +1996,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
attention_mask=None,
|
||||
decoder_position_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None:
|
||||
@ -2002,12 +2012,16 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
|
||||
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"use_cache": use_cache,
|
||||
"decoder_attention_mask": None,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_position_ids": decoder_position_ids,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
@ -1624,6 +1624,48 @@ class GenerationTesterMixin:
|
||||
|
||||
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
||||
|
||||
|
||||
def test_assisted_decoding_greedy_batch(self):
|
||||
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
|
||||
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
return
|
||||
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
|
||||
):
|
||||
return
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=3)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
config.num_assistant_tokens = 2
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_assisted = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
assistant_model=model, # triggers assisted decoding
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
|
Reference in New Issue
Block a user