Compare commits

...

3 Commits

Author SHA1 Message Date
02649eaa4c fix 2023-06-21 13:53:55 +02:00
706461ef5a fix 2023-06-21 13:41:45 +02:00
4c57171c82 fix 2023-06-21 13:30:37 +02:00

View File

@ -711,6 +711,29 @@ def check_all_auto_mapping_names_in_config_mapping_names():
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
def check_all_auto_mappings_importable():
"""Check all auto mappings could be imported."""
check_missing_backends()
failures = []
mappings_to_check = {}
# Each auto modeling files contains multiple mappings. Let's get them in a dynamic way.
for module_name in ["modeling_auto", "modeling_tf_auto", "modeling_flax_auto"]:
module = getattr(transformers.models.auto, module_name, None)
if module is None:
continue
# all mappings in a single auto modeling file
mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")]
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
for name, _ in mappings_to_check.items():
name = name.replace("_MAPPING_NAMES", "_MAPPING")
if not hasattr(transformers, name):
failures.append(f"`{name}` should be defined in the main `__init__` file.")
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
@ -993,6 +1016,8 @@ def check_repo_quality():
check_all_auto_object_names_being_defined()
print("Checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.")
check_all_auto_mapping_names_in_config_mapping_names()
print("Checking all auto mappings could be imported.")
check_all_auto_mappings_importable()
if __name__ == "__main__":