Compare commits

...

1 Commits

2 changed files with 78 additions and 31 deletions

View File

@ -1,7 +1,45 @@
import ast
import os
from collections import defaultdict
FILE_TYPE_PREFIXES = (
"modular_",
"modeling_",
"configuration_",
"tokenization_",
"processing_",
"image_processing_",
"feature_extractor_",
)
def get_model_name_from_filename(filename: str) -> str:
"""From a filename pointing to a model file of any type, extract the model name."""
# If it contains the extension, remove it
modified_filename = filename[:-3] if filename.endswith(".py") else filename
# If it's a full path, extract last part
modified_filename = os.path.basename(modified_filename)
# It may also appear as a pathname, but in Python import form (i.e. with `.` instead of `/` to separate)
modified_filename = modified_filename.split(".")[-1]
file_prefix = None
for prefix in FILE_TYPE_PREFIXES:
if modified_filename.startswith(prefix):
file_prefix = prefix
break
if file_prefix is None:
raise ValueError(f"It looks like `{filename}` is not a Transformers model file!")
model_name = modified_filename.replace(file_prefix, "", 1)
# This filetype may have the "_fast" suffix as well
if file_prefix == "image_processing_" and model_name.endswith("_fast"):
model_name = model_name[:-5]
return model_name
# Function to perform topological sorting
def topological_sort(dependencies: dict):
# Nodes are the name of the models to convert (we only add those to the graph)
@ -11,7 +49,7 @@ def topological_sort(dependencies: dict):
name_mapping = {}
for node, deps in dependencies.items():
node_name = node.rsplit("modular_", 1)[1].replace(".py", "")
dep_names = {dep.split(".")[-2] for dep in deps}
dep_names = {get_model_name_from_filename(dep) for dep in deps}
dependencies = {dep for dep in dep_names if dep in nodes and dep != node_name}
graph[node_name] = dependencies
name_mapping[node_name] = node

View File

@ -23,7 +23,7 @@ from typing import Dict, Optional, Set, Union
import libcst as cst
from check_copies import run_ruff
from create_dependency_mapping import find_priority_list
from create_dependency_mapping import find_priority_list, get_model_name_from_filename
from libcst import ClassDef, CSTVisitor
from libcst import matchers as m
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
@ -893,7 +893,7 @@ class ModelFileMapper(ModuleMapper):
def common_partial_suffix(str1: str, str2: str) -> str:
"""Return the biggest common suffix between 2 strings. If one string is a full suffix of the other string,
we do not consider it a common suffix and return `""`"""
we do not consider it a common suffix and return `""`."""
common_suffix = ""
for i in range(1, min(len(str1), len(str2)) + 1):
if str1[-i] == str2[-i]:
@ -1280,7 +1280,7 @@ class ModularFileMapper(ModuleMapper):
self.renamers = {}
name_prefixes = self.infer_new_model_name()
for file, module in self.model_specific_modules.items():
file_model_name = file.split(".")[-2]
file_model_name = get_model_name_from_filename(file)
new_name = name_prefixes[file]
renamer = ReplaceNameTransformer(file_model_name, new_name, self.model_name)
renamed_module = module.visit(renamer)
@ -1425,6 +1425,17 @@ class ModularFileMapper(ModuleMapper):
# Check if we found multiple prefixes for some modeling files
final_name_mapping = {}
for file, prefixes_counter in prefix_model_name_mapping.items():
old_cased_model_name = get_cased_name(get_model_name_from_filename(file))
# Edge-case: if for some reason the new and old model name have a common prefix, it was wrongly removed
# before, so we need to add it back (if e.g. both model names end with `ResNet`)
model_common_suffix = common_partial_suffix(cased_default_name, old_cased_model_name)
prefixes_counter = Counter(
{
k + model_common_suffix if k + model_common_suffix == cased_default_name else k: v
for k, v in prefixes_counter.items()
}
)
# Find the final name for the file, based on the most frequent prefix in the Counter
if len(prefixes_counter) > 1:
_, total = prefixes_counter.most_common(1)[0]
most_used_entities = [name for name, count in prefixes_counter.most_common() if count == total]
@ -1433,7 +1444,6 @@ class ModularFileMapper(ModuleMapper):
else:
final_name = list(prefixes_counter)[0]
# Check if the prefix can be used without collisions in the names
old_cased_model_name = get_cased_name(file.split(".")[-2])
old_model_name_prefix = final_name.replace(cased_default_name, old_cased_model_name)
# Raise adequate warning depending on the situation
has_prefix_collision = f"\nclass {old_model_name_prefix}" in get_module_source_from_name(file)
@ -1649,34 +1659,33 @@ def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]:
def convert_modular_file(modular_file):
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
output = {}
if pattern is not None:
model_name = pattern.groups()[0]
# Parse the Python file
with open(modular_file, "r", encoding="utf-8") as file:
code = file.read()
module = cst.parse_module(code)
wrapper = MetadataWrapper(module)
cst_transformers = ModularFileMapper(module, model_name)
wrapper.visit(cst_transformers)
for file, module in create_modules(cst_transformers).items():
if module != {}:
# Get relative path starting from src/transformers/
relative_path = re.search(
r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/")
).group(1)
header = AUTO_GENERATED_MESSAGE.format(
relative_path=relative_path, short_name=os.path.basename(relative_path)
)
ruffed_code = run_ruff(header + module.code, True)
formatted_code = run_ruff(ruffed_code, False)
output[file] = [formatted_code, ruffed_code]
try:
model_name = get_model_name_from_filename(modular_file)
except RuntimeError:
print("The file you are trying to convert (`modular_file`) is not a modular file!")
return output
else:
print(f"modular pattern not found in {modular_file}, exiting")
return {}
# Parse the Python file
with open(modular_file, "r", encoding="utf-8") as file:
code = file.read()
module = cst.parse_module(code)
wrapper = MetadataWrapper(module)
cst_transformers = ModularFileMapper(module, model_name)
wrapper.visit(cst_transformers)
for file, module in create_modules(cst_transformers).items():
if module != {}:
# Get relative path starting from src/transformers/
relative_path = re.search(
r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/")
).group(1)
header = AUTO_GENERATED_MESSAGE.format(
relative_path=relative_path, short_name=os.path.basename(relative_path)
)
ruffed_code = run_ruff(header + module.code, True)
formatted_code = run_ruff(ruffed_code, False)
output[file] = [formatted_code, ruffed_code]
return output
def save_modeling_file(modular_file, converted_file):