[repo utils] Update models_to_deprecate.py (#41231)

* update models_to_deprecate

* exclude this file

* handle typos and aliases

* don't commit files

* PR suggestions; make fixup
This commit is contained in:
Joao Gante
2025-10-01 13:01:52 +01:00
committed by GitHub
parent bcec3e2175
commit 1d1ac07893

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Script to find a candidate list of models to deprecate based on the number of downloads and the date of the last commit.
Script to find a candidate list of models to deprecate based on the number of downloads and the date of the last
commit.
"""
import argparse
@ -25,6 +26,9 @@ from pathlib import Path
from git import Repo
from huggingface_hub import HfApi
from tqdm import tqdm
from transformers.models.auto.configuration_auto import DEPRECATED_MODELS, MODEL_NAMES_MAPPING
api = HfApi()
@ -33,6 +37,97 @@ PATH_TO_REPO = Path(__file__).parent.parent.resolve()
repo = Repo(PATH_TO_REPO)
# Used when the folder name on the hub does not match the folder name in `transformers/models`
# format = {folder name in `transformers/models`: expected tag on the hub}
MODEL_FOLDER_NAME_TO_TAG_MAPPING = {
"audio_spectrogram_transformer": "audio-spectrogram-transformer",
"bert_generation": "bert-generation",
"blenderbot_small": "blenderbot-small",
"blip_2": "blip-2",
"dab_detr": "dab-detr",
"data2vec": "data2vec-audio", # actually, the base model is never used as a tag, but the sub models are
"deberta_v2": "deberta-v2",
"donut": "donut-swin",
"encoder_decoder": "encoder-decoder",
"grounding_dino": "grounding-dino",
"kosmos2": "kosmos-2",
"kosmos2_5": "kosmos-2.5",
"megatron_bert": "megatron-bert",
"mgp_str": "mgp-str",
"mm_grounding_dino": "mm-grounding-dino",
"modernbert_decoder": "modernbert-decoder",
"nllb_moe": "nllb-moe",
"omdet_turbo": "omdet-turbo",
"openai": "openai-gpt",
"roberta_prelayernorm": "roberta-prelayernorm",
"sew_d": "sew-d",
"speech_encoder_decoder": "speech-encoder-decoder",
"table_transformer": "table-transformer",
"unispeech_sat": "unispeech-sat",
"vision_encoder_decoder": "vision-encoder-decoder",
"vision_text_dual_encoder": "vision-text-dual-encoder",
"wav2vec2_bert": "wav2vec2-bert",
"wav2vec2_conformer": "wav2vec2-conformer",
"x_clip": "xclip",
"xlm_roberta": "xlm-roberta",
"xlm_roberta_xl": "xlm-roberta-xl",
}
# Used on model architectures with multiple tags on the hub (e.g. on VLMs, we often support a text-only model).
# Applied after the model folder name mapping. format = {base model tag: [extra tags]}
EXTRA_TAGS_MAPPING = {
"aimv2": ["aimv2_vision_model"],
"aria": ["aria_text"],
"bart": ["barthez", "bartpho"],
"bert": ["bert-japanese", "bertweet", "herbert", "phobert"],
"beit": ["dit"],
"blip-2": ["blip_2_qformer"],
"chinese_clip": ["chinese_clip_vision_model"],
"clip": ["clip_text_model", "clip_vision_model"],
"data2vec-audio": ["data2vec-text", "data2vec-vision"],
"depth_anything": ["depth_anything_v2"],
"donut-swin": ["nougat"],
"edgetam": ["edgetam_vision_model"],
"fastspeech2_conformer": ["fastspeech2_conformer_with_hifigan"],
"gemma3": ["gemma3_text"],
"gemma3n": ["gemma3n_audio", "gemma3n_text", "gemma3n_vision"],
"gpt2": ["cpm", "dialogpt", "gpt-sw3", "megatron_gpt2"],
"glm4v_moe": ["glm4v_moe_text"],
"glm4v": ["glm4v_text"],
"idefics3": ["idefics3_vision"],
"internvl": ["internvl_vision"],
"layoutlmv2": ["layoutxlm"],
"llama": ["code_llama", "falcon3", "llama2", "llama3"],
"llama4": ["llama4_text"],
"llava_next": ["granitevision"],
"luke": ["mluke"],
"m2m_100": ["nllb"],
"maskformer": ["maskformer-swin"],
"mbart": ["mbart50"],
"parakeet": ["parakeet_ctc", "parakeet_encoder"],
"perception_lm": ["perception_encoder"],
"pix2struct": ["deplot", "matcha"],
"qwen2_5_vl": ["qwen2_5_vl_text"],
"qwen2_audio": ["qwen2_audio_encoder"],
"qwen2_vl": ["qwen2_vl_text"],
"qwen3_vl_moe": ["qwen3_vl_moe_text"],
"qwen3_vl": ["qwen3_vl_text"],
"rt_detr": ["rt_detr_resnet"],
"sam2": ["sam2_hiera_det_model", "sam2_vision_model"],
"sam": ["sam_hq_vision_model", "sam_vision_model"],
"siglip2": ["siglip2_vision_model"],
"siglip": ["siglip_vision_model"],
"smolvlm": ["smolvlm_vision"],
"t5": ["byt5", "flan-t5", "flan-ul2", "madlad-400", "myt5", "t5v1.1", "ul2"],
"voxtral": ["voxtral_encoder"],
"wav2vec2": ["mms", "wav2vec2_phoneme", "xls_r", "xlsr_wav2vec2"],
"xlm-roberta": ["xlm-v"],
}
# Similar to `DEPRECATED_MODELS`, but containing the tags when the model tag does not match the model folder name :'(
DEPRECATED_MODELS_TAGS = {"gptsan-japanese", "open-llama", "transfo-xl", "xlm-prophetnet"}
class HubModelLister:
"""
Utility for getting models from the hub based on tags. Handles errors without crashing the script.
@ -40,7 +135,7 @@ class HubModelLister:
def __init__(self, tags):
self.tags = tags
self.model_list = api.list_models(tags=tags)
self.model_list = api.list_models(filter=tags)
def __iter__(self):
try:
@ -97,9 +192,11 @@ def get_list_of_models_to_deprecate(
info["first_commit_datetime"] = datetime.fromisoformat(info["first_commit_datetime"])
else:
# Build a dictionary of model info: first commit datetime, commit hash, model path
print("Building a dictionary of basic model info...")
models_info = defaultdict(dict)
for model_path in model_paths:
for i, model_path in enumerate(tqdm(sorted(model_paths))):
if max_num_models != -1 and i > max_num_models:
break
model = model_path.split("/")[-2]
if model in models_info:
continue
@ -111,12 +208,41 @@ def get_list_of_models_to_deprecate(
models_info[model]["first_commit_datetime"] = committed_datetime
models_info[model]["model_path"] = model_path
models_info[model]["downloads"] = 0
models_info[model]["tags"] = [model]
# Some tags on the hub are formatted differently than in the library
tags = [model]
if "_" in model:
tags.append(model.replace("_", "-"))
models_info[model]["tags"] = tags
# The keys in the dictionary above are the model folder names. In some cases, the model tag on the hub does not
# match the model folder name. We replace the key and append the expected tag.
for folder_name, expected_tag in MODEL_FOLDER_NAME_TO_TAG_MAPPING.items():
if folder_name in models_info:
models_info[expected_tag] = models_info[folder_name]
models_info[expected_tag]["tags"] = [expected_tag]
del models_info[folder_name]
# Some models have multiple tags on the hub. We add the expected tag to the list of tags.
for model_name, extra_tags in EXTRA_TAGS_MAPPING.items():
if model_name in models_info:
models_info[model_name]["tags"].extend(extra_tags)
# Sanity check for the case with all models: the model tags must match the keys in the MODEL_NAMES_MAPPING
# (= actual model tags on the hub)
if max_num_models == -1:
all_model_tags = set()
for model_name in models_info:
all_model_tags.update(models_info[model_name]["tags"])
non_deprecated_model_tags = (
set(MODEL_NAMES_MAPPING.keys()) - set(DEPRECATED_MODELS_TAGS) - set(DEPRECATED_MODELS)
)
if all_model_tags != non_deprecated_model_tags:
raise ValueError(
"The tags of the `models_info` dictionary must match the keys in the `MODEL_NAMES_MAPPING`!"
"\nMissing tags in `model_info`: "
+ str(sorted(non_deprecated_model_tags - all_model_tags))
+ "\nExtra tags in `model_info`: "
+ str(sorted(all_model_tags - non_deprecated_model_tags))
+ "\n\nYou need to update one or more of the following: `MODEL_NAMES_MAPPING`, "
"`EXTRA_TAGS_MAPPING` or `DEPRECATED_MODELS_TAGS`."
)
# Filter out models which were added less than a year ago
models_info = {
@ -124,19 +250,21 @@ def get_list_of_models_to_deprecate(
}
# We make successive calls to the hub, filtering based on the model tags
n_seen = 0
for model, model_info in models_info.items():
print("Making calls to the hub to find models below the threshold number of downloads...")
num_models = len(models_info)
for i, (model, model_info) in enumerate(models_info.items()):
print(f"{i + 1}/{num_models}: getting hub downloads for model='{model}' (tags={model_info['tags']})")
for model_tag in model_info["tags"]:
if model_info["downloads"] > thresh_num_downloads:
break
model_list = HubModelLister(tags=model_tag)
for i, hub_model in enumerate(model_list):
n_seen += 1
if i % 100 == 0:
print(f"Processing model {i} for tag {model_tag}")
if max_num_models != -1 and i > n_seen:
break
for hub_model in model_list:
if hub_model.private:
continue
model_info["downloads"] += hub_model.downloads
# No need to make further hub calls, it's above the set threshold
if model_info["downloads"] > thresh_num_downloads:
break
if save_model_info and not (use_cache and os.path.exists("models_info.json")):
# Make datetimes serializable
@ -156,7 +284,11 @@ def get_list_of_models_to_deprecate(
print(f"\nModel: {model}")
print(f"Downloads: {n_downloads}")
print(f"Date: {info['first_commit_datetime']}")
print("\nModels to deprecate: ", "\n" + "\n".join(models_to_deprecate.keys()))
# sort models to deprecate by downloads (lowest downloads first)
models_to_deprecate = sorted(models_to_deprecate.items(), key=lambda x: x[1]["downloads"])
print("\nModels to deprecate: ", "\n" + "\n".join([model[0] for model in models_to_deprecate]))
print(f"\nNumber of models to deprecate: {n_models_to_deprecate}")
print("Before deprecating make sure to verify the models, including if they're used as a module in other models.")
@ -171,19 +303,25 @@ if __name__ == "__main__":
"--thresh_num_downloads",
type=int,
default=5_000,
help="Threshold number of downloads below which a model should be deprecated. Default is 5,000.",
help=(
"Threshold number of downloads below which a model should be deprecated. Default is 5,000. If you are "
"considering a sweep and using a cache, set this to the highest number of the sweep."
),
)
parser.add_argument(
"--thresh_date",
type=str,
default=None,
help="Date to consider the first commit from. Format: YYYY-MM-DD. If unset, defaults to one year ago from today.",
help=(
"Date to consider the first commit from. Format: YYYY-MM-DD. If unset, defaults to one year ago from "
"today."
),
)
parser.add_argument(
"--max_num_models",
type=int,
default=-1,
help="Maximum number of models to consider from the hub. -1 means all models. Useful for testing.",
help="Maximum number of models architectures to consider. -1 means all models. Useful for testing.",
)
args = parser.parse_args()