Compare commits

..

1 Commits

Author SHA1 Message Date
b7b5f40143 Set version to 0.7.0 2025-07-07 13:09:01 +00:00
37 changed files with 445 additions and 2706 deletions

View File

@ -1,17 +0,0 @@
name: Build documentation
on:
push:
branches:
- main
- doc-builder*
- v*-release
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
with:
commit_sha: ${{ github.sha }}
package: kernels
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}

View File

@ -1,15 +0,0 @@
name: Build PR Documentation
on: pull_request
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: kernels

View File

@ -8,24 +8,3 @@ jobs:
- uses: actions/checkout@v4
- name: Run ruff
uses: astral-sh/ruff-action@v3
black:
name: Run black check
runs-on: ubuntu-latest
env:
UV_PYTHON_PREFERENCE: only-managed
steps:
- uses: actions/checkout@v4
- name: Install uv and set the python version
uses: astral-sh/setup-uv@v5
with:
python-version: 3.12
- name: Install black
run: uv pip install black
- name: Check formatting
run: |
uv run black --check src
uv run black --check tests

View File

@ -1,16 +0,0 @@
name: Upload PR Documentation
on:
workflow_run:
workflows: ["Build PR Documentation"]
types:
- completed
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
with:
package_name: kernels
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}

View File

@ -56,13 +56,10 @@ the Hub.
## 📚 Documentation
- [Introduction](docs/source/index.md)
- [Installation](docs/source/installation.md)
- [Basic usage](docs/source/basic-usage.md)
- [Using layers](docs/source/layers.md)
- [Locking kernel/layer versions](docs/source/locking.md)
- [Environment variables](docs/source/env.md)
- [Using kernels in a Docker container](docs/source/docker.md)
- [Kernel requirements](docs/source/kernel-requirements.md)
- [Frequently Asked Questions](docs/source/faq.md)
- [Using layers](docs/layers.md)
- [Locking kernel versions](docs/locking.md)
- [Environment variables](docs/env.md)
- [Using kernels in a Docker container](docs/docker.md)
- [Kernel requirements](docs/kernel-requirements.md)
- [Frequently Asked Questions](docs/faq.md)
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)

8
docs/docker.md Normal file
View File

@ -0,0 +1,8 @@
# Using kernels in a Docker container
build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
```bash
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
```

View File

@ -2,9 +2,9 @@
## Why is the kernelization step needed?
In earlier versions of `kernels`, a layer's `forward` method was replaced
by `use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`.
The new `forward` would dispatch to a kernel based on the device type,
In earlier versions of `kernels`, a layer's `forward` was replaced by
`use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The
new `forward` would dispatch to a kernel based on the device type,
whether a model was training, etc. However, this approach was
fundamentally incompatible with `torch.compile` since it relied
on data-dependent branching.

197
docs/layers.md Normal file
View File

@ -0,0 +1,197 @@
# Layers
A kernel can provide layers in addition to kernel functions. A layer from
the Hub can replace the `forward` method of an existing layer for a certain
device type. This makes it possible to provide more performant kernels for
existing layers.
See [Kernel requirements](kernel-requirements.md) for more information the
requirements of Hub layers.
## Making a layer extensible with kernels from the hub
### Using a decorator
A layer can be made extensible with the `use_kernel_forward_from_hub`
decorator. For example:
```python
@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMul(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
return F.silu(input[..., :d]) * input[..., d:]
```
The decorator does not change the behavior of the class -- it annotates
the class with the given name (here `SiluAndMul`). The `kernelize` function
described below uses this name to look up kernels for the layer.
### External layers
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
decorator can be made extensible using the `replace_kernel_forward_from_hub`
function:
```python
from somelibrary import SiluAndMul
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
```
**Warning:** we strongly recommend using layers with a decorator, since
it signifies that the maintainer intends to keep the `forward` signature
compatible with layers from the hub.
## Kernelizing a model
A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. Kernelize can be used as follows:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE)
```
The `mode` specifies that the model will be used in inference. Similarly,
you can ask `kernelize` to prepare the model for training:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING)
```
**Note:** the `kernelize` function modifies the model in-place, the model
itself is returned as a convenience.
### Kernel device
Kernels can be registered per device type. For instance, separate `cuda` and
`metal` kernels could be registered for the name `SiluAndMul`. By default,
`kernelize` will try to infer the device type from the model's parameters.
You can pass the device type to `kernelize` if the device type cannot be
inferred (e.g. because the model has no parameters):
```python
model = MyModel(...)
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
```
### `torch.compile`
Not all Hub kernels support `torch.compile`. If you want to compile a model
after kernelizing it, you need to add this to the mode. You can use the
set union (`|`) operator to add `TORCH_COMPILE` to the mode:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
```
### Fallback `forward`
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
kernel does not support backward passes or `torch.compile` respectively,
`kernenize` will fall back to the original, non-kernelized, layer. You
can let `kernelize` raise an exception instead by using `use_fallback=False`:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
```
This can be useful if you want to guarantee that Hub kernels are used.
## Registering a hub kernel for a layer
`kernelize` relies on kernel mappings to find Hub kernels for layers.
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
the Hub. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
}
}
```
You can register such a mapping using `register_kernel_mapping`:
```python
register_kernel_mapping(kernel_layer_mapping)
```
This will register the kernel mapping in the current context, which is
normally global. It is recommended to scope the mapping to where it is
used with the `use_kernel_mapping` context manager:
```python
with use_kernel_mapping(kernel_layer_mapping):
# Use the layer for which the mapping is applied.
model = kernelize(model)
```
This ensures that the mapping is not active anymore outside the
`with`-scope.
### Registering kernels for specific modes
You might want to register two different kernels for a particular layer,
where one kernel is optimized for a specific mode. You can do so by
registering layer repositories for specific modes. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
The kernels will match exactly on the mode. So, for instance in the example above, no kernel
layer is used when the `mode` passed to `kernelize` is
`Mode.INFERENCE | Mode.TORCH_COMPILE` or `Mode.TRAINING`. However, if you want to
register a kernel to be used when the mode does not match any of the
modes in the mapping, you can use the special `Mode.DEFAULT` mode to do
so. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.DEFAULT: LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
In this case, modes other than `Mode.INFERENCE` and
`Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using
`kernels-community/activation`.

View File

@ -1,4 +1,4 @@
# Locking kernel/layer versions
# Locking kernel versions
Projects that use `setuptools` can lock the kernel versions that should be
used. First specify the accepted versions in `pyproject.toml` and make
@ -26,24 +26,6 @@ activation = get_locked_kernel("kernels-community/activation")
**Note:** the lock file is included in the package metadata, so it will only be visible
to `kernels` after doing an (editable or regular) installation of your project.
## Locked kernel layers
Locking is also supported for kernel layers. To use locked layers, register them
with the `LockedLayerRepository` class:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": LockedLayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
}
}
register_kernel_mapping(kernel_layer_mapping)
```
## Pre-downloading locked kernels
Locked kernels can be pre-downloaded by running `kernels download .` in your

