mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Use removeprefix and removesuffix (#41240)
* Use removeprefix and removesuffix Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> * More fixes Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> --------- Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
This commit is contained in:
@ -443,8 +443,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
# 2. b. strings should be quoted
|
||||
def is_number(s: str) -> bool:
|
||||
# handle negative numbers
|
||||
if s.startswith("-"):
|
||||
s = s[1:]
|
||||
s = s.removeprefix("-")
|
||||
return s.replace(".", "", 1).isdigit()
|
||||
|
||||
generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()}
|
||||
|
@ -1066,8 +1066,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
for result in streamer:
|
||||
# Temporary hack for GPTOS 3: don't emit the final "<|return|>"
|
||||
if "gptoss" in model.config.architectures[0].lower():
|
||||
if result.endswith("<|return|>"):
|
||||
result = result[: -len("<|return|>")]
|
||||
result = result.removesuffix("<|return|>")
|
||||
results += result
|
||||
|
||||
# (related to temporary hack 2)
|
||||
@ -1325,8 +1324,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
for result in streamer:
|
||||
# Temporary hack for GPTOS 3: don't emit the final "<|return|>"
|
||||
if "gptoss" in model.config.architectures[0].lower():
|
||||
if result.endswith("<|return|>"):
|
||||
result = result[: -len("<|return|>")]
|
||||
result = result.removesuffix("<|return|>")
|
||||
results += result
|
||||
|
||||
# (related to temporary hack 2)
|
||||
|
@ -285,8 +285,7 @@ def get_class_in_module(
|
||||
`typing.Type`: The class looked for.
|
||||
"""
|
||||
name = os.path.normpath(module_path)
|
||||
if name.endswith(".py"):
|
||||
name = name[:-3]
|
||||
name = name.removesuffix(".py")
|
||||
name = name.replace(os.path.sep, ".")
|
||||
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
||||
with _HF_REMOTE_CODE_LOCK:
|
||||
|
@ -667,8 +667,7 @@ def parse_log_history(log_history):
|
||||
if idx > 0:
|
||||
eval_results = {}
|
||||
for key, value in log_history[idx].items():
|
||||
if key.startswith("eval_"):
|
||||
key = key[5:]
|
||||
key = key.removeprefix("eval_")
|
||||
if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
|
||||
camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
|
||||
eval_results[camel_cased_key] = value
|
||||
|
@ -5397,7 +5397,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
for name, module in self.named_modules():
|
||||
if remove_prefix:
|
||||
_prefix = f"{self.base_model_prefix}."
|
||||
name = name[len(_prefix) :] if name.startswith(_prefix) else name
|
||||
name = name.removeprefix(_prefix)
|
||||
elif add_prefix:
|
||||
name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix
|
||||
|
||||
@ -5717,7 +5717,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
# in the warnings. For missing keys, we should show the prefix in the warning as it's part of the final model
|
||||
if loading_task_model_from_base_state_dict:
|
||||
_prefix = f"{self.base_model_prefix}."
|
||||
unexpected_keys = [k[len(_prefix) :] if k.startswith(_prefix) else k for k in unexpected_keys]
|
||||
unexpected_keys = [k.removeprefix(_prefix) for k in unexpected_keys]
|
||||
|
||||
return missing_keys, unexpected_keys
|
||||
|
||||
|
@ -566,9 +566,7 @@ class AutoImageProcessor:
|
||||
)
|
||||
image_processor_class = get_image_processor_class_from_name(image_processor_type)
|
||||
else:
|
||||
image_processor_type_slow = (
|
||||
image_processor_type[:-4] if image_processor_type.endswith("Fast") else image_processor_type
|
||||
)
|
||||
image_processor_type_slow = image_processor_type.removesuffix("Fast")
|
||||
image_processor_class = get_image_processor_class_from_name(image_processor_type_slow)
|
||||
if image_processor_class is None and image_processor_type.endswith("Fast"):
|
||||
raise ValueError(
|
||||
|
@ -105,7 +105,7 @@ def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, c
|
||||
hf_model = ChineseCLIPModel(config).eval()
|
||||
|
||||
pt_weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"]
|
||||
pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()}
|
||||
pt_weights = {(name.removeprefix("module.")): value for name, value in pt_weights.items()}
|
||||
|
||||
copy_text_model_and_projection(hf_model, pt_weights)
|
||||
copy_vision_model_and_projection(hf_model, pt_weights)
|
||||
|
@ -1243,8 +1243,7 @@ class LongT5PreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _try_load_missing_tied_module(self, key):
|
||||
module = self
|
||||
if key.endswith(".weight"):
|
||||
key = key[: -len(".weight")]
|
||||
key = key.removesuffix(".weight")
|
||||
for sub_key in key.split("."):
|
||||
if not hasattr(module, sub_key):
|
||||
return
|
||||
|
@ -509,10 +509,7 @@ class RagRetriever:
|
||||
def cat_input_and_doc(doc_title, doc_text, input_string, prefix):
|
||||
# TODO(Patrick): if we train more RAG models, I want to put the input first to take advantage of effortless truncation
|
||||
# TODO(piktus): better handling of truncation
|
||||
if doc_title.startswith('"'):
|
||||
doc_title = doc_title[1:]
|
||||
if doc_title.endswith('"'):
|
||||
doc_title = doc_title[:-1]
|
||||
doc_title = doc_title.removeprefix('"').removesuffix('"')
|
||||
if prefix is None:
|
||||
prefix = ""
|
||||
out = (prefix + doc_title + self.config.title_sep + doc_text + self.config.doc_sep + input_string).replace(
|
||||
|
@ -1185,8 +1185,7 @@ def get_checkpoint_from_config_class(config_class):
|
||||
# For example, `('google-bert/bert-base-uncased', 'https://huggingface.co/google-bert/bert-base-uncased')`
|
||||
for ckpt_name, ckpt_link in checkpoints:
|
||||
# allow the link to end with `/`
|
||||
if ckpt_link.endswith("/"):
|
||||
ckpt_link = ckpt_link[:-1]
|
||||
ckpt_link = ckpt_link.removesuffix("/")
|
||||
|
||||
# verify the checkpoint name corresponds to the checkpoint link
|
||||
ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}"
|
||||
|
@ -64,8 +64,7 @@ def get_checkpoint_from_config_class(config_class):
|
||||
# For example, `('google-bert/bert-base-uncased', 'https://huggingface.co/google-bert/bert-base-uncased')`
|
||||
for ckpt_name, ckpt_link in checkpoints:
|
||||
# allow the link to end with `/`
|
||||
if ckpt_link.endswith("/"):
|
||||
ckpt_link = ckpt_link[:-1]
|
||||
ckpt_link = ckpt_link.removesuffix("/")
|
||||
|
||||
# verify the checkpoint name corresponds to the checkpoint link
|
||||
ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}"
|
||||
|
@ -797,8 +797,7 @@ def is_copy_consistent(
|
||||
orig_idx = -1
|
||||
observed_code = ""
|
||||
for name, code in observed_code_blocks.items():
|
||||
if code.endswith("\n"):
|
||||
code = code[:-1]
|
||||
code = code.removesuffix("\n")
|
||||
for code_line in code.split("\n"):
|
||||
orig_idx += 1
|
||||
if code_line.strip() and not name.startswith(("_ignored_existing_block_", "_ignored_new_block_")):
|
||||
|
Reference in New Issue
Block a user