Compare commits

...

3 Commits

Author SHA1 Message Date
ae43772a67 Export __version__ for doc-builder 2025-04-14 14:36:09 +00:00
c02f88cd2a CI: build docs 2025-04-14 14:36:09 +00:00
a3db6f437c Move documentation in preparation for doc-builder
Also document public functions better.
2025-04-14 14:36:09 +00:00
16 changed files with 332 additions and 29 deletions

View File

@ -0,0 +1,19 @@
name: Build documentation
on:
push:
paths:
- "docs/source/**"
branches:
- main
- doc-builder*
- v*-release
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
with:
commit_sha: ${{ github.sha }}
package: kernels
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}

View File

@ -0,0 +1,18 @@
name: Build PR Documentation
on:
pull_request:
paths:
- "docs/source/**"
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: kernels

26
docs/source/_toctree.yml Normal file
View File

@ -0,0 +1,26 @@
- sections:
- local: index
title: Introduction
- local: installation
title: Installation
title: Getting started
- sections:
- local: basic_usage
title: Basic Usage
- local: layers
title: Using Layers
- local: locking
title: Locking Kernel Versions
- local: env
title: Environment Variables
title: Usage Guide
- sections:
- local: api/kernels
title: Kernels
- local: api/layers
title: Layers
title: API Reference
- sections:
- local: kernel_requirements
title: Kernel Requirements
title: Developer Guide

View File

@ -0,0 +1,21 @@
# Kernels API Reference
## Main Functions
### get_kernel
[[autodoc]] kernels.get_kernel
### has_kernel
[[autodoc]] kernels.has_kernel
## Loading locked kernels
### load_kernel
[[autodoc]] kernels.load_kernel
### get_locked_kernel
[[autodoc]] kernels.get_locked_kernel

31
docs/source/api/layers.md Normal file
View File

@ -0,0 +1,31 @@
# Layers API Reference
## Making layers kernel-aware
### use_kernel_forward_from_hub
[[autodoc]] kernels.use_kernel_forward_from_hub
### replace_kernel_forward_from_hub
[[autodoc]] kernels.replace_kernel_forward_from_hub
## Registering kernel mappings
### use_kernel_mapping
[[autodoc]] kernels.use_kernel_mapping
### register_kernel_mapping
[[autodoc]] kernels.register_kernel_mapping
## Classes
### LayerRepository
[[autodoc]] kernels.LayerRepository
### Device
[[autodoc]] kernels.Device

View File

@ -0,0 +1,34 @@
# Basic Usage
## Loading Kernels
Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:
```python
import torch
from kernels import get_kernel
# Download optimized kernels from the Hugging Face hub
activation = get_kernel("kernels-community/activation")
# Create a random tensor
x = torch.randn((10, 10), dtype=torch.float16, device="cuda")
# Run the kernel
y = torch.empty_like(x)
activation.gelu_fast(y, x)
print(y)
```
## Checking Kernel Availability
You can check if a specific kernel is available for your environment:
```python
from kernels import has_kernel
# Check if kernel is available for current environment
is_available = has_kernel("kernels-community/activation")
print(f"Kernel available: {is_available}")
```

20
docs/source/index.md Normal file
View File