View File

@ -1,28 +0,0 @@
- sections:
- local: index
title: Introduction
- local: installation
title: Installation
title: Getting started
- sections:
- local: basic-usage
title: Basic Usage
- local: layers
title: Using Layers
- local: locking
title: Locking Kernel Versions
- local: env
title: Environment Variables
- local: faq
title: FAQ
title: Usage Guide
- sections:
- local: api/kernels
title: Kernels
- local: api/layers
title: Layers
title: API Reference
- sections:
- local: kernel-requirements
title: Kernel Requirements
title: Developer Guide

View File

@ -1,21 +0,0 @@
# Kernels API Reference
## Main Functions
### get_kernel
[[autodoc]] kernels.get_kernel
### has_kernel
[[autodoc]] kernels.has_kernel
## Loading locked kernels
### load_kernel
[[autodoc]] kernels.load_kernel
### get_locked_kernel
[[autodoc]] kernels.get_locked_kernel

View File

@ -1,41 +0,0 @@
# Layers API Reference
## Making layers kernel-aware
### use_kernel_forward_from_hub
[[autodoc]] kernels.use_kernel_forward_from_hub
### replace_kernel_forward_from_hub
[[autodoc]] kernels.replace_kernel_forward_from_hub
## Registering kernel mappings
### use_kernel_mapping
[[autodoc]] kernels.use_kernel_mapping
### register_kernel_mapping
[[autodoc]] kernels.register_kernel_mapping
## Kernelizing a model
### kernelize
[[autodoc]] kernels.kernelize
## Classes
### Device
[[autodoc]] kernels.Device
### Mode
[[autodoc]] kernels.Mode
### LayerRepository
[[autodoc]] kernels.LayerRepository

View File

@ -1,34 +0,0 @@
# Basic Usage
## Loading Kernels
Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:
```python
import torch
from kernels import get_kernel
# Download optimized kernels from the Hugging Face hub
activation = get_kernel("kernels-community/activation")
# Create a random tensor
x = torch.randn((10, 10), dtype=torch.float16, device="cuda")
# Run the kernel
y = torch.empty_like(x)
activation.gelu_fast(y, x)
print(y)
```
## Checking Kernel Availability
You can check if a specific kernel is available for your environment:
```python
from kernels import has_kernel
# Check if kernel is available for current environment
is_available = has_kernel("kernels-community/activation")
print(f"Kernel available: {is_available}")
```

View File

