Compare commits

...

5 Commits

Author SHA1 Message Date
e7173d9a28 fix: lint correction 2025-08-07 18:16:34 +00:00
e487945cf6 fix: help mypy type checking 2025-08-07 18:14:07 +00:00
338752c367 fix: improve metaclass repr 2025-08-07 17:59:32 +00:00
53723ea986 fix: improve name 2025-08-07 17:53:51 +00:00
95d3a758a9 feat: add custom type for kernel modules 2025-08-07 17:48:29 +00:00

View File

@ -71,6 +71,24 @@ def universal_build_variant() -> str:
return "torch-universal"
# Metaclass to allow overriding the `__repr__` method for kernel modules.
class _KernelModuleMeta(type):
def __repr__(self):
return "<class 'kernel_module'>"
# Custom module type to identify dynamically loaded kernel modules.
# Using a subclass lets us distinguish these from regular imports.
class _KernelModuleType(ModuleType, metaclass=_KernelModuleMeta):
"""Marker class for modules loaded dynamically from a path."""
module_name: str
is_kernel: bool = True
def __repr__(self):
return f"<kernel_module '{self.module_name}' from '{self.__file__}'>"
def import_from_path(module_name: str, file_path: Path) -> ModuleType:
# We cannot use the module name as-is, after adding it to `sys.modules`,
# it would also be used for other imports. So, we make a module name that
@ -84,6 +102,9 @@ def import_from_path(module_name: str, file_path: Path) -> ModuleType:
module = importlib.util.module_from_spec(spec)
if module is None:
raise ImportError(f"Cannot load module {module_name} from spec")
module.__class__ = _KernelModuleType
assert isinstance(module, _KernelModuleType) # for mypy type checking
module.module_name = module_name
sys.modules[module_name] = module
spec.loader.exec_module(module) # type: ignore
return module