Compare commits

...

5 Commits

Author SHA1 Message Date
03a8662f7f locking docs: fix command name (kernel -> kernels) 2025-04-14 13:19:24 +02:00
cf530c283a Set version to 0.4.4 (#73) 2025-04-11 10:23:26 +02:00
437f910336 Add has_kernel function (#69)
* Add `has_kernel` function

This function checks whether a kernel build exists for the current
environment (Torch version and compute framework).

* Test kernel repo that only contains Torch 2.4
2025-04-11 10:12:37 +02:00
6f1a6067c8 feat: add logo and shields (#72) 2025-04-11 10:07:24 +02:00
1d14abcef0 Do not use kernels without backward when training (#68)
* Do not use kernels without backward when training

* Update repo for backwards marker test
2025-04-11 10:05:57 +02:00
9 changed files with 150 additions and 14 deletions

View File

@ -1,5 +1,16 @@
# kernels
<div align="center">
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
<p align="center">
<a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
<a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
<a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
</p>
</div>
<hr/>
The Kernel Hub allows Python libraries and applications to load compute
kernels directly from the [Hub](https://hf.co/). To support this kind
of dynamic loading, Hub kernels differ from traditional Python kernel

View File

@ -119,10 +119,17 @@ requirements:
- The `forward` method has a signature that is compatible with the
`forward` method that it is extending.
The only exception to the _no class variables rule_ is addition of a
`has_backward` class variable. This variable is used to indicate whether
the layer has a backward pass implemented (`True` when absent).
This is an example of a pure layer:
```python
class SiluAndMul(nn.Module):
# This layer does not implement backward.
has_backward: bool = False
def forward(self, x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)

View File

@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
"kernels-community/activation" = ">=0.0.1"
```
Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with
Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with
the locked revisions. The locked revision will be used when loading a kernel with
`get_locked_kernel`:
@ -28,7 +28,7 @@ to `kernels` after doing an (editable or regular) installation of your project.
## Pre-downloading locked kernels
Locked kernels can be pre-downloaded by running `kernel download .` in your
Locked kernels can be pre-downloaded by running `kernels download .` in your
project directory. This will download the kernels to your local Hugging Face
Hub cache.

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.4.3"
version = "0.4.4"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },

View File

@ -9,6 +9,7 @@ from kernels.layer import (
from kernels.utils import (
get_kernel,
get_locked_kernel,
has_kernel,
install_kernel,
load_kernel,
)
@ -16,6 +17,7 @@ from kernels.utils import (
__all__ = [
"get_kernel",
"get_locked_kernel",
"has_kernel",
"load_kernel",
"install_kernel",
"use_kernel_forward_from_hub",

View File

@ -4,7 +4,7 @@ import warnings
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Dict, Union
from typing import TYPE_CHECKING, Dict, Union
from .utils import get_kernel
@ -131,12 +131,13 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
fallback_forward = cls.forward
cached_forward: Dict[LayerRepository, Callable] = {}
cached_layer: Dict[LayerRepository, nn.Module] = {}
def forward(self, x, *args, **kwargs):
if _DISABLE_KERNEL_MAPPING:
return fallback_forward(self, x, *args, **kwargs)
needs_backward = self.training
kernel = _KERNEL_MAPPING.get().get(layer_name)
if kernel is None:
warnings.warn(
@ -162,9 +163,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
return fallback_forward(self, x, *args, **kwargs)
# Short-circuit if we already loaded the layer.
layer_forward = cached_forward.get(repo, None)
if layer_forward is not None:
return layer_forward(self, x, *args, **kwargs)
layer = cached_layer.get(repo, None)
if layer is not None:
if needs_backward and not getattr(layer, "has_backward", True):
return fallback_forward(self, x, *args, **kwargs)
return layer.forward(self, x, *args, **kwargs)
layer = _get_kernel_layer(
repo_id=repo.repo_id,
@ -180,10 +183,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
finally:
cls.forward = orig_forward
layer_forward = layer.forward
cached_forward[repo] = layer_forward
cached_layer[repo] = layer
return layer_forward(self, x, *args, **kwargs)
if needs_backward and not getattr(layer, "has_backward", True):
return fallback_forward(self, x, *args, **kwargs)
return layer.forward(self, x, *args, **kwargs)
cls.forward = forward
@ -240,7 +244,8 @@ def _validate_layer(*, check_cls, cls):
# ... or predefined member variables.
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(cls)}
if cls_members - torch_module_members != set():
difference = cls_members - torch_module_members
if difference != set() and difference != {"has_backward"}:
raise TypeError("Layer must not contain additional members.")
# Check whether the forward signatures are similar.

View File

@ -13,7 +13,7 @@ from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Tuple
from huggingface_hub import snapshot_download
from huggingface_hub import file_exists, snapshot_download
from packaging.version import parse
from kernels.lockfile import KernelLock, VariantLock
@ -161,6 +161,29 @@ def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
return import_from_path(package_name, package_path / package_name / "__init__.py")
def has_kernel(repo_id: str, revision: str = "main") -> bool:
"""
Check whether a kernel build exists for the current environment
(Torch version and compute framework).
"""
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
universal_variant = universal_build_variant()
if file_exists(
repo_id,
revision=revision,
filename=f"build/{universal_variant}/{package_name}/__init__.py",
):
return True
return file_exists(
repo_id,
revision=revision,
filename=f"build/{variant}/{package_name}/__init__.py",
)
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
"""
Get a pre-downloaded, locked kernel.

View File

@ -1,7 +1,7 @@
import pytest
import torch
from kernels import get_kernel
from kernels import get_kernel, has_kernel
@pytest.fixture
@ -36,6 +36,22 @@ def test_gelu_fast(kernel, device):
assert torch.allclose(y, expected)
@pytest.mark.parametrize(
"kernel_exists",
[
("kernels-community/activation", "main", True),
("kernels-community/triton-layer-norm", "main", True),
# Repo only contains Torch 2.4 kernels (and we don't
# support/test against this version).
("kernels-test/only-torch-2.4", "main", False),
("google-bert/bert-base-uncased", "87565a309", False),
],
)
def test_has_kernel(kernel_exists):
repo_id, revision, kernel = kernel_exists
assert has_kernel(repo_id, revision=revision) == kernel
def test_universal_kernel(universal_kernel):
torch.manual_seed(0)
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")

View File

@ -203,3 +203,75 @@ def test_validate_kernel_layer():
with pytest.raises(TypeError, match="different kind of arguments"):
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
def test_fallback_used_when_training():
@use_kernel_forward_from_hub("Linear")
class TorchLinear(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
return super().forward(input)
linear = TorchLinear(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearImplicitBackward",
)
}
}
):
linear.train()
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
linear.train()
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
)
}
}
):
linear.train()
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 1
linear.eval()
linear(X)
assert linear.n_calls == 1