@ -1,20 +0,0 @@
# Kernels
<div align="center">
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
</div>
The Kernel Hub allows Python libraries and applications to load compute
kernels directly from the [Hub](https://hf.co/). To support this kind
of dynamic loading, Hub kernels differ from traditional Python kernel
packages in that they are made to be:
- **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`.
- **Unique**: multiple versions of the same kernel can be loaded in the
same Python process.
- **Compatible**: kernels must support all recent versions of Python and
the different PyTorch build configurations (various CUDA versions
and C++ ABIs). Furthermore, older C library versions must be supported.
You can [search for kernels](https://huggingface.co/models?other=kernel) on
the Hub.

View File

@ -1,16 +0,0 @@
# Installation
Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):
```bash
pip install kernels
```
# Using kernels in a Docker container
Build and run the reference `examples/basic.py` in a Docker container with the following commands:
```bash
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
```

View File

@ -1,296 +0,0 @@
# Layers
A kernel can provide layers in addition to kernel functions. A layer from
the Hub can replace the `forward` method of an existing layer for a certain
device type. This makes it possible to provide more performant kernels for
existing layers.
See [Kernel requirements](kernel-requirements.md) for more information the
requirements of Hub layers.
## Making a layer extensible with kernels from the hub
### Using a decorator
A layer can be made extensible with the `use_kernel_forward_from_hub`
decorator. For example:
```python
@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMul(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
return F.silu(input[..., :d]) * input[..., d:]
```
The decorator does not change the behavior of the class -- it annotates
the class with the given name (here `SiluAndMul`). The `kernelize` function
described below uses this name to look up kernels for the layer.
### External layers
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
decorator can be made extensible using the `replace_kernel_forward_from_hub`
function:
```python
from somelibrary import SiluAndMul
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
```
**Warning:** we strongly recommend using layers with a decorator, since
it signifies that the maintainer intends to keep the `forward` signature
compatible with layers from the hub.
## Kernelizing a model
A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. `kernelize` can be used as follows:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE)
```
The `kernelize` function modifies the model in-place, the model itself is
returned as a convenience. The `mode` specifies that the model will be used
in inference. Similarly, you can ask `kernelize` to prepare the model for
training:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING)
```
A model that is kernelized for training can also be used for inference, but
not the other way around. If you want to change the mode of the kernelized
model, you can just run `kernelize` on the model again with the new mode.
If you want to compile a model with `torch.compile`, this should be indicated
in the mode as well. You can do this by combining `Mode.INFERENCE` or
`Mode.TRAINING` with `Mode.TORCH_COMPILE` using the set union (`|`) operator:
```python
model = MyModel(...)
# Inference
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
# Training
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```
### Kernel device
Kernels can be registered per device type. For instance, separate `cuda` and
`metal` kernels could be registered for the name `SiluAndMul`. By default,
`kernelize` will try to infer the device type from the model's parameters.
You can pass the device type to `kernelize` if the device type cannot be
inferred (e.g. because the model has no parameters):
```python
model = MyModel(...)
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
```
### Fallback `forward`
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
kernel does not support backward passes or `torch.compile` respectively,
`kernelize` will fall back to the original, non-kernelized, layer. You
can let `kernelize` raise an exception instead by using `use_fallback=False`:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
```
This can be useful if you want to guarantee that Hub kernels are used.
### Inspecting kernels which kernels are used
The kernels that are used are logged at the `INFO` level by `kernelize`.
See the [Python logging](https://docs.python.org/3/library/logging.html)
documentation for information on how to configure logging.
## Registering a hub kernel for a layer
`kernelize` relies on kernel mappings to find Hub kernels for layers.
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
the Hub. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
"rocm": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
}
}
```
You can register such a mapping using `register_kernel_mapping`:
```python
register_kernel_mapping(kernel_layer_mapping)
```
This will register the kernel mapping in the current context, which is
normally global. It is recommended to scope the mapping to where it is
used with the `use_kernel_mapping` context manager:
```python
with use_kernel_mapping(kernel_layer_mapping):
# Use the layer for which the mapping is applied.
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```
This ensures that the mapping is not active anymore outside the
`with`-scope.
### Registering kernels for specific modes
You might want to register two different kernels for a particular layer,
where one kernel is optimized for a specific mode. You can do so by
registering layer repositories for specific modes. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
The `kernelize` function will attempt to use the following registered
kernels for a given mode:
- `INFERENCE`: `INFERENCE``INFERENCE | TORCH_COMPILE``TRAINING`
`TRAINING | TORCH_COMPILE``FALLBACK`
- `INFERENCE | TORCH_COMPILE`: `INFERENCE | TORCH_COMPILE`
`TRAINING | TORCH_COMPILE``FALLBACK`
- `TRAINING`: `TRAINING``TRAINING | TORCH_COMPILE``FALLBACK`
- `TRAINING | TORCH_COMPILE`: `TRAINING | TORCH_COMPILE``FALLBACK`
`Mode.FALLBACK` is a special mode that is used when no other mode matches. It
is also used when a kernel is registered without a mode, as described in the
previous section.
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.FALLBACK: LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
In this case, both `Mode.INFERENCE | Mode.TORCH_COMPILE` and
`Mode.TRAINING | Mode.TORCH_COMPILE` will use the `Mode.FALLBACK` kernel,
since the other kernels do not support `torch.compile`.
### Registering kernels for specific CUDA capabilities
Some kernels only work with newer CUDA architectures. For instance, some
kernels require capability 9.0 for the TMA unit on Hopper GPUs. `kernels`
supports registering layers for a range of CUDA capabilities. To do so,
you need to register the layer for a `Device` with type `cuda` and
set the supported range of CUDA capabilities with using `CUDAProperties`:
```python
kernel_layer_mapping = {
"SiluAndMul": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=75, max_capability=89
),
): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Device(
type="cuda",
properties=CUDAProperties(
min_capability=90, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-community/activation-hopper",
layer_name="SiluAndMul",
),
}
}
```
Capabilities behave as follows:
- The minimum and maximum capabilities are inclusive.
- When a new kernel is registered with the same min/max capabilities as
an existing kernel, the new kernel will replace the old kernel.
- When there are multiple kernels that support a capability, the kernel
with the smaller capability interval will be used. E.g. given:
- `KernelA` with `min_capability=80` and `max_capability=89`;
- `KernelB` with `min_capability=75` and `max_capability=89`;
- `kernelize` runs on a system with capability 8.6.
Then `KernelA` will be used because the interval 80..89 is smaller
than 75..89. The motivation is that kernels with smaller ranges
tend to be more optimized for a specific set of GPUs. **This behavior
might still change in the future.**
### Registering kernels for specific ROCm capabilities
Registering kernels for the ROCm architecture follows the exact same
pattern as CUDA kernels, using `min_capability` and `max_capability` to restrict
a kernel to a range of ROCm capabilities.
### Loading from a local repository for testing
The `LocalLayerRepository` class is provided to load a repository from
a local directory. For example:
```python
with use_kernel_mapping(
{
"SiluAndMul": {
"cuda": LocalLayerRepository(
repo_path="/home/daniel/kernels/activation",
package_name="activation",
layer_name="SiluAndMul",
)
}
},
inherit_mapping=False,
):
kernelize(linear, mode=Mode.INFERENCE)
```

18
flake.lock generated
View File

@ -58,11 +58,11 @@
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1754038838,
"narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
"lastModified": 1750775451,
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=",
"owner": "huggingface",
"repo": "hf-nix",
"rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7",
"type": "github"
},
"original": {
@ -73,17 +73,17 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1752785354,
"narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
"owner": "nixos",
"lastModified": 1747820358,
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
"owner": "danieldk",
"repo": "nixpkgs",
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
"rev": "d3c1681180717528068082103bf323147de6ab0b",
"type": "github"
},
"original": {
"owner": "nixos",
"owner": "danieldk",
"ref": "cudatoolkit-12.9-kernel-builder",
"repo": "nixpkgs",
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
"type": "github"
}
},

View File

@ -26,10 +26,6 @@
formatter = pkgs.nixfmt-tree;
devShells = with pkgs; rec {
default = mkShell {
nativeBuildInputs = [
# For hf-doc-builder.
nodejs
];
buildInputs =
[
black
@ -40,7 +36,6 @@
++ (with python3.pkgs; [
docutils
huggingface-hub
mktestdocs
pytest
pytest-benchmark
pyyaml

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.10.1"
version = "0.7.0"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
@ -24,20 +24,16 @@ build-backend = "setuptools.build_meta"
[dependency-groups]
dev = [
"mktestdocs>=0.2.5",
"mypy>=1.15.0",
"pytest>=8",
"mypy >= 1.15.0",
"pytest >=8",
# Whatever version is compatible with pytest.
"pytest-benchmark",
"torch>=2.5",
"torch >=2.5",
"types-pyyaml"
]
[project.optional-dependencies]
torch = ["torch"]
docs = [
"hf-doc-builder",
]
[project.scripts]
kernels = "kernels.cli:main"

View File

@ -1,5 +1,4 @@
[pytest]
markers =
cuda_only: marks tests that should only hosts with CUDA GPUs
rocm_only: marks tests that should only run on hosts with ROCm GPUs
darwin_only: marks tests that should only run on macOS
linux_only: marks tests that should only run on Linux

View File

@ -1,13 +1,6 @@
import importlib.metadata
__version__ = importlib.metadata.version("kernels")
from kernels.layer import (
CUDAProperties,
Device,
LayerRepository,
LocalLayerRepository,
LockedLayerRepository,
Mode,
kernelize,
register_kernel_mapping,
@ -25,12 +18,8 @@ from kernels.utils import (
)
__all__ = [
"__version__",
"CUDAProperties",
"Device",
"LayerRepository",
"LocalLayerRepository",
"LockedLayerRepository",
"Mode",
"get_kernel",
"get_local_kernel",

View File

@ -1,200 +0,0 @@
# AVL-balanced interval trees. We could use the intervaltree
# packages, but it seems unmaintained and does not have type
# annotations.
from typing import Generic, List, Optional, Tuple, TypeVar
T = TypeVar("T")
class _Node(Generic[T]):
"""A node in the interval tree."""
def __init__(self, start: int, end: int, data: T):
self.start: int = start
self.end: int = end
self.data: T = data
self.max_end: int = end
self.left: Optional["_Node[T]"] = None
self.right: Optional["_Node[T]"] = None
self.height: int = 1
def __repr__(self) -> str:
return f"Node({self.start}, {self.end})"
class IntervalTree(Generic[T]):
"""A data structure to hold and query (unique) intervals."""
root: Optional[_Node[T]]
def __init__(self):
self.root = None
def insert(self, start: int, end: int, data: T) -> None:
"""
Inserts a new interval into the tree.
Args:
start: The starting point of the interval.
end: The ending point of the interval.
data: The data associated with this interval.
"""
self.root = self._insert(self.root, start, end, data)
def _get_height(self, node: Optional[_Node[T]]) -> int:
if not node:
return 0
return node.height
def _get_balance(self, node: Optional[_Node[T]]) -> int:
if not node:
return 0
return self._get_height(node.left) - self._get_height(node.right)
def _update_node_attributes(self, node: _Node[T]) -> None:
node.height = 1 + max(self._get_height(node.left), self._get_height(node.right))
node.max_end = node.end
if node.left:
node.max_end = max(node.max_end, node.left.max_end)
if node.right:
node.max_end = max(node.max_end, node.right.max_end)
def _right_rotate(self, y: _Node[T]) -> _Node[T]:
"""Performs a right rotation."""
x = y.left
assert x is not None
T2 = x.right
x.right = y
y.left = T2
self._update_node_attributes(y)
self._update_node_attributes(x)
return x
def _left_rotate(self, x: _Node[T]) -> _Node[T]:
"""Performs a left rotation."""
y = x.right
assert y is not None
T2 = y.left
y.left = x
x.right = T2
self._update_node_attributes(x)
self._update_node_attributes(y)
return y
def _insert(
self, node: Optional[_Node[T]], start: int, end: int, data: T
) -> _Node[T]:
"""Recursive helper to insert a new node and balance the tree."""
if not node:
return _Node(start, end, data)
# Replace the data if the interval already exists.
if start == node.start and end == node.end:
node.data = data
return node
if start < node.start:
node.left = self._insert(node.left, start, end, data)
else:
node.right = self._insert(node.right, start, end, data)
self._update_node_attributes(node)
balance = self._get_balance(node)
# Left Left Case
if balance > 1 and node.left and start < node.left.start:
return self._right_rotate(node)
# Right Right Case
if balance < -1 and node.right and start >= node.right.start:
return self._left_rotate(node)
# Left Right Case
if balance > 1 and node.left and start >= node.left.start:
node.left = self._left_rotate(node.left)
return self._right_rotate(node)
# Right Left Case
if balance < -1 and node.right and start < node.right.start:
node.right = self._right_rotate(node.right)
return self._left_rotate(node)
return node
def search(self, point: int) -> List[T]:
"""
Searches for all intervals that contain the given point.
Args:
point: The point to search for.
Returns:
A list of data items from all matching intervals.
"""
results: List[T] = []
self._search(self.root, point, results)
return results
def _search(self, node: Optional[_Node[T]], point: int, results: List[T]) -> None:
"""Recursive helper to find all overlapping intervals."""
if node is None or point > node.max_end:
return
if node.left:
self._search(node.left, point, results)
if node.start <= point <= node.end:
results.append(node.data)
if point >= node.start and node.right:
self._search(node.right, point, results)
def find_smallest_interval(self, point: int) -> Optional[T]:
"""
Finds the item with the most specific (smallest) range for a given point.
Args:
point: The capability to look up.
Returns:
The data of the best-matching item, or None if no match is found.
"""
matches: List[Tuple[int, int, T]] = []
self._find_with_intervals(self.root, point, matches)
if not matches:
return None
# Return the smallest interval, sort by memory location when
# there are multiple matches with the same interval size. This
# is just to ensure that we can compare against a trivial
# implementation in tests.
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
return best_match[2]
def _find_with_intervals(
self,
node: Optional[_Node[T]],
point: int,
results: List[Tuple[int, int, T]],
) -> None:
"""A modified search that collects interval ranges along with data."""
if node is None or point > node.max_end:
return
if node.left:
self._find_with_intervals(node.left, point, results)
if node.start <= point <= node.end:
results.append((node.start, node.end, node.data))
if point >= node.start and node.right:
self._find_with_intervals(node.right, point, results)

View File

@ -1,52 +0,0 @@
from typing import Dict, Optional
from huggingface_hub import HfApi
from huggingface_hub.hf_api import GitRefInfo
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
"""Get kernel versions that are available in the repository."""
versions = {}
for tag in HfApi().list_repo_refs(repo_id).tags:
if not tag.name.startswith("v"):
continue
try:
versions[Version(tag.name[1:])] = tag
except InvalidVersion:
continue
return versions
def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo:
"""
Get the locks for a kernel with the given version spec.
The version specifier can be any valid Python version specifier:
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
"""
versions = _get_available_versions(repo_id)
requirement = SpecifierSet(version_spec)
accepted_versions = sorted(requirement.filter(versions.keys()))
if len(accepted_versions) == 0:
raise ValueError(
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
)
return versions[accepted_versions[-1]]
def select_revision_or_version(
repo_id: str, revision: Optional[str], version: Optional[str]
) -> str:
if revision is not None and version is not None:
raise ValueError("Either a revision or a version must be specified, not both.")
elif revision is None and version is None:
revision = "main"
elif version is not None:
revision = resolve_version_spec_as_ref(repo_id, version).target_commit
assert revision is not None
return revision

File diff suppressed because it is too large Load Diff

View File

@ -4,8 +4,10 @@ from pathlib import Path
from typing import Dict, List, Tuple
from huggingface_hub import HfApi
from huggingface_hub.hf_api import GitRefInfo
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version
from kernels._versions import resolve_version_spec_as_ref
from kernels.compat import tomllib
@ -29,6 +31,20 @@ class KernelLock:
return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
"""Get kernel versions that are available in the repository."""
versions = {}
for tag in HfApi().list_repo_refs(repo_id).tags:
if not tag.name.startswith("v"):
continue
try:
versions[Version(tag.name[1:])] = tag
except InvalidVersion:
continue
return versions
def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
"""
Get the locks for a kernel with the given version spec.
@ -36,7 +52,16 @@ def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
The version specifier can be any valid Python version specifier:
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
"""
tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec)
versions = _get_available_versions(repo_id)
requirement = SpecifierSet(version_spec)
accepted_versions = sorted(requirement.filter(versions.keys()))
if len(accepted_versions) == 0:
raise ValueError(
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
)
tag_for_newest = versions[accepted_versions[-1]]
r = HfApi().repo_info(
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True

View File

@ -16,7 +16,6 @@ from typing import Dict, List, Optional, Tuple
from huggingface_hub import file_exists, snapshot_download
from packaging.version import parse
from kernels._versions import select_revision_or_version
from kernels.lockfile import KernelLock, VariantLock
@ -46,12 +45,9 @@ def build_variant() -> str:
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
elif torch.backends.mps.is_available():
compute_framework = "metal"
elif torch.version.xpu is not None:
version = torch.version.xpu
compute_framework = f"xpu{version[0:4]}{version[5:6]}"
else:
raise AssertionError(
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
"Torch was not compiled with CUDA, Metal, or ROCm enabled."
)
torch_version = parse(torch.__version__)
@ -99,20 +95,7 @@ def install_kernel(
"""
Download a kernel for the current environment to the cache.
The output path is validated against the hashes in `variant_locks` when provided.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
revision (`str`):
The specific revision (branch, tag, or commit) to download.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only use local files and not download from the Hub.
variant_locks (`Dict[str, VariantLock]`, *optional*):
Optional dictionary of variant locks for validation.
Returns:
`Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
The output path is validated againt `hash` when set.
"""
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
@ -199,39 +182,13 @@ def install_kernel_all_variants(
return repo_path / "build"
def get_kernel(
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
) -> ModuleType:
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
"""
Load a kernel from the kernel hub.
Download and import a kernel from the Hugging Face Hub.
This function downloads a kernel to the local Hugging Face Hub cache directory (if it was not downloaded before)
and then loads the kernel.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
revision (`str`, *optional*, defaults to `"main"`):
The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
version (`str`, *optional*):
The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
Cannot be used together with `revision`.
Returns:
`ModuleType`: The imported kernel module.
Example:
```python
import torch
from kernels import get_kernel
activation = get_kernel("kernels-community/activation")
x = torch.randn(10, 20, device="cuda")
out = torch.empty_like(x)
result = activation.silu_and_mul(out, x)
```
The kernel is downloaded from the repository `repo_id` at
branch/commit/tag `revision`.
"""
revision = select_revision_or_version(repo_id, revision, version)
package_name, package_path = install_kernel(repo_id, revision=revision)
return import_from_path(package_name, package_path / package_name / "__init__.py")
@ -239,56 +196,16 @@ def get_kernel(
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
"""
Import a kernel from a local kernel repository path.
Args:
repo_path (`Path`):
The local path to the kernel repository.
package_name (`str`):
The name of the package to import from the repository.
Returns:
`ModuleType`: The imported kernel module.
"""
variant = build_variant()
universal_variant = universal_build_variant()
# Presume we were given the top level path of the kernel repository.
for base_path in [repo_path, repo_path / "build"]:
# Prefer the universal variant if it exists.
for v in [universal_variant, variant]:
package_path = base_path / v / package_name / "__init__.py"
if package_path.exists():
return import_from_path(package_name, package_path)
# If we didn't find the package in the repo we may have a explicit
# package path.
package_path = repo_path / package_name / "__init__.py"
if package_path.exists():
return import_from_path(package_name, package_path)
raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def has_kernel(
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
) -> bool:
def has_kernel(repo_id: str, revision: str = "main") -> bool:
"""
Check whether a kernel build exists for the current environment (Torch version and compute framework).
Args:
repo_id (`str`):
The Hub repository containing the kernel.
revision (`str`, *optional*, defaults to `"main"`):
The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
version (`str`, *optional*):
The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
Cannot be used together with `revision`.
Returns:
`bool`: `True` if a kernel is available for the current environment.
Check whether a kernel build exists for the current environment
(Torch version and compute framework).
"""
revision = select_revision_or_version(repo_id, revision, version)
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
universal_variant = universal_build_variant()
@ -311,16 +228,8 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
"""
Get a pre-downloaded, locked kernel.
If `lockfile` is not specified, the lockfile will be loaded from the caller's package metadata.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
lockfile (`Path`, *optional*):
Path to the lockfile. If not provided, the lockfile will be loaded from the caller's package metadata.
Returns:
`ModuleType`: The imported kernel module.
If `lockfile` is not specified, the lockfile will be loaded from the
caller's package metadata.
"""
if lockfile is None:
locked_sha = _get_caller_locked_kernel(repo_id)
@ -365,18 +274,7 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
"""
Get a kernel using a lock file.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only use local files and not download from the Hub.
Returns:
`ModuleType`: The imported kernel module.
"""
"""Get a kernel using a lock file."""
locked_sha = _get_caller_locked_kernel(repo_id)
if locked_sha is None:

View File

@ -1,24 +1,10 @@
import sys
import pytest
import torch
has_cuda = (
hasattr(torch.version, "cuda")
and torch.version.cuda is not None
and torch.cuda.device_count() > 0
)
has_rocm = (
hasattr(torch.version, "hip")
and torch.version.hip is not None
and torch.cuda.device_count() > 0
)
def pytest_runtest_setup(item):
if "cuda_only" in item.keywords and not has_cuda:
pytest.skip("skipping CUDA-only test on host without CUDA")
if "rocm_only" in item.keywords and not has_rocm:
pytest.skip("skipping ROCm-only test on host without ROCm")
if "linux_only" in item.keywords and not sys.platform.startswith("linux"):
pytest.skip("skipping Linux-only test on non-Linux platform")
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
pytest.skip("skipping macOS-only test on non-macOS platform")

View File

@ -1,12 +0,0 @@
[
{
"repo_id": "kernels-test/versions",
"sha": "dc142fd6c9920c993d32be6358b78957c58681c3",
"variants": {
"torch-universal": {
"hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c",
"hash_type": "git_lfs_concat"
}
}
}
]

View File

@ -1,2 +0,0 @@
[tool.kernels.dependencies]
"kernels-test/versions" = ">=0.1.0,<0.2.0"

View File

@ -10,16 +10,10 @@ def kernel():
@pytest.fixture
def local_kernel_path():
def local_kernel():
package_name, path = install_kernel("kernels-community/activation", "main")
# Path is the build variant path (build/torch-<...>), so the grandparent
# is the kernel repository path.
return package_name, path
@pytest.fixture
def local_kernel(local_kernel_path):
package_name, path = local_kernel_path
return get_local_kernel(path.parent.parent, package_name)
@ -40,7 +34,7 @@ def device():
return "cuda"
@pytest.mark.cuda_only
@pytest.mark.linux_only
def test_gelu_fast(kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
@ -56,7 +50,7 @@ def test_gelu_fast(kernel, device):
assert torch.allclose(y, expected)
@pytest.mark.cuda_only
@pytest.mark.linux_only
def test_local_kernel(local_kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
@ -72,39 +66,6 @@ def test_local_kernel(local_kernel, device):
assert torch.allclose(y, expected)
@pytest.mark.cuda_only
def test_local_kernel_path_types(local_kernel_path, device):
package_name, path = local_kernel_path
# Top-level repo path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071
kernel = get_local_kernel(path.parent.parent, package_name)
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
device=device,
dtype=torch.float16,
)
assert torch.allclose(y, expected)
# Build directory path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build
kernel = get_local_kernel(path.parent.parent / "build", package_name)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
assert torch.allclose(y, expected)
# Explicit package path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux
kernel = get_local_kernel(path, package_name)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
assert torch.allclose(y, expected)
@pytest.mark.darwin_only
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_relu_metal(metal_kernel, dtype):
@ -113,7 +74,7 @@ def test_relu_metal(metal_kernel, dtype):
assert torch.allclose(y, torch.relu(x))
@pytest.mark.cuda_only
@pytest.mark.linux_only
@pytest.mark.parametrize(
"kernel_exists",
[
@ -130,26 +91,7 @@ def test_has_kernel(kernel_exists):
assert has_kernel(repo_id, revision=revision) == kernel
def test_version():
kernel = get_kernel("kernels-test/versions")
assert kernel.version() == "0.2.0"
kernel = get_kernel("kernels-test/versions", version="<1.0.0")
assert kernel.version() == "0.2.0"
kernel = get_kernel("kernels-test/versions", version="<0.2.0")
assert kernel.version() == "0.1.1"
kernel = get_kernel("kernels-test/versions", version=">0.1.0,<0.2.0")
assert kernel.version() == "0.1.1"
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
get_kernel("kernels-test/versions", version=">0.2.0")
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
kernel = get_kernel(
"kernels-test/versions", revision="v0.1.0", version="<1.0.0"
)
@pytest.mark.cuda_only
@pytest.mark.linux_only
def test_universal_kernel(universal_kernel):
torch.manual_seed(0)
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")

View File

@ -16,21 +16,21 @@ def device():
return "cuda"
@pytest.mark.cuda_only
@pytest.mark.linux_only
def test_gelu_small(kernel, device, benchmark):
x = torch.randn(32, 32, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.cuda_only
@pytest.mark.linux_only
def test_gelu_medium(kernel, device, benchmark):
x = torch.randn(128, 128, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.cuda_only
@pytest.mark.linux_only
def test_gelu_large(kernel, device, benchmark):
x = torch.randn(512, 512, dtype=torch.float16, device=device)
y = torch.empty_like(x)

View File

@ -1,49 +0,0 @@
import inspect
import pytest
from mktestdocs import check_docstring, get_codeblock_members
import kernels
def all_public_functions():
function_list = inspect.getmembers(kernels, inspect.isfunction)
return [func for _, func in function_list]
def all_public_classes():
class_list = inspect.getmembers(kernels, inspect.isclass)
return [cls for _, cls in class_list]
def all_public_class_members():
members = get_codeblock_members(*all_public_classes())
return members
@pytest.mark.cuda_only
@pytest.mark.parametrize(
"func",
all_public_functions(),
ids=lambda d: d.__name__,
)
def test_func_docstring(func):
check_docstring(obj=func)
@pytest.mark.cuda_only
@pytest.mark.parametrize(
"cls",
all_public_classes(),
ids=lambda d: d.__name__,
)
def test_class_docstring(cls):
check_docstring(obj=cls)
@pytest.mark.cuda_only
@pytest.mark.parametrize(
"member", all_public_class_members(), ids=lambda d: d.__qualname__
)
def test_member_docstring(member):
check_docstring(member)

View File

@ -1,230 +0,0 @@
import random
from typing import Generic, List, Optional, Tuple, TypeVar
import pytest
from kernels._interval_tree import IntervalTree, _Node
T = TypeVar("T")
class SimpleIntervalStore(Generic[T]):
"""A simple O(n) implementation that stores intervals in a list."""
def __init__(self):
self.intervals: List[Tuple[int, int, T]] = []
def insert(self, start: int, end: int, data: T) -> None:
"""Insert an interval into the store."""
# Replace data if the interval already exists.
for i, (existing_start, existing_end, existing_data) in enumerate(
self.intervals
):
if existing_start == start and existing_end == end:
self.intervals[i] = (start, end, data)
return
self.intervals.append((start, end, data))
def find_smallest_interval(self, point: int) -> Optional[T]:
"""Find the best match using linear search."""
matches = []
for start, end, data in self.intervals:
if start <= point <= end:
matches.append((start, end, data))
if not matches:
return None
# Return the smallest interval, sort by memory location when
# there are multiple matches with the same interval size. This
# mirrors the ordering in the intervan tree.
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
return best_match[2]
def is_balanced(tree: IntervalTree[T]) -> bool:
"""Check if the AVL tree is properly balanced."""
def check_balance(node: Optional[_Node[T]]) -> Tuple[bool, int]:
if node is None:
return True, 0
# Left and right subtrees should be balanced.
left_balanced, left_height = check_balance(node.left)
if not left_balanced:
return False, -1
right_balanced, right_height = check_balance(node.right)
if not right_balanced:
return False, -1
# The difference in height should not exceed 1.
if abs(left_height - right_height) > 1:
return False, -1
# Check if the height is correct.
expected_height = 1 + max(left_height, right_height)
if node.height != expected_height:
return False, -1
return True, expected_height
balanced, _ = check_balance(tree.root)
return balanced
@pytest.fixture
def populated_tree() -> IntervalTree[str]:
"""Provides a pre-populated IntervalTree for testing."""
tree = IntervalTree[str]()
kernels = [
(80, 89, "Kernel_A_General_80_89"),
(86, 89, "Kernel_B_Ampere_86_89"),
(80, 86, "Kernel_C_Older_Ampere_80_86"),
(70, 75, "Kernel_D_Volta_70_75"),
(86, 87, "Kernel_E_Specific_86_87"),
]
for start, end, name in kernels:
tree.insert(start, end, name)
return tree
def test_find_smallest_interval_match_with_multiple_overlaps(populated_tree):
# Check that the smallest inteval is selected when there are
# multiple matching intervals.
assert populated_tree.find_smallest_interval(86) == "Kernel_E_Specific_86_87"
def test_find_single_match(populated_tree):
assert populated_tree.find_smallest_interval(72) == "Kernel_D_Volta_70_75"
assert populated_tree.find_smallest_interval(75) == "Kernel_D_Volta_70_75"
def test_no_match_outside_all_ranges(populated_tree):
# Check that no interval is found when the value is out of range
# (too small/too large).
assert populated_tree.find_smallest_interval(65) is None
assert populated_tree.find_smallest_interval(95) is None
def test_no_match_in_gap_between_ranges(populated_tree):
# Check that no interval is found when the value is between two
# intervals.
assert populated_tree.find_smallest_interval(78) is None
def test_boundary_conditions_start_and_end(populated_tree):
# Test exact upper/lower bounds of intervals.
assert populated_tree.find_smallest_interval(80) == "Kernel_C_Older_Ampere_80_86"
assert populated_tree.find_smallest_interval(89) == "Kernel_B_Ampere_86_89"
def test_empty_tree():
# Searching in an empty tree should return None.
empty_tree = IntervalTree[str]()
assert empty_tree.find_smallest_interval(100) is None
def test_multiple_equally_specific_matches():
# Check that we pick the match in a stable way when there is are
# multiple matching intervals with the same size.
tree = IntervalTree[str]()
str1 = "First_Narrow_Kernel"
str2 = "Second_Narrow_Kernel"
tree.insert(10, 20, "Wide_Kernel")
tree.insert(12, 17, str1)
tree.insert(14, 19, str2)
if id(str1) < id(str2):
assert tree.find_smallest_interval(15) == str1
else:
assert tree.find_smallest_interval(15) == str2
def test_property_based_interval_tree():
# Quick-check property-based testing:
#
# - Verify that the tree is balanced after each insertion.
# - Verify the query against a simple list-based implementation.
random.seed(42) # For reproducible tests
test_points = list(range(0, 101))
for _ in range(5):
tree = IntervalTree[str]()
simple = SimpleIntervalStore[str]()
intervals = []
for i in range(100):
start = random.randint(0, 90)
end = random.randint(start, 100)
data = f"interval_{i}_s{start}_e{end}"
intervals.append((start, end, data))
for i, (start, end, data) in enumerate(intervals):
tree.insert(start, end, data)
simple.insert(start, end, data)
# Check that tree is still balanced
assert is_balanced(
tree
), f"Tree became unbalanced after inserting interval {i}: ({start}, {end})"
for point in test_points:
tree_result = tree.find_smallest_interval(point)
simple_result = simple.find_smallest_interval(point)
assert tree_result == simple_result, (
f"Mismatch for point {point} after inserting {i+1} intervals. "
f"Tree: {tree_result}, Simple: {simple_result}. "
f"Last inserted: ({start}, {end})"
)
def test_property_based_edge_cases():
random.seed(123)
tree = IntervalTree[str]()
simple = SimpleIntervalStore[str]()
# Single-point intervals.
for i in range(10):
point = random.randint(0, 100)
data = f"single_point_{i}_{point}"
tree.insert(point, point, data)
simple.insert(point, point, data)
assert is_balanced(
tree
), f"Tree unbalanced after inserting single point {point}"
# Test the exact point and neighbors
for test_point in [point - 1, point, point + 1]:
if 0 <= test_point <= 100:
tree_result = tree.find_smallest_interval(test_point)
simple_result = simple.find_smallest_interval(test_point)
assert tree_result == simple_result
def test_unique_intervals_override():
"""Test that inserting an interval with the same start/end overrides the previous value."""
tree = IntervalTree[str]()
tree.insert(10, 20, "original_value")
assert tree.find_smallest_interval(15) == "original_value"
tree.insert(10, 20, "new_value")
assert tree.find_smallest_interval(15) == "new_value"
tree.insert(10, 25, "different_interval")
results = tree.search(15)
assert "new_value" in results
assert "different_interval" in results
assert len(results) == 2
tree.insert(10, 20, "final_value")
assert tree.find_smallest_interval(15) == "final_value"
assert is_balanced(tree)

View File

@ -2,17 +2,9 @@ from dataclasses import dataclass
from pathlib import Path
import pytest
import torch.nn as nn
from kernels import load_kernel
from kernels.cli import download_kernels
from kernels.layer import (
LockedLayerRepository,
Mode,
kernelize,
use_kernel_forward_from_hub,
use_kernel_mapping,
)
# Mock download arguments class.
@ -27,34 +19,9 @@ def test_download_all_hash_validation():
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
@pytest.mark.cuda_only
@pytest.mark.linux_only
def test_load_locked():
project_dir = Path(__file__).parent / "kernel_locking"
# Also validates that hashing works correctly.
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
def test_layer_locked():
project_dir = Path(__file__).parent / "layer_locking"
@use_kernel_forward_from_hub("Version")
class Version(nn.Module):
def forward(self) -> str:
return "0.0.0"
version = Version()
with use_kernel_mapping(
{
"Version": {
"cuda": LockedLayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
lockfile=project_dir / "kernels.lock",
)
},
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.1.1"

View File

@ -1,4 +1,3 @@
import sys
from contextlib import nullcontext
import pytest
@ -7,21 +6,14 @@ import torch.nn as nn
from torch.nn import functional as F
from kernels import (
CUDAProperties,
Device,
LayerRepository,
LocalLayerRepository,
Mode,
kernelize,
register_kernel_mapping,
use_kernel_forward_from_hub,
use_kernel_mapping,
)
from kernels.layer import (
_KERNEL_MAPPING,
_validate_layer,
)
from kernels.utils import install_kernel
from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping
kernel_layer_mapping = {
"SiluAndMul": {
@ -34,11 +26,7 @@ kernel_layer_mapping = {
"cuda": LayerRepository(
repo_id="kernels-test/op-without-fake-test",
layer_name="SiluAndMul",
),
"rocm": LayerRepository(
repo_id="kernels-test/op-without-fake-test",
layer_name="SiluAndMul",
),
)
},
"SiluAndMulStringDevice": {
"cuda": LayerRepository(
@ -108,120 +96,26 @@ def test_arg_kinds():
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
@pytest.mark.cuda_only
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
def test_hub_forward(cls):
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_hub_forward(cls, device):
torch.random.manual_seed(0)
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), device="cuda")
X = torch.randn((32, 64), device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = kernelize(cls(), device="cuda", mode=Mode.INFERENCE)
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
assert silu_and_mul.n_calls == 1
assert silu_and_mul_with_kernel.n_calls == 0
@pytest.mark.rocm_only
def test_hub_forward_rocm():
torch.manual_seed(0)
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64))
Y = silu_and_mul(X)
silu_and_mul_with_kernel = kernelize(
SiluAndMulNoCompileKernel(), device="rocm", mode=Mode.INFERENCE
)
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
assert silu_and_mul.n_calls == 1
# Should use kernel (n_calls == 0) if ROCm kernel is available, otherwise fallback (n_calls == 1)
# The exact behavior depends on whether the test kernel exists for ROCm
assert silu_and_mul_with_kernel.n_calls in [0, 1]
def test_rocm_kernel_mapping():
"""Test that ROCm shorthand device mapping works correctly."""
kernel_layer_mapping = {
"SiluAndMul": {
"rocm": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
}
}
# Test that the mapping is processed correctly
with use_kernel_mapping(kernel_layer_mapping, inherit_mapping=False):
mapping = _KERNEL_MAPPING.get()
# Verify the mapping exists
assert "SiluAndMul" in mapping
assert "rocm" in mapping["SiluAndMul"]
# Verify the repository is correctly stored
rocm_repos = mapping["SiluAndMul"]["rocm"]
assert rocm_repos is not None
assert (
rocm_repos.repos[Mode.FALLBACK]._repo_id == "kernels-community/activation"
)
assert rocm_repos.repos[Mode.FALLBACK].layer_name == "SiluAndMul"
@pytest.mark.cuda_only
def test_capability():
linear = TorchLinearWithCounter(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=75, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Check that we called out to the kernel.
assert linear.n_calls == 0
with use_kernel_mapping(
{
"Linear": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=sys.maxsize, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Check that we didn't call out to the kernel because there is
# is no kernel with a matching capability..
assert linear.n_calls == 1
if device == "cuda":
assert silu_and_mul_with_kernel.n_calls == 0
else:
assert silu_and_mul_with_kernel.n_calls == 1
def test_layer_fallback_works():
@ -234,33 +128,7 @@ def test_layer_fallback_works():
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
def test_local_layer_repo():
# Fetch a kernel to the local cache.
package_name, path = install_kernel("kernels-test/backward-marker-test", "main")
linear = TorchLinearWithCounter(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
"cuda": LocalLayerRepository(
# install_kernel will give the fully-resolved path.
repo_path=path.parent.parent,
package_name=package_name,
layer_name="LinearBackward",
)
}
},
inherit_mapping=False,
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
@pytest.mark.cuda_only
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_without_fallback(cls, device):
@ -291,7 +159,7 @@ def test_torch_compile_layer_without_fallback(cls, device):
torch.testing.assert_close(Y_compiled, Y)
@pytest.mark.cuda_only
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_with_fallback(cls, device):
@ -314,11 +182,7 @@ def test_torch_compile_layer_with_fallback(cls, device):
torch.testing.assert_close(Y_compiled, Y)
@pytest.mark.cuda_only
def test_mapping_contexts():
# Make sure we start from scratch.
register_kernel_mapping(kernel_layer_mapping, inherit_mapping=False)
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
@ -361,9 +225,9 @@ def test_mapping_contexts():
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"]
.repos[Mode.FALLBACK]
._repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/non-existing"
)
@ -374,7 +238,9 @@ def test_mapping_contexts():
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK]._repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/activation"
)
@ -383,9 +249,9 @@ def test_mapping_contexts():
"SiluAndMul",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"]
.repos[Mode.FALLBACK]
._repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/non-existing"
)
@ -396,7 +262,9 @@ def test_mapping_contexts():
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK]._repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/activation"
)
@ -435,7 +303,6 @@ def test_validate_kernel_layer():
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
@pytest.mark.cuda_only
def test_invalid_mode_for_mapping_rejected():
linear = TorchLinearWithCounter(32, 32).to("cuda")
@ -455,7 +322,6 @@ def test_invalid_mode_for_mapping_rejected():
kernelize(linear, mode=Mode.TRAINING)
@pytest.mark.cuda_only
def test_kernel_modes():
linear = TorchLinearWithCounter(32, 32).to("cuda")
@ -501,24 +367,24 @@ def test_kernel_modes():
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Training has a kernel, so fallback.
assert linear.n_calls == 0
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
assert linear.n_calls == 1
# No kernel for training + torch.compile, so fallback.
assert linear.n_calls == 2
# Case 3: register a kernel just for training and one for fallback.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.FALLBACK: LayerRepository(
Mode.DEFAULT: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
@ -533,18 +399,18 @@ def test_kernel_modes():
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Falls back to TRAINING.
assert linear.n_calls == 1
# Uses the base kernel.
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Falls back to the TRAINING kernel.
assert linear.n_calls == 1
# Uses the training kernel.
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
assert linear.n_calls == 1
# Uses the base kernel.
assert linear.n_calls == 2
# Case 4: register a kernel with two preferences.
with use_kernel_mapping(
@ -563,21 +429,21 @@ def test_kernel_modes():
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Falls back to the TRAINING | TORCH_COMPILE kernel.
assert linear.n_calls == 1
# No inference kernel, so fallback.
assert linear.n_calls == 3
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# TRAINING can fall back to TRAINING | TORCH_COMPILE kernel.
assert linear.n_calls == 1
# No training kernel, so fallback.
assert linear.n_calls == 4
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# Uses TRAINING | TORCH_COMPILE kernel.
assert linear.n_calls == 1
# We do have a training + torch.compile kernel.
assert linear.n_calls == 4
@pytest.mark.cuda_only
@pytest.mark.linux_only
def test_fallback_used_when_training():
linear = TorchLinearWithCounter(32, 32).to("cuda")
@ -631,385 +497,12 @@ def test_invalid_mode_rejected():
_ = Mode.INFERENCE | Mode.TRAINING
with pytest.raises(ValueError, match="cannot be combined with other modes"):
_ = Mode.FALLBACK | Mode.TORCH_COMPILE
_ = Mode.DEFAULT | Mode.TORCH_COMPILE
with pytest.raises(
ValueError, match="can only be used to register kernel mappings"
):
kernelize(torch.nn.Linear(32, 32), mode=Mode.FALLBACK)
kernelize(torch.nn.Linear(32, 32), mode=Mode.DEFAULT)
with pytest.raises(ValueError, match="mode must contain"):
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
@pytest.mark.cuda_only
def test_kernel_modes_inference():
"""Test inference-specific fallback scenarios."""
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: register a kernel just for inference
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
linear(X)
# INFERENCE | TORCH_COMPILE cannot fall back to INFERENCE kernel, so uses original
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# No training kernel, so fallback to original
assert linear.n_calls == 2
# Case 2: register a kernel just for inference + torch.compile
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.INFERENCE
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 2
kernelize(linear, mode=Mode.INFERENCE)
linear(X)
# INFERENCE falls back to INFERENCE | TORCH_COMPILE kernel
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# No training kernel, so fallback to original
assert linear.n_calls == 3
# Case 3: register both inference kernels
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.INFERENCE
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Uses exact INFERENCE kernel
assert linear.n_calls == 3
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
linear(X)
# Uses exact INFERENCE | TORCH_COMPILE kernel
assert linear.n_calls == 3
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# No training kernel, so fallback to original
assert linear.n_calls == 4
@pytest.mark.cuda_only
def test_kernel_modes_mixed():
"""Test mixed training and inference kernel scenarios."""
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: register both base inference and training kernels
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING)
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
linear(X)
# INFERENCE | TORCH_COMPILE cannot fall back to INFERENCE kernel, so uses original
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original
assert linear.n_calls == 2
# Case 2: register all four kernel modes
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.INFERENCE
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.TRAINING
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Uses exact INFERENCE kernel
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Uses exact TRAINING kernel
assert linear.n_calls == 2
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
linear(X)
# Uses exact INFERENCE | TORCH_COMPILE kernel
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# Uses exact TRAINING | TORCH_COMPILE kernel
assert linear.n_calls == 2
@pytest.mark.cuda_only
def test_kernel_modes_cross_fallback():
"""Test cross-mode fallback scenarios from inference to training modes."""
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: Only training kernel registered - inference should fall back to training
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# INFERENCE falls back to TRAINING kernel
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# TRAINING uses the kernel directly
assert linear.n_calls == 0
# Case 2: Only training + torch.compile kernel registered
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# INFERENCE falls back to TRAINING | TORCH_COMPILE kernel
assert linear.n_calls == 0
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
linear(X)
# INFERENCE | TORCH_COMPILE falls back to TRAINING | TORCH_COMPILE kernel
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# TRAINING falls back to TRAINING | TORCH_COMPILE kernel
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# TRAINING | TORCH_COMPILE uses the kernel directly
assert linear.n_calls == 0
# Case 3: Test that training modes don't fall back to inference modes
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.INFERENCE
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.TRAINING)
X = torch.randn(10, 32, device="cuda")
linear(X)
# TRAINING should NOT fall back to inference kernels, use original
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# TRAINING | TORCH_COMPILE should NOT fall back to inference kernels, use original
assert linear.n_calls == 2
def test_layer_versions():
@use_kernel_forward_from_hub("Version")
class Version(nn.Module):
def forward(self) -> str:
return "0.0.0"
version = Version()
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.2.0"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version="<1.0.0",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.2.0"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version="<0.2.0",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.1.1"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version=">0.1.0,<0.2.0",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.1.1"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version=">0.2.0",
)
}
}
):
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
kernelize(version, device="cuda", mode=Mode.INFERENCE)
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
revision="v0.1.0",
version="<1.0.0",
)
}
}
)