mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
* Apply ruff FURB rules Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> * Enable ruff FURB rules Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> * More fixes Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> * More fixes Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> * Revert changes Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> * More fixes Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> --------- Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
185 lines
7.0 KiB
Python
185 lines
7.0 KiB
Python
import os
|
|
|
|
import libcst as cst
|
|
|
|
|
|
# Files from external libraries that should not be tracked
|
|
# E.g. for habana, we don't want to track the dependencies from `modeling_all_models.py` as it is not part of the transformers library
|
|
EXCLUDED_EXTERNAL_FILES = {
|
|
"habana": [{"name": "modeling_all_models", "type": "modeling"}],
|
|
}
|
|
|
|
|
|
def convert_relative_import_to_absolute(
|
|
import_node: cst.ImportFrom,
|
|
file_path: str,
|
|
package_name: str | None = "transformers",
|
|
) -> cst.ImportFrom:
|
|
"""
|
|
Convert a relative libcst.ImportFrom node into an absolute one,
|
|
using the file path and package name.
|
|
|
|
Args:
|
|
import_node: A relative import node (e.g. `from ..utils import helper`)
|
|
file_path: Path to the file containing the import (can be absolute or relative)
|
|
package_name: The top-level package name (e.g. 'myproject')
|
|
|
|
Returns:
|
|
A new ImportFrom node with the absolute import path
|
|
"""
|
|
if not (import_node.relative and len(import_node.relative) > 0):
|
|
return import_node # Already absolute
|
|
|
|
file_path = os.path.abspath(file_path)
|
|
rel_level = len(import_node.relative)
|
|
|
|
# Strip file extension and split into parts
|
|
file_path_no_ext = file_path.removesuffix(".py")
|
|
file_parts = file_path_no_ext.split(os.path.sep)
|
|
|
|
# Ensure the file path includes the package name
|
|
if package_name not in file_parts:
|
|
raise ValueError(f"Package name '{package_name}' not found in file path '{file_path}'")
|
|
|
|
# Slice file_parts starting from the package name
|
|
pkg_index = file_parts.index(package_name)
|
|
module_parts = file_parts[pkg_index + 1 :] # e.g. ['module', 'submodule', 'foo']
|
|
if len(module_parts) < rel_level:
|
|
raise ValueError(f"Relative import level ({rel_level}) goes beyond package root.")
|
|
|
|
base_parts = module_parts[:-rel_level]
|
|
|
|
# Flatten the module being imported (if any)
|
|
def flatten_module(module: cst.BaseExpression | None) -> list[str]:
|
|
if not module:
|
|
return []
|
|
if isinstance(module, cst.Name):
|
|
return [module.value]
|
|
elif isinstance(module, cst.Attribute):
|
|
parts = []
|
|
while isinstance(module, cst.Attribute):
|
|
parts.insert(0, module.attr.value)
|
|
module = module.value
|
|
if isinstance(module, cst.Name):
|
|
parts.insert(0, module.value)
|
|
return parts
|
|
return []
|
|
|
|
import_parts = flatten_module(import_node.module)
|
|
|
|
# Combine to get the full absolute import path
|
|
full_parts = [package_name] + base_parts + import_parts
|
|
|
|
# Handle special case where the import comes from a namespace package (e.g. optimum with `optimum.habana`, `optimum.intel` instead of `src.optimum`)
|
|
if package_name != "transformers" and file_parts[pkg_index - 1] != "src":
|
|
full_parts = [file_parts[pkg_index - 1]] + full_parts
|
|
|
|
# Build the dotted module path
|
|
dotted_module: cst.BaseExpression | None = None
|
|
for part in full_parts:
|
|
name = cst.Name(part)
|
|
dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name)
|
|
|
|
# Return a new ImportFrom node with absolute import
|
|
return import_node.with_changes(module=dotted_module, relative=[])
|
|
|
|
|
|
def convert_to_relative_import(import_node: cst.ImportFrom, file_path: str, package_name: str) -> cst.ImportFrom:
|
|
"""
|
|
Convert an absolute import to a relative one if it belongs to `package_name`.
|
|
|
|
Parameters:
|
|
- node: The ImportFrom node to possibly transform.
|
|
- file_path: Absolute path to the file containing the import (e.g., '/path/to/mypackage/foo/bar.py').
|
|
- package_name: The top-level package name (e.g., 'mypackage').
|
|
|
|
Returns:
|
|
- A possibly modified ImportFrom node.
|
|
"""
|
|
if import_node.relative:
|
|
return import_node # Already relative import
|
|
|
|
# Extract module name string from ImportFrom
|
|
def get_module_name(module):
|
|
if isinstance(module, cst.Name):
|
|
return module.value, [module.value]
|
|
elif isinstance(module, cst.Attribute):
|
|
parts = []
|
|
while isinstance(module, cst.Attribute):
|
|
parts.append(module.attr.value)
|
|
module = module.value
|
|
if isinstance(module, cst.Name):
|
|
parts.append(module.value)
|
|
parts.reverse()
|
|
return ".".join(parts), parts
|
|
return "", None
|
|
|
|
module_name, submodule_list = get_module_name(import_node.module)
|
|
|
|
# Check if it's from the target package
|
|
if (
|
|
not (module_name.startswith(package_name + ".") or module_name.startswith("optimum." + package_name + "."))
|
|
and module_name != package_name
|
|
):
|
|
return import_node # Not from target package
|
|
|
|
# Locate the package root inside the file path
|
|
norm_file_path = os.path.normpath(file_path)
|
|
parts = norm_file_path.split(os.sep)
|
|
|
|
try:
|
|
pkg_index = parts.index(package_name)
|
|
except ValueError:
|
|
# Package name not found in path — assume we can't resolve relative depth
|
|
return import_node
|
|
|
|
# Depth is how many directories after the package name before the current file
|
|
depth = len(parts) - pkg_index - 1 # exclude the .py file itself
|
|
for i, submodule in enumerate(parts[pkg_index + 1 :]):
|
|
if submodule == submodule_list[2 + i]:
|
|
depth -= 1
|
|
else:
|
|
break
|
|
|
|
# Create the correct number of dots
|
|
relative = [cst.Dot()] * depth if depth > 0 else [cst.Dot()]
|
|
|
|
# Strip package prefix from import module path
|
|
if module_name.startswith("optimum." + package_name + "."):
|
|
stripped_name = module_name[len("optimum." + package_name) :].lstrip(".")
|
|
else:
|
|
stripped_name = module_name[len(package_name) :].lstrip(".")
|
|
|
|
# Build new module node
|
|
if stripped_name == "":
|
|
new_module = None
|
|
else:
|
|
name_parts = stripped_name.split(".")[i:]
|
|
new_module = cst.Name(name_parts[0])
|
|
for part in name_parts[1:]:
|
|
new_module = cst.Attribute(value=new_module, attr=cst.Name(part))
|
|
|
|
return import_node.with_changes(module=new_module, relative=relative)
|
|
|
|
|
|
class AbsoluteImportTransformer(cst.CSTTransformer):
|
|
def __init__(self, relative_path: str, source_library: str):
|
|
super().__init__()
|
|
self.relative_path = relative_path
|
|
self.source_library = source_library
|
|
|
|
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
|
|
return convert_relative_import_to_absolute(
|
|
import_node=updated_node, file_path=self.relative_path, package_name=self.source_library
|
|
)
|
|
|
|
|
|
class RelativeImportTransformer(cst.CSTTransformer):
|
|
def __init__(self, relative_path: str, source_library: str):
|
|
super().__init__()
|
|
self.relative_path = relative_path
|
|
self.source_library = source_library
|
|
|
|
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
|
|
return convert_to_relative_import(updated_node, self.relative_path, self.source_library)
|