Compare commits

..

5 Commits

Author SHA1 Message Date
e7173d9a28 fix: lint correction 2025-08-07 18:16:34 +00:00
e487945cf6 fix: help mypy type checking 2025-08-07 18:14:07 +00:00
338752c367 fix: improve metaclass repr 2025-08-07 17:59:32 +00:00
53723ea986 fix: improve name 2025-08-07 17:53:51 +00:00
95d3a758a9 feat: add custom type for kernel modules 2025-08-07 17:48:29 +00:00
2 changed files with 24 additions and 58 deletions

View File

@ -71,6 +71,24 @@ def universal_build_variant() -> str:
return "torch-universal"
# Metaclass to allow overriding the `__repr__` method for kernel modules.
class _KernelModuleMeta(type):
def __repr__(self):
return "<class 'kernel_module'>"
# Custom module type to identify dynamically loaded kernel modules.
# Using a subclass lets us distinguish these from regular imports.
class _KernelModuleType(ModuleType, metaclass=_KernelModuleMeta):
"""Marker class for modules loaded dynamically from a path."""
module_name: str
is_kernel: bool = True
def __repr__(self):
return f"<kernel_module '{self.module_name}' from '{self.__file__}'>"
def import_from_path(module_name: str, file_path: Path) -> ModuleType:
# We cannot use the module name as-is, after adding it to `sys.modules`,
# it would also be used for other imports. So, we make a module name that
@ -84,6 +102,9 @@ def import_from_path(module_name: str, file_path: Path) -> ModuleType:
module = importlib.util.module_from_spec(spec)
if module is None:
raise ImportError(f"Cannot load module {module_name} from spec")
module.__class__ = _KernelModuleType
assert isinstance(module, _KernelModuleType) # for mypy type checking
module.module_name = module_name
sys.modules[module_name] = module
spec.loader.exec_module(module) # type: ignore
return module
@ -248,24 +269,8 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
Returns:
`ModuleType`: The imported kernel module.
"""
variant = build_variant()
universal_variant = universal_build_variant()
# Presume we were given the top level path of the kernel repository.
for base_path in [repo_path, repo_path / "build"]:
# Prefer the universal variant if it exists.
for v in [universal_variant, variant]:
package_path = base_path / v / package_name / "__init__.py"
if package_path.exists():
return import_from_path(package_name, package_path)
# If we didn't find the package in the repo we may have a explicit
# package path.
package_path = repo_path / package_name / "__init__.py"
if package_path.exists():
return import_from_path(package_name, package_path)
raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def has_kernel(

View File

@ -10,16 +10,10 @@ def kernel():
@pytest.fixture
def local_kernel_path():
def local_kernel():
package_name, path = install_kernel("kernels-community/activation", "main")
# Path is the build variant path (build/torch-<...>), so the grandparent
# is the kernel repository path.
return package_name, path
@pytest.fixture
def local_kernel(local_kernel_path):
package_name, path = local_kernel_path
return get_local_kernel(path.parent.parent, package_name)
@ -72,39 +66,6 @@ def test_local_kernel(local_kernel, device):
assert torch.allclose(y, expected)
@pytest.mark.cuda_only
def test_local_kernel_path_types(local_kernel_path, device):
package_name, path = local_kernel_path
# Top-level repo path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071
kernel = get_local_kernel(path.parent.parent, package_name)
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
device=device,
dtype=torch.float16,
)
assert torch.allclose(y, expected)
# Build directory path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build
kernel = get_local_kernel(path.parent.parent / "build", package_name)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
assert torch.allclose(y, expected)
# Explicit package path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux
kernel = get_local_kernel(path, package_name)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
assert torch.allclose(y, expected)
@pytest.mark.darwin_only
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_relu_metal(metal_kernel, dtype):