mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-21 05:30:30 +08:00
Compare commits
63 Commits
Author | SHA1 | Date | |
---|---|---|---|
d8fefaeef5 | |||
ed048616fe | |||
b182cd3458 | |||
ce77658efc | |||
b96b154e7f | |||
b24ef9fa6b | |||
a7101b2cfd | |||
6241afa06e | |||
34a1932751 | |||
e39eac09c1 | |||
b0c431fee4 | |||
9a188eadbe | |||
457c7c1b8d | |||
fb8cd99a2c | |||
dfee307d54 | |||
93e5765611 | |||
bf488208be | |||
2a14472e4c | |||
055a953552 | |||
692d5ad458 | |||
2139df57f4 | |||
8f9a77bb6a | |||
6c00194680 | |||
d6b51eefb7 | |||
d383fdd4b4 | |||
07e5e8481a | |||
88f55d4728 | |||
e801ebf332 | |||
0ae07f05fc | |||
7611021100 | |||
767e7ccf13 | |||
1caa4c1393 | |||
da701bf58a | |||
703664ed31 | |||
a8a6564fa7 | |||
c89e0fa9b9 | |||
176a601178 | |||
cfa0c76ddc | |||
bcc29915f9 | |||
6fbff7a9cb | |||
f7490bd0a9 | |||
8069e3bf0c | |||
c540d1e1d6 | |||
967ac581b8 | |||
81088d44e8 | |||
4a04c005e3 | |||
6d3c6daf20 | |||
071900fd69 | |||
2d2c6b14e0 | |||
03edc573b1 | |||
c841a6c90d | |||
c7a343f195 | |||
8d838f947d | |||
b87e6fadbe | |||
fc935d9874 | |||
3622e1f8dd | |||
a7f3b2e8ed | |||
a6ab5d83ba | |||
4f9f1abfb9 | |||
f94b7780a6 | |||
bd28883775 | |||
498429e322 | |||
09c991af4b |
17
.github/workflows/build_documentation.yaml
vendored
Normal file
17
.github/workflows/build_documentation.yaml
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
name: Build documentation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- doc-builder*
|
||||
- v*-release
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: kernels
|
||||
secrets:
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
15
.github/workflows/build_pr_documentation.yaml
vendored
Normal file
15
.github/workflows/build_pr_documentation.yaml
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
name: Build PR Documentation
|
||||
|
||||
on: pull_request
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
package: kernels
|
21
.github/workflows/lint.yml
vendored
21
.github/workflows/lint.yml
vendored
@ -8,3 +8,24 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Run ruff
|
||||
uses: astral-sh/ruff-action@v3
|
||||
|
||||
black:
|
||||
name: Run black check
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
UV_PYTHON_PREFERENCE: only-managed
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv and set the python version
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: 3.12
|
||||
|
||||
- name: Install black
|
||||
run: uv pip install black
|
||||
|
||||
- name: Check formatting
|
||||
run: |
|
||||
uv run black --check src
|
||||
uv run black --check tests
|
||||
|
19
.github/workflows/test.yml
vendored
19
.github/workflows/test.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
max-parallel: 4
|
||||
matrix:
|
||||
python-version: ["3.10", "3.12"]
|
||||
torch-version: ["2.6.0", "2.7.0"]
|
||||
torch-version: ["2.7.0", "2.8.0"]
|
||||
|
||||
env:
|
||||
UV_PYTHON_PREFERENCE: only-managed
|
||||
@ -51,7 +51,15 @@ jobs:
|
||||
run: uv run mypy src/kernels
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest tests
|
||||
run: |
|
||||
uv run pytest tests
|
||||
|
||||
- name: Run staging tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_STAGING_TOKEN }}
|
||||
run: |
|
||||
HUGGINGFACE_CO_STAGING=true uv run pytest --token -m "is_staging_test" tests/
|
||||
if: matrix.python_version == '3.10' && matrix.torch-version == '2.7.0'
|
||||
|
||||
- name: Check kernel conversion
|
||||
run: |
|
||||
@ -63,7 +71,12 @@ jobs:
|
||||
- name: Check README generation
|
||||
# For now, just checks that generation doesn't fail.
|
||||
run: |
|
||||
uv run kernels generate-readme kernels-community/triton-layer-norm --revision docs
|
||||
uv run kernels generate-readme kernels-community/triton-layer-norm
|
||||
|
||||
- name: Check kernel check
|
||||
run: |
|
||||
uv pip install kernel-abi-check
|
||||
kernels check kernels-community/activation
|
||||
|
||||
- name: Import check without torch
|
||||
run: |
|
||||
|
16
.github/workflows/upload_pr_documentation.yaml
vendored
Normal file
16
.github/workflows/upload_pr_documentation.yaml
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
name: Upload PR Documentation
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Build PR Documentation"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
||||
with:
|
||||
package_name: kernels
|
||||
secrets:
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
|
8
Makefile
Normal file
8
Makefile
Normal file
@ -0,0 +1,8 @@
|
||||
.PHONY: style
|
||||
|
||||
export check_dirs := src examples tests
|
||||
|
||||
style:
|
||||
black ${check_dirs}
|
||||
isort ${check_dirs}
|
||||
ruff check ${check_dirs} --fix
|
14
README.md
14
README.md
@ -56,10 +56,12 @@ the Hub.
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- [Using layers](docs/layers.md)
|
||||
- [Locking kernel versions](docs/locking.md)
|
||||
- [Environment variables](docs/env.md)
|
||||
- [Using kernels in a Docker container](docs/docker.md)
|
||||
- [Kernel requirements](docs/kernel-requirements.md)
|
||||
- [Frequently Asked Questions](docs/faq.md)
|
||||
- [Introduction](docs/source/index.md)
|
||||
- [Installation](docs/source/installation.md)
|
||||
- [Basic usage](docs/source/basic-usage.md)
|
||||
- [Using layers](docs/source/layers.md)
|
||||
- [Locking kernel/layer versions](docs/source/locking.md)
|
||||
- [Environment variables](docs/source/env.md)
|
||||
- [Kernel requirements](docs/source/kernel-requirements.md)
|
||||
- [Frequently Asked Questions](docs/source/faq.md)
|
||||
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||
|
@ -1,8 +0,0 @@
|
||||
# Using kernels in a Docker container
|
||||
|
||||
build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
|
||||
|
||||
```bash
|
||||
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
|
||||
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
|
||||
```
|
13
docs/faq.md
13
docs/faq.md
@ -1,13 +0,0 @@
|
||||
# FAQ
|
||||
|
||||
## Why is the kernelization step needed?
|
||||
|
||||
In earlier versions of `kernels`, a layer's `forward` was replaced by
|
||||
`use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The
|
||||
new `forward` would dispatch to a kernel based on the device type,
|
||||
whether a model was training, etc. However, this approach was
|
||||
fundamentally incompatible with `torch.compile` since it relied
|
||||
on data-dependent branching.
|
||||
|
||||
To avoid branching, we have to make dispatch decisions ahead of time,
|
||||
which is what the `kernelize` function does.
|
134
docs/layers.md
134
docs/layers.md
@ -1,134 +0,0 @@
|
||||
# Layers
|
||||
|
||||
A kernel can provide layers in addition to kernel functions. A layer from
|
||||
the Hub can replace the `forward` method of an existing layer for a certain
|
||||
device type. This makes it possible to provide more performant kernels for
|
||||
existing layers.
|
||||
|
||||
See [Kernel requirements](kernel-requirements.md) for more information the
|
||||
requirements of Hub layers.
|
||||
|
||||
## Making a layer extensible with kernels from the hub
|
||||
|
||||
### Using a decorator
|
||||
|
||||
A layer can be made extensible with the `use_kernel_forward_from_hub`
|
||||
decorator. For example:
|
||||
|
||||
```python
|
||||
@use_kernel_forward_from_hub("SiluAndMul")
|
||||
class SiluAndMul(nn.Module):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
d = input.shape[-1] // 2
|
||||
return F.silu(input[..., :d]) * input[..., d:]
|
||||
```
|
||||
|
||||
The decorator does not change the behavior of the class -- it annotates
|
||||
the class with the given name (here `SiluAndMul`). The `kernelize` function
|
||||
described below uses this name to look up kernels for the layer.
|
||||
|
||||
### External layers
|
||||
|
||||
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
|
||||
decorator can be made extensible using the `replace_kernel_forward_from_hub`
|
||||
function:
|
||||
|
||||
```python
|
||||
from somelibrary import SiluAndMul
|
||||
|
||||
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
|
||||
```
|
||||
|
||||
**Warning:** we strongly recommend using layers with a decorator, since
|
||||
it signifies that the maintainer intends to keep the `forward` signature
|
||||
compatible with layers from the hub.
|
||||
|
||||
## Kernelizing a model
|
||||
|
||||
A model will not use Hub kernels by default, even if it contains extensible
|
||||
layers. To enable the use of Hub kernels in the model, it needs to be
|
||||
'kernelized' using the `kernelize` function. This function traverses the
|
||||
model graph and replaces the `forward` methods of extensible layers for which
|
||||
Hub kernels are registered. Kernelize can be used as follows:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model)
|
||||
```
|
||||
|
||||
**Note:** the `kernelize` function modifies the model in-place, the model
|
||||
itself is returned as a convenience.
|
||||
|
||||
### Kernel device
|
||||
|
||||
Kernels can be registered per device type. For instance, separate `cuda` and
|
||||
`metal` kernels could be registered for the name `SiluAndMul`. By default,
|
||||
`kernelize` will try to infer the device type from the model's parameters.
|
||||
You can pass the device type to `kernelize` if the device type cannot be
|
||||
inferred (e.g. because the model has no parameters):
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, device="cuda")
|
||||
```
|
||||
|
||||
### `torch.compile`
|
||||
|
||||
Not all Hub kernels support `torch.compile`. If you want to compile a model
|
||||
after kernelizing it, pass the `needs_torch_compile` argument to ensure that
|
||||
only kernels that support `torch.compile` will be loaded:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, needs_torch_compile=True)
|
||||
```
|
||||
|
||||
### Fallback forward
|
||||
|
||||
The `needs_torch_compile` argument will fall back to the layer's original
|
||||
`forward` if the registered kernels does not support `torch.compile`. You
|
||||
can let `kernelize` raise an exception instead by using `use_fallback=False`:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, needs_torch_compile=True, use_fallback=False)
|
||||
```
|
||||
|
||||
This can be useful if you want to guarantee that Hub kernels are used.
|
||||
|
||||
## Registering a hub kernel for a layer
|
||||
|
||||
`kernelize`` relies on kernel mappings to find Hub kernels for layers.
|
||||
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
|
||||
the Hub. For example:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
revision="layers",
|
||||
)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
You can register such a mapping using `register_kernel_mapping`:
|
||||
|
||||
```python
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
```
|
||||
|
||||
This will register the kernel mapping in the current context, which is
|
||||
normally global. It is recommended to scope the mapping to where it is
|
||||
used with the `use_kernel_mapping` context manager:
|
||||
|
||||
```python
|
||||
with use_kernel_mapping(kernel_layer_mapping):
|
||||
# Use the layer for which the mapping is applied.
|
||||
model = kernelize(model)
|
||||
```
|
||||
|
||||
This ensures that the mapping is not active anymore outside the
|
||||
`with`-scope.
|
30
docs/source/_toctree.yml
Normal file
30
docs/source/_toctree.yml
Normal file
@ -0,0 +1,30 @@
|
||||
- sections:
|
||||
- local: index
|
||||
title: Introduction
|
||||
- local: installation
|
||||
title: Installation
|
||||
title: Getting started
|
||||
- sections:
|
||||
- local: basic-usage
|
||||
title: Basic Usage
|
||||
- local: layers
|
||||
title: Using Layers
|
||||
- local: locking
|
||||
title: Locking Kernel Versions
|
||||
- local: env
|
||||
title: Environment Variables
|
||||
- local: faq
|
||||
title: FAQ
|
||||
title: Usage Guide
|
||||
- sections:
|
||||
- local: api/kernels
|
||||
title: Kernels
|
||||
- local: api/layers
|
||||
title: Layers
|
||||
- local: cli
|
||||
title: Kernels CLI
|
||||
title: API Reference
|
||||
- sections:
|
||||
- local: kernel-requirements
|
||||
title: Kernel Requirements
|
||||
title: Developer Guide
|
25
docs/source/api/kernels.md
Normal file
25
docs/source/api/kernels.md
Normal file
@ -0,0 +1,25 @@
|
||||
# Kernels API Reference
|
||||
|
||||
## Main Functions
|
||||
|
||||
### get_kernel
|
||||
|
||||
[[autodoc]] kernels.get_kernel
|
||||
|
||||
### get_local_kernel
|
||||
|
||||
[[autodoc]] kernels.get_local_kernel
|
||||
|
||||
### has_kernel
|
||||
|
||||
[[autodoc]] kernels.has_kernel
|
||||
|
||||
## Loading locked kernels
|
||||
|
||||
### load_kernel
|
||||
|
||||
[[autodoc]] kernels.load_kernel
|
||||
|
||||
### get_locked_kernel
|
||||
|
||||
[[autodoc]] kernels.get_locked_kernel
|
49
docs/source/api/layers.md
Normal file
49
docs/source/api/layers.md
Normal file
@ -0,0 +1,49 @@
|
||||
# Layers API Reference
|
||||
|
||||
## Making layers kernel-aware
|
||||
|
||||
### use_kernel_forward_from_hub
|
||||
|
||||
[[autodoc]] kernels.use_kernel_forward_from_hub
|
||||
|
||||
### replace_kernel_forward_from_hub
|
||||
|
||||
[[autodoc]] kernels.replace_kernel_forward_from_hub
|
||||
|
||||
## Registering kernel mappings
|
||||
|
||||
### use_kernel_mapping
|
||||
|
||||
[[autodoc]] kernels.use_kernel_mapping
|
||||
|
||||
### register_kernel_mapping
|
||||
|
||||
[[autodoc]] kernels.register_kernel_mapping
|
||||
|
||||
## Kernelizing a model
|
||||
|
||||
### kernelize
|
||||
|
||||
[[autodoc]] kernels.kernelize
|
||||
|
||||
## Classes
|
||||
|
||||
### Device
|
||||
|
||||
[[autodoc]] kernels.Device
|
||||
|
||||
### Mode
|
||||
|
||||
[[autodoc]] kernels.Mode
|
||||
|
||||
### LayerRepository
|
||||
|
||||
[[autodoc]] kernels.LayerRepository
|
||||
|
||||
### LocalLayerRepository
|
||||
|
||||
[[autodoc]] kernels.LocalLayerRepository
|
||||
|
||||
### LockedLayerRepository
|
||||
|
||||
[[autodoc]] kernels.LockedLayerRepository
|
50
docs/source/basic-usage.md
Normal file
50
docs/source/basic-usage.md
Normal file
@ -0,0 +1,50 @@
|
||||
# Basic Usage
|
||||
|
||||
## Loading Kernels
|
||||
|
||||
Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from kernels import get_kernel
|
||||
|
||||
# Download optimized kernels from the Hugging Face hub
|
||||
activation = get_kernel("kernels-community/activation")
|
||||
|
||||
# Create a random tensor
|
||||
x = torch.randn((10, 10), dtype=torch.float16, device="cuda")
|
||||
|
||||
# Run the kernel
|
||||
y = torch.empty_like(x)
|
||||
activation.gelu_fast(y, x)
|
||||
|
||||
print(y)
|
||||
```
|
||||
|
||||
### Using version bounds
|
||||
|
||||
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
|
||||
You can specify which version to download using Python version specifiers:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from kernels import get_kernel
|
||||
|
||||
activation = get_kernel("kernels-community/activation", version=">=0.0.4,<0.1.0")
|
||||
```
|
||||
|
||||
This will get the latest kernel tagged `v0.0.z` where `z` is at least 4. It
|
||||
is strongly recommended to specify a version bound, since a kernel author
|
||||
might push incompatible changes to the `main` branch.
|
||||
|
||||
## Checking Kernel Availability
|
||||
|
||||
You can check if a specific kernel is available for your environment:
|
||||
|
||||
```python
|
||||
from kernels import has_kernel
|
||||
|
||||
# Check if kernel is available for current environment
|
||||
is_available = has_kernel("kernels-community/activation")
|
||||
print(f"Kernel available: {is_available}")
|
||||
```
|
58
docs/source/cli.md
Normal file
58
docs/source/cli.md
Normal file
@ -0,0 +1,58 @@
|
||||
# Kernels CLI Reference
|
||||
|
||||
## Main Functions
|
||||
|
||||
### kernels check
|
||||
|
||||
You can use `kernels check` to test compliance of a kernel on the Hub.
|
||||
This currently checks that the kernel:
|
||||
|
||||
- Supports the currently-required Python ABI version.
|
||||
- Works on supported operating system versions.
|
||||
|
||||
For example:
|
||||
|
||||
```bash
|
||||
$ kernels check kernels-community/flash-attn3
|
||||
Checking variant: torch28-cxx11-cu128-aarch64-linux
|
||||
🐍 Python ABI 3.9 compatible
|
||||
🐧 manylinux_2_28 compatible
|
||||
[...]
|
||||
```
|
||||
|
||||
### kernels to-wheel
|
||||
|
||||
We strongly recommend downloading kernels from the Hub using the `kernels`
|
||||
package, since this comes with large [benefits](index.md) over using Python
|
||||
wheels. That said, some projects may require deployment of kernels as
|
||||
wheels. The `kernels` utility provides a simple solution to this. You can
|
||||
convert any Hub kernel into a set of wheels with the `to-wheel` command:
|
||||
|
||||
```bash
|
||||
$ kernels to-wheel drbh/img2grey 1.1.2
|
||||
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||
☸ img2grey-1.1.2+torch26cu124cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
|
||||
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_aarch64.whl
|
||||
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
|
||||
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
|
||||
☸ img2grey-1.1.2+torch26cu118cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||
☸ img2grey-1.1.2+torch26cu124cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||
☸ img2grey-1.1.2+torch26cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||
☸ img2grey-1.1.2+torch27cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||
```
|
||||
|
||||
### kernels upload
|
||||
|
||||
Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload
|
||||
your kernel builds to the Hub. To know the supported arguments run: `kernels upload -h`.
|
||||
|
||||
**Notes**:
|
||||
|
||||
- This will take care of creating a repository on the Hub with the `repo_id` provided.
|
||||
- If a repo with the `repo_id` already exists and if it contains a `build` with the build variant
|
||||
being uploaded, it will attempt to delete the files existing under it.
|
||||
- Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub.
|
51
docs/source/faq.md
Normal file
51
docs/source/faq.md
Normal file
@ -0,0 +1,51 @@
|
||||
# FAQ
|
||||
|
||||
## Kernel layers
|
||||
|
||||
### Why is the kernelization step needed as a separate step?
|
||||
|
||||
In earlier versions of `kernels`, a layer's `forward` method was replaced
|
||||
by `use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`.
|
||||
The new `forward` would dispatch to a kernel based on the device type,
|
||||
whether a model was training, etc. However, this approach was
|
||||
fundamentally incompatible with `torch.compile` since it relied
|
||||
on data-dependent branching.
|
||||
|
||||
To avoid branching, we have to make dispatch decisions ahead of time,
|
||||
which is what the `kernelize` function does.
|
||||
|
||||
### Why does kernelization only replace `forward` methods?
|
||||
|
||||
There are some other possible approaches. The first is to completely
|
||||
replace existing layers by kernel layers. However, since this would
|
||||
permit free-form layer classes, it would be much harder to validate
|
||||
that layers are fully compatible with the layers that they are
|
||||
replacing. For instance, they could have completely different member
|
||||
variables. Besides that, we would also need to hold on to the original
|
||||
layers, in case we need to revert to the base layers when the model
|
||||
is `kernelize`d again with different options.
|
||||
|
||||
A second approach would be to make an auxiliary layer that wraps the
|
||||
original layer and the kernel layer and dispatches to the kernel layer.
|
||||
This wouldn't have the issues of the first approach, because kernel layers
|
||||
could be similarly strict as they are now, and we would still have access
|
||||
to the original layers when `kernelize`-ing the model again. However,
|
||||
this would change the graph structure of the model and would break use
|
||||
cases where programs access the model internals (e.g.
|
||||
`model.layers[0].attention.query_weight`) or rely on the graph structure
|
||||
in other ways.
|
||||
|
||||
The approach of `forward`-replacement is the least invasive, because
|
||||
it preserves the original model graph. It is also reversible, since
|
||||
even though the `forward` of a layer _instance_ might be replaced,
|
||||
the corresponding class still has the original `forward`.
|
||||
|
||||
## Misc
|
||||
|
||||
### How can I disable kernel reporting in the user-agent?
|
||||
|
||||
By default, we collect telemetry when a call to `get_kernel()` is made.
|
||||
This only includes the `kernels` version, `torch` version, and the build
|
||||
information for the kernel being requested.
|
||||
|
||||
You can disable this by setting `export DISABLE_TELEMETRY=yes`.
|
20
docs/source/index.md
Normal file
20
docs/source/index.md
Normal file
@ -0,0 +1,20 @@
|
||||
# Kernels
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
|
||||
</div>
|
||||
|
||||
The Kernel Hub allows Python libraries and applications to load compute
|
||||
kernels directly from the [Hub](https://hf.co/). To support this kind
|
||||
of dynamic loading, Hub kernels differ from traditional Python kernel
|
||||
packages in that they are made to be:
|
||||
|
||||
- **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`.
|
||||
- **Unique**: multiple versions of the same kernel can be loaded in the
|
||||
same Python process.
|
||||
- **Compatible**: kernels must support all recent versions of Python and
|
||||
the different PyTorch build configurations (various CUDA versions
|
||||
and C++ ABIs). Furthermore, older C library versions must be supported.
|
||||
|
||||
You can [search for kernels](https://huggingface.co/models?other=kernel) on
|
||||
the Hub.
|
16
docs/source/installation.md
Normal file
16
docs/source/installation.md
Normal file
@ -0,0 +1,16 @@
|
||||
# Installation
|
||||
|
||||
Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):
|
||||
|
||||
```bash
|
||||
pip install kernels
|
||||
```
|
||||
|
||||
# Using kernels in a Docker container
|
||||
|
||||
Build and run the reference `examples/basic.py` in a Docker container with the following commands:
|
||||
|
||||
```bash
|
||||
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
|
||||
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
|
||||
```
|
@ -34,28 +34,51 @@ Kernels are versioned on the Hub using Git tags. Version tags must be of
|
||||
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
|
||||
to resolve the version constraints.
|
||||
|
||||
We recommend using [semver](https://semver.org/) to version kernels.
|
||||
|
||||
## Native Python module
|
||||
|
||||
Kernels will typically contain a native Python module with precompiled
|
||||
compute kernels and bindings. This module must fulfill the following
|
||||
requirements:
|
||||
compute kernels and bindings. This module must fulfill the requirements
|
||||
outlined in this section. For all operating systems, a kernel must not
|
||||
have dynamic library dependencies outside:
|
||||
|
||||
- Torch;
|
||||
- CUDA/ROCm libraries installed as dependencies of Torch.
|
||||
|
||||
## Compatibility with torch.compile
|
||||
|
||||
The Kernel Hub also encourages to write the kernels in a `torch.compile`
|
||||
compliant way. This helps to ensure that the kernels are compatible with
|
||||
`torch.compile` without introducing any graph breaks and triggering
|
||||
recompilation which can limit the benefits of compilation.
|
||||
|
||||
[Here](https://github.com/huggingface/kernel-builder/blob/d1ee9bf9301ac8c5199099d90ee1c9d5c789d5ba/examples/relu-backprop-compile/tests/test_relu.py#L162) is a simple test example which checks for graph breaks and
|
||||
recompilation triggers during `torch.compile`.
|
||||
|
||||
### Linux
|
||||
|
||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||
for compatibility with Python 3.9 and later.
|
||||
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
|
||||
This means that the extension **must not** use symbols versions higher than:
|
||||
|
||||
- GLIBC 2.28
|
||||
- GLIBCXX 3.4.24
|
||||
- CXXABI 1.3.11
|
||||
- GCC 7.0.0
|
||||
|
||||
These requirement can be checked with the ABI checker (see below).
|
||||
These requirements can be checked with the ABI checker (see below).
|
||||
|
||||
- No dynamic library dependencies outside:
|
||||
### macOS
|
||||
|
||||
- Torch;
|
||||
- CUDA/ROCm libraries installed as dependencies of Torch.
|
||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||
for compatibility with Python 3.9 and later.
|
||||
- macOS deployment target 15.0.
|
||||
- Metal 3.0 (`-std=metal3.0`).
|
||||
|
||||
The ABI3 requirement can be checked with the ABI checker (see below).
|
||||
|
||||
### ABI checker
|
||||
|
||||
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
|
||||
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
|
323
docs/source/layers.md
Normal file
323
docs/source/layers.md
Normal file
@ -0,0 +1,323 @@
|
||||
# Layers
|
||||
|
||||
A kernel can provide layers in addition to kernel functions. A layer from
|
||||
the Hub can replace the `forward` method of an existing layer for a certain
|
||||
device type. This makes it possible to provide more performant kernels for
|
||||
existing layers.
|
||||
|
||||
See [Kernel requirements](kernel-requirements.md) for more information on the
|
||||
requirements of Hub layers.
|
||||
|
||||
## Making a layer extensible with kernels from the hub
|
||||
|
||||
### Using a decorator
|
||||
|
||||
A layer can be made extensible with the `use_kernel_forward_from_hub`
|
||||
decorator. For example:
|
||||
|
||||
```python
|
||||
@use_kernel_forward_from_hub("SiluAndMul")
|
||||
class SiluAndMul(nn.Module):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
d = input.shape[-1] // 2
|
||||
return F.silu(input[..., :d]) * input[..., d:]
|
||||
```
|
||||
|
||||
The decorator does not change the behavior of the class -- it annotates
|
||||
the class with the given name (here `SiluAndMul`). The `kernelize` function
|
||||
described below uses this name to look up kernels for the layer.
|
||||
|
||||
### External layers
|
||||
|
||||
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
|
||||
decorator can be made extensible using the `replace_kernel_forward_from_hub`
|
||||
function:
|
||||
|
||||
```python
|
||||
from somelibrary import SiluAndMul
|
||||
|
||||
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
|
||||
```
|
||||
|
||||
**Warning:** we strongly recommend using layers with a decorator, since
|
||||
it signifies that the maintainer intends to keep the `forward` signature
|
||||
compatible with layers from the hub.
|
||||
|
||||
## Kernelizing a model
|
||||
|
||||
A model will not use Hub kernels by default, even if it contains extensible
|
||||
layers. To enable the use of Hub kernels in the model, it needs to be
|
||||
'kernelized' using the `kernelize` function. This function traverses the
|
||||
model graph and replaces the `forward` methods of extensible layers for which
|
||||
Hub kernels are registered. `kernelize` can be used as follows:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.INFERENCE)
|
||||
```
|
||||
|
||||
The `kernelize` function modifies the model in-place, the model itself is
|
||||
returned as a convenience. The `mode` specifies that the model will be used
|
||||
in inference. Similarly, you can ask `kernelize` to prepare the model for
|
||||
training:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.TRAINING)
|
||||
```
|
||||
|
||||
A model that is kernelized for training can also be used for inference, but
|
||||
not the other way around. If you want to change the mode of the kernelized
|
||||
model, you can just run `kernelize` on the model again with the new mode.
|
||||
|
||||
If you want to compile a model with `torch.compile`, this should be indicated
|
||||
in the mode as well. You can do this by combining `Mode.INFERENCE` or
|
||||
`Mode.TRAINING` with `Mode.TORCH_COMPILE` using the set union (`|`) operator:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
|
||||
# Inference
|
||||
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||
|
||||
# Training
|
||||
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
```
|
||||
|
||||
### Kernel device
|
||||
|
||||
Kernels can be registered per device type. For instance, separate `cuda` and
|
||||
`metal` kernels could be registered for the name `SiluAndMul`. By default,
|
||||
`kernelize` will try to infer the device type from the model's parameters.
|
||||
You can pass the device type to `kernelize` if the device type cannot be
|
||||
inferred (e.g. because the model has no parameters):
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
|
||||
```
|
||||
|
||||
### Fallback `forward`
|
||||
|
||||
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
|
||||
kernel does not support backward passes or `torch.compile` respectively,
|
||||
`kernelize` will fall back to the original, non-kernelized, layer. You
|
||||
can let `kernelize` raise an exception instead by using `use_fallback=False`:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
|
||||
```
|
||||
|
||||
This can be useful if you want to guarantee that Hub kernels are used.
|
||||
|
||||
### Inspecting which kernels are used
|
||||
|
||||
The kernels that are used are logged at the `INFO` level by `kernelize`.
|
||||
See the [Python logging](https://docs.python.org/3/library/logging.html)
|
||||
documentation for information on how to configure logging.
|
||||
|
||||
## Registering a hub kernel for a layer
|
||||
|
||||
`kernelize` relies on kernel mappings to find Hub kernels for layers.
|
||||
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
|
||||
the Hub. For example:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
"rocm": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
You can register such a mapping using `register_kernel_mapping`:
|
||||
|
||||
```python
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
```
|
||||
|
||||
This will register the kernel mapping in the current context, which is
|
||||
normally global. It is recommended to scope the mapping to where it is
|
||||
used with the `use_kernel_mapping` context manager:
|
||||
|
||||
```python
|
||||
with use_kernel_mapping(kernel_layer_mapping):
|
||||
# Use the layer for which the mapping is applied.
|
||||
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
```
|
||||
|
||||
This ensures that the mapping is not active anymore outside the
|
||||
`with`-scope.
|
||||
|
||||
### Using version bounds
|
||||
|
||||
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
|
||||
You can specify which version of the kernel to download using Python version
|
||||
specifiers:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
version=">=0.0.4,<0.1.0",
|
||||
),
|
||||
"rocm": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
version=">=0.0.4,<0.1.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This will get the layer from latest kernel tagged `v0.0.z` where `z` is at
|
||||
least 4. It is strongly recommended to specify a version bound, since a
|
||||
kernel author might push incompatible changes to the `main` branch.
|
||||
|
||||
### Registering kernels for specific modes
|
||||
|
||||
You might want to register two different kernels for a particular layer,
|
||||
where one kernel is optimized for a specific mode. You can do so by
|
||||
registering layer repositories for specific modes. For example:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/activation-inference-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-community/activation-training-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The `kernelize` function will attempt to use the following registered
|
||||
kernels for a given mode:
|
||||
|
||||
- `INFERENCE`: `INFERENCE` → `INFERENCE | TORCH_COMPILE` → `TRAINING` →
|
||||
`TRAINING | TORCH_COMPILE` → `FALLBACK`
|
||||
- `INFERENCE | TORCH_COMPILE`: `INFERENCE | TORCH_COMPILE` →
|
||||
`TRAINING | TORCH_COMPILE` → `FALLBACK`
|
||||
- `TRAINING`: `TRAINING` → `TRAINING | TORCH_COMPILE` → `FALLBACK`
|
||||
- `TRAINING | TORCH_COMPILE`: `TRAINING | TORCH_COMPILE` → `FALLBACK`
|
||||
|
||||
`Mode.FALLBACK` is a special mode that is used when no other mode matches. It
|
||||
is also used when a kernel is registered without a mode, as described in the
|
||||
previous section.
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": {
|
||||
Mode.FALLBACK: LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/activation-inference-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="kernels-community/activation-training-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In this case, both `Mode.INFERENCE | Mode.TORCH_COMPILE` and
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` will use the `Mode.FALLBACK` kernel,
|
||||
since the other kernels do not support `torch.compile`.
|
||||
|
||||
### Registering kernels for specific CUDA capabilities
|
||||
|
||||
Some kernels only work with newer CUDA architectures. For instance, some
|
||||
kernels require capability 9.0 for the TMA unit on Hopper GPUs. `kernels`
|
||||
supports registering layers for a range of CUDA capabilities. To do so,
|
||||
you need to register the layer for a `Device` with type `cuda` and
|
||||
set the supported range of CUDA capabilities with using `CUDAProperties`:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
Device(
|
||||
type="cuda",
|
||||
properties=CUDAProperties(
|
||||
min_capability=75, max_capability=89
|
||||
),
|
||||
): LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Device(
|
||||
type="cuda",
|
||||
properties=CUDAProperties(
|
||||
min_capability=90, max_capability=sys.maxsize
|
||||
),
|
||||
): LayerRepository(
|
||||
repo_id="kernels-community/activation-hopper",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Capabilities behave as follows:
|
||||
|
||||
- The minimum and maximum capabilities are inclusive.
|
||||
- When a new kernel is registered with the same min/max capabilities as
|
||||
an existing kernel, the new kernel will replace the old kernel.
|
||||
- When there are multiple kernels that support a capability, the kernel
|
||||
with the smaller capability interval will be used. E.g. given:
|
||||
- `KernelA` with `min_capability=80` and `max_capability=89`;
|
||||
- `KernelB` with `min_capability=75` and `max_capability=89`;
|
||||
- `kernelize` runs on a system with capability 8.6.
|
||||
|
||||
Then `KernelA` will be used because the interval 80..89 is smaller
|
||||
than 75..89. The motivation is that kernels with smaller ranges
|
||||
tend to be more optimized for a specific set of GPUs. **This behavior
|
||||
might still change in the future.**
|
||||
|
||||
### Registering kernels for specific ROCm capabilities
|
||||
|
||||
Registering kernels for the ROCm architecture follows the exact same
|
||||
pattern as CUDA kernels, using `min_capability` and `max_capability` to restrict
|
||||
a kernel to a range of ROCm capabilities.
|
||||
|
||||
### Loading from a local repository for testing
|
||||
|
||||
The `LocalLayerRepository` class is provided to load a repository from
|
||||
a local directory. For example:
|
||||
|
||||
```python
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"SiluAndMul": {
|
||||
"cuda": LocalLayerRepository(
|
||||
repo_path="/home/daniel/kernels/activation",
|
||||
package_name="activation",
|
||||
layer_name="SiluAndMul",
|
||||
)
|
||||
}
|
||||
},
|
||||
inherit_mapping=False,
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
```
|
@ -1,4 +1,4 @@
|
||||
# Locking kernel versions
|
||||
# Locking kernel/layer versions
|
||||
|
||||
Projects that use `setuptools` can lock the kernel versions that should be
|
||||
used. First specify the accepted versions in `pyproject.toml` and make
|
||||
@ -26,6 +26,24 @@ activation = get_locked_kernel("kernels-community/activation")
|
||||
**Note:** the lock file is included in the package metadata, so it will only be visible
|
||||
to `kernels` after doing an (editable or regular) installation of your project.
|
||||
|
||||
## Locked kernel layers
|
||||
|
||||
Locking is also supported for kernel layers. To use locked layers, register them
|
||||
with the `LockedLayerRepository` class:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": LockedLayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
```
|
||||
|
||||
## Pre-downloading locked kernels
|
||||
|
||||
Locked kernels can be pre-downloaded by running `kernels download .` in your
|
@ -20,11 +20,11 @@ activation.gelu_fast(y, x)
|
||||
print("Kernel successfully executed")
|
||||
|
||||
# Check results
|
||||
expected = torch.tensor([
|
||||
[0.8408, 1.9551, 2.9961],
|
||||
[4.0000, 5.0000, 6.0000],
|
||||
[7.0000, 8.0000, 9.0000]
|
||||
], device='cuda:0', dtype=torch.float16)
|
||||
expected = torch.tensor(
|
||||
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||
device="cuda:0",
|
||||
dtype=torch.float16,
|
||||
)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
print("Calculated values are exact")
|
||||
|
19
flake.lock
generated
19
flake.lock
generated
@ -58,33 +58,32 @@
|
||||
"nixpkgs": "nixpkgs"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1749025620,
|
||||
"narHash": "sha256-V/r5KOp8FRC5n3MINDzTeS3pZz57SasFVzx12WQRQ8U=",
|
||||
"lastModified": 1754038838,
|
||||
"narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
|
||||
"owner": "huggingface",
|
||||
"repo": "hf-nix",
|
||||
"rev": "7ab84ffad440c530162f528a96fa062530a6c8e4",
|
||||
"rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "huggingface",
|
||||
"ref": "torch-cxx11",
|
||||
"repo": "hf-nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1747820358,
|
||||
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
||||
"owner": "danieldk",
|
||||
"lastModified": 1752785354,
|
||||
"narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
||||
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "danieldk",
|
||||
"ref": "cudatoolkit-12.9-kernel-builder",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
|
11
flake.nix
11
flake.nix
@ -1,6 +1,6 @@
|
||||
{
|
||||
inputs = {
|
||||
hf-nix.url = "github:huggingface/hf-nix/torch-cxx11";
|
||||
hf-nix.url = "github:huggingface/hf-nix";
|
||||
nixpkgs.follows = "hf-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
@ -16,7 +16,7 @@
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
inherit (hf-nix.lib) config;
|
||||
config = hf-nix.lib.config system;
|
||||
overlays = [
|
||||
hf-nix.overlays.default
|
||||
];
|
||||
@ -24,8 +24,13 @@
|
||||
in
|
||||
{
|
||||
formatter = pkgs.nixfmt-tree;
|
||||
packages.kernel-abi-check = pkgs.python3.pkgs.callPackage ./nix/kernel-abi-check.nix {};
|
||||
devShells = with pkgs; rec {
|
||||
default = mkShell {
|
||||
nativeBuildInputs = [
|
||||
# For hf-doc-builder.
|
||||
nodejs
|
||||
];
|
||||
buildInputs =
|
||||
[
|
||||
black
|
||||
@ -36,6 +41,8 @@
|
||||
++ (with python3.pkgs; [
|
||||
docutils
|
||||
huggingface-hub
|
||||
(callPackage ./nix/kernel-abi-check.nix {})
|
||||
mktestdocs
|
||||
pytest
|
||||
pytest-benchmark
|
||||
pyyaml
|
||||
|
27
nix/kernel-abi-check.nix
Normal file
27
nix/kernel-abi-check.nix
Normal file
@ -0,0 +1,27 @@
|
||||
{
|
||||
buildPythonPackage,
|
||||
fetchPypi,
|
||||
rustPlatform,
|
||||
}:
|
||||
|
||||
buildPythonPackage rec {
|
||||
pname = "kernel-abi-check";
|
||||
version = "0.6.2";
|
||||
|
||||
src = fetchPypi {
|
||||
inherit version;
|
||||
pname = "kernel_abi_check";
|
||||
hash = "sha256-goWC7SK79FVNEvkp3bISBwbOqdSrmobANtrWIve9/Ys=";
|
||||
};
|
||||
|
||||
cargoDeps = rustPlatform.fetchCargoVendor {
|
||||
inherit pname version src sourceRoot;
|
||||
hash = "sha256-+1jdbKsDKmG+bf0NEVYMv8t7Meuge1z2cgYfbdB9q8A=";
|
||||
};
|
||||
|
||||
sourceRoot = "kernel_abi_check-${version}/bindings/python";
|
||||
|
||||
pyproject = true;
|
||||
|
||||
nativeBuildInputs = with rustPlatform; [ cargoSetupHook maturinBuildHook ];
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.6.0.dev0"
|
||||
version = "0.10.4"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
@ -12,7 +12,7 @@ license = { text = "Apache-2.0" }
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.9"
|
||||
dependencies = [
|
||||
"huggingface_hub>=0.26.0,<1.0",
|
||||
"huggingface_hub>=0.26.0,<2.0",
|
||||
"packaging>=20.0",
|
||||
"pyyaml>=6",
|
||||
"tomli>=2.0; python_version<'3.11'",
|
||||
@ -24,16 +24,21 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"mypy == 1.14.1",
|
||||
"pytest >=8",
|
||||
"mktestdocs>=0.2.5",
|
||||
"mypy>=1.15.0",
|
||||
"pytest>=8",
|
||||
# Whatever version is compatible with pytest.
|
||||
"pytest-benchmark",
|
||||
"torch >=2.5",
|
||||
"torch>=2.5",
|
||||
"types-pyyaml"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
abi-check = ["kernel-abi-check>=0.6.2,<0.7.0"]
|
||||
torch = ["torch"]
|
||||
docs = [
|
||||
"hf-doc-builder",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
kernels = "kernels.cli:main"
|
||||
@ -41,6 +46,9 @@ kernels = "kernels.cli:main"
|
||||
[project.entry-points."egg_info.writers"]
|
||||
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 119
|
||||
|
||||
[tool.ruff]
|
||||
exclude = [
|
||||
@ -67,4 +75,4 @@ line-length = 119
|
||||
# Ignored rules:
|
||||
# "E501" -> line length violation
|
||||
lint.ignore = ["E501"]
|
||||
lint.select = ["E", "F", "I", "W"]
|
||||
lint.select = ["E", "F", "W"]
|
||||
|
@ -1,4 +1,9 @@
|
||||
[pytest]
|
||||
markers =
|
||||
cuda_only: marks tests that should only hosts with CUDA GPUs
|
||||
rocm_only: marks tests that should only run on hosts with ROCm GPUs
|
||||
darwin_only: marks tests that should only run on macOS
|
||||
linux_only: marks tests that should only run on Linux
|
||||
xpu_only: marks tests that should only run on hosts with Intel XPUs
|
||||
npu_only: marks tests that should only run on Ascend NPUs
|
||||
token: enable tests that require a write token
|
||||
is_staging_test: Marks tests that should only run on a staging environment
|
||||
|
@ -1,6 +1,14 @@
|
||||
import importlib.metadata
|
||||
|
||||
__version__ = importlib.metadata.version("kernels")
|
||||
|
||||
from kernels.layer import (
|
||||
CUDAProperties,
|
||||
Device,
|
||||
LayerRepository,
|
||||
LocalLayerRepository,
|
||||
LockedLayerRepository,
|
||||
Mode,
|
||||
kernelize,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
@ -9,6 +17,7 @@ from kernels.layer import (
|
||||
)
|
||||
from kernels.utils import (
|
||||
get_kernel,
|
||||
get_local_kernel,
|
||||
get_locked_kernel,
|
||||
has_kernel,
|
||||
install_kernel,
|
||||
@ -16,16 +25,22 @@ from kernels.utils import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"CUDAProperties",
|
||||
"Device",
|
||||
"LayerRepository",
|
||||
"LocalLayerRepository",
|
||||
"LockedLayerRepository",
|
||||
"Mode",
|
||||
"get_kernel",
|
||||
"get_local_kernel",
|
||||
"get_locked_kernel",
|
||||
"has_kernel",
|
||||
"load_kernel",
|
||||
"install_kernel",
|
||||
"use_kernel_forward_from_hub",
|
||||
"use_kernel_mapping",
|
||||
"kernelize",
|
||||
"load_kernel",
|
||||
"register_kernel_mapping",
|
||||
"replace_kernel_forward_from_hub",
|
||||
"LayerRepository",
|
||||
"Device",
|
||||
"kernelize",
|
||||
"use_kernel_forward_from_hub",
|
||||
"use_kernel_mapping",
|
||||
]
|
||||
|
200
src/kernels/_interval_tree.py
Normal file
200
src/kernels/_interval_tree.py
Normal file
@ -0,0 +1,200 @@
|
||||
# AVL-balanced interval trees. We could use the intervaltree
|
||||
# packages, but it seems unmaintained and does not have type
|
||||
# annotations.
|
||||
|
||||
from typing import Generic, List, Optional, Tuple, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _Node(Generic[T]):
|
||||
"""A node in the interval tree."""
|
||||
|
||||
def __init__(self, start: int, end: int, data: T):
|
||||
self.start: int = start
|
||||
self.end: int = end
|
||||
self.data: T = data
|
||||
self.max_end: int = end
|
||||
self.left: Optional["_Node[T]"] = None
|
||||
self.right: Optional["_Node[T]"] = None
|
||||
self.height: int = 1
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Node({self.start}, {self.end})"
|
||||
|
||||
|
||||
class IntervalTree(Generic[T]):
|
||||
"""A data structure to hold and query (unique) intervals."""
|
||||
|
||||
root: Optional[_Node[T]]
|
||||
|
||||
def __init__(self):
|
||||
self.root = None
|
||||
|
||||
def insert(self, start: int, end: int, data: T) -> None:
|
||||
"""
|
||||
Inserts a new interval into the tree.
|
||||
|
||||
Args:
|
||||
start: The starting point of the interval.
|
||||
end: The ending point of the interval.
|
||||
data: The data associated with this interval.
|
||||
"""
|
||||
self.root = self._insert(self.root, start, end, data)
|
||||
|
||||
def _get_height(self, node: Optional[_Node[T]]) -> int:
|
||||
if not node:
|
||||
return 0
|
||||
return node.height
|
||||
|
||||
def _get_balance(self, node: Optional[_Node[T]]) -> int:
|
||||
if not node:
|
||||
return 0
|
||||
return self._get_height(node.left) - self._get_height(node.right)
|
||||
|
||||
def _update_node_attributes(self, node: _Node[T]) -> None:
|
||||
node.height = 1 + max(self._get_height(node.left), self._get_height(node.right))
|
||||
node.max_end = node.end
|
||||
if node.left:
|
||||
node.max_end = max(node.max_end, node.left.max_end)
|
||||
if node.right:
|
||||
node.max_end = max(node.max_end, node.right.max_end)
|
||||
|
||||
def _right_rotate(self, y: _Node[T]) -> _Node[T]:
|
||||
"""Performs a right rotation."""
|
||||
x = y.left
|
||||
assert x is not None
|
||||
T2 = x.right
|
||||
|
||||
x.right = y
|
||||
y.left = T2
|
||||
|
||||
self._update_node_attributes(y)
|
||||
self._update_node_attributes(x)
|
||||
|
||||
return x
|
||||
|
||||
def _left_rotate(self, x: _Node[T]) -> _Node[T]:
|
||||
"""Performs a left rotation."""
|
||||
y = x.right
|
||||
assert y is not None
|
||||
T2 = y.left
|
||||
|
||||
y.left = x
|
||||
x.right = T2
|
||||
|
||||
self._update_node_attributes(x)
|
||||
self._update_node_attributes(y)
|
||||
|
||||
return y
|
||||
|
||||
def _insert(
|
||||
self, node: Optional[_Node[T]], start: int, end: int, data: T
|
||||
) -> _Node[T]:
|
||||
"""Recursive helper to insert a new node and balance the tree."""
|
||||
if not node:
|
||||
return _Node(start, end, data)
|
||||
|
||||
# Replace the data if the interval already exists.
|
||||
if start == node.start and end == node.end:
|
||||
node.data = data
|
||||
return node
|
||||
|
||||
if start < node.start:
|
||||
node.left = self._insert(node.left, start, end, data)
|
||||
else:
|
||||
node.right = self._insert(node.right, start, end, data)
|
||||
|
||||
self._update_node_attributes(node)
|
||||
|
||||
balance = self._get_balance(node)
|
||||
|
||||
# Left Left Case
|
||||
if balance > 1 and node.left and start < node.left.start:
|
||||
return self._right_rotate(node)
|
||||
|
||||
# Right Right Case
|
||||
if balance < -1 and node.right and start >= node.right.start:
|
||||
return self._left_rotate(node)
|
||||
|
||||
# Left Right Case
|
||||
if balance > 1 and node.left and start >= node.left.start:
|
||||
node.left = self._left_rotate(node.left)
|
||||
return self._right_rotate(node)
|
||||
|
||||
# Right Left Case
|
||||
if balance < -1 and node.right and start < node.right.start:
|
||||
node.right = self._right_rotate(node.right)
|
||||
return self._left_rotate(node)
|
||||
|
||||
return node
|
||||
|
||||
def search(self, point: int) -> List[T]:
|
||||
"""
|
||||
Searches for all intervals that contain the given point.
|
||||
|
||||
Args:
|
||||
point: The point to search for.
|
||||
|
||||
Returns:
|
||||
A list of data items from all matching intervals.
|
||||
"""
|
||||
results: List[T] = []
|
||||
self._search(self.root, point, results)
|
||||
return results
|
||||
|
||||
def _search(self, node: Optional[_Node[T]], point: int, results: List[T]) -> None:
|
||||
"""Recursive helper to find all overlapping intervals."""
|
||||
if node is None or point > node.max_end:
|
||||
return
|
||||
|
||||
if node.left:
|
||||
self._search(node.left, point, results)
|
||||
|
||||
if node.start <= point <= node.end:
|
||||
results.append(node.data)
|
||||
|
||||
if point >= node.start and node.right:
|
||||
self._search(node.right, point, results)
|
||||
|
||||
def find_smallest_interval(self, point: int) -> Optional[T]:
|
||||
"""
|
||||
Finds the item with the most specific (smallest) range for a given point.
|
||||
|
||||
Args:
|
||||
point: The capability to look up.
|
||||
|
||||
Returns:
|
||||
The data of the best-matching item, or None if no match is found.
|
||||
"""
|
||||
matches: List[Tuple[int, int, T]] = []
|
||||
self._find_with_intervals(self.root, point, matches)
|
||||
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
# Return the smallest interval, sort by memory location when
|
||||
# there are multiple matches with the same interval size. This
|
||||
# is just to ensure that we can compare against a trivial
|
||||
# implementation in tests.
|
||||
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
|
||||
return best_match[2]
|
||||
|
||||
def _find_with_intervals(
|
||||
self,
|
||||
node: Optional[_Node[T]],
|
||||
point: int,
|
||||
results: List[Tuple[int, int, T]],
|
||||
) -> None:
|
||||
"""A modified search that collects interval ranges along with data."""
|
||||
if node is None or point > node.max_end:
|
||||
return
|
||||
|
||||
if node.left:
|
||||
self._find_with_intervals(node.left, point, results)
|
||||
|
||||
if node.start <= point <= node.end:
|
||||
results.append((node.start, node.end, node.data))
|
||||
|
||||
if point >= node.start and node.right:
|
||||
self._find_with_intervals(node.right, point, results)
|
52
src/kernels/_versions.py
Normal file
52
src/kernels/_versions.py
Normal file
@ -0,0 +1,52 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.hf_api import GitRefInfo
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
|
||||
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
|
||||
"""Get kernel versions that are available in the repository."""
|
||||
versions = {}
|
||||
for tag in HfApi().list_repo_refs(repo_id).tags:
|
||||
if not tag.name.startswith("v"):
|
||||
continue
|
||||
try:
|
||||
versions[Version(tag.name[1:])] = tag
|
||||
except InvalidVersion:
|
||||
continue
|
||||
|
||||
return versions
|
||||
|
||||
|
||||
def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo:
|
||||
"""
|
||||
Get the locks for a kernel with the given version spec.
|
||||
|
||||
The version specifier can be any valid Python version specifier:
|
||||
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
|
||||
"""
|
||||
versions = _get_available_versions(repo_id)
|
||||
requirement = SpecifierSet(version_spec)
|
||||
accepted_versions = sorted(requirement.filter(versions.keys()))
|
||||
|
||||
if len(accepted_versions) == 0:
|
||||
raise ValueError(
|
||||
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
|
||||
)
|
||||
|
||||
return versions[accepted_versions[-1]]
|
||||
|
||||
|
||||
def select_revision_or_version(
|
||||
repo_id: str, revision: Optional[str], version: Optional[str]
|
||||
) -> str:
|
||||
if revision is not None and version is not None:
|
||||
raise ValueError("Either a revision or a version must be specified, not both.")
|
||||
elif revision is None and version is None:
|
||||
revision = "main"
|
||||
elif version is not None:
|
||||
revision = resolve_version_spec_as_ref(repo_id, version).target_commit
|
||||
assert revision is not None
|
||||
return revision
|
142
src/kernels/check.py
Normal file
142
src/kernels/check.py
Normal file
@ -0,0 +1,142 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from kernel_abi_check import (
|
||||
BinaryFormat,
|
||||
IncompatibleAbi3Symbol,
|
||||
IncompatibleMacOSVersion,
|
||||
IncompatibleManylinuxSymbol,
|
||||
MissingMacOSVersion,
|
||||
NonAbi3Symbol,
|
||||
ObjectFile,
|
||||
)
|
||||
|
||||
from kernels.utils import CACHE_DIR
|
||||
|
||||
|
||||
def check_kernel(
|
||||
*, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
|
||||
):
|
||||
variants_path = (
|
||||
Path(
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=["build/*"],
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
)
|
||||
)
|
||||
/ "build"
|
||||
)
|
||||
|
||||
has_issues = False
|
||||
for variant_path in variants_path.iterdir():
|
||||
if not variant_path.is_dir():
|
||||
print(
|
||||
f"⛔ `build/` must only contain directories, found: {variant_path.name}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
has_issues = True
|
||||
continue
|
||||
|
||||
print(f"Checking variant: {variant_path.name}", file=sys.stderr)
|
||||
|
||||
indent = 2
|
||||
|
||||
for dylib_path in variant_path.rglob("*.so"):
|
||||
print_with_indent(
|
||||
indent,
|
||||
f"Dynamic library {dylib_path.relative_to(variant_path)}:",
|
||||
)
|
||||
|
||||
o = ObjectFile(dylib_path)
|
||||
has_issues |= check_abi3(o, python_abi, indent + 2)
|
||||
|
||||
# TODO: also check operating system
|
||||
if o.format() == BinaryFormat.ELF:
|
||||
has_issues |= check_manylinux(o, manylinux, indent + 2)
|
||||
elif o.format() == BinaryFormat.MACH_O:
|
||||
has_issues |= check_macos(o, macos, indent + 2)
|
||||
|
||||
if has_issues:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_abi3(object_file: ObjectFile, python_abi: str, indent: int) -> bool:
|
||||
has_issues = False
|
||||
violations = object_file.check_python_abi(python_abi)
|
||||
if violations != []:
|
||||
has_issues = True
|
||||
print_with_indent(
|
||||
indent,
|
||||
f"⛔ Found symbols that are incompatible with Python ABI {python_abi}:",
|
||||
)
|
||||
for violation in violations:
|
||||
if isinstance(violation, IncompatibleAbi3Symbol):
|
||||
print_with_indent(
|
||||
indent + 3,
|
||||
f"{violation.name}: {violation.version_added}",
|
||||
)
|
||||
elif isinstance(violation, NonAbi3Symbol):
|
||||
print_with_indent(
|
||||
indent + 3,
|
||||
f"{violation.name}",
|
||||
)
|
||||
else:
|
||||
print_with_indent(indent, f"🐍 Python ABI {python_abi} compatible")
|
||||
|
||||
return has_issues
|
||||
|
||||
|
||||
def check_macos(object_file: ObjectFile, macos: str, indent: int) -> bool:
|
||||
has_issues = False
|
||||
violations = object_file.check_macos(macos)
|
||||
if violations != []:
|
||||
has_issues = True
|
||||
print_with_indent(
|
||||
indent,
|
||||
f"⛔ Found incompatibility with macOS {macos}:",
|
||||
)
|
||||
|
||||
for violation in violations:
|
||||
if isinstance(violation, MissingMacOSVersion):
|
||||
print_with_indent(
|
||||
indent + 3,
|
||||
"shared library does not contain macOS version",
|
||||
)
|
||||
elif isinstance(violation, IncompatibleMacOSVersion):
|
||||
print_with_indent(
|
||||
indent + 3,
|
||||
f"shared library requires macOS {violation.version}",
|
||||
)
|
||||
else:
|
||||
print_with_indent(indent, f"🍏 compatible with macOS {macos}")
|
||||
|
||||
return has_issues
|
||||
|
||||
|
||||
def check_manylinux(object_file: ObjectFile, manylinux: str, indent: int) -> bool:
|
||||
has_issues = False
|
||||
violations = object_file.check_manylinux(manylinux)
|
||||
if violations != []:
|
||||
has_issues = True
|
||||
print_with_indent(
|
||||
indent,
|
||||
f"⛔ Found symbols that are incompatible with {manylinux}:",
|
||||
)
|
||||
|
||||
for violation in violations:
|
||||
if isinstance(violation, IncompatibleManylinuxSymbol):
|
||||
print_with_indent(
|
||||
indent + 3,
|
||||
f"{violation.name}_{violation.dep}: {violation.version}",
|
||||
)
|
||||
else:
|
||||
print_with_indent(indent, f"🐧 {manylinux} compatible")
|
||||
|
||||
return has_issues
|
||||
|
||||
|
||||
def print_with_indent(indent: int, message: str):
|
||||
print(f"{' ' * indent}{message}", file=sys.stderr)
|
@ -4,6 +4,8 @@ import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import create_repo, upload_folder, create_branch
|
||||
|
||||
from kernels.compat import tomllib
|
||||
from kernels.lockfile import KernelLock, get_kernel_locks
|
||||
from kernels.utils import install_kernel, install_kernel_all_variants
|
||||
@ -18,6 +20,31 @@ def main():
|
||||
)
|
||||
subparsers = parser.add_subparsers(required=True)
|
||||
|
||||
check_parser = subparsers.add_parser("check", help="Check a kernel for compliance")
|
||||
check_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
|
||||
check_parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default="main",
|
||||
help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
|
||||
)
|
||||
check_parser.add_argument("--macos", type=str, help="macOS version", default="15.0")
|
||||
check_parser.add_argument(
|
||||
"--manylinux", type=str, help="Manylinux version", default="manylinux_2_28"
|
||||
)
|
||||
check_parser.add_argument(
|
||||
"--python-abi", type=str, help="Python ABI version", default="3.9"
|
||||
)
|
||||
check_parser.set_defaults(
|
||||
func=lambda args: check_kernel(
|
||||
macos=args.macos,
|
||||
manylinux=args.manylinux,
|
||||
python_abi=args.python_abi,
|
||||
repo_id=args.repo_id,
|
||||
revision=args.revision,
|
||||
)
|
||||
)
|
||||
|
||||
download_parser = subparsers.add_parser("download", help="Download locked kernels")
|
||||
download_parser.add_argument(
|
||||
"project_dir",
|
||||
@ -31,6 +58,29 @@ def main():
|
||||
)
|
||||
download_parser.set_defaults(func=download_kernels)
|
||||
|
||||
upload_parser = subparsers.add_parser("upload", help="Upload kernels to the Hub")
|
||||
upload_parser.add_argument(
|
||||
"kernel_dir",
|
||||
type=Path,
|
||||
help="Directory of the kernel build",
|
||||
)
|
||||
upload_parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
help="Repository ID to use to upload to the Hugging Face Hub",
|
||||
)
|
||||
upload_parser.add_argument(
|
||||
"--branch",
|
||||
type=None,
|
||||
help="If set, the upload will be made to a particular branch of the provided `repo_id`.",
|
||||
)
|
||||
upload_parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="If the repository should be private.",
|
||||
)
|
||||
upload_parser.set_defaults(func=upload_kernels)
|
||||
|
||||
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
|
||||
lock_parser.add_argument(
|
||||
"project_dir",
|
||||
@ -153,8 +203,61 @@ def lock_kernels(args):
|
||||
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
|
||||
|
||||
|
||||
def upload_kernels(args):
|
||||
# Resolve `kernel_dir` to be uploaded.
|
||||
kernel_dir = Path(args.kernel_dir).resolve()
|
||||
build_dir = kernel_dir / "build"
|
||||
if not kernel_dir.is_dir():
|
||||
raise ValueError(f"{kernel_dir} is not a directory")
|
||||
if not build_dir.is_dir():
|
||||
raise ValueError("Couldn't find `build` directory inside `kernel_dir`")
|
||||
|
||||
repo_id = create_repo(
|
||||
repo_id=args.repo_id, private=args.private, exist_ok=True
|
||||
).repo_id
|
||||
|
||||
if args.branch is not None:
|
||||
create_branch(repo_id=repo_id, branch=args.branch, exist_ok=True)
|
||||
|
||||
delete_patterns: set[str] = set()
|
||||
for build_variant in build_dir.iterdir():
|
||||
if build_variant.is_dir():
|
||||
delete_patterns.add(f"{build_variant.name}/**")
|
||||
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=build_dir,
|
||||
revision=args.branch,
|
||||
path_in_repo="build",
|
||||
delete_patterns=list(delete_patterns),
|
||||
commit_message="Build uploaded using `kernels`.",
|
||||
)
|
||||
print(f"✅ Kernel upload successful. Find the kernel in https://hf.co/{repo_id}.")
|
||||
|
||||
|
||||
class _JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if dataclasses.is_dataclass(o):
|
||||
return dataclasses.asdict(o)
|
||||
return super().default(o)
|
||||
|
||||
|
||||
def check_kernel(
|
||||
*, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
|
||||
):
|
||||
try:
|
||||
import kernels.check
|
||||
except ImportError:
|
||||
print(
|
||||
"`kernels check` requires the `kernel-abi-check` package: pip install kernel-abi-check",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
kernels.check.check_kernel(
|
||||
macos=macos,
|
||||
manylinux=manylinux,
|
||||
python_abi=python_abi,
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
)
|
||||
|
@ -17,6 +17,87 @@ _RE_RETURNTYPE = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def _extract_description_before_tags(docstring_mdx: str) -> str:
|
||||
"""Extract the description part of a docstring before any tags."""
|
||||
params_pos = docstring_mdx.find("<parameters>")
|
||||
returns_pos = docstring_mdx.find("<returns>")
|
||||
returntype_pos = docstring_mdx.find("<returntype>")
|
||||
positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1]
|
||||
|
||||
if positions:
|
||||
first_tag_pos = min(positions)
|
||||
return docstring_mdx[:first_tag_pos].strip()
|
||||
else:
|
||||
return docstring_mdx.strip()
|
||||
|
||||
|
||||
def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None:
|
||||
"""Print the parameters section from a docstring."""
|
||||
matches = _RE_PARAMETERS.findall(docstring_mdx)
|
||||
if matches:
|
||||
header = "#" * header_level
|
||||
print(f"\n{header} Parameters")
|
||||
for match in matches:
|
||||
print(f"\n{match[0].strip()}")
|
||||
|
||||
|
||||
def _print_returns_section(
|
||||
docstring_mdx: str, *, context_name: str, header_level: int
|
||||
) -> None:
|
||||
"""Print the returns section from a docstring."""
|
||||
return_matches = _RE_RETURNS.findall(docstring_mdx)
|
||||
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
|
||||
|
||||
if return_matches or returntype_matches:
|
||||
header = "#" * header_level
|
||||
print(f"\n{header} Returns")
|
||||
|
||||
if returntype_matches:
|
||||
if len(returntype_matches) > 1:
|
||||
raise ValueError(
|
||||
f"More than one <returntype> tag found in docstring for {context_name}"
|
||||
)
|
||||
print(f"\n**Type**: {returntype_matches[0][0].strip()}")
|
||||
|
||||
if return_matches:
|
||||
for match in return_matches:
|
||||
print(f"\n{match[0].strip()}")
|
||||
|
||||
|
||||
def _get_docstring(obj, use_dict_check: bool = False) -> str:
|
||||
"""Get docstring from an object, with fallback to default message."""
|
||||
# Check whether the class/method itself has docs and not just
|
||||
# the superclass.
|
||||
if use_dict_check:
|
||||
has_doc = obj.__dict__.get("__doc__", None) is not None
|
||||
else:
|
||||
has_doc = getattr(obj, "__doc__", None) is not None
|
||||
|
||||
# We use inspect.getdoc because it does normalization.
|
||||
doc = inspect.getdoc(obj)
|
||||
|
||||
return doc if has_doc and doc is not None else "No documentation available."
|
||||
|
||||
|
||||
def _process_and_print_docstring(
|
||||
docstring: str, *, kernel_name: str, context_name: str, header_level: int
|
||||
) -> None:
|
||||
"""Convert docstring to MDX and print description, parameters, and returns sections."""
|
||||
docstring_mdx = convert_rst_docstring_to_mdx(
|
||||
docstring, page_info={"package_name": kernel_name}
|
||||
)
|
||||
|
||||
# Print the description
|
||||
description = _extract_description_before_tags(docstring_mdx)
|
||||
print(f"\n{description}")
|
||||
|
||||
# Print parameters and returns sections
|
||||
_print_parameters_section(docstring_mdx, header_level=header_level)
|
||||
_print_returns_section(
|
||||
docstring_mdx, context_name=context_name, header_level=header_level
|
||||
)
|
||||
|
||||
|
||||
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
|
||||
kernel_module = get_kernel(repo_id=repo_id, revision=revision)
|
||||
kernel_name = repo_id.split("/")[-1].replace("-", "_")
|
||||
@ -24,9 +105,10 @@ def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
|
||||
generate_metadata(kernel_module)
|
||||
generate_kernel_doc(kernel_module, kernel_name)
|
||||
generate_function_doc(kernel_module, kernel_name)
|
||||
generate_layers_doc(kernel_module, kernel_name)
|
||||
|
||||
|
||||
def generate_metadata(module: ModuleType):
|
||||
def generate_metadata(module: ModuleType) -> None:
|
||||
metadata = getattr(module, "__kernel_metadata__", {})
|
||||
if "tags" not in metadata:
|
||||
metadata["tags"] = ["kernel"]
|
||||
@ -39,7 +121,7 @@ def generate_metadata(module: ModuleType):
|
||||
print("---")
|
||||
|
||||
|
||||
def generate_kernel_doc(module: ModuleType, kernel_name: str):
|
||||
def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None:
|
||||
docstring = module.__doc__.strip() if module.__doc__ is not None else None
|
||||
if docstring:
|
||||
title, rest = docstring.split("\n", 1)
|
||||
@ -49,76 +131,112 @@ def generate_kernel_doc(module: ModuleType, kernel_name: str):
|
||||
)
|
||||
|
||||
|
||||
def generate_function_doc(kernel_module, kernel_name):
|
||||
functions_info = []
|
||||
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
|
||||
# Do not include imported functions.
|
||||
if func.__module__ == kernel_module.__name__:
|
||||
# Exclude private functions.
|
||||
if not name.startswith("_"):
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
docstring = inspect.getdoc(func) or "No documentation available."
|
||||
functions_info.append((name, sig, docstring))
|
||||
except ValueError:
|
||||
print(
|
||||
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None:
|
||||
print("\n## Functions")
|
||||
|
||||
if not functions_info:
|
||||
print(
|
||||
"\nNo public top-level functions.",
|
||||
)
|
||||
return
|
||||
# Track if we found any functions
|
||||
found_functions = False
|
||||
|
||||
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
|
||||
# Do not include imported functions.
|
||||
if func.__module__ != kernel_module.__name__:
|
||||
continue
|
||||
|
||||
# Exclude private functions.
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
found_functions = True
|
||||
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
docstring = _get_docstring(func)
|
||||
except ValueError:
|
||||
print(
|
||||
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
for name, sig, docstring in functions_info:
|
||||
print(f"\n### Function `{name}`")
|
||||
print(f"\n`{sig}`")
|
||||
|
||||
docstring_mdx = convert_rst_docstring_to_mdx(
|
||||
docstring, page_info={"package_name": kernel_name}
|
||||
_process_and_print_docstring(
|
||||
docstring, kernel_name=kernel_name, context_name=name, header_level=3
|
||||
)
|
||||
|
||||
params_pos = docstring_mdx.find("<parameters>")
|
||||
returns_pos = docstring_mdx.find("<returns>")
|
||||
returntype_pos = docstring_mdx.find("<returntype>")
|
||||
positions = [
|
||||
pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1
|
||||
]
|
||||
if not found_functions:
|
||||
print("\nNo public top-level functions.")
|
||||
|
||||
if positions:
|
||||
first_tag_pos = min(positions)
|
||||
# The function description is anything before the first tag.
|
||||
print(f"\n{docstring_mdx[:first_tag_pos].strip()}")
|
||||
else:
|
||||
print(f"\n{docstring_mdx.strip()}")
|
||||
|
||||
# Extract parameters
|
||||
matches = _RE_PARAMETERS.findall(docstring_mdx)
|
||||
if matches:
|
||||
print("\n### Parameters")
|
||||
for match in matches:
|
||||
print(f"\n{match[0].strip()}")
|
||||
def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None:
|
||||
# Check if layers module is available
|
||||
layers_module = getattr(kernel_module, "layers", None)
|
||||
if layers_module is None:
|
||||
return
|
||||
|
||||
# Extract return information
|
||||
return_matches = _RE_RETURNS.findall(docstring_mdx)
|
||||
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
|
||||
print("\n## Layers")
|
||||
|
||||
if return_matches or returntype_matches:
|
||||
print("\n### Returns", file=sys.stdout)
|
||||
# Track if we found any classes
|
||||
found_classes = False
|
||||
|
||||
if returntype_matches:
|
||||
if len(returntype_matches) > 1:
|
||||
raise ValueError(
|
||||
f"More than one <returntype> tag found in docstring for {name} in {kernel_module.__name__}"
|
||||
)
|
||||
for class_name, cls in inspect.getmembers(layers_module, inspect.isclass):
|
||||
# Exclude classes that were imported.
|
||||
if cls.__module__ != layers_module.__name__:
|
||||
continue
|
||||
|
||||
found_classes = True
|
||||
|
||||
try:
|
||||
# Get docstring, but not from superclasses.
|
||||
class_docstring = _get_docstring(cls, use_dict_check=True)
|
||||
except Exception:
|
||||
print(
|
||||
f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
print(f"\n### Class `{class_name}`")
|
||||
|
||||
# Always print class description (helper handles conversion and formatting)
|
||||
class_docstring_mdx = convert_rst_docstring_to_mdx(
|
||||
class_docstring, page_info={"package_name": kernel_name}
|
||||
)
|
||||
description = _extract_description_before_tags(class_docstring_mdx)
|
||||
print(f"\n{description}")
|
||||
|
||||
# Document methods
|
||||
print("\n#### Methods")
|
||||
|
||||
for method_name, method in inspect.getmembers(cls, inspect.isfunction):
|
||||
# Note: also skip __init__, since extension layers cannot have a constructor.
|
||||
if method_name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Skip methods from superclasses.
|
||||
if method_name not in cls.__dict__:
|
||||
continue
|
||||
|
||||
try:
|
||||
sig = inspect.signature(method)
|
||||
method_docstring = _get_docstring(method)
|
||||
except ValueError:
|
||||
print(
|
||||
f"\n**Type**: {returntype_matches[0][0].strip()}", file=sys.stdout
|
||||
f"Warning: Could not retrieve signature for {method_name} in {class_name}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
if return_matches:
|
||||
for match in return_matches:
|
||||
print(f"\n{match[0].strip()}")
|
||||
print(f"\n##### Method `{method_name}`")
|
||||
print(f"\n`{sig}`")
|
||||
|
||||
_process_and_print_docstring(
|
||||
method_docstring,
|
||||
kernel_name=kernel_name,
|
||||
context_name=method_name,
|
||||
header_level=6,
|
||||
)
|
||||
|
||||
if not found_classes:
|
||||
print("\nNo layers defined.")
|
||||
|
1095
src/kernels/layer.py
1095
src/kernels/layer.py
File diff suppressed because it is too large
Load Diff
@ -4,10 +4,8 @@ from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.hf_api import GitRefInfo
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
from kernels._versions import resolve_version_spec_as_ref
|
||||
from kernels.compat import tomllib
|
||||
|
||||
|
||||
@ -31,20 +29,6 @@ class KernelLock:
|
||||
return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)
|
||||
|
||||
|
||||
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
|
||||
"""Get kernel versions that are available in the repository."""
|
||||
versions = {}
|
||||
for tag in HfApi().list_repo_refs(repo_id).tags:
|
||||
if not tag.name.startswith("v"):
|
||||
continue
|
||||
try:
|
||||
versions[Version(tag.name[1:])] = tag
|
||||
except InvalidVersion:
|
||||
continue
|
||||
|
||||
return versions
|
||||
|
||||
|
||||
def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
|
||||
"""
|
||||
Get the locks for a kernel with the given version spec.
|
||||
@ -52,16 +36,7 @@ def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
|
||||
The version specifier can be any valid Python version specifier:
|
||||
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
|
||||
"""
|
||||
versions = _get_available_versions(repo_id)
|
||||
requirement = SpecifierSet(version_spec)
|
||||
accepted_versions = sorted(requirement.filter(versions.keys()))
|
||||
|
||||
if len(accepted_versions) == 0:
|
||||
raise ValueError(
|
||||
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
|
||||
)
|
||||
|
||||
tag_for_newest = versions[accepted_versions[-1]]
|
||||
tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec)
|
||||
|
||||
r = HfApi().repo_info(
|
||||
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
|
||||
|
@ -11,13 +11,16 @@ import sys
|
||||
from importlib.metadata import Distribution
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from huggingface_hub import file_exists, snapshot_download
|
||||
from packaging.version import parse
|
||||
|
||||
from kernels._versions import select_revision_or_version
|
||||
from kernels.lockfile import KernelLock, VariantLock
|
||||
|
||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
||||
|
||||
|
||||
def _get_cache_dir() -> Optional[str]:
|
||||
"""Returns the kernels cache directory."""
|
||||
@ -34,6 +37,14 @@ def _get_cache_dir() -> Optional[str]:
|
||||
CACHE_DIR: Optional[str] = _get_cache_dir()
|
||||
|
||||
|
||||
def _get_privateuse_backend_name() -> Optional[str]:
|
||||
import torch
|
||||
|
||||
if hasattr(torch._C, "_get_privateuse1_backend_name"):
|
||||
return torch._C._get_privateuse1_backend_name()
|
||||
return None
|
||||
|
||||
|
||||
def build_variant() -> str:
|
||||
import torch
|
||||
|
||||
@ -45,9 +56,17 @@ def build_variant() -> str:
|
||||
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
||||
elif torch.backends.mps.is_available():
|
||||
compute_framework = "metal"
|
||||
elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
|
||||
version = torch.version.xpu
|
||||
compute_framework = f"xpu{version[0:4]}{version[5:6]}"
|
||||
elif _get_privateuse_backend_name() == "npu":
|
||||
from torch_npu.utils.collect_env import get_cann_version # type: ignore[import-not-found]
|
||||
|
||||
cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2]
|
||||
compute_framework = f"cann{cann_major}{cann_minor}"
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Torch was not compiled with CUDA, Metal, or ROCm enabled."
|
||||
"Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled."
|
||||
)
|
||||
|
||||
torch_version = parse(torch.__version__)
|
||||
@ -55,6 +74,7 @@ def build_variant() -> str:
|
||||
os = platform.system().lower()
|
||||
|
||||
if os == "darwin":
|
||||
cpu = "aarch64" if cpu == "arm64" else cpu
|
||||
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
|
||||
|
||||
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
||||
@ -90,15 +110,32 @@ def install_kernel(
|
||||
revision: str,
|
||||
local_files_only: bool = False,
|
||||
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
||||
user_agent: Optional[Union[str, dict]] = None,
|
||||
) -> Tuple[str, Path]:
|
||||
"""
|
||||
Download a kernel for the current environment to the cache.
|
||||
|
||||
The output path is validated againt `hash` when set.
|
||||
The output path is validated against the hashes in `variant_locks` when provided.
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The Hub repository containing the kernel.
|
||||
revision (`str`):
|
||||
The specific revision (branch, tag, or commit) to download.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only use local files and not download from the Hub.
|
||||
variant_locks (`Dict[str, VariantLock]`, *optional*):
|
||||
Optional dictionary of variant locks for validation.
|
||||
user_agent (`Union[str, dict]`, *optional*):
|
||||
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
|
||||
|
||||
Returns:
|
||||
`Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
|
||||
"""
|
||||
package_name = package_name_from_repo_id(repo_id)
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
user_agent = _get_user_agent(user_agent=user_agent)
|
||||
repo_path = Path(
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
@ -106,9 +143,27 @@ def install_kernel(
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
return _load_kernel_from_path(repo_path, package_name, variant_locks)
|
||||
except FileNotFoundError:
|
||||
# Redo with more specific error message.
|
||||
raise FileNotFoundError(
|
||||
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
|
||||
)
|
||||
|
||||
|
||||
def _load_kernel_from_path(
|
||||
repo_path: Path,
|
||||
package_name: str,
|
||||
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
||||
) -> Tuple[str, Path]:
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
|
||||
variant_path = repo_path / "build" / variant
|
||||
universal_variant_path = repo_path / "build" / universal_variant
|
||||
|
||||
@ -127,7 +182,7 @@ def install_kernel(
|
||||
|
||||
if not os.path.exists(module_init_path):
|
||||
raise FileNotFoundError(
|
||||
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
|
||||
f"Kernel at path `{repo_path}` does not have build: {variant}"
|
||||
)
|
||||
|
||||
return package_name, variant_path
|
||||
@ -164,16 +219,103 @@ def install_kernel_all_variants(
|
||||
return repo_path / "build"
|
||||
|
||||
|
||||
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
||||
package_name, package_path = install_kernel(repo_id, revision=revision)
|
||||
def get_kernel(
|
||||
repo_id: str,
|
||||
revision: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
user_agent: Optional[Union[str, dict]] = None,
|
||||
) -> ModuleType:
|
||||
"""
|
||||
Load a kernel from the kernel hub.
|
||||
|
||||
This function downloads a kernel to the local Hugging Face Hub cache directory (if it was not downloaded before)
|
||||
and then loads the kernel.
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The Hub repository containing the kernel.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
|
||||
version (`str`, *optional*):
|
||||
The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||
Cannot be used together with `revision`.
|
||||
user_agent (`Union[str, dict]`, *optional*):
|
||||
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
|
||||
|
||||
Returns:
|
||||
`ModuleType`: The imported kernel module.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import torch
|
||||
from kernels import get_kernel
|
||||
|
||||
activation = get_kernel("kernels-community/activation")
|
||||
x = torch.randn(10, 20, device="cuda")
|
||||
out = torch.empty_like(x)
|
||||
result = activation.silu_and_mul(out, x)
|
||||
```
|
||||
"""
|
||||
revision = select_revision_or_version(repo_id, revision, version)
|
||||
package_name, package_path = install_kernel(
|
||||
repo_id, revision=revision, user_agent=user_agent
|
||||
)
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def has_kernel(repo_id: str, revision: str = "main") -> bool:
|
||||
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
|
||||
"""
|
||||
Check whether a kernel build exists for the current environment
|
||||
(Torch version and compute framework).
|
||||
Import a kernel from a local kernel repository path.
|
||||
|
||||
Args:
|
||||
repo_path (`Path`):
|
||||
The local path to the kernel repository.
|
||||
package_name (`str`):
|
||||
The name of the package to import from the repository.
|
||||
|
||||
Returns:
|
||||
`ModuleType`: The imported kernel module.
|
||||
"""
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
|
||||
# Presume we were given the top level path of the kernel repository.
|
||||
for base_path in [repo_path, repo_path / "build"]:
|
||||
# Prefer the universal variant if it exists.
|
||||
for v in [universal_variant, variant]:
|
||||
package_path = base_path / v / package_name / "__init__.py"
|
||||
if package_path.exists():
|
||||
return import_from_path(package_name, package_path)
|
||||
|
||||
# If we didn't find the package in the repo we may have a explicit
|
||||
# package path.
|
||||
package_path = repo_path / package_name / "__init__.py"
|
||||
if package_path.exists():
|
||||
return import_from_path(package_name, package_path)
|
||||
|
||||
raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")
|
||||
|
||||
|
||||
def has_kernel(
|
||||
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check whether a kernel build exists for the current environment (Torch version and compute framework).
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The Hub repository containing the kernel.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
|
||||
version (`str`, *optional*):
|
||||
The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||
Cannot be used together with `revision`.
|
||||
|
||||
Returns:
|
||||
`bool`: `True` if a kernel is available for the current environment.
|
||||
"""
|
||||
revision = select_revision_or_version(repo_id, revision, version)
|
||||
|
||||
package_name = package_name_from_repo_id(repo_id)
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
@ -196,8 +338,16 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
|
||||
"""
|
||||
Get a pre-downloaded, locked kernel.
|
||||
|
||||
If `lockfile` is not specified, the lockfile will be loaded from the
|
||||
caller's package metadata.
|
||||
If `lockfile` is not specified, the lockfile will be loaded from the caller's package metadata.
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The Hub repository containing the kernel.
|
||||
lockfile (`Path`, *optional*):
|
||||
Path to the lockfile. If not provided, the lockfile will be loaded from the caller's package metadata.
|
||||
|
||||
Returns:
|
||||
`ModuleType`: The imported kernel module.
|
||||
"""
|
||||
if lockfile is None:
|
||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||
@ -242,7 +392,18 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
|
||||
|
||||
|
||||
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
|
||||
"""Get a kernel using a lock file."""
|
||||
"""
|
||||
Get a kernel using a lock file.
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The Hub repository containing the kernel.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only use local files and not download from the Hub.
|
||||
|
||||
Returns:
|
||||
`ModuleType`: The imported kernel module.
|
||||
"""
|
||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||
|
||||
if locked_sha is None:
|
||||
@ -354,3 +515,24 @@ def git_hash_object(data: bytes, object_type: str = "blob"):
|
||||
|
||||
def package_name_from_repo_id(repo_id: str) -> str:
|
||||
return repo_id.split("/")[-1].replace("-", "_")
|
||||
|
||||
|
||||
def _get_user_agent(
|
||||
user_agent: Optional[Union[dict, str]] = None,
|
||||
) -> Union[None, dict, str]:
|
||||
import torch
|
||||
|
||||
from . import __version__
|
||||
|
||||
if os.getenv("DISABLE_TELEMETRY", "false").upper() in ENV_VARS_TRUE_VALUES:
|
||||
return None
|
||||
|
||||
if user_agent is None:
|
||||
user_agent = {
|
||||
"kernels": __version__,
|
||||
"torch": torch.__version__,
|
||||
"build_variant": build_variant(),
|
||||
"file_type": "kernel",
|
||||
}
|
||||
|
||||
return user_agent
|
||||
|
@ -1,10 +1,46 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels.utils import _get_privateuse_backend_name
|
||||
|
||||
has_cuda = (
|
||||
hasattr(torch.version, "cuda")
|
||||
and torch.version.cuda is not None
|
||||
and torch.cuda.device_count() > 0
|
||||
)
|
||||
has_rocm = (
|
||||
hasattr(torch.version, "hip")
|
||||
and torch.version.hip is not None
|
||||
and torch.cuda.device_count() > 0
|
||||
)
|
||||
has_xpu = (
|
||||
hasattr(torch.version, "xpu")
|
||||
and torch.version.xpu is not None
|
||||
and torch.xpu.device_count() > 0
|
||||
)
|
||||
has_npu = _get_privateuse_backend_name() == "npu"
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--token",
|
||||
action="store_true",
|
||||
help="run tests that require a token with write permissions",
|
||||
)
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
if "linux_only" in item.keywords and not sys.platform.startswith("linux"):
|
||||
pytest.skip("skipping Linux-only test on non-Linux platform")
|
||||
if "cuda_only" in item.keywords and not has_cuda:
|
||||
pytest.skip("skipping CUDA-only test on host without CUDA")
|
||||
if "rocm_only" in item.keywords and not has_rocm:
|
||||
pytest.skip("skipping ROCm-only test on host without ROCm")
|
||||
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
||||
pytest.skip("skipping macOS-only test on non-macOS platform")
|
||||
if "xpu_only" in item.keywords and not has_xpu:
|
||||
pytest.skip("skipping XPU-only test on host without XPU")
|
||||
if "npu_only" in item.keywords and not has_npu:
|
||||
pytest.skip("skipping NPU-only test on host without NPU")
|
||||
if "token" in item.keywords and not item.config.getoption("--token"):
|
||||
pytest.skip("need --token option to run this test")
|
||||
|
@ -1,82 +1,70 @@
|
||||
[
|
||||
{
|
||||
"repo_id": "kernels-community/activation",
|
||||
"sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538",
|
||||
"sha": "83046852be158d525114f68513cd79fd88911b37",
|
||||
"variants": {
|
||||
"torch25-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu121-x86_64-linux": {
|
||||
"hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu121-x86_64-linux": {
|
||||
"hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu126-aarch64-linux": {
|
||||
"hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu126-aarch64-linux": {
|
||||
"hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu126-x86_64-linux": {
|
||||
"hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu126-aarch64-linux": {
|
||||
"hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e",
|
||||
"hash": "sha256-e34965c814c4c092fcb634ebadefe82ea9a05b98343f8ebdefa7305dcc05359e",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405",
|
||||
"hash": "sha256-5f92b35922b37224a416398a39a29b7e5f1aca1df17d5c69f1b9e9cdb7033561",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu128-aarch64-linux": {
|
||||
"hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99",
|
||||
"hash": "sha256-125967cb23bacd2cec443799f184ac08247dfff33f5027e54ee16d3779ca5986",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu128-x86_64-linux": {
|
||||
"hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597",
|
||||
"hash": "sha256-496a84c99d7035a1b6f0ea1c026b751c3a2677956f4c1be546d3cc1505a5fdbb",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch28-cxx11-cu126-aarch64-linux": {
|
||||
"hash": "sha256-f0775a30ffa290c90aba3a41037e3ca91edb15b4a9367561fafd5f25455e117a",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch28-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-081995e6230f306bdf6111186618794f2411cf0ffd9b4800330df60b4ebe1927",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch28-cxx11-cu128-aarch64-linux": {
|
||||
"hash": "sha256-b937fef62a0c1cd71ab98490b651c473577af209b9a3e2a6b452350283d8812c",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch28-cxx11-cu128-x86_64-linux": {
|
||||
"hash": "sha256-a3915686cc58641a3361ece63ab77b33e9d30315dea12547e4bda008d8810a01",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch28-cxx11-cu129-aarch64-linux": {
|
||||
"hash": "sha256-a24dca8e998f88be42491921c9df89d88a6112ca630acd2efc2dd34a64b91fcb",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch28-cxx11-cu129-x86_64-linux": {
|
||||
"hash": "sha256-df6c70a70f425db2f68b86561c6f93c5675c1d5e5d058766d88ab17472229907",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch29-cxx11-cu126-aarch64-linux": {
|
||||
"hash": "sha256-c120011c201072b4cfd70c2ba2d45c2f05337feaf604ddec3c6c4987def33ab3",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch29-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-765a7f3279009979be4001a23c5c70e5e6ab9553098d67886731a5275a6d4b32",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch29-cxx11-cu128-aarch64-linux": {
|
||||
"hash": "sha256-266d057a9cd82b872a0e02f09ac5e2660fcffcf9a7b7fa1fa8ff33dc19c0f5c2",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch29-cxx11-cu128-x86_64-linux": {
|
||||
"hash": "sha256-6850e594ba4588f289b5904eb88eda5a41870ee20a3bf1586f3268307caf4b53",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch29-cxx11-cu130-aarch64-linux": {
|
||||
"hash": "sha256-23741b935462b53bdf868f8d1c9c8cff5f02f71ea3b0550df41dc8b030b0b474",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch29-cxx11-cu130-x86_64-linux": {
|
||||
"hash": "sha256-b884ae792dc1eada071f31645add0c2c76d479864f25aebcdd8318b675aaaf29",
|
||||
"hash_type": "git_lfs_concat"
|
||||
}
|
||||
}
|
||||
|
12
tests/layer_locking/kernels.lock
Normal file
12
tests/layer_locking/kernels.lock
Normal file
@ -0,0 +1,12 @@
|
||||
[
|
||||
{
|
||||
"repo_id": "kernels-test/versions",
|
||||
"sha": "dc142fd6c9920c993d32be6358b78957c58681c3",
|
||||
"variants": {
|
||||
"torch-universal": {
|
||||
"hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c",
|
||||
"hash_type": "git_lfs_concat"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
2
tests/layer_locking/pyproject.toml
Normal file
2
tests/layer_locking/pyproject.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[tool.kernels.dependencies]
|
||||
"kernels-test/versions" = ">=0.1.0,<0.2.0"
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels import get_kernel, has_kernel
|
||||
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -9,6 +9,20 @@ def kernel():
|
||||
return get_kernel("kernels-community/activation")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_kernel_path():
|
||||
package_name, path = install_kernel("kernels-community/activation", "main")
|
||||
# Path is the build variant path (build/torch-<...>), so the grandparent
|
||||
# is the kernel repository path.
|
||||
return package_name, path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_kernel(local_kernel_path):
|
||||
package_name, path = local_kernel_path
|
||||
return get_local_kernel(path.parent.parent, package_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metal_kernel():
|
||||
return get_kernel("kernels-test/relu-metal")
|
||||
@ -26,7 +40,7 @@ def device():
|
||||
return "cuda"
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@pytest.mark.cuda_only
|
||||
def test_gelu_fast(kernel, device):
|
||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||
y = torch.empty_like(x)
|
||||
@ -42,6 +56,55 @@ def test_gelu_fast(kernel, device):
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
def test_local_kernel(local_kernel, device):
|
||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||
y = torch.empty_like(x)
|
||||
|
||||
local_kernel.gelu_fast(y, x)
|
||||
|
||||
expected = torch.tensor(
|
||||
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
def test_local_kernel_path_types(local_kernel_path, device):
|
||||
package_name, path = local_kernel_path
|
||||
|
||||
# Top-level repo path
|
||||
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071
|
||||
kernel = get_local_kernel(path.parent.parent, package_name)
|
||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||
y = torch.empty_like(x)
|
||||
|
||||
kernel.gelu_fast(y, x)
|
||||
expected = torch.tensor(
|
||||
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
# Build directory path
|
||||
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build
|
||||
kernel = get_local_kernel(path.parent.parent / "build", package_name)
|
||||
y = torch.empty_like(x)
|
||||
kernel.gelu_fast(y, x)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
# Explicit package path
|
||||
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux
|
||||
kernel = get_local_kernel(path, package_name)
|
||||
y = torch.empty_like(x)
|
||||
kernel.gelu_fast(y, x)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.darwin_only
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_relu_metal(metal_kernel, dtype):
|
||||
@ -50,7 +113,7 @@ def test_relu_metal(metal_kernel, dtype):
|
||||
assert torch.allclose(y, torch.relu(x))
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@pytest.mark.cuda_only
|
||||
@pytest.mark.parametrize(
|
||||
"kernel_exists",
|
||||
[
|
||||
@ -67,7 +130,26 @@ def test_has_kernel(kernel_exists):
|
||||
assert has_kernel(repo_id, revision=revision) == kernel
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_version():
|
||||
kernel = get_kernel("kernels-test/versions")
|
||||
assert kernel.version() == "0.2.0"
|
||||
kernel = get_kernel("kernels-test/versions", version="<1.0.0")
|
||||
assert kernel.version() == "0.2.0"
|
||||
kernel = get_kernel("kernels-test/versions", version="<0.2.0")
|
||||
assert kernel.version() == "0.1.1"
|
||||
kernel = get_kernel("kernels-test/versions", version=">0.1.0,<0.2.0")
|
||||
assert kernel.version() == "0.1.1"
|
||||
|
||||
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
||||
get_kernel("kernels-test/versions", version=">0.2.0")
|
||||
|
||||
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
||||
kernel = get_kernel(
|
||||
"kernels-test/versions", revision="v0.1.0", version="<1.0.0"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
def test_universal_kernel(universal_kernel):
|
||||
torch.manual_seed(0)
|
||||
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
||||
|
@ -16,21 +16,21 @@ def device():
|
||||
return "cuda"
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@pytest.mark.cuda_only
|
||||
def test_gelu_small(kernel, device, benchmark):
|
||||
x = torch.randn(32, 32, dtype=torch.float16, device=device)
|
||||
y = torch.empty_like(x)
|
||||
benchmark(kernel.gelu_fast, y, x)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@pytest.mark.cuda_only
|
||||
def test_gelu_medium(kernel, device, benchmark):
|
||||
x = torch.randn(128, 128, dtype=torch.float16, device=device)
|
||||
y = torch.empty_like(x)
|
||||
benchmark(kernel.gelu_fast, y, x)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@pytest.mark.cuda_only
|
||||
def test_gelu_large(kernel, device, benchmark):
|
||||
x = torch.randn(512, 512, dtype=torch.float16, device=device)
|
||||
y = torch.empty_like(x)
|
||||
|
49
tests/test_doctest.py
Normal file
49
tests/test_doctest.py
Normal file
@ -0,0 +1,49 @@
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
from mktestdocs import check_docstring, get_codeblock_members
|
||||
|
||||
import kernels
|
||||
|
||||
|
||||
def all_public_functions():
|
||||
function_list = inspect.getmembers(kernels, inspect.isfunction)
|
||||
return [func for _, func in function_list]
|
||||
|
||||
|
||||
def all_public_classes():
|
||||
class_list = inspect.getmembers(kernels, inspect.isclass)
|
||||
return [cls for _, cls in class_list]
|
||||
|
||||
|
||||
def all_public_class_members():
|
||||
members = get_codeblock_members(*all_public_classes())
|
||||
return members
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
@pytest.mark.parametrize(
|
||||
"func",
|
||||
all_public_functions(),
|
||||
ids=lambda d: d.__name__,
|
||||
)
|
||||
def test_func_docstring(func):
|
||||
check_docstring(obj=func)
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
@pytest.mark.parametrize(
|
||||
"cls",
|
||||
all_public_classes(),
|
||||
ids=lambda d: d.__name__,
|
||||
)
|
||||
def test_class_docstring(cls):
|
||||
check_docstring(obj=cls)
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
@pytest.mark.parametrize(
|
||||
"member", all_public_class_members(), ids=lambda d: d.__qualname__
|
||||
)
|
||||
def test_member_docstring(member):
|
||||
check_docstring(member)
|
230
tests/test_interval_tree.py
Normal file
230
tests/test_interval_tree.py
Normal file
@ -0,0 +1,230 @@
|
||||
import random
|
||||
from typing import Generic, List, Optional, Tuple, TypeVar
|
||||
|
||||
import pytest
|
||||
|
||||
from kernels._interval_tree import IntervalTree, _Node
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SimpleIntervalStore(Generic[T]):
|
||||
"""A simple O(n) implementation that stores intervals in a list."""
|
||||
|
||||
def __init__(self):
|
||||
self.intervals: List[Tuple[int, int, T]] = []
|
||||
|
||||
def insert(self, start: int, end: int, data: T) -> None:
|
||||
"""Insert an interval into the store."""
|
||||
# Replace data if the interval already exists.
|
||||
for i, (existing_start, existing_end, existing_data) in enumerate(
|
||||
self.intervals
|
||||
):
|
||||
if existing_start == start and existing_end == end:
|
||||
self.intervals[i] = (start, end, data)
|
||||
return
|
||||
|
||||
self.intervals.append((start, end, data))
|
||||
|
||||
def find_smallest_interval(self, point: int) -> Optional[T]:
|
||||
"""Find the best match using linear search."""
|
||||
matches = []
|
||||
for start, end, data in self.intervals:
|
||||
if start <= point <= end:
|
||||
matches.append((start, end, data))
|
||||
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
# Return the smallest interval, sort by memory location when
|
||||
# there are multiple matches with the same interval size. This
|
||||
# mirrors the ordering in the intervan tree.
|
||||
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
|
||||
return best_match[2]
|
||||
|
||||
|
||||
def is_balanced(tree: IntervalTree[T]) -> bool:
|
||||
"""Check if the AVL tree is properly balanced."""
|
||||
|
||||
def check_balance(node: Optional[_Node[T]]) -> Tuple[bool, int]:
|
||||
if node is None:
|
||||
return True, 0
|
||||
|
||||
# Left and right subtrees should be balanced.
|
||||
left_balanced, left_height = check_balance(node.left)
|
||||
if not left_balanced:
|
||||
return False, -1
|
||||
|
||||
right_balanced, right_height = check_balance(node.right)
|
||||
if not right_balanced:
|
||||
return False, -1
|
||||
|
||||
# The difference in height should not exceed 1.
|
||||
if abs(left_height - right_height) > 1:
|
||||
return False, -1
|
||||
|
||||
# Check if the height is correct.
|
||||
expected_height = 1 + max(left_height, right_height)
|
||||
if node.height != expected_height:
|
||||
return False, -1
|
||||
|
||||
return True, expected_height
|
||||
|
||||
balanced, _ = check_balance(tree.root)
|
||||
return balanced
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def populated_tree() -> IntervalTree[str]:
|
||||
"""Provides a pre-populated IntervalTree for testing."""
|
||||
tree = IntervalTree[str]()
|
||||
kernels = [
|
||||
(80, 89, "Kernel_A_General_80_89"),
|
||||
(86, 89, "Kernel_B_Ampere_86_89"),
|
||||
(80, 86, "Kernel_C_Older_Ampere_80_86"),
|
||||
(70, 75, "Kernel_D_Volta_70_75"),
|
||||
(86, 87, "Kernel_E_Specific_86_87"),
|
||||
]
|
||||
for start, end, name in kernels:
|
||||
tree.insert(start, end, name)
|
||||
return tree
|
||||
|
||||
|
||||
def test_find_smallest_interval_match_with_multiple_overlaps(populated_tree):
|
||||
# Check that the smallest inteval is selected when there are
|
||||
# multiple matching intervals.
|
||||
assert populated_tree.find_smallest_interval(86) == "Kernel_E_Specific_86_87"
|
||||
|
||||
|
||||
def test_find_single_match(populated_tree):
|
||||
assert populated_tree.find_smallest_interval(72) == "Kernel_D_Volta_70_75"
|
||||
assert populated_tree.find_smallest_interval(75) == "Kernel_D_Volta_70_75"
|
||||
|
||||
|
||||
def test_no_match_outside_all_ranges(populated_tree):
|
||||
# Check that no interval is found when the value is out of range
|
||||
# (too small/too large).
|
||||
assert populated_tree.find_smallest_interval(65) is None
|
||||
assert populated_tree.find_smallest_interval(95) is None
|
||||
|
||||
|
||||
def test_no_match_in_gap_between_ranges(populated_tree):
|
||||
# Check that no interval is found when the value is between two
|
||||
# intervals.
|
||||
assert populated_tree.find_smallest_interval(78) is None
|
||||
|
||||
|
||||
def test_boundary_conditions_start_and_end(populated_tree):
|
||||
# Test exact upper/lower bounds of intervals.
|
||||
assert populated_tree.find_smallest_interval(80) == "Kernel_C_Older_Ampere_80_86"
|
||||
assert populated_tree.find_smallest_interval(89) == "Kernel_B_Ampere_86_89"
|
||||
|
||||
|
||||
def test_empty_tree():
|
||||
# Searching in an empty tree should return None.
|
||||
empty_tree = IntervalTree[str]()
|
||||
assert empty_tree.find_smallest_interval(100) is None
|
||||
|
||||
|
||||
def test_multiple_equally_specific_matches():
|
||||
# Check that we pick the match in a stable way when there is are
|
||||
# multiple matching intervals with the same size.
|
||||
tree = IntervalTree[str]()
|
||||
str1 = "First_Narrow_Kernel"
|
||||
str2 = "Second_Narrow_Kernel"
|
||||
tree.insert(10, 20, "Wide_Kernel")
|
||||
tree.insert(12, 17, str1)
|
||||
tree.insert(14, 19, str2)
|
||||
|
||||
if id(str1) < id(str2):
|
||||
assert tree.find_smallest_interval(15) == str1
|
||||
else:
|
||||
assert tree.find_smallest_interval(15) == str2
|
||||
|
||||
|
||||
def test_property_based_interval_tree():
|
||||
# Quick-check property-based testing:
|
||||
#
|
||||
# - Verify that the tree is balanced after each insertion.
|
||||
# - Verify the query against a simple list-based implementation.
|
||||
|
||||
random.seed(42) # For reproducible tests
|
||||
|
||||
test_points = list(range(0, 101))
|
||||
|
||||
for _ in range(5):
|
||||
tree = IntervalTree[str]()
|
||||
simple = SimpleIntervalStore[str]()
|
||||
|
||||
intervals = []
|
||||
for i in range(100):
|
||||
start = random.randint(0, 90)
|
||||
end = random.randint(start, 100)
|
||||
data = f"interval_{i}_s{start}_e{end}"
|
||||
intervals.append((start, end, data))
|
||||
|
||||
for i, (start, end, data) in enumerate(intervals):
|
||||
tree.insert(start, end, data)
|
||||
simple.insert(start, end, data)
|
||||
|
||||
# Check that tree is still balanced
|
||||
assert is_balanced(
|
||||
tree
|
||||
), f"Tree became unbalanced after inserting interval {i}: ({start}, {end})"
|
||||
|
||||
for point in test_points:
|
||||
tree_result = tree.find_smallest_interval(point)
|
||||
simple_result = simple.find_smallest_interval(point)
|
||||
|
||||
assert tree_result == simple_result, (
|
||||
f"Mismatch for point {point} after inserting {i+1} intervals. "
|
||||
f"Tree: {tree_result}, Simple: {simple_result}. "
|
||||
f"Last inserted: ({start}, {end})"
|
||||
)
|
||||
|
||||
|
||||
def test_property_based_edge_cases():
|
||||
random.seed(123)
|
||||
|
||||
tree = IntervalTree[str]()
|
||||
simple = SimpleIntervalStore[str]()
|
||||
|
||||
# Single-point intervals.
|
||||
for i in range(10):
|
||||
point = random.randint(0, 100)
|
||||
data = f"single_point_{i}_{point}"
|
||||
tree.insert(point, point, data)
|
||||
simple.insert(point, point, data)
|
||||
|
||||
assert is_balanced(
|
||||
tree
|
||||
), f"Tree unbalanced after inserting single point {point}"
|
||||
|
||||
# Test the exact point and neighbors
|
||||
for test_point in [point - 1, point, point + 1]:
|
||||
if 0 <= test_point <= 100:
|
||||
tree_result = tree.find_smallest_interval(test_point)
|
||||
simple_result = simple.find_smallest_interval(test_point)
|
||||
assert tree_result == simple_result
|
||||
|
||||
|
||||
def test_unique_intervals_override():
|
||||
"""Test that inserting an interval with the same start/end overrides the previous value."""
|
||||
tree = IntervalTree[str]()
|
||||
|
||||
tree.insert(10, 20, "original_value")
|
||||
assert tree.find_smallest_interval(15) == "original_value"
|
||||
|
||||
tree.insert(10, 20, "new_value")
|
||||
assert tree.find_smallest_interval(15) == "new_value"
|
||||
|
||||
tree.insert(10, 25, "different_interval")
|
||||
results = tree.search(15)
|
||||
assert "new_value" in results
|
||||
assert "different_interval" in results
|
||||
assert len(results) == 2
|
||||
|
||||
tree.insert(10, 20, "final_value")
|
||||
assert tree.find_smallest_interval(15) == "final_value"
|
||||
|
||||
assert is_balanced(tree)
|
@ -2,9 +2,17 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
|
||||
from kernels import load_kernel
|
||||
from kernels.cli import download_kernels
|
||||
from kernels.layer import (
|
||||
LockedLayerRepository,
|
||||
Mode,
|
||||
kernelize,
|
||||
use_kernel_forward_from_hub,
|
||||
use_kernel_mapping,
|
||||
)
|
||||
|
||||
|
||||
# Mock download arguments class.
|
||||
@ -19,9 +27,35 @@ def test_download_all_hash_validation():
|
||||
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@pytest.mark.cuda_only
|
||||
def test_load_locked():
|
||||
project_dir = Path(__file__).parent / "kernel_locking"
|
||||
# Also validates that hashing works correctly.
|
||||
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
|
||||
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
def test_layer_locked():
|
||||
project_dir = Path(__file__).parent / "layer_locking"
|
||||
|
||||
@use_kernel_forward_from_hub("Version")
|
||||
class Version(nn.Module):
|
||||
def forward(self) -> str:
|
||||
return "0.0.0"
|
||||
|
||||
version = Version()
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
"cuda": LockedLayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
lockfile=project_dir / "kernels.lock",
|
||||
)
|
||||
},
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
assert version() == "0.1.1"
|
||||
|
122
tests/test_kernel_upload.py
Normal file
122
tests/test_kernel_upload.py
Normal file
@ -0,0 +1,122 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import delete_repo, model_info, list_repo_refs
|
||||
|
||||
from kernels.cli import upload_kernels
|
||||
|
||||
REPO_ID = "valid_org/kernels-upload-test"
|
||||
|
||||
|
||||
PY_CONTENT = """\
|
||||
#!/usr/bin/env python3
|
||||
|
||||
def main():
|
||||
print("Hello from torch-universal!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class UploadArgs:
|
||||
kernel_dir: None
|
||||
repo_id: None
|
||||
private: False
|
||||
branch: None
|
||||
|
||||
|
||||
def next_filename(path: Path) -> Path:
|
||||
"""
|
||||
Given a path like foo_2050.py, return foo_2051.py.
|
||||
"""
|
||||
m = re.match(r"^(.*?)(\d+)(\.py)$", path.name)
|
||||
if not m:
|
||||
raise ValueError(
|
||||
f"Filename {path.name!r} does not match pattern <prefix>_<number>.py"
|
||||
)
|
||||
|
||||
prefix, number, suffix = m.groups()
|
||||
new_number = str(int(number) + 1).zfill(len(number))
|
||||
return path.with_name(f"{prefix}{new_number}{suffix}")
|
||||
|
||||
|
||||
def get_filename_to_change(repo_filenames):
|
||||
for f in repo_filenames:
|
||||
if "foo" in f and f.endswith(".py"):
|
||||
filename_to_change = os.path.basename(f)
|
||||
break
|
||||
assert filename_to_change
|
||||
return filename_to_change
|
||||
|
||||
|
||||
def get_filenames_from_a_repo(repo_id: str) -> List[str]:
|
||||
try:
|
||||
repo_info = model_info(repo_id=repo_id, files_metadata=True)
|
||||
repo_siblings = repo_info.siblings
|
||||
if repo_siblings is not None:
|
||||
return [f.rfilename for f in repo_siblings]
|
||||
else:
|
||||
raise ValueError("No repo siblings found.")
|
||||
except Exception as e:
|
||||
logging.error(f"Error connecting to the Hub: {e}.")
|
||||
|
||||
|
||||
@pytest.mark.token
|
||||
@pytest.mark.is_staging_test
|
||||
@pytest.mark.parametrize("branch", (None, "foo"))
|
||||
def test_kernel_upload_works_as_expected(branch):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = f"{tmpdir}/build/torch-universal/upload_test"
|
||||
build_dir = Path(path)
|
||||
build_dir.mkdir(parents=True, exist_ok=True)
|
||||
script_path = build_dir / "foo.py"
|
||||
script_path.write_text(PY_CONTENT)
|
||||
upload_kernels(UploadArgs(tmpdir, REPO_ID, False, branch))
|
||||
|
||||
repo_filenames = get_filenames_from_a_repo(REPO_ID)
|
||||
assert any(str(script_path.name) for f in repo_filenames)
|
||||
|
||||
if branch is not None:
|
||||
refs = list_repo_refs(repo_id=REPO_ID)
|
||||
assert any(ref_branch.name == branch for ref_branch in refs.branches)
|
||||
|
||||
delete_repo(repo_id=REPO_ID)
|
||||
|
||||
|
||||
@pytest.mark.token
|
||||
@pytest.mark.is_staging_test
|
||||
def test_kernel_upload_deletes_as_expected():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = f"{tmpdir}/build/torch-universal/upload_test"
|
||||
build_dir = Path(path)
|
||||
build_dir.mkdir(parents=True, exist_ok=True)
|
||||
script_path = build_dir / "foo_2025.py"
|
||||
script_path.write_text(PY_CONTENT)
|
||||
upload_kernels(UploadArgs(tmpdir, REPO_ID, False, None))
|
||||
|
||||
repo_filenames = get_filenames_from_a_repo(REPO_ID)
|
||||
filename_to_change = get_filename_to_change(repo_filenames)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = f"{tmpdir}/build/torch-universal/upload_test"
|
||||
build_dir = Path(path)
|
||||
build_dir.mkdir(parents=True, exist_ok=True)
|
||||
changed_filename = next_filename(Path(filename_to_change))
|
||||
script_path = build_dir / changed_filename
|
||||
script_path.write_text(PY_CONTENT)
|
||||
upload_kernels(UploadArgs(tmpdir, REPO_ID, False, None))
|
||||
|
||||
repo_filenames = get_filenames_from_a_repo(REPO_ID)
|
||||
assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}"
|
||||
assert not any(
|
||||
str(filename_to_change) in k for k in repo_filenames
|
||||
), f"{repo_filenames=}"
|
||||
delete_repo(repo_id=REPO_ID)
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user