Compare commits

...

1 Commits

Author SHA1 Message Date
f994749cdc Add use_hub_kernel decorator
This decorator replaces a layer with a layer from the Kernel Hub.
2025-03-11 13:37:44 +00:00
6 changed files with 176 additions and 2 deletions

View File

@ -45,6 +45,7 @@ the Hub.
## 📚 Documentation
- [Using layers](docs/layers.md)
- [Locking kernel versions](docs/locking.md)
- [Using kernels in a Docker container](docs/docker.md)
- [Kernel requirements](docs/kernel-requirements.md)

View File

@ -76,6 +76,26 @@ might use two different commits that happen to have the same version
number. Git tags are not stable, so they do not provide a good way
of guaranteeing uniqueness of the namespace.
## Layers
Kernels can provide a `layers` attribute containing a Python module with
layers. Such layers can be used directly by downstream users or through
the `use_hub_kernel` decorator. For Torch, layers must be a subclass of
[`nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html).
To accommodate portable loading, `layers` must be defined in the main
`__init__.py` file. For example:
```python
from . import layers
__all__ = [
# ...
"layers"
# ...
]
```
## Python requirements
- Python code must be compatible with Python 3.9 and later.

39
docs/layers.md Normal file
View File

@ -0,0 +1,39 @@
# Layers
A kernel can provide layers in addition to kernel functions. For Torch
kernels, layers are subclasses of [`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html).
## Getting a layer
Layers are exposed through the `layers` attribute of a kernel,
so they can be used directly after loading a kernel with `get_kernel`.
For example:
```python
import kernels
import torch
activation = kernels.get_kernel("kernels-community/activation", revision="layers")
layer = activation.layers.SiluAndMul()
out = layer(torch.randn((64, 64), device='cuda:0'))
```
## Using a kernel layer as a replacement for an existing layer
An existing layer in a library can be Kernel Hub-enabled using the
`use_hub_kernel` decorator. This decorator will replace the existing
layer if the kernel layer could be loaded successfully.
For example:
```python
@use_hub_kernel(
"kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
class SiluAndMul(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
return F.silu(input[..., :d]) * input[..., d:]
```

View File

@ -1,3 +1,15 @@
from kernels.utils import get_kernel, get_locked_kernel, install_kernel, load_kernel
from kernels.decorators import use_hub_kernel
from kernels.utils import (
get_kernel,
get_locked_kernel,
install_kernel,
load_kernel,
)
__all__ = ["get_kernel", "get_locked_kernel", "load_kernel", "install_kernel"]
__all__ = [
"get_kernel",
"get_locked_kernel",
"load_kernel",
"install_kernel",
"use_hub_kernel",
]

54
src/kernels/decorators.py Normal file
View File

@ -0,0 +1,54 @@
from typing import TYPE_CHECKING
from .utils import get_kernel
if TYPE_CHECKING:
from torch import nn
def use_hub_kernel(
repo_id: str,
*,
layer_name: str,
revision: str = "main",
fallback_on_error: bool = True,
):
"""
Replace a layer with a layer from the kernel hub.
When `fallback_on_error` is True, the original layer will be used if
the kernel's layer could not be loaded.
"""
def decorator(cls):
try:
return _get_kernel_layer(
repo_id=repo_id, layer_name=layer_name, revision=revision
)
except Exception as e:
if not fallback_on_error:
raise e
return cls
return decorator
def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
"""Get a layer from a kernel."""
from torch import nn
kernel = get_kernel(repo_id, revision=revision)
if getattr(kernel, "layers", None) is None:
raise ValueError(
f"Kernel `{repo_id}` at revision `{revision}` does not define any layers."
)
layer = getattr(kernel.layers, layer_name, None)
if layer is None:
raise ValueError(f"Layer `{layer_name}` not found in kernel `{repo_id}`.")
if not issubclass(layer, nn.Module):
raise TypeError(f"Layer `{layer_name}` is not a Torch layer.")
return layer

48
tests/test_layer.py Normal file
View File

@ -0,0 +1,48 @@
import re
import torch
import torch.nn as nn
from torch.nn import functional as F
from kernels import use_hub_kernel
class SiluAndMul(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
return F.silu(input[..., :d]) * input[..., d:]
def test_activation_layer():
@use_hub_kernel(
"kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
fallback_on_error=False,
)
class SiluAndMulWithKernel(SiluAndMul):
pass
torch.random.manual_seed(0)
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), device="cuda")
Y = silu_and_mul(X)
# Verify that the Hub kernel was loaded.
assert SiluAndMulWithKernel.__name__ == "SiluAndMul"
assert re.match(r"activation.*layers", SiluAndMulWithKernel.__module__)
silu_and_mul_with_kernel = SiluAndMulWithKernel()
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
def test_layer_fallback_works():
@use_hub_kernel("kernels-community/non-existing", layer_name="SiluAndMul")
class SiluAndMulWithKernelFallback(SiluAndMul):
pass
# Check that we don't raise an exception for a non-existing kernel.
SiluAndMulWithKernelFallback()