mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
🚨 [v5] generate
delegates default cache initialization to the model (#41505)
This commit is contained in:
@ -2167,7 +2167,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||
|
||||
|
||||
class WhisperNoSpeechDetection(LogitsProcessor):
|
||||
r"""This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits to follow the original implementation"""
|
||||
"""
|
||||
This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits
|
||||
to follow the original implementation
|
||||
"""
|
||||
|
||||
def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False):
|
||||
self.no_speech_token = no_speech_token
|
||||
@ -2188,6 +2191,10 @@ class WhisperNoSpeechDetection(LogitsProcessor):
|
||||
self.model = model
|
||||
|
||||
def set_inputs(self, inputs):
|
||||
# build `cache_position` on the fly
|
||||
seq_length = inputs["input_ids"].shape[1]
|
||||
inputs = self.model._get_initial_cache_position(seq_length, self.model.device, inputs)
|
||||
# prepare other inputs
|
||||
self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs}
|
||||
self.inputs["input_features"] = self.inputs.pop("inputs")
|
||||
|
||||
|
@ -548,6 +548,12 @@ class GenerationMixin(ContinuousMixin):
|
||||
# 2. Generic cache-dependent input preparation
|
||||
if past_key_values is not None:
|
||||
model_inputs["past_key_values"] = past_key_values
|
||||
# We check `use_cache` below because some stateful models (like `recurrent_gemma`) expect input slicing if
|
||||
# their caching mechanism is used. To define `use_cache`, the user-defined argument takes precedence.
|
||||
use_cache = kwargs.get("use_cache")
|
||||
if use_cache is None:
|
||||
use_cache = getattr(self.config, "use_cache", False)
|
||||
if past_key_values is None or use_cache:
|
||||
# TODO (joao): handle the case where cache length == input_ids length. The function below results in an
|
||||
# exception because we get empty input_ids after slicing. In essence, we need to roll back the cache 1
|
||||
# token to recompute the logits for the first token to be generated (but not all caches support roll backs)
|
||||
@ -589,7 +595,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
|
||||
model_input = kwargs.get(model_input_name)
|
||||
if model_input is not None:
|
||||
if past_key_values is not None:
|
||||
if past_key_values is not None or use_cache:
|
||||
current_input_length = (
|
||||
model_inputs["inputs_embeds"].shape[1]
|
||||
if model_inputs.get("inputs_embeds") is not None
|
||||
@ -1999,17 +2005,15 @@ class GenerationMixin(ContinuousMixin):
|
||||
elif "dynamic" in generation_config.cache_implementation:
|
||||
model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)
|
||||
|
||||
# TODO (joao): remove this `else` when we remove the last traces of the legacy cache format (v4.58.0, search
|
||||
# for `instance(past_key_values, Cache)` as well). In general, if `cache_implementation` is unset, cache
|
||||
# initialization should happen inside the model at prefill time.
|
||||
else:
|
||||
model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)
|
||||
|
||||
# TODO (joao): this logic is incomplete, e.g. `offloaded` should apply to both caches. Refactor this function
|
||||
# to correctly pass parameterization to both caches.
|
||||
if requires_cross_attention_cache and not isinstance(model_kwargs[cache_name], EncoderDecoderCache):
|
||||
model_kwargs[cache_name] = EncoderDecoderCache(
|
||||
model_kwargs[cache_name], # self-attention cache
|
||||
if (
|
||||
requires_cross_attention_cache
|
||||
and "past_key_values" in model_kwargs
|
||||
and not isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
|
||||
):
|
||||
model_kwargs["past_key_values"] = EncoderDecoderCache(
|
||||
model_kwargs["past_key_values"], # self-attention cache
|
||||
DynamicCache(**dynamic_cache_kwargs), # cross-attention cache
|
||||
)
|
||||
|
||||
@ -3335,7 +3339,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
|
||||
# pluck the cache from the beam indices that will be used in the next iteration
|
||||
# NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
|
||||
if model_kwargs.get("past_key_values", None) is not None:
|
||||
if model_kwargs.get("past_key_values") is not None:
|
||||
beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len])
|
||||
if hasattr(self, "_reorder_cache"):
|
||||
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
|
||||
|
@ -802,7 +802,7 @@ class BartDecoder(BartPreTrainedModel):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -671,7 +671,11 @@ class BertModel(BertPreTrainedModel):
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
@ -533,7 +533,11 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
@ -1976,7 +1976,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -752,7 +752,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -739,7 +739,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -651,7 +651,11 @@ class CamembertModel(CamembertPreTrainedModel):
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
@ -611,7 +611,11 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
@ -610,7 +610,11 @@ class ElectraModel(ElectraPreTrainedModel):
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
@ -647,7 +647,11 @@ class ErnieModel(ErniePreTrainedModel):
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
|
@ -234,7 +234,11 @@ class ErnieModel(BertModel):
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
|
@ -616,10 +616,6 @@ class FSMTDecoder(nn.Module):
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
# initialize `past_key_values`
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
|
||||
x += positions
|
||||
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
@ -912,6 +908,9 @@ class FSMTModel(PretrainedFSMTModel):
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
raise ValueError("Make sure that `decoder_input_ids` or `decoder_inputs_embeds` are passed.")
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
|
@ -1029,7 +1029,7 @@ class Kosmos2TextTransformer(nn.Module):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -750,7 +750,7 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -800,7 +800,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -802,7 +802,7 @@ class MvpDecoder(MvpPreTrainedModel):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -802,7 +802,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -715,7 +715,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -1182,7 +1182,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -621,7 +621,11 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
@ -645,7 +645,11 @@ class RobertaPreLayerNormModel(RobertaPreLayerNormPreTrainedModel):
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
@ -430,9 +430,6 @@ class RoFormerEncoder(nn.Module):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
@ -736,6 +733,13 @@ class RoFormerModel(RoFormerPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
||||
|
||||
if attention_mask is None:
|
||||
|
@ -550,7 +550,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -808,12 +808,11 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
if self.config.is_encoder_decoder:
|
||||
past_key_values = EncoderDecoderCache(
|
||||
DynamicCache(config=self.config), DynamicCache(config=self.config)
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
else:
|
||||
past_key_values = DynamicCache(config=self.config)
|
||||
|
||||
past_key_values_length = 0
|
||||
if cache_position is not None:
|
||||
|
@ -469,7 +469,7 @@ class XGLMModel(XGLMPreTrainedModel):
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
|
@ -640,7 +640,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
@ -627,7 +627,11 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
@ -742,7 +742,11 @@ class XmodModel(XmodPreTrainedModel):
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch RoFormer model."""
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from transformers import RoFormerConfig, is_torch_available
|
||||
@ -209,6 +210,8 @@ class RoFormerModelTester:
|
||||
token_labels,
|
||||
choice_labels,
|
||||
):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
model = RoFormerForCausalLM(config=config).to(torch_device).eval()
|
||||
torch.manual_seed(0)
|
||||
output_without_past_cache = model.generate(
|
||||
|
Reference in New Issue
Block a user