mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-22 05:48:52 +08:00
Compare commits
1 Commits
test-torch
...
add-compat
Author | SHA1 | Date | |
---|---|---|---|
cc66a53025 |
12
README.md
12
README.md
@ -1,16 +1,5 @@
|
||||
# kernels
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
|
||||
<a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
|
||||
<a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
|
||||
|
||||
</p>
|
||||
</div>
|
||||
<hr/>
|
||||
|
||||
The Kernel Hub allows Python libraries and applications to load compute
|
||||
kernels directly from the [Hub](https://hf.co/). To support this kind
|
||||
of dynamic loading, Hub kernels differ from traditional Python kernel
|
||||
@ -58,7 +47,6 @@ the Hub.
|
||||
|
||||
- [Using layers](docs/layers.md)
|
||||
- [Locking kernel versions](docs/locking.md)
|
||||
- [Environment variables](docs/env.md)
|
||||
- [Using kernels in a Docker container](docs/docker.md)
|
||||
- [Kernel requirements](docs/kernel-requirements.md)
|
||||
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||
|
10
docs/env.md
10
docs/env.md
@ -1,10 +0,0 @@
|
||||
# Environment variables
|
||||
|
||||
## `KERNELS_CACHE`
|
||||
|
||||
The directory to use as the local kernel cache. If not set, the cache
|
||||
of the `huggingface_hub` package is used.
|
||||
|
||||
## `DISABLE_KERNEL_MAPPING`
|
||||
|
||||
Disables kernel mappings for [`layers`](layers.md).
|
@ -1,11 +1,8 @@
|
||||
# Kernel requirements
|
||||
|
||||
Kernels on the Hub must fulfill the requirements outlined on this page. By
|
||||
ensuring kernels are compliant, they can be used on a wide range of Linux
|
||||
systems and Torch builds.
|
||||
|
||||
Kernels on the Hub must fulfill the requirements outlined on this page.
|
||||
You can use [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||
to build compliant kernels.
|
||||
to build conforming kernels.
|
||||
|
||||
## Directory layout
|
||||
|
||||
@ -13,21 +10,34 @@ A kernel repository on the Hub must contain a `build` directory. This
|
||||
directory contains build variants of a kernel in the form of directories
|
||||
following the template
|
||||
`<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`.
|
||||
For example `build/torch26-cxx98-cu118-x86_64-linux`.
|
||||
For example `build/torch26-cxx98-cu118-x86_64-linux`. The currently
|
||||
recommended build variants are:
|
||||
|
||||
Each variant directory must contain a single directory with the same name
|
||||
- `torch25-cxx11-cu118-x86_64-linux`
|
||||
- `torch25-cxx11-cu121-x86_64-linux`
|
||||
- `torch25-cxx11-cu124-x86_64-linux`
|
||||
- `torch25-cxx98-cu118-x86_64-linux`
|
||||
- `torch25-cxx98-cu121-x86_64-linux`
|
||||
- `torch25-cxx98-cu124-x86_64-linux`
|
||||
- `torch26-cxx11-cu118-x86_64-linux`
|
||||
- `torch26-cxx11-cu124-x86_64-linux`
|
||||
- `torch26-cxx11-cu126-x86_64-linux`
|
||||
- `torch26-cxx98-cu118-x86_64-linux`
|
||||
- `torch26-cxx98-cu124-x86_64-linux`
|
||||
- `torch26-cxx98-cu126-x86_64-linux`
|
||||
|
||||
This list will be updated as new PyTorch versions are released. Kernels
|
||||
that are in pure Python (e.g. Triton kernels) only need to provide a
|
||||
single build variant:
|
||||
|
||||
- `torch-universal`
|
||||
|
||||
Each variant directory should contain a single directory with the same name
|
||||
as the repository (replacing `-` by `_`). For instance, kernels in the
|
||||
`kernels-community/activation` repository have a directories like
|
||||
`build/<variant>/activation`. This directory
|
||||
must be a Python package with an `__init__.py` file.
|
||||
|
||||
## Build variants
|
||||
|
||||
A kernel can be compliant for a specific compute framework (e.g. CUDA) or
|
||||
architecture (e.g. x86_64). For compliance with a compute framework and
|
||||
architecture combination, all the variants from the [build variant list](https://github.com/huggingface/kernel-builder/blob/main/docs/build-variants.md)
|
||||
must be available for that combination.
|
||||
|
||||
## Versioning
|
||||
|
||||
Kernels are versioned on the Hub using Git tags. Version tags must be of
|
||||
@ -109,17 +119,10 @@ requirements:
|
||||
- The `forward` method has a signature that is compatible with the
|
||||
`forward` method that it is extending.
|
||||
|
||||
The only exception to the _no class variables rule_ is addition of a
|
||||
`has_backward` class variable. This variable is used to indicate whether
|
||||
the layer has a backward pass implemented (`True` when absent).
|
||||
|
||||
This is an example of a pure layer:
|
||||
|
||||
```python
|
||||
class SiluAndMul(nn.Module):
|
||||
# This layer does not implement backward.
|
||||
has_backward: bool = False
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
|
@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
|
||||
"kernels-community/activation" = ">=0.0.1"
|
||||
```
|
||||
|
||||
Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with
|
||||
Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with
|
||||
the locked revisions. The locked revision will be used when loading a kernel with
|
||||
`get_locked_kernel`:
|
||||
|
||||
@ -28,7 +28,7 @@ to `kernels` after doing an (editable or regular) installation of your project.
|
||||
|
||||
## Pre-downloading locked kernels
|
||||
|
||||
Locked kernels can be pre-downloaded by running `kernels download .` in your
|
||||
Locked kernels can be pre-downloaded by running `kernel download .` in your
|
||||
project directory. This will download the kernels to your local Hugging Face
|
||||
Hub cache.
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.4.4"
|
||||
version = "0.4.2"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
|
@ -9,7 +9,6 @@ from kernels.layer import (
|
||||
from kernels.utils import (
|
||||
get_kernel,
|
||||
get_locked_kernel,
|
||||
has_kernel,
|
||||
install_kernel,
|
||||
load_kernel,
|
||||
)
|
||||
@ -17,7 +16,6 @@ from kernels.utils import (
|
||||
__all__ = [
|
||||
"get_kernel",
|
||||
"get_locked_kernel",
|
||||
"has_kernel",
|
||||
"load_kernel",
|
||||
"install_kernel",
|
||||
"use_kernel_forward_from_hub",
|
||||
|
@ -4,6 +4,8 @@ import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from kernels.compat import tomllib
|
||||
from kernels.lockfile import KernelLock, get_kernel_locks
|
||||
from kernels.utils import install_kernel, install_kernel_all_variants
|
||||
@ -36,6 +38,23 @@ def main():
|
||||
)
|
||||
lock_parser.set_defaults(func=lock_kernels)
|
||||
|
||||
# Add a new compatibility command
|
||||
compat_parser = subparsers.add_parser(
|
||||
"compatibility", help="Show kernel build compatibility"
|
||||
)
|
||||
compat_parser.add_argument(
|
||||
"repo_id",
|
||||
type=str,
|
||||
help="The repository ID of the kernel (e.g., 'kernels-community/activation')",
|
||||
)
|
||||
compat_parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default="main",
|
||||
help="The revision of the kernel (default: main)",
|
||||
)
|
||||
compat_parser.set_defaults(func=check_compatibility)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
|
||||
@ -91,6 +110,53 @@ def lock_kernels(args):
|
||||
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
|
||||
|
||||
|
||||
def check_compatibility(args):
|
||||
"""Check build compatibility for a kernel by reading its build.toml file."""
|
||||
try:
|
||||
# Download only the build.toml file from the repository
|
||||
build_toml_path = hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename="build.toml",
|
||||
revision=args.revision,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
f"Error: Could not find build.toml in repository {args.repo_id}.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Parse the build.toml file
|
||||
try:
|
||||
with open(build_toml_path, "rb") as f:
|
||||
content = f.read().decode("utf-8")
|
||||
|
||||
# Simple check for compatibility without full parsing
|
||||
is_universal = "language" in content and "python" in content
|
||||
has_cuda = "cuda-capabilities" in content
|
||||
has_rocm = "rocm-archs" in content
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading build.toml: {str(e)}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Print the compatibility
|
||||
print(f"Kernel: {args.repo_id}")
|
||||
print("Compatibility: ", end="")
|
||||
|
||||
if is_universal:
|
||||
print("universal")
|
||||
else:
|
||||
compatibilities = []
|
||||
if has_cuda:
|
||||
compatibilities.append("cuda")
|
||||
if has_rocm:
|
||||
compatibilities.append("rocm")
|
||||
print(", ".join(compatibilities) if compatibilities else "unknown")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
class _JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if dataclasses.is_dataclass(o):
|
||||
|
@ -1,18 +1,15 @@
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from contextvars import ContextVar
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Dict, Union
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Union
|
||||
|
||||
from .utils import get_kernel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import nn
|
||||
|
||||
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Device:
|
||||
@ -131,13 +128,9 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
|
||||
fallback_forward = cls.forward
|
||||
|
||||
cached_layer: Dict[LayerRepository, nn.Module] = {}
|
||||
cached_forward: Dict[LayerRepository, Callable] = {}
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if _DISABLE_KERNEL_MAPPING:
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
needs_backward = self.training
|
||||
kernel = _KERNEL_MAPPING.get().get(layer_name)
|
||||
if kernel is None:
|
||||
warnings.warn(
|
||||
@ -163,11 +156,9 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
# Short-circuit if we already loaded the layer.
|
||||
layer = cached_layer.get(repo, None)
|
||||
if layer is not None:
|
||||
if needs_backward and not getattr(layer, "has_backward", True):
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
return layer.forward(self, x, *args, **kwargs)
|
||||
layer_forward = cached_forward.get(repo, None)
|
||||
if layer_forward is not None:
|
||||
return layer_forward(self, x, *args, **kwargs)
|
||||
|
||||
layer = _get_kernel_layer(
|
||||
repo_id=repo.repo_id,
|
||||
@ -183,11 +174,10 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
finally:
|
||||
cls.forward = orig_forward
|
||||
|
||||
cached_layer[repo] = layer
|
||||
layer_forward = layer.forward
|
||||
cached_forward[repo] = layer_forward
|
||||
|
||||
if needs_backward and not getattr(layer, "has_backward", True):
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
return layer.forward(self, x, *args, **kwargs)
|
||||
return layer_forward(self, x, *args, **kwargs)
|
||||
|
||||
cls.forward = forward
|
||||
|
||||
@ -244,8 +234,7 @@ def _validate_layer(*, check_cls, cls):
|
||||
# ... or predefined member variables.
|
||||
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
||||
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
||||
difference = cls_members - torch_module_members
|
||||
if difference != set() and difference != {"has_backward"}:
|
||||
if cls_members - torch_module_members != set():
|
||||
raise TypeError("Layer must not contain additional members.")
|
||||
|
||||
# Check whether the forward signatures are similar.
|
||||
|
@ -4,7 +4,6 @@ import importlib
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
@ -13,25 +12,12 @@ from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from huggingface_hub import file_exists, snapshot_download
|
||||
from huggingface_hub import snapshot_download
|
||||
from packaging.version import parse
|
||||
|
||||
from kernels.lockfile import KernelLock, VariantLock
|
||||
|
||||
|
||||
def _get_cache_dir() -> Optional[str]:
|
||||
"""Returns the kernels cache directory."""
|
||||
cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
|
||||
if cache_dir is not None:
|
||||
logging.warning(
|
||||
"HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead"
|
||||
)
|
||||
return cache_dir
|
||||
|
||||
return os.environ.get("KERNELS_CACHE", None)
|
||||
|
||||
|
||||
CACHE_DIR: Optional[str] = _get_cache_dir()
|
||||
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
|
||||
|
||||
|
||||
def build_variant() -> str:
|
||||
@ -161,29 +147,6 @@ def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def has_kernel(repo_id: str, revision: str = "main") -> bool:
|
||||
"""
|
||||
Check whether a kernel build exists for the current environment
|
||||
(Torch version and compute framework).
|
||||
"""
|
||||
package_name = package_name_from_repo_id(repo_id)
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
|
||||
if file_exists(
|
||||
repo_id,
|
||||
revision=revision,
|
||||
filename=f"build/{universal_variant}/{package_name}/__init__.py",
|
||||
):
|
||||
return True
|
||||
|
||||
return file_exists(
|
||||
repo_id,
|
||||
revision=revision,
|
||||
filename=f"build/{variant}/{package_name}/__init__.py",
|
||||
)
|
||||
|
||||
|
||||
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
|
||||
"""
|
||||
Get a pre-downloaded, locked kernel.
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels import get_kernel, has_kernel
|
||||
from kernels import get_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -21,16 +21,11 @@ def device():
|
||||
return "cuda"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("torch_compile", [False, True])
|
||||
def test_gelu_fast(kernel, device, torch_compile):
|
||||
def test_gelu_fast(kernel, device):
|
||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||
y = torch.empty_like(x)
|
||||
|
||||
op = kernel.gelu_fast
|
||||
if torch_compile:
|
||||
op = torch.compile(op)
|
||||
|
||||
op(y, x)
|
||||
kernel.gelu_fast(y, x)
|
||||
|
||||
expected = torch.tensor(
|
||||
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||
@ -41,22 +36,6 @@ def test_gelu_fast(kernel, device, torch_compile):
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kernel_exists",
|
||||
[
|
||||
("kernels-community/activation", "main", True),
|
||||
("kernels-community/triton-layer-norm", "main", True),
|
||||
# Repo only contains Torch 2.4 kernels (and we don't
|
||||
# support/test against this version).
|
||||
("kernels-test/only-torch-2.4", "main", False),
|
||||
("google-bert/bert-base-uncased", "87565a309", False),
|
||||
],
|
||||
)
|
||||
def test_has_kernel(kernel_exists):
|
||||
repo_id, revision, kernel = kernel_exists
|
||||
assert has_kernel(repo_id, revision=revision) == kernel
|
||||
|
||||
|
||||
def test_universal_kernel(universal_kernel):
|
||||
torch.manual_seed(0)
|
||||
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
||||
|
@ -73,8 +73,7 @@ def test_arg_kinds():
|
||||
|
||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("torch_compile", [False, True])
|
||||
def test_hub_forward(cls, device, torch_compile):
|
||||
def test_hub_forward(cls, device):
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
silu_and_mul = SiluAndMul()
|
||||
@ -82,8 +81,6 @@ def test_hub_forward(cls, device, torch_compile):
|
||||
Y = silu_and_mul(X)
|
||||
|
||||
silu_and_mul_with_kernel = cls()
|
||||
if torch_compile:
|
||||
silu_and_mul_with_kernel = torch.compile(silu_and_mul_with_kernel)
|
||||
Y_kernel = silu_and_mul_with_kernel(X)
|
||||
|
||||
torch.testing.assert_close(Y_kernel, Y)
|
||||
@ -206,75 +203,3 @@ def test_validate_kernel_layer():
|
||||
|
||||
with pytest.raises(TypeError, match="different kind of arguments"):
|
||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||
|
||||
|
||||
def test_fallback_used_when_training():
|
||||
@use_kernel_forward_from_hub("Linear")
|
||||
class TorchLinear(nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Used to check that we called hub kernel.
|
||||
self.n_calls = 0
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
self.n_calls += 1
|
||||
return super().forward(input)
|
||||
|
||||
linear = TorchLinear(32, 32).to("cuda")
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearImplicitBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearNoBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
|
Reference in New Issue
Block a user