Compare commits

...

77 Commits

Author SHA1 Message Date
ed048616fe Set version to 0.10.4.dev0 (#169) 2025-10-16 20:21:35 +02:00
b182cd3458 feat: allow get_kernel to log telemetry. (#167)
* feat: allow get_kernel to log telemetry.

* Apply suggestions from code review

Co-authored-by: Daniël de Kok <me@danieldk.eu>

* doc

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-10-16 20:16:41 +02:00
ce77658efc fix: kernels upload to a repo branch (#168)
* fix: kernels upload to a repo branch

* up
2025-10-16 16:01:00 +02:00
b96b154e7f Avoid exception when detecting XPU on Torch <= 2.6 (#165)
torch.version has no xpu field in torch<=2.6

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-10-14 09:01:53 +02:00
b24ef9fa6b Set version to 0.10.3.dev0 (#164) 2025-10-13 17:23:39 +02:00
a7101b2cfd feat: allow kernels to be uploaded to a revision (#161)
* feat: allow kernels to be uploaded to a revision

* revision -> branch
2025-10-13 10:31:11 +02:00
6241afa06e Bump torch version in runner (#162)
* bump torch version

* run kernels lock tests/kernel_locking
2025-10-09 11:04:52 +02:00
34a1932751 Link local kernel and local/locked kernel API docs (#160) 2025-10-02 14:38:47 +02:00
e39eac09c1 up (#159) 2025-09-30 17:42:09 +02:00
b0c431fee4 Add the kernels check subcommand (#158)
* Add the `kernels check` subcommand

This subcommand checks a given kernel. Currently it applies the same ABI
checks as `kernel-abi-check` in `kernel-builder`.

* Print an error when `build` contains files

* Forgot to update has_issues in two places
2025-09-25 19:05:29 +02:00
9a188eadbe up (#157) 2025-09-24 11:39:07 +02:00
457c7c1b8d Only run staging tests in one configuration (#156) 2025-09-23 10:52:47 +02:00
fb8cd99a2c Add support for NPU kernelize/layers (#155)
This change add support for Huawei Ascend NPUs. This is #146 with some formatting/typing fixes.

Co-authored-by: zheliuyu <15750543867@163.com>
2025-09-23 10:46:41 +02:00
dfee307d54 Set version to 0.10.2.dev0 (#154) 2025-09-22 18:54:09 +02:00
93e5765611 [tests] turn the kernels upload tests to be staging tests (#152) 2025-09-22 18:53:53 +02:00
bf488208be faq: why only replace forward methods? (#153) 2025-09-19 17:38:03 +02:00
2a14472e4c Bump huggingface_hub upper bound <2.0 (#151) 2025-09-19 16:56:30 +02:00
055a953552 Document the to-wheel subcommand (#149)
* Document the `to-wheel` subcommand

* Capitalization
2025-09-17 17:02:41 +02:00
692d5ad458 Fix some spelling errors to check docs CI is working (#120) 2025-09-17 13:44:09 +02:00
2139df57f4 rm link (#148) 2025-09-17 12:46:49 +02:00
8f9a77bb6a Describe the get_kernel/LayerRepository (#147)
This was already in the API documentation, but describe this in the
guides as well (since we want people to use versions).
2025-09-16 16:06:40 +02:00
6c00194680 Improve errors for layer validation (#145)
* Improve errors for layer validation

Include the repo and layer name as well as the name of the class
that is being compared to (when applicable).

* Remove upload xfail

* Only enable tests that require a token with `--token`
2025-09-16 14:40:54 +02:00
d6b51eefb7 [feat] add an uploading utility (#138)
* add an uploading utility.

* format

* remove stale files.

* black format

* sorted imports.

* up

* up

* add a test

* propagate.

* remove duplicate imports.

* Apply suggestions from code review

Co-authored-by: Daniël de Kok <me@danieldk.eu>

* up

* up

* up

* command to format all files at once would be nice.

* up

* up

* up

* Use token for upload test

* assign env better.

* docs

* polish

* up

* xfail the test for now.

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-09-16 08:56:54 +02:00
d383fdd4b4 Add support for XPU layer repostories (#142)
This change adds support for XPU layer repositories, e.g.:

```
kernel_mapping = {
    "LigerRMSNorm": {
        "xpu": LayerRepository(
            repo_id="kernels-community/liger_kernels",
            layer_name="LigerRMSNorm",
        )
    },
}

Co-authored-by: YangKai0616 <kai.yang@intel.com>
2025-09-11 15:51:02 +02:00
07e5e8481a Set version to 0.10.1.dev0 (#140)
* Set version to 0.10.1.dev0

* Add `__version__` attribute to top-level module

This is needed for doc generation.
2025-09-10 09:08:02 +02:00
88f55d4728 XPU: look up kernel by framework version (#139)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-09-09 13:10:11 +02:00
e801ebf332 Set version to v0.10.0.dev0 (#137) 2025-09-05 10:48:41 +02:00
0ae07f05fc Remove default for mode argument of kernelize (#136) 2025-08-29 17:44:20 +02:00
7611021100 cpu is not (yet) a supported device type (#132)
Fixes #131.
2025-08-25 16:25:58 +02:00
767e7ccf13 fix: add get local tests (#134)
* fix: add tests for get local kernel

* fix: update test and add path example comments

* fix: run black linter
2025-08-21 13:01:48 -04:00
1caa4c1393 feat: improve get local kernel importing (#129)
* feat: improve get local kernel importing

* fix: adjust for linter
2025-08-08 10:22:29 -04:00
da701bf58a Small markup fixes of the local kernel repo example (#127) 2025-08-06 08:02:28 +02:00
703664ed31 Set version to 0.9.0.dev0 (#126) 2025-08-01 16:37:30 +02:00
a8a6564fa7 Add ROCm device discovery (#122)
* Add ROCm device discovery

* Ruff

* Address review comments

* Ruff

* Reorg torch import

* Remove redundant import

* Apply suggestions from code review

Co-authored-by: Daniël de Kok <me@danieldk.eu>

* Address review comments

* Validat device type

* Clean diff

* black

* Sync test with repo changes

* black again

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-08-01 16:09:45 +02:00
c89e0fa9b9 Nix: go back to hf-nix main (#125) 2025-08-01 15:56:02 +02:00
176a601178 Run black check (#124) 2025-08-01 15:42:38 +02:00
cfa0c76ddc Add LocalLayerRepository to load from a local repo (#123) 2025-08-01 14:03:11 +02:00
bcc29915f9 Log when using fallback layer (#121) 2025-07-31 17:18:00 +02:00
6fbff7a9cb Add doc build to CI (#119)
* Add doc build to CI

* Trigger doc build

* No path scoping
2025-07-29 16:01:05 +02:00
f7490bd0a9 Test examples in docstrings using mktestdocs (#118)
Also adjust examples so that they are correct.
2025-07-28 17:31:34 +02:00
8069e3bf0c Update documentation for compatibility with doc-builder (#117) 2025-07-24 16:21:54 +02:00
c540d1e1d6 Fix typo in layers documentation (#116) 2025-07-23 17:13:14 +02:00
967ac581b8 Set version to 0.8.1.dev0 (#115) 2025-07-23 14:42:24 +02:00
81088d44e8 Add support for project-wide locking of layers (#114)
This change adds `LockedLayerRepository` as an alternative to
`LayerRepository`. `LockedLayerRepository` allows for locking all kernel
layers that are used at the project level. Example usage:

```
with use_kernel_mapping(
    {
        "SomeLayer": {
            "cuda": LockedLayerRepository(
                repo_id="some-org/some-layer",
                layer_name="SomeLayer",
            )
        },
    }
):
    layer = kernelize(layer, device="cuda", mode=Mode.INFERENCE)
```

This requires that the project has a `pyproject.toml` with kernel
version specifications and `kernel.lock` with the locked kernels.
2025-07-23 09:37:05 +02:00
4a04c005e3 Add version support to LayerRepository (#113)
* Add version support to `LayerRepository`

* Remove some docs that do not apply

* Removed unused member variable
2025-07-22 17:02:39 +02:00
6d3c6daf20 triton based kernel could also run in xpu (#112)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-07-22 10:03:34 +02:00
071900fd69 get_kernel: allow Python-style version specifiers (#111)
Use Python-style version specifiers to resolve to tags. E.g., given
the presence of the tags `v0.1.0`, `v0.1.1`, and `v0.2.0`,

get_kernel("my/kernel", version=">=0.1.0,<0.2.0")

would resolve to `v0.1.1`.
2025-07-21 17:18:35 +02:00
2d2c6b14e0 Set version to 0.8.0.dev0 (#110) 2025-07-15 18:45:03 +02:00
03edc573b1 Log kernel layer selection (#109) 2025-07-15 18:38:17 +02:00
c841a6c90d Improve mode handling (#108)
* Set `kernelize` default mode to `Mode.TRAINING | Mode.TORCH_COMPILE`

Also update docs and tests.

* Rename `Mode.DEFAULT` to `Mode.FALLBACK`

* More fine-grained fallbacks

For instance, INFERENCE can fall back to INFERENCE | TORCH_COMPILE,
TRAINING, TRAINING | TORCH_COMPILE, and FALLBACK.

* Update documtenation for mode fallback

* Mention that you can rerun `kernelize` to change the mode
2025-07-15 16:10:43 +02:00
c7a343f195 Support registering layers with a range of CUDA capabilities (#106)
* Add interval tree implementation

* Support registering layers with a range of CUDA capabilities

This change adds support for registering a layers for ranges
of CUDA capabilities. This makes it possible to use newer, faster
kernels for new GPUs, while falling back to another implementation
on older GPUs.

* Add docs for registering kernels with CUDA capabilities

* Fix typing errors
2025-07-14 16:59:21 +02:00
8d838f947d Fix macOS tests by marking some CUDA-only tests (#105) 2025-07-10 12:24:25 +02:00
b87e6fadbe Set version to 0.7.0.dev0 (#104) 2025-07-07 14:56:43 +02:00
fc935d9874 Support registering inference/training-specific layers (#103)
* Support registering inference/training-specific layers

This change makes it possible to register kernels specialized for
inference, training, and/or `torch.compile`. To do so, the mapping
notation is extended to support registering specialized kernels
for a specific 'mode'. For instance, the following mapping,

```python
kernel_layer_mapping = {
    "SiluAndMul": {
        "cuda": {
          Mode.DEFAULT: LayerRepository(
              repo_id="kernels-community/activation",
              layer_name="SiluAndMul",
          ),
          Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
              repo_id="kernels-community/activation-training-optimized",
              layer_name="SiluAndMul",
          ),
      }
    }
}
```

uses `kernels-community/activation` by default, but will switch to
using `kernels-community/activation-training-optimized` if a model
is kernelized for training and `torch.compile`.

To make it easier to add more modes in the future and to unify the
`register_kernel_mapping` and `kernelize` signatures, the `training`
and `needs_torch_compile` arguments of `kernelize` are replaced by
a single `mode` argument:

```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```

* Documentation fixes

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Add note on when the fallback is used

* Tighten up some Mode checks

* Fix ruff check

* Attempt to fix mypy errors

* More typing fixes

* Ignore Python < 3.11 type check SNAFU

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-07-04 19:57:14 +02:00
3622e1f8dd Add get_local_kernel function (#102)
This function loads a kernel from a local repository (e.g. the output
of kernel-builder), which can be handy for testing.
2025-07-01 13:58:47 +02:00
a7f3b2e8ed Set version to 0.6.2.dev0 (#100) 2025-06-25 09:48:09 +02:00
a6ab5d83ba Make the flake work on Darwin (#98) 2025-06-24 20:35:21 +02:00
4f9f1abfb9 darwin: fix variant CPU for aarch64 (#97) 2025-06-24 20:35:07 +02:00
f94b7780a6 CI: main triton-layer-norm has docs, branch is gone (#99) 2025-06-24 16:40:36 +02:00
bd28883775 Set version to 0.6.1.dev1 (#96) 2025-06-20 11:43:26 +02:00
498429e322 Add README generation for layers (#94) 2025-06-20 10:16:50 +02:00
09c991af4b Add macOS requirements (#95) 2025-06-16 17:20:47 +02:00
bcf8df5875 Bump version to 0.6.0.dev0 (#93) 2025-06-04 13:59:32 +02:00
239afff6f5 Update Nix flake dependencies (#92)
* Update Nix flake dependencies

To ensure that we can test with Torch 2.7 kernels in the development
environment.

* Update nix fmt to use nixfmt-tree
2025-06-04 12:13:19 +02:00
c5ec6b900a Hotfix: add FAQ (#91) 2025-06-04 09:52:39 +02:00
3a635eaeea Automatic fallback for kernels that don't support training (#90)
For kernels that do not support backward, fall back to the original
implementation if `model.train(True)` is called. This removes the
need for the `needs_backward` argument of `kernelize`.
2025-06-03 19:13:57 +02:00
32ec496c5a Make the forward pass torch.compile compatible (#87)
* first commit

* style

* update

* fix

* different approach

* Polish kernelize

- Process comment from the PR.
- Replacement should be on instances, not the class.
- Remove torch compile checks (not relevant during kernelize). We
  might add it back in a different way in another commit: add an
  option to `kernelize`.

* Fixup tests

* Fix `torch.compile` support

* Remove some unused code

* Sync the docs

* CI: update Torch versions

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-06-03 15:06:02 +02:00
848c6db87b Add support for Metal builds (#89)
* Add support for Metal builds

* Add Metal test, gate tests by OS where necessary
2025-05-30 15:54:28 +02:00
fabb8c52d1 Add generate-readme subcommand for generating a README (#88)
* Add `generate-readme` subcommand for generating a README

This README includes all the top-level functions with docs (if
docstrings are available).

* CI: attempt README generation

* Add PyYAML dependencies

* Typing fixes
2025-05-21 15:43:53 +02:00
d66260dd83 kernels: add the to-wheel subcommand (#84)
* kernels: add the `to-wheel` subcommand

This subcommand accepts a kernel repo and version as arguments:

    kernels to-wheel kernels-community/activation 0.0.3

Wheels will then be generated for every build variant.

* CI: check kernel -> wheel conversion

* No typing for wheel.wheelfile
2025-05-08 17:30:06 +02:00
daac8078fc CI: fix some stubs (#83) 2025-05-07 14:43:57 +02:00
fcb9a80ce6 Set version to 0.5.0 (#82) 2025-05-06 11:45:26 +02:00
c25bb32e6e Add publishing workflow (#81) 2025-05-06 09:29:08 +00:00
2036892762 Allow layers to opt in to torch.compile (#79)
* Allow layers to opt in to `torch.compile`

This change allows a layer to set the `can_torch_compile` class
variable to indicate that the layer is compatible with `torch.compile`.
When enabled, the layer does not fall back to the original
implementation when `torch.compile` is used.

* Comment fixes

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-05-06 09:36:33 +02:00
0f0de049cf docs: link to autogenerated build variant list (#77) 2025-04-16 17:25:51 +02:00
59597df03e Specify required aarch64 build variants (#76) 2025-04-15 16:09:44 +02:00
5e938ede40 locking docs: fix command name (kernel -> kernels) (#74) 2025-04-14 16:02:00 +02:00
50 changed files with 5588 additions and 408 deletions

View File

@ -0,0 +1,17 @@
name: Build documentation
on:
push:
branches:
- main
- doc-builder*
- v*-release
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
with:
commit_sha: ${{ github.sha }}
package: kernels
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}

View File

@ -0,0 +1,15 @@
name: Build PR Documentation
on: pull_request
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: kernels

View File

@ -8,3 +8,24 @@ jobs:
- uses: actions/checkout@v4
- name: Run ruff
uses: astral-sh/ruff-action@v3
black:
name: Run black check
runs-on: ubuntu-latest
env:
UV_PYTHON_PREFERENCE: only-managed
steps:
- uses: actions/checkout@v4
- name: Install uv and set the python version
uses: astral-sh/setup-uv@v5
with:
python-version: 3.12
- name: Install black
run: uv pip install black
- name: Check formatting
run: |
uv run black --check src
uv run black --check tests

120
.github/workflows/publish.yml vendored Normal file
View File

@ -0,0 +1,120 @@
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
on: push
jobs:
build:
name: Build distribution 📦
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.9"
- name: Install pypa/build
run: >-
python3 -m
pip install
build
--user
- name: Build a binary wheel and a source tarball
run: python3 -m build
- name: Store the distribution packages
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: dist/
publish-to-pypi:
name: >-
Publish Python 🐍 distribution 📦 to PyPI
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
needs:
- build
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/kernels
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
github-release:
name: >-
Sign the Python 🐍 distribution 📦 with Sigstore
and upload them to GitHub Release
needs:
- publish-to-pypi
runs-on: ubuntu-latest
permissions:
contents: write # IMPORTANT: mandatory for making GitHub Releases
id-token: write # IMPORTANT: mandatory for sigstore
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Sign the dists with Sigstore
uses: sigstore/gh-action-sigstore-python@v3.0.0
with:
inputs: >-
./dist/*.tar.gz
./dist/*.whl
- name: Create GitHub Release
env:
GITHUB_TOKEN: ${{ github.token }}
run: >-
gh release create
"$GITHUB_REF_NAME"
--repo "$GITHUB_REPOSITORY"
--notes ""
- name: Upload artifact signatures to GitHub Release
env:
GITHUB_TOKEN: ${{ github.token }}
# Upload to GitHub Release using the `gh` CLI.
# `dist/` contains the built packages, and the
# sigstore-produced signatures and certificates.
run: >-
gh release upload
"$GITHUB_REF_NAME" dist/**
--repo "$GITHUB_REPOSITORY"
publish-to-testpypi:
name: Publish Python 🐍 distribution 📦 to TestPyPI
needs:
- build
runs-on: ubuntu-latest
environment:
name: testpypi
url: https://test.pypi.org/p/kernels
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
skip-existing: true # Only upload when the version is unique.

View File

@ -24,7 +24,7 @@ jobs:
max-parallel: 4
matrix:
python-version: ["3.10", "3.12"]
torch-version: ["2.5.1", "2.6.0"]
torch-version: ["2.7.0", "2.8.0"]
env:
UV_PYTHON_PREFERENCE: only-managed
@ -51,7 +51,32 @@ jobs:
run: uv run mypy src/kernels
- name: Run tests
run: uv run pytest tests
run: |
uv run pytest tests
- name: Run staging tests
env:
HF_TOKEN: ${{ secrets.HF_STAGING_TOKEN }}
run: |
HUGGINGFACE_CO_STAGING=true uv run pytest --token -m "is_staging_test" tests/
if: matrix.python_version == '3.10' && matrix.torch-version == '2.7.0'
- name: Check kernel conversion
run: |
uv pip install wheel
uv run kernels to-wheel kernels-community/triton-layer-norm 0.0.1
uv pip install triton_layer_norm-0.0.1*.whl
uv run python -c "import triton_layer_norm"
- name: Check README generation
# For now, just checks that generation doesn't fail.
run: |
uv run kernels generate-readme kernels-community/triton-layer-norm
- name: Check kernel check
run: |
uv pip install kernel-abi-check
kernels check kernels-community/activation
- name: Import check without torch
run: |

View File

@ -0,0 +1,16 @@
name: Upload PR Documentation
on:
workflow_run:
workflows: ["Build PR Documentation"]
types:
- completed
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
with:
package_name: kernels
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}

8
Makefile Normal file
View File

@ -0,0 +1,8 @@
.PHONY: style
export check_dirs := src examples tests
style:
black ${check_dirs}
isort ${check_dirs}
ruff check ${check_dirs} --fix

View File

@ -56,9 +56,12 @@ the Hub.
## 📚 Documentation
- [Using layers](docs/layers.md)
- [Locking kernel versions](docs/locking.md)
- [Environment variables](docs/env.md)
- [Using kernels in a Docker container](docs/docker.md)
- [Kernel requirements](docs/kernel-requirements.md)
- [Introduction](docs/source/index.md)
- [Installation](docs/source/installation.md)
- [Basic usage](docs/source/basic-usage.md)
- [Using layers](docs/source/layers.md)
- [Locking kernel/layer versions](docs/source/locking.md)
- [Environment variables](docs/source/env.md)
- [Kernel requirements](docs/source/kernel-requirements.md)
- [Frequently Asked Questions](docs/source/faq.md)
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)

View File

@ -1,8 +0,0 @@
# Using kernels in a Docker container
build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
```bash
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
```

View File

@ -1,79 +0,0 @@
# Layers
A kernel can provide layers in addition to kernel functions. A layer from
the Hub can replace the `forward` method of an existing layer for a certain
device type. This makes it possible to provide more performant kernels for
existing layers.
See [Kernel requirements](kernel-requirements.md) for more information the
requirements of Hub layers.
## Making a layer extensible with kernels from the hub
### Using a decorator
A layer can be made extensible with the `use_kernel_forward_from_hub`
decorator. For example:
```python
@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMul(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
return F.silu(input[..., :d]) * input[..., d:]
```
The decorator changes the layer, so that other implementations of the `forward`
method can be registered using the name `SiluAndMul`.
### External layers
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function.
```python
from somelibrary import SiluAndMul
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
register_kernel_mapping(kernel_layer_mapping)
```
The `register_kernel_mapping` call maps the name `SiluAndMul` to actual
hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer)
section for more information.
**Warning:** we strongly recommend using layers with a decorator, since
it signifies that the maintainer intends to keep the `forward` signature
compatible with layers from the hub.
## Registering a hub kernel for a layer
Once a layer is made extensible, users can register hub kernels for it
by name using the `register_kernel_mapping` function. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
}
}
register_kernel_mapping(kernel_layer_mapping)
```
This will register the kernel mapping in the current context, which is
normally global. It is recommended to scope the mapping to where it is
used with the `use_kernel_mapping` context manager:
```python
with use_kernel_mapping(kernel_layer_mapping):
# Use the layer for which the mapping is applied.
...
```
This ensures that the mapping is not active anymore outside the
`with`-scope.

30
docs/source/_toctree.yml Normal file
View File

@ -0,0 +1,30 @@
- sections:
- local: index
title: Introduction
- local: installation
title: Installation
title: Getting started
- sections:
- local: basic-usage
title: Basic Usage
- local: layers
title: Using Layers
- local: locking
title: Locking Kernel Versions
- local: env
title: Environment Variables
- local: faq
title: FAQ
title: Usage Guide
- sections:
- local: api/kernels
title: Kernels
- local: api/layers
title: Layers
- local: cli
title: Kernels CLI
title: API Reference
- sections:
- local: kernel-requirements
title: Kernel Requirements
title: Developer Guide

View File

@ -0,0 +1,25 @@
# Kernels API Reference
## Main Functions
### get_kernel
[[autodoc]] kernels.get_kernel
### get_local_kernel
[[autodoc]] kernels.get_local_kernel
### has_kernel
[[autodoc]] kernels.has_kernel
## Loading locked kernels
### load_kernel
[[autodoc]] kernels.load_kernel
### get_locked_kernel
[[autodoc]] kernels.get_locked_kernel

49
docs/source/api/layers.md Normal file
View File

@ -0,0 +1,49 @@
# Layers API Reference
## Making layers kernel-aware
### use_kernel_forward_from_hub
[[autodoc]] kernels.use_kernel_forward_from_hub
### replace_kernel_forward_from_hub
[[autodoc]] kernels.replace_kernel_forward_from_hub
## Registering kernel mappings
### use_kernel_mapping
[[autodoc]] kernels.use_kernel_mapping
### register_kernel_mapping
[[autodoc]] kernels.register_kernel_mapping
## Kernelizing a model
### kernelize
[[autodoc]] kernels.kernelize
## Classes
### Device
[[autodoc]] kernels.Device
### Mode
[[autodoc]] kernels.Mode
### LayerRepository
[[autodoc]] kernels.LayerRepository
### LocalLayerRepository
[[autodoc]] kernels.LocalLayerRepository
### LockedLayerRepository
[[autodoc]] kernels.LockedLayerRepository

View File

@ -0,0 +1,50 @@
# Basic Usage
## Loading Kernels
Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:
```python
import torch
from kernels import get_kernel
# Download optimized kernels from the Hugging Face hub
activation = get_kernel("kernels-community/activation")
# Create a random tensor
x = torch.randn((10, 10), dtype=torch.float16, device="cuda")
# Run the kernel
y = torch.empty_like(x)
activation.gelu_fast(y, x)
print(y)
```
### Using version bounds
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
You can specify which version to download using Python version specifiers:
```python
import torch
from kernels import get_kernel
activation = get_kernel("kernels-community/activation", version=">=0.0.4,<0.1.0")
```
This will get the latest kernel tagged `v0.0.z` where `z` is at least 4. It
is strongly recommended to specify a version bound, since a kernel author
might push incompatible changes to the `main` branch.
## Checking Kernel Availability
You can check if a specific kernel is available for your environment:
```python
from kernels import has_kernel
# Check if kernel is available for current environment
is_available = has_kernel("kernels-community/activation")
print(f"Kernel available: {is_available}")
```

58
docs/source/cli.md Normal file
View File

@ -0,0 +1,58 @@
# Kernels CLI Reference
## Main Functions
### kernels check
You can use `kernels check` to test compliance of a kernel on the Hub.
This currently checks that the kernel:
- Supports the currently-required Python ABI version.
- Works on supported operating system versions.
For example:
```bash
$ kernels check kernels-community/flash-attn3
Checking variant: torch28-cxx11-cu128-aarch64-linux
🐍 Python ABI 3.9 compatible
🐧 manylinux_2_28 compatible
[...]
```
### kernels to-wheel
We strongly recommend downloading kernels from the Hub using the `kernels`
package, since this comes with large [benefits](index.md) over using Python
wheels. That said, some projects may require deployment of kernels as
wheels. The `kernels` utility provides a simple solution to this. You can
convert any Hub kernel into a set of wheels with the `to-wheel` command:
```bash
$ kernels to-wheel drbh/img2grey 1.1.2
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu124cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_aarch64.whl
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
☸ img2grey-1.1.2+torch26cu118cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu124cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch27cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
```
### kernels upload
Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload
your kernel builds to the Hub. To know the supported arguments run: `kernels upload -h`.
**Notes**:
- This will take care of creating a repository on the Hub with the `repo_id` provided.
- If a repo with the `repo_id` already exists and if it contains a `build` with the build variant
being uploaded, it will attempt to delete the files existing under it.
- Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub.

51
docs/source/faq.md Normal file
View File

@ -0,0 +1,51 @@
# FAQ
## Kernel layers
### Why is the kernelization step needed as a separate step?
In earlier versions of `kernels`, a layer's `forward` method was replaced
by `use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`.
The new `forward` would dispatch to a kernel based on the device type,
whether a model was training, etc. However, this approach was
fundamentally incompatible with `torch.compile` since it relied
on data-dependent branching.
To avoid branching, we have to make dispatch decisions ahead of time,
which is what the `kernelize` function does.
### Why does kernelization only replace `forward` methods?
There are some other possible approaches. The first is to completely
replace existing layers by kernel layers. However, since this would
permit free-form layer classes, it would be much harder to validate
that layers are fully compatible with the layers that they are
replacing. For instance, they could have completely different member
variables. Besides that, we would also need to hold on to the original
layers, in case we need to revert to the base layers when the model
is `kernelize`d again with different options.
A second approach would be to make an auxiliary layer that wraps the
original layer and the kernel layer and dispatches to the kernel layer.
This wouldn't have the issues of the first approach, because kernel layers
could be similarly strict as they are now, and we would still have access
to the original layers when `kernelize`-ing the model again. However,
this would change the graph structure of the model and would break use
cases where programs access the model internals (e.g.
`model.layers[0].attention.query_weight`) or rely on the graph structure
in other ways.
The approach of `forward`-replacement is the least invasive, because
it preserves the original model graph. It is also reversible, since
even though the `forward` of a layer _instance_ might be replaced,
the corresponding class still has the original `forward`.
## Misc
### How can I disable kernel reporting in the user-agent?
By default, we collect telemetry when a call to `get_kernel()` is made.
This only includes the `kernels` version, `torch` version, and the build
information for the kernel being requested.
You can disable this by setting `export DISABLE_TELEMETRY=yes`.

20
docs/source/index.md Normal file
View File

@ -0,0 +1,20 @@
# Kernels
<div align="center">
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
</div>
The Kernel Hub allows Python libraries and applications to load compute
kernels directly from the [Hub](https://hf.co/). To support this kind
of dynamic loading, Hub kernels differ from traditional Python kernel
packages in that they are made to be:
- **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`.
- **Unique**: multiple versions of the same kernel can be loaded in the
same Python process.
- **Compatible**: kernels must support all recent versions of Python and
the different PyTorch build configurations (various CUDA versions
and C++ ABIs). Furthermore, older C library versions must be supported.
You can [search for kernels](https://huggingface.co/models?other=kernel) on
the Hub.

View File

@ -0,0 +1,16 @@
# Installation
Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):
```bash
pip install kernels
```
# Using kernels in a Docker container
Build and run the reference `examples/basic.py` in a Docker container with the following commands:
```bash
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
```

View File

@ -1,8 +1,11 @@
# Kernel requirements
Kernels on the Hub must fulfill the requirements outlined on this page.
Kernels on the Hub must fulfill the requirements outlined on this page. By
ensuring kernels are compliant, they can be used on a wide range of Linux
systems and Torch builds.
You can use [kernel-builder](https://github.com/huggingface/kernel-builder/)
to build conforming kernels.
to build compliant kernels.
## Directory layout
@ -10,62 +13,72 @@ A kernel repository on the Hub must contain a `build` directory. This
directory contains build variants of a kernel in the form of directories
following the template
`<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`.
For example `build/torch26-cxx98-cu118-x86_64-linux`. The currently
recommended build variants are:
For example `build/torch26-cxx98-cu118-x86_64-linux`.
- `torch25-cxx11-cu118-x86_64-linux`
- `torch25-cxx11-cu121-x86_64-linux`
- `torch25-cxx11-cu124-x86_64-linux`
- `torch25-cxx98-cu118-x86_64-linux`
- `torch25-cxx98-cu121-x86_64-linux`
- `torch25-cxx98-cu124-x86_64-linux`
- `torch26-cxx11-cu118-x86_64-linux`
- `torch26-cxx11-cu124-x86_64-linux`
- `torch26-cxx11-cu126-x86_64-linux`
- `torch26-cxx98-cu118-x86_64-linux`
- `torch26-cxx98-cu124-x86_64-linux`
- `torch26-cxx98-cu126-x86_64-linux`
This list will be updated as new PyTorch versions are released. Kernels
that are in pure Python (e.g. Triton kernels) only need to provide a
single build variant:
- `torch-universal`
Each variant directory should contain a single directory with the same name
Each variant directory must contain a single directory with the same name
as the repository (replacing `-` by `_`). For instance, kernels in the
`kernels-community/activation` repository have a directories like
`build/<variant>/activation`. This directory
must be a Python package with an `__init__.py` file.
## Build variants
A kernel can be compliant for a specific compute framework (e.g. CUDA) or
architecture (e.g. x86_64). For compliance with a compute framework and
architecture combination, all the variants from the [build variant list](https://github.com/huggingface/kernel-builder/blob/main/docs/build-variants.md)
must be available for that combination.
## Versioning
Kernels are versioned on the Hub using Git tags. Version tags must be of
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
to resolve the version constraints.
We recommend using [semver](https://semver.org/) to version kernels.
## Native Python module
Kernels will typically contain a native Python module with precompiled
compute kernels and bindings. This module must fulfill the following
requirements:
compute kernels and bindings. This module must fulfill the requirements
outlined in this section. For all operating systems, a kernel must not
have dynamic library dependencies outside:
- Torch;
- CUDA/ROCm libraries installed as dependencies of Torch.
## Compatibility with torch.compile
The Kernel Hub also encourages to write the kernels in a `torch.compile`
compliant way. This helps to ensure that the kernels are compatible with
`torch.compile` without introducing any graph breaks and triggering
recompilation which can limit the benefits of compilation.
[Here](https://github.com/huggingface/kernel-builder/blob/d1ee9bf9301ac8c5199099d90ee1c9d5c789d5ba/examples/relu-backprop-compile/tests/test_relu.py#L162) is a simple test example which checks for graph breaks and
recompilation triggers during `torch.compile`.
### Linux
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
for compatibility with Python 3.9 and later.
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
This means that the extension **must not** use symbols versions higher than:
- GLIBC 2.28
- GLIBCXX 3.4.24
- CXXABI 1.3.11
- GCC 7.0.0
These requirement can be checked with the ABI checker (see below).
These requirements can be checked with the ABI checker (see below).
- No dynamic library dependencies outside:
### macOS
- Torch;
- CUDA/ROCm libraries installed as dependencies of Torch.
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
for compatibility with Python 3.9 and later.
- macOS deployment target 15.0.
- Metal 3.0 (`-std=metal3.0`).
The ABI3 requirement can be checked with the ABI checker (see below).
### ABI checker
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
@ -119,9 +132,12 @@ requirements:
- The `forward` method has a signature that is compatible with the
`forward` method that it is extending.
The only exception to the _no class variables rule_ is addition of a
`has_backward` class variable. This variable is used to indicate whether
the layer has a backward pass implemented (`True` when absent).
There are two exceptions to the _no class variables rule_:
1. The `has_backward` variable can be used to indicate whether the layer has
a backward pass implemented (`True` when absent).
2. The `can_torch_compile` variable can be used to indicate whether the layer
supports `torch.compile` (`False` when absent).
This is an example of a pure layer:

323
docs/source/layers.md Normal file
View File

@ -0,0 +1,323 @@
# Layers
A kernel can provide layers in addition to kernel functions. A layer from
the Hub can replace the `forward` method of an existing layer for a certain
device type. This makes it possible to provide more performant kernels for
existing layers.
See [Kernel requirements](kernel-requirements.md) for more information on the
requirements of Hub layers.
## Making a layer extensible with kernels from the hub
### Using a decorator
A layer can be made extensible with the `use_kernel_forward_from_hub`
decorator. For example:
```python
@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMul(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
return F.silu(input[..., :d]) * input[..., d:]
```
The decorator does not change the behavior of the class -- it annotates
the class with the given name (here `SiluAndMul`). The `kernelize` function
described below uses this name to look up kernels for the layer.
### External layers
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
decorator can be made extensible using the `replace_kernel_forward_from_hub`
function:
```python
from somelibrary import SiluAndMul
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
```
**Warning:** we strongly recommend using layers with a decorator, since
it signifies that the maintainer intends to keep the `forward` signature
compatible with layers from the hub.
## Kernelizing a model
A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. `kernelize` can be used as follows:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE)
```
The `kernelize` function modifies the model in-place, the model itself is
returned as a convenience. The `mode` specifies that the model will be used
in inference. Similarly, you can ask `kernelize` to prepare the model for
training:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING)
```
A model that is kernelized for training can also be used for inference, but
not the other way around. If you want to change the mode of the kernelized
model, you can just run `kernelize` on the model again with the new mode.
If you want to compile a model with `torch.compile`, this should be indicated
in the mode as well. You can do this by combining `Mode.INFERENCE` or
`Mode.TRAINING` with `Mode.TORCH_COMPILE` using the set union (`|`) operator:
```python
model = MyModel(...)
# Inference
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
# Training
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```
### Kernel device
Kernels can be registered per device type. For instance, separate `cuda` and
`metal` kernels could be registered for the name `SiluAndMul`. By default,
`kernelize` will try to infer the device type from the model's parameters.
You can pass the device type to `kernelize` if the device type cannot be
inferred (e.g. because the model has no parameters):
```python
model = MyModel(...)
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
```
### Fallback `forward`
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
kernel does not support backward passes or `torch.compile` respectively,
`kernelize` will fall back to the original, non-kernelized, layer. You
can let `kernelize` raise an exception instead by using `use_fallback=False`:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
```
This can be useful if you want to guarantee that Hub kernels are used.
### Inspecting which kernels are used
The kernels that are used are logged at the `INFO` level by `kernelize`.
See the [Python logging](https://docs.python.org/3/library/logging.html)
documentation for information on how to configure logging.
## Registering a hub kernel for a layer
`kernelize` relies on kernel mappings to find Hub kernels for layers.
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
the Hub. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
"rocm": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
}
}
```
You can register such a mapping using `register_kernel_mapping`:
```python
register_kernel_mapping(kernel_layer_mapping)
```
This will register the kernel mapping in the current context, which is
normally global. It is recommended to scope the mapping to where it is
used with the `use_kernel_mapping` context manager:
```python
with use_kernel_mapping(kernel_layer_mapping):
# Use the layer for which the mapping is applied.
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```
This ensures that the mapping is not active anymore outside the
`with`-scope.
### Using version bounds
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
You can specify which version of the kernel to download using Python version
specifiers:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
version=">=0.0.4,<0.1.0",
),
"rocm": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
version=">=0.0.4,<0.1.0",
)
}
}
```
This will get the layer from latest kernel tagged `v0.0.z` where `z` is at
least 4. It is strongly recommended to specify a version bound, since a
kernel author might push incompatible changes to the `main` branch.
### Registering kernels for specific modes
You might want to register two different kernels for a particular layer,
where one kernel is optimized for a specific mode. You can do so by
registering layer repositories for specific modes. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
The `kernelize` function will attempt to use the following registered
kernels for a given mode:
- `INFERENCE`: `INFERENCE``INFERENCE | TORCH_COMPILE``TRAINING`
`TRAINING | TORCH_COMPILE``FALLBACK`
- `INFERENCE | TORCH_COMPILE`: `INFERENCE | TORCH_COMPILE`
`TRAINING | TORCH_COMPILE``FALLBACK`
- `TRAINING`: `TRAINING``TRAINING | TORCH_COMPILE``FALLBACK`
- `TRAINING | TORCH_COMPILE`: `TRAINING | TORCH_COMPILE``FALLBACK`
`Mode.FALLBACK` is a special mode that is used when no other mode matches. It
is also used when a kernel is registered without a mode, as described in the
previous section.
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.FALLBACK: LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
In this case, both `Mode.INFERENCE | Mode.TORCH_COMPILE` and
`Mode.TRAINING | Mode.TORCH_COMPILE` will use the `Mode.FALLBACK` kernel,
since the other kernels do not support `torch.compile`.
### Registering kernels for specific CUDA capabilities
Some kernels only work with newer CUDA architectures. For instance, some
kernels require capability 9.0 for the TMA unit on Hopper GPUs. `kernels`
supports registering layers for a range of CUDA capabilities. To do so,
you need to register the layer for a `Device` with type `cuda` and
set the supported range of CUDA capabilities with using `CUDAProperties`:
```python
kernel_layer_mapping = {
"SiluAndMul": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=75, max_capability=89
),
): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Device(
type="cuda",
properties=CUDAProperties(
min_capability=90, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-community/activation-hopper",
layer_name="SiluAndMul",
),
}
}
```
Capabilities behave as follows:
- The minimum and maximum capabilities are inclusive.
- When a new kernel is registered with the same min/max capabilities as
an existing kernel, the new kernel will replace the old kernel.
- When there are multiple kernels that support a capability, the kernel
with the smaller capability interval will be used. E.g. given:
- `KernelA` with `min_capability=80` and `max_capability=89`;
- `KernelB` with `min_capability=75` and `max_capability=89`;
- `kernelize` runs on a system with capability 8.6.
Then `KernelA` will be used because the interval 80..89 is smaller
than 75..89. The motivation is that kernels with smaller ranges
tend to be more optimized for a specific set of GPUs. **This behavior
might still change in the future.**
### Registering kernels for specific ROCm capabilities
Registering kernels for the ROCm architecture follows the exact same
pattern as CUDA kernels, using `min_capability` and `max_capability` to restrict
a kernel to a range of ROCm capabilities.
### Loading from a local repository for testing
The `LocalLayerRepository` class is provided to load a repository from
a local directory. For example:
```python
with use_kernel_mapping(
{
"SiluAndMul": {
"cuda": LocalLayerRepository(
repo_path="/home/daniel/kernels/activation",
package_name="activation",
layer_name="SiluAndMul",
)
}
},
inherit_mapping=False,
):
kernelize(linear, mode=Mode.INFERENCE)
```

View File

@ -1,4 +1,4 @@
# Locking kernel versions
# Locking kernel/layer versions
Projects that use `setuptools` can lock the kernel versions that should be
used. First specify the accepted versions in `pyproject.toml` and make
@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
"kernels-community/activation" = ">=0.0.1"
```
Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with
Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with
the locked revisions. The locked revision will be used when loading a kernel with
`get_locked_kernel`:
@ -26,9 +26,27 @@ activation = get_locked_kernel("kernels-community/activation")
**Note:** the lock file is included in the package metadata, so it will only be visible
to `kernels` after doing an (editable or regular) installation of your project.
## Locked kernel layers
Locking is also supported for kernel layers. To use locked layers, register them
with the `LockedLayerRepository` class:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": LockedLayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
}
}
register_kernel_mapping(kernel_layer_mapping)
```
## Pre-downloading locked kernels
Locked kernels can be pre-downloaded by running `kernel download .` in your
Locked kernels can be pre-downloaded by running `kernels download .` in your
project directory. This will download the kernels to your local Hugging Face
Hub cache.

View File

@ -20,11 +20,11 @@ activation.gelu_fast(y, x)
print("Kernel successfully executed")
# Check results
expected = torch.tensor([
[0.8408, 1.9551, 2.9961],
[4.0000, 5.0000, 6.0000],
[7.0000, 8.0000, 9.0000]
], device='cuda:0', dtype=torch.float16)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
device="cuda:0",
dtype=torch.float16,
)
assert torch.allclose(y, expected)
print("Calculated values are exact")

63
flake.lock generated
View File

@ -51,30 +51,50 @@
"type": "github"
}
},
"nixpkgs": {
"hf-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1737453259,
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
"owner": "danieldk",
"repo": "nixpkgs",
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
"lastModified": 1754038838,
"narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
"owner": "huggingface",
"repo": "hf-nix",
"rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "outlines-v0.1.4-tgi",
"owner": "huggingface",
"repo": "hf-nix",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1752785354,
"narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
"type": "github"
},
"original": {
"owner": "nixos",
"repo": "nixpkgs",
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"hf-nix": "hf-nix",
"nixpkgs": [
"tgi-nix",
"hf-nix",
"nixpkgs"
],
"tgi-nix": "tgi-nix"
]
}
},
"systems": {
@ -106,27 +126,6 @@
"repo": "default",
"type": "github"
}
},
"tgi-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1741617161,
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "kernels-0.2.0",
"repo": "text-generation-inference-nix",
"type": "github"
}
}
},
"root": "root",

View File

@ -1,7 +1,7 @@
{
inputs = {
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
hf-nix.url = "github:huggingface/hf-nix";
nixpkgs.follows = "hf-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
};
outputs =
@ -9,23 +9,28 @@
self,
nixpkgs,
flake-utils,
tgi-nix,
hf-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = import nixpkgs {
inherit system;
inherit (tgi-nix.lib) config;
config = hf-nix.lib.config system;
overlays = [
tgi-nix.overlays.default
hf-nix.overlays.default
];
};
in
{
formatter = pkgs.nixfmt-rfc-style;
formatter = pkgs.nixfmt-tree;
packages.kernel-abi-check = pkgs.python3.pkgs.callPackage ./nix/kernel-abi-check.nix {};
devShells = with pkgs; rec {
default = mkShell {
nativeBuildInputs = [
# For hf-doc-builder.
nodejs
];
buildInputs =
[
black
@ -34,10 +39,15 @@
ruff
]
++ (with python3.pkgs; [
docutils
huggingface-hub
(callPackage ./nix/kernel-abi-check.nix {})
mktestdocs
pytest
pytest-benchmark
pyyaml
torch
types-pyyaml
venvShellHook
]);

27
nix/kernel-abi-check.nix Normal file
View File

@ -0,0 +1,27 @@
{
buildPythonPackage,
fetchPypi,
rustPlatform,
}:
buildPythonPackage rec {
pname = "kernel-abi-check";
version = "0.6.2";
src = fetchPypi {
inherit version;
pname = "kernel_abi_check";
hash = "sha256-goWC7SK79FVNEvkp3bISBwbOqdSrmobANtrWIve9/Ys=";
};
cargoDeps = rustPlatform.fetchCargoVendor {
inherit pname version src sourceRoot;
hash = "sha256-+1jdbKsDKmG+bf0NEVYMv8t7Meuge1z2cgYfbdB9q8A=";
};
sourceRoot = "kernel_abi_check-${version}/bindings/python";
pyproject = true;
nativeBuildInputs = with rustPlatform; [ cargoSetupHook maturinBuildHook ];
}

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.4.4"
version = "0.10.4.dev0"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
@ -12,8 +12,9 @@ license = { text = "Apache-2.0" }
readme = "README.md"
requires-python = ">= 3.9"
dependencies = [
"huggingface_hub>=0.26.0,<1.0",
"huggingface_hub>=0.26.0,<2.0",
"packaging>=20.0",
"pyyaml>=6",
"tomli>=2.0; python_version<'3.11'",
]
@ -23,15 +24,21 @@ build-backend = "setuptools.build_meta"
[dependency-groups]
dev = [
"mypy == 1.14.1",
"pytest >=8",
"mktestdocs>=0.2.5",
"mypy>=1.15.0",
"pytest>=8",
# Whatever version is compatible with pytest.
"pytest-benchmark",
"torch >=2.5",
"torch>=2.5",
"types-pyyaml"
]
[project.optional-dependencies]
abi-check = ["kernel-abi-check>=0.6.2,<0.7.0"]
torch = ["torch"]
docs = [
"hf-doc-builder",
]
[project.scripts]
kernels = "kernels.cli:main"
@ -39,6 +46,9 @@ kernels = "kernels.cli:main"
[project.entry-points."egg_info.writers"]
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
[tool.isort]
profile = "black"
line_length = 119
[tool.ruff]
exclude = [
@ -65,4 +75,4 @@ line-length = 119
# Ignored rules:
# "E501" -> line length violation
lint.ignore = ["E501"]
lint.select = ["E", "F", "I", "W"]
lint.select = ["E", "F", "W"]

9
pytest.ini Normal file
View File

@ -0,0 +1,9 @@
[pytest]
markers =
cuda_only: marks tests that should only hosts with CUDA GPUs
rocm_only: marks tests that should only run on hosts with ROCm GPUs
darwin_only: marks tests that should only run on macOS
xpu_only: marks tests that should only run on hosts with Intel XPUs
npu_only: marks tests that should only run on Ascend NPUs
token: enable tests that require a write token
is_staging_test: Marks tests that should only run on a staging environment

View File

@ -1,6 +1,15 @@
import importlib.metadata
__version__ = importlib.metadata.version("kernels")
from kernels.layer import (
CUDAProperties,
Device,
LayerRepository,
LocalLayerRepository,
LockedLayerRepository,
Mode,
kernelize,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
@ -8,6 +17,7 @@ from kernels.layer import (
)
from kernels.utils import (
get_kernel,
get_local_kernel,
get_locked_kernel,
has_kernel,
install_kernel,
@ -15,15 +25,22 @@ from kernels.utils import (
)
__all__ = [
"__version__",
"CUDAProperties",
"Device",
"LayerRepository",
"LocalLayerRepository",
"LockedLayerRepository",
"Mode",
"get_kernel",
"get_local_kernel",
"get_locked_kernel",
"has_kernel",
"load_kernel",
"install_kernel",
"use_kernel_forward_from_hub",
"use_kernel_mapping",
"kernelize",
"load_kernel",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"LayerRepository",
"Device",
"use_kernel_forward_from_hub",
"use_kernel_mapping",
]

View File

@ -0,0 +1,200 @@
# AVL-balanced interval trees. We could use the intervaltree
# packages, but it seems unmaintained and does not have type
# annotations.
from typing import Generic, List, Optional, Tuple, TypeVar
T = TypeVar("T")
class _Node(Generic[T]):
"""A node in the interval tree."""
def __init__(self, start: int, end: int, data: T):
self.start: int = start
self.end: int = end
self.data: T = data
self.max_end: int = end
self.left: Optional["_Node[T]"] = None
self.right: Optional["_Node[T]"] = None
self.height: int = 1
def __repr__(self) -> str:
return f"Node({self.start}, {self.end})"
class IntervalTree(Generic[T]):
"""A data structure to hold and query (unique) intervals."""
root: Optional[_Node[T]]
def __init__(self):
self.root = None
def insert(self, start: int, end: int, data: T) -> None:
"""
Inserts a new interval into the tree.
Args:
start: The starting point of the interval.
end: The ending point of the interval.
data: The data associated with this interval.
"""
self.root = self._insert(self.root, start, end, data)
def _get_height(self, node: Optional[_Node[T]]) -> int:
if not node:
return 0
return node.height
def _get_balance(self, node: Optional[_Node[T]]) -> int:
if not node:
return 0
return self._get_height(node.left) - self._get_height(node.right)
def _update_node_attributes(self, node: _Node[T]) -> None:
node.height = 1 + max(self._get_height(node.left), self._get_height(node.right))
node.max_end = node.end
if node.left:
node.max_end = max(node.max_end, node.left.max_end)
if node.right:
node.max_end = max(node.max_end, node.right.max_end)
def _right_rotate(self, y: _Node[T]) -> _Node[T]:
"""Performs a right rotation."""
x = y.left
assert x is not None
T2 = x.right
x.right = y
y.left = T2
self._update_node_attributes(y)
self._update_node_attributes(x)
return x
def _left_rotate(self, x: _Node[T]) -> _Node[T]:
"""Performs a left rotation."""
y = x.right
assert y is not None
T2 = y.left
y.left = x
x.right = T2
self._update_node_attributes(x)
self._update_node_attributes(y)
return y
def _insert(
self, node: Optional[_Node[T]], start: int, end: int, data: T
) -> _Node[T]:
"""Recursive helper to insert a new node and balance the tree."""
if not node:
return _Node(start, end, data)
# Replace the data if the interval already exists.
if start == node.start and end == node.end:
node.data = data
return node
if start < node.start:
node.left = self._insert(node.left, start, end, data)
else:
node.right = self._insert(node.right, start, end, data)
self._update_node_attributes(node)
balance = self._get_balance(node)
# Left Left Case
if balance > 1 and node.left and start < node.left.start:
return self._right_rotate(node)
# Right Right Case
if balance < -1 and node.right and start >= node.right.start:
return self._left_rotate(node)
# Left Right Case
if balance > 1 and node.left and start >= node.left.start:
node.left = self._left_rotate(node.left)
return self._right_rotate(node)
# Right Left Case
if balance < -1 and node.right and start < node.right.start:
node.right = self._right_rotate(node.right)
return self._left_rotate(node)
return node
def search(self, point: int) -> List[T]:
"""
Searches for all intervals that contain the given point.
Args:
point: The point to search for.
Returns:
A list of data items from all matching intervals.
"""
results: List[T] = []
self._search(self.root, point, results)
return results
def _search(self, node: Optional[_Node[T]], point: int, results: List[T]) -> None:
"""Recursive helper to find all overlapping intervals."""
if node is None or point > node.max_end:
return
if node.left:
self._search(node.left, point, results)
if node.start <= point <= node.end:
results.append(node.data)
if point >= node.start and node.right:
self._search(node.right, point, results)
def find_smallest_interval(self, point: int) -> Optional[T]:
"""
Finds the item with the most specific (smallest) range for a given point.
Args:
point: The capability to look up.
Returns:
The data of the best-matching item, or None if no match is found.
"""
matches: List[Tuple[int, int, T]] = []
self._find_with_intervals(self.root, point, matches)
if not matches:
return None
# Return the smallest interval, sort by memory location when
# there are multiple matches with the same interval size. This
# is just to ensure that we can compare against a trivial
# implementation in tests.
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
return best_match[2]
def _find_with_intervals(
self,
node: Optional[_Node[T]],
point: int,
results: List[Tuple[int, int, T]],
) -> None:
"""A modified search that collects interval ranges along with data."""
if node is None or point > node.max_end:
return
if node.left:
self._find_with_intervals(node.left, point, results)
if node.start <= point <= node.end:
results.append((node.start, node.end, node.data))
if point >= node.start and node.right:
self._find_with_intervals(node.right, point, results)

View File

@ -0,0 +1,751 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Vendored from https://github.com/huggingface/doc-builder/blob/main/src/doc_builder/convert_rst_to_mdx.py
import re
# Re pattern to catch things inside ` ` in :obj:`thing`.
_re_obj = re.compile(r":obj:`([^`]+)`")
# Re pattern to catch things inside ` ` in :math:`thing`.
_re_math = re.compile(r":math:`([^`]+)`")
# Re pattern to catch things between single backquotes.
_re_single_backquotes = re.compile(r"(^|[^`])`([^`]+)`([^`]|$)")
# Re pattern to catch things between double backquotes.
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
# Re pattern to catch things inside ` ` in :func/class/meth:`thing`.
_re_func_class = re.compile(r":(?:func|class|meth):`([^`]+)`")
def convert_rst_formatting(text):
"""
Convert rst syntax for formatting to markdown in a given text.
"""
# Remove :class:, :func: and :meth: markers. To code-links and put double backquotes
# (to not be caught by the italic conversion).
text = _re_func_class.sub(r"[``\1``]", text)
# Remove :obj: markers. What's after is in a single backquotes so we put in double backquotes
# (to not be caught by the italic conversion).
text = _re_obj.sub(r"``\1``", text)
# Remove :math: markers.
text = _re_math.sub(r"\\\\(\1\\\\)", text)
# Convert content in single backquotes to italic.
text = _re_single_backquotes.sub(r"\1*\2*\3", text)
# Convert content in double backquotes to single backquotes.
text = _re_double_backquotes.sub(r"\1`\2`\3", text)
# Remove remaining ::
text = re.sub(r"::\n", "", text)
# Remove new lines inside blocks in backsticks as they will be kept.
lines = text.split("\n")
in_code = False
text = None
for line in lines:
if in_code:
splits = line.split("`")
in_code = len(splits) > 1 and len(splits) % 2 == 1
if len(splits) == 1:
# Some forgotten lone backstick
text += "\n" + line
else:
text += " " + line.lstrip()
else:
if text is not None:
text += "\n" + line
else:
text = line
splits = line.split("`")
in_code = len(splits) % 2 == 0
return text
# Re pattern to catch description and url in links of the form `description <url>`_.
_re_links = re.compile(r"`([^`]+\S)\s+</*([^/][^>`]*)>`_+")
# Re pattern to catch description and url in links of the form :prefix_link:`description <url>`_.
_re_prefix_links = re.compile(r":prefix_link:`([^`]+\S)\s+</*([^/][^>`]*)>`")
# Re pattern to catch reference in links of the form :doc:`reference`.
_re_simple_doc = re.compile(r":doc:`([^`<]*)`")
# Re pattern to catch description and reference in links of the form :doc:`description <reference>`.
_re_doc_with_description = re.compile(r":doc:`([^`<]+\S)\s+</*([^/][^>`]*)>`")
# Re pattern to catch reference in links of the form :ref:`reference`.
_re_simple_ref = re.compile(r":ref:`([^`<]*)`")
# Re pattern to catch description and reference in links of the form :ref:`description <reference>`.
_re_ref_with_description = re.compile(r":ref:`([^`<]+\S)\s+<([^>]*)>`")
def convert_rst_links(text, page_info):
"""
Convert the rst links in text to markdown.
"""
if "package_name" not in page_info:
raise ValueError("`page_info` must contain at least the package_name.")
package_name = page_info["package_name"]
version = page_info.get("version", "main")
language = page_info.get("language", "en")
no_prefix = page_info.get("no_prefix", False)
prefix = "" if no_prefix else f"/docs/{package_name}/{version}/{language}/"
# Links of the form :doc:`page`
text = _re_simple_doc.sub(rf"[\1]({prefix}\1)", text)
# Links of the form :doc:`text <page>`
text = _re_doc_with_description.sub(rf"[\1]({prefix}\2)", text)
if "page" in page_info and not no_prefix:
page = str(page_info["page"])
if page.endswith(".html"):
page = page[:-5]
prefix = f"{prefix}{page}"
else:
prefix = ""
# Refs of the form :ref:`page`
text = _re_simple_ref.sub(rf"[\1]({prefix}#\1)", text)
# Refs of the form :ref:`text <page>`
text = _re_ref_with_description.sub(rf"[\1]({prefix}#\2)", text)
# Links with a prefix
# TODO: when it exists, use the API to deal with prefix links properly.
prefix = f"https://github.com/huggingface/{package_name}/tree/main/"
text = _re_prefix_links.sub(rf"[\1]({prefix}\2)", text)
# Other links
text = _re_links.sub(r"[\1](\2)", text)
# Relative links or Transformers links need to remove the .html
if (
"(https://https://huggingface.co/" in text
or re.search(r"\(\.+/", text) is not None
):
text = text.replace(".html", "")
return text
# Re pattern that catches examples blocks of the form `Example::`.
_re_example = re.compile(r"^\s*(\S.*)::\s*$")
# Re pattern that catches rst blocks of the form `.. block_name::`.
_re_block = re.compile(r"^\s*\.\.\s+(\S+)::")
# Re pattern that catches what's after the :: in rst blocks of the form `.. block_name:: something`.
_re_block_info = re.compile(r"^\s*\.\.\s+\S+::\s*(\S.*)$")
def is_empty_line(line):
return len(line) == 0 or line.isspace()
def find_indent(line):
"""
Returns the number of spaces that start a line indent.
"""
search = re.search(r"^(\s*)(?:\S|$)", line)
if search is None:
return 0
return len(search.groups()[0])
_re_rst_option = re.compile(r"^\s*:(\S+):(.*)$")
def convert_special_chars(text):
"""
Converts { and < that have special meanings in MDX.
"""
text = text.replace("{", "&amp;lcub;")
# We don't want to replace those by the HTML code, so we temporarily set them at LTHTML
text = re.sub(
r"<(img|br|hr|Youtube)", r"LTHTML\1", text
) # html void elements with no closing counterpart
_re_lt_html = re.compile(r"<(\S+)([^>]*>)(((?!</\1>).)*)<(/\1>)", re.DOTALL)
while _re_lt_html.search(text):
text = _re_lt_html.sub(r"LTHTML\1\2\3LTHTML\5", text)
text = re.sub(r"(^|[^<])<([^<]|$)", r"\1&amp;lt;\2", text)
text = text.replace("LTHTML", "<")
return text
def parse_options(block_content):
"""
Parses the option in some rst block content.
"""
block_lines = block_content.split("\n")
block_indent = find_indent(block_lines[0])
current_option = None
result = {}
for line in block_lines:
if _re_rst_option.search(line) is not None:
current_option, value = _re_rst_option.search(line).groups()
result[current_option] = value.lstrip()
elif find_indent(line) > block_indent:
result[current_option] += " " + line.lstrip()
return result
def apply_min_indent(text, min_indent):
"""
Make sure all lines in a text are have a minimum indentation.
Args:
text (`str`): The text to treat.
min_indent (`int`): The minimal indentation.
Returns:
`str`: The processed text.
"""
lines = text.split("\n")
idx = 0
while idx < len(lines):
if is_empty_line(lines[idx]):
idx += 1
continue
indent = find_indent(lines[idx])
if indent < min_indent:
while idx < len(lines) and (
find_indent(lines[idx]) >= indent or is_empty_line(lines[idx])
):
if not is_empty_line(lines[idx]):
lines[idx] = " " * (min_indent - indent) + lines[idx]
idx += 1
else:
idx += 1
return "\n".join(lines)
def convert_rst_blocks(text, page_info):
"""
Converts rst special blocks (examples, notes) into MDX.
"""
if "package_name" not in page_info:
raise ValueError("`page_info` must contain at least the package_name.")
package_name = page_info["package_name"]
version = page_info.get("version", "main")
language = page_info.get("language", "en")
lines = text.split("\n")
idx = 0
new_lines = []
while idx < len(lines):
block_type = None
block_info = None
if _re_block.search(lines[idx]) is not None:
block_type = _re_block.search(lines[idx]).groups()[0]
if _re_block_info.search(lines[idx]) is not None:
block_info = _re_block_info.search(lines[idx]).groups()[0]
elif _re_example.search(lines[idx]) is not None:
block_type = "code-block-example"
block_info = "python"
example_name = _re_example.search(lines[idx]).groups()[0]
new_lines.append(f"<exampletitle>{example_name}:</exampletitle>\n")
elif lines[idx].strip() == "..":
block_type = "comment"
elif lines[idx].strip() == "::":
block_type = "code-block"
if block_type is not None:
block_indent = find_indent(lines[idx])
# Find the next nonempty line
idx += 1
while idx < len(lines) and is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the return line, this block will stop when we unindent under it (or has already)
example_indent = (
find_indent(lines[idx]) if idx < len(lines) else block_indent
)
if example_indent == block_indent:
block_content = ""
else:
block_lines = []
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) >= example_indent
):
block_lines.append(lines[idx][example_indent:])
idx += 1
block_content = "\n".join(block_lines)
if block_type in ["code", "code-block"]:
prefix = "```" if block_info is None else f"```{block_info}"
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n")
elif block_type == "code-block-example":
prefix = f"<example>```{block_info}"
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n</example>")
elif block_type == "note":
new_lines.append(
apply_min_indent(
f"<Tip>\n\n{block_content.strip()}\n\n</Tip>\n", block_indent
)
)
elif block_type == "warning":
new_lines.append(
apply_min_indent(
"<Tip warning={true}>\n\n"
+ f"{block_content.strip()}\n\n</Tip>\n",
block_indent,
)
)
elif block_type == "raw":
new_lines.append(block_content.strip() + "\n")
elif block_type == "math":
new_lines.append(f"$${block_content.strip()}$$\n")
elif block_type == "comment":
new_lines.append(f"<!--{block_content.strip()}\n-->\n")
elif block_type == "autofunction":
if block_info is not None:
new_lines.append(f"[[autodoc]] {block_info}\n")
elif block_type == "autoclass":
if block_info is not None:
block = f"[[autodoc]] {block_info}\n"
options = parse_options(block_content)
if "special-members" in options:
special_members = options["special-members"].split(", ")
for special_member in special_members:
block += f" - {special_member}\n"
if "members" in options:
members = options["members"]
if len(members) == 0:
block += " - all\n"
else:
for member in members.split(", "):
block += f" - {member}\n"
new_lines.append(block)
elif block_type == "image":
options = parse_options(block_content)
target = options.pop("target", None)
if block_info is not None:
options["src"] = block_info
else:
if target is None:
raise ValueError("Image source not defined.")
options["src"] = target
# Adapt path
options["src"] = options["src"].replace(
"/imgs/", f"/docs/{package_name}/{version}/{language}/imgs/"
)
html_code = " ".join(
[f'{key}="{value}"' for key, value in options.items()]
)
new_lines.append(f"<img {html_code}/>\n")
else:
new_lines.append(
f"{block_type},{block_info}\n{block_content.rstrip()}\n"
)
else:
new_lines.append(lines[idx])
idx += 1
return "\n".join(new_lines)
# Re pattern that catches rst args blocks of the form `Parameters:`.
_re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\s*$")
# Re pattern that catches return blocks of the form `Return:`.
_re_returns = re.compile(r"^\s*(Return|Yield|Raise)s?:\s*$")
def split_return_line(line):
"""
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
"""
splits_on_colon = line.split(":")
idx = 1
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
idx += 2
if idx >= len(splits_on_colon):
if len(splits_on_colon) % 2 == 1 and re.search(r"`\w+`$", line.rstrip()):
return line, ""
return None, line
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
def split_raise_line(line):
"""
Split the raise line with format `SomeError some doc`.
"""
splits_on_colon = line.strip().split(" ")
error_type, doc = splits_on_colon[0], " ".join(splits_on_colon[1:])
if error_type and error_type[-1] == ":":
error_type = error_type[:-1]
return error_type, doc
def split_arg_line(line):
"""
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
"""
splits_on_colon = line.split(":")
idx = 1
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
idx += 2
if idx >= len(splits_on_colon):
return line, ""
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
class InvalidRstDocstringError(ValueError):
pass
_re_parameters = re.compile(
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
)
_re_md_link = re.compile(r"\[(.+)\]\(.+\)", re.DOTALL)
def parse_rst_docstring(docstring):
"""
Parses a docstring written in rst, in particular the list of arguments and the return type.
"""
lines = docstring.split("\n")
idx = 0
while idx < len(lines):
# Parameters section
if _re_args.search(lines[idx]) is not None:
# Title of the section.
lines[idx] = "<parameters>\n"
# Find the next nonempty line
idx += 1
while is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the list of parameters, this block will stop when we unindent under it or we see the
# Returns or Raises block.
param_indent = find_indent(lines[idx])
while (
idx < len(lines)
and find_indent(lines[idx]) == param_indent
and _re_returns.search(lines[idx]) is None
):
intro, doc = split_arg_line(lines[idx])
# Line starting with a > after indent indicate a "section title" in the parameters.
if intro.lstrip().startswith(">"):
lines[idx] = intro.lstrip()
else:
lines[idx] = (
re.sub(r"^\s*(\S+)(\s)?", r"- **\1**\2", intro) + " --" + doc
)
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx]) or find_indent(lines[idx]) > param_indent
):
idx += 1
lines.insert(idx, "</parameters>\n")
idx += 1
# Returns section
elif _re_returns.search(lines[idx]) is not None:
# tag is either `return` or `yield`
tag = _re_returns.match(lines[idx]).group(1).lower()
# Title of the section.
lines[idx] = f"<{tag}s>\n"
# Find the next nonempty line
idx += 1
while is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the return line, this block will stop when we unindent under it.
return_indent = find_indent(lines[idx])
raised_errors = []
# The line may contain the return type.
if tag in ["return", "yield"]:
return_type, return_description = split_return_line(lines[idx])
lines[idx] = return_description
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) >= return_indent
):
idx += 1
else:
while idx < len(lines) and find_indent(lines[idx]) == return_indent:
return_type, return_description = split_raise_line(lines[idx])
raised_error = re.sub(r"^\s*`?([\w\.]*)`?$", r"``\1``", return_type)
lines[idx] = "- " + raised_error + " -- " + return_description
md_link = _re_md_link.match(raised_error)
if md_link:
raised_error = md_link[1]
raised_error = re.sub(
r"^\s*`?([\w\.]*)`?$", r"``\1``", raised_error
)
if raised_error not in raised_errors:
raised_errors.append(raised_error)
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) > return_indent
):
idx += 1
lines.insert(idx, f"</{tag}s>\n")
idx += 1
# Return block finished, we insert the return type if one was specified
if tag in ["return", "yield"] and return_type is not None:
lines[idx - 1] += f"\n<{tag}type>{return_type}</{tag}type>\n"
elif len(raised_errors) > 0:
# raised errors
lines[
idx - 1
] += f"\n<raisederrors>{' or '.join(raised_errors)}</raisederrors>\n"
else:
idx += 1
result = "\n".join(lines)
# combine multiple <parameters> blocks into one block
if result.count("<parameters>") > 1:
parameters_blocks = _re_parameters.findall(result)
parameters_blocks = [pb[0].strip() for pb in parameters_blocks]
parameters_str = "\n".join(parameters_blocks)
result = _re_parameters.sub("", result)
result += f"\n<parameters>{parameters_str}</parameters>\n"
return result
_re_list = re.compile(r"^\s*(-|\*|\d+\.)\s")
_re_autodoc = re.compile(r"^\s*\[\[autodoc\]\]\s+(\S+)\s*$")
def remove_indent(text):
"""
Remove indents in text, except the one linked to lists (or sublists).
"""
lines = text.split("\n")
# List of indents to remember for nested lists
current_indents = []
# List of new indents to remember for nested lists
new_indents = []
is_inside_code = False
code_indent = 0
for idx, line in enumerate(lines):
# Line is an item in a list.
if _re_list.search(line) is not None:
indent = find_indent(line)
# Is it a new list / new level of nestedness?
if len(current_indents) == 0 or indent > current_indents[-1]:
current_indents.append(indent)
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indent += len(_re_list.search(line).groups()[0]) + 1
new_indents.append(new_indent)
# Otherwise it's an existing level of list (current one, or previous one)
else:
# Let's find the proper level of indentation
level = len(current_indents) - 1
while level >= 0 and current_indents[level] != indent:
level -= 1
current_indents = current_indents[: level + 1]
new_indents = new_indents[:level]
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indent += len(_re_list.search(line).groups()[0]) + 1
new_indents.append(new_indent)
# Line is an autodoc, we keep the indent for the list just after if there is one.
elif _re_autodoc.search(line) is not None:
indent = find_indent(line)
current_indents = [indent]
new_indents = [4]
lines[idx] = line.strip()
# Deal with empty lines separately
elif is_empty_line(line):
lines[idx] = ""
# Code blocks
elif line.lstrip().startswith("```"):
is_inside_code = not is_inside_code
if is_inside_code:
code_indent = find_indent(line)
lines[idx] = line[code_indent:]
elif is_inside_code:
lines[idx] = line[code_indent:]
else:
indent = find_indent(line)
if len(current_indents) > 0 and indent > current_indents[-1]:
lines[idx] = " " * new_indents[-1] + line[indent:]
elif len(current_indents) > 0:
# Let's find the proper level of indentation
level = len(current_indents) - 1
while level >= 0 and current_indents[level] > indent:
level -= 1
current_indents = current_indents[: level + 1]
if level >= 0:
if current_indents[level] < indent:
new_indents = new_indents[: level + 1]
else:
new_indents = new_indents[:level]
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indents.append(new_indent)
else:
new_indents = []
lines[idx] = line[indent:]
else:
lines[idx] = line[indent:]
return "\n".join(lines)
def base_rst_to_mdx(text, page_info, unindent=True):
"""
Convert a text from rst to mdx, with the base operations necessary for both docstrings and rst docs.
"""
text = convert_rst_links(text, page_info)
text = convert_special_chars(text)
text = convert_rst_blocks(text, page_info)
# Convert * in lists to - to avoid the formatting conversion treat them as bold.
text = re.sub(r"^(\s*)\*(\s)", r"\1-\2", text, flags=re.MULTILINE)
text = convert_rst_formatting(text)
return remove_indent(text) if unindent else text
def convert_rst_docstring_to_mdx(docstring, page_info):
"""
Convert a docstring written in rst to mdx.
"""
text = parse_rst_docstring(docstring)
return base_rst_to_mdx(text, page_info)
def process_titles(lines):
"""Converts rst titles to markdown titles."""
title_chars = """= - ` : ' " ~ ^ _ * + # < >""".split(" ")
title_levels = {}
new_lines = []
for line in lines:
if (
len(new_lines) > 0
and len(line) >= len(new_lines[-1])
and len(set(line)) == 1
and line[0] in title_chars
and line != "::"
):
char = line[0]
level = title_levels.get(char, len(title_levels) + 1)
if level not in title_levels:
title_levels[char] = level
new_lines[-1] = f"{'#' * level} {new_lines[-1]}"
else:
new_lines.append(line)
return new_lines
# Matches lines with a pattern of a table new line in rst.
_re_ignore_line_table = re.compile(r"^(\+[\-\s]+)+\+\s*$")
# Matches lines with a pattern of a table new line in rst, with a first column empty.
_re_ignore_line_table1 = re.compile(r"^\|\s+(\+[\-\s]+)+\+\s*$")
# Matches lines with a pattern of a first table line in rst.
_re_sep_line_table = re.compile(r"^(\+[=\s]+)+\+\s*$")
# Re pattern that catches anchors of the type .. reference:
_re_anchor_section = re.compile(r"^\.\.\s+_(\S+):")
def split_pt_tf_code_blocks(text):
"""
Split PyTorch and TensorFlow specific block codes.
"""
lines = text.split("\n")
new_lines = []
idx = 0
while idx < len(lines):
if lines[idx].startswith("```"):
code_lines = {"common": [lines[idx]], "pytorch": [], "tensorflow": []}
is_pytorch = False
is_tensorflow = False
idx += 1
while idx < len(lines) and lines[idx].strip() != "```":
if "## PYTORCH CODE" in lines[idx]:
is_pytorch = True
is_tensorflow = False
elif "## TENSORFLOW CODE" in lines[idx]:
is_tensorflow = True
is_pytorch = False
elif is_pytorch:
code_lines["pytorch"].append(lines[idx])
elif is_tensorflow:
code_lines["tensorflow"].append(lines[idx])
else:
code_lines["common"].append(lines[idx])
idx += 1
if len(code_lines["pytorch"]) > 0 or len(code_lines["tensorflow"]) > 0:
block_lines = ["<frameworkcontent>", "<pt>"]
block_lines.extend(code_lines["common"].copy() + code_lines["pytorch"])
block_lines.extend(["```", "</pt>", "<tf>"])
block_lines.extend(
code_lines["common"].copy() + code_lines["tensorflow"]
)
block_lines.extend(["```", "</tf>", "</frameworkcontent>"])
new_lines.extend(block_lines)
else:
block_lines = code_lines["common"] + ["```"]
new_lines.extend(block_lines)
idx += 1
else:
new_lines.append(lines[idx])
idx += 1
return "\n".join(new_lines)
def convert_rst_to_mdx(rst_text, page_info, add_imports=True):
"""
Convert a document written in rst to mdx.
"""
lines = rst_text.split("\n")
lines = process_titles(lines)
if add_imports:
new_lines = [
'<script lang="ts">',
' import Tip from "$lib/Tip.svelte";',
' import Youtube from "$lib/Youtube.svelte";',
' import Docstring from "$lib/Docstring.svelte";',
' import CodeBlock from "$lib/CodeBlock.svelte";',
' import CodeBlockFw from "$lib/CodeBlockFw.svelte";',
' import DocNotebookDropdown from "$lib/DocNotebookDropdown.svelte";',
' import CourseFloatingBanner from "$lib/CourseFloatingBanner.svelte";',
' import IconCopyLink from "$lib/IconCopyLink.svelte";',
' import FrameworkContent from "$lib/FrameworkContent.svelte";',
' import Markdown from "$lib/Markdown.svelte";',
' import ExampleCodeBlock from "$lib/ExampleCodeBlock.svelte";',
' import Added from "$lib/Added.svelte";',
' import Changed from "$lib/Changed.svelte";',
' import Deprecated from "$lib/Deprecated.svelte";',
' import PipelineIcon from "$lib/PipelineIcon.svelte";',
' import PipelineTag from "$lib/PipelineTag.svelte";',
" ",
' export let fw: "pt" | "tf"',
"</script>",
"<svelte:head>",
'<meta name="hf:doc:metadata" content={JSON.stringify(metadata)} >',
"</svelte:head>",
"",
]
else:
new_lines = []
for line in lines:
if _re_ignore_line_table.search(line) is not None:
continue
elif _re_ignore_line_table1.search(line) is not None:
continue
elif _re_sep_line_table.search(line) is not None:
line = line.replace("=", "-").replace("+", "|")
elif _re_anchor_section.search(line) is not None:
anchor_name = _re_anchor_section.search(line).groups()[0]
line = f"<a id='{anchor_name}'></a>"
new_lines.append(line)
text = "\n".join(new_lines)
return split_pt_tf_code_blocks(base_rst_to_mdx(text, page_info))

52
src/kernels/_versions.py Normal file
View File

@ -0,0 +1,52 @@
from typing import Dict, Optional
from huggingface_hub import HfApi
from huggingface_hub.hf_api import GitRefInfo
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
"""Get kernel versions that are available in the repository."""
versions = {}
for tag in HfApi().list_repo_refs(repo_id).tags:
if not tag.name.startswith("v"):
continue
try:
versions[Version(tag.name[1:])] = tag
except InvalidVersion:
continue
return versions
def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo:
"""
Get the locks for a kernel with the given version spec.
The version specifier can be any valid Python version specifier:
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
"""
versions = _get_available_versions(repo_id)
requirement = SpecifierSet(version_spec)
accepted_versions = sorted(requirement.filter(versions.keys()))
if len(accepted_versions) == 0:
raise ValueError(
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
)
return versions[accepted_versions[-1]]
def select_revision_or_version(
repo_id: str, revision: Optional[str], version: Optional[str]
) -> str:
if revision is not None and version is not None:
raise ValueError("Either a revision or a version must be specified, not both.")
elif revision is None and version is None:
revision = "main"
elif version is not None:
revision = resolve_version_spec_as_ref(repo_id, version).target_commit
assert revision is not None
return revision

142
src/kernels/check.py Normal file
View File

@ -0,0 +1,142 @@
import sys
from pathlib import Path
from huggingface_hub import snapshot_download
from kernel_abi_check import (
BinaryFormat,
IncompatibleAbi3Symbol,
IncompatibleMacOSVersion,
IncompatibleManylinuxSymbol,
MissingMacOSVersion,
NonAbi3Symbol,
ObjectFile,
)
from kernels.utils import CACHE_DIR
def check_kernel(
*, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
):
variants_path = (
Path(
snapshot_download(
repo_id,
allow_patterns=["build/*"],
cache_dir=CACHE_DIR,
revision=revision,
)
)
/ "build"
)
has_issues = False
for variant_path in variants_path.iterdir():
if not variant_path.is_dir():
print(
f"⛔ `build/` must only contain directories, found: {variant_path.name}",
file=sys.stderr,
)
has_issues = True
continue
print(f"Checking variant: {variant_path.name}", file=sys.stderr)
indent = 2
for dylib_path in variant_path.rglob("*.so"):
print_with_indent(
indent,
f"Dynamic library {dylib_path.relative_to(variant_path)}:",
)
o = ObjectFile(dylib_path)
has_issues |= check_abi3(o, python_abi, indent + 2)
# TODO: also check operating system
if o.format() == BinaryFormat.ELF:
has_issues |= check_manylinux(o, manylinux, indent + 2)
elif o.format() == BinaryFormat.MACH_O:
has_issues |= check_macos(o, macos, indent + 2)
if has_issues:
sys.exit(1)
def check_abi3(object_file: ObjectFile, python_abi: str, indent: int) -> bool:
has_issues = False
violations = object_file.check_python_abi(python_abi)
if violations != []:
has_issues = True
print_with_indent(
indent,
f"⛔ Found symbols that are incompatible with Python ABI {python_abi}:",
)
for violation in violations:
if isinstance(violation, IncompatibleAbi3Symbol):
print_with_indent(
indent + 3,
f"{violation.name}: {violation.version_added}",
)
elif isinstance(violation, NonAbi3Symbol):
print_with_indent(
indent + 3,
f"{violation.name}",
)
else:
print_with_indent(indent, f"🐍 Python ABI {python_abi} compatible")
return has_issues
def check_macos(object_file: ObjectFile, macos: str, indent: int) -> bool:
has_issues = False
violations = object_file.check_macos(macos)
if violations != []:
has_issues = True
print_with_indent(
indent,
f"⛔ Found incompatibility with macOS {macos}:",
)
for violation in violations:
if isinstance(violation, MissingMacOSVersion):
print_with_indent(
indent + 3,
"shared library does not contain macOS version",
)
elif isinstance(violation, IncompatibleMacOSVersion):
print_with_indent(
indent + 3,
f"shared library requires macOS {violation.version}",
)
else:
print_with_indent(indent, f"🍏 compatible with macOS {macos}")
return has_issues
def check_manylinux(object_file: ObjectFile, manylinux: str, indent: int) -> bool:
has_issues = False
violations = object_file.check_manylinux(manylinux)
if violations != []:
has_issues = True
print_with_indent(
indent,
f"⛔ Found symbols that are incompatible with {manylinux}:",
)
for violation in violations:
if isinstance(violation, IncompatibleManylinuxSymbol):
print_with_indent(
indent + 3,
f"{violation.name}_{violation.dep}: {violation.version}",
)
else:
print_with_indent(indent, f"🐧 {manylinux} compatible")
return has_issues
def print_with_indent(indent: int, message: str):
print(f"{' ' * indent}{message}", file=sys.stderr)

View File

@ -4,10 +4,15 @@ import json
import sys
from pathlib import Path
from huggingface_hub import create_repo, upload_folder, create_branch
from kernels.compat import tomllib
from kernels.lockfile import KernelLock, get_kernel_locks
from kernels.utils import install_kernel, install_kernel_all_variants
from .doc import generate_readme_for_kernel
from .wheel import build_variant_to_wheel
def main():
parser = argparse.ArgumentParser(
@ -15,6 +20,31 @@ def main():
)
subparsers = parser.add_subparsers(required=True)
check_parser = subparsers.add_parser("check", help="Check a kernel for compliance")
check_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
check_parser.add_argument(
"--revision",
type=str,
default="main",
help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
)
check_parser.add_argument("--macos", type=str, help="macOS version", default="15.0")
check_parser.add_argument(
"--manylinux", type=str, help="Manylinux version", default="manylinux_2_28"
)
check_parser.add_argument(
"--python-abi", type=str, help="Python ABI version", default="3.9"
)
check_parser.set_defaults(
func=lambda args: check_kernel(
macos=args.macos,
manylinux=args.manylinux,
python_abi=args.python_abi,
repo_id=args.repo_id,
revision=args.revision,
)
)
download_parser = subparsers.add_parser("download", help="Download locked kernels")
download_parser.add_argument(
"project_dir",
@ -28,6 +58,29 @@ def main():
)
download_parser.set_defaults(func=download_kernels)
upload_parser = subparsers.add_parser("upload", help="Upload kernels to the Hub")
upload_parser.add_argument(
"kernel_dir",
type=Path,
help="Directory of the kernel build",
)
upload_parser.add_argument(
"--repo_id",
type=str,
help="Repository ID to use to upload to the Hugging Face Hub",
)
upload_parser.add_argument(
"--branch",
type=None,
help="If set, the upload will be made to a particular branch of the provided `repo_id`.",
)
upload_parser.add_argument(
"--private",
action="store_true",
help="If the repository should be private.",
)
upload_parser.set_defaults(func=upload_kernels)
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
lock_parser.add_argument(
"project_dir",
@ -36,6 +89,47 @@ def main():
)
lock_parser.set_defaults(func=lock_kernels)
to_wheel_parser = subparsers.add_parser(
"to-wheel", help="Convert a kernel to a wheel file"
)
to_wheel_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
to_wheel_parser.add_argument("version", type=str, help="The kernel version")
to_wheel_parser.add_argument(
"--python-version",
type=str,
default="3.9",
help="The minimum Python version. Must match the Python version that the kernel was compiled for.",
)
to_wheel_parser.add_argument(
"--manylinux-version",
type=str,
default="2.28",
help="The manylinux version. Must match the manylinux version that the kernel was compiled for.",
)
to_wheel_parser.set_defaults(func=kernels_to_wheel)
# Add generate-readme subcommand parser
generate_readme_parser = subparsers.add_parser(
"generate-readme",
help="Generate README snippets for a kernel's public functions",
)
generate_readme_parser.add_argument(
"repo_id",
type=str,
help="The kernel repo ID (e.g., kernels-community/activation)",
)
generate_readme_parser.add_argument(
"--revision",
type=str,
default="main",
help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
)
generate_readme_parser.set_defaults(
func=lambda args: generate_readme_for_kernel(
repo_id=args.repo_id, revision=args.revision
)
)
args = parser.parse_args()
args.func(args)
@ -77,6 +171,24 @@ def download_kernels(args):
sys.exit(1)
def kernels_to_wheel(args):
variants_path = install_kernel_all_variants(
repo_id=args.repo_id, revision=f"v{args.version}"
)
for variant_path in variants_path.iterdir():
if not variant_path.is_dir():
continue
wheel_path = build_variant_to_wheel(
manylinux_version=args.manylinux_version,
python_version=args.python_version,
repo_id=args.repo_id,
version=args.version,
variant_path=variant_path,
wheel_dir=Path("."),
)
print(f"☸️ {wheel_path.name}", file=sys.stderr)
def lock_kernels(args):
with open(args.project_dir / "pyproject.toml", "rb") as f:
data = tomllib.load(f)
@ -91,8 +203,61 @@ def lock_kernels(args):
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
def upload_kernels(args):
# Resolve `kernel_dir` to be uploaded.
kernel_dir = Path(args.kernel_dir).resolve()
build_dir = kernel_dir / "build"
if not kernel_dir.is_dir():
raise ValueError(f"{kernel_dir} is not a directory")
if not build_dir.is_dir():
raise ValueError("Couldn't find `build` directory inside `kernel_dir`")
repo_id = create_repo(
repo_id=args.repo_id, private=args.private, exist_ok=True
).repo_id
if args.branch is not None:
create_branch(repo_id=repo_id, branch=args.branch, exist_ok=True)
delete_patterns: set[str] = set()
for build_variant in build_dir.iterdir():
if build_variant.is_dir():
delete_patterns.add(f"{build_variant.name}/**")
upload_folder(
repo_id=repo_id,
folder_path=build_dir,
revision=args.branch,
path_in_repo="build",
delete_patterns=list(delete_patterns),
commit_message="Build uploaded using `kernels`.",
)
print(f"✅ Kernel upload successful. Find the kernel in https://hf.co/{repo_id}.")
class _JSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)
def check_kernel(
*, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
):
try:
import kernels.check
except ImportError:
print(
"`kernels check` requires the `kernel-abi-check` package: pip install kernel-abi-check",
file=sys.stderr,
)
sys.exit(1)
kernels.check.check_kernel(
macos=macos,
manylinux=manylinux,
python_abi=python_abi,
repo_id=repo_id,
revision=revision,
)

242
src/kernels/doc.py Normal file
View File

@ -0,0 +1,242 @@
import inspect
import re
import sys
from types import ModuleType
import yaml
from ._vendored.convert_rst_to_mdx import convert_rst_docstring_to_mdx
from .utils import get_kernel
_RE_PARAMETERS = re.compile(
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
)
_RE_RETURNS = re.compile(r"<returns>(((?!<returns>).)*)</returns>", re.DOTALL)
_RE_RETURNTYPE = re.compile(
r"<returntype>(((?!<returntype>).)*)</returntype>", re.DOTALL
)
def _extract_description_before_tags(docstring_mdx: str) -> str:
"""Extract the description part of a docstring before any tags."""
params_pos = docstring_mdx.find("<parameters>")
returns_pos = docstring_mdx.find("<returns>")
returntype_pos = docstring_mdx.find("<returntype>")
positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1]
if positions:
first_tag_pos = min(positions)
return docstring_mdx[:first_tag_pos].strip()
else:
return docstring_mdx.strip()
def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None:
"""Print the parameters section from a docstring."""
matches = _RE_PARAMETERS.findall(docstring_mdx)
if matches:
header = "#" * header_level
print(f"\n{header} Parameters")
for match in matches:
print(f"\n{match[0].strip()}")
def _print_returns_section(
docstring_mdx: str, *, context_name: str, header_level: int
) -> None:
"""Print the returns section from a docstring."""
return_matches = _RE_RETURNS.findall(docstring_mdx)
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
if return_matches or returntype_matches:
header = "#" * header_level
print(f"\n{header} Returns")
if returntype_matches:
if len(returntype_matches) > 1:
raise ValueError(
f"More than one <returntype> tag found in docstring for {context_name}"
)
print(f"\n**Type**: {returntype_matches[0][0].strip()}")
if return_matches:
for match in return_matches:
print(f"\n{match[0].strip()}")
def _get_docstring(obj, use_dict_check: bool = False) -> str:
"""Get docstring from an object, with fallback to default message."""
# Check whether the class/method itself has docs and not just
# the superclass.
if use_dict_check:
has_doc = obj.__dict__.get("__doc__", None) is not None
else:
has_doc = getattr(obj, "__doc__", None) is not None
# We use inspect.getdoc because it does normalization.
doc = inspect.getdoc(obj)
return doc if has_doc and doc is not None else "No documentation available."
def _process_and_print_docstring(
docstring: str, *, kernel_name: str, context_name: str, header_level: int
) -> None:
"""Convert docstring to MDX and print description, parameters, and returns sections."""
docstring_mdx = convert_rst_docstring_to_mdx(
docstring, page_info={"package_name": kernel_name}
)
# Print the description
description = _extract_description_before_tags(docstring_mdx)
print(f"\n{description}")
# Print parameters and returns sections
_print_parameters_section(docstring_mdx, header_level=header_level)
_print_returns_section(
docstring_mdx, context_name=context_name, header_level=header_level
)
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
kernel_module = get_kernel(repo_id=repo_id, revision=revision)
kernel_name = repo_id.split("/")[-1].replace("-", "_")
generate_metadata(kernel_module)
generate_kernel_doc(kernel_module, kernel_name)
generate_function_doc(kernel_module, kernel_name)
generate_layers_doc(kernel_module, kernel_name)
def generate_metadata(module: ModuleType) -> None:
metadata = getattr(module, "__kernel_metadata__", {})
if "tags" not in metadata:
metadata["tags"] = ["kernel"]
else:
if "kernel" not in metadata["tags"]:
metadata["tags"].append("kernel")
print("---")
print(yaml.dump(metadata), end="")
print("---")
def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None:
docstring = module.__doc__.strip() if module.__doc__ is not None else None
if docstring:
title, rest = docstring.split("\n", 1)
print(f"# {title.strip()}")
print(
f"\n{convert_rst_docstring_to_mdx(rest.strip(), page_info={'package_name': kernel_name})}"
)
def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None:
print("\n## Functions")
# Track if we found any functions
found_functions = False
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
# Do not include imported functions.
if func.__module__ != kernel_module.__name__:
continue
# Exclude private functions.
if name.startswith("_"):
continue
found_functions = True
try:
sig = inspect.signature(func)
docstring = _get_docstring(func)
except ValueError:
print(
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
file=sys.stderr,
)
continue
print(f"\n### Function `{name}`")
print(f"\n`{sig}`")
_process_and_print_docstring(
docstring, kernel_name=kernel_name, context_name=name, header_level=3
)
if not found_functions:
print("\nNo public top-level functions.")
def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None:
# Check if layers module is available
layers_module = getattr(kernel_module, "layers", None)
if layers_module is None:
return
print("\n## Layers")
# Track if we found any classes
found_classes = False
for class_name, cls in inspect.getmembers(layers_module, inspect.isclass):
# Exclude classes that were imported.
if cls.__module__ != layers_module.__name__:
continue
found_classes = True
try:
# Get docstring, but not from superclasses.
class_docstring = _get_docstring(cls, use_dict_check=True)
except Exception:
print(
f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}",
file=sys.stderr,
)
continue
print(f"\n### Class `{class_name}`")
# Always print class description (helper handles conversion and formatting)
class_docstring_mdx = convert_rst_docstring_to_mdx(
class_docstring, page_info={"package_name": kernel_name}
)
description = _extract_description_before_tags(class_docstring_mdx)
print(f"\n{description}")
# Document methods
print("\n#### Methods")
for method_name, method in inspect.getmembers(cls, inspect.isfunction):
# Note: also skip __init__, since extension layers cannot have a constructor.
if method_name.startswith("_"):
continue
# Skip methods from superclasses.
if method_name not in cls.__dict__:
continue
try:
sig = inspect.signature(method)
method_docstring = _get_docstring(method)
except ValueError:
print(
f"Warning: Could not retrieve signature for {method_name} in {class_name}",
file=sys.stderr,
)
continue
print(f"\n##### Method `{method_name}`")
print(f"\n`{sig}`")
_process_and_print_docstring(
method_docstring,
kernel_name=kernel_name,
context_name=method_name,
header_level=6,
)
if not found_classes:
print("\nNo layers defined.")

File diff suppressed because it is too large Load Diff

View File

@ -4,10 +4,8 @@ from pathlib import Path
from typing import Dict, List, Tuple
from huggingface_hub import HfApi
from huggingface_hub.hf_api import GitRefInfo
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version
from kernels._versions import resolve_version_spec_as_ref
from kernels.compat import tomllib
@ -31,20 +29,6 @@ class KernelLock:
return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
"""Get kernel versions that are available in the repository."""
versions = {}
for tag in HfApi().list_repo_refs(repo_id).tags:
if not tag.name.startswith("v"):
continue
try:
versions[Version(tag.name[1:])] = tag
except InvalidVersion:
continue
return versions
def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
"""
Get the locks for a kernel with the given version spec.
@ -52,16 +36,7 @@ def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
The version specifier can be any valid Python version specifier:
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
"""
versions = _get_available_versions(repo_id)
requirement = SpecifierSet(version_spec)
accepted_versions = sorted(requirement.filter(versions.keys()))
if len(accepted_versions) == 0:
raise ValueError(
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
)
tag_for_newest = versions[accepted_versions[-1]]
tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec)
r = HfApi().repo_info(
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True

View File

@ -11,13 +11,16 @@ import sys
from importlib.metadata import Distribution
from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
from huggingface_hub import file_exists, snapshot_download
from packaging.version import parse
from kernels._versions import select_revision_or_version
from kernels.lockfile import KernelLock, VariantLock
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
def _get_cache_dir() -> Optional[str]:
"""Returns the kernels cache directory."""
@ -34,6 +37,14 @@ def _get_cache_dir() -> Optional[str]:
CACHE_DIR: Optional[str] = _get_cache_dir()
def _get_privateuse_backend_name() -> Optional[str]:
import torch
if hasattr(torch._C, "_get_privateuse1_backend_name"):
return torch._C._get_privateuse1_backend_name()
return None
def build_variant() -> str:
import torch
@ -43,14 +54,31 @@ def build_variant() -> str:
elif torch.version.hip is not None:
rocm_version = parse(torch.version.hip.split("-")[0])
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
elif torch.backends.mps.is_available():
compute_framework = "metal"
elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
version = torch.version.xpu
compute_framework = f"xpu{version[0:4]}{version[5:6]}"
elif _get_privateuse_backend_name() == "npu":
from torch_npu.utils.collect_env import get_cann_version # type: ignore[import-not-found]
cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2]
compute_framework = f"cann{cann_major}{cann_minor}"
else:
raise AssertionError("Torch was not compiled with CUDA or ROCm enabled.")
raise AssertionError(
"Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled."
)
torch_version = parse(torch.__version__)
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
cpu = platform.machine()
os = platform.system().lower()
if os == "darwin":
cpu = "aarch64" if cpu == "arm64" else cpu
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
@ -82,15 +110,32 @@ def install_kernel(
revision: str,
local_files_only: bool = False,
variant_locks: Optional[Dict[str, VariantLock]] = None,
user_agent: Optional[Union[str, dict]] = None,
) -> Tuple[str, Path]:
"""
Download a kernel for the current environment to the cache.
The output path is validated againt `hash` when set.
The output path is validated against the hashes in `variant_locks` when provided.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
revision (`str`):
The specific revision (branch, tag, or commit) to download.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only use local files and not download from the Hub.
variant_locks (`Dict[str, VariantLock]`, *optional*):
Optional dictionary of variant locks for validation.
user_agent (`Union[str, dict]`, *optional*):
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
Returns:
`Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
"""
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
universal_variant = universal_build_variant()
user_agent = _get_user_agent(user_agent=user_agent)
repo_path = Path(
snapshot_download(
repo_id,
@ -98,9 +143,27 @@ def install_kernel(
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
user_agent=user_agent,
)
)
try:
return _load_kernel_from_path(repo_path, package_name, variant_locks)
except FileNotFoundError:
# Redo with more specific error message.
raise FileNotFoundError(
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
)
def _load_kernel_from_path(
repo_path: Path,
package_name: str,
variant_locks: Optional[Dict[str, VariantLock]] = None,
) -> Tuple[str, Path]:
variant = build_variant()
universal_variant = universal_build_variant()
variant_path = repo_path / "build" / variant
universal_variant_path = repo_path / "build" / universal_variant
@ -119,7 +182,7 @@ def install_kernel(
if not os.path.exists(module_init_path):
raise FileNotFoundError(
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
f"Kernel at path `{repo_path}` does not have build: {variant}"
)
return package_name, variant_path
@ -156,16 +219,103 @@ def install_kernel_all_variants(
return repo_path / "build"
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
package_name, package_path = install_kernel(repo_id, revision=revision)
def get_kernel(
repo_id: str,
revision: Optional[str] = None,
version: Optional[str] = None,
user_agent: Optional[Union[str, dict]] = None,
) -> ModuleType:
"""
Load a kernel from the kernel hub.
This function downloads a kernel to the local Hugging Face Hub cache directory (if it was not downloaded before)
and then loads the kernel.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
revision (`str`, *optional*, defaults to `"main"`):
The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
version (`str`, *optional*):
The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
Cannot be used together with `revision`.
user_agent (`Union[str, dict]`, *optional*):
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
Returns:
`ModuleType`: The imported kernel module.
Example:
```python
import torch
from kernels import get_kernel
activation = get_kernel("kernels-community/activation")
x = torch.randn(10, 20, device="cuda")
out = torch.empty_like(x)
result = activation.silu_and_mul(out, x)
```
"""
revision = select_revision_or_version(repo_id, revision, version)
package_name, package_path = install_kernel(
repo_id, revision=revision, user_agent=user_agent
)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def has_kernel(repo_id: str, revision: str = "main") -> bool:
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
"""
Check whether a kernel build exists for the current environment
(Torch version and compute framework).
Import a kernel from a local kernel repository path.
Args:
repo_path (`Path`):
The local path to the kernel repository.
package_name (`str`):
The name of the package to import from the repository.
Returns:
`ModuleType`: The imported kernel module.
"""
variant = build_variant()
universal_variant = universal_build_variant()
# Presume we were given the top level path of the kernel repository.
for base_path in [repo_path, repo_path / "build"]:
# Prefer the universal variant if it exists.
for v in [universal_variant, variant]:
package_path = base_path / v / package_name / "__init__.py"
if package_path.exists():
return import_from_path(package_name, package_path)
# If we didn't find the package in the repo we may have a explicit
# package path.
package_path = repo_path / package_name / "__init__.py"
if package_path.exists():
return import_from_path(package_name, package_path)
raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")
def has_kernel(
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
) -> bool:
"""
Check whether a kernel build exists for the current environment (Torch version and compute framework).
Args:
repo_id (`str`):
The Hub repository containing the kernel.
revision (`str`, *optional*, defaults to `"main"`):
The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
version (`str`, *optional*):
The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
Cannot be used together with `revision`.
Returns:
`bool`: `True` if a kernel is available for the current environment.
"""
revision = select_revision_or_version(repo_id, revision, version)
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
universal_variant = universal_build_variant()
@ -188,8 +338,16 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
"""
Get a pre-downloaded, locked kernel.
If `lockfile` is not specified, the lockfile will be loaded from the
caller's package metadata.
If `lockfile` is not specified, the lockfile will be loaded from the caller's package metadata.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
lockfile (`Path`, *optional*):
Path to the lockfile. If not provided, the lockfile will be loaded from the caller's package metadata.
Returns:
`ModuleType`: The imported kernel module.
"""
if lockfile is None:
locked_sha = _get_caller_locked_kernel(repo_id)
@ -234,7 +392,18 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
"""Get a kernel using a lock file."""
"""
Get a kernel using a lock file.
Args:
repo_id (`str`):
The Hub repository containing the kernel.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only use local files and not download from the Hub.
Returns:
`ModuleType`: The imported kernel module.
"""
locked_sha = _get_caller_locked_kernel(repo_id)
if locked_sha is None:
@ -346,3 +515,24 @@ def git_hash_object(data: bytes, object_type: str = "blob"):
def package_name_from_repo_id(repo_id: str) -> str:
return repo_id.split("/")[-1].replace("-", "_")
def _get_user_agent(
user_agent: Optional[Union[dict, str]] = None,
) -> Union[None, dict, str]:
import torch
from . import __version__
if os.getenv("DISABLE_TELEMETRY", "false").upper() in ENV_VARS_TRUE_VALUES:
return None
if user_agent is None:
user_agent = {
"kernels": __version__,
"torch": torch.__version__,
"build_variant": build_variant(),
"file_type": "kernel",
}
return user_agent

186
src/kernels/wheel.py Normal file
View File

@ -0,0 +1,186 @@
import email.policy
import os
from dataclasses import dataclass
from email.message import Message
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from typing import Optional
try:
KERNELS_VERSION = version("kernels")
except PackageNotFoundError:
KERNELS_VERSION = "unknown"
@dataclass
class Metadata:
name: str
version: str
cuda_version: Optional[str]
cxx_abi_version: Optional[str]
torch_version: Optional[str]
os: Optional[str]
platform: Optional[str]
@property
def is_universal(self) -> bool:
return self.platform is None
def build_variant_to_wheel(
repo_id: str,
*,
version: str,
variant_path: Path,
wheel_dir: Path,
manylinux_version: str = "2.28",
python_version: str = "3.9",
) -> Path:
"""
Create a wheel file from the variant path.
"""
name = repo_id.split("/")[-1].replace("_", "-")
metadata = extract_metadata(name, version, variant_path)
return build_wheel(
metadata,
variant_path=variant_path,
wheel_dir=wheel_dir,
manylinux_version=manylinux_version,
python_version=python_version,
)
def extract_metadata(name: str, version: str, variant_path: Path) -> Metadata:
"""
Extract metadata from the variant path.
"""
if variant_path.name == "torch-universal":
return Metadata(
name=name,
version=version,
cuda_version=None,
cxx_abi_version=None,
torch_version=None,
os=None,
platform=None,
)
if not variant_path.name.startswith("torch"):
raise ValueError("Currently only conversion of Torch kernels is supported.")
variant_parts = variant_path.name.removeprefix("torch").split("-")
if len(variant_parts) != 5:
raise ValueError(f"Invalid variant name: {variant_path.name}")
torch_version = f"{variant_parts[0][:-1]}.{variant_parts[0][-1:]}"
cpp_abi_version = variant_parts[1].removeprefix("cxx")
cuda_version = variant_parts[2].removeprefix("cu")
platform = variant_parts[3].replace("-", "_")
os = variant_parts[4]
return Metadata(
name=name,
version=version,
cuda_version=cuda_version,
cxx_abi_version=cpp_abi_version,
torch_version=torch_version,
os=os,
platform=platform,
)
def build_wheel(
metadata: Metadata,
*,
variant_path: Path,
wheel_dir: Path,
manylinux_version: str = "2.28",
python_version: str = "3.9",
) -> Path:
"""
Build the wheel file.
"""
try:
from wheel.wheelfile import WheelFile # type: ignore
except ImportError:
raise ImportError(
"The 'wheel' package is required to build wheels. Please install it with: `pip install wheel`"
)
name = metadata.name.replace("-", "_")
python_version_flat = python_version.replace(".", "")
if metadata.is_universal:
python_tag = f"py{python_version_flat}"
abi_tag = "none"
platform_tag = "any"
wheel_filename = (
f"{name}-{metadata.version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
)
dist_info_dir_name = f"{name}-{metadata.version}.dist-info"
root_is_purelib = "true"
requires_dist_torch = "torch"
else:
python_tag = f"cp{python_version_flat}"
abi_tag = "abi3"
if (
metadata.torch_version is None
or metadata.cuda_version is None
or metadata.cxx_abi_version is None
or metadata.os is None
or metadata.platform is None
):
raise ValueError(
"Torch version, CUDA version, C++ ABI version, OS, and platform must be specified for non-universal wheels."
)
local_version = f"torch{metadata.torch_version.replace('.', '')}cu{metadata.cuda_version}cxx{metadata.cxx_abi_version}"
if metadata.os == "linux":
platform_tag = (
f"manylinux_{manylinux_version.replace('.', '_')}_{metadata.platform}"
)
else:
platform_tag = f"{metadata.os}_{metadata.platform.replace('-', '_')}"
wheel_filename = f"{name}-{metadata.version}+{local_version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
dist_info_dir_name = f"{name}-{metadata.version}+{local_version}.dist-info"
root_is_purelib = "false"
requires_dist_torch = f"torch=={metadata.torch_version}.*"
wheel_path = wheel_dir / wheel_filename
wheel_msg = Message(email.policy.compat32)
wheel_msg.add_header("Wheel-Version", "1.0")
wheel_msg.add_header("Generator", f"kernels ({KERNELS_VERSION})")
wheel_msg.add_header("Root-Is-Purelib", root_is_purelib)
wheel_msg.add_header("Tag", f"{python_tag}-{abi_tag}-{platform_tag}")
metadata_msg = Message(email.policy.compat32)
metadata_msg.add_header("Metadata-Version", "2.1")
metadata_msg.add_header("Name", name)
metadata_msg.add_header("Version", metadata.version)
metadata_msg.add_header("Summary", f"{name} kernel")
metadata_msg.add_header("Requires-Python", ">=3.9")
metadata_msg.add_header("Requires-Dist", requires_dist_torch)
source_pkg_dir = variant_path / name
with WheelFile(wheel_path, "w") as wheel_file:
for root, dirnames, filenames in os.walk(source_pkg_dir):
for filename in filenames:
if filename.endswith(".pyc"):
continue
abs_filepath = os.path.join(root, filename)
entry_name = os.path.relpath(abs_filepath, variant_path)
wheel_file.write(abs_filepath, entry_name)
wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL")
wheel_file.writestr(wheel_metadata_path, str(wheel_msg).encode("utf-8"))
metadata_path = os.path.join(dist_info_dir_name, "METADATA")
wheel_file.writestr(metadata_path, str(metadata_msg).encode("utf-8"))
return wheel_path

46
tests/conftest.py Normal file
View File

@ -0,0 +1,46 @@
import sys
import pytest
import torch
from kernels.utils import _get_privateuse_backend_name
has_cuda = (
hasattr(torch.version, "cuda")
and torch.version.cuda is not None
and torch.cuda.device_count() > 0
)
has_rocm = (
hasattr(torch.version, "hip")
and torch.version.hip is not None
and torch.cuda.device_count() > 0
)
has_xpu = (
hasattr(torch.version, "xpu")
and torch.version.xpu is not None
and torch.xpu.device_count() > 0
)
has_npu = _get_privateuse_backend_name() == "npu"
def pytest_addoption(parser):
parser.addoption(
"--token",
action="store_true",
help="run tests that require a token with write permissions",
)
def pytest_runtest_setup(item):
if "cuda_only" in item.keywords and not has_cuda:
pytest.skip("skipping CUDA-only test on host without CUDA")
if "rocm_only" in item.keywords and not has_rocm:
pytest.skip("skipping ROCm-only test on host without ROCm")
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
pytest.skip("skipping macOS-only test on non-macOS platform")
if "xpu_only" in item.keywords and not has_xpu:
pytest.skip("skipping XPU-only test on host without XPU")
if "npu_only" in item.keywords and not has_npu:
pytest.skip("skipping NPU-only test on host without NPU")
if "token" in item.keywords and not item.config.getoption("--token"):
pytest.skip("need --token option to run this test")

View File

@ -1,54 +1,70 @@
[
{
"repo_id": "kernels-community/activation",
"sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
"sha": "83046852be158d525114f68513cd79fd88911b37",
"variants": {
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d",
"torch27-cxx11-cu118-x86_64-linux": {
"hash": "sha256-e34965c814c4c092fcb634ebadefe82ea9a05b98343f8ebdefa7305dcc05359e",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58",
"torch27-cxx11-cu126-x86_64-linux": {
"hash": "sha256-5f92b35922b37224a416398a39a29b7e5f1aca1df17d5c69f1b9e9cdb7033561",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a",
"torch27-cxx11-cu128-aarch64-linux": {
"hash": "sha256-125967cb23bacd2cec443799f184ac08247dfff33f5027e54ee16d3779ca5986",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0",
"torch27-cxx11-cu128-x86_64-linux": {
"hash": "sha256-496a84c99d7035a1b6f0ea1c026b751c3a2677956f4c1be546d3cc1505a5fdbb",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8",
"torch28-cxx11-cu126-aarch64-linux": {
"hash": "sha256-f0775a30ffa290c90aba3a41037e3ca91edb15b4a9367561fafd5f25455e117a",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc",
"torch28-cxx11-cu126-x86_64-linux": {
"hash": "sha256-081995e6230f306bdf6111186618794f2411cf0ffd9b4800330df60b4ebe1927",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b",
"torch28-cxx11-cu128-aarch64-linux": {
"hash": "sha256-b937fef62a0c1cd71ab98490b651c473577af209b9a3e2a6b452350283d8812c",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005",
"torch28-cxx11-cu128-x86_64-linux": {
"hash": "sha256-a3915686cc58641a3361ece63ab77b33e9d30315dea12547e4bda008d8810a01",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc",
"torch28-cxx11-cu129-aarch64-linux": {
"hash": "sha256-a24dca8e998f88be42491921c9df89d88a6112ca630acd2efc2dd34a64b91fcb",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f",
"torch28-cxx11-cu129-x86_64-linux": {
"hash": "sha256-df6c70a70f425db2f68b86561c6f93c5675c1d5e5d058766d88ab17472229907",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf",
"torch29-cxx11-cu126-aarch64-linux": {
"hash": "sha256-c120011c201072b4cfd70c2ba2d45c2f05337feaf604ddec3c6c4987def33ab3",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a",
"torch29-cxx11-cu126-x86_64-linux": {
"hash": "sha256-765a7f3279009979be4001a23c5c70e5e6ab9553098d67886731a5275a6d4b32",
"hash_type": "git_lfs_concat"
},
"torch29-cxx11-cu128-aarch64-linux": {
"hash": "sha256-266d057a9cd82b872a0e02f09ac5e2660fcffcf9a7b7fa1fa8ff33dc19c0f5c2",
"hash_type": "git_lfs_concat"
},
"torch29-cxx11-cu128-x86_64-linux": {
"hash": "sha256-6850e594ba4588f289b5904eb88eda5a41870ee20a3bf1586f3268307caf4b53",
"hash_type": "git_lfs_concat"
},
"torch29-cxx11-cu130-aarch64-linux": {
"hash": "sha256-23741b935462b53bdf868f8d1c9c8cff5f02f71ea3b0550df41dc8b030b0b474",
"hash_type": "git_lfs_concat"
},
"torch29-cxx11-cu130-x86_64-linux": {
"hash": "sha256-b884ae792dc1eada071f31645add0c2c76d479864f25aebcdd8318b675aaaf29",
"hash_type": "git_lfs_concat"
}
}

View File

@ -0,0 +1,12 @@
[
{
"repo_id": "kernels-test/versions",
"sha": "dc142fd6c9920c993d32be6358b78957c58681c3",
"variants": {
"torch-universal": {
"hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c",
"hash_type": "git_lfs_concat"
}
}
}
]

View File

@ -0,0 +1,2 @@
[tool.kernels.dependencies]
"kernels-test/versions" = ">=0.1.0,<0.2.0"

View File

@ -1,7 +1,7 @@
import pytest
import torch
from kernels import get_kernel, has_kernel
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
@pytest.fixture
@ -9,6 +9,25 @@ def kernel():
return get_kernel("kernels-community/activation")
@pytest.fixture
def local_kernel_path():
package_name, path = install_kernel("kernels-community/activation", "main")
# Path is the build variant path (build/torch-<...>), so the grandparent
# is the kernel repository path.
return package_name, path
@pytest.fixture
def local_kernel(local_kernel_path):
package_name, path = local_kernel_path
return get_local_kernel(path.parent.parent, package_name)
@pytest.fixture
def metal_kernel():
return get_kernel("kernels-test/relu-metal")
@pytest.fixture
def universal_kernel():
return get_kernel("kernels-community/triton-scaled-mm")
@ -21,6 +40,7 @@ def device():
return "cuda"
@pytest.mark.cuda_only
def test_gelu_fast(kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
@ -36,6 +56,64 @@ def test_gelu_fast(kernel, device):
assert torch.allclose(y, expected)
@pytest.mark.cuda_only
def test_local_kernel(local_kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
local_kernel.gelu_fast(y, x)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
device=device,
dtype=torch.float16,
)
assert torch.allclose(y, expected)
@pytest.mark.cuda_only
def test_local_kernel_path_types(local_kernel_path, device):
package_name, path = local_kernel_path
# Top-level repo path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071
kernel = get_local_kernel(path.parent.parent, package_name)
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
device=device,
dtype=torch.float16,
)
assert torch.allclose(y, expected)
# Build directory path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build
kernel = get_local_kernel(path.parent.parent / "build", package_name)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
assert torch.allclose(y, expected)
# Explicit package path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux
kernel = get_local_kernel(path, package_name)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
assert torch.allclose(y, expected)
@pytest.mark.darwin_only
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_relu_metal(metal_kernel, dtype):
x = torch.arange(-10, 10, dtype=dtype, device="mps")
y = metal_kernel.relu(x)
assert torch.allclose(y, torch.relu(x))
@pytest.mark.cuda_only
@pytest.mark.parametrize(
"kernel_exists",
[
@ -52,6 +130,26 @@ def test_has_kernel(kernel_exists):
assert has_kernel(repo_id, revision=revision) == kernel
def test_version():
kernel = get_kernel("kernels-test/versions")
assert kernel.version() == "0.2.0"
kernel = get_kernel("kernels-test/versions", version="<1.0.0")
assert kernel.version() == "0.2.0"
kernel = get_kernel("kernels-test/versions", version="<0.2.0")
assert kernel.version() == "0.1.1"
kernel = get_kernel("kernels-test/versions", version=">0.1.0,<0.2.0")
assert kernel.version() == "0.1.1"
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
get_kernel("kernels-test/versions", version=">0.2.0")
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
kernel = get_kernel(
"kernels-test/versions", revision="v0.1.0", version="<1.0.0"
)
@pytest.mark.cuda_only
def test_universal_kernel(universal_kernel):
torch.manual_seed(0)
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")

View File

@ -16,18 +16,21 @@ def device():
return "cuda"
@pytest.mark.cuda_only
def test_gelu_small(kernel, device, benchmark):
x = torch.randn(32, 32, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.cuda_only
def test_gelu_medium(kernel, device, benchmark):
x = torch.randn(128, 128, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.cuda_only
def test_gelu_large(kernel, device, benchmark):
x = torch.randn(512, 512, dtype=torch.float16, device=device)
y = torch.empty_like(x)

49
tests/test_doctest.py Normal file
View File

@ -0,0 +1,49 @@
import inspect
import pytest
from mktestdocs import check_docstring, get_codeblock_members
import kernels
def all_public_functions():
function_list = inspect.getmembers(kernels, inspect.isfunction)
return [func for _, func in function_list]
def all_public_classes():
class_list = inspect.getmembers(kernels, inspect.isclass)
return [cls for _, cls in class_list]
def all_public_class_members():
members = get_codeblock_members(*all_public_classes())
return members
@pytest.mark.cuda_only
@pytest.mark.parametrize(
"func",
all_public_functions(),
ids=lambda d: d.__name__,
)
def test_func_docstring(func):
check_docstring(obj=func)
@pytest.mark.cuda_only
@pytest.mark.parametrize(
"cls",
all_public_classes(),
ids=lambda d: d.__name__,
)
def test_class_docstring(cls):
check_docstring(obj=cls)
@pytest.mark.cuda_only
@pytest.mark.parametrize(
"member", all_public_class_members(), ids=lambda d: d.__qualname__
)
def test_member_docstring(member):
check_docstring(member)

230
tests/test_interval_tree.py Normal file
View File

@ -0,0 +1,230 @@
import random
from typing import Generic, List, Optional, Tuple, TypeVar
import pytest
from kernels._interval_tree import IntervalTree, _Node
T = TypeVar("T")
class SimpleIntervalStore(Generic[T]):
"""A simple O(n) implementation that stores intervals in a list."""
def __init__(self):
self.intervals: List[Tuple[int, int, T]] = []
def insert(self, start: int, end: int, data: T) -> None:
"""Insert an interval into the store."""
# Replace data if the interval already exists.
for i, (existing_start, existing_end, existing_data) in enumerate(
self.intervals
):
if existing_start == start and existing_end == end:
self.intervals[i] = (start, end, data)
return
self.intervals.append((start, end, data))
def find_smallest_interval(self, point: int) -> Optional[T]:
"""Find the best match using linear search."""
matches = []
for start, end, data in self.intervals:
if start <= point <= end:
matches.append((start, end, data))
if not matches:
return None
# Return the smallest interval, sort by memory location when
# there are multiple matches with the same interval size. This
# mirrors the ordering in the intervan tree.
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
return best_match[2]
def is_balanced(tree: IntervalTree[T]) -> bool:
"""Check if the AVL tree is properly balanced."""
def check_balance(node: Optional[_Node[T]]) -> Tuple[bool, int]:
if node is None:
return True, 0
# Left and right subtrees should be balanced.
left_balanced, left_height = check_balance(node.left)
if not left_balanced:
return False, -1
right_balanced, right_height = check_balance(node.right)
if not right_balanced:
return False, -1
# The difference in height should not exceed 1.
if abs(left_height - right_height) > 1:
return False, -1
# Check if the height is correct.
expected_height = 1 + max(left_height, right_height)
if node.height != expected_height:
return False, -1
return True, expected_height
balanced, _ = check_balance(tree.root)
return balanced
@pytest.fixture
def populated_tree() -> IntervalTree[str]:
"""Provides a pre-populated IntervalTree for testing."""
tree = IntervalTree[str]()
kernels = [
(80, 89, "Kernel_A_General_80_89"),
(86, 89, "Kernel_B_Ampere_86_89"),
(80, 86, "Kernel_C_Older_Ampere_80_86"),
(70, 75, "Kernel_D_Volta_70_75"),
(86, 87, "Kernel_E_Specific_86_87"),
]
for start, end, name in kernels:
tree.insert(start, end, name)
return tree
def test_find_smallest_interval_match_with_multiple_overlaps(populated_tree):
# Check that the smallest inteval is selected when there are
# multiple matching intervals.
assert populated_tree.find_smallest_interval(86) == "Kernel_E_Specific_86_87"
def test_find_single_match(populated_tree):
assert populated_tree.find_smallest_interval(72) == "Kernel_D_Volta_70_75"
assert populated_tree.find_smallest_interval(75) == "Kernel_D_Volta_70_75"
def test_no_match_outside_all_ranges(populated_tree):
# Check that no interval is found when the value is out of range
# (too small/too large).
assert populated_tree.find_smallest_interval(65) is None
assert populated_tree.find_smallest_interval(95) is None
def test_no_match_in_gap_between_ranges(populated_tree):
# Check that no interval is found when the value is between two
# intervals.
assert populated_tree.find_smallest_interval(78) is None
def test_boundary_conditions_start_and_end(populated_tree):
# Test exact upper/lower bounds of intervals.
assert populated_tree.find_smallest_interval(80) == "Kernel_C_Older_Ampere_80_86"
assert populated_tree.find_smallest_interval(89) == "Kernel_B_Ampere_86_89"
def test_empty_tree():
# Searching in an empty tree should return None.
empty_tree = IntervalTree[str]()
assert empty_tree.find_smallest_interval(100) is None
def test_multiple_equally_specific_matches():
# Check that we pick the match in a stable way when there is are
# multiple matching intervals with the same size.
tree = IntervalTree[str]()
str1 = "First_Narrow_Kernel"
str2 = "Second_Narrow_Kernel"
tree.insert(10, 20, "Wide_Kernel")
tree.insert(12, 17, str1)
tree.insert(14, 19, str2)
if id(str1) < id(str2):
assert tree.find_smallest_interval(15) == str1
else:
assert tree.find_smallest_interval(15) == str2
def test_property_based_interval_tree():
# Quick-check property-based testing:
#
# - Verify that the tree is balanced after each insertion.
# - Verify the query against a simple list-based implementation.
random.seed(42) # For reproducible tests
test_points = list(range(0, 101))
for _ in range(5):
tree = IntervalTree[str]()
simple = SimpleIntervalStore[str]()
intervals = []
for i in range(100):
start = random.randint(0, 90)
end = random.randint(start, 100)
data = f"interval_{i}_s{start}_e{end}"
intervals.append((start, end, data))
for i, (start, end, data) in enumerate(intervals):
tree.insert(start, end, data)
simple.insert(start, end, data)
# Check that tree is still balanced
assert is_balanced(
tree
), f"Tree became unbalanced after inserting interval {i}: ({start}, {end})"
for point in test_points:
tree_result = tree.find_smallest_interval(point)
simple_result = simple.find_smallest_interval(point)
assert tree_result == simple_result, (
f"Mismatch for point {point} after inserting {i+1} intervals. "
f"Tree: {tree_result}, Simple: {simple_result}. "
f"Last inserted: ({start}, {end})"
)
def test_property_based_edge_cases():
random.seed(123)
tree = IntervalTree[str]()
simple = SimpleIntervalStore[str]()
# Single-point intervals.
for i in range(10):
point = random.randint(0, 100)
data = f"single_point_{i}_{point}"
tree.insert(point, point, data)
simple.insert(point, point, data)
assert is_balanced(
tree
), f"Tree unbalanced after inserting single point {point}"
# Test the exact point and neighbors
for test_point in [point - 1, point, point + 1]:
if 0 <= test_point <= 100:
tree_result = tree.find_smallest_interval(test_point)
simple_result = simple.find_smallest_interval(test_point)
assert tree_result == simple_result
def test_unique_intervals_override():
"""Test that inserting an interval with the same start/end overrides the previous value."""
tree = IntervalTree[str]()
tree.insert(10, 20, "original_value")
assert tree.find_smallest_interval(15) == "original_value"
tree.insert(10, 20, "new_value")
assert tree.find_smallest_interval(15) == "new_value"
tree.insert(10, 25, "different_interval")
results = tree.search(15)
assert "new_value" in results
assert "different_interval" in results
assert len(results) == 2
tree.insert(10, 20, "final_value")
assert tree.find_smallest_interval(15) == "final_value"
assert is_balanced(tree)

View File

@ -1,8 +1,18 @@
from dataclasses import dataclass
from pathlib import Path
import pytest
import torch.nn as nn
from kernels import load_kernel
from kernels.cli import download_kernels
from kernels.layer import (
LockedLayerRepository,
Mode,
kernelize,
use_kernel_forward_from_hub,
use_kernel_mapping,
)
# Mock download arguments class.
@ -17,8 +27,35 @@ def test_download_all_hash_validation():
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
@pytest.mark.cuda_only
def test_load_locked():
project_dir = Path(__file__).parent / "kernel_locking"
# Also validates that hashing works correctly.
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
@pytest.mark.cuda_only
def test_layer_locked():
project_dir = Path(__file__).parent / "layer_locking"
@use_kernel_forward_from_hub("Version")
class Version(nn.Module):
def forward(self) -> str:
return "0.0.0"
version = Version()
with use_kernel_mapping(
{
"Version": {
"cuda": LockedLayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
lockfile=project_dir / "kernels.lock",
)
},
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.1.1"

122
tests/test_kernel_upload.py Normal file
View File

@ -0,0 +1,122 @@
import logging
import os
import re
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List
import pytest
from huggingface_hub import delete_repo, model_info, list_repo_refs
from kernels.cli import upload_kernels
REPO_ID = "valid_org/kernels-upload-test"
PY_CONTENT = """\
#!/usr/bin/env python3
def main():
print("Hello from torch-universal!")
if __name__ == "__main__":
main()
"""
@dataclass
class UploadArgs:
kernel_dir: None
repo_id: None
private: False
branch: None
def next_filename(path: Path) -> Path:
"""
Given a path like foo_2050.py, return foo_2051.py.
"""
m = re.match(r"^(.*?)(\d+)(\.py)$", path.name)
if not m:
raise ValueError(
f"Filename {path.name!r} does not match pattern <prefix>_<number>.py"
)
prefix, number, suffix = m.groups()
new_number = str(int(number) + 1).zfill(len(number))
return path.with_name(f"{prefix}{new_number}{suffix}")
def get_filename_to_change(repo_filenames):
for f in repo_filenames:
if "foo" in f and f.endswith(".py"):
filename_to_change = os.path.basename(f)
break
assert filename_to_change
return filename_to_change
def get_filenames_from_a_repo(repo_id: str) -> List[str]:
try:
repo_info = model_info(repo_id=repo_id, files_metadata=True)
repo_siblings = repo_info.siblings
if repo_siblings is not None:
return [f.rfilename for f in repo_siblings]
else:
raise ValueError("No repo siblings found.")
except Exception as e:
logging.error(f"Error connecting to the Hub: {e}.")
@pytest.mark.token
@pytest.mark.is_staging_test
@pytest.mark.parametrize("branch", (None, "foo"))
def test_kernel_upload_works_as_expected(branch):
with tempfile.TemporaryDirectory() as tmpdir:
path = f"{tmpdir}/build/torch-universal/upload_test"
build_dir = Path(path)
build_dir.mkdir(parents=True, exist_ok=True)
script_path = build_dir / "foo.py"
script_path.write_text(PY_CONTENT)
upload_kernels(UploadArgs(tmpdir, REPO_ID, False, branch))
repo_filenames = get_filenames_from_a_repo(REPO_ID)
assert any(str(script_path.name) for f in repo_filenames)
if branch is not None:
refs = list_repo_refs(repo_id=REPO_ID)
assert any(ref_branch.name == branch for ref_branch in refs.branches)
delete_repo(repo_id=REPO_ID)
@pytest.mark.token
@pytest.mark.is_staging_test
def test_kernel_upload_deletes_as_expected():
with tempfile.TemporaryDirectory() as tmpdir:
path = f"{tmpdir}/build/torch-universal/upload_test"
build_dir = Path(path)
build_dir.mkdir(parents=True, exist_ok=True)
script_path = build_dir / "foo_2025.py"
script_path.write_text(PY_CONTENT)
upload_kernels(UploadArgs(tmpdir, REPO_ID, False, None))
repo_filenames = get_filenames_from_a_repo(REPO_ID)
filename_to_change = get_filename_to_change(repo_filenames)
with tempfile.TemporaryDirectory() as tmpdir:
path = f"{tmpdir}/build/torch-universal/upload_test"
build_dir = Path(path)
build_dir.mkdir(parents=True, exist_ok=True)
changed_filename = next_filename(Path(filename_to_change))
script_path = build_dir / changed_filename
script_path.write_text(PY_CONTENT)
upload_kernels(UploadArgs(tmpdir, REPO_ID, False, None))
repo_filenames = get_filenames_from_a_repo(REPO_ID)
assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}"
assert not any(
str(filename_to_change) in k for k in repo_filenames
), f"{repo_filenames=}"
delete_repo(repo_id=REPO_ID)

File diff suppressed because it is too large Load Diff