mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-10-21 01:23:56 +08:00 
			
		
		
		
	Compare commits
	
		
			9 Commits
		
	
	
		
			v4.43.3
			...
			assistant_
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 996d948e51 | |||
| b6ad3e925c | |||
| 51f0e4f1a7 | |||
| daa7cc9575 | |||
| 7a071d7952 | |||
| 04a6581f9b | |||
| 60b5aab973 | |||
| 202d74bb02 | |||
| 3f26e69034 | 
| @ -1621,8 +1621,6 @@ class GenerationMixin: | ||||
|                     "num_return_sequences has to be 1 when doing assisted generate, " | ||||
|                     f"but is {generation_config.num_return_sequences}." | ||||
|                 ) | ||||
|             if batch_size > 1: | ||||
|                 raise ValueError("assisted generate is only supported for batch_size = 1") | ||||
|             if not model_kwargs["use_cache"]: | ||||
|                 raise ValueError("assisted generate requires `use_cache=True`") | ||||
|  | ||||
| @ -4427,6 +4425,7 @@ class GenerationMixin: | ||||
|  | ||||
|         # keep track of which sequences are already finished | ||||
|         unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) | ||||
|         batch_size = input_ids.shape[0] | ||||
|  | ||||
|         # other auxiliary variables | ||||
|         max_len = stopping_criteria[0].max_length | ||||
| @ -4441,6 +4440,13 @@ class GenerationMixin: | ||||
|         ) | ||||
|  | ||||
|         this_peer_finished = False  # used by synced_gpus only | ||||
|  | ||||
|         # 2 * max_len to give us room to potentially left cut | ||||
|         position_ids = torch.arange(2 * max_len, device=input_ids.device, dtype=torch.long)[None, :].broadcast_to(batch_size, 2 * max_len) if batch_size > 1 else None | ||||
|         attention_mask = torch.ones_like(position_ids) if position_ids is not None else  None | ||||
|         n_matches = None | ||||
|         eos_tokens_mask = None | ||||
|  | ||||
|         while True: | ||||
|             if synced_gpus: | ||||
|                 # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | ||||
| @ -4452,6 +4458,31 @@ class GenerationMixin: | ||||
|                 if this_peer_finished_flag.item() == 0.0: | ||||
|                     break | ||||
|  | ||||
|             # Rotate everything for bsz > 1 | ||||
|             if n_matches is not None and position_ids is not None: | ||||
|                 # compute by how much everything can be rotated | ||||
|                 shift = unfinished_sequences * (n_matches.max() - n_matches) + (1 - unfinished_sequences) * (eos_tokens_mask.sum(-1) - 1) | ||||
|  | ||||
|                 for i in range(batch_size): | ||||
|                     if shift[i] > 0: | ||||
|                         input_ids[i][shift[i]:] = input_ids[i][:-shift[i]].clone() | ||||
|                         input_ids[i][:shift[i]] = self.config.pad_token_id | ||||
|  | ||||
|                 position_ids = position_ids.add(-shift[:, None]).clamp(min=0) | ||||
|                 attention_mask[:, :-1] = position_ids[:, 1:] > 0 | ||||
|  | ||||
|                 left_cut = (1 - attention_mask).sum(-1).min() | ||||
|  | ||||
|                 if left_cut > 0: | ||||
|                     position_ids = position_ids[:, left_cut:] | ||||
|                     attention_mask = attention_mask[:, left_cut:] | ||||
|                     input_ids = input_ids[:, left_cut:] | ||||
|  | ||||
|                     model_kwargs["past_key_values"] = _crop_past_key_values(self, model_kwargs["past_key_values"], left_cut=left_cut) | ||||
|                     model_kwargs["assistant_past_key_values"] = _crop_past_key_values( | ||||
|                         assistant_model, model_kwargs["assistant_past_key_values"], left_cut=left_cut | ||||
|                     )  # the assistant does not have the token after the last match, hence the -1 | ||||
|  | ||||
|             # Assistant: main logic start | ||||
|             cur_len = input_ids.shape[-1] | ||||
|  | ||||
| @ -4459,17 +4490,24 @@ class GenerationMixin: | ||||
|             # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we | ||||
|             # need access to the assistant cache to secure strong speedups. | ||||
|             candidate_input_ids = input_ids | ||||
|             for _ in range(int(num_assistant_tokens)): | ||||
|  | ||||
|             for assist_idx in range(int(num_assistant_tokens)): | ||||
|                 # 1.1. use the assistant model to obtain the next candidate logits | ||||
|                 if "assistant_past_key_values" in model_kwargs: | ||||
|                     prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] | ||||
|                     # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) | ||||
|                     new_token_len = candidate_input_ids.shape[1] - prev_seq_len | ||||
|                     assist_inputs = candidate_input_ids[:, -new_token_len:] | ||||
|  | ||||
|                     assist_position_ids = position_ids[:, cur_len - new_token_len + assist_idx:cur_len + assist_idx] if position_ids is not None else None | ||||
|                     assist_attention_mask = attention_mask[:, :cur_len + assist_idx] if attention_mask is not None else None | ||||
|  | ||||
|                     # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 | ||||
|                     if assistant_model.config.is_encoder_decoder: | ||||
|                         assistant_model_outputs = assistant_model( | ||||
|                             decoder_input_ids=assist_inputs, | ||||
|                             decoder_position_ids=assist_position_ids, | ||||
|                             decoder_attention_mask=assist_attention_mask, | ||||
|                             past_key_values=model_kwargs["assistant_past_key_values"], | ||||
|                             encoder_outputs=model_kwargs["assistant_encoder_outputs"], | ||||
|                         ) | ||||
| @ -4505,16 +4543,19 @@ class GenerationMixin: | ||||
|                 new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) | ||||
|                 candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) | ||||
|  | ||||
|                 if stopping_criteria(candidate_input_ids, assistant_model_outputs.logits[:, -1, :]): | ||||
|                     break | ||||
|  | ||||
|                 # 1.3. stop assistant generation on EOS | ||||
|                 if eos_token_id_tensor is not None: | ||||
|                     last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1) | ||||
|                     last_assistant_token_is_eos = ( | ||||
|                         ~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() | ||||
|                     ) | ||||
|                     if last_assistant_token_is_eos: | ||||
|                     if torch.logical_or(~unfinished_sequences.bool(), last_assistant_token_is_eos).all(): | ||||
|                         break | ||||
|                 else: | ||||
|                     last_assistant_token_is_eos = False | ||||
|                     last_assistant_token_is_eos = torch.zeros((1, 1), device=candidate_input_ids.device, dtype=torch.int8) | ||||
|  | ||||
|             candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] | ||||
|  | ||||
| @ -4527,6 +4568,12 @@ class GenerationMixin: | ||||
|             candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1]) | ||||
|             candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) | ||||
|  | ||||
|             if self.config.is_encoder_decoder: | ||||
|                 candidate_kwargs["decoder_position_ids"] = position_ids[:, :cur_len + candidate_length] if position_ids is not None else None | ||||
|                 candidate_kwargs["decoder_attention_mask"] = attention_mask[:, :cur_len + candidate_length] if attention_mask is not None else None | ||||
|             else: | ||||
|                 candidate_kwargs["position_ids"] = position_ids[:, :cur_len + candidate_length] | ||||
|  | ||||
|             model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) | ||||
|  | ||||
|             # 2.2. Run a forward pass on the candidate sequence | ||||
| @ -4555,7 +4602,7 @@ class GenerationMixin: | ||||
|             # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep | ||||
|             # the assistant forecasted tokens until the first mismatch, or until the max length is reached. | ||||
|             candidate_new_tokens = candidate_input_ids[:, -candidate_length:] | ||||
|             n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() | ||||
|             n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum(-1) | ||||
|  | ||||
|             # 5. Update variables according to the number of matching assistant tokens. Remember: the token generated | ||||
|             # by the model after the last candidate match is also valid, as it is generated from a correct sequence. | ||||
| @ -4563,32 +4610,50 @@ class GenerationMixin: | ||||
|             # is no match. | ||||
|  | ||||
|             # 5.1. Ensure we don't generate beyond max_len or an EOS token | ||||
|             if last_assistant_token_is_eos and n_matches == candidate_length: | ||||
|                 n_matches -= 1 | ||||
|             n_matches = min(n_matches, max_len - cur_len - 1) | ||||
|             n_matches -= last_assistant_token_is_eos.int() * (n_matches == candidate_length).int() | ||||
|             # make sure than already finished sequences always match until longest "still active" sequence | ||||
|             n_matches = torch.clamp(n_matches, max=max_len - cur_len - 1) | ||||
|             # make sure that finished sentences cannot slow down valid tokens | ||||
|             n_matches = unfinished_sequences * n_matches + (1 - unfinished_sequences) * (unfinished_sequences * n_matches).max() | ||||
|  | ||||
|             # 5.2. Get the valid continuation, after the matching tokens | ||||
|             valid_tokens = selected_tokens[:, : n_matches + 1] | ||||
|             valid_tokens = selected_tokens[:, : n_matches.max() + 1] | ||||
|  | ||||
|             # if eos_token was found in one sentence, set sentence to finished | ||||
|             if eos_token_id_tensor is not None: | ||||
|                 eos_tokens = valid_tokens.eq(eos_token_id_tensor) | ||||
|                 finished_seq_mask = ~(unfinished_sequences.bool()[:, None].broadcast_to(valid_tokens.shape)) | ||||
|                 eos_tokens_mask = torch.logical_or(eos_tokens.cumsum(-1).bool(), finished_seq_mask) | ||||
|                 valid_tokens = torch.where(eos_tokens_mask, eos_token_id_tensor, valid_tokens) | ||||
|  | ||||
|                 # check which sentence has finished | ||||
|                 unfinished_sequences = (1 - eos_tokens_mask.gather(1, n_matches[:, None]).squeeze(-1).int()) | ||||
|  | ||||
|                 # stop when each sentence is finished | ||||
|                 if unfinished_sequences.max() == 0: | ||||
|                     this_peer_finished = True | ||||
|  | ||||
|             input_ids = torch.cat((input_ids, valid_tokens), dim=-1) | ||||
|  | ||||
|             if streamer is not None: | ||||
|                 streamer.put(valid_tokens.cpu()) | ||||
|             new_cur_len = input_ids.shape[-1] | ||||
|  | ||||
|             # 5.3. Discard past key values relative to unused assistant tokens | ||||
|             new_cache_size = new_cur_len - 1 | ||||
|             outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) | ||||
|             outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size, n_matches) | ||||
|             model_kwargs["assistant_past_key_values"] = _crop_past_key_values( | ||||
|                 assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1 | ||||
|                 assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1, n_matches | ||||
|             )  # the assistant does not have the token after the last match, hence the -1 | ||||
|  | ||||
|             # 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, | ||||
|             # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the | ||||
|             # cost of forecasting incorrect assistant tokens. | ||||
|             if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic": | ||||
|                 if n_matches == int(num_assistant_tokens): | ||||
|                     num_assistant_tokens += 2.0 | ||||
|                 if n_matches.min() == int(num_assistant_tokens): | ||||
|                     num_assistant_tokens += 2 | ||||
|                 else: | ||||
|                     num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0) | ||||
|                     num_assistant_tokens = max(1, num_assistant_tokens - 1) | ||||
|  | ||||
|             # Assistant: main logic end | ||||
|             if synced_gpus and this_peer_finished: | ||||
| @ -4598,7 +4663,7 @@ class GenerationMixin: | ||||
|             # Assistant: modified to append one tuple element per token, as in the other generation methods. | ||||
|             if return_dict_in_generate: | ||||
|                 if output_scores: | ||||
|                     scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) | ||||
|                     scores += tuple(new_logits[:, i, :] for i in range(n_matches.max() + 1)) | ||||
|  | ||||
|                 if "past_key_values" not in model_kwargs: | ||||
|                     added_len = new_cur_len | ||||
| @ -4639,19 +4704,6 @@ class GenerationMixin: | ||||
|                 outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | ||||
|             ) | ||||
|  | ||||
|             # if eos_token was found in one sentence, set sentence to finished | ||||
|             if eos_token_id_tensor is not None: | ||||
|                 unfinished_sequences = unfinished_sequences.mul( | ||||
|                     input_ids[:, -1] | ||||
|                     .tile(eos_token_id_tensor.shape[0], 1) | ||||
|                     .ne(eos_token_id_tensor.unsqueeze(1)) | ||||
|                     .prod(dim=0) | ||||
|                 ) | ||||
|  | ||||
|                 # stop when each sentence is finished | ||||
|                 if unfinished_sequences.max() == 0: | ||||
|                     this_peer_finished = True | ||||
|  | ||||
|             # stop if we exceed the maximum length | ||||
|             if stopping_criteria(input_ids, scores): | ||||
|                 this_peer_finished = True | ||||
| @ -4684,15 +4736,32 @@ class GenerationMixin: | ||||
|             return input_ids | ||||
|  | ||||
|  | ||||
| def _crop_past_key_values(model, past_key_values, maximum_length): | ||||
| def _crop_past_key_values(model, past_key_values, maximum_length=None, n_matches=None, left_cut=None): | ||||
|     """Crops the past key values up to a certain maximum length.""" | ||||
|     new_past = [] | ||||
|     if model.config.is_encoder_decoder: | ||||
|         for idx in range(len(past_key_values)): | ||||
|             if left_cut is None: | ||||
|                 k_cache = past_key_values[idx][0][:, :, :maximum_length, :] | ||||
|                 v_cache = past_key_values[idx][1][:, :, :maximum_length, :] | ||||
|             else: | ||||
|                 k_cache = past_key_values[idx][0][:, :, left_cut:, :] | ||||
|                 v_cache = past_key_values[idx][1][:, :, left_cut:, :] | ||||
|  | ||||
|             if n_matches is not None: | ||||
|                 for batch_idx in range(len(n_matches)): | ||||
|                     num_roll_left = n_matches.max() - n_matches[batch_idx] | ||||
|                     if num_roll_left > 0: | ||||
|                         # TODO(PVP) - check mem usage | ||||
|                         # k_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=k_cache.device), k_cache[batch_idx][:, :-num_roll_left].clone()) | ||||
|                         # v_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=v_cache.device), v_cache[batch_idx][:, :-num_roll_left].clone()) | ||||
|                         k_cache[batch_idx][:, num_roll_left:] = k_cache[batch_idx][:, :-num_roll_left].clone() | ||||
|                         v_cache[batch_idx][:, num_roll_left:] = v_cache[batch_idx][:, :-num_roll_left].clone() | ||||
|  | ||||
|             new_past.append( | ||||
|                 ( | ||||
|                     past_key_values[idx][0][:, :, :maximum_length, :], | ||||
|                     past_key_values[idx][1][:, :, :maximum_length, :], | ||||
|                     k_cache, | ||||
|                     v_cache, | ||||
|                     past_key_values[idx][2], | ||||
|                     past_key_values[idx][3], | ||||
|                 ) | ||||
| @ -4749,6 +4818,9 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at | ||||
|         cur_len += 1 | ||||
|         added_len -= cur_len | ||||
|  | ||||
|     if torch.is_tensor(added_len): | ||||
|         added_len = added_len.max().item() | ||||
|  | ||||
|     for i in range(added_len): | ||||
|         new_tuple = () | ||||
|         for layer in new_outputs: | ||||
|  | ||||
| @ -1078,7 +1078,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): | ||||
|         ) | ||||
|  | ||||
|     def prepare_inputs_for_generation( | ||||
|         self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | ||||
|         self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, position_ids=None, **kwargs | ||||
|     ): | ||||
|         if past_key_values is not None: | ||||
|             past_length = past_key_values[0][0].shape[2] | ||||
| @ -1092,7 +1092,9 @@ class LlamaForCausalLM(LlamaPreTrainedModel): | ||||
|  | ||||
|             input_ids = input_ids[:, remove_prefix_length:] | ||||
|  | ||||
|         position_ids = kwargs.get("position_ids", None) | ||||
|             if position_ids is not None and position_ids.shape[1] > input_ids.shape[1]: | ||||
|                 position_ids = position_ids[:, remove_prefix_length:] | ||||
|  | ||||
|         if attention_mask is not None and position_ids is None: | ||||
|             # create position_ids on the fly for batch generation | ||||
|             position_ids = attention_mask.long().cumsum(-1) - 1 | ||||
|  | ||||
| @ -305,8 +305,11 @@ class WhisperPositionalEmbedding(nn.Embedding): | ||||
|     def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): | ||||
|         super().__init__(num_positions, embedding_dim) | ||||
|  | ||||
|     def forward(self, input_ids, past_key_values_length=0): | ||||
|         return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] | ||||
|     def forward(self, input_ids, past_key_values_length=0, position_ids=None): | ||||
|         if position_ids is None: | ||||
|             return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] | ||||
|         else: | ||||
|             return self.weight[position_ids] | ||||
|  | ||||
|  | ||||
| class WhisperAttention(nn.Module): | ||||
| @ -1224,6 +1227,7 @@ class WhisperDecoder(WhisperPreTrainedModel): | ||||
|         cross_attn_head_mask=None, | ||||
|         past_key_values=None, | ||||
|         inputs_embeds=None, | ||||
|         position_ids=None, | ||||
|         use_cache=None, | ||||
|         output_attentions=None, | ||||
|         output_hidden_states=None, | ||||
| @ -1322,9 +1326,9 @@ class WhisperDecoder(WhisperPreTrainedModel): | ||||
|  | ||||
|         # embed positions | ||||
|         if input_ids is not None: | ||||
|             positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) | ||||
|             positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids) | ||||
|         else: | ||||
|             positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) | ||||
|             positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids) | ||||
|  | ||||
|         hidden_states = inputs_embeds + positions | ||||
|         hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) | ||||
| @ -1506,6 +1510,7 @@ class WhisperModel(WhisperPreTrainedModel): | ||||
|         encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | ||||
|         past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | ||||
|         decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, | ||||
|         decoder_position_ids: Optional[torch.LongTensor] = None, | ||||
|         use_cache: Optional[bool] = None, | ||||
|         output_attentions: Optional[bool] = None, | ||||
|         output_hidden_states: Optional[bool] = None, | ||||
| @ -1564,6 +1569,7 @@ class WhisperModel(WhisperPreTrainedModel): | ||||
|             cross_attn_head_mask=cross_attn_head_mask, | ||||
|             past_key_values=past_key_values, | ||||
|             inputs_embeds=decoder_inputs_embeds, | ||||
|             position_ids=decoder_position_ids, | ||||
|             use_cache=use_cache, | ||||
|             output_attentions=output_attentions, | ||||
|             output_hidden_states=output_hidden_states, | ||||
| @ -1637,6 +1643,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): | ||||
|         encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | ||||
|         past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | ||||
|         decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, | ||||
|         decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, | ||||
|         labels: Optional[torch.LongTensor] = None, | ||||
|         use_cache: Optional[bool] = None, | ||||
|         output_attentions: Optional[bool] = None, | ||||
| @ -1691,6 +1698,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): | ||||
|             cross_attn_head_mask=cross_attn_head_mask, | ||||
|             past_key_values=past_key_values, | ||||
|             decoder_inputs_embeds=decoder_inputs_embeds, | ||||
|             decoder_position_ids=decoder_position_ids, | ||||
|             use_cache=use_cache, | ||||
|             output_attentions=output_attentions, | ||||
|             output_hidden_states=output_hidden_states, | ||||
| @ -1988,6 +1996,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): | ||||
|         use_cache=None, | ||||
|         encoder_outputs=None, | ||||
|         attention_mask=None, | ||||
|         decoder_position_ids=None, | ||||
|         decoder_attention_mask=None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if past_key_values is not None: | ||||
| @ -2002,12 +2012,16 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): | ||||
|  | ||||
|             decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] | ||||
|  | ||||
|             if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]: | ||||
|                 decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] | ||||
|  | ||||
|         return { | ||||
|             "encoder_outputs": encoder_outputs, | ||||
|             "past_key_values": past_key_values, | ||||
|             "decoder_input_ids": decoder_input_ids, | ||||
|             "use_cache": use_cache, | ||||
|             "decoder_attention_mask": None, | ||||
|             "decoder_attention_mask": decoder_attention_mask, | ||||
|             "decoder_position_ids": decoder_position_ids, | ||||
|         } | ||||
|  | ||||
|     @staticmethod | ||||
|  | ||||
| @ -1624,6 +1624,48 @@ class GenerationTesterMixin: | ||||
|  | ||||
|             self._check_outputs(output_assisted, input_ids, model.config, use_cache=True) | ||||
|  | ||||
|  | ||||
|     def test_assisted_decoding_greedy_batch(self): | ||||
|         # Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the | ||||
|         # exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking). | ||||
|  | ||||
|         for model_class in self.all_generative_model_classes: | ||||
|             # won't fix: FSMT and Reformer have a different cache variable type (and format). | ||||
|             if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): | ||||
|                 return | ||||
|             # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes | ||||
|             if any( | ||||
|                 model_name in model_class.__name__.lower() | ||||
|                 for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] | ||||
|             ): | ||||
|                 return | ||||
|  | ||||
|             # enable cache | ||||
|             config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=3) | ||||
|  | ||||
|             # NOTE: assisted generation only works with cache on at the moment. | ||||
|             if not hasattr(config, "use_cache"): | ||||
|                 return | ||||
|  | ||||
|             config.use_cache = True | ||||
|             config.is_decoder = True | ||||
|             config.num_assistant_tokens = 2 | ||||
|             model = model_class(config).to(torch_device).eval() | ||||
|             output_assisted = model.generate( | ||||
|                 input_ids, | ||||
|                 attention_mask=attention_mask, | ||||
|                 max_length=max_length, | ||||
|                 num_beams=1, | ||||
|                 do_sample=False, | ||||
|                 assistant_model=model,  # triggers assisted decoding | ||||
|                 output_scores=True, | ||||
|                 output_hidden_states=True, | ||||
|                 output_attentions=True, | ||||
|                 return_dict_in_generate=True, | ||||
|             ) | ||||
|  | ||||
|             self._check_outputs(output_assisted, input_ids, model.config, use_cache=True) | ||||
|  | ||||
|     def test_generate_with_head_masking(self): | ||||
|         """Test designed for encoder-decoder models to ensure the attention head masking is used.""" | ||||
|         attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	