Compare commits

...

7 Commits

9 changed files with 351 additions and 168 deletions

View File

@ -621,6 +621,7 @@ class PretrainedConfig(PushToHubMixin):
commit_hash = kwargs.pop("_commit_hash", None)
gguf_file = kwargs.get("gguf_file", None)
dduf_entries = kwargs.pop("dduf_entries", None)
if trust_remote_code is True:
logger.warning(
@ -634,7 +635,7 @@ class PretrainedConfig(PushToHubMixin):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path) or dduf_entries
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
# Special case when pretrained_model_name_or_path is a local file
resolved_config_file = pretrained_model_name_or_path
@ -644,26 +645,31 @@ class PretrainedConfig(PushToHubMixin):
resolved_config_file = download_url(pretrained_model_name_or_path)
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
try:
# Load from local folder or from cache or download from model Hub and cache
resolved_config_file = cached_file(
pretrained_model_name_or_path,
configuration_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
if resolved_config_file is None:
return None, kwargs
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
if dduf_entries:
resolved_config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
commit_hash = extract_commit_hash(dduf_entries[resolved_config_file].dduf_path, commit_hash)
else:
resolved_config_file = cached_file(
pretrained_model_name_or_path,
configuration_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
if resolved_config_file is None:
return None, kwargs
else:
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
@ -682,7 +688,10 @@ class PretrainedConfig(PushToHubMixin):
config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
else:
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
if dduf_entries:
config_dict = json.loads(dduf_entries[resolved_config_file].read_text())
else:
config_dict = cls._dict_from_json_file(resolved_config_file)
config_dict["_commit_hash"] = commit_hash
except (json.JSONDecodeError, UnicodeDecodeError):

View File

@ -544,17 +544,22 @@ class SpmConverter(Converter):
SpmExtractor = SentencePieceExtractor
special_tokens = {}
def __init__(self, *args):
def __init__(self, *args, **kwargs):
requires_backends(self, "protobuf")
super().__init__(*args)
dduf_entries = kwargs.get("dduf_entries", None)
# from .utils import sentencepiece_model_pb2 as model_pb2
model_pb2 = import_protobuf()
m = model_pb2.ModelProto()
with open(self.original_tokenizer.vocab_file, "rb") as f:
m.ParseFromString(f.read())
if dduf_entries:
with dduf_entries[self.original_tokenizer.vocab_file].as_mmap() as mm:
m.ParseFromString(mm)
else:
with open(self.original_tokenizer.vocab_file, "rb") as f:
m.ParseFromString(f.read())
self.proto = m
if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
@ -1606,7 +1611,7 @@ SLOW_TO_FAST_CONVERTERS = {
}
def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer:
def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False, **kwargs) -> Tokenizer:
"""
Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
@ -1625,7 +1630,7 @@ def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokeni
tokenizer_class_name = transformer_tokenizer.__class__.__name__
if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(transformer_tokenizer).converted()
return converter_class(transformer_tokenizer, **kwargs).converted()
else:
try:

View File

@ -131,6 +131,7 @@ if is_accelerate_available():
from accelerate.utils.modeling import get_state_dict_from_offload
if is_safetensors_available():
import safetensors
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
@ -495,21 +496,29 @@ def load_state_dict(
is_quantized: bool = False,
map_location: Optional[Union[str, torch.device]] = None,
weights_only: bool = True,
dduf_entries=None,
):
"""
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
"""
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
return safe_load_file(checkpoint_file)
if dduf_entries:
# TODO: Find a way to only open the metadata
with dduf_entries[checkpoint_file].as_mmap() as mm:
return safetensors.torch.load(mm)
else:
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
return safe_load_file(checkpoint_file)
try:
if dduf_entries:
raise ValueError("DDUF format is not supported yet with torch format. Please use safetensors")
if map_location is None:
if (
(
@ -3410,6 +3419,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
generation_config = kwargs.pop("generation_config", None)
dduf_entries = kwargs.pop("dduf_entries", None)
gguf_file = kwargs.pop("gguf_file", None)
# Cache path to the GGUF file
@ -3450,22 +3460,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if commit_hash is None:
if not isinstance(config, PretrainedConfig):
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if dduf_entries:
# files are in an archive, so I'm assuming the commit hash of the archive is enough.
resolved_config_file = next(iter(dduf_entries.items()))[1].dduf_path
else:
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
else:
commit_hash = getattr(config, "_commit_hash", None)
@ -3474,16 +3489,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
if _adapter_model_path is None:
_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
_commit_hash=commit_hash,
**adapter_kwargs,
)
if dduf_entries:
# TODO: use the global var from peft utils
if os.path.join(pretrained_model_name_or_path, "adapter_config.json") in dduf_entries:
_adapter_model_path = os.path.join(pretrained_model_name_or_path, "adapter_config.json")
else:
_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
_commit_hash=commit_hash,
**adapter_kwargs,
)
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, "r", encoding="utf-8") as f:
_adapter_model_path = pretrained_model_name_or_path
@ -3554,7 +3574,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
@ -3571,6 +3590,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
dduf_entries=dduf_entries,
**kwargs,
)
else:
@ -3636,36 +3656,69 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise ValueError(
"You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
)
if pretrained_model_name_or_path is not None and gguf_file is None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
# passing a dduf_entries means that we already knows where the file
is_local = os.path.isdir(pretrained_model_name_or_path) or dduf_entries
if is_local:
if from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
if from_tf and (
os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index"))
or (
dduf_entries
and os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
in dduf_entries
)
):
# Load from a TF 1.0 checkpoint in priority if from_tf
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
elif from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
elif from_tf and (
os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME))
or (
dduf_entries
and os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) in dduf_entries
)
):
# Load from a TF 2.0 checkpoint in priority if from_tf
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
elif from_flax and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif from_flax and (
os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME))
or (
dduf_entries
and os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) in dduf_entries
)
):
# Load from a Flax checkpoint in priority if from_flax
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
elif use_safetensors is not False and (
os.path.isfile(
os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
)
)
or (
dduf_entries
and os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
)
in dduf_entries
)
):
# Load from a safetensors checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
)
elif use_safetensors is not False and os.path.isfile(
os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
elif use_safetensors is not False and (
os.path.isfile(
os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
)
or (
dduf_entries
and os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
in dduf_entries
)
):
# Load from a sharded safetensors checkpoint
@ -3673,15 +3726,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True
elif not use_safetensors and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
elif not use_safetensors and (
os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
)
or (
dduf_entries
and os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
in dduf_entries
)
):
# Load from a PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
)
elif not use_safetensors and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
elif not use_safetensors and (
os.path.isfile(
os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
)
or (
dduf_entries
and os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
in dduf_entries
)
):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(
@ -3692,14 +3763,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif not use_safetensors and (
os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index"))
or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME))
or (
dduf_entries
and (
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
in dduf_entries
or os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) in dduf_entries
)
)
):
raise EnvironmentError(
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
" `from_tf=True` to load this model from those weights."
)
elif not use_safetensors and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif not use_safetensors and (
os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME))
or (
dduf_entries
and os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) in dduf_entries
)
):
raise EnvironmentError(
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
@ -3936,6 +4019,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
if (
@ -3943,8 +4027,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
and isinstance(resolved_archive_file, str)
and resolved_archive_file.endswith(".safetensors")
):
with safe_open(resolved_archive_file, framework="pt") as f:
metadata = f.metadata()
if dduf_entries:
# TODO: Find a way to better deal with that. We shouldn't have to read the entire file
metadata = {"format": "pt"}
# with dduf_entries[resolved_archive_file].as_mmap() as mm:
# metadata = safetensors.torch.load(mm).metadata()
else:
with safe_open(resolved_archive_file, framework="pt") as f:
metadata = f.metadata()
if metadata is None:
# Assume it's a pytorch checkpoint (introduced for timm checkpoints)
@ -3972,7 +4062,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if from_pt:
if not is_sharded and state_dict is None:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
state_dict = load_state_dict(
resolved_archive_file, weights_only=weights_only, dduf_entries=dduf_entries
)
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
@ -3993,7 +4085,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif not is_sharded:
torch_dtype = get_state_dict_dtype(state_dict)
else:
one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only)
one_state_dict = load_state_dict(
resolved_archive_file[0], weights_only=weights_only, dduf_entries=dduf_entries
)
torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory
logger.info(
@ -4218,6 +4312,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
dduf_entries=dduf_entries,
)
# make sure token embedding weights are still tied if needed
@ -4400,6 +4495,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=None,
gguf_path=None,
weights_only=True,
dduf_entries=None,
):
is_safetensors = False
is_quantized = hf_quantizer is not None
@ -4742,7 +4838,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
shard_file,
is_quantized=is_quantized,
map_location=map_location,
weights_only=weights_only,
dduf_entries=dduf_entries,
)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not

View File

@ -291,6 +291,8 @@ class CLIPTokenizer(PreTrainedTokenizer):
pad_token="<|endoftext|>", # hack to enable padding
**kwargs,
):
dduf_entries = kwargs.get("dduf_entries", None)
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
@ -302,15 +304,21 @@ class CLIPTokenizer(PreTrainedTokenizer):
logger.info("ftfy or spacy is not installed using custom BasicTokenizer instead of ftfy.")
self.nlp = BasicTokenizer(strip_accents=False, do_split_on_punc=False)
self.fix_text = None
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
if dduf_entries:
self.encoder = json.loads(dduf_entries[vocab_file].read_text())
else:
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
if dduf_entries:
bpe_merges = dduf_entries[merges_file].read_text().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
else:
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}

View File

@ -137,6 +137,7 @@ class T5Tokenizer(PreTrainedTokenizer):
add_prefix_space=True,
**kwargs,
) -> None:
dduf_entries = kwargs.get("dduf_entries", None)
pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token
eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token
@ -145,9 +146,12 @@ class T5Tokenizer(PreTrainedTokenizer):
self.vocab_file = vocab_file
self._extra_ids = extra_ids
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
if dduf_entries:
with dduf_entries[self.vocab_file].as_mmap() as mm:
self.sp_model.load_from_serialized_proto(mm)
else:
self.sp_model.Load(vocab_file)
if additional_special_tokens is not None:
extra_tokens = [x for x in additional_special_tokens if "<extra_id_" in str(x)]
@ -181,7 +185,7 @@ class T5Tokenizer(PreTrainedTokenizer):
legacy = True
self.legacy = legacy
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False), dduf_entries=dduf_entries)
self.vocab_file = vocab_file
self._extra_ids = extra_ids
self.add_prefix_space = add_prefix_space
@ -199,21 +203,29 @@ class T5Tokenizer(PreTrainedTokenizer):
)
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
def get_spm_processor(self, from_slow=False):
def get_spm_processor(self, from_slow=False, dduf_entries=None):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
if self.legacy or from_slow: # no dependency on protobuf
tokenizer.Load(self.vocab_file)
if dduf_entries:
with dduf_entries[self.vocab_file].as_mmap() as mm:
tokenizer.load_from_serialized_proto(mm)
else:
tokenizer.Load(self.vocab_file)
return tokenizer
with open(self.vocab_file, "rb") as f:
sp_model = f.read()
model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
model = model_pb2.ModelProto.FromString(sp_model)
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
sp_model = model.SerializeToString()
tokenizer.LoadFromSerializedProto(sp_model)
if dduf_entries:
with dduf_entries[self.vocab_file].as_mmap() as mm:
tokenizer.load_from_serialized_proto(mm)
else:
with open(self.vocab_file, "rb") as f:
sp_model = f.read()
model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
model = model_pb2.ModelProto.FromString(sp_model)
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
sp_model = model.SerializeToString()
tokenizer.LoadFromSerializedProto(sp_model)
return tokenizer
@staticmethod

View File

@ -95,6 +95,7 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
add_prefix_space=None,
**kwargs,
):
self.dduf_entries = kwargs.get("dduf_entries", None)
# Add extra_ids to the special token list
if additional_special_tokens is not None:
extra_tokens = [x for x in additional_special_tokens if "<extra_id_" in str(x)]
@ -132,7 +133,9 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
# TODO: update this. Putting it to True for now
# return os.path.isfile(self.vocab_file) if self.vocab_file else False
return True
@staticmethod
def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length):
@ -170,9 +173,15 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
logger.info(f"Copy vocab file to {out_vocab_file}")
# copyfile don't work with binary content e.g when we load file from an archive
elif not os.path.isfile(self.vocab_file):
with self.dduf_entries[self.vocab_file].as_mmap() as mm:
with open(out_vocab_file, "wb") as out_file:
out_file.write(mm)
logger.info(f"Copy vocab file to {out_vocab_file}")
return (out_vocab_file,)

View File

@ -1894,6 +1894,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)
gguf_file = kwargs.get("gguf_file", None)
dduf_entries = kwargs.get("dduf_entries", None)
if use_auth_token is not None:
warnings.warn(
@ -1952,36 +1953,51 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if "tokenizer_file" in vocab_files:
# Try to get the tokenizer config to see if there are versioned tokenizer files.
fast_tokenizer_file = FULL_TOKENIZER_FILE
resolved_config_file = cached_file(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
user_agent=user_agent,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
if dduf_entries:
tokenizer_config = json.loads(
dduf_entries[
os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
].read_text()
)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
else:
resolved_config_file = cached_file(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
user_agent=user_agent,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(
tokenizer_config["fast_tokenizer_files"]
)
vocab_files["tokenizer_file"] = fast_tokenizer_file
# Get files from url, cache, or disk depending on the case
resolved_vocab_files = {}
unresolved_files = []
for file_id, file_path in vocab_files.items():
if file_path is None:
if dduf_entries:
# We don't necessarily have the file
if os.path.join(pretrained_model_name_or_path, file_path) in dduf_entries:
resolved_vocab_files[file_id] = os.path.join(pretrained_model_name_or_path, file_path)
elif file_path is None:
resolved_vocab_files[file_id] = None
elif single_file_id == file_id:
if os.path.isfile(file_path):
@ -2007,6 +2023,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash)
# if dduf_entries:
# # preload files that are going to be opened in the tokenizer so that we don't need to modify each tokenizer with dduf ?
# for file in cls.vocab_files_names:
# # e.g vocab_file, merges ...
if len(unresolved_files) > 0:
logger.info(
@ -2066,6 +2086,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# file or if `from_slow` is set to True.
from_slow = kwargs.get("from_slow", False)
gguf_file = kwargs.get("gguf_file", None)
dduf_entries = kwargs.get("dduf_entries", None)
has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None
# If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be
@ -2089,8 +2110,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# Did we saved some inputs and kwargs to reload ?
tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None)
if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle)
if dduf_entries:
init_kwargs = json.loads(dduf_entries[tokenizer_config_file].read_text())
else:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle)
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
config_tokenizer_class = init_kwargs.get("tokenizer_class")
init_kwargs.pop("tokenizer_class", None)
@ -2106,8 +2130,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# If an independent chat template file exists, it takes priority over template entries in the tokenizer config
chat_template_file = resolved_vocab_files.pop("chat_template_file", None)
if chat_template_file is not None:
with open(chat_template_file) as chat_template_handle:
init_kwargs["chat_template"] = chat_template_handle.read() # Clobbers any template in the config
if dduf_entries:
# TODO: check that it works
init_kwargs["chat_template"] = dduf_entries[chat_template_file].read_text()
else:
with open(chat_template_file) as chat_template_handle:
init_kwargs["chat_template"] = chat_template_handle.read() # Clobbers any template in the config
if not _is_local:
if "auto_map" in init_kwargs:
@ -2207,26 +2235,29 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
else:
# begin legacy: read the added_tokens_file and update kwargs with special_tokens_map if modified
if special_tokens_map_file is not None:
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
special_tokens_map = json.load(special_tokens_map_handle)
for key, value in special_tokens_map.items():
if key in kwargs and kwargs[key]:
# This value has already been redefined by the kwargs
# We keep this new value and ignore the one stored in the special_tokens_map_file
continue
if isinstance(value, dict):
value["special"] = True
value = AddedToken(**value)
elif key == "additional_special_tokens" and isinstance(value, list):
additional_special_tokens = init_kwargs.pop("additional_special_tokens", []) or []
for token in value:
if isinstance(token, dict):
token["special"] = True
token = AddedToken(**token)
if token not in additional_special_tokens:
additional_special_tokens.append(token)
value = additional_special_tokens
init_kwargs[key] = value
if dduf_entries:
special_tokens_map = json.loads(dduf_entries[special_tokens_map_file].read_text())
else:
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
special_tokens_map = json.load(special_tokens_map_handle)
for key, value in special_tokens_map.items():
if key in kwargs and kwargs[key]:
# This value has already been redefined by the kwargs
# We keep this new value and ignore the one stored in the special_tokens_map_file
continue
if isinstance(value, dict):
value["special"] = True
value = AddedToken(**value)
elif key == "additional_special_tokens" and isinstance(value, list):
additional_special_tokens = init_kwargs.pop("additional_special_tokens", []) or []
for token in value:
if isinstance(token, dict):
token["special"] = True
token = AddedToken(**token)
if token not in additional_special_tokens:
additional_special_tokens.append(token)
value = additional_special_tokens
init_kwargs[key] = value
# slow -> slow|fast, legacy: convert the `"added_tokens.json"` file to `added_tokens_decoder`.
# this is for legacy purpose. We don't add the tokens after init for efficiency.
@ -2239,8 +2270,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
else:
special_tokens.append(str(init_kwargs[key]))
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
added_tok_encoder = json.load(added_tokens_handle)
if dduf_entries:
added_tok_encoder = json.loads(dduf_entries[added_tokens_file].read_text())
else:
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
added_tok_encoder = json.load(added_tokens_handle)
for str_token, index in added_tok_encoder.items():
# if index not in added_tokens_decoder and str_token not in added_tokens_map:
special = str_token in special_tokens
@ -2253,9 +2287,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# if `tokenizer_config.json` is `None`
if tokenizer_file is not None:
# This is for slow so can be done before
with open(tokenizer_file, encoding="utf-8") as tokenizer_file_handle:
tokenizer_file_handle = json.load(tokenizer_file_handle)
added_tokens = tokenizer_file_handle.pop("added_tokens")
if dduf_entries:
tokenizer_file_handle = json.loads(dduf_entries[tokenizer_file].read_text())
else:
with open(tokenizer_file, encoding="utf-8") as tokenizer_file_handle:
tokenizer_file_handle = json.load(tokenizer_file_handle)
added_tokens = tokenizer_file_handle.pop("added_tokens")
for serialized_tokens in added_tokens:
idx = serialized_tokens.pop("id")
added_tokens_decoder[idx] = AddedToken(**serialized_tokens)
@ -2482,6 +2519,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
tokenizer_config.pop("tokenizer_file", None)
if "device_map" in tokenizer_config:
tokenizer_config.pop("device_map")
if "dduf_entries" in tokenizer_config:
tokenizer_config.pop("dduf_entries", None)
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"

View File

@ -108,7 +108,6 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
"Cannot instantiate this tokenizer from a slow version. If it's based on sentencepiece, make sure you "
"have sentencepiece installed."
)
if tokenizer_object is not None:
fast_tokenizer = copy.deepcopy(tokenizer_object)
elif fast_tokenizer_file is not None and not from_slow:
@ -130,7 +129,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
elif self.slow_tokenizer_class is not None and slow_tokenizer is not False:
# We need to create and convert a slow tokenizer to build the backend
slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
fast_tokenizer = convert_slow_tokenizer(slow_tokenizer, **kwargs)
elif not slow_tokenizer:
# We tried loading a slow_tokenizer with spm and failed, try to load with tiktoken
self.vocab_file = kwargs.get("vocab_file", None)

View File

@ -1044,6 +1044,7 @@ def get_checkpoint_shard_files(
revision=None,
subfolder="",
_commit_hash=None,
dduf_entries=None,
**deprecated_kwargs,
):
"""
@ -1068,22 +1069,23 @@ def get_checkpoint_shard_files(
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token
if not os.path.isfile(index_filename):
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
with open(index_filename, "r") as f:
index = json.loads(f.read())
if dduf_entries:
index = json.loads(dduf_entries[index_filename].read_text())
else:
if not os.path.isfile(index_filename):
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
with open(index_filename, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
sharded_metadata = index["metadata"]
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
sharded_metadata["weight_map"] = index["weight_map"].copy()
# First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path):
# First, let's deal with local folder and dduf
if os.path.isdir(pretrained_model_name_or_path) or dduf_entries:
shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
return shard_filenames, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames = []
# Check if the model is already cached or not. We only try the last checkpoint, this should cover most cases of