Compare commits

...

2 Commits

Author SHA1 Message Date
5409ae5df0 default padding side to right for llama 2023-07-24 09:51:08 +02:00
d0734192c6 remove persistent tensor 2023-07-21 17:41:46 +02:00
4 changed files with 8 additions and 4 deletions

View File

@ -107,7 +107,7 @@ class OpenLlamaRotaryEmbedding(torch.nn.Module):
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
@ -171,7 +171,7 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

View File

@ -97,7 +97,7 @@ class LlamaRotaryEmbedding(torch.nn.Module):
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
@ -159,7 +159,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

View File

@ -107,6 +107,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
bos_token="<s>",
eos_token="</s>",
pad_token=None,
padding_side="right",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
@ -124,6 +125,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
padding_side=padding_side,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=self.sp_model_kwargs,

View File

@ -108,6 +108,7 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
padding_side="right",
add_bos_token=True,
add_eos_token=False,
**kwargs,
@ -116,6 +117,7 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
padding_side=padding_side,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,