mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 17:48:57 +08:00
Compare commits
9 Commits
v4.52.2
...
assistant_
Author | SHA1 | Date | |
---|---|---|---|
996d948e51 | |||
b6ad3e925c | |||
51f0e4f1a7 | |||
daa7cc9575 | |||
7a071d7952 | |||
04a6581f9b | |||
60b5aab973 | |||
202d74bb02 | |||
3f26e69034 |
@ -1621,8 +1621,6 @@ class GenerationMixin:
|
|||||||
"num_return_sequences has to be 1 when doing assisted generate, "
|
"num_return_sequences has to be 1 when doing assisted generate, "
|
||||||
f"but is {generation_config.num_return_sequences}."
|
f"but is {generation_config.num_return_sequences}."
|
||||||
)
|
)
|
||||||
if batch_size > 1:
|
|
||||||
raise ValueError("assisted generate is only supported for batch_size = 1")
|
|
||||||
if not model_kwargs["use_cache"]:
|
if not model_kwargs["use_cache"]:
|
||||||
raise ValueError("assisted generate requires `use_cache=True`")
|
raise ValueError("assisted generate requires `use_cache=True`")
|
||||||
|
|
||||||
@ -4427,6 +4425,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
|
||||||
# other auxiliary variables
|
# other auxiliary variables
|
||||||
max_len = stopping_criteria[0].max_length
|
max_len = stopping_criteria[0].max_length
|
||||||
@ -4441,6 +4440,13 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
this_peer_finished = False # used by synced_gpus only
|
this_peer_finished = False # used by synced_gpus only
|
||||||
|
|
||||||
|
# 2 * max_len to give us room to potentially left cut
|
||||||
|
position_ids = torch.arange(2 * max_len, device=input_ids.device, dtype=torch.long)[None, :].broadcast_to(batch_size, 2 * max_len) if batch_size > 1 else None
|
||||||
|
attention_mask = torch.ones_like(position_ids) if position_ids is not None else None
|
||||||
|
n_matches = None
|
||||||
|
eos_tokens_mask = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if synced_gpus:
|
if synced_gpus:
|
||||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||||
@ -4452,6 +4458,31 @@ class GenerationMixin:
|
|||||||
if this_peer_finished_flag.item() == 0.0:
|
if this_peer_finished_flag.item() == 0.0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Rotate everything for bsz > 1
|
||||||
|
if n_matches is not None and position_ids is not None:
|
||||||
|
# compute by how much everything can be rotated
|
||||||
|
shift = unfinished_sequences * (n_matches.max() - n_matches) + (1 - unfinished_sequences) * (eos_tokens_mask.sum(-1) - 1)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
if shift[i] > 0:
|
||||||
|
input_ids[i][shift[i]:] = input_ids[i][:-shift[i]].clone()
|
||||||
|
input_ids[i][:shift[i]] = self.config.pad_token_id
|
||||||
|
|
||||||
|
position_ids = position_ids.add(-shift[:, None]).clamp(min=0)
|
||||||
|
attention_mask[:, :-1] = position_ids[:, 1:] > 0
|
||||||
|
|
||||||
|
left_cut = (1 - attention_mask).sum(-1).min()
|
||||||
|
|
||||||
|
if left_cut > 0:
|
||||||
|
position_ids = position_ids[:, left_cut:]
|
||||||
|
attention_mask = attention_mask[:, left_cut:]
|
||||||
|
input_ids = input_ids[:, left_cut:]
|
||||||
|
|
||||||
|
model_kwargs["past_key_values"] = _crop_past_key_values(self, model_kwargs["past_key_values"], left_cut=left_cut)
|
||||||
|
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
|
||||||
|
assistant_model, model_kwargs["assistant_past_key_values"], left_cut=left_cut
|
||||||
|
) # the assistant does not have the token after the last match, hence the -1
|
||||||
|
|
||||||
# Assistant: main logic start
|
# Assistant: main logic start
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
|
|
||||||
@ -4459,17 +4490,24 @@ class GenerationMixin:
|
|||||||
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
||||||
# need access to the assistant cache to secure strong speedups.
|
# need access to the assistant cache to secure strong speedups.
|
||||||
candidate_input_ids = input_ids
|
candidate_input_ids = input_ids
|
||||||
for _ in range(int(num_assistant_tokens)):
|
|
||||||
|
for assist_idx in range(int(num_assistant_tokens)):
|
||||||
# 1.1. use the assistant model to obtain the next candidate logits
|
# 1.1. use the assistant model to obtain the next candidate logits
|
||||||
if "assistant_past_key_values" in model_kwargs:
|
if "assistant_past_key_values" in model_kwargs:
|
||||||
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
|
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
|
||||||
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
||||||
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
||||||
assist_inputs = candidate_input_ids[:, -new_token_len:]
|
assist_inputs = candidate_input_ids[:, -new_token_len:]
|
||||||
|
|
||||||
|
assist_position_ids = position_ids[:, cur_len - new_token_len + assist_idx:cur_len + assist_idx] if position_ids is not None else None
|
||||||
|
assist_attention_mask = attention_mask[:, :cur_len + assist_idx] if attention_mask is not None else None
|
||||||
|
|
||||||
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
|
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
|
||||||
if assistant_model.config.is_encoder_decoder:
|
if assistant_model.config.is_encoder_decoder:
|
||||||
assistant_model_outputs = assistant_model(
|
assistant_model_outputs = assistant_model(
|
||||||
decoder_input_ids=assist_inputs,
|
decoder_input_ids=assist_inputs,
|
||||||
|
decoder_position_ids=assist_position_ids,
|
||||||
|
decoder_attention_mask=assist_attention_mask,
|
||||||
past_key_values=model_kwargs["assistant_past_key_values"],
|
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||||
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||||
)
|
)
|
||||||
@ -4505,16 +4543,19 @@ class GenerationMixin:
|
|||||||
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
|
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
|
||||||
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
|
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
|
||||||
|
|
||||||
|
if stopping_criteria(candidate_input_ids, assistant_model_outputs.logits[:, -1, :]):
|
||||||
|
break
|
||||||
|
|
||||||
# 1.3. stop assistant generation on EOS
|
# 1.3. stop assistant generation on EOS
|
||||||
if eos_token_id_tensor is not None:
|
if eos_token_id_tensor is not None:
|
||||||
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
|
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
|
||||||
last_assistant_token_is_eos = (
|
last_assistant_token_is_eos = (
|
||||||
~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
|
~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
|
||||||
)
|
)
|
||||||
if last_assistant_token_is_eos:
|
if torch.logical_or(~unfinished_sequences.bool(), last_assistant_token_is_eos).all():
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
last_assistant_token_is_eos = False
|
last_assistant_token_is_eos = torch.zeros((1, 1), device=candidate_input_ids.device, dtype=torch.int8)
|
||||||
|
|
||||||
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
||||||
|
|
||||||
@ -4527,6 +4568,12 @@ class GenerationMixin:
|
|||||||
candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
|
candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
|
||||||
candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
|
candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
|
||||||
|
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
candidate_kwargs["decoder_position_ids"] = position_ids[:, :cur_len + candidate_length] if position_ids is not None else None
|
||||||
|
candidate_kwargs["decoder_attention_mask"] = attention_mask[:, :cur_len + candidate_length] if attention_mask is not None else None
|
||||||
|
else:
|
||||||
|
candidate_kwargs["position_ids"] = position_ids[:, :cur_len + candidate_length]
|
||||||
|
|
||||||
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
|
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
|
||||||
|
|
||||||
# 2.2. Run a forward pass on the candidate sequence
|
# 2.2. Run a forward pass on the candidate sequence
|
||||||
@ -4555,7 +4602,7 @@ class GenerationMixin:
|
|||||||
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
|
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
|
||||||
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
|
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
|
||||||
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
|
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
|
||||||
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
|
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum(-1)
|
||||||
|
|
||||||
# 5. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
# 5. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
||||||
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
|
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
|
||||||
@ -4563,32 +4610,50 @@ class GenerationMixin:
|
|||||||
# is no match.
|
# is no match.
|
||||||
|
|
||||||
# 5.1. Ensure we don't generate beyond max_len or an EOS token
|
# 5.1. Ensure we don't generate beyond max_len or an EOS token
|
||||||
if last_assistant_token_is_eos and n_matches == candidate_length:
|
n_matches -= last_assistant_token_is_eos.int() * (n_matches == candidate_length).int()
|
||||||
n_matches -= 1
|
# make sure than already finished sequences always match until longest "still active" sequence
|
||||||
n_matches = min(n_matches, max_len - cur_len - 1)
|
n_matches = torch.clamp(n_matches, max=max_len - cur_len - 1)
|
||||||
|
# make sure that finished sentences cannot slow down valid tokens
|
||||||
|
n_matches = unfinished_sequences * n_matches + (1 - unfinished_sequences) * (unfinished_sequences * n_matches).max()
|
||||||
|
|
||||||
# 5.2. Get the valid continuation, after the matching tokens
|
# 5.2. Get the valid continuation, after the matching tokens
|
||||||
valid_tokens = selected_tokens[:, : n_matches + 1]
|
valid_tokens = selected_tokens[:, : n_matches.max() + 1]
|
||||||
|
|
||||||
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
|
if eos_token_id_tensor is not None:
|
||||||
|
eos_tokens = valid_tokens.eq(eos_token_id_tensor)
|
||||||
|
finished_seq_mask = ~(unfinished_sequences.bool()[:, None].broadcast_to(valid_tokens.shape))
|
||||||
|
eos_tokens_mask = torch.logical_or(eos_tokens.cumsum(-1).bool(), finished_seq_mask)
|
||||||
|
valid_tokens = torch.where(eos_tokens_mask, eos_token_id_tensor, valid_tokens)
|
||||||
|
|
||||||
|
# check which sentence has finished
|
||||||
|
unfinished_sequences = (1 - eos_tokens_mask.gather(1, n_matches[:, None]).squeeze(-1).int())
|
||||||
|
|
||||||
|
# stop when each sentence is finished
|
||||||
|
if unfinished_sequences.max() == 0:
|
||||||
|
this_peer_finished = True
|
||||||
|
|
||||||
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
|
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
|
||||||
|
|
||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
streamer.put(valid_tokens.cpu())
|
streamer.put(valid_tokens.cpu())
|
||||||
new_cur_len = input_ids.shape[-1]
|
new_cur_len = input_ids.shape[-1]
|
||||||
|
|
||||||
# 5.3. Discard past key values relative to unused assistant tokens
|
# 5.3. Discard past key values relative to unused assistant tokens
|
||||||
new_cache_size = new_cur_len - 1
|
new_cache_size = new_cur_len - 1
|
||||||
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
|
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size, n_matches)
|
||||||
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
|
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
|
||||||
assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1
|
assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1, n_matches
|
||||||
) # the assistant does not have the token after the last match, hence the -1
|
) # the assistant does not have the token after the last match, hence the -1
|
||||||
|
|
||||||
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
||||||
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
||||||
# cost of forecasting incorrect assistant tokens.
|
# cost of forecasting incorrect assistant tokens.
|
||||||
if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic":
|
if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic":
|
||||||
if n_matches == int(num_assistant_tokens):
|
if n_matches.min() == int(num_assistant_tokens):
|
||||||
num_assistant_tokens += 2.0
|
num_assistant_tokens += 2
|
||||||
else:
|
else:
|
||||||
num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0)
|
num_assistant_tokens = max(1, num_assistant_tokens - 1)
|
||||||
|
|
||||||
# Assistant: main logic end
|
# Assistant: main logic end
|
||||||
if synced_gpus and this_peer_finished:
|
if synced_gpus and this_peer_finished:
|
||||||
@ -4598,7 +4663,7 @@ class GenerationMixin:
|
|||||||
# Assistant: modified to append one tuple element per token, as in the other generation methods.
|
# Assistant: modified to append one tuple element per token, as in the other generation methods.
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if output_scores:
|
if output_scores:
|
||||||
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
|
scores += tuple(new_logits[:, i, :] for i in range(n_matches.max() + 1))
|
||||||
|
|
||||||
if "past_key_values" not in model_kwargs:
|
if "past_key_values" not in model_kwargs:
|
||||||
added_len = new_cur_len
|
added_len = new_cur_len
|
||||||
@ -4639,19 +4704,6 @@ class GenerationMixin:
|
|||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
|
||||||
if eos_token_id_tensor is not None:
|
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
|
||||||
input_ids[:, -1]
|
|
||||||
.tile(eos_token_id_tensor.shape[0], 1)
|
|
||||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
|
||||||
.prod(dim=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
# stop when each sentence is finished
|
|
||||||
if unfinished_sequences.max() == 0:
|
|
||||||
this_peer_finished = True
|
|
||||||
|
|
||||||
# stop if we exceed the maximum length
|
# stop if we exceed the maximum length
|
||||||
if stopping_criteria(input_ids, scores):
|
if stopping_criteria(input_ids, scores):
|
||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
@ -4684,15 +4736,32 @@ class GenerationMixin:
|
|||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
def _crop_past_key_values(model, past_key_values, maximum_length):
|
def _crop_past_key_values(model, past_key_values, maximum_length=None, n_matches=None, left_cut=None):
|
||||||
"""Crops the past key values up to a certain maximum length."""
|
"""Crops the past key values up to a certain maximum length."""
|
||||||
new_past = []
|
new_past = []
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
for idx in range(len(past_key_values)):
|
for idx in range(len(past_key_values)):
|
||||||
|
if left_cut is None:
|
||||||
|
k_cache = past_key_values[idx][0][:, :, :maximum_length, :]
|
||||||
|
v_cache = past_key_values[idx][1][:, :, :maximum_length, :]
|
||||||
|
else:
|
||||||
|
k_cache = past_key_values[idx][0][:, :, left_cut:, :]
|
||||||
|
v_cache = past_key_values[idx][1][:, :, left_cut:, :]
|
||||||
|
|
||||||
|
if n_matches is not None:
|
||||||
|
for batch_idx in range(len(n_matches)):
|
||||||
|
num_roll_left = n_matches.max() - n_matches[batch_idx]
|
||||||
|
if num_roll_left > 0:
|
||||||
|
# TODO(PVP) - check mem usage
|
||||||
|
# k_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=k_cache.device), k_cache[batch_idx][:, :-num_roll_left].clone())
|
||||||
|
# v_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=v_cache.device), v_cache[batch_idx][:, :-num_roll_left].clone())
|
||||||
|
k_cache[batch_idx][:, num_roll_left:] = k_cache[batch_idx][:, :-num_roll_left].clone()
|
||||||
|
v_cache[batch_idx][:, num_roll_left:] = v_cache[batch_idx][:, :-num_roll_left].clone()
|
||||||
|
|
||||||
new_past.append(
|
new_past.append(
|
||||||
(
|
(
|
||||||
past_key_values[idx][0][:, :, :maximum_length, :],
|
k_cache,
|
||||||
past_key_values[idx][1][:, :, :maximum_length, :],
|
v_cache,
|
||||||
past_key_values[idx][2],
|
past_key_values[idx][2],
|
||||||
past_key_values[idx][3],
|
past_key_values[idx][3],
|
||||||
)
|
)
|
||||||
@ -4749,6 +4818,9 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at
|
|||||||
cur_len += 1
|
cur_len += 1
|
||||||
added_len -= cur_len
|
added_len -= cur_len
|
||||||
|
|
||||||
|
if torch.is_tensor(added_len):
|
||||||
|
added_len = added_len.max().item()
|
||||||
|
|
||||||
for i in range(added_len):
|
for i in range(added_len):
|
||||||
new_tuple = ()
|
new_tuple = ()
|
||||||
for layer in new_outputs:
|
for layer in new_outputs:
|
||||||
|
@ -1078,7 +1078,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, position_ids=None, **kwargs
|
||||||
):
|
):
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
past_length = past_key_values[0][0].shape[2]
|
past_length = past_key_values[0][0].shape[2]
|
||||||
@ -1092,7 +1092,9 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
|
|
||||||
input_ids = input_ids[:, remove_prefix_length:]
|
input_ids = input_ids[:, remove_prefix_length:]
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
if position_ids is not None and position_ids.shape[1] > input_ids.shape[1]:
|
||||||
|
position_ids = position_ids[:, remove_prefix_length:]
|
||||||
|
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
# create position_ids on the fly for batch generation
|
# create position_ids on the fly for batch generation
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
@ -305,8 +305,11 @@ class WhisperPositionalEmbedding(nn.Embedding):
|
|||||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||||
super().__init__(num_positions, embedding_dim)
|
super().__init__(num_positions, embedding_dim)
|
||||||
|
|
||||||
def forward(self, input_ids, past_key_values_length=0):
|
def forward(self, input_ids, past_key_values_length=0, position_ids=None):
|
||||||
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
|
if position_ids is None:
|
||||||
|
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
|
||||||
|
else:
|
||||||
|
return self.weight[position_ids]
|
||||||
|
|
||||||
|
|
||||||
class WhisperAttention(nn.Module):
|
class WhisperAttention(nn.Module):
|
||||||
@ -1224,6 +1227,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
position_ids=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@ -1322,9 +1326,9 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids)
|
||||||
else:
|
else:
|
||||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids)
|
||||||
|
|
||||||
hidden_states = inputs_embeds + positions
|
hidden_states = inputs_embeds + positions
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
@ -1506,6 +1510,7 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
|
decoder_position_ids: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
@ -1564,6 +1569,7 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||||||
cross_attn_head_mask=cross_attn_head_mask,
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
|
position_ids=decoder_position_ids,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@ -1637,6 +1643,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
|
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
@ -1691,6 +1698,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
cross_attn_head_mask=cross_attn_head_mask,
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
|
decoder_position_ids=decoder_position_ids,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@ -1988,6 +1996,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
decoder_position_ids=None,
|
||||||
|
decoder_attention_mask=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
@ -2002,12 +2012,16 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
|
|
||||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||||
|
|
||||||
|
if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
|
||||||
|
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
"decoder_attention_mask": None,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
"decoder_position_ids": decoder_position_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1624,6 +1624,48 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_assisted_decoding_greedy_batch(self):
|
||||||
|
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
|
||||||
|
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||||
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||||
|
return
|
||||||
|
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
|
||||||
|
if any(
|
||||||
|
model_name in model_class.__name__.lower()
|
||||||
|
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# enable cache
|
||||||
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=3)
|
||||||
|
|
||||||
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
|
if not hasattr(config, "use_cache"):
|
||||||
|
return
|
||||||
|
|
||||||
|
config.use_cache = True
|
||||||
|
config.is_decoder = True
|
||||||
|
config.num_assistant_tokens = 2
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
output_assisted = model.generate(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
max_length=max_length,
|
||||||
|
num_beams=1,
|
||||||
|
do_sample=False,
|
||||||
|
assistant_model=model, # triggers assisted decoding
|
||||||
|
output_scores=True,
|
||||||
|
output_hidden_states=True,
|
||||||
|
output_attentions=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
def test_generate_with_head_masking(self):
|
def test_generate_with_head_masking(self):
|
||||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
|
Reference in New Issue
Block a user