From 3016717f0d904910fd9a55b3708e76024e71fe72 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Wed, 1 Oct 2025 21:13:04 +0800 Subject: [PATCH] Use removeprefix and removesuffix (#41240) * Use removeprefix and removesuffix Signed-off-by: Yuanyuan Chen * More fixes Signed-off-by: Yuanyuan Chen --------- Signed-off-by: Yuanyuan Chen --- src/transformers/commands/chat.py | 3 +-- src/transformers/commands/serving.py | 6 ++---- src/transformers/dynamic_module_utils.py | 3 +-- src/transformers/modelcard.py | 3 +-- src/transformers/modeling_utils.py | 4 ++-- src/transformers/models/auto/image_processing_auto.py | 4 +--- .../convert_chinese_clip_original_pytorch_to_hf.py | 2 +- src/transformers/models/longt5/modeling_longt5.py | 3 +-- src/transformers/models/rag/retrieval_rag.py | 5 +---- src/transformers/utils/auto_docstring.py | 3 +-- utils/check_config_docstrings.py | 3 +-- utils/check_copies.py | 3 +-- 12 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py index 37f606d00dd..bc421f7d84d 100644 --- a/src/transformers/commands/chat.py +++ b/src/transformers/commands/chat.py @@ -443,8 +443,7 @@ class ChatCommand(BaseTransformersCLICommand): # 2. b. strings should be quoted def is_number(s: str) -> bool: # handle negative numbers - if s.startswith("-"): - s = s[1:] + s = s.removeprefix("-") return s.replace(".", "", 1).isdigit() generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()} diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index ee9009090f9..970d59c96e7 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -1066,8 +1066,7 @@ class ServeCommand(BaseTransformersCLICommand): for result in streamer: # Temporary hack for GPTOS 3: don't emit the final "<|return|>" if "gptoss" in model.config.architectures[0].lower(): - if result.endswith("<|return|>"): - result = result[: -len("<|return|>")] + result = result.removesuffix("<|return|>") results += result # (related to temporary hack 2) @@ -1325,8 +1324,7 @@ class ServeCommand(BaseTransformersCLICommand): for result in streamer: # Temporary hack for GPTOS 3: don't emit the final "<|return|>" if "gptoss" in model.config.architectures[0].lower(): - if result.endswith("<|return|>"): - result = result[: -len("<|return|>")] + result = result.removesuffix("<|return|>") results += result # (related to temporary hack 2) diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 5b541c076f6..6e31557c39a 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -285,8 +285,7 @@ def get_class_in_module( `typing.Type`: The class looked for. """ name = os.path.normpath(module_path) - if name.endswith(".py"): - name = name[:-3] + name = name.removesuffix(".py") name = name.replace(os.path.sep, ".") module_file: Path = Path(HF_MODULES_CACHE) / module_path with _HF_REMOTE_CODE_LOCK: diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index 2a1c77b1282..cbd148153ca 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -667,8 +667,7 @@ def parse_log_history(log_history): if idx > 0: eval_results = {} for key, value in log_history[idx].items(): - if key.startswith("eval_"): - key = key[5:] + key = key.removeprefix("eval_") if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]: camel_cased_key = " ".join([part.capitalize() for part in key.split("_")]) eval_results[camel_cased_key] = value diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1c57072a0c7..4175b349ea0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5397,7 +5397,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH for name, module in self.named_modules(): if remove_prefix: _prefix = f"{self.base_model_prefix}." - name = name[len(_prefix) :] if name.startswith(_prefix) else name + name = name.removeprefix(_prefix) elif add_prefix: name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix @@ -5717,7 +5717,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH # in the warnings. For missing keys, we should show the prefix in the warning as it's part of the final model if loading_task_model_from_base_state_dict: _prefix = f"{self.base_model_prefix}." - unexpected_keys = [k[len(_prefix) :] if k.startswith(_prefix) else k for k in unexpected_keys] + unexpected_keys = [k.removeprefix(_prefix) for k in unexpected_keys] return missing_keys, unexpected_keys diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index a272735af20..4b71712dfc7 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -566,9 +566,7 @@ class AutoImageProcessor: ) image_processor_class = get_image_processor_class_from_name(image_processor_type) else: - image_processor_type_slow = ( - image_processor_type[:-4] if image_processor_type.endswith("Fast") else image_processor_type - ) + image_processor_type_slow = image_processor_type.removesuffix("Fast") image_processor_class = get_image_processor_class_from_name(image_processor_type_slow) if image_processor_class is None and image_processor_type.endswith("Fast"): raise ValueError( diff --git a/src/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py b/src/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py index adc9300ef51..2ec838a7da6 100644 --- a/src/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py +++ b/src/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py @@ -105,7 +105,7 @@ def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, c hf_model = ChineseCLIPModel(config).eval() pt_weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"] - pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()} + pt_weights = {(name.removeprefix("module.")): value for name, value in pt_weights.items()} copy_text_model_and_projection(hf_model, pt_weights) copy_vision_model_and_projection(hf_model, pt_weights) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 3361e7aafe8..bfde37f2c86 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1243,8 +1243,7 @@ class LongT5PreTrainedModel(PreTrainedModel): def _try_load_missing_tied_module(self, key): module = self - if key.endswith(".weight"): - key = key[: -len(".weight")] + key = key.removesuffix(".weight") for sub_key in key.split("."): if not hasattr(module, sub_key): return diff --git a/src/transformers/models/rag/retrieval_rag.py b/src/transformers/models/rag/retrieval_rag.py index 6fb924c8b7b..49c782e5b3f 100644 --- a/src/transformers/models/rag/retrieval_rag.py +++ b/src/transformers/models/rag/retrieval_rag.py @@ -509,10 +509,7 @@ class RagRetriever: def cat_input_and_doc(doc_title, doc_text, input_string, prefix): # TODO(Patrick): if we train more RAG models, I want to put the input first to take advantage of effortless truncation # TODO(piktus): better handling of truncation - if doc_title.startswith('"'): - doc_title = doc_title[1:] - if doc_title.endswith('"'): - doc_title = doc_title[:-1] + doc_title = doc_title.removeprefix('"').removesuffix('"') if prefix is None: prefix = "" out = (prefix + doc_title + self.config.title_sep + doc_text + self.config.doc_sep + input_string).replace( diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index 9bf44c8bb42..15882c6b63c 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -1185,8 +1185,7 @@ def get_checkpoint_from_config_class(config_class): # For example, `('google-bert/bert-base-uncased', 'https://huggingface.co/google-bert/bert-base-uncased')` for ckpt_name, ckpt_link in checkpoints: # allow the link to end with `/` - if ckpt_link.endswith("/"): - ckpt_link = ckpt_link[:-1] + ckpt_link = ckpt_link.removesuffix("/") # verify the checkpoint name corresponds to the checkpoint link ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}" diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py index 4a5d0395a20..d344bf42601 100644 --- a/utils/check_config_docstrings.py +++ b/utils/check_config_docstrings.py @@ -64,8 +64,7 @@ def get_checkpoint_from_config_class(config_class): # For example, `('google-bert/bert-base-uncased', 'https://huggingface.co/google-bert/bert-base-uncased')` for ckpt_name, ckpt_link in checkpoints: # allow the link to end with `/` - if ckpt_link.endswith("/"): - ckpt_link = ckpt_link[:-1] + ckpt_link = ckpt_link.removesuffix("/") # verify the checkpoint name corresponds to the checkpoint link ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}" diff --git a/utils/check_copies.py b/utils/check_copies.py index 2bb00776af9..28b743beab5 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -797,8 +797,7 @@ def is_copy_consistent( orig_idx = -1 observed_code = "" for name, code in observed_code_blocks.items(): - if code.endswith("\n"): - code = code[:-1] + code = code.removesuffix("\n") for code_line in code.split("\n"): orig_idx += 1 if code_line.strip() and not name.startswith(("_ignored_existing_block_", "_ignored_new_block_")):