Compare commits

..

4 Commits

Author SHA1 Message Date
f994749cdc Add use_hub_kernel decorator
This decorator replaces a layer with a layer from the Kernel Hub.
2025-03-11 13:37:44 +00:00
cf0413efe5 Add Nix flake devshell (#44) 2025-03-11 10:59:12 +01:00
851c13f666 Set version to 0.2.1 (#43) 2025-03-10 15:20:34 +01:00
b6a393612f Pass through locked sha again when loading locked kernels (#42)
This bit got removed accidentally when adding support for universal
kernels. Also add a test to ensure that we'd catch this in the future.
2025-03-10 15:10:47 +01:00
13 changed files with 401 additions and 18 deletions

View File

@ -45,6 +45,7 @@ the Hub.
## 📚 Documentation
- [Using layers](docs/layers.md)
- [Locking kernel versions](docs/locking.md)
- [Using kernels in a Docker container](docs/docker.md)
- [Kernel requirements](docs/kernel-requirements.md)

View File

@ -76,6 +76,26 @@ might use two different commits that happen to have the same version
number. Git tags are not stable, so they do not provide a good way
of guaranteeing uniqueness of the namespace.
## Layers
Kernels can provide a `layers` attribute containing a Python module with
layers. Such layers can be used directly by downstream users or through
the `use_hub_kernel` decorator. For Torch, layers must be a subclass of
[`nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html).
To accommodate portable loading, `layers` must be defined in the main
`__init__.py` file. For example:
```python
from . import layers
__all__ = [
# ...
"layers"
# ...
]
```
## Python requirements
- Python code must be compatible with Python 3.9 and later.

39
docs/layers.md Normal file
View File

@ -0,0 +1,39 @@
# Layers
A kernel can provide layers in addition to kernel functions. For Torch
kernels, layers are subclasses of [`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html).
## Getting a layer
Layers are exposed through the `layers` attribute of a kernel,
so they can be used directly after loading a kernel with `get_kernel`.
For example:
```python
import kernels
import torch
activation = kernels.get_kernel("kernels-community/activation", revision="layers")
layer = activation.layers.SiluAndMul()
out = layer(torch.randn((64, 64), device='cuda:0'))
```
## Using a kernel layer as a replacement for an existing layer
An existing layer in a library can be Kernel Hub-enabled using the
`use_hub_kernel` decorator. This decorator will replace the existing
layer if the kernel layer could be loaded successfully.
For example:
```python
@use_hub_kernel(
"kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
class SiluAndMul(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
return F.silu(input[..., :d]) * input[..., d:]
```

134
flake.lock generated Normal file
View File

@ -0,0 +1,134 @@
{
"nodes": {
"flake-compat": {
"locked": {
"lastModified": 1733328505,
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
"owner": "edolstra",
"repo": "flake-compat",
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
"type": "github"
},
"original": {
"owner": "edolstra",
"repo": "flake-compat",
"type": "github"
}
},
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"flake-utils_2": {
"inputs": {
"systems": "systems_2"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1737453259,
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
"owner": "danieldk",
"repo": "nixpkgs",
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "outlines-v0.1.4-tgi",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": [
"tgi-nix",
"nixpkgs"
],
"tgi-nix": "tgi-nix"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"systems_2": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"tgi-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1741617161,
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "kernels-0.2.0",
"repo": "text-generation-inference-nix",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

54
flake.nix Normal file
View File

@ -0,0 +1,54 @@
{
inputs = {
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
};
outputs =
{
self,
nixpkgs,
flake-utils,
tgi-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = import nixpkgs {
inherit system;
inherit (tgi-nix.lib) config;
overlays = [
tgi-nix.overlays.default
];
};
in
{
formatter = pkgs.nixfmt-rfc-style;
devShells = with pkgs; rec {
default = mkShell {
buildInputs =
[
black
mypy
pyright
ruff
]
++ (with python3.pkgs; [
huggingface-hub
pytest
pytest-benchmark
torch
venvShellHook
]);
venvDir = "./.venv";
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
( python -m pip install --no-build-isolation --no-dependencies -e . )
'';
};
};
}
);
}

View File

@ -1,7 +1,7 @@
[project]
name = "kernels"
version = "0.2.0"
description = "Download cuda kernels"
version = "0.2.1"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
{ name = "Daniel de Kok", email = "daniel@huggingface.co" },

View File

@ -1,3 +1,15 @@
from kernels.utils import get_kernel, get_locked_kernel, install_kernel, load_kernel
from kernels.decorators import use_hub_kernel
from kernels.utils import (
get_kernel,
get_locked_kernel,
install_kernel,
load_kernel,
)
__all__ = ["get_kernel", "get_locked_kernel", "load_kernel", "install_kernel"]
__all__ = [
"get_kernel",
"get_locked_kernel",
"load_kernel",
"install_kernel",
"use_hub_kernel",
]

54
src/kernels/decorators.py Normal file
View File

@ -0,0 +1,54 @@
from typing import TYPE_CHECKING
from .utils import get_kernel
if TYPE_CHECKING:
from torch import nn
def use_hub_kernel(
repo_id: str,
*,
layer_name: str,
revision: str = "main",
fallback_on_error: bool = True,
):
"""
Replace a layer with a layer from the kernel hub.
When `fallback_on_error` is True, the original layer will be used if
the kernel's layer could not be loaded.
"""
def decorator(cls):
try:
return _get_kernel_layer(
repo_id=repo_id, layer_name=layer_name, revision=revision
)
except Exception as e:
if not fallback_on_error:
raise e
return cls
return decorator
def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
"""Get a layer from a kernel."""
from torch import nn
kernel = get_kernel(repo_id, revision=revision)
if getattr(kernel, "layers", None) is None:
raise ValueError(
f"Kernel `{repo_id}` at revision `{revision}` does not define any layers."
)
layer = getattr(kernel.layers, layer_name, None)
if layer is None:
raise ValueError(f"Layer `{layer_name}` not found in kernel `{repo_id}`.")
if not issubclass(layer, nn.Module):
raise TypeError(f"Layer `{layer_name}` is not a Torch layer.")
return layer

View File

@ -144,9 +144,18 @@ def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
return import_from_path(package_name, package_path / package_name / "__init__.py")
def load_kernel(repo_id: str) -> ModuleType:
"""Get a pre-downloaded, locked kernel."""
locked_sha = _get_caller_locked_kernel(repo_id)
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
"""
Get a pre-downloaded, locked kernel.
If `lockfile` is not specified, the lockfile will be loaded from the
caller's package metadata.
"""
if lockfile is None:
locked_sha = _get_caller_locked_kernel(repo_id)
else:
with open(lockfile, "r") as f:
locked_sha = _get_locked_kernel(repo_id, f.read())
if locked_sha is None:
raise ValueError(
@ -163,6 +172,7 @@ def load_kernel(repo_id: str) -> ModuleType:
repo_id,
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
cache_dir=CACHE_DIR,
revision=locked_sha,
local_files_only=True,
)
)
@ -200,11 +210,19 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp
def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
for dist in _get_caller_distributions():
lock_json = dist.read_text("kernels.lock")
if lock_json is not None:
for kernel_lock_json in json.loads(lock_json):
kernel_lock = KernelLock.from_json(kernel_lock_json)
if kernel_lock.repo_id == repo_id:
return kernel_lock.sha
if lock_json is None:
continue
locked_sha = _get_locked_kernel(repo_id, lock_json)
if locked_sha is not None:
return locked_sha
return None
def _get_locked_kernel(repo_id: str, lock_json: str) -> Optional[str]:
for kernel_lock_json in json.loads(lock_json):
kernel_lock = KernelLock.from_json(kernel_lock_json)
if kernel_lock.repo_id == repo_id:
return kernel_lock.sha
return None

View File

@ -1,6 +1,7 @@
from dataclasses import dataclass
from pathlib import Path
from kernels import load_kernel
from kernels.cli import download_kernels
@ -11,11 +12,13 @@ class DownloadArgs:
project_dir: Path
def test_download_hash_validation():
project_dir = Path(__file__).parent / "hash_validation"
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
def test_download_all_hash_validation():
project_dir = Path(__file__).parent / "hash_validation"
project_dir = Path(__file__).parent / "kernel_locking"
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
def test_load_locked():
project_dir = Path(__file__).parent / "kernel_locking"
# Also validates that hashing works correctly.
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")

48
tests/test_layer.py Normal file
View File

@ -0,0 +1,48 @@
import re
import torch
import torch.nn as nn
from torch.nn import functional as F
from kernels import use_hub_kernel
class SiluAndMul(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
return F.silu(input[..., :d]) * input[..., d:]
def test_activation_layer():
@use_hub_kernel(
"kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
fallback_on_error=False,
)
class SiluAndMulWithKernel(SiluAndMul):
pass
torch.random.manual_seed(0)
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), device="cuda")
Y = silu_and_mul(X)
# Verify that the Hub kernel was loaded.
assert SiluAndMulWithKernel.__name__ == "SiluAndMul"
assert re.match(r"activation.*layers", SiluAndMulWithKernel.__module__)
silu_and_mul_with_kernel = SiluAndMulWithKernel()
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
def test_layer_fallback_works():
@use_hub_kernel("kernels-community/non-existing", layer_name="SiluAndMul")
class SiluAndMulWithKernelFallback(SiluAndMul):
pass
# Check that we don't raise an exception for a non-existing kernel.
SiluAndMulWithKernelFallback()