Compare commits

...

26 Commits

Author SHA1 Message Date
a988871e9e Add support for stateful layers 2025-09-19 10:14:43 +00:00
055a953552 Document the to-wheel subcommand (#149)
* Document the `to-wheel` subcommand

* Capitalization
2025-09-17 17:02:41 +02:00
692d5ad458 Fix some spelling errors to check docs CI is working (#120) 2025-09-17 13:44:09 +02:00
2139df57f4 rm link (#148) 2025-09-17 12:46:49 +02:00
8f9a77bb6a Describe the get_kernel/LayerRepository (#147)
This was already in the API documentation, but describe this in the
guides as well (since we want people to use versions).
2025-09-16 16:06:40 +02:00
6c00194680 Improve errors for layer validation (#145)
* Improve errors for layer validation

Include the repo and layer name as well as the name of the class
that is being compared to (when applicable).

* Remove upload xfail

* Only enable tests that require a token with `--token`
2025-09-16 14:40:54 +02:00
d6b51eefb7 [feat] add an uploading utility (#138)
* add an uploading utility.

* format

* remove stale files.

* black format

* sorted imports.

* up

* up

* add a test

* propagate.

* remove duplicate imports.

* Apply suggestions from code review

Co-authored-by: Daniël de Kok <me@danieldk.eu>

* up

* up

* up

* command to format all files at once would be nice.

* up

* up

* up

* Use token for upload test

* assign env better.

* docs

* polish

* up

* xfail the test for now.

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-09-16 08:56:54 +02:00
d383fdd4b4 Add support for XPU layer repostories (#142)
This change adds support for XPU layer repositories, e.g.:

```
kernel_mapping = {
    "LigerRMSNorm": {
        "xpu": LayerRepository(
            repo_id="kernels-community/liger_kernels",
            layer_name="LigerRMSNorm",
        )
    },
}

Co-authored-by: YangKai0616 <kai.yang@intel.com>
2025-09-11 15:51:02 +02:00
07e5e8481a Set version to 0.10.1.dev0 (#140)
* Set version to 0.10.1.dev0

* Add `__version__` attribute to top-level module

This is needed for doc generation.
2025-09-10 09:08:02 +02:00
88f55d4728 XPU: look up kernel by framework version (#139)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-09-09 13:10:11 +02:00
e801ebf332 Set version to v0.10.0.dev0 (#137) 2025-09-05 10:48:41 +02:00
0ae07f05fc Remove default for mode argument of kernelize (#136) 2025-08-29 17:44:20 +02:00
7611021100 cpu is not (yet) a supported device type (#132)
Fixes #131.
2025-08-25 16:25:58 +02:00
767e7ccf13 fix: add get local tests (#134)
* fix: add tests for get local kernel

* fix: update test and add path example comments

* fix: run black linter
2025-08-21 13:01:48 -04:00
1caa4c1393 feat: improve get local kernel importing (#129)
* feat: improve get local kernel importing

* fix: adjust for linter
2025-08-08 10:22:29 -04:00
da701bf58a Small markup fixes of the local kernel repo example (#127) 2025-08-06 08:02:28 +02:00
703664ed31 Set version to 0.9.0.dev0 (#126) 2025-08-01 16:37:30 +02:00
a8a6564fa7 Add ROCm device discovery (#122)
* Add ROCm device discovery

* Ruff

* Address review comments

* Ruff

* Reorg torch import

* Remove redundant import

* Apply suggestions from code review

Co-authored-by: Daniël de Kok <me@danieldk.eu>

* Address review comments

* Validat device type

* Clean diff

* black

* Sync test with repo changes

* black again

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-08-01 16:09:45 +02:00
c89e0fa9b9 Nix: go back to hf-nix main (#125) 2025-08-01 15:56:02 +02:00
176a601178 Run black check (#124) 2025-08-01 15:42:38 +02:00
cfa0c76ddc Add LocalLayerRepository to load from a local repo (#123) 2025-08-01 14:03:11 +02:00
bcc29915f9 Log when using fallback layer (#121) 2025-07-31 17:18:00 +02:00
6fbff7a9cb Add doc build to CI (#119)
* Add doc build to CI

* Trigger doc build

* No path scoping
2025-07-29 16:01:05 +02:00
f7490bd0a9 Test examples in docstrings using mktestdocs (#118)
Also adjust examples so that they are correct.
2025-07-28 17:31:34 +02:00
8069e3bf0c Update documentation for compatibility with doc-builder (#117) 2025-07-24 16:21:54 +02:00
c540d1e1d6 Fix typo in layers documentation (#116) 2025-07-23 17:13:14 +02:00
34 changed files with 1639 additions and 296 deletions

View File

@ -0,0 +1,17 @@
name: Build documentation
on:
push:
branches:
- main
- doc-builder*
- v*-release
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
with:
commit_sha: ${{ github.sha }}
package: kernels
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}

View File

@ -0,0 +1,15 @@
name: Build PR Documentation
on: pull_request
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: kernels

View File

@ -8,3 +8,24 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Run ruff - name: Run ruff
uses: astral-sh/ruff-action@v3 uses: astral-sh/ruff-action@v3
black:
name: Run black check
runs-on: ubuntu-latest
env:
UV_PYTHON_PREFERENCE: only-managed
steps:
- uses: actions/checkout@v4
- name: Install uv and set the python version
uses: astral-sh/setup-uv@v5
with:
python-version: 3.12
- name: Install black
run: uv pip install black
- name: Check formatting
run: |
uv run black --check src
uv run black --check tests

View File

@ -51,7 +51,10 @@ jobs:
run: uv run mypy src/kernels run: uv run mypy src/kernels
- name: Run tests - name: Run tests
run: uv run pytest tests env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
uv run pytest tests
- name: Check kernel conversion - name: Check kernel conversion
run: | run: |

View File

@ -0,0 +1,16 @@
name: Upload PR Documentation
on:
workflow_run:
workflows: ["Build PR Documentation"]
types:
- completed
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
with:
package_name: kernels
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}

View File

@ -56,10 +56,12 @@ the Hub.
## 📚 Documentation ## 📚 Documentation
- [Using layers](docs/layers.md) - [Introduction](docs/source/index.md)
- [Locking kernel/layer versions](docs/locking.md) - [Installation](docs/source/installation.md)
- [Environment variables](docs/env.md) - [Basic usage](docs/source/basic-usage.md)
- [Using kernels in a Docker container](docs/docker.md) - [Using layers](docs/source/layers.md)
- [Kernel requirements](docs/kernel-requirements.md) - [Locking kernel/layer versions](docs/source/locking.md)
- [Frequently Asked Questions](docs/faq.md) - [Environment variables](docs/source/env.md)
- [Kernel requirements](docs/source/kernel-requirements.md)
- [Frequently Asked Questions](docs/source/faq.md)
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/) - [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)

View File

@ -1,8 +0,0 @@
# Using kernels in a Docker container
build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
```bash
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
```

30
docs/source/_toctree.yml Normal file
View File

@ -0,0 +1,30 @@
- sections:
- local: index
title: Introduction
- local: installation
title: Installation
title: Getting started
- sections:
- local: basic-usage
title: Basic Usage
- local: layers
title: Using Layers
- local: locking
title: Locking Kernel Versions
- local: env
title: Environment Variables
- local: faq
title: FAQ
title: Usage Guide
- sections:
- local: api/kernels
title: Kernels
- local: api/layers
title: Layers
- local: cli
title: Kernels CLI
title: API Reference
- sections:
- local: kernel-requirements
title: Kernel Requirements
title: Developer Guide

View File

@ -0,0 +1,21 @@
# Kernels API Reference
## Main Functions
### get_kernel
[[autodoc]] kernels.get_kernel
### has_kernel
[[autodoc]] kernels.has_kernel
## Loading locked kernels
### load_kernel
[[autodoc]] kernels.load_kernel
### get_locked_kernel
[[autodoc]] kernels.get_locked_kernel

41
docs/source/api/layers.md Normal file
View File

@ -0,0 +1,41 @@
# Layers API Reference
## Making layers kernel-aware
### use_kernel_forward_from_hub
[[autodoc]] kernels.use_kernel_forward_from_hub
### replace_kernel_forward_from_hub
[[autodoc]] kernels.replace_kernel_forward_from_hub
## Registering kernel mappings
### use_kernel_mapping
[[autodoc]] kernels.use_kernel_mapping
### register_kernel_mapping
[[autodoc]] kernels.register_kernel_mapping
## Kernelizing a model
### kernelize
[[autodoc]] kernels.kernelize
## Classes
### Device
[[autodoc]] kernels.Device
### Mode
[[autodoc]] kernels.Mode
### LayerRepository
[[autodoc]] kernels.LayerRepository

View File

@ -0,0 +1,50 @@
# Basic Usage
## Loading Kernels
Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:
```python
import torch
from kernels import get_kernel
# Download optimized kernels from the Hugging Face hub
activation = get_kernel("kernels-community/activation")
# Create a random tensor
x = torch.randn((10, 10), dtype=torch.float16, device="cuda")
# Run the kernel
y = torch.empty_like(x)
activation.gelu_fast(y, x)
print(y)
```
### Using version bounds
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
You can specify which version to download using Python version specifiers:
```python
import torch
from kernels import get_kernel
activation = get_kernel("kernels-community/activation", version=">=0.0.4,<0.1.0")
```
This will get the latest kernel tagged `v0.0.z` where `z` is at least 4. It
is strongly recommended to specify a version bound, since a kernel author
might push incompatible changes to the `main` branch.
## Checking Kernel Availability
You can check if a specific kernel is available for your environment:
```python
from kernels import has_kernel
# Check if kernel is available for current environment
is_available = has_kernel("kernels-community/activation")
print(f"Kernel available: {is_available}")
```

41
docs/source/cli.md Normal file
View File

@ -0,0 +1,41 @@
# Kernels CLI Reference
## Main Functions
### kernels to-wheel
We strongly recommend downloading kernels from the Hub using the `kernels`
package, since this comes with large [benefits](index.md) over using Python
wheels. That said, some projects may require deployment of kernels as
wheels. The `kernels` utility provides a simple solution to this. You can
convert any Hub kernel into a set of wheels with the `to-wheel` command:
```bash
$ kernels to-wheel drbh/img2grey 1.1.2
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu124cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_aarch64.whl
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
☸ img2grey-1.1.2+torch26cu118cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu124cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch27cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
```
### kernels upload
Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload
your kernel builds to the Hub.
**Notes**:
- This will take care of creating a repository on the Hub with the `repo_id` provided.
- If a repo with the `repo_id` already exists and if it contains a `build` with the build variant
being uploaded, it will attempt to delete the files existing under it.
- Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub.

View File

@ -2,9 +2,9 @@
## Why is the kernelization step needed? ## Why is the kernelization step needed?
In earlier versions of `kernels`, a layer's `forward` was replaced by In earlier versions of `kernels`, a layer's `forward` method was replaced
`use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The by `use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`.
new `forward` would dispatch to a kernel based on the device type, The new `forward` would dispatch to a kernel based on the device type,
whether a model was training, etc. However, this approach was whether a model was training, etc. However, this approach was
fundamentally incompatible with `torch.compile` since it relied fundamentally incompatible with `torch.compile` since it relied
on data-dependent branching. on data-dependent branching.

20
docs/source/index.md Normal file
View File

@ -0,0 +1,20 @@
# Kernels
<div align="center">
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
</div>
The Kernel Hub allows Python libraries and applications to load compute
kernels directly from the [Hub](https://hf.co/). To support this kind
of dynamic loading, Hub kernels differ from traditional Python kernel
packages in that they are made to be:
- **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`.
- **Unique**: multiple versions of the same kernel can be loaded in the
same Python process.
- **Compatible**: kernels must support all recent versions of Python and
the different PyTorch build configurations (various CUDA versions
and C++ ABIs). Furthermore, older C library versions must be supported.
You can [search for kernels](https://huggingface.co/models?other=kernel) on
the Hub.

View File

@ -0,0 +1,16 @@
# Installation
Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):
```bash
pip install kernels
```
# Using kernels in a Docker container
Build and run the reference `examples/basic.py` in a Docker container with the following commands:
```bash
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
```

View File

@ -34,6 +34,8 @@ Kernels are versioned on the Hub using Git tags. Version tags must be of
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md) the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
to resolve the version constraints. to resolve the version constraints.
We recommend using [semver](https://semver.org/) to version kernels.
## Native Python module ## Native Python module
Kernels will typically contain a native Python module with precompiled Kernels will typically contain a native Python module with precompiled
@ -50,13 +52,12 @@ have dynamic library dependencies outside:
for compatibility with Python 3.9 and later. for compatibility with Python 3.9 and later.
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based). - Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
This means that the extension **must not** use symbols versions higher than: This means that the extension **must not** use symbols versions higher than:
- GLIBC 2.28 - GLIBC 2.28
- GLIBCXX 3.4.24 - GLIBCXX 3.4.24
- CXXABI 1.3.11 - CXXABI 1.3.11
- GCC 7.0.0 - GCC 7.0.0
These requirement can be checked with the ABI checker (see below). These requirements can be checked with the ABI checker (see below).
### macOS ### macOS

View File

@ -5,7 +5,7 @@ the Hub can replace the `forward` method of an existing layer for a certain
device type. This makes it possible to provide more performant kernels for device type. This makes it possible to provide more performant kernels for
existing layers. existing layers.
See [Kernel requirements](kernel-requirements.md) for more information the See [Kernel requirements](kernel-requirements.md) for more information on the
requirements of Hub layers. requirements of Hub layers.
## Making a layer extensible with kernels from the hub ## Making a layer extensible with kernels from the hub
@ -84,12 +84,6 @@ model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE) model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
``` ```
When the `mode` argument is not specified,
`Mode.TRAINING | Mode.TORCH_COMPILE` is used as the default. This mode
aligns most closely with pure PyTorch layers which also support training
and `torch.compile`. However, to select the most performant kernels, it
is often good to make the mode specific as possible.
### Kernel device ### Kernel device
Kernels can be registered per device type. For instance, separate `cuda` and Kernels can be registered per device type. For instance, separate `cuda` and
@ -107,7 +101,7 @@ model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
kernel does not support backward passes or `torch.compile` respectively, kernel does not support backward passes or `torch.compile` respectively,
`kernenize` will fall back to the original, non-kernelized, layer. You `kernelize` will fall back to the original, non-kernelized, layer. You
can let `kernelize` raise an exception instead by using `use_fallback=False`: can let `kernelize` raise an exception instead by using `use_fallback=False`:
```python ```python
@ -117,7 +111,7 @@ model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=
This can be useful if you want to guarantee that Hub kernels are used. This can be useful if you want to guarantee that Hub kernels are used.
### Inspecting kernels which kernels are used ### Inspecting which kernels are used
The kernels that are used are logged at the `INFO` level by `kernelize`. The kernels that are used are logged at the `INFO` level by `kernelize`.
See the [Python logging](https://docs.python.org/3/library/logging.html) See the [Python logging](https://docs.python.org/3/library/logging.html)
@ -135,6 +129,10 @@ kernel_layer_mapping = {
"cuda": LayerRepository( "cuda": LayerRepository(
repo_id="kernels-community/activation", repo_id="kernels-community/activation",
layer_name="SiluAndMul", layer_name="SiluAndMul",
),
"rocm": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
) )
} }
} }
@ -153,12 +151,39 @@ used with the `use_kernel_mapping` context manager:
```python ```python
with use_kernel_mapping(kernel_layer_mapping): with use_kernel_mapping(kernel_layer_mapping):
# Use the layer for which the mapping is applied. # Use the layer for which the mapping is applied.
model = kernelize(model) model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
``` ```
This ensures that the mapping is not active anymore outside the This ensures that the mapping is not active anymore outside the
`with`-scope. `with`-scope.
### Using version bounds
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
You can specify which version of the kernel to download using Python version
specifiers:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
version=">=0.0.4,<0.1.0",
),
"rocm": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
version=">=0.0.4,<0.1.0",
)
}
}
```
This will get the layer from latest kernel tagged `v0.0.z` where `z` is at
least 4. It is strongly recommended to specify a version bound, since a
kernel author might push incompatible changes to the `main` branch.
### Registering kernels for specific modes ### Registering kernels for specific modes
You might want to register two different kernels for a particular layer, You might want to register two different kernels for a particular layer,
@ -261,7 +286,6 @@ Capabilities behave as follows:
an existing kernel, the new kernel will replace the old kernel. an existing kernel, the new kernel will replace the old kernel.
- When there are multiple kernels that support a capability, the kernel - When there are multiple kernels that support a capability, the kernel
with the smaller capability interval will be used. E.g. given: with the smaller capability interval will be used. E.g. given:
- `KernelA` with `min_capability=80` and `max_capability=89`; - `KernelA` with `min_capability=80` and `max_capability=89`;
- `KernelB` with `min_capability=75` and `max_capability=89`; - `KernelB` with `min_capability=75` and `max_capability=89`;
- `kernelize` runs on a system with capability 8.6. - `kernelize` runs on a system with capability 8.6.
@ -270,3 +294,30 @@ Capabilities behave as follows:
than 75..89. The motivation is that kernels with smaller ranges than 75..89. The motivation is that kernels with smaller ranges
tend to be more optimized for a specific set of GPUs. **This behavior tend to be more optimized for a specific set of GPUs. **This behavior
might still change in the future.** might still change in the future.**
### Registering kernels for specific ROCm capabilities
Registering kernels for the ROCm architecture follows the exact same
pattern as CUDA kernels, using `min_capability` and `max_capability` to restrict
a kernel to a range of ROCm capabilities.
### Loading from a local repository for testing
The `LocalLayerRepository` class is provided to load a repository from
a local directory. For example:
```python
with use_kernel_mapping(
{
"SiluAndMul": {
"cuda": LocalLayerRepository(
repo_path="/home/daniel/kernels/activation",
package_name="activation",
layer_name="SiluAndMul",
)
}
},
inherit_mapping=False,
):
kernelize(linear, mode=Mode.INFERENCE)
```

18
flake.lock generated
View File

@ -58,11 +58,11 @@
"nixpkgs": "nixpkgs" "nixpkgs": "nixpkgs"
}, },
"locked": { "locked": {
"lastModified": 1750775451, "lastModified": 1754038838,
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=", "narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
"owner": "huggingface", "owner": "huggingface",
"repo": "hf-nix", "repo": "hf-nix",
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7", "rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -73,17 +73,17 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1747820358, "lastModified": 1752785354,
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=", "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
"owner": "danieldk", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "d3c1681180717528068082103bf323147de6ab0b", "rev": "d38025438a6ee456758dc03188ca6873a415463b",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "danieldk", "owner": "nixos",
"ref": "cudatoolkit-12.9-kernel-builder",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
"type": "github" "type": "github"
} }
}, },

View File

@ -26,6 +26,10 @@
formatter = pkgs.nixfmt-tree; formatter = pkgs.nixfmt-tree;
devShells = with pkgs; rec { devShells = with pkgs; rec {
default = mkShell { default = mkShell {
nativeBuildInputs = [
# For hf-doc-builder.
nodejs
];
buildInputs = buildInputs =
[ [
black black
@ -36,6 +40,7 @@
++ (with python3.pkgs; [ ++ (with python3.pkgs; [
docutils docutils
huggingface-hub huggingface-hub
mktestdocs
pytest pytest
pytest-benchmark pytest-benchmark
pyyaml pyyaml

View File

@ -1,6 +1,6 @@
[project] [project]
name = "kernels" name = "kernels"
version = "0.8.1.dev0" version = "0.10.1.dev0"
description = "Download compute kernels" description = "Download compute kernels"
authors = [ authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" }, { name = "OlivierDehaene", email = "olivier@huggingface.co" },
@ -24,16 +24,20 @@ build-backend = "setuptools.build_meta"
[dependency-groups] [dependency-groups]
dev = [ dev = [
"mypy >= 1.15.0", "mktestdocs>=0.2.5",
"pytest >=8", "mypy>=1.15.0",
"pytest>=8",
# Whatever version is compatible with pytest. # Whatever version is compatible with pytest.
"pytest-benchmark", "pytest-benchmark",
"torch >=2.5", "torch>=2.5",
"types-pyyaml" "types-pyyaml"
] ]
[project.optional-dependencies] [project.optional-dependencies]
torch = ["torch"] torch = ["torch"]
docs = [
"hf-doc-builder",
]
[project.scripts] [project.scripts]
kernels = "kernels.cli:main" kernels = "kernels.cli:main"

View File

@ -1,4 +1,7 @@
[pytest] [pytest]
markers = markers =
cuda_only: marks tests that should only hosts with CUDA GPUs
rocm_only: marks tests that should only run on hosts with ROCm GPUs
darwin_only: marks tests that should only run on macOS darwin_only: marks tests that should only run on macOS
linux_only: marks tests that should only run on Linux xpu_only: marks tests that should only run on hosts with Intel XPUs
token: enable tests that require a write token

View File

@ -1,7 +1,13 @@
import importlib.metadata
__version__ = importlib.metadata.version("kernels")
from kernels.layer import ( from kernels.layer import (
CUDAProperties, CUDAProperties,
Device, Device,
LayerRepository, LayerRepository,
LocalLayerRepository,
LockedLayerRepository,
Mode, Mode,
kernelize, kernelize,
register_kernel_mapping, register_kernel_mapping,
@ -19,9 +25,12 @@ from kernels.utils import (
) )
__all__ = [ __all__ = [
"__version__",
"CUDAProperties", "CUDAProperties",
"Device", "Device",
"LayerRepository", "LayerRepository",
"LocalLayerRepository",
"LockedLayerRepository",
"Mode", "Mode",
"get_kernel", "get_kernel",
"get_local_kernel", "get_local_kernel",

View File

@ -4,6 +4,8 @@ import json
import sys import sys
from pathlib import Path from pathlib import Path
from huggingface_hub import create_repo, upload_folder
from kernels.compat import tomllib from kernels.compat import tomllib
from kernels.lockfile import KernelLock, get_kernel_locks from kernels.lockfile import KernelLock, get_kernel_locks
from kernels.utils import install_kernel, install_kernel_all_variants from kernels.utils import install_kernel, install_kernel_all_variants
@ -31,6 +33,24 @@ def main():
) )
download_parser.set_defaults(func=download_kernels) download_parser.set_defaults(func=download_kernels)
upload_parser = subparsers.add_parser("upload", help="Upload kernels to the Hub")
upload_parser.add_argument(
"kernel_dir",
type=Path,
help="Directory of the kernel build",
)
upload_parser.add_argument(
"--repo_id",
type=str,
help="Repository ID to use to upload to the Hugging Face Hub",
)
upload_parser.add_argument(
"--private",
action="store_true",
help="If the repository should be private.",
)
upload_parser.set_defaults(func=upload_kernels)
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions") lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
lock_parser.add_argument( lock_parser.add_argument(
"project_dir", "project_dir",
@ -153,6 +173,33 @@ def lock_kernels(args):
json.dump(all_locks, f, cls=_JSONEncoder, indent=2) json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
def upload_kernels(args):
kernel_dir = Path(args.kernel_dir).resolve()
build_dir = kernel_dir / "build"
if not kernel_dir.is_dir():
raise ValueError(f"{kernel_dir} is not a directory")
if not build_dir.is_dir():
raise ValueError("Couldn't find `build` directory inside `kernel_dir`")
repo_id = create_repo(
repo_id=args.repo_id, private=args.private, exist_ok=True
).repo_id
delete_patterns: set[str] = set()
for build_variant in build_dir.iterdir():
if build_variant.is_dir():
delete_patterns.add(f"{build_variant.name}/**")
upload_folder(
repo_id=repo_id,
folder_path=build_dir,
path_in_repo="build",
delete_patterns=list(delete_patterns),
commit_message="Build uploaded using `kernels`.",
)
print(f"✅ Kernel upload successful. Find the kernel in https://hf.co/{repo_id}.")
class _JSONEncoder(json.JSONEncoder): class _JSONEncoder(json.JSONEncoder):
def default(self, o): def default(self, o):
if dataclasses.is_dataclass(o): if dataclasses.is_dataclass(o):

File diff suppressed because it is too large Load Diff

View File

@ -46,8 +46,9 @@ def build_variant() -> str:
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}" compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
compute_framework = "metal" compute_framework = "metal"
elif hasattr(torch, "xpu") and torch.xpu.is_available(): elif torch.version.xpu is not None:
compute_framework = "xpu" version = torch.version.xpu
compute_framework = f"xpu{version[0:4]}{version[5:6]}"
else: else:
raise AssertionError( raise AssertionError(
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled." "Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
@ -98,7 +99,20 @@ def install_kernel(
""" """
Download a kernel for the current environment to the cache. Download a kernel for the current environment to the cache.
The output path is validated againt `hash` when set. The output path is validated against the hashes in `variant_locks` when provided.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
revision (`str`):
The specific revision (branch, tag, or commit) to download.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only use local files and not download from the Hub.
variant_locks (`Dict[str, VariantLock]`, *optional*):
Optional dictionary of variant locks for validation.
Returns:
`Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
""" """
package_name = package_name_from_repo_id(repo_id) package_name = package_name_from_repo_id(repo_id)
variant = build_variant() variant = build_variant()
@ -190,23 +204,31 @@ def get_kernel(
) -> ModuleType: ) -> ModuleType:
""" """
Load a kernel from the kernel hub. Load a kernel from the kernel hub.
This function downloads a kernel to the local Hugging Face Hub cache
directory (if it was not downloaded before) and then loads the kernel. This function downloads a kernel to the local Hugging Face Hub cache directory (if it was not downloaded before)
and then loads the kernel.
Args: Args:
repo_id (`str`): The Hub repository containing the kernel. repo_id (`str`):
revision (`str`, *optional*, defaults to `"main"`): The specific The Hub repository containing the kernel.
revision (branch, tag, or commit) to download. revision (`str`, *optional*, defaults to `"main"`):
Cannot be used together with `version`. The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
version (`str`, *optional*): The kernel version to download. This version (`str`, *optional*):
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
Cannot be used together with `revision`. Cannot be used together with `revision`.
Returns: Returns:
`ModuleType`: The imported kernel module. `ModuleType`: The imported kernel module.
Example: Example:
```python ```python
import torch
from kernels import get_kernel from kernels import get_kernel
kernel = get_kernel("username/my-kernel")
result = kernel.kernel_function(input_data) activation = get_kernel("kernels-community/activation")
x = torch.randn(10, 20, device="cuda")
out = torch.empty_like(x)
result = activation.silu_and_mul(out, x)
``` ```
""" """
revision = select_revision_or_version(repo_id, revision, version) revision = select_revision_or_version(repo_id, revision, version)
@ -217,28 +239,53 @@ def get_kernel(
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType: def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
""" """
Import a kernel from a local kernel repository path. Import a kernel from a local kernel repository path.
Args:
repo_path (`Path`):
The local path to the kernel repository.
package_name (`str`):
The name of the package to import from the repository.
Returns:
`ModuleType`: The imported kernel module.
""" """
package_name, package_path = _load_kernel_from_path(repo_path, package_name) variant = build_variant()
return import_from_path(package_name, package_path / package_name / "__init__.py") universal_variant = universal_build_variant()
# Presume we were given the top level path of the kernel repository.
for base_path in [repo_path, repo_path / "build"]:
# Prefer the universal variant if it exists.
for v in [universal_variant, variant]:
package_path = base_path / v / package_name / "__init__.py"
if package_path.exists():
return import_from_path(package_name, package_path)
# If we didn't find the package in the repo we may have a explicit
# package path.
package_path = repo_path / package_name / "__init__.py"
if package_path.exists():
return import_from_path(package_name, package_path)
raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")
def has_kernel( def has_kernel(
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
) -> bool: ) -> bool:
""" """
Check whether a kernel build exists for the current environment Check whether a kernel build exists for the current environment (Torch version and compute framework).
(Torch version and compute framework).
Args: Args:
repo_id (`str`): The Hub repository containing the kernel. repo_id (`str`):
revision (`str`, *optional*, defaults to `"main"`): The specific The Hub repository containing the kernel.
revision (branch, tag, or commit) to download. revision (`str`, *optional*, defaults to `"main"`):
Cannot be used together with `version`. The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
version (`str`, *optional*): The kernel version to download. This version (`str`, *optional*):
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
Cannot be used together with `revision`. Cannot be used together with `revision`.
Returns: Returns:
`bool`: `true` if a kernel is avaialble for the current environment. `bool`: `True` if a kernel is available for the current environment.
""" """
revision = select_revision_or_version(repo_id, revision, version) revision = select_revision_or_version(repo_id, revision, version)
@ -264,8 +311,16 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
""" """
Get a pre-downloaded, locked kernel. Get a pre-downloaded, locked kernel.
If `lockfile` is not specified, the lockfile will be loaded from the If `lockfile` is not specified, the lockfile will be loaded from the caller's package metadata.
caller's package metadata.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
lockfile (`Path`, *optional*):
Path to the lockfile. If not provided, the lockfile will be loaded from the caller's package metadata.
Returns:
`ModuleType`: The imported kernel module.
""" """
if lockfile is None: if lockfile is None:
locked_sha = _get_caller_locked_kernel(repo_id) locked_sha = _get_caller_locked_kernel(repo_id)
@ -310,7 +365,18 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType: def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
"""Get a kernel using a lock file.""" """
Get a kernel using a lock file.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only use local files and not download from the Hub.
Returns:
`ModuleType`: The imported kernel module.
"""
locked_sha = _get_caller_locked_kernel(repo_id) locked_sha = _get_caller_locked_kernel(repo_id)
if locked_sha is None: if locked_sha is None:

View File

@ -1,10 +1,41 @@
import sys import sys
import pytest import pytest
import torch
has_cuda = (
hasattr(torch.version, "cuda")
and torch.version.cuda is not None
and torch.cuda.device_count() > 0
)
has_rocm = (
hasattr(torch.version, "hip")
and torch.version.hip is not None
and torch.cuda.device_count() > 0
)
has_xpu = (
hasattr(torch.version, "xpu")
and torch.version.xpu is not None
and torch.xpu.device_count() > 0
)
def pytest_addoption(parser):
parser.addoption(
"--token",
action="store_true",
help="run tests that require a token with write permissions",
)
def pytest_runtest_setup(item): def pytest_runtest_setup(item):
if "linux_only" in item.keywords and not sys.platform.startswith("linux"): if "cuda_only" in item.keywords and not has_cuda:
pytest.skip("skipping Linux-only test on non-Linux platform") pytest.skip("skipping CUDA-only test on host without CUDA")
if "rocm_only" in item.keywords and not has_rocm:
pytest.skip("skipping ROCm-only test on host without ROCm")
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"): if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
pytest.skip("skipping macOS-only test on non-macOS platform") pytest.skip("skipping macOS-only test on non-macOS platform")
if "xpu_only" in item.keywords and not has_xpu:
pytest.skip("skipping XPU-only test on host without XPU")
if "token" in item.keywords and not item.config.getoption("--token"):
pytest.skip("need --token option to run this test")

View File

@ -10,10 +10,16 @@ def kernel():
@pytest.fixture @pytest.fixture
def local_kernel(): def local_kernel_path():
package_name, path = install_kernel("kernels-community/activation", "main") package_name, path = install_kernel("kernels-community/activation", "main")
# Path is the build variant path (build/torch-<...>), so the grandparent # Path is the build variant path (build/torch-<...>), so the grandparent
# is the kernel repository path. # is the kernel repository path.
return package_name, path
@pytest.fixture
def local_kernel(local_kernel_path):
package_name, path = local_kernel_path
return get_local_kernel(path.parent.parent, package_name) return get_local_kernel(path.parent.parent, package_name)
@ -34,7 +40,7 @@ def device():
return "cuda" return "cuda"
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_gelu_fast(kernel, device): def test_gelu_fast(kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x) y = torch.empty_like(x)
@ -50,7 +56,7 @@ def test_gelu_fast(kernel, device):
assert torch.allclose(y, expected) assert torch.allclose(y, expected)
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_local_kernel(local_kernel, device): def test_local_kernel(local_kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x) y = torch.empty_like(x)
@ -66,6 +72,39 @@ def test_local_kernel(local_kernel, device):
assert torch.allclose(y, expected) assert torch.allclose(y, expected)
@pytest.mark.cuda_only
def test_local_kernel_path_types(local_kernel_path, device):
package_name, path = local_kernel_path
# Top-level repo path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071
kernel = get_local_kernel(path.parent.parent, package_name)
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
device=device,
dtype=torch.float16,
)
assert torch.allclose(y, expected)
# Build directory path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build
kernel = get_local_kernel(path.parent.parent / "build", package_name)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
assert torch.allclose(y, expected)
# Explicit package path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux
kernel = get_local_kernel(path, package_name)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
assert torch.allclose(y, expected)
@pytest.mark.darwin_only @pytest.mark.darwin_only
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_relu_metal(metal_kernel, dtype): def test_relu_metal(metal_kernel, dtype):
@ -74,7 +113,7 @@ def test_relu_metal(metal_kernel, dtype):
assert torch.allclose(y, torch.relu(x)) assert torch.allclose(y, torch.relu(x))
@pytest.mark.linux_only @pytest.mark.cuda_only
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kernel_exists", "kernel_exists",
[ [
@ -110,7 +149,7 @@ def test_version():
) )
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_universal_kernel(universal_kernel): def test_universal_kernel(universal_kernel):
torch.manual_seed(0) torch.manual_seed(0)
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda") A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")

View File

@ -16,21 +16,21 @@ def device():
return "cuda" return "cuda"
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_gelu_small(kernel, device, benchmark): def test_gelu_small(kernel, device, benchmark):
x = torch.randn(32, 32, dtype=torch.float16, device=device) x = torch.randn(32, 32, dtype=torch.float16, device=device)
y = torch.empty_like(x) y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x) benchmark(kernel.gelu_fast, y, x)
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_gelu_medium(kernel, device, benchmark): def test_gelu_medium(kernel, device, benchmark):
x = torch.randn(128, 128, dtype=torch.float16, device=device) x = torch.randn(128, 128, dtype=torch.float16, device=device)
y = torch.empty_like(x) y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x) benchmark(kernel.gelu_fast, y, x)
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_gelu_large(kernel, device, benchmark): def test_gelu_large(kernel, device, benchmark):
x = torch.randn(512, 512, dtype=torch.float16, device=device) x = torch.randn(512, 512, dtype=torch.float16, device=device)
y = torch.empty_like(x) y = torch.empty_like(x)

49
tests/test_doctest.py Normal file
View File

@ -0,0 +1,49 @@
import inspect
import pytest
from mktestdocs import check_docstring, get_codeblock_members
import kernels
def all_public_functions():
function_list = inspect.getmembers(kernels, inspect.isfunction)
return [func for _, func in function_list]
def all_public_classes():
class_list = inspect.getmembers(kernels, inspect.isclass)
return [cls for _, cls in class_list]
def all_public_class_members():
members = get_codeblock_members(*all_public_classes())
return members
@pytest.mark.cuda_only
@pytest.mark.parametrize(
"func",
all_public_functions(),
ids=lambda d: d.__name__,
)
def test_func_docstring(func):
check_docstring(obj=func)
@pytest.mark.cuda_only
@pytest.mark.parametrize(
"cls",
all_public_classes(),
ids=lambda d: d.__name__,
)
def test_class_docstring(cls):
check_docstring(obj=cls)
@pytest.mark.cuda_only
@pytest.mark.parametrize(
"member", all_public_class_members(), ids=lambda d: d.__qualname__
)
def test_member_docstring(member):
check_docstring(member)

View File

@ -27,7 +27,7 @@ def test_download_all_hash_validation():
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir)) download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_load_locked(): def test_load_locked():
project_dir = Path(__file__).parent / "kernel_locking" project_dir = Path(__file__).parent / "kernel_locking"
# Also validates that hashing works correctly. # Also validates that hashing works correctly.

View File

@ -0,0 +1,88 @@
import logging
import os
import re
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List
import pytest
from huggingface_hub import model_info
from kernels.cli import upload_kernels
REPO_ID = "kernels-test/kernels-upload-test"
PY_CONTENT = """\
#!/usr/bin/env python3
def main():
print("Hello from torch-universal!")
if __name__ == "__main__":
main()
"""
@dataclass
class UploadArgs:
kernel_dir: None
repo_id: None
private: False
def next_filename(path: Path) -> Path:
"""
Given a path like foo_2050.py, return foo_2051.py.
"""
m = re.match(r"^(.*?)(\d+)(\.py)$", path.name)
if not m:
raise ValueError(
f"Filename {path.name!r} does not match pattern <prefix>_<number>.py"
)
prefix, number, suffix = m.groups()
new_number = str(int(number) + 1).zfill(len(number))
return path.with_name(f"{prefix}{new_number}{suffix}")
def get_filename_to_change(repo_filenames):
for f in repo_filenames:
if "foo" in f and f.endswith(".py"):
filename_to_change = os.path.basename(f)
break
assert filename_to_change
return filename_to_change
def get_filenames_from_a_repo(repo_id: str) -> List[str]:
try:
repo_info = model_info(repo_id=repo_id, files_metadata=True)
repo_siblings = repo_info.siblings
if repo_siblings is not None:
return [f.rfilename for f in repo_siblings]
else:
raise ValueError("No repo siblings found.")
except Exception as e:
logging.error(f"Error connecting to the Hub: {e}.")
@pytest.mark.token
def test_kernel_upload_deletes_as_expected():
repo_filenames = get_filenames_from_a_repo(REPO_ID)
filename_to_change = get_filename_to_change(repo_filenames)
with tempfile.TemporaryDirectory() as tmpdir:
path = f"{tmpdir}/build/torch-universal/upload_test"
build_dir = Path(path)
build_dir.mkdir(parents=True, exist_ok=True)
changed_filename = next_filename(Path(filename_to_change))
script_path = build_dir / changed_filename
script_path.write_text(PY_CONTENT)
upload_kernels(UploadArgs(tmpdir, REPO_ID, False))
repo_filenames = get_filenames_from_a_repo(REPO_ID)
assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}"
assert not any(
str(filename_to_change) in k for k in repo_filenames
), f"{repo_filenames=}"

View File

@ -5,21 +5,24 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.testing import assert_close
from kernels import ( from kernels import (
CUDAProperties,
Device, Device,
LayerRepository, LayerRepository,
LocalLayerRepository,
Mode, Mode,
kernelize, kernelize,
register_kernel_mapping, register_kernel_mapping,
use_kernel_forward_from_hub, use_kernel_forward_from_hub,
use_kernel_mapping,
) )
from kernels.layer import ( from kernels.layer import (
_KERNEL_MAPPING, _KERNEL_MAPPING,
CUDAProperties,
_validate_layer, _validate_layer,
use_kernel_mapping,
) )
from kernels.utils import install_kernel
kernel_layer_mapping = { kernel_layer_mapping = {
"SiluAndMul": { "SiluAndMul": {
@ -32,7 +35,11 @@ kernel_layer_mapping = {
"cuda": LayerRepository( "cuda": LayerRepository(
repo_id="kernels-test/op-without-fake-test", repo_id="kernels-test/op-without-fake-test",
layer_name="SiluAndMul", layer_name="SiluAndMul",
) ),
"rocm": LayerRepository(
repo_id="kernels-test/op-without-fake-test",
layer_name="SiluAndMul",
),
}, },
"SiluAndMulStringDevice": { "SiluAndMulStringDevice": {
"cuda": LayerRepository( "cuda": LayerRepository(
@ -40,11 +47,37 @@ kernel_layer_mapping = {
layer_name="SiluAndMul", layer_name="SiluAndMul",
) )
}, },
"LigerRMSNorm": {
"xpu": LayerRepository(
repo_id="kernels-community/liger_kernels",
layer_name="LigerRMSNorm", # Triton
)
},
} }
register_kernel_mapping(kernel_layer_mapping) register_kernel_mapping(kernel_layer_mapping)
class RMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float = 1e-6):
super().__init__()
# Used to check that we called hub kernel.
self.n_calls = 0
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, x: torch.Tensor):
self.n_calls += 1
var = x.pow(2).mean(-1, keepdim=True)
x_norm = x * torch.rsqrt(var + self.variance_epsilon)
return x_norm * self.weight
@use_kernel_forward_from_hub("LigerRMSNorm")
class RMSNormWithKernel(RMSNorm):
pass
class SiluAndMul(nn.Module): class SiluAndMul(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -84,6 +117,16 @@ class TorchLinearWithCounter(nn.Linear):
return super().forward(input) return super().forward(input)
@pytest.fixture
def device():
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
pytest.skip("No CUDA or XPU")
def test_arg_kinds(): def test_arg_kinds():
@use_kernel_forward_from_hub("ArgKind") @use_kernel_forward_from_hub("ArgKind")
class ArgKind(nn.Module): class ArgKind(nn.Module):
@ -102,29 +145,99 @@ def test_arg_kinds():
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5) assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
@pytest.mark.linux_only @pytest.mark.cuda_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice]) @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
@pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_hub_forward(cls):
def test_hub_forward(cls, device):
torch.random.manual_seed(0) torch.random.manual_seed(0)
silu_and_mul = SiluAndMul() silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), device=device) X = torch.randn((32, 64), device="cuda")
Y = silu_and_mul(X) Y = silu_and_mul(X)
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE) silu_and_mul_with_kernel = kernelize(cls(), device="cuda", mode=Mode.INFERENCE)
Y_kernel = silu_and_mul_with_kernel(X) Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y) torch.testing.assert_close(Y_kernel, Y)
assert silu_and_mul.n_calls == 1 assert silu_and_mul.n_calls == 1
if device == "cuda": assert silu_and_mul_with_kernel.n_calls == 0
assert silu_and_mul_with_kernel.n_calls == 0
else:
assert silu_and_mul_with_kernel.n_calls == 1
@pytest.mark.linux_only @pytest.mark.rocm_only
def test_hub_forward_rocm():
torch.manual_seed(0)
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64))
Y = silu_and_mul(X)
silu_and_mul_with_kernel = kernelize(
SiluAndMulNoCompileKernel(), device="rocm", mode=Mode.INFERENCE
)
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
assert silu_and_mul.n_calls == 1
# Should use kernel (n_calls == 0) if ROCm kernel is available, otherwise fallback (n_calls == 1)
# The exact behavior depends on whether the test kernel exists for ROCm
assert silu_and_mul_with_kernel.n_calls in [0, 1]
@pytest.mark.xpu_only
def test_hub_forward_xpu():
torch.manual_seed(0)
hidden_size = 1024
weight = torch.ones(hidden_size, device="xpu")
rms_norm = RMSNorm(weight).to("xpu")
X = torch.randn(4, 16, hidden_size, device="xpu", dtype=torch.float32)
Y = rms_norm(X)
rms_norm_with_kernel = kernelize(
RMSNormWithKernel(weight), mode=Mode.INFERENCE, device="xpu"
)
Y_kernel = rms_norm_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
assert rms_norm.n_calls == 1
assert rms_norm_with_kernel.n_calls == 0
@pytest.mark.skipif(
hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
reason="Skip on xpu devices",
)
def test_rocm_kernel_mapping():
"""Test that ROCm shorthand device mapping works correctly."""
kernel_layer_mapping = {
"SiluAndMul": {
"rocm": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
}
}
# Test that the mapping is processed correctly
with use_kernel_mapping(kernel_layer_mapping, inherit_mapping=False):
mapping = _KERNEL_MAPPING.get()
# Verify the mapping exists
assert "SiluAndMul" in mapping
assert "rocm" in mapping["SiluAndMul"]
# Verify the repository is correctly stored
rocm_repos = mapping["SiluAndMul"]["rocm"]
assert rocm_repos is not None
assert (
rocm_repos.repos[Mode.FALLBACK]._repo_id == "kernels-community/activation"
)
assert rocm_repos.repos[Mode.FALLBACK].layer_name == "SiluAndMul"
@pytest.mark.cuda_only
def test_capability(): def test_capability():
linear = TorchLinearWithCounter(32, 32).to("cuda") linear = TorchLinearWithCounter(32, 32).to("cuda")
with use_kernel_mapping( with use_kernel_mapping(
@ -183,7 +296,74 @@ def test_layer_fallback_works():
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE) kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
@pytest.mark.linux_only def test_local_layer_repo(device):
# Fetch a kernel to the local cache.
package_name, path = install_kernel("kernels-test/backward-marker-test", "main")
linear = TorchLinearWithCounter(32, 32).to(device)
with use_kernel_mapping(
{
"Linear": {
device: LocalLayerRepository(
# install_kernel will give the fully-resolved path.
repo_path=path.parent.parent,
package_name=package_name,
layer_name="LinearBackward",
)
}
},
inherit_mapping=False,
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device=device)
linear(X)
assert linear.n_calls == 0
def test_stateful_layer(device):
@use_kernel_forward_from_hub("ReluWithHiddenSize")
class ReluWithHiddenSize(nn.Module):
hidden_size: int
def __init__(self, hidden_size: int):
super().__init__()
self.hidden_size = hidden_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x)
model = ReluWithHiddenSize(hidden_size=64).to(device)
x = torch.randn((32, 64), device=device)
y_ref = model(x)
with use_kernel_mapping(
{
"ReluWithHiddenSize": {
"cuda": LayerRepository(
repo_id="kernels-test/state-test",
layer_name="StatefulReLU",
),
"xpu": LayerRepository(
repo_id="kernels-test/state-test",
layer_name="StatefulReLU",
),
}
},
inherit_mapping=False,
):
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device=device)
y = model(x)
assert_close(y, y_ref)
model = torch.compile(model, fullgraph=True)
y = model(x)
assert_close(y, y_ref)
@pytest.mark.cuda_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel]) @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_without_fallback(cls, device): def test_torch_compile_layer_without_fallback(cls, device):
@ -214,7 +394,7 @@ def test_torch_compile_layer_without_fallback(cls, device):
torch.testing.assert_close(Y_compiled, Y) torch.testing.assert_close(Y_compiled, Y)
@pytest.mark.linux_only @pytest.mark.cuda_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel]) @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_with_fallback(cls, device): def test_torch_compile_layer_with_fallback(cls, device):
@ -237,12 +417,16 @@ def test_torch_compile_layer_with_fallback(cls, device):
torch.testing.assert_close(Y_compiled, Y) torch.testing.assert_close(Y_compiled, Y)
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_mapping_contexts(): def test_mapping_contexts():
# Make sure we start from scratch.
register_kernel_mapping(kernel_layer_mapping, inherit_mapping=False)
assert set(_KERNEL_MAPPING.get().keys()) == { assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul", "SiluAndMul",
"SiluAndMulStringDevice", "SiluAndMulStringDevice",
"SiluAndMulNoCompile", "SiluAndMulNoCompile",
"LigerRMSNorm",
} }
extra_mapping1 = { extra_mapping1 = {
@ -260,6 +444,7 @@ def test_mapping_contexts():
"SiluAndMul", "SiluAndMul",
"SiluAndMulStringDevice", "SiluAndMulStringDevice",
"SiluAndMulNoCompile", "SiluAndMulNoCompile",
"LigerRMSNorm",
"TestKernel", "TestKernel",
} }
@ -278,10 +463,13 @@ def test_mapping_contexts():
"SiluAndMul", "SiluAndMul",
"SiluAndMulStringDevice", "SiluAndMulStringDevice",
"SiluAndMulNoCompile", "SiluAndMulNoCompile",
"LigerRMSNorm",
"TestKernel", "TestKernel",
} }
assert ( assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id _KERNEL_MAPPING.get()["SiluAndMul"]["cuda"]
.repos[Mode.FALLBACK]
._repo_id
== "kernels-community/non-existing" == "kernels-community/non-existing"
) )
@ -289,10 +477,11 @@ def test_mapping_contexts():
"SiluAndMul", "SiluAndMul",
"SiluAndMulStringDevice", "SiluAndMulStringDevice",
"SiluAndMulNoCompile", "SiluAndMulNoCompile",
"LigerRMSNorm",
"TestKernel", "TestKernel",
} }
assert ( assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id _KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK]._repo_id
== "kernels-community/activation" == "kernels-community/activation"
) )
@ -301,7 +490,9 @@ def test_mapping_contexts():
"SiluAndMul", "SiluAndMul",
} }
assert ( assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id _KERNEL_MAPPING.get()["SiluAndMul"]["cuda"]
.repos[Mode.FALLBACK]
._repo_id
== "kernels-community/non-existing" == "kernels-community/non-existing"
) )
@ -309,10 +500,11 @@ def test_mapping_contexts():
"SiluAndMul", "SiluAndMul",
"SiluAndMulStringDevice", "SiluAndMulStringDevice",
"SiluAndMulNoCompile", "SiluAndMulNoCompile",
"LigerRMSNorm",
"TestKernel", "TestKernel",
} }
assert ( assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id _KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK]._repo_id
== "kernels-community/activation" == "kernels-community/activation"
) )
@ -320,6 +512,7 @@ def test_mapping_contexts():
"SiluAndMul", "SiluAndMul",
"SiluAndMulStringDevice", "SiluAndMulStringDevice",
"SiluAndMulNoCompile", "SiluAndMulNoCompile",
"LigerRMSNorm",
} }
@ -329,29 +522,46 @@ def test_validate_kernel_layer():
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.foo = 42 self.foo = 42
with pytest.raises(TypeError, match="not override"): def stub_repo(layer):
_validate_layer(cls=BadLayer, check_cls=SiluAndMul) return LayerRepository(
repo_id="kernels-test/nonexisting", layer_name=layer.__name__
)
with pytest.raises(
TypeError,
match="`kernels-test/nonexisting`.*layer `BadLayer` must not override",
):
_validate_layer(cls=BadLayer, check_cls=SiluAndMul, repo=stub_repo(BadLayer))
class BadLayer2(nn.Module): class BadLayer2(nn.Module):
foo: int = 42 foo: int = 42
with pytest.raises(TypeError, match="not contain additional members"): with pytest.raises(
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul) TypeError,
match="`kernels-test/nonexisting`.*layer `BadLayer2` must not contain.*SiluAndMul",
):
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul, repo=stub_repo(BadLayer2))
class BadLayer3(nn.Module): class BadLayer3(nn.Module):
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ... def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
with pytest.raises(TypeError, match="different number of arguments"): with pytest.raises(
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul) TypeError,
match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer3` does not match `SiluAndMul`: different number of arguments",
):
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul, repo=stub_repo(BadLayer3))
class BadLayer4(nn.Module): class BadLayer4(nn.Module):
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ... def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
with pytest.raises(TypeError, match="different kind of arguments"): with pytest.raises(
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul) TypeError,
match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer4` does not match `SiluAndMul`: different kind of arguments",
):
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul, repo=stub_repo(BadLayer4))
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_invalid_mode_for_mapping_rejected(): def test_invalid_mode_for_mapping_rejected():
linear = TorchLinearWithCounter(32, 32).to("cuda") linear = TorchLinearWithCounter(32, 32).to("cuda")
@ -371,7 +581,7 @@ def test_invalid_mode_for_mapping_rejected():
kernelize(linear, mode=Mode.TRAINING) kernelize(linear, mode=Mode.TRAINING)
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_kernel_modes(): def test_kernel_modes():
linear = TorchLinearWithCounter(32, 32).to("cuda") linear = TorchLinearWithCounter(32, 32).to("cuda")
@ -400,11 +610,6 @@ def test_kernel_modes():
linear(X) linear(X)
assert linear.n_calls == 0 assert linear.n_calls == 0
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
assert linear.n_calls == 0
# Case 2: register a kernel just for training. If no base kernel # Case 2: register a kernel just for training. If no base kernel
# layer is registered, we fall back to the original layer. # layer is registered, we fall back to the original layer.
with use_kernel_mapping( with use_kernel_mapping(
@ -434,12 +639,6 @@ def test_kernel_modes():
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original. # TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
assert linear.n_calls == 1 assert linear.n_calls == 1
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
assert linear.n_calls == 2
# Case 3: register a kernel just for training and one for fallback. # Case 3: register a kernel just for training and one for fallback.
with use_kernel_mapping( with use_kernel_mapping(
{ {
@ -461,23 +660,17 @@ def test_kernel_modes():
X = torch.randn(10, 32, device="cuda") X = torch.randn(10, 32, device="cuda")
linear(X) linear(X)
# Falls back to TRAINING. # Falls back to TRAINING.
assert linear.n_calls == 2 assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING) kernelize(linear, mode=Mode.TRAINING)
linear(X) linear(X)
# Falls back to the TRAINING kernel. # Falls back to the TRAINING kernel.
assert linear.n_calls == 2 assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE) kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X) linear(X)
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel. # TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
assert linear.n_calls == 2 assert linear.n_calls == 1
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
assert linear.n_calls == 2
# Case 4: register a kernel with two preferences. # Case 4: register a kernel with two preferences.
with use_kernel_mapping( with use_kernel_mapping(
@ -497,25 +690,20 @@ def test_kernel_modes():
X = torch.randn(10, 32, device="cuda") X = torch.randn(10, 32, device="cuda")
linear(X) linear(X)
# Falls back to the TRAINING | TORCH_COMPILE kernel. # Falls back to the TRAINING | TORCH_COMPILE kernel.
assert linear.n_calls == 2 assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING) kernelize(linear, mode=Mode.TRAINING)
linear(X) linear(X)
# TRAINING can fall back to TRAINING | TORCH_COMPILE kernel. # TRAINING can fall back to TRAINING | TORCH_COMPILE kernel.
assert linear.n_calls == 2 assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE) kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X) linear(X)
# Uses TRAINING | TORCH_COMPILE kernel. # Uses TRAINING | TORCH_COMPILE kernel.
assert linear.n_calls == 2 assert linear.n_calls == 1
kernelize(linear)
linear(X)
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
assert linear.n_calls == 2
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_fallback_used_when_training(): def test_fallback_used_when_training():
linear = TorchLinearWithCounter(32, 32).to("cuda") linear = TorchLinearWithCounter(32, 32).to("cuda")
@ -580,7 +768,7 @@ def test_invalid_mode_rejected():
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE) kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_kernel_modes_inference(): def test_kernel_modes_inference():
"""Test inference-specific fallback scenarios.""" """Test inference-specific fallback scenarios."""
linear = TorchLinearWithCounter(32, 32).to("cuda") linear = TorchLinearWithCounter(32, 32).to("cuda")
@ -677,7 +865,7 @@ def test_kernel_modes_inference():
assert linear.n_calls == 4 assert linear.n_calls == 4
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_kernel_modes_mixed(): def test_kernel_modes_mixed():
"""Test mixed training and inference kernel scenarios.""" """Test mixed training and inference kernel scenarios."""
linear = TorchLinearWithCounter(32, 32).to("cuda") linear = TorchLinearWithCounter(32, 32).to("cuda")
@ -767,7 +955,7 @@ def test_kernel_modes_mixed():
assert linear.n_calls == 2 assert linear.n_calls == 2
@pytest.mark.linux_only @pytest.mark.cuda_only
def test_kernel_modes_cross_fallback(): def test_kernel_modes_cross_fallback():
"""Test cross-mode fallback scenarios from inference to training modes.""" """Test cross-mode fallback scenarios from inference to training modes."""
linear = TorchLinearWithCounter(32, 32).to("cuda") linear = TorchLinearWithCounter(32, 32).to("cuda")
@ -861,7 +1049,7 @@ def test_kernel_modes_cross_fallback():
assert linear.n_calls == 2 assert linear.n_calls == 2
def test_layer_versions(): def test_layer_versions(device):
@use_kernel_forward_from_hub("Version") @use_kernel_forward_from_hub("Version")
class Version(nn.Module): class Version(nn.Module):
def forward(self) -> str: def forward(self) -> str:
@ -872,20 +1060,20 @@ def test_layer_versions():
with use_kernel_mapping( with use_kernel_mapping(
{ {
"Version": { "Version": {
Device(type="cuda"): LayerRepository( Device(type=device): LayerRepository(
repo_id="kernels-test/versions", repo_id="kernels-test/versions",
layer_name="Version", layer_name="Version",
) )
} }
} }
): ):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE) version = kernelize(version, device=device, mode=Mode.INFERENCE)
assert version() == "0.2.0" assert version() == "0.2.0"
with use_kernel_mapping( with use_kernel_mapping(
{ {
"Version": { "Version": {
Device(type="cuda"): LayerRepository( Device(type=device): LayerRepository(
repo_id="kernels-test/versions", repo_id="kernels-test/versions",
layer_name="Version", layer_name="Version",
version="<1.0.0", version="<1.0.0",
@ -893,13 +1081,13 @@ def test_layer_versions():
} }
} }
): ):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE) version = kernelize(version, device=device, mode=Mode.INFERENCE)
assert version() == "0.2.0" assert version() == "0.2.0"
with use_kernel_mapping( with use_kernel_mapping(
{ {
"Version": { "Version": {
Device(type="cuda"): LayerRepository( Device(type=device): LayerRepository(
repo_id="kernels-test/versions", repo_id="kernels-test/versions",
layer_name="Version", layer_name="Version",
version="<0.2.0", version="<0.2.0",
@ -907,13 +1095,13 @@ def test_layer_versions():
} }
} }
): ):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE) version = kernelize(version, device=device, mode=Mode.INFERENCE)
assert version() == "0.1.1" assert version() == "0.1.1"
with use_kernel_mapping( with use_kernel_mapping(
{ {
"Version": { "Version": {
Device(type="cuda"): LayerRepository( Device(type=device): LayerRepository(
repo_id="kernels-test/versions", repo_id="kernels-test/versions",
layer_name="Version", layer_name="Version",
version=">0.1.0,<0.2.0", version=">0.1.0,<0.2.0",
@ -921,13 +1109,13 @@ def test_layer_versions():
} }
} }
): ):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE) version = kernelize(version, device=device, mode=Mode.INFERENCE)
assert version() == "0.1.1" assert version() == "0.1.1"
with use_kernel_mapping( with use_kernel_mapping(
{ {
"Version": { "Version": {
Device(type="cuda"): LayerRepository( Device(type=device): LayerRepository(
repo_id="kernels-test/versions", repo_id="kernels-test/versions",
layer_name="Version", layer_name="Version",
version=">0.2.0", version=">0.2.0",
@ -936,13 +1124,13 @@ def test_layer_versions():
} }
): ):
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"): with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
kernelize(version, device="cuda", mode=Mode.INFERENCE) kernelize(version, device=device, mode=Mode.INFERENCE)
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"): with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
use_kernel_mapping( use_kernel_mapping(
{ {
"Version": { "Version": {
Device(type="cuda"): LayerRepository( Device(type=device): LayerRepository(
repo_id="kernels-test/versions", repo_id="kernels-test/versions",
layer_name="Version", layer_name="Version",
revision="v0.1.0", revision="v0.1.0",