mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[repo utils] Update models_to_deprecate.py
(#41231)
* update models_to_deprecate * exclude this file * handle typos and aliases * don't commit files * PR suggestions; make fixup
This commit is contained in:
@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Script to find a candidate list of models to deprecate based on the number of downloads and the date of the last commit.
|
||||
Script to find a candidate list of models to deprecate based on the number of downloads and the date of the last
|
||||
commit.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -25,6 +26,9 @@ from pathlib import Path
|
||||
|
||||
from git import Repo
|
||||
from huggingface_hub import HfApi
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers.models.auto.configuration_auto import DEPRECATED_MODELS, MODEL_NAMES_MAPPING
|
||||
|
||||
|
||||
api = HfApi()
|
||||
@ -33,6 +37,97 @@ PATH_TO_REPO = Path(__file__).parent.parent.resolve()
|
||||
repo = Repo(PATH_TO_REPO)
|
||||
|
||||
|
||||
# Used when the folder name on the hub does not match the folder name in `transformers/models`
|
||||
# format = {folder name in `transformers/models`: expected tag on the hub}
|
||||
MODEL_FOLDER_NAME_TO_TAG_MAPPING = {
|
||||
"audio_spectrogram_transformer": "audio-spectrogram-transformer",
|
||||
"bert_generation": "bert-generation",
|
||||
"blenderbot_small": "blenderbot-small",
|
||||
"blip_2": "blip-2",
|
||||
"dab_detr": "dab-detr",
|
||||
"data2vec": "data2vec-audio", # actually, the base model is never used as a tag, but the sub models are
|
||||
"deberta_v2": "deberta-v2",
|
||||
"donut": "donut-swin",
|
||||
"encoder_decoder": "encoder-decoder",
|
||||
"grounding_dino": "grounding-dino",
|
||||
"kosmos2": "kosmos-2",
|
||||
"kosmos2_5": "kosmos-2.5",
|
||||
"megatron_bert": "megatron-bert",
|
||||
"mgp_str": "mgp-str",
|
||||
"mm_grounding_dino": "mm-grounding-dino",
|
||||
"modernbert_decoder": "modernbert-decoder",
|
||||
"nllb_moe": "nllb-moe",
|
||||
"omdet_turbo": "omdet-turbo",
|
||||
"openai": "openai-gpt",
|
||||
"roberta_prelayernorm": "roberta-prelayernorm",
|
||||
"sew_d": "sew-d",
|
||||
"speech_encoder_decoder": "speech-encoder-decoder",
|
||||
"table_transformer": "table-transformer",
|
||||
"unispeech_sat": "unispeech-sat",
|
||||
"vision_encoder_decoder": "vision-encoder-decoder",
|
||||
"vision_text_dual_encoder": "vision-text-dual-encoder",
|
||||
"wav2vec2_bert": "wav2vec2-bert",
|
||||
"wav2vec2_conformer": "wav2vec2-conformer",
|
||||
"x_clip": "xclip",
|
||||
"xlm_roberta": "xlm-roberta",
|
||||
"xlm_roberta_xl": "xlm-roberta-xl",
|
||||
}
|
||||
|
||||
# Used on model architectures with multiple tags on the hub (e.g. on VLMs, we often support a text-only model).
|
||||
# Applied after the model folder name mapping. format = {base model tag: [extra tags]}
|
||||
EXTRA_TAGS_MAPPING = {
|
||||
"aimv2": ["aimv2_vision_model"],
|
||||
"aria": ["aria_text"],
|
||||
"bart": ["barthez", "bartpho"],
|
||||
"bert": ["bert-japanese", "bertweet", "herbert", "phobert"],
|
||||
"beit": ["dit"],
|
||||
"blip-2": ["blip_2_qformer"],
|
||||
"chinese_clip": ["chinese_clip_vision_model"],
|
||||
"clip": ["clip_text_model", "clip_vision_model"],
|
||||
"data2vec-audio": ["data2vec-text", "data2vec-vision"],
|
||||
"depth_anything": ["depth_anything_v2"],
|
||||
"donut-swin": ["nougat"],
|
||||
"edgetam": ["edgetam_vision_model"],
|
||||
"fastspeech2_conformer": ["fastspeech2_conformer_with_hifigan"],
|
||||
"gemma3": ["gemma3_text"],
|
||||
"gemma3n": ["gemma3n_audio", "gemma3n_text", "gemma3n_vision"],
|
||||
"gpt2": ["cpm", "dialogpt", "gpt-sw3", "megatron_gpt2"],
|
||||
"glm4v_moe": ["glm4v_moe_text"],
|
||||
"glm4v": ["glm4v_text"],
|
||||
"idefics3": ["idefics3_vision"],
|
||||
"internvl": ["internvl_vision"],
|
||||
"layoutlmv2": ["layoutxlm"],
|
||||
"llama": ["code_llama", "falcon3", "llama2", "llama3"],
|
||||
"llama4": ["llama4_text"],
|
||||
"llava_next": ["granitevision"],
|
||||
"luke": ["mluke"],
|
||||
"m2m_100": ["nllb"],
|
||||
"maskformer": ["maskformer-swin"],
|
||||
"mbart": ["mbart50"],
|
||||
"parakeet": ["parakeet_ctc", "parakeet_encoder"],
|
||||
"perception_lm": ["perception_encoder"],
|
||||
"pix2struct": ["deplot", "matcha"],
|
||||
"qwen2_5_vl": ["qwen2_5_vl_text"],
|
||||
"qwen2_audio": ["qwen2_audio_encoder"],
|
||||
"qwen2_vl": ["qwen2_vl_text"],
|
||||
"qwen3_vl_moe": ["qwen3_vl_moe_text"],
|
||||
"qwen3_vl": ["qwen3_vl_text"],
|
||||
"rt_detr": ["rt_detr_resnet"],
|
||||
"sam2": ["sam2_hiera_det_model", "sam2_vision_model"],
|
||||
"sam": ["sam_hq_vision_model", "sam_vision_model"],
|
||||
"siglip2": ["siglip2_vision_model"],
|
||||
"siglip": ["siglip_vision_model"],
|
||||
"smolvlm": ["smolvlm_vision"],
|
||||
"t5": ["byt5", "flan-t5", "flan-ul2", "madlad-400", "myt5", "t5v1.1", "ul2"],
|
||||
"voxtral": ["voxtral_encoder"],
|
||||
"wav2vec2": ["mms", "wav2vec2_phoneme", "xls_r", "xlsr_wav2vec2"],
|
||||
"xlm-roberta": ["xlm-v"],
|
||||
}
|
||||
|
||||
# Similar to `DEPRECATED_MODELS`, but containing the tags when the model tag does not match the model folder name :'(
|
||||
DEPRECATED_MODELS_TAGS = {"gptsan-japanese", "open-llama", "transfo-xl", "xlm-prophetnet"}
|
||||
|
||||
|
||||
class HubModelLister:
|
||||
"""
|
||||
Utility for getting models from the hub based on tags. Handles errors without crashing the script.
|
||||
@ -40,7 +135,7 @@ class HubModelLister:
|
||||
|
||||
def __init__(self, tags):
|
||||
self.tags = tags
|
||||
self.model_list = api.list_models(tags=tags)
|
||||
self.model_list = api.list_models(filter=tags)
|
||||
|
||||
def __iter__(self):
|
||||
try:
|
||||
@ -97,9 +192,11 @@ def get_list_of_models_to_deprecate(
|
||||
info["first_commit_datetime"] = datetime.fromisoformat(info["first_commit_datetime"])
|
||||
|
||||
else:
|
||||
# Build a dictionary of model info: first commit datetime, commit hash, model path
|
||||
print("Building a dictionary of basic model info...")
|
||||
models_info = defaultdict(dict)
|
||||
for model_path in model_paths:
|
||||
for i, model_path in enumerate(tqdm(sorted(model_paths))):
|
||||
if max_num_models != -1 and i > max_num_models:
|
||||
break
|
||||
model = model_path.split("/")[-2]
|
||||
if model in models_info:
|
||||
continue
|
||||
@ -111,12 +208,41 @@ def get_list_of_models_to_deprecate(
|
||||
models_info[model]["first_commit_datetime"] = committed_datetime
|
||||
models_info[model]["model_path"] = model_path
|
||||
models_info[model]["downloads"] = 0
|
||||
models_info[model]["tags"] = [model]
|
||||
|
||||
# Some tags on the hub are formatted differently than in the library
|
||||
tags = [model]
|
||||
if "_" in model:
|
||||
tags.append(model.replace("_", "-"))
|
||||
models_info[model]["tags"] = tags
|
||||
# The keys in the dictionary above are the model folder names. In some cases, the model tag on the hub does not
|
||||
# match the model folder name. We replace the key and append the expected tag.
|
||||
for folder_name, expected_tag in MODEL_FOLDER_NAME_TO_TAG_MAPPING.items():
|
||||
if folder_name in models_info:
|
||||
models_info[expected_tag] = models_info[folder_name]
|
||||
models_info[expected_tag]["tags"] = [expected_tag]
|
||||
del models_info[folder_name]
|
||||
|
||||
# Some models have multiple tags on the hub. We add the expected tag to the list of tags.
|
||||
for model_name, extra_tags in EXTRA_TAGS_MAPPING.items():
|
||||
if model_name in models_info:
|
||||
models_info[model_name]["tags"].extend(extra_tags)
|
||||
|
||||
# Sanity check for the case with all models: the model tags must match the keys in the MODEL_NAMES_MAPPING
|
||||
# (= actual model tags on the hub)
|
||||
if max_num_models == -1:
|
||||
all_model_tags = set()
|
||||
for model_name in models_info:
|
||||
all_model_tags.update(models_info[model_name]["tags"])
|
||||
|
||||
non_deprecated_model_tags = (
|
||||
set(MODEL_NAMES_MAPPING.keys()) - set(DEPRECATED_MODELS_TAGS) - set(DEPRECATED_MODELS)
|
||||
)
|
||||
if all_model_tags != non_deprecated_model_tags:
|
||||
raise ValueError(
|
||||
"The tags of the `models_info` dictionary must match the keys in the `MODEL_NAMES_MAPPING`!"
|
||||
"\nMissing tags in `model_info`: "
|
||||
+ str(sorted(non_deprecated_model_tags - all_model_tags))
|
||||
+ "\nExtra tags in `model_info`: "
|
||||
+ str(sorted(all_model_tags - non_deprecated_model_tags))
|
||||
+ "\n\nYou need to update one or more of the following: `MODEL_NAMES_MAPPING`, "
|
||||
"`EXTRA_TAGS_MAPPING` or `DEPRECATED_MODELS_TAGS`."
|
||||
)
|
||||
|
||||
# Filter out models which were added less than a year ago
|
||||
models_info = {
|
||||
@ -124,19 +250,21 @@ def get_list_of_models_to_deprecate(
|
||||
}
|
||||
|
||||
# We make successive calls to the hub, filtering based on the model tags
|
||||
n_seen = 0
|
||||
for model, model_info in models_info.items():
|
||||
print("Making calls to the hub to find models below the threshold number of downloads...")
|
||||
num_models = len(models_info)
|
||||
for i, (model, model_info) in enumerate(models_info.items()):
|
||||
print(f"{i + 1}/{num_models}: getting hub downloads for model='{model}' (tags={model_info['tags']})")
|
||||
for model_tag in model_info["tags"]:
|
||||
if model_info["downloads"] > thresh_num_downloads:
|
||||
break
|
||||
model_list = HubModelLister(tags=model_tag)
|
||||
for i, hub_model in enumerate(model_list):
|
||||
n_seen += 1
|
||||
if i % 100 == 0:
|
||||
print(f"Processing model {i} for tag {model_tag}")
|
||||
if max_num_models != -1 and i > n_seen:
|
||||
break
|
||||
for hub_model in model_list:
|
||||
if hub_model.private:
|
||||
continue
|
||||
model_info["downloads"] += hub_model.downloads
|
||||
# No need to make further hub calls, it's above the set threshold
|
||||
if model_info["downloads"] > thresh_num_downloads:
|
||||
break
|
||||
|
||||
if save_model_info and not (use_cache and os.path.exists("models_info.json")):
|
||||
# Make datetimes serializable
|
||||
@ -156,7 +284,11 @@ def get_list_of_models_to_deprecate(
|
||||
print(f"\nModel: {model}")
|
||||
print(f"Downloads: {n_downloads}")
|
||||
print(f"Date: {info['first_commit_datetime']}")
|
||||
print("\nModels to deprecate: ", "\n" + "\n".join(models_to_deprecate.keys()))
|
||||
|
||||
# sort models to deprecate by downloads (lowest downloads first)
|
||||
models_to_deprecate = sorted(models_to_deprecate.items(), key=lambda x: x[1]["downloads"])
|
||||
|
||||
print("\nModels to deprecate: ", "\n" + "\n".join([model[0] for model in models_to_deprecate]))
|
||||
print(f"\nNumber of models to deprecate: {n_models_to_deprecate}")
|
||||
print("Before deprecating make sure to verify the models, including if they're used as a module in other models.")
|
||||
|
||||
@ -171,19 +303,25 @@ if __name__ == "__main__":
|
||||
"--thresh_num_downloads",
|
||||
type=int,
|
||||
default=5_000,
|
||||
help="Threshold number of downloads below which a model should be deprecated. Default is 5,000.",
|
||||
help=(
|
||||
"Threshold number of downloads below which a model should be deprecated. Default is 5,000. If you are "
|
||||
"considering a sweep and using a cache, set this to the highest number of the sweep."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--thresh_date",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Date to consider the first commit from. Format: YYYY-MM-DD. If unset, defaults to one year ago from today.",
|
||||
help=(
|
||||
"Date to consider the first commit from. Format: YYYY-MM-DD. If unset, defaults to one year ago from "
|
||||
"today."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_num_models",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Maximum number of models to consider from the hub. -1 means all models. Useful for testing.",
|
||||
help="Maximum number of models architectures to consider. -1 means all models. Useful for testing.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
Reference in New Issue
Block a user