@ -0,0 +1,20 @@
# Kernels
<div align="center">
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
</div>
The Kernel Hub allows Python libraries and applications to load compute
kernels directly from the [Hub](https://hf.co/). To support this kind
of dynamic loading, Hub kernels differ from traditional Python kernel
packages in that they are made to be:
- **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`.
- **Unique**: multiple versions of the same kernel can be loaded in the
same Python process.
- **Compatible**: kernels must support all recent versions of Python and
the different PyTorch build configurations (various CUDA versions
and C++ ABIs). Furthermore, older C library versions must be supported.
You can [search for kernels](https://huggingface.co/models?other=kernel) on
the Hub.

View File

@ -0,0 +1,16 @@
# Installation
Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):
```bash
pip install kernels
```
# Using kernels in a Docker container
build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
```bash
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
```

View File

@ -31,7 +31,9 @@ dev = [
]
[project.optional-dependencies]
torch = ["torch"]
docs = [
"hf-doc-builder",
]
[project.scripts]
kernels = "kernels.cli:main"
@ -39,7 +41,6 @@ kernels = "kernels.cli:main"
[project.entry-points."egg_info.writers"]
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
[tool.ruff]
exclude = [
".eggs",

View File

@ -1,3 +1,5 @@
import importlib.metadata
from kernels.layer import (
Device,
LayerRepository,
@ -27,3 +29,5 @@ __all__ = [
"LayerRepository",
"Device",
]
__version__ = importlib.metadata.version("kernels")

View File

@ -64,11 +64,18 @@ def use_kernel_mapping(
inherit_mapping: bool = True,
):
"""
Context manager that sets a mapping for a duration of the context.
Context manager that sets a kernel mapping for the duration of the context.
When `inherit_mapping` is set to `True` the current mapping will be
extended by `mapping` inside the context. If it is `False`, only
`mapping` is used inside the context.
Args:
mapping (`Dict[str, Dict[Union[Device, str], LayerRepository]]`):
A mapping between layer names and their corresponding kernel repositories.
inherit_mapping (`bool`, *optional*, defaults to `True`):
The current mapping will be extended by `mapping` when set to `True`.
When set to `False`, the current mapping will be replaced by `mapping`
for the duration of the context.
Returns:
`ContextManager`: Context manager that sets up the mapping.
"""
class ContextManager:
@ -87,27 +94,31 @@ def use_kernel_mapping(
def register_kernel_mapping(
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
):
"""
Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device.
Register a mapping between a layer name the corresponding kernel to use, depending on the device.
This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname.
Exemple usage:
```python
from kernels import LayerRepository, register_kernel_mapping
Args:
mapping (`Dict[str, Dict[Union[Device, str], LayerRepository]]`):
A mapping between layer names and their corresponding kernel repositories.
kernel_layer_mapping = {
"LlamaRMSNorm": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="RmsNorm",
revision="layers",
),
},
}
register_kernel_mapping(kernel_layer_mapping)
```
Example:
```python
from kernels import LayerRepository, register_kernel_mapping
kernel_layer_mapping = {
"LlamaRMSNorm": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="RmsNorm",
revision="layers",
),
},
}
register_kernel_mapping(kernel_layer_mapping)
```
"""
# Merge with existing mappings.
for new_kernel, new_device_repos in mapping.items():
@ -125,8 +136,18 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
This function monkeypatches a layer, replacing the `forward` method
of the layer with that of a layer from the hub. The replacement is done
when a layer matching `layer_name` and device type is registered through
`register_layer_mapping`. The device type is inferred from the first
[`register_layer_mapping`]. The device type is inferred from the first
argument to `forward`.
Args:
cls (`nn.Module`):
The layer class to replace the forward function of.
layer_name (`str`):
The name to assign to the layer.
use_fallback (`bool`, *optional*, defaults to `True`):
Whether to use the fallback forward function if no kernel mapping
is found. If set to `False`, a `ValueError` will be raised if no kernel
mapping is found.
"""
fallback_forward = cls.forward
@ -195,11 +216,34 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True):
"""
Replace the forward function of a layer using a layer from the kernel hub.
This decorator can be applied to a layer and replaces the forward method
of the layer with that of a layer from the hub. The replacement is done
when a layer matching `layer_name` and device type is registered through
`register_layer_mapping`. The device type is inferred from the first
[`register_layer_mapping`]. The device type is inferred from the first
argument to `forward`.
Args:
layer_name (`str`):
The name to assign to the layer.
use_fallback (`bool`, *optional*, defaults to `True`):
Whether to use the fallback forward function if no kernel mapping
is found. If set to `False`, a `ValueError` will be raised if no kernel
mapping is found.
Example:
```python
from kernels import use_kernel_forward_from_hub
@use_kernel_forward_from_hub(layer_name="LlamaRMSNorm")
class LlamaRMSNorm(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
# Original forward implementation
pass
```
"""
def decorator(cls):

View File

@ -157,6 +157,27 @@ def install_kernel_all_variants(
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
"""
Load a kernel from the kernel hub.
This function downloads a kernel to the local Hugging Face Hub cache
directory (if it was not downloaded before) and then loads the kernel.
Args:
repo_id (`str`): The Hub repository containing the kernel.
revision (`str`, *optional*, defaults to `"main"`): The specific
revision (branch, tag, or commit) to download.
Returns:
`ModuleType`: The imported kernel module.
Example:
```python
from kernels import get_kernel
kernel = get_kernel("username/my-kernel")
result = kernel.kernel_function(input_data)
```
"""
package_name, package_path = install_kernel(repo_id, revision=revision)
return import_from_path(package_name, package_path / package_name / "__init__.py")
@ -164,7 +185,20 @@ def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
def has_kernel(repo_id: str, revision: str = "main") -> bool:
"""
Check whether a kernel build exists for the current environment
(Torch version and compute framework).
This function checks whether there exists a kernel build for the current
environment (Torch version, compute framework and architecture).
Args:
repo_id (`str`):
The Hub repository containing the kernel.
revision (`str`, *optional*, defaults to `"main"`):
The kernel revision.
Returns:
`bool`:
`True` if a compatible kernel build exists for the current environment,
`False` otherwise.
"""
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
@ -186,10 +220,25 @@ def has_kernel(repo_id: str, revision: str = "main") -> bool:
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
"""
Get a pre-downloaded, locked kernel.
Loads a pre-downloaded, locked kernel module from the local cache.
If `lockfile` is not specified, the lockfile will be loaded from the
caller's package metadata.
This function retrieves a kernel that was locked at a specific revision with
`kernels lock <project>` and then downloaded with `kernels download <project>`.
This function will fail if the kernel was not locked or downloaded. If you want
the kernel to be downloaded when it is not in the cache, use [`get_locked_kernel`]
instead.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
lockfile (`Optional[Path]`, *optional*, defaults to `None`):
Path to a lockfile containing the commit SHA for the kernel. If `None`,
the lock information is automatically retrieved from the metadata of the
calling package.
Returns:
`ModuleType`: The imported kernel module corresponding to the locked version.
"""
if lockfile is None:
locked_sha = _get_caller_locked_kernel(repo_id)
@ -234,7 +283,27 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
"""Get a kernel using a lock file."""
"""
Loads a locked kernel module.
This function retrieves a kernel that was locked at a specific revision with
`kernels lock <project>`.
This function will download the locked kernel when it is not available in the
cache. If you want loading to fail if the kernel is not in the cache, use
[`load_kernel`] instead.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
lockfile (`Optional[Path]`, *optional*, defaults to `None`):
Path to a lockfile containing the commit SHA for the kernel. If `None`,
the lock information is automatically retrieved from the metadata of the
calling package.
Returns:
`ModuleType`: The imported kernel module corresponding to the locked version.
"""
locked_sha = _get_caller_locked_kernel(repo_id)
if locked_sha is None: