Compare commits

...

33 Commits

Author SHA1 Message Date
0429131630 Set version to 0.8.1 2025-07-23 14:43:31 +02:00
967ac581b8 Set version to 0.8.1.dev0 (#115) 2025-07-23 14:42:24 +02:00
81088d44e8 Add support for project-wide locking of layers (#114)
This change adds `LockedLayerRepository` as an alternative to
`LayerRepository`. `LockedLayerRepository` allows for locking all kernel
layers that are used at the project level. Example usage:

```
with use_kernel_mapping(
    {
        "SomeLayer": {
            "cuda": LockedLayerRepository(
                repo_id="some-org/some-layer",
                layer_name="SomeLayer",
            )
        },
    }
):
    layer = kernelize(layer, device="cuda", mode=Mode.INFERENCE)
```

This requires that the project has a `pyproject.toml` with kernel
version specifications and `kernel.lock` with the locked kernels.
2025-07-23 09:37:05 +02:00
4a04c005e3 Add version support to LayerRepository (#113)
* Add version support to `LayerRepository`

* Remove some docs that do not apply

* Removed unused member variable
2025-07-22 17:02:39 +02:00
6d3c6daf20 triton based kernel could also run in xpu (#112)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-07-22 10:03:34 +02:00
071900fd69 get_kernel: allow Python-style version specifiers (#111)
Use Python-style version specifiers to resolve to tags. E.g., given
the presence of the tags `v0.1.0`, `v0.1.1`, and `v0.2.0`,

get_kernel("my/kernel", version=">=0.1.0,<0.2.0")

would resolve to `v0.1.1`.
2025-07-21 17:18:35 +02:00
2d2c6b14e0 Set version to 0.8.0.dev0 (#110) 2025-07-15 18:45:03 +02:00
03edc573b1 Log kernel layer selection (#109) 2025-07-15 18:38:17 +02:00
c841a6c90d Improve mode handling (#108)
* Set `kernelize` default mode to `Mode.TRAINING | Mode.TORCH_COMPILE`

Also update docs and tests.

* Rename `Mode.DEFAULT` to `Mode.FALLBACK`

* More fine-grained fallbacks

For instance, INFERENCE can fall back to INFERENCE | TORCH_COMPILE,
TRAINING, TRAINING | TORCH_COMPILE, and FALLBACK.

* Update documtenation for mode fallback

* Mention that you can rerun `kernelize` to change the mode
2025-07-15 16:10:43 +02:00
c7a343f195 Support registering layers with a range of CUDA capabilities (#106)
* Add interval tree implementation

* Support registering layers with a range of CUDA capabilities

This change adds support for registering a layers for ranges
of CUDA capabilities. This makes it possible to use newer, faster
kernels for new GPUs, while falling back to another implementation
on older GPUs.

* Add docs for registering kernels with CUDA capabilities

* Fix typing errors
2025-07-14 16:59:21 +02:00
8d838f947d Fix macOS tests by marking some CUDA-only tests (#105) 2025-07-10 12:24:25 +02:00
b87e6fadbe Set version to 0.7.0.dev0 (#104) 2025-07-07 14:56:43 +02:00
fc935d9874 Support registering inference/training-specific layers (#103)
* Support registering inference/training-specific layers

This change makes it possible to register kernels specialized for
inference, training, and/or `torch.compile`. To do so, the mapping
notation is extended to support registering specialized kernels
for a specific 'mode'. For instance, the following mapping,

```python
kernel_layer_mapping = {
    "SiluAndMul": {
        "cuda": {
          Mode.DEFAULT: LayerRepository(
              repo_id="kernels-community/activation",
              layer_name="SiluAndMul",
          ),
          Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
              repo_id="kernels-community/activation-training-optimized",
              layer_name="SiluAndMul",
          ),
      }
    }
}
```

uses `kernels-community/activation` by default, but will switch to
using `kernels-community/activation-training-optimized` if a model
is kernelized for training and `torch.compile`.

To make it easier to add more modes in the future and to unify the
`register_kernel_mapping` and `kernelize` signatures, the `training`
and `needs_torch_compile` arguments of `kernelize` are replaced by
a single `mode` argument:

```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```

* Documentation fixes

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Add note on when the fallback is used

* Tighten up some Mode checks

* Fix ruff check

* Attempt to fix mypy errors

* More typing fixes

* Ignore Python < 3.11 type check SNAFU

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-07-04 19:57:14 +02:00
3622e1f8dd Add get_local_kernel function (#102)
This function loads a kernel from a local repository (e.g. the output
of kernel-builder), which can be handy for testing.
2025-07-01 13:58:47 +02:00
a7f3b2e8ed Set version to 0.6.2.dev0 (#100) 2025-06-25 09:48:09 +02:00
a6ab5d83ba Make the flake work on Darwin (#98) 2025-06-24 20:35:21 +02:00
4f9f1abfb9 darwin: fix variant CPU for aarch64 (#97) 2025-06-24 20:35:07 +02:00
f94b7780a6 CI: main triton-layer-norm has docs, branch is gone (#99) 2025-06-24 16:40:36 +02:00
bd28883775 Set version to 0.6.1.dev1 (#96) 2025-06-20 11:43:26 +02:00
498429e322 Add README generation for layers (#94) 2025-06-20 10:16:50 +02:00
09c991af4b Add macOS requirements (#95) 2025-06-16 17:20:47 +02:00
bcf8df5875 Bump version to 0.6.0.dev0 (#93) 2025-06-04 13:59:32 +02:00
239afff6f5 Update Nix flake dependencies (#92)
* Update Nix flake dependencies

To ensure that we can test with Torch 2.7 kernels in the development
environment.

* Update nix fmt to use nixfmt-tree
2025-06-04 12:13:19 +02:00
c5ec6b900a Hotfix: add FAQ (#91) 2025-06-04 09:52:39 +02:00
3a635eaeea Automatic fallback for kernels that don't support training (#90)
For kernels that do not support backward, fall back to the original
implementation if `model.train(True)` is called. This removes the
need for the `needs_backward` argument of `kernelize`.
2025-06-03 19:13:57 +02:00
32ec496c5a Make the forward pass torch.compile compatible (#87)
* first commit

* style

* update

* fix

* different approach

* Polish kernelize

- Process comment from the PR.
- Replacement should be on instances, not the class.
- Remove torch compile checks (not relevant during kernelize). We
  might add it back in a different way in another commit: add an
  option to `kernelize`.

* Fixup tests

* Fix `torch.compile` support

* Remove some unused code

* Sync the docs

* CI: update Torch versions

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-06-03 15:06:02 +02:00
848c6db87b Add support for Metal builds (#89)
* Add support for Metal builds

* Add Metal test, gate tests by OS where necessary
2025-05-30 15:54:28 +02:00
fabb8c52d1 Add generate-readme subcommand for generating a README (#88)
* Add `generate-readme` subcommand for generating a README

This README includes all the top-level functions with docs (if
docstrings are available).

* CI: attempt README generation

* Add PyYAML dependencies

* Typing fixes
2025-05-21 15:43:53 +02:00
d66260dd83 kernels: add the to-wheel subcommand (#84)
* kernels: add the `to-wheel` subcommand

This subcommand accepts a kernel repo and version as arguments:

    kernels to-wheel kernels-community/activation 0.0.3

Wheels will then be generated for every build variant.

* CI: check kernel -> wheel conversion

* No typing for wheel.wheelfile
2025-05-08 17:30:06 +02:00
daac8078fc CI: fix some stubs (#83) 2025-05-07 14:43:57 +02:00
fcb9a80ce6 Set version to 0.5.0 (#82) 2025-05-06 11:45:26 +02:00
c25bb32e6e Add publishing workflow (#81) 2025-05-06 09:29:08 +00:00
2036892762 Allow layers to opt in to torch.compile (#79)
* Allow layers to opt in to `torch.compile`

This change allows a layer to set the `can_torch_compile` class
variable to indicate that the layer is compatible with `torch.compile`.
When enabled, the layer does not fall back to the original
implementation when `torch.compile` is used.

* Comment fixes

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-05-06 09:36:33 +02:00
30 changed files with 3702 additions and 220 deletions

120
.github/workflows/publish.yml vendored Normal file
View File

@ -0,0 +1,120 @@
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
on: push
jobs:
build:
name: Build distribution 📦
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.9"
- name: Install pypa/build
run: >-
python3 -m
pip install
build
--user
- name: Build a binary wheel and a source tarball
run: python3 -m build
- name: Store the distribution packages
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: dist/
publish-to-pypi:
name: >-
Publish Python 🐍 distribution 📦 to PyPI
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
needs:
- build
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/kernels
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
github-release:
name: >-
Sign the Python 🐍 distribution 📦 with Sigstore
and upload them to GitHub Release
needs:
- publish-to-pypi
runs-on: ubuntu-latest
permissions:
contents: write # IMPORTANT: mandatory for making GitHub Releases
id-token: write # IMPORTANT: mandatory for sigstore
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Sign the dists with Sigstore
uses: sigstore/gh-action-sigstore-python@v3.0.0
with:
inputs: >-
./dist/*.tar.gz
./dist/*.whl
- name: Create GitHub Release
env:
GITHUB_TOKEN: ${{ github.token }}
run: >-
gh release create
"$GITHUB_REF_NAME"
--repo "$GITHUB_REPOSITORY"
--notes ""
- name: Upload artifact signatures to GitHub Release
env:
GITHUB_TOKEN: ${{ github.token }}
# Upload to GitHub Release using the `gh` CLI.
# `dist/` contains the built packages, and the
# sigstore-produced signatures and certificates.
run: >-
gh release upload
"$GITHUB_REF_NAME" dist/**
--repo "$GITHUB_REPOSITORY"
publish-to-testpypi:
name: Publish Python 🐍 distribution 📦 to TestPyPI
needs:
- build
runs-on: ubuntu-latest
environment:
name: testpypi
url: https://test.pypi.org/p/kernels
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
skip-existing: true # Only upload when the version is unique.

View File

@ -24,7 +24,7 @@ jobs:
max-parallel: 4
matrix:
python-version: ["3.10", "3.12"]
torch-version: ["2.5.1", "2.6.0"]
torch-version: ["2.6.0", "2.7.0"]
env:
UV_PYTHON_PREFERENCE: only-managed
@ -53,6 +53,18 @@ jobs:
- name: Run tests
run: uv run pytest tests
- name: Check kernel conversion
run: |
uv pip install wheel
uv run kernels to-wheel kernels-community/triton-layer-norm 0.0.1
uv pip install triton_layer_norm-0.0.1*.whl
uv run python -c "import triton_layer_norm"
- name: Check README generation
# For now, just checks that generation doesn't fail.
run: |
uv run kernels generate-readme kernels-community/triton-layer-norm
- name: Import check without torch
run: |
uv pip uninstall torch

View File

@ -57,8 +57,9 @@ the Hub.
## 📚 Documentation
- [Using layers](docs/layers.md)
- [Locking kernel versions](docs/locking.md)
- [Locking kernel/layer versions](docs/locking.md)
- [Environment variables](docs/env.md)
- [Using kernels in a Docker container](docs/docker.md)
- [Kernel requirements](docs/kernel-requirements.md)
- [Frequently Asked Questions](docs/faq.md)
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)

13
docs/faq.md Normal file
View File

@ -0,0 +1,13 @@
# FAQ
## Why is the kernelization step needed?
In earlier versions of `kernels`, a layer's `forward` was replaced by
`use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The
new `forward` would dispatch to a kernel based on the device type,
whether a model was training, etc. However, this approach was
fundamentally incompatible with `torch.compile` since it relied
on data-dependent branching.
To avoid branching, we have to make dispatch decisions ahead of time,
which is what the `kernelize` function does.

View File

@ -37,8 +37,14 @@ to resolve the version constraints.
## Native Python module
Kernels will typically contain a native Python module with precompiled
compute kernels and bindings. This module must fulfill the following
requirements:
compute kernels and bindings. This module must fulfill the requirements
outlined in this section. For all operating systems, a kernel must not
have dynamic library dependencies outside:
- Torch;
- CUDA/ROCm libraries installed as dependencies of Torch.
### Linux
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
for compatibility with Python 3.9 and later.
@ -50,12 +56,18 @@ requirements:
- CXXABI 1.3.11
- GCC 7.0.0
These requirement can be checked with the ABI checker (see below).
These requirement can be checked with the ABI checker (see below).
- No dynamic library dependencies outside:
### macOS
- Torch;
- CUDA/ROCm libraries installed as dependencies of Torch.
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
for compatibility with Python 3.9 and later.
- macOS deployment target 15.0.
- Metal 3.0 (`-std=metal3.0`).
The ABI3 requirement can be checked with the ABI checker (see below).
### ABI checker
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
@ -109,9 +121,12 @@ requirements:
- The `forward` method has a signature that is compatible with the
`forward` method that it is extending.
The only exception to the _no class variables rule_ is addition of a
`has_backward` class variable. This variable is used to indicate whether
the layer has a backward pass implemented (`True` when absent).
There are two exceptions to the _no class variables rule_:
1. The `has_backward` variable can be used to indicate whether the layer has
a backward pass implemented (`True` when absent).
2. The `can_torch_compile` variable can be used to indicate whether the layer
supports `torch.compile` (`False` when absent).
This is an example of a pure layer:

View File

@ -23,33 +23,111 @@ class SiluAndMul(nn.Module):
return F.silu(input[..., :d]) * input[..., d:]
```
The decorator changes the layer, so that other implementations of the `forward`
method can be registered using the name `SiluAndMul`.
The decorator does not change the behavior of the class -- it annotates
the class with the given name (here `SiluAndMul`). The `kernelize` function
described below uses this name to look up kernels for the layer.
### External layers
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function.
decorator can be made extensible using the `replace_kernel_forward_from_hub`
function:
```python
from somelibrary import SiluAndMul
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
register_kernel_mapping(kernel_layer_mapping)
```
The `register_kernel_mapping` call maps the name `SiluAndMul` to actual
hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer)
section for more information.
**Warning:** we strongly recommend using layers with a decorator, since
it signifies that the maintainer intends to keep the `forward` signature
compatible with layers from the hub.
## Kernelizing a model
A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. `kernelize` can be used as follows:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE)
```
The `kernelize` function modifies the model in-place, the model itself is
returned as a convenience. The `mode` specifies that the model will be used
in inference. Similarly, you can ask `kernelize` to prepare the model for
training:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING)
```
A model that is kernelized for training can also be used for inference, but
not the other way around. If you want to change the mode of the kernelized
model, you can just run `kernelize` on the model again with the new mode.
If you want to compile a model with `torch.compile`, this should be indicated
in the mode as well. You can do this by combining `Mode.INFERENCE` or
`Mode.TRAINING` with `Mode.TORCH_COMPILE` using the set union (`|`) operator:
```python
model = MyModel(...)
# Inference
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
# Training
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```
When the `mode` argument is not specified,
`Mode.TRAINING | Mode.TORCH_COMPILE` is used as the default. This mode
aligns most closely with pure PyTorch layers which also support training
and `torch.compile`. However, to select the most performant kernels, it
is often good to make the mode specific as possible.
### Kernel device
Kernels can be registered per device type. For instance, separate `cuda` and
`metal` kernels could be registered for the name `SiluAndMul`. By default,
`kernelize` will try to infer the device type from the model's parameters.
You can pass the device type to `kernelize` if the device type cannot be
inferred (e.g. because the model has no parameters):
```python
model = MyModel(...)
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
```
### Fallback `forward`
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
kernel does not support backward passes or `torch.compile` respectively,
`kernenize` will fall back to the original, non-kernelized, layer. You
can let `kernelize` raise an exception instead by using `use_fallback=False`:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
```
This can be useful if you want to guarantee that Hub kernels are used.
### Inspecting kernels which kernels are used
The kernels that are used are logged at the `INFO` level by `kernelize`.
See the [Python logging](https://docs.python.org/3/library/logging.html)
documentation for information on how to configure logging.
## Registering a hub kernel for a layer
Once a layer is made extensible, users can register hub kernels for it
by name using the `register_kernel_mapping` function. For example:
`kernelize` relies on kernel mappings to find Hub kernels for layers.
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
the Hub. For example:
```python
kernel_layer_mapping = {
@ -57,11 +135,14 @@ kernel_layer_mapping = {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
}
}
```
You can register such a mapping using `register_kernel_mapping`:
```python
register_kernel_mapping(kernel_layer_mapping)
```
@ -72,8 +153,120 @@ used with the `use_kernel_mapping` context manager:
```python
with use_kernel_mapping(kernel_layer_mapping):
# Use the layer for which the mapping is applied.
...
model = kernelize(model)
```
This ensures that the mapping is not active anymore outside the
`with`-scope.
### Registering kernels for specific modes
You might want to register two different kernels for a particular layer,
where one kernel is optimized for a specific mode. You can do so by
registering layer repositories for specific modes. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
The `kernelize` function will attempt to use the following registered
kernels for a given mode:
- `INFERENCE`: `INFERENCE``INFERENCE | TORCH_COMPILE``TRAINING`
`TRAINING | TORCH_COMPILE``FALLBACK`
- `INFERENCE | TORCH_COMPILE`: `INFERENCE | TORCH_COMPILE`
`TRAINING | TORCH_COMPILE``FALLBACK`
- `TRAINING`: `TRAINING``TRAINING | TORCH_COMPILE``FALLBACK`
- `TRAINING | TORCH_COMPILE`: `TRAINING | TORCH_COMPILE``FALLBACK`
`Mode.FALLBACK` is a special mode that is used when no other mode matches. It
is also used when a kernel is registered without a mode, as described in the
previous section.
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.FALLBACK: LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
In this case, both `Mode.INFERENCE | Mode.TORCH_COMPILE` and
`Mode.TRAINING | Mode.TORCH_COMPILE` will use the `Mode.FALLBACK` kernel,
since the other kernels do not support `torch.compile`.
### Registering kernels for specific CUDA capabilities
Some kernels only work with newer CUDA architectures. For instance, some
kernels require capability 9.0 for the TMA unit on Hopper GPUs. `kernels`
supports registering layers for a range of CUDA capabilities. To do so,
you need to register the layer for a `Device` with type `cuda` and
set the supported range of CUDA capabilities with using `CUDAProperties`:
```python
kernel_layer_mapping = {
"SiluAndMul": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=75, max_capability=89
),
): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Device(
type="cuda",
properties=CUDAProperties(
min_capability=90, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-community/activation-hopper",
layer_name="SiluAndMul",
),
}
}
```
Capabilities behave as follows:
- The minimum and maximum capabilities are inclusive.
- When a new kernel is registered with the same min/max capabilities as
an existing kernel, the new kernel will replace the old kernel.
- When there are multiple kernels that support a capability, the kernel
with the smaller capability interval will be used. E.g. given:
- `KernelA` with `min_capability=80` and `max_capability=89`;
- `KernelB` with `min_capability=75` and `max_capability=89`;
- `kernelize` runs on a system with capability 8.6.
Then `KernelA` will be used because the interval 80..89 is smaller
than 75..89. The motivation is that kernels with smaller ranges
tend to be more optimized for a specific set of GPUs. **This behavior
might still change in the future.**

View File

@ -1,4 +1,4 @@
# Locking kernel versions
# Locking kernel/layer versions
Projects that use `setuptools` can lock the kernel versions that should be
used. First specify the accepted versions in `pyproject.toml` and make
@ -26,6 +26,24 @@ activation = get_locked_kernel("kernels-community/activation")
**Note:** the lock file is included in the package metadata, so it will only be visible
to `kernels` after doing an (editable or regular) installation of your project.
## Locked kernel layers
Locking is also supported for kernel layers. To use locked layers, register them
with the `LockedLayerRepository` class:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": LockedLayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
}
}
register_kernel_mapping(kernel_layer_mapping)
```
## Pre-downloading locked kernels
Locked kernels can be pre-downloaded by running `kernels download .` in your

55
flake.lock generated
View File

@ -51,18 +51,38 @@
"type": "github"
}
},
"hf-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1750775451,
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=",
"owner": "huggingface",
"repo": "hf-nix",
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7",
"type": "github"
},
"original": {
"owner": "huggingface",
"repo": "hf-nix",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1737453259,
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
"lastModified": 1747820358,
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
"owner": "danieldk",
"repo": "nixpkgs",
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
"rev": "d3c1681180717528068082103bf323147de6ab0b",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "outlines-v0.1.4-tgi",
"ref": "cudatoolkit-12.9-kernel-builder",
"repo": "nixpkgs",
"type": "github"
}
@ -70,11 +90,11 @@
"root": {
"inputs": {
"flake-utils": "flake-utils",
"hf-nix": "hf-nix",
"nixpkgs": [
"tgi-nix",
"hf-nix",
"nixpkgs"
],
"tgi-nix": "tgi-nix"
]
}
},
"systems": {
@ -106,27 +126,6 @@
"repo": "default",
"type": "github"
}
},
"tgi-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1741617161,
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "kernels-0.2.0",
"repo": "text-generation-inference-nix",
"type": "github"
}
}
},
"root": "root",

View File

@ -1,7 +1,7 @@
{
inputs = {
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
hf-nix.url = "github:huggingface/hf-nix";
nixpkgs.follows = "hf-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
};
outputs =
@ -9,21 +9,21 @@
self,
nixpkgs,
flake-utils,
tgi-nix,
hf-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = import nixpkgs {
inherit system;
inherit (tgi-nix.lib) config;
config = hf-nix.lib.config system;
overlays = [
tgi-nix.overlays.default
hf-nix.overlays.default
];
};
in
{
formatter = pkgs.nixfmt-rfc-style;
formatter = pkgs.nixfmt-tree;
devShells = with pkgs; rec {
default = mkShell {
buildInputs =
@ -34,10 +34,13 @@
ruff
]
++ (with python3.pkgs; [
docutils
huggingface-hub
pytest
pytest-benchmark
pyyaml
torch
types-pyyaml
venvShellHook
]);

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.4.4"
version = "0.8.1"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
@ -14,6 +14,7 @@ requires-python = ">= 3.9"
dependencies = [
"huggingface_hub>=0.26.0,<1.0",
"packaging>=20.0",
"pyyaml>=6",
"tomli>=2.0; python_version<'3.11'",
]
@ -23,11 +24,12 @@ build-backend = "setuptools.build_meta"
[dependency-groups]
dev = [
"mypy == 1.14.1",
"mypy >= 1.15.0",
"pytest >=8",
# Whatever version is compatible with pytest.
"pytest-benchmark",
"torch >=2.5",
"types-pyyaml"
]
[project.optional-dependencies]

4
pytest.ini Normal file
View File

@ -0,0 +1,4 @@
[pytest]
markers =
darwin_only: marks tests that should only run on macOS
linux_only: marks tests that should only run on Linux

View File

@ -1,6 +1,9 @@
from kernels.layer import (
CUDAProperties,
Device,
LayerRepository,
Mode,
kernelize,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
@ -8,6 +11,7 @@ from kernels.layer import (
)
from kernels.utils import (
get_kernel,
get_local_kernel,
get_locked_kernel,
has_kernel,
install_kernel,
@ -15,15 +19,19 @@ from kernels.utils import (
)
__all__ = [
"CUDAProperties",
"Device",
"LayerRepository",
"Mode",
"get_kernel",
"get_local_kernel",
"get_locked_kernel",
"has_kernel",
"load_kernel",
"install_kernel",
"use_kernel_forward_from_hub",
"use_kernel_mapping",
"kernelize",
"load_kernel",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"LayerRepository",
"Device",
"use_kernel_forward_from_hub",
"use_kernel_mapping",
]

View File

@ -0,0 +1,200 @@
# AVL-balanced interval trees. We could use the intervaltree
# packages, but it seems unmaintained and does not have type
# annotations.
from typing import Generic, List, Optional, Tuple, TypeVar
T = TypeVar("T")
class _Node(Generic[T]):
"""A node in the interval tree."""
def __init__(self, start: int, end: int, data: T):
self.start: int = start
self.end: int = end
self.data: T = data
self.max_end: int = end
self.left: Optional["_Node[T]"] = None
self.right: Optional["_Node[T]"] = None
self.height: int = 1
def __repr__(self) -> str:
return f"Node({self.start}, {self.end})"
class IntervalTree(Generic[T]):
"""A data structure to hold and query (unique) intervals."""
root: Optional[_Node[T]]
def __init__(self):
self.root = None
def insert(self, start: int, end: int, data: T) -> None:
"""
Inserts a new interval into the tree.
Args:
start: The starting point of the interval.
end: The ending point of the interval.
data: The data associated with this interval.
"""
self.root = self._insert(self.root, start, end, data)
def _get_height(self, node: Optional[_Node[T]]) -> int:
if not node:
return 0
return node.height
def _get_balance(self, node: Optional[_Node[T]]) -> int:
if not node:
return 0
return self._get_height(node.left) - self._get_height(node.right)
def _update_node_attributes(self, node: _Node[T]) -> None:
node.height = 1 + max(self._get_height(node.left), self._get_height(node.right))
node.max_end = node.end
if node.left:
node.max_end = max(node.max_end, node.left.max_end)
if node.right:
node.max_end = max(node.max_end, node.right.max_end)
def _right_rotate(self, y: _Node[T]) -> _Node[T]:
"""Performs a right rotation."""
x = y.left
assert x is not None
T2 = x.right
x.right = y
y.left = T2
self._update_node_attributes(y)
self._update_node_attributes(x)
return x
def _left_rotate(self, x: _Node[T]) -> _Node[T]:
"""Performs a left rotation."""
y = x.right
assert y is not None
T2 = y.left
y.left = x
x.right = T2
self._update_node_attributes(x)
self._update_node_attributes(y)
return y
def _insert(
self, node: Optional[_Node[T]], start: int, end: int, data: T
) -> _Node[T]:
"""Recursive helper to insert a new node and balance the tree."""
if not node:
return _Node(start, end, data)
# Replace the data if the interval already exists.
if start == node.start and end == node.end:
node.data = data
return node
if start < node.start:
node.left = self._insert(node.left, start, end, data)
else:
node.right = self._insert(node.right, start, end, data)
self._update_node_attributes(node)
balance = self._get_balance(node)
# Left Left Case
if balance > 1 and node.left and start < node.left.start:
return self._right_rotate(node)
# Right Right Case
if balance < -1 and node.right and start >= node.right.start:
return self._left_rotate(node)
# Left Right Case
if balance > 1 and node.left and start >= node.left.start:
node.left = self._left_rotate(node.left)
return self._right_rotate(node)
# Right Left Case
if balance < -1 and node.right and start < node.right.start:
node.right = self._right_rotate(node.right)
return self._left_rotate(node)
return node
def search(self, point: int) -> List[T]:
"""
Searches for all intervals that contain the given point.
Args:
point: The point to search for.
Returns:
A list of data items from all matching intervals.
"""
results: List[T] = []
self._search(self.root, point, results)
return results
def _search(self, node: Optional[_Node[T]], point: int, results: List[T]) -> None:
"""Recursive helper to find all overlapping intervals."""
if node is None or point > node.max_end:
return
if node.left:
self._search(node.left, point, results)
if node.start <= point <= node.end:
results.append(node.data)
if point >= node.start and node.right:
self._search(node.right, point, results)
def find_smallest_interval(self, point: int) -> Optional[T]:
"""
Finds the item with the most specific (smallest) range for a given point.
Args:
point: The capability to look up.
Returns:
The data of the best-matching item, or None if no match is found.
"""
matches: List[Tuple[int, int, T]] = []
self._find_with_intervals(self.root, point, matches)
if not matches:
return None
# Return the smallest interval, sort by memory location when
# there are multiple matches with the same interval size. This
# is just to ensure that we can compare against a trivial
# implementation in tests.
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
return best_match[2]
def _find_with_intervals(
self,
node: Optional[_Node[T]],
point: int,
results: List[Tuple[int, int, T]],
) -> None:
"""A modified search that collects interval ranges along with data."""
if node is None or point > node.max_end:
return
if node.left:
self._find_with_intervals(node.left, point, results)
if node.start <= point <= node.end:
results.append((node.start, node.end, node.data))
if point >= node.start and node.right:
self._find_with_intervals(node.right, point, results)

View File

@ -0,0 +1,751 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Vendored from https://github.com/huggingface/doc-builder/blob/main/src/doc_builder/convert_rst_to_mdx.py
import re
# Re pattern to catch things inside ` ` in :obj:`thing`.
_re_obj = re.compile(r":obj:`([^`]+)`")
# Re pattern to catch things inside ` ` in :math:`thing`.
_re_math = re.compile(r":math:`([^`]+)`")
# Re pattern to catch things between single backquotes.
_re_single_backquotes = re.compile(r"(^|[^`])`([^`]+)`([^`]|$)")
# Re pattern to catch things between double backquotes.
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
# Re pattern to catch things inside ` ` in :func/class/meth:`thing`.
_re_func_class = re.compile(r":(?:func|class|meth):`([^`]+)`")
def convert_rst_formatting(text):
"""
Convert rst syntax for formatting to markdown in a given text.
"""
# Remove :class:, :func: and :meth: markers. To code-links and put double backquotes
# (to not be caught by the italic conversion).
text = _re_func_class.sub(r"[``\1``]", text)
# Remove :obj: markers. What's after is in a single backquotes so we put in double backquotes
# (to not be caught by the italic conversion).
text = _re_obj.sub(r"``\1``", text)
# Remove :math: markers.
text = _re_math.sub(r"\\\\(\1\\\\)", text)
# Convert content in single backquotes to italic.
text = _re_single_backquotes.sub(r"\1*\2*\3", text)
# Convert content in double backquotes to single backquotes.
text = _re_double_backquotes.sub(r"\1`\2`\3", text)
# Remove remaining ::
text = re.sub(r"::\n", "", text)
# Remove new lines inside blocks in backsticks as they will be kept.
lines = text.split("\n")
in_code = False
text = None
for line in lines:
if in_code:
splits = line.split("`")
in_code = len(splits) > 1 and len(splits) % 2 == 1
if len(splits) == 1:
# Some forgotten lone backstick
text += "\n" + line
else:
text += " " + line.lstrip()
else:
if text is not None:
text += "\n" + line
else:
text = line
splits = line.split("`")
in_code = len(splits) % 2 == 0
return text
# Re pattern to catch description and url in links of the form `description <url>`_.
_re_links = re.compile(r"`([^`]+\S)\s+</*([^/][^>`]*)>`_+")
# Re pattern to catch description and url in links of the form :prefix_link:`description <url>`_.
_re_prefix_links = re.compile(r":prefix_link:`([^`]+\S)\s+</*([^/][^>`]*)>`")
# Re pattern to catch reference in links of the form :doc:`reference`.
_re_simple_doc = re.compile(r":doc:`([^`<]*)`")
# Re pattern to catch description and reference in links of the form :doc:`description <reference>`.
_re_doc_with_description = re.compile(r":doc:`([^`<]+\S)\s+</*([^/][^>`]*)>`")
# Re pattern to catch reference in links of the form :ref:`reference`.
_re_simple_ref = re.compile(r":ref:`([^`<]*)`")
# Re pattern to catch description and reference in links of the form :ref:`description <reference>`.
_re_ref_with_description = re.compile(r":ref:`([^`<]+\S)\s+<([^>]*)>`")
def convert_rst_links(text, page_info):
"""
Convert the rst links in text to markdown.
"""
if "package_name" not in page_info:
raise ValueError("`page_info` must contain at least the package_name.")
package_name = page_info["package_name"]
version = page_info.get("version", "main")
language = page_info.get("language", "en")
no_prefix = page_info.get("no_prefix", False)
prefix = "" if no_prefix else f"/docs/{package_name}/{version}/{language}/"
# Links of the form :doc:`page`
text = _re_simple_doc.sub(rf"[\1]({prefix}\1)", text)
# Links of the form :doc:`text <page>`
text = _re_doc_with_description.sub(rf"[\1]({prefix}\2)", text)
if "page" in page_info and not no_prefix:
page = str(page_info["page"])
if page.endswith(".html"):
page = page[:-5]
prefix = f"{prefix}{page}"
else:
prefix = ""
# Refs of the form :ref:`page`
text = _re_simple_ref.sub(rf"[\1]({prefix}#\1)", text)
# Refs of the form :ref:`text <page>`
text = _re_ref_with_description.sub(rf"[\1]({prefix}#\2)", text)
# Links with a prefix
# TODO: when it exists, use the API to deal with prefix links properly.
prefix = f"https://github.com/huggingface/{package_name}/tree/main/"
text = _re_prefix_links.sub(rf"[\1]({prefix}\2)", text)
# Other links
text = _re_links.sub(r"[\1](\2)", text)
# Relative links or Transformers links need to remove the .html
if (
"(https://https://huggingface.co/" in text
or re.search(r"\(\.+/", text) is not None
):
text = text.replace(".html", "")
return text
# Re pattern that catches examples blocks of the form `Example::`.
_re_example = re.compile(r"^\s*(\S.*)::\s*$")
# Re pattern that catches rst blocks of the form `.. block_name::`.
_re_block = re.compile(r"^\s*\.\.\s+(\S+)::")
# Re pattern that catches what's after the :: in rst blocks of the form `.. block_name:: something`.
_re_block_info = re.compile(r"^\s*\.\.\s+\S+::\s*(\S.*)$")
def is_empty_line(line):
return len(line) == 0 or line.isspace()
def find_indent(line):
"""
Returns the number of spaces that start a line indent.
"""
search = re.search(r"^(\s*)(?:\S|$)", line)
if search is None:
return 0
return len(search.groups()[0])
_re_rst_option = re.compile(r"^\s*:(\S+):(.*)$")
def convert_special_chars(text):
"""
Converts { and < that have special meanings in MDX.
"""
text = text.replace("{", "&amp;lcub;")
# We don't want to replace those by the HTML code, so we temporarily set them at LTHTML
text = re.sub(
r"<(img|br|hr|Youtube)", r"LTHTML\1", text
) # html void elements with no closing counterpart
_re_lt_html = re.compile(r"<(\S+)([^>]*>)(((?!</\1>).)*)<(/\1>)", re.DOTALL)
while _re_lt_html.search(text):
text = _re_lt_html.sub(r"LTHTML\1\2\3LTHTML\5", text)
text = re.sub(r"(^|[^<])<([^<]|$)", r"\1&amp;lt;\2", text)
text = text.replace("LTHTML", "<")
return text
def parse_options(block_content):
"""
Parses the option in some rst block content.
"""
block_lines = block_content.split("\n")
block_indent = find_indent(block_lines[0])
current_option = None
result = {}
for line in block_lines:
if _re_rst_option.search(line) is not None:
current_option, value = _re_rst_option.search(line).groups()
result[current_option] = value.lstrip()
elif find_indent(line) > block_indent:
result[current_option] += " " + line.lstrip()
return result
def apply_min_indent(text, min_indent):
"""
Make sure all lines in a text are have a minimum indentation.
Args:
text (`str`): The text to treat.
min_indent (`int`): The minimal indentation.
Returns:
`str`: The processed text.
"""
lines = text.split("\n")
idx = 0
while idx < len(lines):
if is_empty_line(lines[idx]):
idx += 1
continue
indent = find_indent(lines[idx])
if indent < min_indent:
while idx < len(lines) and (
find_indent(lines[idx]) >= indent or is_empty_line(lines[idx])
):
if not is_empty_line(lines[idx]):
lines[idx] = " " * (min_indent - indent) + lines[idx]
idx += 1
else:
idx += 1
return "\n".join(lines)
def convert_rst_blocks(text, page_info):
"""
Converts rst special blocks (examples, notes) into MDX.
"""
if "package_name" not in page_info:
raise ValueError("`page_info` must contain at least the package_name.")
package_name = page_info["package_name"]
version = page_info.get("version", "main")
language = page_info.get("language", "en")
lines = text.split("\n")
idx = 0
new_lines = []
while idx < len(lines):
block_type = None
block_info = None
if _re_block.search(lines[idx]) is not None:
block_type = _re_block.search(lines[idx]).groups()[0]
if _re_block_info.search(lines[idx]) is not None:
block_info = _re_block_info.search(lines[idx]).groups()[0]
elif _re_example.search(lines[idx]) is not None:
block_type = "code-block-example"
block_info = "python"
example_name = _re_example.search(lines[idx]).groups()[0]
new_lines.append(f"<exampletitle>{example_name}:</exampletitle>\n")
elif lines[idx].strip() == "..":
block_type = "comment"
elif lines[idx].strip() == "::":
block_type = "code-block"
if block_type is not None:
block_indent = find_indent(lines[idx])
# Find the next nonempty line
idx += 1
while idx < len(lines) and is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the return line, this block will stop when we unindent under it (or has already)
example_indent = (
find_indent(lines[idx]) if idx < len(lines) else block_indent
)
if example_indent == block_indent:
block_content = ""
else:
block_lines = []
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) >= example_indent
):
block_lines.append(lines[idx][example_indent:])
idx += 1
block_content = "\n".join(block_lines)
if block_type in ["code", "code-block"]:
prefix = "```" if block_info is None else f"```{block_info}"
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n")
elif block_type == "code-block-example":
prefix = f"<example>```{block_info}"
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n</example>")
elif block_type == "note":
new_lines.append(
apply_min_indent(
f"<Tip>\n\n{block_content.strip()}\n\n</Tip>\n", block_indent
)
)
elif block_type == "warning":
new_lines.append(
apply_min_indent(
"<Tip warning={true}>\n\n"
+ f"{block_content.strip()}\n\n</Tip>\n",
block_indent,
)
)
elif block_type == "raw":
new_lines.append(block_content.strip() + "\n")
elif block_type == "math":
new_lines.append(f"$${block_content.strip()}$$\n")
elif block_type == "comment":
new_lines.append(f"<!--{block_content.strip()}\n-->\n")
elif block_type == "autofunction":
if block_info is not None:
new_lines.append(f"[[autodoc]] {block_info}\n")
elif block_type == "autoclass":
if block_info is not None:
block = f"[[autodoc]] {block_info}\n"
options = parse_options(block_content)
if "special-members" in options:
special_members = options["special-members"].split(", ")
for special_member in special_members:
block += f" - {special_member}\n"
if "members" in options:
members = options["members"]
if len(members) == 0:
block += " - all\n"
else:
for member in members.split(", "):
block += f" - {member}\n"
new_lines.append(block)
elif block_type == "image":
options = parse_options(block_content)
target = options.pop("target", None)
if block_info is not None:
options["src"] = block_info
else:
if target is None:
raise ValueError("Image source not defined.")
options["src"] = target
# Adapt path
options["src"] = options["src"].replace(
"/imgs/", f"/docs/{package_name}/{version}/{language}/imgs/"
)
html_code = " ".join(
[f'{key}="{value}"' for key, value in options.items()]
)
new_lines.append(f"<img {html_code}/>\n")
else:
new_lines.append(
f"{block_type},{block_info}\n{block_content.rstrip()}\n"
)
else:
new_lines.append(lines[idx])
idx += 1
return "\n".join(new_lines)
# Re pattern that catches rst args blocks of the form `Parameters:`.
_re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\s*$")
# Re pattern that catches return blocks of the form `Return:`.
_re_returns = re.compile(r"^\s*(Return|Yield|Raise)s?:\s*$")
def split_return_line(line):
"""
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
"""
splits_on_colon = line.split(":")
idx = 1
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
idx += 2
if idx >= len(splits_on_colon):
if len(splits_on_colon) % 2 == 1 and re.search(r"`\w+`$", line.rstrip()):
return line, ""
return None, line
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
def split_raise_line(line):
"""
Split the raise line with format `SomeError some doc`.
"""
splits_on_colon = line.strip().split(" ")
error_type, doc = splits_on_colon[0], " ".join(splits_on_colon[1:])
if error_type and error_type[-1] == ":":
error_type = error_type[:-1]
return error_type, doc
def split_arg_line(line):
"""
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
"""
splits_on_colon = line.split(":")
idx = 1
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
idx += 2
if idx >= len(splits_on_colon):
return line, ""
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
class InvalidRstDocstringError(ValueError):
pass
_re_parameters = re.compile(
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
)
_re_md_link = re.compile(r"\[(.+)\]\(.+\)", re.DOTALL)
def parse_rst_docstring(docstring):
"""
Parses a docstring written in rst, in particular the list of arguments and the return type.
"""
lines = docstring.split("\n")
idx = 0
while idx < len(lines):
# Parameters section
if _re_args.search(lines[idx]) is not None:
# Title of the section.
lines[idx] = "<parameters>\n"
# Find the next nonempty line
idx += 1
while is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the list of parameters, this block will stop when we unindent under it or we see the
# Returns or Raises block.
param_indent = find_indent(lines[idx])
while (
idx < len(lines)
and find_indent(lines[idx]) == param_indent
and _re_returns.search(lines[idx]) is None
):
intro, doc = split_arg_line(lines[idx])
# Line starting with a > after indent indicate a "section title" in the parameters.
if intro.lstrip().startswith(">"):
lines[idx] = intro.lstrip()
else:
lines[idx] = (
re.sub(r"^\s*(\S+)(\s)?", r"- **\1**\2", intro) + " --" + doc
)
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx]) or find_indent(lines[idx]) > param_indent
):
idx += 1
lines.insert(idx, "</parameters>\n")
idx += 1
# Returns section
elif _re_returns.search(lines[idx]) is not None:
# tag is either `return` or `yield`
tag = _re_returns.match(lines[idx]).group(1).lower()
# Title of the section.
lines[idx] = f"<{tag}s>\n"
# Find the next nonempty line
idx += 1
while is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the return line, this block will stop when we unindent under it.
return_indent = find_indent(lines[idx])
raised_errors = []
# The line may contain the return type.
if tag in ["return", "yield"]:
return_type, return_description = split_return_line(lines[idx])
lines[idx] = return_description
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) >= return_indent
):
idx += 1
else:
while idx < len(lines) and find_indent(lines[idx]) == return_indent:
return_type, return_description = split_raise_line(lines[idx])
raised_error = re.sub(r"^\s*`?([\w\.]*)`?$", r"``\1``", return_type)
lines[idx] = "- " + raised_error + " -- " + return_description
md_link = _re_md_link.match(raised_error)
if md_link:
raised_error = md_link[1]
raised_error = re.sub(
r"^\s*`?([\w\.]*)`?$", r"``\1``", raised_error
)
if raised_error not in raised_errors:
raised_errors.append(raised_error)
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) > return_indent
):
idx += 1
lines.insert(idx, f"</{tag}s>\n")
idx += 1
# Return block finished, we insert the return type if one was specified
if tag in ["return", "yield"] and return_type is not None:
lines[idx - 1] += f"\n<{tag}type>{return_type}</{tag}type>\n"
elif len(raised_errors) > 0:
# raised errors
lines[
idx - 1
] += f"\n<raisederrors>{' or '.join(raised_errors)}</raisederrors>\n"
else:
idx += 1
result = "\n".join(lines)
# combine multiple <parameters> blocks into one block
if result.count("<parameters>") > 1:
parameters_blocks = _re_parameters.findall(result)
parameters_blocks = [pb[0].strip() for pb in parameters_blocks]
parameters_str = "\n".join(parameters_blocks)
result = _re_parameters.sub("", result)
result += f"\n<parameters>{parameters_str}</parameters>\n"
return result
_re_list = re.compile(r"^\s*(-|\*|\d+\.)\s")
_re_autodoc = re.compile(r"^\s*\[\[autodoc\]\]\s+(\S+)\s*$")
def remove_indent(text):
"""
Remove indents in text, except the one linked to lists (or sublists).
"""
lines = text.split("\n")
# List of indents to remember for nested lists
current_indents = []
# List of new indents to remember for nested lists
new_indents = []
is_inside_code = False
code_indent = 0
for idx, line in enumerate(lines):
# Line is an item in a list.
if _re_list.search(line) is not None:
indent = find_indent(line)
# Is it a new list / new level of nestedness?
if len(current_indents) == 0 or indent > current_indents[-1]:
current_indents.append(indent)
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indent += len(_re_list.search(line).groups()[0]) + 1
new_indents.append(new_indent)
# Otherwise it's an existing level of list (current one, or previous one)
else:
# Let's find the proper level of indentation
level = len(current_indents) - 1
while level >= 0 and current_indents[level] != indent:
level -= 1
current_indents = current_indents[: level + 1]
new_indents = new_indents[:level]
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indent += len(_re_list.search(line).groups()[0]) + 1
new_indents.append(new_indent)
# Line is an autodoc, we keep the indent for the list just after if there is one.
elif _re_autodoc.search(line) is not None:
indent = find_indent(line)
current_indents = [indent]
new_indents = [4]
lines[idx] = line.strip()
# Deal with empty lines separately
elif is_empty_line(line):
lines[idx] = ""
# Code blocks
elif line.lstrip().startswith("```"):
is_inside_code = not is_inside_code
if is_inside_code:
code_indent = find_indent(line)
lines[idx] = line[code_indent:]
elif is_inside_code:
lines[idx] = line[code_indent:]
else:
indent = find_indent(line)
if len(current_indents) > 0 and indent > current_indents[-1]:
lines[idx] = " " * new_indents[-1] + line[indent:]
elif len(current_indents) > 0:
# Let's find the proper level of indentation
level = len(current_indents) - 1
while level >= 0 and current_indents[level] > indent:
level -= 1
current_indents = current_indents[: level + 1]
if level >= 0:
if current_indents[level] < indent:
new_indents = new_indents[: level + 1]
else:
new_indents = new_indents[:level]
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indents.append(new_indent)
else:
new_indents = []
lines[idx] = line[indent:]
else:
lines[idx] = line[indent:]
return "\n".join(lines)
def base_rst_to_mdx(text, page_info, unindent=True):
"""
Convert a text from rst to mdx, with the base operations necessary for both docstrings and rst docs.
"""
text = convert_rst_links(text, page_info)
text = convert_special_chars(text)
text = convert_rst_blocks(text, page_info)
# Convert * in lists to - to avoid the formatting conversion treat them as bold.
text = re.sub(r"^(\s*)\*(\s)", r"\1-\2", text, flags=re.MULTILINE)
text = convert_rst_formatting(text)
return remove_indent(text) if unindent else text
def convert_rst_docstring_to_mdx(docstring, page_info):
"""
Convert a docstring written in rst to mdx.
"""
text = parse_rst_docstring(docstring)
return base_rst_to_mdx(text, page_info)
def process_titles(lines):
"""Converts rst titles to markdown titles."""
title_chars = """= - ` : ' " ~ ^ _ * + # < >""".split(" ")
title_levels = {}
new_lines = []
for line in lines:
if (
len(new_lines) > 0
and len(line) >= len(new_lines[-1])
and len(set(line)) == 1
and line[0] in title_chars
and line != "::"
):
char = line[0]
level = title_levels.get(char, len(title_levels) + 1)
if level not in title_levels:
title_levels[char] = level
new_lines[-1] = f"{'#' * level} {new_lines[-1]}"
else:
new_lines.append(line)
return new_lines
# Matches lines with a pattern of a table new line in rst.
_re_ignore_line_table = re.compile(r"^(\+[\-\s]+)+\+\s*$")
# Matches lines with a pattern of a table new line in rst, with a first column empty.
_re_ignore_line_table1 = re.compile(r"^\|\s+(\+[\-\s]+)+\+\s*$")
# Matches lines with a pattern of a first table line in rst.
_re_sep_line_table = re.compile(r"^(\+[=\s]+)+\+\s*$")
# Re pattern that catches anchors of the type .. reference:
_re_anchor_section = re.compile(r"^\.\.\s+_(\S+):")
def split_pt_tf_code_blocks(text):
"""
Split PyTorch and TensorFlow specific block codes.
"""
lines = text.split("\n")
new_lines = []
idx = 0
while idx < len(lines):
if lines[idx].startswith("```"):
code_lines = {"common": [lines[idx]], "pytorch": [], "tensorflow": []}
is_pytorch = False
is_tensorflow = False
idx += 1
while idx < len(lines) and lines[idx].strip() != "```":
if "## PYTORCH CODE" in lines[idx]:
is_pytorch = True
is_tensorflow = False
elif "## TENSORFLOW CODE" in lines[idx]:
is_tensorflow = True
is_pytorch = False
elif is_pytorch:
code_lines["pytorch"].append(lines[idx])
elif is_tensorflow:
code_lines["tensorflow"].append(lines[idx])
else:
code_lines["common"].append(lines[idx])
idx += 1
if len(code_lines["pytorch"]) > 0 or len(code_lines["tensorflow"]) > 0:
block_lines = ["<frameworkcontent>", "<pt>"]
block_lines.extend(code_lines["common"].copy() + code_lines["pytorch"])
block_lines.extend(["```", "</pt>", "<tf>"])
block_lines.extend(
code_lines["common"].copy() + code_lines["tensorflow"]
)
block_lines.extend(["```", "</tf>", "</frameworkcontent>"])
new_lines.extend(block_lines)
else:
block_lines = code_lines["common"] + ["```"]
new_lines.extend(block_lines)
idx += 1
else:
new_lines.append(lines[idx])
idx += 1
return "\n".join(new_lines)
def convert_rst_to_mdx(rst_text, page_info, add_imports=True):
"""
Convert a document written in rst to mdx.
"""
lines = rst_text.split("\n")
lines = process_titles(lines)
if add_imports:
new_lines = [
'<script lang="ts">',
' import Tip from "$lib/Tip.svelte";',
' import Youtube from "$lib/Youtube.svelte";',
' import Docstring from "$lib/Docstring.svelte";',
' import CodeBlock from "$lib/CodeBlock.svelte";',
' import CodeBlockFw from "$lib/CodeBlockFw.svelte";',
' import DocNotebookDropdown from "$lib/DocNotebookDropdown.svelte";',
' import CourseFloatingBanner from "$lib/CourseFloatingBanner.svelte";',
' import IconCopyLink from "$lib/IconCopyLink.svelte";',
' import FrameworkContent from "$lib/FrameworkContent.svelte";',
' import Markdown from "$lib/Markdown.svelte";',
' import ExampleCodeBlock from "$lib/ExampleCodeBlock.svelte";',
' import Added from "$lib/Added.svelte";',
' import Changed from "$lib/Changed.svelte";',
' import Deprecated from "$lib/Deprecated.svelte";',
' import PipelineIcon from "$lib/PipelineIcon.svelte";',
' import PipelineTag from "$lib/PipelineTag.svelte";',
" ",
' export let fw: "pt" | "tf"',
"</script>",
"<svelte:head>",
'<meta name="hf:doc:metadata" content={JSON.stringify(metadata)} >',
"</svelte:head>",
"",
]
else:
new_lines = []
for line in lines:
if _re_ignore_line_table.search(line) is not None:
continue
elif _re_ignore_line_table1.search(line) is not None:
continue
elif _re_sep_line_table.search(line) is not None:
line = line.replace("=", "-").replace("+", "|")
elif _re_anchor_section.search(line) is not None:
anchor_name = _re_anchor_section.search(line).groups()[0]
line = f"<a id='{anchor_name}'></a>"
new_lines.append(line)
text = "\n".join(new_lines)
return split_pt_tf_code_blocks(base_rst_to_mdx(text, page_info))

52
src/kernels/_versions.py Normal file
View File

@ -0,0 +1,52 @@
from typing import Dict, Optional
from huggingface_hub import HfApi
from huggingface_hub.hf_api import GitRefInfo
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
"""Get kernel versions that are available in the repository."""
versions = {}
for tag in HfApi().list_repo_refs(repo_id).tags:
if not tag.name.startswith("v"):
continue
try:
versions[Version(tag.name[1:])] = tag
except InvalidVersion:
continue
return versions
def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo:
"""
Get the locks for a kernel with the given version spec.
The version specifier can be any valid Python version specifier:
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
"""
versions = _get_available_versions(repo_id)
requirement = SpecifierSet(version_spec)
accepted_versions = sorted(requirement.filter(versions.keys()))
if len(accepted_versions) == 0:
raise ValueError(
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
)
return versions[accepted_versions[-1]]
def select_revision_or_version(
repo_id: str, revision: Optional[str], version: Optional[str]
) -> str:
if revision is not None and version is not None:
raise ValueError("Either a revision or a version must be specified, not both.")
elif revision is None and version is None:
revision = "main"
elif version is not None:
revision = resolve_version_spec_as_ref(repo_id, version).target_commit
assert revision is not None
return revision

View File

@ -8,6 +8,9 @@ from kernels.compat import tomllib
from kernels.lockfile import KernelLock, get_kernel_locks
from kernels.utils import install_kernel, install_kernel_all_variants
from .doc import generate_readme_for_kernel
from .wheel import build_variant_to_wheel
def main():
parser = argparse.ArgumentParser(
@ -36,6 +39,47 @@ def main():
)
lock_parser.set_defaults(func=lock_kernels)
to_wheel_parser = subparsers.add_parser(
"to-wheel", help="Convert a kernel to a wheel file"
)
to_wheel_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
to_wheel_parser.add_argument("version", type=str, help="The kernel version")
to_wheel_parser.add_argument(
"--python-version",
type=str,
default="3.9",
help="The minimum Python version. Must match the Python version that the kernel was compiled for.",
)
to_wheel_parser.add_argument(
"--manylinux-version",
type=str,
default="2.28",
help="The manylinux version. Must match the manylinux version that the kernel was compiled for.",
)
to_wheel_parser.set_defaults(func=kernels_to_wheel)
# Add generate-readme subcommand parser
generate_readme_parser = subparsers.add_parser(
"generate-readme",
help="Generate README snippets for a kernel's public functions",
)
generate_readme_parser.add_argument(
"repo_id",
type=str,
help="The kernel repo ID (e.g., kernels-community/activation)",
)
generate_readme_parser.add_argument(
"--revision",
type=str,
default="main",
help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
)
generate_readme_parser.set_defaults(
func=lambda args: generate_readme_for_kernel(
repo_id=args.repo_id, revision=args.revision
)
)
args = parser.parse_args()
args.func(args)
@ -77,6 +121,24 @@ def download_kernels(args):
sys.exit(1)
def kernels_to_wheel(args):
variants_path = install_kernel_all_variants(
repo_id=args.repo_id, revision=f"v{args.version}"
)
for variant_path in variants_path.iterdir():
if not variant_path.is_dir():
continue
wheel_path = build_variant_to_wheel(
manylinux_version=args.manylinux_version,
python_version=args.python_version,
repo_id=args.repo_id,
version=args.version,
variant_path=variant_path,
wheel_dir=Path("."),
)
print(f"☸️ {wheel_path.name}", file=sys.stderr)
def lock_kernels(args):
with open(args.project_dir / "pyproject.toml", "rb") as f:
data = tomllib.load(f)

242
src/kernels/doc.py Normal file
View File

@ -0,0 +1,242 @@
import inspect
import re
import sys
from types import ModuleType
import yaml
from ._vendored.convert_rst_to_mdx import convert_rst_docstring_to_mdx
from .utils import get_kernel
_RE_PARAMETERS = re.compile(
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
)
_RE_RETURNS = re.compile(r"<returns>(((?!<returns>).)*)</returns>", re.DOTALL)
_RE_RETURNTYPE = re.compile(
r"<returntype>(((?!<returntype>).)*)</returntype>", re.DOTALL
)
def _extract_description_before_tags(docstring_mdx: str) -> str:
"""Extract the description part of a docstring before any tags."""
params_pos = docstring_mdx.find("<parameters>")
returns_pos = docstring_mdx.find("<returns>")
returntype_pos = docstring_mdx.find("<returntype>")
positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1]
if positions:
first_tag_pos = min(positions)
return docstring_mdx[:first_tag_pos].strip()
else:
return docstring_mdx.strip()
def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None:
"""Print the parameters section from a docstring."""
matches = _RE_PARAMETERS.findall(docstring_mdx)
if matches:
header = "#" * header_level
print(f"\n{header} Parameters")
for match in matches:
print(f"\n{match[0].strip()}")
def _print_returns_section(
docstring_mdx: str, *, context_name: str, header_level: int
) -> None:
"""Print the returns section from a docstring."""
return_matches = _RE_RETURNS.findall(docstring_mdx)
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
if return_matches or returntype_matches:
header = "#" * header_level
print(f"\n{header} Returns")
if returntype_matches:
if len(returntype_matches) > 1:
raise ValueError(
f"More than one <returntype> tag found in docstring for {context_name}"
)
print(f"\n**Type**: {returntype_matches[0][0].strip()}")
if return_matches:
for match in return_matches:
print(f"\n{match[0].strip()}")
def _get_docstring(obj, use_dict_check: bool = False) -> str:
"""Get docstring from an object, with fallback to default message."""
# Check whether the class/method itself has docs and not just
# the superclass.
if use_dict_check:
has_doc = obj.__dict__.get("__doc__", None) is not None
else:
has_doc = getattr(obj, "__doc__", None) is not None
# We use inspect.getdoc because it does normalization.
doc = inspect.getdoc(obj)
return doc if has_doc and doc is not None else "No documentation available."
def _process_and_print_docstring(
docstring: str, *, kernel_name: str, context_name: str, header_level: int
) -> None:
"""Convert docstring to MDX and print description, parameters, and returns sections."""
docstring_mdx = convert_rst_docstring_to_mdx(
docstring, page_info={"package_name": kernel_name}
)
# Print the description
description = _extract_description_before_tags(docstring_mdx)
print(f"\n{description}")
# Print parameters and returns sections
_print_parameters_section(docstring_mdx, header_level=header_level)
_print_returns_section(
docstring_mdx, context_name=context_name, header_level=header_level
)
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
kernel_module = get_kernel(repo_id=repo_id, revision=revision)
kernel_name = repo_id.split("/")[-1].replace("-", "_")
generate_metadata(kernel_module)
generate_kernel_doc(kernel_module, kernel_name)
generate_function_doc(kernel_module, kernel_name)
generate_layers_doc(kernel_module, kernel_name)
def generate_metadata(module: ModuleType) -> None:
metadata = getattr(module, "__kernel_metadata__", {})
if "tags" not in metadata:
metadata["tags"] = ["kernel"]
else:
if "kernel" not in metadata["tags"]:
metadata["tags"].append("kernel")
print("---")
print(yaml.dump(metadata), end="")
print("---")
def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None:
docstring = module.__doc__.strip() if module.__doc__ is not None else None
if docstring:
title, rest = docstring.split("\n", 1)
print(f"# {title.strip()}")
print(
f"\n{convert_rst_docstring_to_mdx(rest.strip(), page_info={'package_name': kernel_name})}"
)
def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None:
print("\n## Functions")
# Track if we found any functions
found_functions = False
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
# Do not include imported functions.
if func.__module__ != kernel_module.__name__:
continue
# Exclude private functions.
if name.startswith("_"):
continue
found_functions = True
try:
sig = inspect.signature(func)
docstring = _get_docstring(func)
except ValueError:
print(
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
file=sys.stderr,
)
continue
print(f"\n### Function `{name}`")
print(f"\n`{sig}`")
_process_and_print_docstring(
docstring, kernel_name=kernel_name, context_name=name, header_level=3
)
if not found_functions:
print("\nNo public top-level functions.")
def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None:
# Check if layers module is available
layers_module = getattr(kernel_module, "layers", None)
if layers_module is None:
return
print("\n## Layers")
# Track if we found any classes
found_classes = False
for class_name, cls in inspect.getmembers(layers_module, inspect.isclass):
# Exclude classes that were imported.
if cls.__module__ != layers_module.__name__:
continue
found_classes = True
try:
# Get docstring, but not from superclasses.
class_docstring = _get_docstring(cls, use_dict_check=True)
except Exception:
print(
f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}",
file=sys.stderr,
)
continue
print(f"\n### Class `{class_name}`")
# Always print class description (helper handles conversion and formatting)
class_docstring_mdx = convert_rst_docstring_to_mdx(
class_docstring, page_info={"package_name": kernel_name}
)
description = _extract_description_before_tags(class_docstring_mdx)
print(f"\n{description}")
# Document methods
print("\n#### Methods")
for method_name, method in inspect.getmembers(cls, inspect.isfunction):
# Note: also skip __init__, since extension layers cannot have a constructor.
if method_name.startswith("_"):
continue
# Skip methods from superclasses.
if method_name not in cls.__dict__:
continue
try:
sig = inspect.signature(method)
method_docstring = _get_docstring(method)
except ValueError:
print(
f"Warning: Could not retrieve signature for {method_name} in {class_name}",
file=sys.stderr,
)
continue
print(f"\n##### Method `{method_name}`")
print(f"\n`{sig}`")
_process_and_print_docstring(
method_docstring,
kernel_name=kernel_name,
context_name=method_name,
header_level=6,
)
if not found_classes:
print("\nNo layers defined.")

View File

@ -1,65 +1,327 @@
from __future__ import annotations
import functools
import inspect
import logging
import os
import sys
import warnings
from abc import ABC, abstractmethod
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, Union
from dataclasses import dataclass
from enum import Flag, auto
from functools import lru_cache
from pathlib import Path
from types import MethodType
from typing import (
TYPE_CHECKING,
Dict,
Optional,
Protocol,
Tuple,
Type,
Union,
)
from .utils import get_kernel
from ._interval_tree import IntervalTree
from ._versions import select_revision_or_version
from .utils import _get_caller_locked_kernel, _get_locked_kernel, get_kernel
if TYPE_CHECKING:
import torch
from torch import nn
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
class Mode(Flag):
"""
Kernelize mode
The `Mode` flag is used by `kernelize` to select kernels for the given
mode. Mappings can be registered for specific modes.
* `INFERENCE`: The kernel is used for inference.
* `TRAINING`: The kernel is used for training.
* `TORCH_COMPILE`: The kernel is used with `torch.compile`.
* `FALLBACK`: In a kernel mapping, this kernel is used when no other mode
matches.
Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE`
should be used for layers that are used for inference *with* `torch.compile`.
"""
_NONE = 0
FALLBACK = auto()
TRAINING = auto()
INFERENCE = auto()
TORCH_COMPILE = auto()
def __or__(self, other: Mode) -> Mode:
union = super().__or__(other)
if Mode.INFERENCE in union and Mode.TRAINING in union:
raise ValueError("Mode.INFERENCE and Mode.TRAINING are mutually exclusive.")
if Mode.FALLBACK in union and union != Mode.FALLBACK:
raise ValueError("Mode.FALLBACK cannot be combined with other modes.")
return union
@dataclass(frozen=True)
class Device:
type: str
properties: Optional[CUDAProperties] = None
# In the future we might add compute capabilities, etc.
def __post_init__(self):
if self.properties is not None and isinstance(self.properties, CUDAProperties):
if self.type != "cuda":
raise ValueError("CUDAProperties is only supported for 'cuda' devices.")
def create_repo(self) -> _DeviceRepos:
"""Create an appropriate repository set for this device type."""
if self.type == "cuda":
return _CUDARepos()
elif self.type == "mps":
return _MPSRepos()
else:
raise ValueError(f"Unknown device type: {self.type}")
def __eq__(self, other):
return isinstance(other, Device) and self.type == other.type
if not isinstance(other, Device):
return NotImplemented
return self.type == other.type and self.properties == other.properties
def __hash__(self):
return hash(self.type)
return hash((self.type, self.properties))
@dataclass(frozen=True)
class CUDAProperties:
min_capability: int
max_capability: int
def __eq__(self, other):
if not isinstance(other, CUDAProperties):
return NotImplemented
return (
self.min_capability == other.min_capability
and self.max_capability == other.max_capability
)
def __hash__(self):
return hash((self.min_capability, self.max_capability))
class LayerRepositoryProtocol(Protocol):
@property
def layer_name(self) -> str: ...
@property
def repo_id(self) -> str: ...
@property
def revision(self) -> str: ...
@dataclass
class LayerRepository:
"""
Repository and name of a layer.
"""
layer_name: str = field(
metadata={"help": "The name of the layer in the kernel repository."}
)
repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."})
revision: str = field(
default="main", metadata={"help": "The revision of the layer."}
)
def __init__(
self,
repo_id: str,
*,
layer_name: str,
revision: Optional[str] = None,
version: Optional[str] = None,
):
"""
Construct a layer repository.
Args:
repo_id (`str`): The Hub repository containing the layer.
revision (`str`, *optional*, defaults to `"main"`): The specific
revision (branch, tag, or commit) to download.
Cannot be used together with `version`.
version (`str`, *optional*): The kernel version to download. This
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
Cannot be used together with `revision`.
"""
if revision is not None and version is not None:
raise ValueError(
"Either a revision or a version must be specified, not both."
)
self.repo_id = repo_id
self.layer_name = layer_name
# We are going to resolve these lazily, since we do not want
# to do a network request for every registered LayerRepository.
self._revision = revision
self._version = version
@property
@functools.lru_cache()
def revision(self) -> str:
return select_revision_or_version(
repo_id=self.repo_id, revision=self._revision, version=self._version
)
def __eq__(self, other):
return (
isinstance(other, LayerRepository)
and self.layer_name == other.layer_name
and self.repo_id == other.repo_id
and self.revision == other.revision
and self._revision == other._revision
and self._version == other._version
)
def __hash__(self):
return hash((self.layer_name, self.repo_id, self.revision))
return hash((self.layer_name, self.repo_id, self._revision, self._version))
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
class LockedLayerRepository:
"""
Repository and name of a layer.
In contrast to `LayerRepository`, this class uses repositories that
are locked inside a project.
"""
def __init__(
self,
repo_id: str,
*,
lockfile: Optional[Path] = None,
layer_name: str,
):
"""
Construct a layer repository.
Args:
repo_id (`str`): The Hub repository containing the layer.
"""
self.repo_id = repo_id
self.lockfile = lockfile
self.layer_name = layer_name
@property
@functools.lru_cache()
def revision(self) -> str:
if self.lockfile is None:
locked_sha = _get_caller_locked_kernel(self.repo_id)
else:
with open(self.lockfile, "r") as f:
locked_sha = _get_locked_kernel(self.repo_id, f.read())
if locked_sha is None:
raise ValueError(f"Kernel `{self.repo_id}` is not locked")
return locked_sha
def __eq__(self, other):
return (
isinstance(other, LockedLayerRepository)
and self.layer_name == other.layer_name
and self.repo_id == other.repo_id
)
def __hash__(self):
return hash((self.layer_name, self.repo_id))
_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {}
class _DeviceRepos(ABC):
"""
Device-specific kernel layer repositories.
"""
@property
@abstractmethod
def repos(
self,
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: ...
@abstractmethod
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
"""
Insert a repository for a specific device and mode.
"""
...
class _MPSRepos(_DeviceRepos):
_repos: Dict[Mode, LayerRepositoryProtocol]
def __init__(self):
super().__init__()
self._repos = {}
@property
def repos(
self,
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
return self._repos
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
if device.type != "mps":
raise ValueError(f"Device type must be 'mps', got {device.type}")
self._repos = repos
class _CUDARepos(_DeviceRepos):
_repos: IntervalTree[Dict[Mode, LayerRepositoryProtocol]]
def __init__(self):
super().__init__()
self.repos_by_capability = IntervalTree()
@property
def repos(
self,
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
capability = _find_capability()
return self.repos_by_capability.find_smallest_interval(capability)
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
assert device.properties is None or isinstance(
device.properties, CUDAProperties
)
min_capability = (
0 if device.properties is None else device.properties.min_capability
)
max_capability = (
sys.maxsize
if device.properties is None
else device.properties.max_capability
)
self.repos_by_capability.insert(min_capability, max_capability, repos)
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[str, _DeviceRepos]]] = ContextVar(
"_KERNEL_MAPPING", default={}
)
def use_kernel_mapping(
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
mapping: Dict[
str,
Dict[
Union[Device, str],
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
],
],
*,
inherit_mapping: bool = True,
):
@ -87,12 +349,20 @@ def use_kernel_mapping(
def register_kernel_mapping(
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]
mapping: Dict[
str,
Dict[
Union[Device, str],
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
],
],
):
"""
Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device.
This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname.
Exemple usage:
Allows one to register a mapping between a layer name and the corresponding
kernel(s) to use, depending on the device. This should be used in conjunction
with `kernelize`.
Example usage:
```python
from kernels import LayerRepository, register_kernel_mapping
@ -113,32 +383,141 @@ def register_kernel_mapping(
for new_kernel, new_device_repos in mapping.items():
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
for new_device, new_repo in new_device_repos.items():
if isinstance(new_device, str):
device_repo[Device(type=new_device)] = new_repo
device = (
Device(type=new_device) if isinstance(new_device, str) else new_device
)
if isinstance(new_repo, dict):
kernel_options = new_repo
else:
device_repo[new_device] = new_repo
kernel_options = {Mode.FALLBACK: new_repo}
feature_repos = device_repo.setdefault(device.type, device.create_repo())
feature_repos.insert(device, kernel_options)
def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True):
def replace_kernel_forward_from_hub(
cls,
layer_name: str,
):
"""
Replace the forward function of a layer using a layer from the kernel hub.
This function monkeypatches a layer, replacing the `forward` method
of the layer with that of a layer from the hub. The replacement is done
when a layer matching `layer_name` and device type is registered through
`register_layer_mapping`. The device type is inferred from the first
argument to `forward`.
Decorator that prepares a layer class to use a kernel from the Hugging Face Hub.
This decorator stores the layer name and original forward method, which will be used
by the kernelize function to replace the forward implementation with the appropriate
kernel from the hub.
Args:
cls: The layer class to decorate
layer_name: The name of the layer to use for kernel lookup
"""
cls.kernel_layer_name = layer_name
fallback_forward = cls.forward
cached_layer: Dict[LayerRepository, nn.Module] = {}
_MODE_FALLBACK_PRIORITY = {
Mode.INFERENCE: [
Mode.INFERENCE,
Mode.INFERENCE | Mode.TORCH_COMPILE,
Mode.TRAINING,
Mode.TRAINING | Mode.TORCH_COMPILE,
Mode.FALLBACK,
],
Mode.TRAINING: [
Mode.TRAINING,
Mode.TRAINING | Mode.TORCH_COMPILE,
Mode.FALLBACK,
],
Mode.INFERENCE
| Mode.TORCH_COMPILE: [
Mode.INFERENCE | Mode.TORCH_COMPILE,
Mode.TRAINING | Mode.TORCH_COMPILE,
Mode.FALLBACK,
],
Mode.TRAINING
| Mode.TORCH_COMPILE: [
Mode.TRAINING | Mode.TORCH_COMPILE,
Mode.FALLBACK,
],
}
def _select_repository(
repositories: Dict[Mode, LayerRepositoryProtocol],
*,
mode: Mode,
) -> Optional[Tuple[LayerRepositoryProtocol, Mode]]:
# Get the fallback priority list for the requested mode
if mode not in _MODE_FALLBACK_PRIORITY:
raise ValueError(f"Unsupported mode: {mode}")
fallback_modes = _MODE_FALLBACK_PRIORITY[mode]
# Try each mode in priority order
for fallback_mode in fallback_modes:
if fallback_mode in repositories:
return (repositories[fallback_mode], fallback_mode)
return None
def kernelize(
model: "nn.Module",
*,
mode: Mode = Mode.TRAINING | Mode.TORCH_COMPILE,
device: Optional[Union[str, "torch.device"]] = None,
use_fallback: bool = True,
):
"""
Iterate over all modules in the model and replace the `forward` method of
extensible layers for which kernels are registered using `register_kernel_mapping`
or `use_kernel_mapping`.
Args:
model: The PyTorch model to kernelize
mode: the mode that the kernel is going to be used in (e.g.
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training
and `torch.compile`).
device: The device type to load kernels for. The device type will be inferred
from the parameters of the model when not provided.
use_fallback: Whether to use the original forward method of modules when no
compatible kernel could be found. If set to `False`, an exception will
be raised in such cases.
Returns:
The kernelized model
"""
import torch
if mode == Mode.FALLBACK:
raise ValueError("Mode.FALLBACK can only be used to register kernel mappings.")
# Type check ignored because this causes a false negative on Python < 3.11.
# Looks similar to: https://github.com/python/mypy/issues/9642
# Remove once we start doing typing checks on >= 3.11.
if Mode.INFERENCE not in mode and Mode.TRAINING not in mode: # type: ignore[operator]
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
if device is None:
device_type = _find_device(model)
elif isinstance(device, str):
device_type = Device(type=torch.device(device).type)
else:
device_type = Device(device.type)
assert isinstance(device_type, Device)
for _, module in model.named_modules():
module_class = type(module)
if not hasattr(module_class, "kernel_layer_name"):
continue
layer_name = module_class.kernel_layer_name
def forward(self, x, *args, **kwargs):
if _DISABLE_KERNEL_MAPPING:
return fallback_forward(self, x, *args, **kwargs)
_replace_forward(module, module_class)
continue
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
needs_backward = self.training
kernel = _KERNEL_MAPPING.get().get(layer_name)
if kernel is None:
warnings.warn(
"\n"
@ -148,68 +527,85 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
)
if not use_fallback:
raise ValueError(f"No layer mapping for `{layer_name}`")
return fallback_forward(self, x, *args, **kwargs)
_replace_forward(module, module_class)
continue
device = getattr(x, "device", None)
if device is None:
return fallback_forward(self, x, *args, **kwargs)
# Get kernel options for the device
property_repos = kernel.get(device_type.type)
repo = kernel.get(Device(type=device.type))
if repo is None:
if property_repos is None:
if not use_fallback:
raise ValueError(
f"No layer mapping for `{layer_name}` with device type `{device.type}`"
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
)
return fallback_forward(self, x, *args, **kwargs)
_replace_forward(module, module_class)
continue
# Short-circuit if we already loaded the layer.
layer = cached_layer.get(repo, None)
if layer is not None:
if needs_backward and not getattr(layer, "has_backward", True):
return fallback_forward(self, x, *args, **kwargs)
return layer.forward(self, x, *args, **kwargs)
repos = property_repos.repos
layer = _get_kernel_layer(
repo_id=repo.repo_id,
layer_name=repo.layer_name,
revision=repo.revision,
if repos is None:
if not use_fallback:
raise ValueError(
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
)
_replace_forward(module, module_class)
continue
repo_with_mode = _select_repository(
repos,
mode=mode,
)
# We have to validate against the original signature.
orig_forward = cls.forward
try:
cls.forward = fallback_forward
_validate_layer(check_cls=cls, cls=layer)
finally:
cls.forward = orig_forward
if repo_with_mode is None:
if not use_fallback:
raise ValueError(
f"No repository for `{layer_name}` for configuration mode={mode}"
)
_replace_forward(module, module_class)
continue
cached_layer[repo] = layer
repo, repo_mode = repo_with_mode
if needs_backward and not getattr(layer, "has_backward", True):
return fallback_forward(self, x, *args, **kwargs)
return layer.forward(self, x, *args, **kwargs)
logging.info(
f"Using layer `{repo.layer_name}` from repo `{repo.repo_id}` (revision: {repo.revision}) for layer `{layer_name}`"
)
logging.debug(f"kernelize mode: {mode}, repo mode: {repo_mode}")
cls.forward = forward
layer = _get_layer_memoize(repo, module_class)
# Ideally we would do validation on the mapping where we check that
# e.g. if a repo class is registered for TRAINING | TORCH_COMPILE,
# the actual layer is compatible with that. Unfortunately, this would
# mean that we have to pre-download everything.
_validate_layer_has_mode(
layer_name=layer_name, module=layer, repo=repo, repo_mode=repo_mode
)
_conditionally_replace_forward(
module=module,
layer=layer,
mode=mode,
use_fallback=use_fallback,
)
return model
def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True):
def use_kernel_forward_from_hub(layer_name: str):
"""
Replace the forward function of a layer using a layer from the kernel hub.
This decorator can be applied to a layer and replaces the forward method
of the layer with that of a layer from the hub. The replacement is done
when a layer matching `layer_name` and device type is registered through
`register_layer_mapping`. The device type is inferred from the first
argument to `forward`.
Make a layer extensible using the name `layer_name`.
"""
def decorator(cls):
replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback)
replace_kernel_forward_from_hub(cls, layer_name)
return cls
return decorator
def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
def _get_kernel_layer(
*, repo_id: str, layer_name: str, revision: str
) -> Type["nn.Module"]:
"""Get a layer from a kernel."""
kernel = get_kernel(repo_id, revision=revision)
@ -226,13 +622,13 @@ def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Mo
def _validate_layer(*, check_cls, cls):
import torch.nn as nn
# The layer must have at least have the following properties: (1) it
# must be stateless; (2) the forward signature should correspond to
# the signature it is replacing; (3) forward should not call other
# methods.
from torch import nn
if not issubclass(cls, nn.Module):
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
@ -245,7 +641,8 @@ def _validate_layer(*, check_cls, cls):
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(cls)}
difference = cls_members - torch_module_members
if difference != set() and difference != {"has_backward"}:
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
if not difference <= {"can_torch_compile", "has_backward"}:
raise TypeError("Layer must not contain additional members.")
# Check whether the forward signatures are similar.
@ -262,3 +659,100 @@ def _validate_layer(*, check_cls, cls):
raise TypeError(
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
)
def _find_device(model: "nn.Module") -> Device:
try:
param = next(model.parameters())
except StopIteration:
raise ValueError(
"Cannot determine model device, provide as `device` argument to `kernelize`."
)
return Device(type=param.device.type)
@lru_cache
def _find_capability() -> int:
import torch
major, minor = torch.cuda.get_device_capability(device=None)
return major * 10 + minor
def _conditionally_replace_forward(
*,
module: "nn.Module",
layer: Type["nn.Module"],
mode: Mode,
use_fallback: bool,
):
module_class = type(module)
# Switch to fallback if the mode is not supported by the layer.
# Note that this is useful even after _validate_layer_has_mode because
# layers registered with the FALLBACK mode never get rejected by
# _validate_layer_has_mode. For such layers, we want to fall back in
# case the layer does not support the given mode.
needs_fallback = Mode.TORCH_COMPILE in mode and not getattr(
layer, "can_torch_compile", False
)
needs_fallback |= Mode.TRAINING in mode and not getattr(layer, "has_backward", True)
if needs_fallback:
if use_fallback:
_replace_forward(module, module_class)
else:
raise ValueError(f"Available kernel does not support mode: {mode}")
else:
_replace_forward(module, layer)
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
def _validate_layer_has_mode(
*,
layer_name: str,
module: Type["nn.Module"],
repo: LayerRepositoryProtocol,
repo_mode: Mode,
):
"""
Check that a repository supports the mode that it was registered for.
"""
if Mode.TRAINING in repo_mode and not getattr(module, "has_backward", True):
raise ValueError(
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support backward.\n"
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
)
if Mode.TORCH_COMPILE in repo_mode and not getattr(
module, "can_torch_compile", False
):
raise ValueError(
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support torch.compile.\n"
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
)
return True
def _get_layer_memoize(
repo: LayerRepositoryProtocol, module_class: Type["nn.Module"]
) -> Type["nn.Module"]:
layer = _CACHED_LAYER.get(repo, None)
if layer is not None:
return layer
layer = _get_kernel_layer(
repo_id=repo.repo_id,
layer_name=repo.layer_name,
revision=repo.revision,
)
_validate_layer(check_cls=module_class, cls=layer)
_CACHED_LAYER[repo] = layer
return layer

View File

@ -4,10 +4,8 @@ from pathlib import Path
from typing import Dict, List, Tuple
from huggingface_hub import HfApi
from huggingface_hub.hf_api import GitRefInfo
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version
from kernels._versions import resolve_version_spec_as_ref
from kernels.compat import tomllib
@ -31,20 +29,6 @@ class KernelLock:
return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
"""Get kernel versions that are available in the repository."""
versions = {}
for tag in HfApi().list_repo_refs(repo_id).tags:
if not tag.name.startswith("v"):
continue
try:
versions[Version(tag.name[1:])] = tag
except InvalidVersion:
continue
return versions
def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
"""
Get the locks for a kernel with the given version spec.
@ -52,16 +36,7 @@ def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
The version specifier can be any valid Python version specifier:
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
"""
versions = _get_available_versions(repo_id)
requirement = SpecifierSet(version_spec)
accepted_versions = sorted(requirement.filter(versions.keys()))
if len(accepted_versions) == 0:
raise ValueError(
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
)
tag_for_newest = versions[accepted_versions[-1]]
tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec)
r = HfApi().repo_info(
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True

View File

@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Tuple
from huggingface_hub import file_exists, snapshot_download
from packaging.version import parse
from kernels._versions import select_revision_or_version
from kernels.lockfile import KernelLock, VariantLock
@ -43,14 +44,25 @@ def build_variant() -> str:
elif torch.version.hip is not None:
rocm_version = parse(torch.version.hip.split("-")[0])
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
elif torch.backends.mps.is_available():
compute_framework = "metal"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
compute_framework = "xpu"
else:
raise AssertionError("Torch was not compiled with CUDA or ROCm enabled.")
raise AssertionError(
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
)
torch_version = parse(torch.__version__)
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
cpu = platform.machine()
os = platform.system().lower()
if os == "darwin":
cpu = "aarch64" if cpu == "arm64" else cpu
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
@ -101,6 +113,23 @@ def install_kernel(
)
)
try:
return _load_kernel_from_path(repo_path, package_name, variant_locks)
except FileNotFoundError:
# Redo with more specific error message.
raise FileNotFoundError(
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
)
def _load_kernel_from_path(
repo_path: Path,
package_name: str,
variant_locks: Optional[Dict[str, VariantLock]] = None,
) -> Tuple[str, Path]:
variant = build_variant()
universal_variant = universal_build_variant()
variant_path = repo_path / "build" / variant
universal_variant_path = repo_path / "build" / universal_variant
@ -119,7 +148,7 @@ def install_kernel(
if not os.path.exists(module_init_path):
raise FileNotFoundError(
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
f"Kernel at path `{repo_path}` does not have build: {variant}"
)
return package_name, variant_path
@ -156,16 +185,63 @@ def install_kernel_all_variants(
return repo_path / "build"
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
def get_kernel(
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
) -> ModuleType:
"""
Load a kernel from the kernel hub.
This function downloads a kernel to the local Hugging Face Hub cache
directory (if it was not downloaded before) and then loads the kernel.
Args:
repo_id (`str`): The Hub repository containing the kernel.
revision (`str`, *optional*, defaults to `"main"`): The specific
revision (branch, tag, or commit) to download.
Cannot be used together with `version`.
version (`str`, *optional*): The kernel version to download. This
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
Cannot be used together with `revision`.
Returns:
`ModuleType`: The imported kernel module.
Example:
```python
from kernels import get_kernel
kernel = get_kernel("username/my-kernel")
result = kernel.kernel_function(input_data)
```
"""
revision = select_revision_or_version(repo_id, revision, version)
package_name, package_path = install_kernel(repo_id, revision=revision)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def has_kernel(repo_id: str, revision: str = "main") -> bool:
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
"""
Import a kernel from a local kernel repository path.
"""
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def has_kernel(
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
) -> bool:
"""
Check whether a kernel build exists for the current environment
(Torch version and compute framework).
Args:
repo_id (`str`): The Hub repository containing the kernel.
revision (`str`, *optional*, defaults to `"main"`): The specific
revision (branch, tag, or commit) to download.
Cannot be used together with `version`.
version (`str`, *optional*): The kernel version to download. This
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
Cannot be used together with `revision`.
Returns:
`bool`: `true` if a kernel is avaialble for the current environment.
"""
revision = select_revision_or_version(repo_id, revision, version)
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
universal_variant = universal_build_variant()

186
src/kernels/wheel.py Normal file
View File

@ -0,0 +1,186 @@
import email.policy
import os
from dataclasses import dataclass
from email.message import Message
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from typing import Optional
try:
KERNELS_VERSION = version("kernels")
except PackageNotFoundError:
KERNELS_VERSION = "unknown"
@dataclass
class Metadata:
name: str
version: str
cuda_version: Optional[str]
cxx_abi_version: Optional[str]
torch_version: Optional[str]
os: Optional[str]
platform: Optional[str]
@property
def is_universal(self) -> bool:
return self.platform is None
def build_variant_to_wheel(
repo_id: str,
*,
version: str,
variant_path: Path,
wheel_dir: Path,
manylinux_version: str = "2.28",
python_version: str = "3.9",
) -> Path:
"""
Create a wheel file from the variant path.
"""
name = repo_id.split("/")[-1].replace("_", "-")
metadata = extract_metadata(name, version, variant_path)
return build_wheel(
metadata,
variant_path=variant_path,
wheel_dir=wheel_dir,
manylinux_version=manylinux_version,
python_version=python_version,
)
def extract_metadata(name: str, version: str, variant_path: Path) -> Metadata:
"""
Extract metadata from the variant path.
"""
if variant_path.name == "torch-universal":
return Metadata(
name=name,
version=version,
cuda_version=None,
cxx_abi_version=None,
torch_version=None,
os=None,
platform=None,
)
if not variant_path.name.startswith("torch"):
raise ValueError("Currently only conversion of Torch kernels is supported.")
variant_parts = variant_path.name.removeprefix("torch").split("-")
if len(variant_parts) != 5:
raise ValueError(f"Invalid variant name: {variant_path.name}")
torch_version = f"{variant_parts[0][:-1]}.{variant_parts[0][-1:]}"
cpp_abi_version = variant_parts[1].removeprefix("cxx")
cuda_version = variant_parts[2].removeprefix("cu")
platform = variant_parts[3].replace("-", "_")
os = variant_parts[4]
return Metadata(
name=name,
version=version,
cuda_version=cuda_version,
cxx_abi_version=cpp_abi_version,
torch_version=torch_version,
os=os,
platform=platform,
)
def build_wheel(
metadata: Metadata,
*,
variant_path: Path,
wheel_dir: Path,
manylinux_version: str = "2.28",
python_version: str = "3.9",
) -> Path:
"""
Build the wheel file.
"""
try:
from wheel.wheelfile import WheelFile # type: ignore
except ImportError:
raise ImportError(
"The 'wheel' package is required to build wheels. Please install it with: `pip install wheel`"
)
name = metadata.name.replace("-", "_")
python_version_flat = python_version.replace(".", "")
if metadata.is_universal:
python_tag = f"py{python_version_flat}"
abi_tag = "none"
platform_tag = "any"
wheel_filename = (
f"{name}-{metadata.version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
)
dist_info_dir_name = f"{name}-{metadata.version}.dist-info"
root_is_purelib = "true"
requires_dist_torch = "torch"
else:
python_tag = f"cp{python_version_flat}"
abi_tag = "abi3"
if (
metadata.torch_version is None
or metadata.cuda_version is None
or metadata.cxx_abi_version is None
or metadata.os is None
or metadata.platform is None
):
raise ValueError(
"Torch version, CUDA version, C++ ABI version, OS, and platform must be specified for non-universal wheels."
)
local_version = f"torch{metadata.torch_version.replace('.', '')}cu{metadata.cuda_version}cxx{metadata.cxx_abi_version}"
if metadata.os == "linux":
platform_tag = (
f"manylinux_{manylinux_version.replace('.', '_')}_{metadata.platform}"
)
else:
platform_tag = f"{metadata.os}_{metadata.platform.replace('-', '_')}"
wheel_filename = f"{name}-{metadata.version}+{local_version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
dist_info_dir_name = f"{name}-{metadata.version}+{local_version}.dist-info"
root_is_purelib = "false"
requires_dist_torch = f"torch=={metadata.torch_version}.*"
wheel_path = wheel_dir / wheel_filename
wheel_msg = Message(email.policy.compat32)
wheel_msg.add_header("Wheel-Version", "1.0")
wheel_msg.add_header("Generator", f"kernels ({KERNELS_VERSION})")
wheel_msg.add_header("Root-Is-Purelib", root_is_purelib)
wheel_msg.add_header("Tag", f"{python_tag}-{abi_tag}-{platform_tag}")
metadata_msg = Message(email.policy.compat32)
metadata_msg.add_header("Metadata-Version", "2.1")
metadata_msg.add_header("Name", name)
metadata_msg.add_header("Version", metadata.version)
metadata_msg.add_header("Summary", f"{name} kernel")
metadata_msg.add_header("Requires-Python", ">=3.9")
metadata_msg.add_header("Requires-Dist", requires_dist_torch)
source_pkg_dir = variant_path / name
with WheelFile(wheel_path, "w") as wheel_file:
for root, dirnames, filenames in os.walk(source_pkg_dir):
for filename in filenames:
if filename.endswith(".pyc"):
continue
abs_filepath = os.path.join(root, filename)
entry_name = os.path.relpath(abs_filepath, variant_path)
wheel_file.write(abs_filepath, entry_name)
wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL")
wheel_file.writestr(wheel_metadata_path, str(wheel_msg).encode("utf-8"))
metadata_path = os.path.join(dist_info_dir_name, "METADATA")
wheel_file.writestr(metadata_path, str(metadata_msg).encode("utf-8"))
return wheel_path

10
tests/conftest.py Normal file
View File

@ -0,0 +1,10 @@
import sys
import pytest
def pytest_runtest_setup(item):
if "linux_only" in item.keywords and not sys.platform.startswith("linux"):
pytest.skip("skipping Linux-only test on non-Linux platform")
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
pytest.skip("skipping macOS-only test on non-macOS platform")

View File

@ -1,54 +1,82 @@
[
{
"repo_id": "kernels-community/activation",
"sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
"sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538",
"variants": {
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d",
"hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58",
"hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a",
"hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0",
"hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8",
"hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc",
"hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b",
"hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005",
"hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-aarch64-linux": {
"hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc",
"hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f",
"hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf",
"hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-aarch64-linux": {
"hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a",
"hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu118-x86_64-linux": {
"hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-aarch64-linux": {
"hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-x86_64-linux": {
"hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-aarch64-linux": {
"hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-x86_64-linux": {
"hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597",
"hash_type": "git_lfs_concat"
}
}

View File

@ -0,0 +1,12 @@
[
{
"repo_id": "kernels-test/versions",
"sha": "dc142fd6c9920c993d32be6358b78957c58681c3",
"variants": {
"torch-universal": {
"hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c",
"hash_type": "git_lfs_concat"
}
}
}
]

View File

@ -0,0 +1,2 @@
[tool.kernels.dependencies]
"kernels-test/versions" = ">=0.1.0,<0.2.0"

View File

@ -1,7 +1,7 @@
import pytest
import torch
from kernels import get_kernel, has_kernel
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
@pytest.fixture
@ -9,6 +9,19 @@ def kernel():
return get_kernel("kernels-community/activation")
@pytest.fixture
def local_kernel():
package_name, path = install_kernel("kernels-community/activation", "main")
# Path is the build variant path (build/torch-<...>), so the grandparent
# is the kernel repository path.
return get_local_kernel(path.parent.parent, package_name)
@pytest.fixture
def metal_kernel():
return get_kernel("kernels-test/relu-metal")
@pytest.fixture
def universal_kernel():
return get_kernel("kernels-community/triton-scaled-mm")
@ -21,6 +34,7 @@ def device():
return "cuda"
@pytest.mark.linux_only
def test_gelu_fast(kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
@ -36,6 +50,31 @@ def test_gelu_fast(kernel, device):
assert torch.allclose(y, expected)
@pytest.mark.linux_only
def test_local_kernel(local_kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
local_kernel.gelu_fast(y, x)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
device=device,
dtype=torch.float16,
)
assert torch.allclose(y, expected)
@pytest.mark.darwin_only
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_relu_metal(metal_kernel, dtype):
x = torch.arange(-10, 10, dtype=dtype, device="mps")
y = metal_kernel.relu(x)
assert torch.allclose(y, torch.relu(x))
@pytest.mark.linux_only
@pytest.mark.parametrize(
"kernel_exists",
[
@ -52,6 +91,26 @@ def test_has_kernel(kernel_exists):
assert has_kernel(repo_id, revision=revision) == kernel
def test_version():
kernel = get_kernel("kernels-test/versions")
assert kernel.version() == "0.2.0"
kernel = get_kernel("kernels-test/versions", version="<1.0.0")
assert kernel.version() == "0.2.0"
kernel = get_kernel("kernels-test/versions", version="<0.2.0")
assert kernel.version() == "0.1.1"
kernel = get_kernel("kernels-test/versions", version=">0.1.0,<0.2.0")
assert kernel.version() == "0.1.1"
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
get_kernel("kernels-test/versions", version=">0.2.0")
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
kernel = get_kernel(
"kernels-test/versions", revision="v0.1.0", version="<1.0.0"
)
@pytest.mark.linux_only
def test_universal_kernel(universal_kernel):
torch.manual_seed(0)
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")

View File

@ -16,18 +16,21 @@ def device():
return "cuda"
@pytest.mark.linux_only
def test_gelu_small(kernel, device, benchmark):
x = torch.randn(32, 32, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.linux_only
def test_gelu_medium(kernel, device, benchmark):
x = torch.randn(128, 128, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.linux_only
def test_gelu_large(kernel, device, benchmark):
x = torch.randn(512, 512, dtype=torch.float16, device=device)
y = torch.empty_like(x)

230
tests/test_interval_tree.py Normal file
View File

@ -0,0 +1,230 @@
import random
from typing import Generic, List, Optional, Tuple, TypeVar
import pytest
from kernels._interval_tree import IntervalTree, _Node
T = TypeVar("T")
class SimpleIntervalStore(Generic[T]):
"""A simple O(n) implementation that stores intervals in a list."""
def __init__(self):
self.intervals: List[Tuple[int, int, T]] = []
def insert(self, start: int, end: int, data: T) -> None:
"""Insert an interval into the store."""
# Replace data if the interval already exists.
for i, (existing_start, existing_end, existing_data) in enumerate(
self.intervals
):
if existing_start == start and existing_end == end:
self.intervals[i] = (start, end, data)
return
self.intervals.append((start, end, data))
def find_smallest_interval(self, point: int) -> Optional[T]:
"""Find the best match using linear search."""
matches = []
for start, end, data in self.intervals:
if start <= point <= end:
matches.append((start, end, data))
if not matches:
return None
# Return the smallest interval, sort by memory location when
# there are multiple matches with the same interval size. This
# mirrors the ordering in the intervan tree.
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
return best_match[2]
def is_balanced(tree: IntervalTree[T]) -> bool:
"""Check if the AVL tree is properly balanced."""
def check_balance(node: Optional[_Node[T]]) -> Tuple[bool, int]:
if node is None:
return True, 0
# Left and right subtrees should be balanced.
left_balanced, left_height = check_balance(node.left)
if not left_balanced:
return False, -1
right_balanced, right_height = check_balance(node.right)
if not right_balanced:
return False, -1
# The difference in height should not exceed 1.
if abs(left_height - right_height) > 1:
return False, -1
# Check if the height is correct.
expected_height = 1 + max(left_height, right_height)
if node.height != expected_height:
return False, -1
return True, expected_height
balanced, _ = check_balance(tree.root)
return balanced
@pytest.fixture
def populated_tree() -> IntervalTree[str]:
"""Provides a pre-populated IntervalTree for testing."""
tree = IntervalTree[str]()
kernels = [
(80, 89, "Kernel_A_General_80_89"),
(86, 89, "Kernel_B_Ampere_86_89"),
(80, 86, "Kernel_C_Older_Ampere_80_86"),
(70, 75, "Kernel_D_Volta_70_75"),
(86, 87, "Kernel_E_Specific_86_87"),
]
for start, end, name in kernels:
tree.insert(start, end, name)
return tree
def test_find_smallest_interval_match_with_multiple_overlaps(populated_tree):
# Check that the smallest inteval is selected when there are
# multiple matching intervals.
assert populated_tree.find_smallest_interval(86) == "Kernel_E_Specific_86_87"
def test_find_single_match(populated_tree):
assert populated_tree.find_smallest_interval(72) == "Kernel_D_Volta_70_75"
assert populated_tree.find_smallest_interval(75) == "Kernel_D_Volta_70_75"
def test_no_match_outside_all_ranges(populated_tree):
# Check that no interval is found when the value is out of range
# (too small/too large).
assert populated_tree.find_smallest_interval(65) is None
assert populated_tree.find_smallest_interval(95) is None
def test_no_match_in_gap_between_ranges(populated_tree):
# Check that no interval is found when the value is between two
# intervals.
assert populated_tree.find_smallest_interval(78) is None
def test_boundary_conditions_start_and_end(populated_tree):
# Test exact upper/lower bounds of intervals.
assert populated_tree.find_smallest_interval(80) == "Kernel_C_Older_Ampere_80_86"
assert populated_tree.find_smallest_interval(89) == "Kernel_B_Ampere_86_89"
def test_empty_tree():
# Searching in an empty tree should return None.
empty_tree = IntervalTree[str]()
assert empty_tree.find_smallest_interval(100) is None
def test_multiple_equally_specific_matches():
# Check that we pick the match in a stable way when there is are
# multiple matching intervals with the same size.
tree = IntervalTree[str]()
str1 = "First_Narrow_Kernel"
str2 = "Second_Narrow_Kernel"
tree.insert(10, 20, "Wide_Kernel")
tree.insert(12, 17, str1)
tree.insert(14, 19, str2)
if id(str1) < id(str2):
assert tree.find_smallest_interval(15) == str1
else:
assert tree.find_smallest_interval(15) == str2
def test_property_based_interval_tree():
# Quick-check property-based testing:
#
# - Verify that the tree is balanced after each insertion.
# - Verify the query against a simple list-based implementation.
random.seed(42) # For reproducible tests
test_points = list(range(0, 101))
for _ in range(5):
tree = IntervalTree[str]()
simple = SimpleIntervalStore[str]()
intervals = []
for i in range(100):
start = random.randint(0, 90)
end = random.randint(start, 100)
data = f"interval_{i}_s{start}_e{end}"
intervals.append((start, end, data))
for i, (start, end, data) in enumerate(intervals):
tree.insert(start, end, data)
simple.insert(start, end, data)
# Check that tree is still balanced
assert is_balanced(
tree
), f"Tree became unbalanced after inserting interval {i}: ({start}, {end})"
for point in test_points:
tree_result = tree.find_smallest_interval(point)
simple_result = simple.find_smallest_interval(point)
assert tree_result == simple_result, (
f"Mismatch for point {point} after inserting {i+1} intervals. "
f"Tree: {tree_result}, Simple: {simple_result}. "
f"Last inserted: ({start}, {end})"
)
def test_property_based_edge_cases():
random.seed(123)
tree = IntervalTree[str]()
simple = SimpleIntervalStore[str]()
# Single-point intervals.
for i in range(10):
point = random.randint(0, 100)
data = f"single_point_{i}_{point}"
tree.insert(point, point, data)
simple.insert(point, point, data)
assert is_balanced(
tree
), f"Tree unbalanced after inserting single point {point}"
# Test the exact point and neighbors
for test_point in [point - 1, point, point + 1]:
if 0 <= test_point <= 100:
tree_result = tree.find_smallest_interval(test_point)
simple_result = simple.find_smallest_interval(test_point)
assert tree_result == simple_result
def test_unique_intervals_override():
"""Test that inserting an interval with the same start/end overrides the previous value."""
tree = IntervalTree[str]()
tree.insert(10, 20, "original_value")
assert tree.find_smallest_interval(15) == "original_value"
tree.insert(10, 20, "new_value")
assert tree.find_smallest_interval(15) == "new_value"
tree.insert(10, 25, "different_interval")
results = tree.search(15)
assert "new_value" in results
assert "different_interval" in results
assert len(results) == 2
tree.insert(10, 20, "final_value")
assert tree.find_smallest_interval(15) == "final_value"
assert is_balanced(tree)

View File

@ -1,8 +1,18 @@
from dataclasses import dataclass
from pathlib import Path
import pytest
import torch.nn as nn
from kernels import load_kernel
from kernels.cli import download_kernels
from kernels.layer import (
LockedLayerRepository,
Mode,
kernelize,
use_kernel_forward_from_hub,
use_kernel_mapping,
)
# Mock download arguments class.
@ -17,8 +27,34 @@ def test_download_all_hash_validation():
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
@pytest.mark.linux_only
def test_load_locked():
project_dir = Path(__file__).parent / "kernel_locking"
# Also validates that hashing works correctly.
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
def test_layer_locked():
project_dir = Path(__file__).parent / "layer_locking"
@use_kernel_forward_from_hub("Version")
class Version(nn.Module):
def forward(self) -> str:
return "0.0.0"
version = Version()
with use_kernel_mapping(
{
"Version": {
"cuda": LockedLayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
lockfile=project_dir / "kernels.lock",
)
},
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.1.1"

View File

@ -1,3 +1,6 @@
import sys
from contextlib import nullcontext
import pytest
import torch
import torch.nn as nn
@ -6,24 +9,35 @@ from torch.nn import functional as F
from kernels import (
Device,
LayerRepository,
Mode,
kernelize,
register_kernel_mapping,
use_kernel_forward_from_hub,
)
from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping
from kernels.layer import (
_KERNEL_MAPPING,
CUDAProperties,
_validate_layer,
use_kernel_mapping,
)
kernel_layer_mapping = {
"SiluAndMul": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
},
"SiluAndMulNoCompile": {
"cuda": LayerRepository(
repo_id="kernels-test/op-without-fake-test",
layer_name="SiluAndMul",
)
},
"SiluAndMulStringDevice": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
},
}
@ -43,6 +57,11 @@ class SiluAndMul(nn.Module):
return F.silu(input[..., :d]) * input[..., d:]
@use_kernel_forward_from_hub("SiluAndMulNoCompile")
class SiluAndMulNoCompileKernel(SiluAndMul):
pass
@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMulWithKernel(SiluAndMul):
pass
@ -53,6 +72,18 @@ class SiluAndMulStringDevice(SiluAndMul):
pass
@use_kernel_forward_from_hub("Linear")
class TorchLinearWithCounter(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
return super().forward(input)
def test_arg_kinds():
@use_kernel_forward_from_hub("ArgKind")
class ArgKind(nn.Module):
@ -71,6 +102,7 @@ def test_arg_kinds():
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_hub_forward(cls, device):
@ -80,7 +112,7 @@ def test_hub_forward(cls, device):
X = torch.randn((32, 64), device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
@ -92,17 +124,126 @@ def test_hub_forward(cls, device):
assert silu_and_mul_with_kernel.n_calls == 1
@pytest.mark.linux_only
def test_capability():
linear = TorchLinearWithCounter(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=75, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Check that we called out to the kernel.
assert linear.n_calls == 0
with use_kernel_mapping(
{
"Linear": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=sys.maxsize, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Check that we didn't call out to the kernel because there is
# is no kernel with a matching capability..
assert linear.n_calls == 1
def test_layer_fallback_works():
@use_kernel_forward_from_hub("SiluAndMulNonExisting")
class SiluAndMulWithKernelFallback(SiluAndMul):
pass
# Check that we don't raise an exception for a non-existing kernel.
SiluAndMulWithKernelFallback()
silu_and_mul = SiluAndMulWithKernelFallback()
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_without_fallback(cls, device):
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), dtype=torch.float32, device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel.eval()
ctx = (
pytest.raises(ValueError, match="does not support mode")
if cls is SiluAndMulNoCompileKernel
else nullcontext()
)
with ctx:
silu_and_mul_with_kernel = kernelize(
silu_and_mul_with_kernel,
device=device,
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
use_fallback=False,
)
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
Y_compiled = silu_and_mul_compiled(X)
torch.testing.assert_close(Y_compiled, Y)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_with_fallback(cls, device):
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), dtype=torch.float32, device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel.eval()
silu_and_mul_with_kernel = kernelize(
silu_and_mul_with_kernel,
device=device,
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
)
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
Y_compiled = silu_and_mul_compiled(X)
torch.testing.assert_close(Y_compiled, Y)
@pytest.mark.linux_only
def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"}
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
}
extra_mapping1 = {
"TestKernel": {
@ -118,6 +259,7 @@ def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
@ -135,20 +277,22 @@ def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
== "kernels-community/non-existing"
)
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
== "kernels-community/activation"
)
@ -157,23 +301,25 @@ def test_mapping_contexts():
"SiluAndMul",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
== "kernels-community/non-existing"
)
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
== "kernels-community/activation"
)
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
}
@ -205,39 +351,176 @@ def test_validate_kernel_layer():
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
def test_fallback_used_when_training():
@use_kernel_forward_from_hub("Linear")
class TorchLinear(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
return super().forward(input)
linear = TorchLinear(32, 32).to("cuda")
@pytest.mark.linux_only
def test_invalid_mode_for_mapping_rejected():
linear = TorchLinearWithCounter(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
)
}
}
}
):
with pytest.raises(ValueError, match="does not support backward"):
kernelize(linear, mode=Mode.TRAINING)
@pytest.mark.linux_only
def test_kernel_modes():
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: layer without further specification, becomes the
# base layer.
with use_kernel_mapping(
{
"Linear": {
"cuda": LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearImplicitBackward",
layer_name="LinearBackward",
)
}
}
):
linear.train()
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
kernelize(linear, mode=Mode.TRAINING)
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
assert linear.n_calls == 0
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
assert linear.n_calls == 0
# Case 2: register a kernel just for training. If no base kernel
# layer is registered, we fall back to the original layer.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Training has a kernel, so fallback.
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
assert linear.n_calls == 1
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
assert linear.n_calls == 2
# Case 3: register a kernel just for training and one for fallback.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.FALLBACK: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Falls back to TRAINING.
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Falls back to the TRAINING kernel.
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
assert linear.n_calls == 2
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
assert linear.n_calls == 2
# Case 4: register a kernel with two preferences.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Falls back to the TRAINING | TORCH_COMPILE kernel.
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# TRAINING can fall back to TRAINING | TORCH_COMPILE kernel.
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# Uses TRAINING | TORCH_COMPILE kernel.
assert linear.n_calls == 2
kernelize(linear)
linear(X)
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
assert linear.n_calls == 2
@pytest.mark.linux_only
def test_fallback_used_when_training():
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: kernel with explicit backward support should always
# use the kernel.
with use_kernel_mapping(
{
"Linear": {
@ -249,6 +532,7 @@ def test_fallback_used_when_training():
}
):
linear.train()
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
@ -257,21 +541,413 @@ def test_fallback_used_when_training():
linear(X)
assert linear.n_calls == 0
# Case 2: kernel with implicit backward support should always
# use the kernel.
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
layer_name="LinearImplicitBackward",
)
}
}
):
linear.train()
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 1
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
def test_invalid_mode_rejected():
with pytest.raises(ValueError, match="mutually exclusive"):
_ = Mode.INFERENCE | Mode.TRAINING
with pytest.raises(ValueError, match="cannot be combined with other modes"):
_ = Mode.FALLBACK | Mode.TORCH_COMPILE
with pytest.raises(
ValueError, match="can only be used to register kernel mappings"
):
kernelize(torch.nn.Linear(32, 32), mode=Mode.FALLBACK)
with pytest.raises(ValueError, match="mode must contain"):
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
@pytest.mark.linux_only
def test_kernel_modes_inference():
"""Test inference-specific fallback scenarios."""
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: register a kernel just for inference
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
linear(X)
# INFERENCE | TORCH_COMPILE cannot fall back to INFERENCE kernel, so uses original
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# No training kernel, so fallback to original
assert linear.n_calls == 2
# Case 2: register a kernel just for inference + torch.compile
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.INFERENCE
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 2
kernelize(linear, mode=Mode.INFERENCE)
linear(X)
# INFERENCE falls back to INFERENCE | TORCH_COMPILE kernel
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# No training kernel, so fallback to original
assert linear.n_calls == 3
# Case 3: register both inference kernels
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.INFERENCE
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Uses exact INFERENCE kernel
assert linear.n_calls == 3
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
linear(X)
# Uses exact INFERENCE | TORCH_COMPILE kernel
assert linear.n_calls == 3
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# No training kernel, so fallback to original
assert linear.n_calls == 4
@pytest.mark.linux_only
def test_kernel_modes_mixed():
"""Test mixed training and inference kernel scenarios."""
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: register both base inference and training kernels
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING)
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
linear(X)
# INFERENCE | TORCH_COMPILE cannot fall back to INFERENCE kernel, so uses original
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original
assert linear.n_calls == 2
# Case 2: register all four kernel modes
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.INFERENCE
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.TRAINING
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Uses exact INFERENCE kernel
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Uses exact TRAINING kernel
assert linear.n_calls == 2
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
linear(X)
# Uses exact INFERENCE | TORCH_COMPILE kernel
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# Uses exact TRAINING | TORCH_COMPILE kernel
assert linear.n_calls == 2
@pytest.mark.linux_only
def test_kernel_modes_cross_fallback():
"""Test cross-mode fallback scenarios from inference to training modes."""
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: Only training kernel registered - inference should fall back to training
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# INFERENCE falls back to TRAINING kernel
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# TRAINING uses the kernel directly
assert linear.n_calls == 0
# Case 2: Only training + torch.compile kernel registered
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# INFERENCE falls back to TRAINING | TORCH_COMPILE kernel
assert linear.n_calls == 0
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
linear(X)
# INFERENCE | TORCH_COMPILE falls back to TRAINING | TORCH_COMPILE kernel
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# TRAINING falls back to TRAINING | TORCH_COMPILE kernel
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# TRAINING | TORCH_COMPILE uses the kernel directly
assert linear.n_calls == 0
# Case 3: Test that training modes don't fall back to inference modes
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.INFERENCE
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.TRAINING)
X = torch.randn(10, 32, device="cuda")
linear(X)
# TRAINING should NOT fall back to inference kernels, use original
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# TRAINING | TORCH_COMPILE should NOT fall back to inference kernels, use original
assert linear.n_calls == 2
def test_layer_versions():
@use_kernel_forward_from_hub("Version")
class Version(nn.Module):
def forward(self) -> str:
return "0.0.0"
version = Version()
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.2.0"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version="<1.0.0",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.2.0"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version="<0.2.0",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.1.1"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version=">0.1.0,<0.2.0",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.1.1"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version=">0.2.0",
)
}
}
):
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
kernelize(version, device="cuda", mode=Mode.INFERENCE)
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
revision="v0.1.0",
version="<1.0.0",
)
}
}
)