Compare commits

...

3 Commits

Author SHA1 Message Date
1c7c87c960 Set version to 0.3.0 (#47) 2025-03-19 12:02:02 +01:00
df45cf2795 Add use_kernel_forward_from_hub decorator (#46)
* Add `use_kernel_forward_from_hub` decorator

This decorator replaces a layer's `forward` with the `forward` of
a layer on the hub.

* Add support for registering a mapping for the duration of a context

This change makes `_KERNEL_MAPPING` a context variable and adds a
`use_kernel_mapping` context manager. This allows users to register
a mapping for the duration of a context.

* Update layer docs

* ruff fix

* Remove an old bit from the docs

* Extend layer mapping example

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Support stringly-typed device type

* Forward-reference `register_kernel_mapping` in monkeypatching section

* Use stringly-typed device name in layer mapping example

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2025-03-19 11:03:18 +01:00
cf0413efe5 Add Nix flake devshell (#44) 2025-03-11 10:59:12 +01:00
9 changed files with 764 additions and 3 deletions

View File

@ -45,6 +45,7 @@ the Hub.
## 📚 Documentation
- [Using layers](docs/layers.md)
- [Locking kernel versions](docs/locking.md)
- [Using kernels in a Docker container](docs/docker.md)
- [Kernel requirements](docs/kernel-requirements.md)

View File

@ -76,6 +76,80 @@ might use two different commits that happen to have the same version
number. Git tags are not stable, so they do not provide a good way
of guaranteeing uniqueness of the namespace.
## Layers
A kernel can provide layers in addition to kernel functions. A layer from
the Hub can replace the `forward` method of an existing layer for a certain
device type. This makes it possible to provide more performant kernels for
existing layers. See the [layers documentation](layers.md) for more information
on how to use layers.
### Writing layers
To make the extension of layers safe, the layers must fulfill the following
requirements:
- The layers are subclasses of `torch.nn.Module`.
- The layers are pure, meaning that they do not have their own state. This
means that:
- The layer must not define its own constructor.
- The layer must not use class variables.
- No other methods must be defined than `forward`.
- The `forward` method has a signature that is compatible with the
`forward` method that it is extending.
This is an example of a pure layer:
```python
class SiluAndMul(nn.Module):
def forward(self, x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
return out
```
For some layers, the `forward` method has to use state from the adopting class.
In these cases, we recommend to use type annotations to indicate what member
variables are expected. For instance:
```python
class LlamaRMSNorm(nn.Module):
weight: torch.Tensor
variance_epsilon: float
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return rms_norm_fn(
hidden_states,
self.weight,
bias=None,
residual=None,
eps=self.variance_epsilon,
dropout_p=0.0,
prenorm=False,
residual_in_fp32=False,
)
```
This layer expects the adopting layer to have `weight` and `variance_epsilon`
member variables and uses them in the `forward` method.
### Exporting layers
To accommodate portable loading, `layers` must be defined in the main
`__init__.py` file. For example:
```python
from . import layers
__all__ = [
# ...
"layers"
# ...
]
```
## Python requirements
- Python code must be compatible with Python 3.9 and later.

79
docs/layers.md Normal file
View File

@ -0,0 +1,79 @@
# Layers
A kernel can provide layers in addition to kernel functions. A layer from
the Hub can replace the `forward` method of an existing layer for a certain
device type. This makes it possible to provide more performant kernels for
existing layers.
See [Kernel requirements](kernel-requirements.md) for more information the
requirements of Hub layers.
## Making a layer extensible with kernels from the hub
### Using a decorator
A layer can be made extensible with the `use_kernel_forward_from_hub`
decorator. For example:
```python
@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMul(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
return F.silu(input[..., :d]) * input[..., d:]
```
The decorator changes the layer, so that other implementations of the `forward`
method can be registered using the name `SiluAndMul`.
### External layers
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function.
```python
from somelibrary import SiluAndMul
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
register_kernel_mapping(kernel_layer_mapping)
```
The `register_kernel_mapping` call maps the name `SiluAndMul` to actual
hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer)
section for more information.
**Warning:** we strongly recommend using layers with a decorator, since
it signifies that the maintainer intends to keep the `forward` signature
compatible with layers from the hub.
## Registering a hub kernel for a layer
Once a layer is made extensible, users can register hub kernels for it
by name using the `register_kernel_mapping` function. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
}
}
register_kernel_mapping(kernel_layer_mapping)
```
This will register the kernel mapping in the current context, which is
normally global. It is recommended to scope the mapping to where it is
used with the `use_kernel_mapping` context manager:
```python
with use_kernel_mapping(kernel_layer_mapping):
# Use the layer for which the mapping is applied.
...
```
This ensures that the mapping is not active anymore outside the
`with`-scope.

134
flake.lock generated Normal file
View File

@ -0,0 +1,134 @@
{
"nodes": {
"flake-compat": {
"locked": {
"lastModified": 1733328505,
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
"owner": "edolstra",
"repo": "flake-compat",
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
"type": "github"
},
"original": {
"owner": "edolstra",
"repo": "flake-compat",
"type": "github"
}
},
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"flake-utils_2": {
"inputs": {
"systems": "systems_2"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1737453259,
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
"owner": "danieldk",
"repo": "nixpkgs",
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "outlines-v0.1.4-tgi",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": [
"tgi-nix",
"nixpkgs"
],
"tgi-nix": "tgi-nix"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"systems_2": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"tgi-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1741617161,
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "kernels-0.2.0",
"repo": "text-generation-inference-nix",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

54
flake.nix Normal file
View File

@ -0,0 +1,54 @@
{
inputs = {
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
};
outputs =
{
self,
nixpkgs,
flake-utils,
tgi-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = import nixpkgs {
inherit system;
inherit (tgi-nix.lib) config;
overlays = [
tgi-nix.overlays.default
];
};
in
{
formatter = pkgs.nixfmt-rfc-style;
devShells = with pkgs; rec {
default = mkShell {
buildInputs =
[
black
mypy
pyright
ruff
]
++ (with python3.pkgs; [
huggingface-hub
pytest
pytest-benchmark
torch
venvShellHook
]);
venvDir = "./.venv";
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
( python -m pip install --no-build-isolation --no-dependencies -e . )
'';
};
};
}
);
}

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.2.1"
version = "0.3.0"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },

View File

@ -1,3 +1,23 @@
from kernels.utils import get_kernel, get_locked_kernel, install_kernel, load_kernel
from kernels.layer import (
Device,
LayerRepository,
register_kernel_mapping,
use_kernel_forward_from_hub,
)
from kernels.utils import (
get_kernel,
get_locked_kernel,
install_kernel,
load_kernel,
)
__all__ = ["get_kernel", "get_locked_kernel", "load_kernel", "install_kernel"]
__all__ = [
"get_kernel",
"get_locked_kernel",
"load_kernel",
"install_kernel",
"use_kernel_forward_from_hub",
"register_kernel_mapping",
"LayerRepository",
"Device",
]

231
src/kernels/layer.py Normal file
View File

@ -0,0 +1,231 @@
import inspect
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Dict, Union
from .utils import get_kernel
if TYPE_CHECKING:
from torch import nn
@dataclass(frozen=True)
class Device:
type: str
# In the future we might add compute capabilities, etc.
def __eq__(self, other):
return isinstance(other, Device) and self.type == other.type
def __hash__(self):
return hash(self.type)
@dataclass
class LayerRepository:
"""
Repository and name of a layer.
"""
layer_name: str = field(
metadata={"help": "The name of the layer in the kernel repository."}
)
repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."})
revision: str = field(
default="main", metadata={"help": "The revision of the layer."}
)
def __eq__(self, other):
return (
isinstance(other, LayerRepository)
and self.layer_name == other.layer_name
and self.repo_id == other.repo_id
and self.revision == other.revision
)
def __hash__(self):
return hash((self.layer_name, self.repo_id, self.revision))
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
"_KERNEL_MAPPING", default={}
)
def use_kernel_mapping(mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]):
class ContextManager:
def __enter__(self):
# Mappings always stack on previous mappings.
self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get()))
register_kernel_mapping(mapping)
def __exit__(self, exc_type, exc_value, traceback):
_KERNEL_MAPPING.reset(self.token)
return ContextManager()
def register_kernel_mapping(
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]
):
"""
Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device.
This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname.
Exemple usage:
```python
from kernels import LayerRepository, register_kernel_mapping
kernel_layer_mapping = {
"LlamaRMSNorm": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="RmsNorm",
revision="layers",
),
},
}
register_kernel_mapping(kernel_layer_mapping)
```
"""
# Merge with existing mappings.
for new_kernel, new_device_repos in mapping.items():
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
for new_device, new_repo in new_device_repos.items():
if isinstance(new_device, str):
device_repo[Device(type=new_device)] = new_repo
else:
device_repo[new_device] = new_repo
def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True):
"""
Replace the forward function of a layer using a layer from the kernel hub.
This function monkeypatches a layer, replacing the `forward` method
of the layer with that of a layer from the hub. The replacement is done
when a layer matching `layer_name` and device type is registered through
`register_layer_mapping`. The device type is inferred from the first
argument to `forward`.
"""
fallback_forward = cls.forward
cached_forward: Dict[LayerRepository, Callable] = {}
def forward(self, x, **args):
kernel = _KERNEL_MAPPING.get().get(layer_name)
if kernel is None:
if not use_fallback:
raise ValueError(f"No layer mapping for `{layer_name}`")
return fallback_forward(self, x, **args)
device = getattr(x, "device", None)
if device is None:
return fallback_forward(self, x, **args)
repo = kernel.get(Device(type=device.type))
if repo is None:
if not use_fallback:
raise ValueError(
f"No layer mapping for `{layer_name}` with device type `{device.type}`"
)
return fallback_forward(self, x, **args)
# Short-circuit if we already loaded the layer.
layer_forward = cached_forward.get(repo, None)
if layer_forward is not None:
return layer_forward(self, x, **args)
layer = _get_kernel_layer(
repo_id=repo.repo_id,
layer_name=repo.layer_name,
revision=repo.revision,
)
# We have to validate against the original signature.
orig_forward = cls.forward
try:
cls.forward = fallback_forward
_validate_layer(check_cls=cls, cls=layer)
finally:
cls.forward = orig_forward
layer_forward = layer.forward
cached_forward[repo] = layer_forward
return layer_forward(self, x, **args)
cls.forward = forward
def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True):
"""
Replace the forward function of a layer using a layer from the kernel hub.
This decorator can be applied to a layer and replaces the forward method
of the layer with that of a layer from the hub. The replacement is done
when a layer matching `layer_name` and device type is registered through
`register_layer_mapping`. The device type is inferred from the first
argument to `forward`.
"""
def decorator(cls):
replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback)
return cls
return decorator
def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
"""Get a layer from a kernel."""
kernel = get_kernel(repo_id, revision=revision)
if getattr(kernel, "layers", None) is None:
raise ValueError(
f"Kernel `{repo_id}` at revision `{revision}` does not define any layers."
)
layer = getattr(kernel.layers, layer_name, None)
if layer is None:
raise ValueError(f"Layer `{layer_name}` not found in kernel `{repo_id}`.")
return layer
def _validate_layer(*, check_cls, cls):
# The layer must have at least have the following properties: (1) it
# must be stateless; (2) the forward signature should correspond to
# the signature it is replacing; (3) forward should not call other
# methods.
from torch import nn
if not issubclass(cls, nn.Module):
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
# We verify statelessness by checking that the does not have its own
# constructor (since the constructor could add member variables)...
if cls.__init__ is not nn.Module.__init__:
raise TypeError("Layer must not override nn.Module constructor.")
# ... or predefined member variables.
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(cls)}
if cls_members - torch_module_members != set():
raise TypeError("Layer must not contain additional members.")
# Check whether the forward signatures are similar.
params = inspect.signature(cls.forward).parameters
ref_params = inspect.signature(check_cls.forward).parameters
if len(params) != len(ref_params):
raise TypeError(
"Forward signature does not match: different number of arguments."
)
for param, ref_param in zip(params.values(), ref_params.values()):
if param.kind != ref_param.kind:
raise TypeError(
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
)

168
tests/test_layer.py Normal file
View File

@ -0,0 +1,168 @@
import pytest
import torch
import torch.nn as nn
from torch.nn import functional as F
from kernels import (
Device,
LayerRepository,
register_kernel_mapping,
use_kernel_forward_from_hub,
)
from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping
kernel_layer_mapping = {
"SiluAndMul": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
},
"SiluAndMulStringDevice": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
},
}
register_kernel_mapping(kernel_layer_mapping)
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
d = input.shape[-1] // 2
return F.silu(input[..., :d]) * input[..., d:]
@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMulWithKernel(SiluAndMul):
pass
@use_kernel_forward_from_hub("SiluAndMulStringDevice")
class SiluAndMulStringDevice(SiluAndMul):
pass
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_hub_forward(cls, device):
torch.random.manual_seed(0)
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
assert silu_and_mul.n_calls == 1
if device == "cuda":
assert silu_and_mul_with_kernel.n_calls == 0
else:
assert silu_and_mul_with_kernel.n_calls == 1
def test_layer_fallback_works():
@use_kernel_forward_from_hub("SiluAndMulNonExisting")
class SiluAndMulWithKernelFallback(SiluAndMul):
pass
# Check that we don't raise an exception for a non-existing kernel.
SiluAndMulWithKernelFallback()
def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"}
extra_mapping1 = {
"TestKernel": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
}
}
with use_kernel_mapping(extra_mapping1):
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"TestKernel",
}
extra_mapping2 = {
"SiluAndMul": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-community/non-existing",
layer_name="SiluAndMul",
revision="layers",
)
}
}
with use_kernel_mapping(extra_mapping2):
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
== "kernels-community/non-existing"
)
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
== "kernels-community/activation"
)
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
}
def test_validate_kernel_layer():
class BadLayer(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.foo = 42
with pytest.raises(TypeError, match="not override"):
_validate_layer(cls=BadLayer, check_cls=SiluAndMul)
class BadLayer2(nn.Module):
foo: int = 42
with pytest.raises(TypeError, match="not contain additional members"):
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
class BadLayer3(nn.Module):
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
with pytest.raises(TypeError, match="different number of arguments"):
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
class BadLayer4(nn.Module):
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
with pytest.raises(TypeError, match="different kind of arguments"):
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)