mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-21 05:30:30 +08:00
Compare commits
5 Commits
v0.4.3
...
fix-comman
Author | SHA1 | Date | |
---|---|---|---|
03a8662f7f | |||
cf530c283a | |||
437f910336 | |||
6f1a6067c8 | |||
1d14abcef0 |
11
README.md
11
README.md
@ -1,5 +1,16 @@
|
||||
# kernels
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
|
||||
<a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
|
||||
<a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
|
||||
|
||||
</p>
|
||||
</div>
|
||||
<hr/>
|
||||
|
||||
The Kernel Hub allows Python libraries and applications to load compute
|
||||
kernels directly from the [Hub](https://hf.co/). To support this kind
|
||||
of dynamic loading, Hub kernels differ from traditional Python kernel
|
||||
|
@ -119,10 +119,17 @@ requirements:
|
||||
- The `forward` method has a signature that is compatible with the
|
||||
`forward` method that it is extending.
|
||||
|
||||
The only exception to the _no class variables rule_ is addition of a
|
||||
`has_backward` class variable. This variable is used to indicate whether
|
||||
the layer has a backward pass implemented (`True` when absent).
|
||||
|
||||
This is an example of a pure layer:
|
||||
|
||||
```python
|
||||
class SiluAndMul(nn.Module):
|
||||
# This layer does not implement backward.
|
||||
has_backward: bool = False
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
|
@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
|
||||
"kernels-community/activation" = ">=0.0.1"
|
||||
```
|
||||
|
||||
Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with
|
||||
Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with
|
||||
the locked revisions. The locked revision will be used when loading a kernel with
|
||||
`get_locked_kernel`:
|
||||
|
||||
@ -28,7 +28,7 @@ to `kernels` after doing an (editable or regular) installation of your project.
|
||||
|
||||
## Pre-downloading locked kernels
|
||||
|
||||
Locked kernels can be pre-downloaded by running `kernel download .` in your
|
||||
Locked kernels can be pre-downloaded by running `kernels download .` in your
|
||||
project directory. This will download the kernels to your local Hugging Face
|
||||
Hub cache.
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.4.3"
|
||||
version = "0.4.4"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
|
@ -9,6 +9,7 @@ from kernels.layer import (
|
||||
from kernels.utils import (
|
||||
get_kernel,
|
||||
get_locked_kernel,
|
||||
has_kernel,
|
||||
install_kernel,
|
||||
load_kernel,
|
||||
)
|
||||
@ -16,6 +17,7 @@ from kernels.utils import (
|
||||
__all__ = [
|
||||
"get_kernel",
|
||||
"get_locked_kernel",
|
||||
"has_kernel",
|
||||
"load_kernel",
|
||||
"install_kernel",
|
||||
"use_kernel_forward_from_hub",
|
||||
|
@ -4,7 +4,7 @@ import warnings
|
||||
from contextvars import ContextVar
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Union
|
||||
from typing import TYPE_CHECKING, Dict, Union
|
||||
|
||||
from .utils import get_kernel
|
||||
|
||||
@ -131,12 +131,13 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
|
||||
fallback_forward = cls.forward
|
||||
|
||||
cached_forward: Dict[LayerRepository, Callable] = {}
|
||||
cached_layer: Dict[LayerRepository, nn.Module] = {}
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if _DISABLE_KERNEL_MAPPING:
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
needs_backward = self.training
|
||||
kernel = _KERNEL_MAPPING.get().get(layer_name)
|
||||
if kernel is None:
|
||||
warnings.warn(
|
||||
@ -162,9 +163,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
# Short-circuit if we already loaded the layer.
|
||||
layer_forward = cached_forward.get(repo, None)
|
||||
if layer_forward is not None:
|
||||
return layer_forward(self, x, *args, **kwargs)
|
||||
layer = cached_layer.get(repo, None)
|
||||
if layer is not None:
|
||||
if needs_backward and not getattr(layer, "has_backward", True):
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
return layer.forward(self, x, *args, **kwargs)
|
||||
|
||||
layer = _get_kernel_layer(
|
||||
repo_id=repo.repo_id,
|
||||
@ -180,10 +183,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
finally:
|
||||
cls.forward = orig_forward
|
||||
|
||||
layer_forward = layer.forward
|
||||
cached_forward[repo] = layer_forward
|
||||
cached_layer[repo] = layer
|
||||
|
||||
return layer_forward(self, x, *args, **kwargs)
|
||||
if needs_backward and not getattr(layer, "has_backward", True):
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
return layer.forward(self, x, *args, **kwargs)
|
||||
|
||||
cls.forward = forward
|
||||
|
||||
@ -240,7 +244,8 @@ def _validate_layer(*, check_cls, cls):
|
||||
# ... or predefined member variables.
|
||||
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
||||
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
||||
if cls_members - torch_module_members != set():
|
||||
difference = cls_members - torch_module_members
|
||||
if difference != set() and difference != {"has_backward"}:
|
||||
raise TypeError("Layer must not contain additional members.")
|
||||
|
||||
# Check whether the forward signatures are similar.
|
||||
|
@ -13,7 +13,7 @@ from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub import file_exists, snapshot_download
|
||||
from packaging.version import parse
|
||||
|
||||
from kernels.lockfile import KernelLock, VariantLock
|
||||
@ -161,6 +161,29 @@ def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def has_kernel(repo_id: str, revision: str = "main") -> bool:
|
||||
"""
|
||||
Check whether a kernel build exists for the current environment
|
||||
(Torch version and compute framework).
|
||||
"""
|
||||
package_name = package_name_from_repo_id(repo_id)
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
|
||||
if file_exists(
|
||||
repo_id,
|
||||
revision=revision,
|
||||
filename=f"build/{universal_variant}/{package_name}/__init__.py",
|
||||
):
|
||||
return True
|
||||
|
||||
return file_exists(
|
||||
repo_id,
|
||||
revision=revision,
|
||||
filename=f"build/{variant}/{package_name}/__init__.py",
|
||||
)
|
||||
|
||||
|
||||
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
|
||||
"""
|
||||
Get a pre-downloaded, locked kernel.
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels import get_kernel
|
||||
from kernels import get_kernel, has_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -36,6 +36,22 @@ def test_gelu_fast(kernel, device):
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kernel_exists",
|
||||
[
|
||||
("kernels-community/activation", "main", True),
|
||||
("kernels-community/triton-layer-norm", "main", True),
|
||||
# Repo only contains Torch 2.4 kernels (and we don't
|
||||
# support/test against this version).
|
||||
("kernels-test/only-torch-2.4", "main", False),
|
||||
("google-bert/bert-base-uncased", "87565a309", False),
|
||||
],
|
||||
)
|
||||
def test_has_kernel(kernel_exists):
|
||||
repo_id, revision, kernel = kernel_exists
|
||||
assert has_kernel(repo_id, revision=revision) == kernel
|
||||
|
||||
|
||||
def test_universal_kernel(universal_kernel):
|
||||
torch.manual_seed(0)
|
||||
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
||||
|
@ -203,3 +203,75 @@ def test_validate_kernel_layer():
|
||||
|
||||
with pytest.raises(TypeError, match="different kind of arguments"):
|
||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||
|
||||
|
||||
def test_fallback_used_when_training():
|
||||
@use_kernel_forward_from_hub("Linear")
|
||||
class TorchLinear(nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Used to check that we called hub kernel.
|
||||
self.n_calls = 0
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
self.n_calls += 1
|
||||
return super().forward(input)
|
||||
|
||||
linear = TorchLinear(32, 32).to("cuda")
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearImplicitBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearNoBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
|
Reference in New Issue
Block a user