Files
transformers/utils/modular_integrations.py
Yuanyuan Chen 12a50f294d Enable FURB rules in ruff (#41395)
* Apply ruff FURB rules

Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>

* Enable ruff FURB rules

Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>

* More fixes

Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>

* More fixes

Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>

* Revert changes

Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>

* More fixes

Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>

---------

Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
2025-10-17 15:00:40 +00:00

185 lines
7.0 KiB
Python

import os
import libcst as cst
# Files from external libraries that should not be tracked
# E.g. for habana, we don't want to track the dependencies from `modeling_all_models.py` as it is not part of the transformers library
EXCLUDED_EXTERNAL_FILES = {
"habana": [{"name": "modeling_all_models", "type": "modeling"}],
}
def convert_relative_import_to_absolute(
import_node: cst.ImportFrom,
file_path: str,
package_name: str | None = "transformers",
) -> cst.ImportFrom:
"""
Convert a relative libcst.ImportFrom node into an absolute one,
using the file path and package name.
Args:
import_node: A relative import node (e.g. `from ..utils import helper`)
file_path: Path to the file containing the import (can be absolute or relative)
package_name: The top-level package name (e.g. 'myproject')
Returns:
A new ImportFrom node with the absolute import path
"""
if not (import_node.relative and len(import_node.relative) > 0):
return import_node # Already absolute
file_path = os.path.abspath(file_path)
rel_level = len(import_node.relative)
# Strip file extension and split into parts
file_path_no_ext = file_path.removesuffix(".py")
file_parts = file_path_no_ext.split(os.path.sep)
# Ensure the file path includes the package name
if package_name not in file_parts:
raise ValueError(f"Package name '{package_name}' not found in file path '{file_path}'")
# Slice file_parts starting from the package name
pkg_index = file_parts.index(package_name)
module_parts = file_parts[pkg_index + 1 :] # e.g. ['module', 'submodule', 'foo']
if len(module_parts) < rel_level:
raise ValueError(f"Relative import level ({rel_level}) goes beyond package root.")
base_parts = module_parts[:-rel_level]
# Flatten the module being imported (if any)
def flatten_module(module: cst.BaseExpression | None) -> list[str]:
if not module:
return []
if isinstance(module, cst.Name):
return [module.value]
elif isinstance(module, cst.Attribute):
parts = []
while isinstance(module, cst.Attribute):
parts.insert(0, module.attr.value)
module = module.value
if isinstance(module, cst.Name):
parts.insert(0, module.value)
return parts
return []
import_parts = flatten_module(import_node.module)
# Combine to get the full absolute import path
full_parts = [package_name] + base_parts + import_parts
# Handle special case where the import comes from a namespace package (e.g. optimum with `optimum.habana`, `optimum.intel` instead of `src.optimum`)
if package_name != "transformers" and file_parts[pkg_index - 1] != "src":
full_parts = [file_parts[pkg_index - 1]] + full_parts
# Build the dotted module path
dotted_module: cst.BaseExpression | None = None
for part in full_parts:
name = cst.Name(part)
dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name)
# Return a new ImportFrom node with absolute import
return import_node.with_changes(module=dotted_module, relative=[])
def convert_to_relative_import(import_node: cst.ImportFrom, file_path: str, package_name: str) -> cst.ImportFrom:
"""
Convert an absolute import to a relative one if it belongs to `package_name`.
Parameters:
- node: The ImportFrom node to possibly transform.
- file_path: Absolute path to the file containing the import (e.g., '/path/to/mypackage/foo/bar.py').
- package_name: The top-level package name (e.g., 'mypackage').
Returns:
- A possibly modified ImportFrom node.
"""
if import_node.relative:
return import_node # Already relative import
# Extract module name string from ImportFrom
def get_module_name(module):
if isinstance(module, cst.Name):
return module.value, [module.value]
elif isinstance(module, cst.Attribute):
parts = []
while isinstance(module, cst.Attribute):
parts.append(module.attr.value)
module = module.value
if isinstance(module, cst.Name):
parts.append(module.value)
parts.reverse()
return ".".join(parts), parts
return "", None
module_name, submodule_list = get_module_name(import_node.module)
# Check if it's from the target package
if (
not (module_name.startswith(package_name + ".") or module_name.startswith("optimum." + package_name + "."))
and module_name != package_name
):
return import_node # Not from target package
# Locate the package root inside the file path
norm_file_path = os.path.normpath(file_path)
parts = norm_file_path.split(os.sep)
try:
pkg_index = parts.index(package_name)
except ValueError:
# Package name not found in path — assume we can't resolve relative depth
return import_node
# Depth is how many directories after the package name before the current file
depth = len(parts) - pkg_index - 1 # exclude the .py file itself
for i, submodule in enumerate(parts[pkg_index + 1 :]):
if submodule == submodule_list[2 + i]:
depth -= 1
else:
break
# Create the correct number of dots
relative = [cst.Dot()] * depth if depth > 0 else [cst.Dot()]
# Strip package prefix from import module path
if module_name.startswith("optimum." + package_name + "."):
stripped_name = module_name[len("optimum." + package_name) :].lstrip(".")
else:
stripped_name = module_name[len(package_name) :].lstrip(".")
# Build new module node
if stripped_name == "":
new_module = None
else:
name_parts = stripped_name.split(".")[i:]
new_module = cst.Name(name_parts[0])
for part in name_parts[1:]:
new_module = cst.Attribute(value=new_module, attr=cst.Name(part))
return import_node.with_changes(module=new_module, relative=relative)
class AbsoluteImportTransformer(cst.CSTTransformer):
def __init__(self, relative_path: str, source_library: str):
super().__init__()
self.relative_path = relative_path
self.source_library = source_library
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
return convert_relative_import_to_absolute(
import_node=updated_node, file_path=self.relative_path, package_name=self.source_library
)
class RelativeImportTransformer(cst.CSTTransformer):
def __init__(self, relative_path: str, source_library: str):
super().__init__()
self.relative_path = relative_path
self.source_library = source_library
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
return convert_to_relative_import(updated_node, self.relative_path, self.source_library)