Compare commits

...

83 Commits

Author SHA1 Message Date
b7b5f40143 Set version to 0.7.0 2025-07-07 13:09:01 +00:00
b87e6fadbe Set version to 0.7.0.dev0 (#104) 2025-07-07 14:56:43 +02:00
fc935d9874 Support registering inference/training-specific layers (#103)
* Support registering inference/training-specific layers

This change makes it possible to register kernels specialized for
inference, training, and/or `torch.compile`. To do so, the mapping
notation is extended to support registering specialized kernels
for a specific 'mode'. For instance, the following mapping,

```python
kernel_layer_mapping = {
    "SiluAndMul": {
        "cuda": {
          Mode.DEFAULT: LayerRepository(
              repo_id="kernels-community/activation",
              layer_name="SiluAndMul",
          ),
          Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
              repo_id="kernels-community/activation-training-optimized",
              layer_name="SiluAndMul",
          ),
      }
    }
}
```

uses `kernels-community/activation` by default, but will switch to
using `kernels-community/activation-training-optimized` if a model
is kernelized for training and `torch.compile`.

To make it easier to add more modes in the future and to unify the
`register_kernel_mapping` and `kernelize` signatures, the `training`
and `needs_torch_compile` arguments of `kernelize` are replaced by
a single `mode` argument:

```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```

* Documentation fixes

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Add note on when the fallback is used

* Tighten up some Mode checks

* Fix ruff check

* Attempt to fix mypy errors

* More typing fixes

* Ignore Python < 3.11 type check SNAFU

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-07-04 19:57:14 +02:00
3622e1f8dd Add get_local_kernel function (#102)
This function loads a kernel from a local repository (e.g. the output
of kernel-builder), which can be handy for testing.
2025-07-01 13:58:47 +02:00
a7f3b2e8ed Set version to 0.6.2.dev0 (#100) 2025-06-25 09:48:09 +02:00
a6ab5d83ba Make the flake work on Darwin (#98) 2025-06-24 20:35:21 +02:00
4f9f1abfb9 darwin: fix variant CPU for aarch64 (#97) 2025-06-24 20:35:07 +02:00
f94b7780a6 CI: main triton-layer-norm has docs, branch is gone (#99) 2025-06-24 16:40:36 +02:00
bd28883775 Set version to 0.6.1.dev1 (#96) 2025-06-20 11:43:26 +02:00
498429e322 Add README generation for layers (#94) 2025-06-20 10:16:50 +02:00
09c991af4b Add macOS requirements (#95) 2025-06-16 17:20:47 +02:00
bcf8df5875 Bump version to 0.6.0.dev0 (#93) 2025-06-04 13:59:32 +02:00
239afff6f5 Update Nix flake dependencies (#92)
* Update Nix flake dependencies

To ensure that we can test with Torch 2.7 kernels in the development
environment.

* Update nix fmt to use nixfmt-tree
2025-06-04 12:13:19 +02:00
c5ec6b900a Hotfix: add FAQ (#91) 2025-06-04 09:52:39 +02:00
3a635eaeea Automatic fallback for kernels that don't support training (#90)
For kernels that do not support backward, fall back to the original
implementation if `model.train(True)` is called. This removes the
need for the `needs_backward` argument of `kernelize`.
2025-06-03 19:13:57 +02:00
32ec496c5a Make the forward pass torch.compile compatible (#87)
* first commit

* style

* update

* fix

* different approach

* Polish kernelize

- Process comment from the PR.
- Replacement should be on instances, not the class.
- Remove torch compile checks (not relevant during kernelize). We
  might add it back in a different way in another commit: add an
  option to `kernelize`.

* Fixup tests

* Fix `torch.compile` support

* Remove some unused code

* Sync the docs

* CI: update Torch versions

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-06-03 15:06:02 +02:00
848c6db87b Add support for Metal builds (#89)
* Add support for Metal builds

* Add Metal test, gate tests by OS where necessary
2025-05-30 15:54:28 +02:00
fabb8c52d1 Add generate-readme subcommand for generating a README (#88)
* Add `generate-readme` subcommand for generating a README

This README includes all the top-level functions with docs (if
docstrings are available).

* CI: attempt README generation

* Add PyYAML dependencies

* Typing fixes
2025-05-21 15:43:53 +02:00
d66260dd83 kernels: add the to-wheel subcommand (#84)
* kernels: add the `to-wheel` subcommand

This subcommand accepts a kernel repo and version as arguments:

    kernels to-wheel kernels-community/activation 0.0.3

Wheels will then be generated for every build variant.

* CI: check kernel -> wheel conversion

* No typing for wheel.wheelfile
2025-05-08 17:30:06 +02:00
daac8078fc CI: fix some stubs (#83) 2025-05-07 14:43:57 +02:00
fcb9a80ce6 Set version to 0.5.0 (#82) 2025-05-06 11:45:26 +02:00
c25bb32e6e Add publishing workflow (#81) 2025-05-06 09:29:08 +00:00
2036892762 Allow layers to opt in to torch.compile (#79)
* Allow layers to opt in to `torch.compile`

This change allows a layer to set the `can_torch_compile` class
variable to indicate that the layer is compatible with `torch.compile`.
When enabled, the layer does not fall back to the original
implementation when `torch.compile` is used.

* Comment fixes

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-05-06 09:36:33 +02:00
0f0de049cf docs: link to autogenerated build variant list (#77) 2025-04-16 17:25:51 +02:00
59597df03e Specify required aarch64 build variants (#76) 2025-04-15 16:09:44 +02:00
5e938ede40 locking docs: fix command name (kernel -> kernels) (#74) 2025-04-14 16:02:00 +02:00
cf530c283a Set version to 0.4.4 (#73) 2025-04-11 10:23:26 +02:00
437f910336 Add has_kernel function (#69)
* Add `has_kernel` function

This function checks whether a kernel build exists for the current
environment (Torch version and compute framework).

* Test kernel repo that only contains Torch 2.4
2025-04-11 10:12:37 +02:00
6f1a6067c8 feat: add logo and shields (#72) 2025-04-11 10:07:24 +02:00
1d14abcef0 Do not use kernels without backward when training (#68)
* Do not use kernels without backward when training

* Update repo for backwards marker test
2025-04-11 10:05:57 +02:00
6fd2112e22 Set version to 0.4.3 (#71) 2025-04-10 11:57:15 +02:00
70f56ff856 Support DISABLE_KERNEL_MAPPING env var for completely disabling kernel mappings (#70)
* Disable kernel mappings with `DISABLE_KERNEL_MAPPING=1`

* Rename HF_KERNELS_CACHE to KERNELS_CACHE

But still recognize the old variant for compatibility.

* Add documentation for environment variables
2025-04-10 11:37:54 +02:00
7178b0b86c Add Apache License version 2.0 (#66)
Fixes #64
2025-04-04 20:35:29 +02:00
0bbf90a564 Update ABI requirement to manylinux_2_28 (#65) 2025-04-04 19:38:15 +02:00
27d6ffcb80 Add more details about the ABI requirements (#63) 2025-03-31 14:29:30 +02:00
f7bd21438b Set version to 0.4.2 (#62) 2025-03-27 16:57:28 +01:00
6174febb4b Add warning when layer_name not present in _KERNEL_MAPPING (#61)
* add warning

* fix import order
2025-03-27 16:22:58 +01:00
ff55bc201b Add support for fetching ROCm kernels (#59) 2025-03-25 15:11:03 +01:00
3808108d62 doc: add versioning (#58) 2025-03-24 16:48:20 +01:00
c4a16ef462 Actually export use_kernel_mapping at the top-level (#57)
* Actually export `use_kernel_mapping` at the top-level

* Set version to 0.4.1
2025-03-24 12:44:00 +01:00
9762794dd2 Set version to 0.4.0 (#56) 2025-03-21 20:49:01 +01:00
b7d6867c52 use_kernel_mapping: add inherit_mapping option (#55)
`inherit_mapping` is the default and extends the existing mapping
with the given mapping. If `inherit_mapping` is `False`, existing
mappings are not inherited.
2025-03-21 17:28:45 +01:00
fbcd0f2ebd Set version to 0.3.3 (#54) 2025-03-20 16:09:11 +01:00
5af46eca94 Align dependency versions with transformers (#53) 2025-03-20 15:13:45 +01:00
747dd66876 Set version to 0.3.2 (#51) 2025-03-20 11:46:36 +01:00
920590a592 Also export replace_kernel_forward_from_hub (#52) 2025-03-20 11:46:18 +01:00
5208ac4be5 Make torch an extra/dev dependency (#50)
To support use of this package when Torch is optional.
2025-03-20 10:18:19 +01:00
22eaba2826 Set version to 0.3.1 (#49) 2025-03-19 16:35:10 +01:00
9521ba79a0 Fix forward positional argument handling (#48) 2025-03-19 15:54:51 +01:00
9861a5bdef Fix forward positional argument handling (#48) 2025-03-19 15:34:35 +01:00
1c7c87c960 Set version to 0.3.0 (#47) 2025-03-19 12:02:02 +01:00
df45cf2795 Add use_kernel_forward_from_hub decorator (#46)
* Add `use_kernel_forward_from_hub` decorator

This decorator replaces a layer's `forward` with the `forward` of
a layer on the hub.

* Add support for registering a mapping for the duration of a context

This change makes `_KERNEL_MAPPING` a context variable and adds a
`use_kernel_mapping` context manager. This allows users to register
a mapping for the duration of a context.

* Update layer docs

* ruff fix

* Remove an old bit from the docs

* Extend layer mapping example

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Support stringly-typed device type

* Forward-reference `register_kernel_mapping` in monkeypatching section

* Use stringly-typed device name in layer mapping example

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2025-03-19 11:03:18 +01:00
cf0413efe5 Add Nix flake devshell (#44) 2025-03-11 10:59:12 +01:00
851c13f666 Set version to 0.2.1 (#43) 2025-03-10 15:20:34 +01:00
b6a393612f Pass through locked sha again when loading locked kernels (#42)
This bit got removed accidentally when adding support for universal
kernels. Also add a test to ensure that we'd catch this in the future.
2025-03-10 15:10:47 +01:00
18ecd0ce69 Set version to 0.2.0 (#41) 2025-03-10 10:24:02 +01:00
b4ef1d60e5 Update torch dependency to 2.5 (#40)
Fixes #37.
2025-03-07 20:32:54 +01:00
a40756f306 Configure ruff lints and add to CI (#39) 2025-03-07 20:32:44 +01:00
3671158f47 Rename noarch to universal (#38)
Also update docs to mention this variant.
2025-03-07 15:12:44 +01:00
2ddd473cf7 Add a bunch of cleanups (#36)
* Remove old build backend

* Add types, use `Path` where possible

* Remove unused `get_metadata` function

This function is also problematic, because it assumes that `build.toml`
is always present.
2025-03-07 14:41:08 +01:00
497dffb89e Support kernels that are not pre-compiled (#35)
* Support kernels that are not pre-compiled

This change add support for kernels that are not precompiled (such as
Triton-based kernels). For Torch, these kernels are assumed to be in
`build/torch-noarch`. Kernel download functions will filter on both
the expected (CUDA) build variant and the `noarch` variant. If a binary
variant exists, it is used. Otherwise the `noarch` variant is used
when present.

We don't append a Torch version, since in most cases the output for
every `ver` in `build/torch<ver>-noarch` would be the same. If some
kernel needs features that are only available in a specific Torch
version, the capabilities can be checked by the kernel itself at
runtime.

* CI: system Python does not have headers installed
2025-03-05 14:05:46 +01:00
f036fd09cb Clean up the README, link out to docs (#34) 2025-02-28 16:08:47 +01:00
3e4c83c798 Describe what to do when kernel is not locked (#33)
Especially the second step (reinstalling the project) is easy to
forget.
2025-02-28 10:47:10 +01:00
4116d6019e hf-kernels -> kernels (#32)
* hf-kernels -> kernels

* Set version to 0.1.7

* hf-kernels.lock -> kernels.lock
2025-02-25 16:13:37 +01:00
bd166b348a Revert "hf-kernels -> kernels"
This reverts commit 386c2a104ef4c251912e63bfcdbfaa588dc09605.
2025-02-25 15:06:35 +01:00
386c2a104e hf-kernels -> kernels 2025-02-25 15:05:38 +01:00
c7516b9e50 Use per-build variant hashes in the lockfile (#29)
This makes the lock file a fair bit shorter than per-file hashes. The
hash is computed from filenames + SHA-1 hash for git objects/SHA-256
hash for LFS files.
2025-02-25 14:58:03 +01:00
a8dcd1f6bc Describe requirements for Hub kernels (#31) 2025-02-25 13:15:23 +01:00
af7fdf9202 Add more info, installation details, to the README (#30)
* Improve readme

* Update README.md

Co-authored-by: Daniël de Kok <me@danieldk.eu>

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-02-25 09:47:23 +01:00
9426e7e290 Fix package name & add CUDA shield (#27)
* package_name should not depend on build.toml

* Raise when CUDA not installed
2025-02-24 14:10:54 +01:00
df2c165d61 hf-kernels: error out when no build is available (#25) 2025-02-14 20:20:44 +01:00
d89239464a Update README.md (#24) 2025-02-07 17:52:36 +01:00
3212affd9e Set version to 0.1.6 (#23) 2025-02-05 15:18:50 +01:00
7ff40a859c write_egg_lockfile: bail out if the project does not have pyproject.toml (#22) 2025-02-05 14:55:51 +01:00
cf64113c8b Set version to 0.1.5 (#21) 2025-02-05 11:04:28 +01:00
ba4f88f5aa Make module names unique by path (#20)
So far we have been using the extension name in `build.toml` as
the module name. However, this can cause naming conflicts. For
instance, if a kernel named `moe` is loaded through `hf_kernels`,
it would be registered as the `moe` module. This would cause
subsequent imports of a module named `moe` from the Python path
to be resolved as the module loaded through `hf_kernels`.

Solve this issue by adding some unique material to the module
name (hex-encoded hash of the kernel path).
2025-02-05 10:55:02 +01:00
d61971ad46 Set version to 0.1.4 (#19) 2025-02-04 20:29:27 +01:00
d7f3831992 Support kernel cache directory with HF_KERNELS_CACHE env var (#18) 2025-02-04 20:18:43 +01:00
03875be8a0 Set version to 0.1.3 (#17) 2025-01-31 15:43:55 +01:00
e41ef2358e Only import torch when needed (#16)
* Only import torch when needed

This avoids the (costly) torch load when e.g. the setuptools hooks
are running in downstream packages.

* Lock Python/Torch versions

Also update to Torch 2.5.1/2.6.0.

* Set the minimum Python version to 3.9

* Change step description
2025-01-31 15:33:58 +01:00
aca3ce7dfb Merge pull request #15 from huggingface/all-variants
Download all build variants of a kernel
2025-01-23 17:10:29 +01:00
3bae6fca7d Download all build variants of a kernel
This can be useful in cases where we want to have all CUDA/Torch
versions ahead of time.
2025-01-23 14:40:19 +00:00
cffbafa61f Set version to 0.1.2 (#14) 2025-01-23 10:12:53 +01:00
37 changed files with 4154 additions and 452 deletions

10
.github/workflows/lint.yml vendored Normal file
View File

@ -0,0 +1,10 @@
name: Lints
on: [push, pull_request]
jobs:
lint:
name: Run lints
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Run ruff
uses: astral-sh/ruff-action@v3

120
.github/workflows/publish.yml vendored Normal file
View File

@ -0,0 +1,120 @@
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
on: push
jobs:
build:
name: Build distribution 📦
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.9"
- name: Install pypa/build
run: >-
python3 -m
pip install
build
--user
- name: Build a binary wheel and a source tarball
run: python3 -m build
- name: Store the distribution packages
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: dist/
publish-to-pypi:
name: >-
Publish Python 🐍 distribution 📦 to PyPI
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
needs:
- build
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/kernels
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
github-release:
name: >-
Sign the Python 🐍 distribution 📦 with Sigstore
and upload them to GitHub Release
needs:
- publish-to-pypi
runs-on: ubuntu-latest
permissions:
contents: write # IMPORTANT: mandatory for making GitHub Releases
id-token: write # IMPORTANT: mandatory for sigstore
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Sign the dists with Sigstore
uses: sigstore/gh-action-sigstore-python@v3.0.0
with:
inputs: >-
./dist/*.tar.gz
./dist/*.whl
- name: Create GitHub Release
env:
GITHUB_TOKEN: ${{ github.token }}
run: >-
gh release create
"$GITHUB_REF_NAME"
--repo "$GITHUB_REPOSITORY"
--notes ""
- name: Upload artifact signatures to GitHub Release
env:
GITHUB_TOKEN: ${{ github.token }}
# Upload to GitHub Release using the `gh` CLI.
# `dist/` contains the built packages, and the
# sigstore-produced signatures and certificates.
run: >-
gh release upload
"$GITHUB_REF_NAME" dist/**
--repo "$GITHUB_REPOSITORY"
publish-to-testpypi:
name: Publish Python 🐍 distribution 📦 to TestPyPI
needs:
- build
runs-on: ubuntu-latest
environment:
name: testpypi
url: https://test.pypi.org/p/kernels
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
skip-existing: true # Only upload when the version is unique.

View File

@ -1,4 +1,4 @@
name: Test hf-kernels
name: Test kernels
on:
push:
@ -23,8 +23,11 @@ jobs:
strategy:
max-parallel: 4
matrix:
python: ["3.10", "3.12"]
torch: ["2.4.0", "2.5.0"]
python-version: ["3.10", "3.12"]
torch-version: ["2.6.0", "2.7.0"]
env:
UV_PYTHON_PREFERENCE: only-managed
steps:
- name: Checkout code
@ -35,8 +38,34 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Lock Torch version
run: uv lock --upgrade-package "torch==${{ matrix.torch-version }}"
- name: Install the project
run: uv sync --all-extras --dev
- name: Install setuptools for Triton-based test
run: uv pip install setuptools
- name: Check typing
run: uv run mypy src/kernels
- name: Run tests
run: uv run pytest tests
- name: Check kernel conversion
run: |
uv pip install wheel
uv run kernels to-wheel kernels-community/triton-layer-norm 0.0.1
uv pip install triton_layer_norm-0.0.1*.whl
uv run python -c "import triton_layer_norm"
- name: Check README generation
# For now, just checks that generation doesn't fail.
run: |
uv run kernels generate-readme kernels-community/triton-layer-norm
- name: Import check without torch
run: |
uv pip uninstall torch
python -c "import kernels"

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -1,11 +1,42 @@
# hf-kernels
# kernels
Make sure you have `torch==2.5.1+cu124` installed.
<div align="center">
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
<p align="center">
<a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
<a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
<a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
</p>
</div>
<hr/>
The Kernel Hub allows Python libraries and applications to load compute
kernels directly from the [Hub](https://hf.co/). To support this kind
of dynamic loading, Hub kernels differ from traditional Python kernel
packages in that they are made to be:
- Portable: a kernel can be loaded from paths outside `PYTHONPATH`.
- Unique: multiple versions of the same kernel can be loaded in the
same Python process.
- Compatible: kernels must support all recent versions of Python and
the different PyTorch build configurations (various CUDA versions
and C++ ABIs). Furthermore, older C library versions must be supported.
## 🚀 Quick Start
Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):
```bash
pip install kernels
```
Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:
```python
import torch
from hf_kernels import get_kernel
from kernels import get_kernel
# Download optimized kernels from the Hugging Face hub
activation = get_kernel("kernels-community/activation")
@ -20,57 +51,15 @@ activation.gelu_fast(y, x)
print(y)
```
## Docker Reference
You can [search for kernels](https://huggingface.co/models?other=kernel) on
the Hub.
build and run the reference [example/basic.py](example/basic.py) in a Docker container with the following commands:
## 📚 Documentation
```bash
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
```
## Locking kernel versions
Projects that use `setuptools` can lock the kernel versions that should be
used. First specify the accepted versions in `pyproject.toml` and make
sure that `hf-kernels` is a build dependency:
```toml
[build-system]
requires = ["hf-kernels", "setuptools"]
build-backend = "setuptools.build_meta"
[tool.kernels.dependencies]
"kernels-community/activation" = ">=0.0.1"
```
Then run `hf-kernel lock .` in the project directory. This generates a `kernels.lock` file with
the locked revisions. The locked revision will be used when loading a kernel with
`get_locked_kernel`:
```python
from hf_kernels import get_locked_kernel
activation = get_locked_kernel("kernels-community/activation")
```
**Note:** the lock file is included in the package metadata, so it will only be visible
to `hf-kernels` after doing an (editable or regular) installation of your project.
## Pre-downloading locked kernels
Locked kernels can be pre-downloaded by running `hf-kernel download .` in your
project directory. This will download the kernels to your local Hugging Face
Hub cache.
The pre-downloaded kernels are used by the `get_locked_kernel` function.
`get_locked_kernel` will download a kernel when it is not pre-downloaded. If you
want kernel loading to error when a kernel is not pre-downloaded, you can use
the `load_kernel` function instead:
````python
```python
from hf_kernels import load_kernel
activation = load_kernel("kernels-community/activation")
````
- [Using layers](docs/layers.md)
- [Locking kernel versions](docs/locking.md)
- [Environment variables](docs/env.md)
- [Using kernels in a Docker container](docs/docker.md)
- [Kernel requirements](docs/kernel-requirements.md)
- [Frequently Asked Questions](docs/faq.md)
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)

View File

@ -31,13 +31,13 @@ WORKDIR /app/kernel-test
# install python depdencies
RUN uv add torch==2.5.0 numpy
# copy hf-kernels lib
COPY src ./hf-kernels/src
COPY pyproject.toml ./hf-kernels/pyproject.toml
COPY README.md ./hf-kernels/README.md
# copy kernels lib
COPY src ./kernels/src
COPY pyproject.toml ./kernels/pyproject.toml
COPY README.md ./kernels/README.md
# install library
RUN uv pip install -e hf-kernels
RUN uv pip install -e kernels
# copy examples
COPY examples ./examples
@ -48,4 +48,4 @@ ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
# command to run the script
CMD ["uv", "run", "examples/basic.py"]
# CMD ["ls", "hf-kernels"]
# CMD ["ls", "kernels"]

8
docs/docker.md Normal file
View File

@ -0,0 +1,8 @@
# Using kernels in a Docker container
build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
```bash
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
```

10
docs/env.md Normal file
View File

@ -0,0 +1,10 @@
# Environment variables
## `KERNELS_CACHE`
The directory to use as the local kernel cache. If not set, the cache
of the `huggingface_hub` package is used.
## `DISABLE_KERNEL_MAPPING`
Disables kernel mappings for [`layers`](layers.md).

13
docs/faq.md Normal file
View File

@ -0,0 +1,13 @@
# FAQ
## Why is the kernelization step needed?
In earlier versions of `kernels`, a layer's `forward` was replaced by
`use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The
new `forward` would dispatch to a kernel based on the device type,
whether a model was training, etc. However, this approach was
fundamentally incompatible with `torch.compile` since it relied
on data-dependent branching.
To avoid branching, we have to make dispatch decisions ahead of time,
which is what the `kernelize` function does.

210
docs/kernel-requirements.md Normal file
View File

@ -0,0 +1,210 @@
# Kernel requirements
Kernels on the Hub must fulfill the requirements outlined on this page. By
ensuring kernels are compliant, they can be used on a wide range of Linux
systems and Torch builds.
You can use [kernel-builder](https://github.com/huggingface/kernel-builder/)
to build compliant kernels.
## Directory layout
A kernel repository on the Hub must contain a `build` directory. This
directory contains build variants of a kernel in the form of directories
following the template
`<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`.
For example `build/torch26-cxx98-cu118-x86_64-linux`.
Each variant directory must contain a single directory with the same name
as the repository (replacing `-` by `_`). For instance, kernels in the
`kernels-community/activation` repository have a directories like
`build/<variant>/activation`. This directory
must be a Python package with an `__init__.py` file.
## Build variants
A kernel can be compliant for a specific compute framework (e.g. CUDA) or
architecture (e.g. x86_64). For compliance with a compute framework and
architecture combination, all the variants from the [build variant list](https://github.com/huggingface/kernel-builder/blob/main/docs/build-variants.md)
must be available for that combination.
## Versioning
Kernels are versioned on the Hub using Git tags. Version tags must be of
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
to resolve the version constraints.
## Native Python module
Kernels will typically contain a native Python module with precompiled
compute kernels and bindings. This module must fulfill the requirements
outlined in this section. For all operating systems, a kernel must not
have dynamic library dependencies outside:
- Torch;
- CUDA/ROCm libraries installed as dependencies of Torch.
### Linux
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
for compatibility with Python 3.9 and later.
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
This means that the extension **must not** use symbols versions higher than:
- GLIBC 2.28
- GLIBCXX 3.4.24
- CXXABI 1.3.11
- GCC 7.0.0
These requirement can be checked with the ABI checker (see below).
### macOS
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
for compatibility with Python 3.9 and later.
- macOS deployment target 15.0.
- Metal 3.0 (`-std=metal3.0`).
The ABI3 requirement can be checked with the ABI checker (see below).
### ABI checker
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
```bash
$ cargo install kernel-abi-check
$ kernel-abi-check result/relu/_relu_e87e0ca_dirty.abi3.so
🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9
✅ No compatibility issues found
```
## Torch extension
Torch native extension functions must be [registered](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial)
in `torch.ops.<namespace>`. Since we allow loading of multiple versions of
a module in the same Python process, `namespace` must be unique for each
version of a kernel. Failing to do so will create clashes when different
versions of the same kernel are loaded. Two suggested ways of doing this
are:
- Appending a truncated SHA-1 hash of the git commit that the kernel was
built from to the name of the extension.
- Appending random material to the name of the extension.
**Note:** we recommend against appending a version number or git tag.
Version numbers are typically not bumped on each commit, so users
might use two different commits that happen to have the same version
number. Git tags are not stable, so they do not provide a good way
of guaranteeing uniqueness of the namespace.
## Layers
A kernel can provide layers in addition to kernel functions. A layer from
the Hub can replace the `forward` method of an existing layer for a certain
device type. This makes it possible to provide more performant kernels for
existing layers. See the [layers documentation](layers.md) for more information
on how to use layers.
### Writing layers
To make the extension of layers safe, the layers must fulfill the following
requirements:
- The layers are subclasses of `torch.nn.Module`.
- The layers are pure, meaning that they do not have their own state. This
means that:
- The layer must not define its own constructor.
- The layer must not use class variables.
- No other methods must be defined than `forward`.
- The `forward` method has a signature that is compatible with the
`forward` method that it is extending.
There are two exceptions to the _no class variables rule_:
1. The `has_backward` variable can be used to indicate whether the layer has
a backward pass implemented (`True` when absent).
2. The `can_torch_compile` variable can be used to indicate whether the layer
supports `torch.compile` (`False` when absent).
This is an example of a pure layer:
```python
class SiluAndMul(nn.Module):
# This layer does not implement backward.
has_backward: bool = False
def forward(self, x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
return out
```
For some layers, the `forward` method has to use state from the adopting class.
In these cases, we recommend to use type annotations to indicate what member
variables are expected. For instance:
```python
class LlamaRMSNorm(nn.Module):
weight: torch.Tensor
variance_epsilon: float
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return rms_norm_fn(
hidden_states,
self.weight,
bias=None,
residual=None,
eps=self.variance_epsilon,
dropout_p=0.0,
prenorm=False,
residual_in_fp32=False,
)
```
This layer expects the adopting layer to have `weight` and `variance_epsilon`
member variables and uses them in the `forward` method.
### Exporting layers
To accommodate portable loading, `layers` must be defined in the main
`__init__.py` file. For example:
```python
from . import layers
__all__ = [
# ...
"layers"
# ...
]
```
## Python requirements
- Python code must be compatible with Python 3.9 and later.
- All Python code imports from the kernel itself must be relative. So,
for instance if in the example kernel `example`,
`module_b` needs a function from `module_a`, import as:
```python
from .module_a import foo
```
**Never use:**
```python
# DO NOT DO THIS!
from example.module_a import foo
```
The latter would import from the module `example` that is in Python's
global module dict. However, since we allow loading multiple versions
of a module, we uniquely name the module.
- Only modules from the Python standard library, Torch, or the kernel itself
can be imported.

197
docs/layers.md Normal file
View File

@ -0,0 +1,197 @@
# Layers
A kernel can provide layers in addition to kernel functions. A layer from
the Hub can replace the `forward` method of an existing layer for a certain
device type. This makes it possible to provide more performant kernels for
existing layers.
See [Kernel requirements](kernel-requirements.md) for more information the
requirements of Hub layers.
## Making a layer extensible with kernels from the hub
### Using a decorator
A layer can be made extensible with the `use_kernel_forward_from_hub`
decorator. For example:
```python
@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMul(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
return F.silu(input[..., :d]) * input[..., d:]
```
The decorator does not change the behavior of the class -- it annotates
the class with the given name (here `SiluAndMul`). The `kernelize` function
described below uses this name to look up kernels for the layer.
### External layers
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
decorator can be made extensible using the `replace_kernel_forward_from_hub`
function:
```python
from somelibrary import SiluAndMul
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
```
**Warning:** we strongly recommend using layers with a decorator, since
it signifies that the maintainer intends to keep the `forward` signature
compatible with layers from the hub.
## Kernelizing a model
A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. Kernelize can be used as follows:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE)
```
The `mode` specifies that the model will be used in inference. Similarly,
you can ask `kernelize` to prepare the model for training:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING)
```
**Note:** the `kernelize` function modifies the model in-place, the model
itself is returned as a convenience.
### Kernel device
Kernels can be registered per device type. For instance, separate `cuda` and
`metal` kernels could be registered for the name `SiluAndMul`. By default,
`kernelize` will try to infer the device type from the model's parameters.
You can pass the device type to `kernelize` if the device type cannot be
inferred (e.g. because the model has no parameters):
```python
model = MyModel(...)
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
```
### `torch.compile`
Not all Hub kernels support `torch.compile`. If you want to compile a model
after kernelizing it, you need to add this to the mode. You can use the
set union (`|`) operator to add `TORCH_COMPILE` to the mode:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
```
### Fallback `forward`
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
kernel does not support backward passes or `torch.compile` respectively,
`kernenize` will fall back to the original, non-kernelized, layer. You
can let `kernelize` raise an exception instead by using `use_fallback=False`:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
```
This can be useful if you want to guarantee that Hub kernels are used.
## Registering a hub kernel for a layer
`kernelize` relies on kernel mappings to find Hub kernels for layers.
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
the Hub. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
}
}
```
You can register such a mapping using `register_kernel_mapping`:
```python
register_kernel_mapping(kernel_layer_mapping)
```
This will register the kernel mapping in the current context, which is
normally global. It is recommended to scope the mapping to where it is
used with the `use_kernel_mapping` context manager:
```python
with use_kernel_mapping(kernel_layer_mapping):
# Use the layer for which the mapping is applied.
model = kernelize(model)
```
This ensures that the mapping is not active anymore outside the
`with`-scope.
### Registering kernels for specific modes
You might want to register two different kernels for a particular layer,
where one kernel is optimized for a specific mode. You can do so by
registering layer repositories for specific modes. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
The kernels will match exactly on the mode. So, for instance in the example above, no kernel
layer is used when the `mode` passed to `kernelize` is
`Mode.INFERENCE | Mode.TORCH_COMPILE` or `Mode.TRAINING`. However, if you want to
register a kernel to be used when the mode does not match any of the
modes in the mapping, you can use the special `Mode.DEFAULT` mode to do
so. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.DEFAULT: LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
In this case, modes other than `Mode.INFERENCE` and
`Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using
`kernels-community/activation`.

44
docs/locking.md Normal file
View File

@ -0,0 +1,44 @@
# Locking kernel versions
Projects that use `setuptools` can lock the kernel versions that should be
used. First specify the accepted versions in `pyproject.toml` and make
sure that `kernels` is a build dependency:
```toml
[build-system]
requires = ["kernels", "setuptools"]
build-backend = "setuptools.build_meta"
[tool.kernels.dependencies]
"kernels-community/activation" = ">=0.0.1"
```
Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with
the locked revisions. The locked revision will be used when loading a kernel with
`get_locked_kernel`:
```python
from kernels import get_locked_kernel
activation = get_locked_kernel("kernels-community/activation")
```
**Note:** the lock file is included in the package metadata, so it will only be visible
to `kernels` after doing an (editable or regular) installation of your project.
## Pre-downloading locked kernels
Locked kernels can be pre-downloaded by running `kernels download .` in your
project directory. This will download the kernels to your local Hugging Face
Hub cache.
The pre-downloaded kernels are used by the `get_locked_kernel` function.
`get_locked_kernel` will download a kernel when it is not pre-downloaded. If you
want kernel loading to error when a kernel is not pre-downloaded, you can use
the `load_kernel` function instead:
```python
from kernels import load_kernel
activation = load_kernel("kernels-community/activation")
```

View File

@ -1,6 +1,6 @@
import torch
from hf_kernels import get_kernel
from kernels import get_kernel
print("Starting examples/basic.py demo")

133
flake.lock generated Normal file
View File

@ -0,0 +1,133 @@
{
"nodes": {
"flake-compat": {
"locked": {
"lastModified": 1733328505,
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
"owner": "edolstra",
"repo": "flake-compat",
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
"type": "github"
},
"original": {
"owner": "edolstra",
"repo": "flake-compat",
"type": "github"
}
},
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"flake-utils_2": {
"inputs": {
"systems": "systems_2"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"hf-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1750775451,
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=",
"owner": "huggingface",
"repo": "hf-nix",
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7",
"type": "github"
},
"original": {
"owner": "huggingface",
"repo": "hf-nix",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1747820358,
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
"owner": "danieldk",
"repo": "nixpkgs",
"rev": "d3c1681180717528068082103bf323147de6ab0b",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "cudatoolkit-12.9-kernel-builder",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"hf-nix": "hf-nix",
"nixpkgs": [
"hf-nix",
"nixpkgs"
]
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"systems_2": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

57
flake.nix Normal file
View File

@ -0,0 +1,57 @@
{
inputs = {
hf-nix.url = "github:huggingface/hf-nix";
nixpkgs.follows = "hf-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
};
outputs =
{
self,
nixpkgs,
flake-utils,
hf-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = import nixpkgs {
inherit system;
config = hf-nix.lib.config system;
overlays = [
hf-nix.overlays.default
];
};
in
{
formatter = pkgs.nixfmt-tree;
devShells = with pkgs; rec {
default = mkShell {
buildInputs =
[
black
mypy
pyright
ruff
]
++ (with python3.pkgs; [
docutils
huggingface-hub
pytest
pytest-benchmark
pyyaml
torch
types-pyyaml
venvShellHook
]);
venvDir = "./.venv";
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
( python -m pip install --no-build-isolation --no-dependencies -e . )
'';
};
};
}
);
}

View File

@ -1,20 +1,21 @@
[project]
name = "hf-kernels"
version = "0.1.1"
description = "Download cuda kernels"
name = "kernels"
version = "0.7.0"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
{ name = "Daniel de Kok", email = "daniel@huggingface.co" },
{ name = "David Holtz", email = "david@huggingface.co" },
{ name = "Nicolas Patry", email = "nicolas@huggingface.co" },
]
license = { text = "Apache-2.0" }
readme = "README.md"
requires-python = ">= 3.8"
requires-python = ">= 3.9"
dependencies = [
"huggingface-hub>=0.26.3",
"packaging>=24.2",
"tomli>=2.0.1; python_version<'3.11'",
"torch>=2.4",
"huggingface_hub>=0.26.0,<1.0",
"packaging>=20.0",
"pyyaml>=6",
"tomli>=2.0; python_version<'3.11'",
]
[build-system]
@ -23,18 +24,47 @@ build-backend = "setuptools.build_meta"
[dependency-groups]
dev = [
"mypy >= 1.15.0",
"pytest >=8",
# Whatever version is compatible with pytest.
"pytest-benchmark",
"torch >=2.5",
"types-pyyaml"
]
[project.optional-dependencies]
torch = ["torch"]
[project.scripts]
hf-kernels = "hf_kernels.cli:main"
kernels = "kernels.cli:main"
[project.entry-points."egg_info.writers"]
"hf-kernels.lock" = "hf_kernels.lockfile:write_egg_lockfile"
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
#[build-system]
#requires = ["torch", "huggingface_hub", "numpy", "tomli;python_version<='3.10'"]
#build-backend = "hf_kernels.build"
#backend-path = ["src"]
[tool.ruff]
exclude = [
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".venv*",
"__pypackages__",
"_build",
"build",
"dist",
"venv",
]
line-length = 119
# Ignored rules:
# "E501" -> line length violation
lint.ignore = ["E501"]
lint.select = ["E", "F", "I", "W"]

4
pytest.ini Normal file
View File

@ -0,0 +1,4 @@
[pytest]
markers =
darwin_only: marks tests that should only run on macOS
linux_only: marks tests that should only run on Linux

View File

@ -1,3 +0,0 @@
from hf_kernels.utils import get_kernel, install_kernel, load_kernel, get_locked_kernel
__all__ = ["get_kernel", "get_locked_kernel", "load_kernel", "install_kernel"]

View File

@ -1,144 +0,0 @@
"""
Python shims for the PEP 517 and PEP 660 build backend.
Major imports in this module are required to be lazy:
```
$ hyperfine \
"/usr/bin/python3 -c \"print('hi')\"" \
"/usr/bin/python3 -c \"from subprocess import check_call; print('hi')\""
Base: Time (mean ± σ): 11.0 ms ± 1.7 ms [User: 8.5 ms, System: 2.5 ms]
With import: Time (mean ± σ): 15.2 ms ± 2.0 ms [User: 12.3 ms, System: 2.9 ms]
Base 1.38 ± 0.28 times faster than with import
```
The same thing goes for the typing module, so we use Python 3.10 type annotations that
don't require importing typing but then quote them so earlier Python version ignore
them while IDEs and type checker can see through the quotes.
"""
from hf_kernels.compat import tomllib
TYPE_CHECKING = False
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence # noqa:I001
from typing import Any # noqa:I001
def warn_config_settings(config_settings: "Mapping[Any, Any] | None" = None) -> None:
import sys
if config_settings:
print("Warning: Config settings are not supported", file=sys.stderr)
def call(
args: "Sequence[str]", config_settings: "Mapping[Any, Any] | None" = None
) -> str:
"""Invoke a uv subprocess and return the filename from stdout."""
import shutil
import subprocess
import sys
warn_config_settings(config_settings)
# Unlike `find_uv_bin`, this mechanism must work according to PEP 517
import os
cwd = os.getcwd()
filename = os.path.join(cwd, "pyproject.toml")
with open(filename, "rb") as f:
data = tomllib.load(f)
for kernel, _ in (
data.get("tool", {}).get("hf-kernels", {}).get("dependencies", {}).items()
):
from hf_kernels.utils import install_kernel
install_kernel(kernel, revision="main")
uv_bin = shutil.which("uv")
if uv_bin is None:
raise RuntimeError("uv was not properly installed")
# Forward stderr, capture stdout for the filename
result = subprocess.run([uv_bin, *args], stdout=subprocess.PIPE)
if result.returncode != 0:
sys.exit(result.returncode)
# If there was extra stdout, forward it (there should not be extra stdout)
stdout = result.stdout.decode("utf-8").strip().splitlines(keepends=True)
sys.stdout.writelines(stdout[:-1])
# Fail explicitly instead of an irrelevant stacktrace
if not stdout:
print("uv subprocess did not return a filename on stdout", file=sys.stderr)
sys.exit(1)
return stdout[-1].strip()
def build_sdist(
sdist_directory: str, config_settings: "Mapping[Any, Any] | None" = None
) -> str:
"""PEP 517 hook `build_sdist`."""
args = ["build-backend", "build-sdist", sdist_directory]
return call(args, config_settings)
def build_wheel(
wheel_directory: str,
config_settings: "Mapping[Any, Any] | None" = None,
metadata_directory: "str | None" = None,
) -> str:
"""PEP 517 hook `build_wheel`."""
args = ["build-backend", "build-wheel", wheel_directory]
if metadata_directory:
args.extend(["--metadata-directory", metadata_directory])
return call(args, config_settings)
def get_requires_for_build_sdist(
config_settings: "Mapping[Any, Any] | None" = None,
) -> "Sequence[str]":
"""PEP 517 hook `get_requires_for_build_sdist`."""
warn_config_settings(config_settings)
return []
def get_requires_for_build_wheel(
config_settings: "Mapping[Any, Any] | None" = None,
) -> "Sequence[str]":
"""PEP 517 hook `get_requires_for_build_wheel`."""
warn_config_settings(config_settings)
return []
def prepare_metadata_for_build_wheel(
metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None
) -> str:
"""PEP 517 hook `prepare_metadata_for_build_wheel`."""
args = ["build-backend", "prepare-metadata-for-build-wheel", metadata_directory]
return call(args, config_settings)
def build_editable(
wheel_directory: str,
config_settings: "Mapping[Any, Any] | None" = None,
metadata_directory: "str | None" = None,
) -> str:
"""PEP 660 hook `build_editable`."""
args = ["build-backend", "build-editable", wheel_directory]
if metadata_directory:
args.extend(["--metadata-directory", metadata_directory])
return call(args, config_settings)
def get_requires_for_build_editable(
config_settings: "Mapping[Any, Any] | None" = None,
) -> "Sequence[str]":
"""PEP 660 hook `get_requires_for_build_editable`."""
warn_config_settings(config_settings)
return []
def prepare_metadata_for_build_editable(
metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None
) -> str:
"""PEP 660 hook `prepare_metadata_for_build_editable`."""
args = ["build-backend", "prepare-metadata-for-build-editable", metadata_directory]
return call(args, config_settings)

View File

@ -1,75 +0,0 @@
import argparse
import dataclasses
import json
import sys
from pathlib import Path
from hf_kernels.compat import tomllib
from hf_kernels.lockfile import KernelLock, get_kernel_locks
from hf_kernels.utils import install_kernel
def main():
parser = argparse.ArgumentParser(
prog="hf-kernel", description="Manage compute kernels"
)
subparsers = parser.add_subparsers()
download_parser = subparsers.add_parser("download", help="Download locked kernels")
download_parser.add_argument(
"project_dir",
type=Path,
help="The project directory",
)
download_parser.set_defaults(func=download_kernels)
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
lock_parser.add_argument(
"project_dir",
type=Path,
help="The project directory",
)
lock_parser.set_defaults(func=lock_kernels)
args = parser.parse_args()
args.func(args)
def download_kernels(args):
lock_path = args.project_dir / "hf-kernels.lock"
if not lock_path.exists():
print(f"No hf-kernels.lock file found in: {args.project_dir}", file=sys.stderr)
sys.exit(1)
with open(args.project_dir / "hf-kernels.lock", "r") as f:
lock_json = json.load(f)
for kernel_lock_json in lock_json:
kernel_lock = KernelLock.from_json(kernel_lock_json)
print(
f"Downloading `{kernel_lock.repo_id}` at with SHA: {kernel_lock.sha}",
file=sys.stderr,
)
install_kernel(kernel_lock.repo_id, kernel_lock.sha)
def lock_kernels(args):
with open(args.project_dir / "pyproject.toml", "rb") as f:
data = tomllib.load(f)
kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)
all_locks = []
for kernel, version in kernel_versions.items():
all_locks.append(get_kernel_locks(kernel, version))
with open(args.project_dir / "hf-kernels.lock", "w") as f:
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
class _JSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)

View File

@ -1,132 +0,0 @@
import importlib
import importlib.metadata
import inspect
import json
import os
import platform
import sys
from importlib.metadata import Distribution
from types import ModuleType
from typing import List, Optional
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from packaging.version import parse
from hf_kernels.compat import tomllib
from hf_kernels.lockfile import KernelLock
def build_variant():
torch_version = parse(torch.__version__)
cuda_version = parse(torch.version.cuda)
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
cpu = platform.machine()
os = platform.system().lower()
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
def import_from_path(module_name: str, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def install_kernel(repo_id: str, revision: str, local_files_only: bool = False):
package_name = get_metadata(repo_id, revision, local_files_only=local_files_only)[
"torch"
]["name"]
repo_path = snapshot_download(
repo_id,
allow_patterns=f"build/{build_variant()}/*",
revision=revision,
local_files_only=local_files_only,
)
return package_name, f"{repo_path}/build/{build_variant()}"
def get_metadata(repo_id: str, revision: str, local_files_only: bool = False):
with open(
hf_hub_download(
repo_id, "build.toml", revision=revision, local_files_only=local_files_only
),
"rb",
) as f:
return tomllib.load(f)
def get_kernel(repo_id: str, revision: str = "main"):
package_name, package_path = install_kernel(repo_id, revision=revision)
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
def load_kernel(repo_id: str):
"""Get a pre-downloaded, locked kernel."""
locked_sha = _get_caller_locked_kernel(repo_id)
if locked_sha is None:
raise ValueError(f"Kernel `{repo_id}` is not locked")
filename = hf_hub_download(
repo_id, "build.toml", local_files_only=True, revision=locked_sha
)
with open(filename, "rb") as f:
metadata = tomllib.load(f)
package_name = metadata["torch"]["name"]
repo_path = os.path.dirname(filename)
package_path = f"{repo_path}/build/{build_variant()}"
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
def get_locked_kernel(repo_id: str, local_files_only: bool = False):
"""Get a kernel using a lock file."""
locked_sha = _get_caller_locked_kernel(repo_id)
if locked_sha is None:
raise ValueError(f"Kernel `{repo_id}` is not locked")
package_name, package_path = install_kernel(
repo_id, locked_sha, local_files_only=local_files_only
)
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
for dist in _get_caller_distributions():
lock_json = dist.read_text("hf-kernels.lock")
if lock_json is not None:
for kernel_lock_json in json.loads(lock_json):
kernel_lock = KernelLock.from_json(kernel_lock_json)
if kernel_lock.repo_id == repo_id:
return kernel_lock.sha
return None
def _get_caller_distributions() -> List[Distribution]:
module = _get_caller_module()
if module is None:
return []
# Look up all possible distributions that this module could be from.
package = module.__name__.split(".")[0]
dist_names = importlib.metadata.packages_distributions().get(package)
if dist_names is None:
return []
return [importlib.metadata.distribution(dist_name) for dist_name in dist_names]
def _get_caller_module() -> Optional[ModuleType]:
stack = inspect.stack()
# Get first module in the stack that is not the current module.
first_module = inspect.getmodule(stack[0][0])
for frame in stack[1:]:
module = inspect.getmodule(frame[0])
if module is not None and module != first_module:
return module
return first_module

35
src/kernels/__init__.py Normal file
View File

@ -0,0 +1,35 @@
from kernels.layer import (
Device,
LayerRepository,
Mode,
kernelize,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
use_kernel_mapping,
)
from kernels.utils import (
get_kernel,
get_local_kernel,
get_locked_kernel,
has_kernel,
install_kernel,
load_kernel,
)
__all__ = [
"Device",
"LayerRepository",
"Mode",
"get_kernel",
"get_local_kernel",
"get_locked_kernel",
"has_kernel",
"install_kernel",
"kernelize",
"load_kernel",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"use_kernel_forward_from_hub",
"use_kernel_mapping",
]

View File

@ -0,0 +1,751 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Vendored from https://github.com/huggingface/doc-builder/blob/main/src/doc_builder/convert_rst_to_mdx.py
import re
# Re pattern to catch things inside ` ` in :obj:`thing`.
_re_obj = re.compile(r":obj:`([^`]+)`")
# Re pattern to catch things inside ` ` in :math:`thing`.
_re_math = re.compile(r":math:`([^`]+)`")
# Re pattern to catch things between single backquotes.
_re_single_backquotes = re.compile(r"(^|[^`])`([^`]+)`([^`]|$)")
# Re pattern to catch things between double backquotes.
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
# Re pattern to catch things inside ` ` in :func/class/meth:`thing`.
_re_func_class = re.compile(r":(?:func|class|meth):`([^`]+)`")
def convert_rst_formatting(text):
"""
Convert rst syntax for formatting to markdown in a given text.
"""
# Remove :class:, :func: and :meth: markers. To code-links and put double backquotes
# (to not be caught by the italic conversion).
text = _re_func_class.sub(r"[``\1``]", text)
# Remove :obj: markers. What's after is in a single backquotes so we put in double backquotes
# (to not be caught by the italic conversion).
text = _re_obj.sub(r"``\1``", text)
# Remove :math: markers.
text = _re_math.sub(r"\\\\(\1\\\\)", text)
# Convert content in single backquotes to italic.
text = _re_single_backquotes.sub(r"\1*\2*\3", text)
# Convert content in double backquotes to single backquotes.
text = _re_double_backquotes.sub(r"\1`\2`\3", text)
# Remove remaining ::
text = re.sub(r"::\n", "", text)
# Remove new lines inside blocks in backsticks as they will be kept.
lines = text.split("\n")
in_code = False
text = None
for line in lines:
if in_code:
splits = line.split("`")
in_code = len(splits) > 1 and len(splits) % 2 == 1
if len(splits) == 1:
# Some forgotten lone backstick
text += "\n" + line
else:
text += " " + line.lstrip()
else:
if text is not None:
text += "\n" + line
else:
text = line
splits = line.split("`")
in_code = len(splits) % 2 == 0
return text
# Re pattern to catch description and url in links of the form `description <url>`_.
_re_links = re.compile(r"`([^`]+\S)\s+</*([^/][^>`]*)>`_+")
# Re pattern to catch description and url in links of the form :prefix_link:`description <url>`_.
_re_prefix_links = re.compile(r":prefix_link:`([^`]+\S)\s+</*([^/][^>`]*)>`")
# Re pattern to catch reference in links of the form :doc:`reference`.
_re_simple_doc = re.compile(r":doc:`([^`<]*)`")
# Re pattern to catch description and reference in links of the form :doc:`description <reference>`.
_re_doc_with_description = re.compile(r":doc:`([^`<]+\S)\s+</*([^/][^>`]*)>`")
# Re pattern to catch reference in links of the form :ref:`reference`.
_re_simple_ref = re.compile(r":ref:`([^`<]*)`")
# Re pattern to catch description and reference in links of the form :ref:`description <reference>`.
_re_ref_with_description = re.compile(r":ref:`([^`<]+\S)\s+<([^>]*)>`")
def convert_rst_links(text, page_info):
"""
Convert the rst links in text to markdown.
"""
if "package_name" not in page_info:
raise ValueError("`page_info` must contain at least the package_name.")
package_name = page_info["package_name"]
version = page_info.get("version", "main")
language = page_info.get("language", "en")
no_prefix = page_info.get("no_prefix", False)
prefix = "" if no_prefix else f"/docs/{package_name}/{version}/{language}/"
# Links of the form :doc:`page`
text = _re_simple_doc.sub(rf"[\1]({prefix}\1)", text)
# Links of the form :doc:`text <page>`
text = _re_doc_with_description.sub(rf"[\1]({prefix}\2)", text)
if "page" in page_info and not no_prefix:
page = str(page_info["page"])
if page.endswith(".html"):
page = page[:-5]
prefix = f"{prefix}{page}"
else:
prefix = ""
# Refs of the form :ref:`page`
text = _re_simple_ref.sub(rf"[\1]({prefix}#\1)", text)
# Refs of the form :ref:`text <page>`
text = _re_ref_with_description.sub(rf"[\1]({prefix}#\2)", text)
# Links with a prefix
# TODO: when it exists, use the API to deal with prefix links properly.
prefix = f"https://github.com/huggingface/{package_name}/tree/main/"
text = _re_prefix_links.sub(rf"[\1]({prefix}\2)", text)
# Other links
text = _re_links.sub(r"[\1](\2)", text)
# Relative links or Transformers links need to remove the .html
if (
"(https://https://huggingface.co/" in text
or re.search(r"\(\.+/", text) is not None
):
text = text.replace(".html", "")
return text
# Re pattern that catches examples blocks of the form `Example::`.
_re_example = re.compile(r"^\s*(\S.*)::\s*$")
# Re pattern that catches rst blocks of the form `.. block_name::`.
_re_block = re.compile(r"^\s*\.\.\s+(\S+)::")
# Re pattern that catches what's after the :: in rst blocks of the form `.. block_name:: something`.
_re_block_info = re.compile(r"^\s*\.\.\s+\S+::\s*(\S.*)$")
def is_empty_line(line):
return len(line) == 0 or line.isspace()
def find_indent(line):
"""
Returns the number of spaces that start a line indent.
"""
search = re.search(r"^(\s*)(?:\S|$)", line)
if search is None:
return 0
return len(search.groups()[0])
_re_rst_option = re.compile(r"^\s*:(\S+):(.*)$")
def convert_special_chars(text):
"""
Converts { and < that have special meanings in MDX.
"""
text = text.replace("{", "&amp;lcub;")
# We don't want to replace those by the HTML code, so we temporarily set them at LTHTML
text = re.sub(
r"<(img|br|hr|Youtube)", r"LTHTML\1", text
) # html void elements with no closing counterpart
_re_lt_html = re.compile(r"<(\S+)([^>]*>)(((?!</\1>).)*)<(/\1>)", re.DOTALL)
while _re_lt_html.search(text):
text = _re_lt_html.sub(r"LTHTML\1\2\3LTHTML\5", text)
text = re.sub(r"(^|[^<])<([^<]|$)", r"\1&amp;lt;\2", text)
text = text.replace("LTHTML", "<")
return text
def parse_options(block_content):
"""
Parses the option in some rst block content.
"""
block_lines = block_content.split("\n")
block_indent = find_indent(block_lines[0])
current_option = None
result = {}
for line in block_lines:
if _re_rst_option.search(line) is not None:
current_option, value = _re_rst_option.search(line).groups()
result[current_option] = value.lstrip()
elif find_indent(line) > block_indent:
result[current_option] += " " + line.lstrip()
return result
def apply_min_indent(text, min_indent):
"""
Make sure all lines in a text are have a minimum indentation.
Args:
text (`str`): The text to treat.
min_indent (`int`): The minimal indentation.
Returns:
`str`: The processed text.
"""
lines = text.split("\n")
idx = 0
while idx < len(lines):
if is_empty_line(lines[idx]):
idx += 1
continue
indent = find_indent(lines[idx])
if indent < min_indent:
while idx < len(lines) and (
find_indent(lines[idx]) >= indent or is_empty_line(lines[idx])
):
if not is_empty_line(lines[idx]):
lines[idx] = " " * (min_indent - indent) + lines[idx]
idx += 1
else:
idx += 1
return "\n".join(lines)
def convert_rst_blocks(text, page_info):
"""
Converts rst special blocks (examples, notes) into MDX.
"""
if "package_name" not in page_info:
raise ValueError("`page_info` must contain at least the package_name.")
package_name = page_info["package_name"]
version = page_info.get("version", "main")
language = page_info.get("language", "en")
lines = text.split("\n")
idx = 0
new_lines = []
while idx < len(lines):
block_type = None
block_info = None
if _re_block.search(lines[idx]) is not None:
block_type = _re_block.search(lines[idx]).groups()[0]
if _re_block_info.search(lines[idx]) is not None:
block_info = _re_block_info.search(lines[idx]).groups()[0]
elif _re_example.search(lines[idx]) is not None:
block_type = "code-block-example"
block_info = "python"
example_name = _re_example.search(lines[idx]).groups()[0]
new_lines.append(f"<exampletitle>{example_name}:</exampletitle>\n")
elif lines[idx].strip() == "..":
block_type = "comment"
elif lines[idx].strip() == "::":
block_type = "code-block"
if block_type is not None:
block_indent = find_indent(lines[idx])
# Find the next nonempty line
idx += 1
while idx < len(lines) and is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the return line, this block will stop when we unindent under it (or has already)
example_indent = (
find_indent(lines[idx]) if idx < len(lines) else block_indent
)
if example_indent == block_indent:
block_content = ""
else:
block_lines = []
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) >= example_indent
):
block_lines.append(lines[idx][example_indent:])
idx += 1
block_content = "\n".join(block_lines)
if block_type in ["code", "code-block"]:
prefix = "```" if block_info is None else f"```{block_info}"
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n")
elif block_type == "code-block-example":
prefix = f"<example>```{block_info}"
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n</example>")
elif block_type == "note":
new_lines.append(
apply_min_indent(
f"<Tip>\n\n{block_content.strip()}\n\n</Tip>\n", block_indent
)
)
elif block_type == "warning":
new_lines.append(
apply_min_indent(
"<Tip warning={true}>\n\n"
+ f"{block_content.strip()}\n\n</Tip>\n",
block_indent,
)
)
elif block_type == "raw":
new_lines.append(block_content.strip() + "\n")
elif block_type == "math":
new_lines.append(f"$${block_content.strip()}$$\n")
elif block_type == "comment":
new_lines.append(f"<!--{block_content.strip()}\n-->\n")
elif block_type == "autofunction":
if block_info is not None:
new_lines.append(f"[[autodoc]] {block_info}\n")
elif block_type == "autoclass":
if block_info is not None:
block = f"[[autodoc]] {block_info}\n"
options = parse_options(block_content)
if "special-members" in options:
special_members = options["special-members"].split(", ")
for special_member in special_members:
block += f" - {special_member}\n"
if "members" in options:
members = options["members"]
if len(members) == 0:
block += " - all\n"
else:
for member in members.split(", "):
block += f" - {member}\n"
new_lines.append(block)
elif block_type == "image":
options = parse_options(block_content)
target = options.pop("target", None)
if block_info is not None:
options["src"] = block_info
else:
if target is None:
raise ValueError("Image source not defined.")
options["src"] = target
# Adapt path
options["src"] = options["src"].replace(
"/imgs/", f"/docs/{package_name}/{version}/{language}/imgs/"
)
html_code = " ".join(
[f'{key}="{value}"' for key, value in options.items()]
)
new_lines.append(f"<img {html_code}/>\n")
else:
new_lines.append(
f"{block_type},{block_info}\n{block_content.rstrip()}\n"
)
else:
new_lines.append(lines[idx])
idx += 1
return "\n".join(new_lines)
# Re pattern that catches rst args blocks of the form `Parameters:`.
_re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\s*$")
# Re pattern that catches return blocks of the form `Return:`.
_re_returns = re.compile(r"^\s*(Return|Yield|Raise)s?:\s*$")
def split_return_line(line):
"""
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
"""
splits_on_colon = line.split(":")
idx = 1
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
idx += 2
if idx >= len(splits_on_colon):
if len(splits_on_colon) % 2 == 1 and re.search(r"`\w+`$", line.rstrip()):
return line, ""
return None, line
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
def split_raise_line(line):
"""
Split the raise line with format `SomeError some doc`.
"""
splits_on_colon = line.strip().split(" ")
error_type, doc = splits_on_colon[0], " ".join(splits_on_colon[1:])
if error_type and error_type[-1] == ":":
error_type = error_type[:-1]
return error_type, doc
def split_arg_line(line):
"""
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
"""
splits_on_colon = line.split(":")
idx = 1
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
idx += 2
if idx >= len(splits_on_colon):
return line, ""
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
class InvalidRstDocstringError(ValueError):
pass
_re_parameters = re.compile(
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
)
_re_md_link = re.compile(r"\[(.+)\]\(.+\)", re.DOTALL)
def parse_rst_docstring(docstring):
"""
Parses a docstring written in rst, in particular the list of arguments and the return type.
"""
lines = docstring.split("\n")
idx = 0
while idx < len(lines):
# Parameters section
if _re_args.search(lines[idx]) is not None:
# Title of the section.
lines[idx] = "<parameters>\n"
# Find the next nonempty line
idx += 1
while is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the list of parameters, this block will stop when we unindent under it or we see the
# Returns or Raises block.
param_indent = find_indent(lines[idx])
while (
idx < len(lines)
and find_indent(lines[idx]) == param_indent
and _re_returns.search(lines[idx]) is None
):
intro, doc = split_arg_line(lines[idx])
# Line starting with a > after indent indicate a "section title" in the parameters.
if intro.lstrip().startswith(">"):
lines[idx] = intro.lstrip()
else:
lines[idx] = (
re.sub(r"^\s*(\S+)(\s)?", r"- **\1**\2", intro) + " --" + doc
)
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx]) or find_indent(lines[idx]) > param_indent
):
idx += 1
lines.insert(idx, "</parameters>\n")
idx += 1
# Returns section
elif _re_returns.search(lines[idx]) is not None:
# tag is either `return` or `yield`
tag = _re_returns.match(lines[idx]).group(1).lower()
# Title of the section.
lines[idx] = f"<{tag}s>\n"
# Find the next nonempty line
idx += 1
while is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the return line, this block will stop when we unindent under it.
return_indent = find_indent(lines[idx])
raised_errors = []
# The line may contain the return type.
if tag in ["return", "yield"]:
return_type, return_description = split_return_line(lines[idx])
lines[idx] = return_description
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) >= return_indent
):
idx += 1
else:
while idx < len(lines) and find_indent(lines[idx]) == return_indent:
return_type, return_description = split_raise_line(lines[idx])
raised_error = re.sub(r"^\s*`?([\w\.]*)`?$", r"``\1``", return_type)
lines[idx] = "- " + raised_error + " -- " + return_description
md_link = _re_md_link.match(raised_error)
if md_link:
raised_error = md_link[1]
raised_error = re.sub(
r"^\s*`?([\w\.]*)`?$", r"``\1``", raised_error
)
if raised_error not in raised_errors:
raised_errors.append(raised_error)
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) > return_indent
):
idx += 1
lines.insert(idx, f"</{tag}s>\n")
idx += 1
# Return block finished, we insert the return type if one was specified
if tag in ["return", "yield"] and return_type is not None:
lines[idx - 1] += f"\n<{tag}type>{return_type}</{tag}type>\n"
elif len(raised_errors) > 0:
# raised errors
lines[
idx - 1
] += f"\n<raisederrors>{' or '.join(raised_errors)}</raisederrors>\n"
else:
idx += 1
result = "\n".join(lines)
# combine multiple <parameters> blocks into one block
if result.count("<parameters>") > 1:
parameters_blocks = _re_parameters.findall(result)
parameters_blocks = [pb[0].strip() for pb in parameters_blocks]
parameters_str = "\n".join(parameters_blocks)
result = _re_parameters.sub("", result)
result += f"\n<parameters>{parameters_str}</parameters>\n"
return result
_re_list = re.compile(r"^\s*(-|\*|\d+\.)\s")
_re_autodoc = re.compile(r"^\s*\[\[autodoc\]\]\s+(\S+)\s*$")
def remove_indent(text):
"""
Remove indents in text, except the one linked to lists (or sublists).
"""
lines = text.split("\n")
# List of indents to remember for nested lists
current_indents = []
# List of new indents to remember for nested lists
new_indents = []
is_inside_code = False
code_indent = 0
for idx, line in enumerate(lines):
# Line is an item in a list.
if _re_list.search(line) is not None:
indent = find_indent(line)
# Is it a new list / new level of nestedness?
if len(current_indents) == 0 or indent > current_indents[-1]:
current_indents.append(indent)
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indent += len(_re_list.search(line).groups()[0]) + 1
new_indents.append(new_indent)
# Otherwise it's an existing level of list (current one, or previous one)
else:
# Let's find the proper level of indentation
level = len(current_indents) - 1
while level >= 0 and current_indents[level] != indent:
level -= 1
current_indents = current_indents[: level + 1]
new_indents = new_indents[:level]
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indent += len(_re_list.search(line).groups()[0]) + 1
new_indents.append(new_indent)
# Line is an autodoc, we keep the indent for the list just after if there is one.
elif _re_autodoc.search(line) is not None:
indent = find_indent(line)
current_indents = [indent]
new_indents = [4]
lines[idx] = line.strip()
# Deal with empty lines separately
elif is_empty_line(line):
lines[idx] = ""
# Code blocks
elif line.lstrip().startswith("```"):
is_inside_code = not is_inside_code
if is_inside_code:
code_indent = find_indent(line)
lines[idx] = line[code_indent:]
elif is_inside_code:
lines[idx] = line[code_indent:]
else:
indent = find_indent(line)
if len(current_indents) > 0 and indent > current_indents[-1]:
lines[idx] = " " * new_indents[-1] + line[indent:]
elif len(current_indents) > 0:
# Let's find the proper level of indentation
level = len(current_indents) - 1
while level >= 0 and current_indents[level] > indent:
level -= 1
current_indents = current_indents[: level + 1]
if level >= 0:
if current_indents[level] < indent:
new_indents = new_indents[: level + 1]
else:
new_indents = new_indents[:level]
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indents.append(new_indent)
else:
new_indents = []
lines[idx] = line[indent:]
else:
lines[idx] = line[indent:]
return "\n".join(lines)
def base_rst_to_mdx(text, page_info, unindent=True):
"""
Convert a text from rst to mdx, with the base operations necessary for both docstrings and rst docs.
"""
text = convert_rst_links(text, page_info)
text = convert_special_chars(text)
text = convert_rst_blocks(text, page_info)
# Convert * in lists to - to avoid the formatting conversion treat them as bold.
text = re.sub(r"^(\s*)\*(\s)", r"\1-\2", text, flags=re.MULTILINE)
text = convert_rst_formatting(text)
return remove_indent(text) if unindent else text
def convert_rst_docstring_to_mdx(docstring, page_info):
"""
Convert a docstring written in rst to mdx.
"""
text = parse_rst_docstring(docstring)
return base_rst_to_mdx(text, page_info)
def process_titles(lines):
"""Converts rst titles to markdown titles."""
title_chars = """= - ` : ' " ~ ^ _ * + # < >""".split(" ")
title_levels = {}
new_lines = []
for line in lines:
if (
len(new_lines) > 0
and len(line) >= len(new_lines[-1])
and len(set(line)) == 1
and line[0] in title_chars
and line != "::"
):
char = line[0]
level = title_levels.get(char, len(title_levels) + 1)
if level not in title_levels:
title_levels[char] = level
new_lines[-1] = f"{'#' * level} {new_lines[-1]}"
else:
new_lines.append(line)
return new_lines
# Matches lines with a pattern of a table new line in rst.
_re_ignore_line_table = re.compile(r"^(\+[\-\s]+)+\+\s*$")
# Matches lines with a pattern of a table new line in rst, with a first column empty.
_re_ignore_line_table1 = re.compile(r"^\|\s+(\+[\-\s]+)+\+\s*$")
# Matches lines with a pattern of a first table line in rst.
_re_sep_line_table = re.compile(r"^(\+[=\s]+)+\+\s*$")
# Re pattern that catches anchors of the type .. reference:
_re_anchor_section = re.compile(r"^\.\.\s+_(\S+):")
def split_pt_tf_code_blocks(text):
"""
Split PyTorch and TensorFlow specific block codes.
"""
lines = text.split("\n")
new_lines = []
idx = 0
while idx < len(lines):
if lines[idx].startswith("```"):
code_lines = {"common": [lines[idx]], "pytorch": [], "tensorflow": []}
is_pytorch = False
is_tensorflow = False
idx += 1
while idx < len(lines) and lines[idx].strip() != "```":
if "## PYTORCH CODE" in lines[idx]:
is_pytorch = True
is_tensorflow = False
elif "## TENSORFLOW CODE" in lines[idx]:
is_tensorflow = True
is_pytorch = False
elif is_pytorch:
code_lines["pytorch"].append(lines[idx])
elif is_tensorflow:
code_lines["tensorflow"].append(lines[idx])
else:
code_lines["common"].append(lines[idx])
idx += 1
if len(code_lines["pytorch"]) > 0 or len(code_lines["tensorflow"]) > 0:
block_lines = ["<frameworkcontent>", "<pt>"]
block_lines.extend(code_lines["common"].copy() + code_lines["pytorch"])
block_lines.extend(["```", "</pt>", "<tf>"])
block_lines.extend(
code_lines["common"].copy() + code_lines["tensorflow"]
)
block_lines.extend(["```", "</tf>", "</frameworkcontent>"])
new_lines.extend(block_lines)
else:
block_lines = code_lines["common"] + ["```"]
new_lines.extend(block_lines)
idx += 1
else:
new_lines.append(lines[idx])
idx += 1
return "\n".join(new_lines)
def convert_rst_to_mdx(rst_text, page_info, add_imports=True):
"""
Convert a document written in rst to mdx.
"""
lines = rst_text.split("\n")
lines = process_titles(lines)
if add_imports:
new_lines = [
'<script lang="ts">',
' import Tip from "$lib/Tip.svelte";',
' import Youtube from "$lib/Youtube.svelte";',
' import Docstring from "$lib/Docstring.svelte";',
' import CodeBlock from "$lib/CodeBlock.svelte";',
' import CodeBlockFw from "$lib/CodeBlockFw.svelte";',
' import DocNotebookDropdown from "$lib/DocNotebookDropdown.svelte";',
' import CourseFloatingBanner from "$lib/CourseFloatingBanner.svelte";',
' import IconCopyLink from "$lib/IconCopyLink.svelte";',
' import FrameworkContent from "$lib/FrameworkContent.svelte";',
' import Markdown from "$lib/Markdown.svelte";',
' import ExampleCodeBlock from "$lib/ExampleCodeBlock.svelte";',
' import Added from "$lib/Added.svelte";',
' import Changed from "$lib/Changed.svelte";',
' import Deprecated from "$lib/Deprecated.svelte";',
' import PipelineIcon from "$lib/PipelineIcon.svelte";',
' import PipelineTag from "$lib/PipelineTag.svelte";',
" ",
' export let fw: "pt" | "tf"',
"</script>",
"<svelte:head>",
'<meta name="hf:doc:metadata" content={JSON.stringify(metadata)} >',
"</svelte:head>",
"",
]
else:
new_lines = []
for line in lines:
if _re_ignore_line_table.search(line) is not None:
continue
elif _re_ignore_line_table1.search(line) is not None:
continue
elif _re_sep_line_table.search(line) is not None:
line = line.replace("=", "-").replace("+", "|")
elif _re_anchor_section.search(line) is not None:
anchor_name = _re_anchor_section.search(line).groups()[0]
line = f"<a id='{anchor_name}'></a>"
new_lines.append(line)
text = "\n".join(new_lines)
return split_pt_tf_code_blocks(base_rst_to_mdx(text, page_info))

160
src/kernels/cli.py Normal file
View File

@ -0,0 +1,160 @@
import argparse
import dataclasses
import json
import sys
from pathlib import Path
from kernels.compat import tomllib
from kernels.lockfile import KernelLock, get_kernel_locks
from kernels.utils import install_kernel, install_kernel_all_variants
from .doc import generate_readme_for_kernel
from .wheel import build_variant_to_wheel
def main():
parser = argparse.ArgumentParser(
prog="kernel", description="Manage compute kernels"
)
subparsers = parser.add_subparsers(required=True)
download_parser = subparsers.add_parser("download", help="Download locked kernels")
download_parser.add_argument(
"project_dir",
type=Path,
help="The project directory",
)
download_parser.add_argument(
"--all-variants",
action="store_true",
help="Download all build variants of the kernel",
)
download_parser.set_defaults(func=download_kernels)
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
lock_parser.add_argument(
"project_dir",
type=Path,
help="The project directory",
)
lock_parser.set_defaults(func=lock_kernels)
to_wheel_parser = subparsers.add_parser(
"to-wheel", help="Convert a kernel to a wheel file"
)
to_wheel_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
to_wheel_parser.add_argument("version", type=str, help="The kernel version")
to_wheel_parser.add_argument(
"--python-version",
type=str,
default="3.9",
help="The minimum Python version. Must match the Python version that the kernel was compiled for.",
)
to_wheel_parser.add_argument(
"--manylinux-version",
type=str,
default="2.28",
help="The manylinux version. Must match the manylinux version that the kernel was compiled for.",
)
to_wheel_parser.set_defaults(func=kernels_to_wheel)
# Add generate-readme subcommand parser
generate_readme_parser = subparsers.add_parser(
"generate-readme",
help="Generate README snippets for a kernel's public functions",
)
generate_readme_parser.add_argument(
"repo_id",
type=str,
help="The kernel repo ID (e.g., kernels-community/activation)",
)
generate_readme_parser.add_argument(
"--revision",
type=str,
default="main",
help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
)
generate_readme_parser.set_defaults(
func=lambda args: generate_readme_for_kernel(
repo_id=args.repo_id, revision=args.revision
)
)
args = parser.parse_args()
args.func(args)
def download_kernels(args):
lock_path = args.project_dir / "kernels.lock"
if not lock_path.exists():
print(f"No kernels.lock file found in: {args.project_dir}", file=sys.stderr)
sys.exit(1)
with open(args.project_dir / "kernels.lock", "r") as f:
lock_json = json.load(f)
all_successful = True
for kernel_lock_json in lock_json:
kernel_lock = KernelLock.from_json(kernel_lock_json)
print(
f"Downloading `{kernel_lock.repo_id}` at with SHA: {kernel_lock.sha}",
file=sys.stderr,
)
if args.all_variants:
install_kernel_all_variants(
kernel_lock.repo_id, kernel_lock.sha, variant_locks=kernel_lock.variants
)
else:
try:
install_kernel(
kernel_lock.repo_id,
kernel_lock.sha,
variant_locks=kernel_lock.variants,
)
except FileNotFoundError as e:
print(e, file=sys.stderr)
all_successful = False
if not all_successful:
sys.exit(1)
def kernels_to_wheel(args):
variants_path = install_kernel_all_variants(
repo_id=args.repo_id, revision=f"v{args.version}"
)
for variant_path in variants_path.iterdir():
if not variant_path.is_dir():
continue
wheel_path = build_variant_to_wheel(
manylinux_version=args.manylinux_version,
python_version=args.python_version,
repo_id=args.repo_id,
version=args.version,
variant_path=variant_path,
wheel_dir=Path("."),
)
print(f"☸️ {wheel_path.name}", file=sys.stderr)
def lock_kernels(args):
with open(args.project_dir / "pyproject.toml", "rb") as f:
data = tomllib.load(f)
kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)
all_locks = []
for kernel, version in kernel_versions.items():
all_locks.append(get_kernel_locks(kernel, version))
with open(args.project_dir / "kernels.lock", "w") as f:
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
class _JSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)

242
src/kernels/doc.py Normal file
View File

@ -0,0 +1,242 @@
import inspect
import re
import sys
from types import ModuleType
import yaml
from ._vendored.convert_rst_to_mdx import convert_rst_docstring_to_mdx
from .utils import get_kernel
_RE_PARAMETERS = re.compile(
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
)
_RE_RETURNS = re.compile(r"<returns>(((?!<returns>).)*)</returns>", re.DOTALL)
_RE_RETURNTYPE = re.compile(
r"<returntype>(((?!<returntype>).)*)</returntype>", re.DOTALL
)
def _extract_description_before_tags(docstring_mdx: str) -> str:
"""Extract the description part of a docstring before any tags."""
params_pos = docstring_mdx.find("<parameters>")
returns_pos = docstring_mdx.find("<returns>")
returntype_pos = docstring_mdx.find("<returntype>")
positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1]
if positions:
first_tag_pos = min(positions)
return docstring_mdx[:first_tag_pos].strip()
else:
return docstring_mdx.strip()
def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None:
"""Print the parameters section from a docstring."""
matches = _RE_PARAMETERS.findall(docstring_mdx)
if matches:
header = "#" * header_level
print(f"\n{header} Parameters")
for match in matches:
print(f"\n{match[0].strip()}")
def _print_returns_section(
docstring_mdx: str, *, context_name: str, header_level: int
) -> None:
"""Print the returns section from a docstring."""
return_matches = _RE_RETURNS.findall(docstring_mdx)
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
if return_matches or returntype_matches:
header = "#" * header_level
print(f"\n{header} Returns")
if returntype_matches:
if len(returntype_matches) > 1:
raise ValueError(
f"More than one <returntype> tag found in docstring for {context_name}"
)
print(f"\n**Type**: {returntype_matches[0][0].strip()}")
if return_matches:
for match in return_matches:
print(f"\n{match[0].strip()}")
def _get_docstring(obj, use_dict_check: bool = False) -> str:
"""Get docstring from an object, with fallback to default message."""
# Check whether the class/method itself has docs and not just
# the superclass.
if use_dict_check:
has_doc = obj.__dict__.get("__doc__", None) is not None
else:
has_doc = getattr(obj, "__doc__", None) is not None
# We use inspect.getdoc because it does normalization.
doc = inspect.getdoc(obj)
return doc if has_doc and doc is not None else "No documentation available."
def _process_and_print_docstring(
docstring: str, *, kernel_name: str, context_name: str, header_level: int
) -> None:
"""Convert docstring to MDX and print description, parameters, and returns sections."""
docstring_mdx = convert_rst_docstring_to_mdx(
docstring, page_info={"package_name": kernel_name}
)
# Print the description
description = _extract_description_before_tags(docstring_mdx)
print(f"\n{description}")
# Print parameters and returns sections
_print_parameters_section(docstring_mdx, header_level=header_level)
_print_returns_section(
docstring_mdx, context_name=context_name, header_level=header_level
)
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
kernel_module = get_kernel(repo_id=repo_id, revision=revision)
kernel_name = repo_id.split("/")[-1].replace("-", "_")
generate_metadata(kernel_module)
generate_kernel_doc(kernel_module, kernel_name)
generate_function_doc(kernel_module, kernel_name)
generate_layers_doc(kernel_module, kernel_name)
def generate_metadata(module: ModuleType) -> None:
metadata = getattr(module, "__kernel_metadata__", {})
if "tags" not in metadata:
metadata["tags"] = ["kernel"]
else:
if "kernel" not in metadata["tags"]:
metadata["tags"].append("kernel")
print("---")
print(yaml.dump(metadata), end="")
print("---")
def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None:
docstring = module.__doc__.strip() if module.__doc__ is not None else None
if docstring:
title, rest = docstring.split("\n", 1)
print(f"# {title.strip()}")
print(
f"\n{convert_rst_docstring_to_mdx(rest.strip(), page_info={'package_name': kernel_name})}"
)
def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None:
print("\n## Functions")
# Track if we found any functions
found_functions = False
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
# Do not include imported functions.
if func.__module__ != kernel_module.__name__:
continue
# Exclude private functions.
if name.startswith("_"):
continue
found_functions = True
try:
sig = inspect.signature(func)
docstring = _get_docstring(func)
except ValueError:
print(
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
file=sys.stderr,
)
continue
print(f"\n### Function `{name}`")
print(f"\n`{sig}`")
_process_and_print_docstring(
docstring, kernel_name=kernel_name, context_name=name, header_level=3
)
if not found_functions:
print("\nNo public top-level functions.")
def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None:
# Check if layers module is available
layers_module = getattr(kernel_module, "layers", None)
if layers_module is None:
return
print("\n## Layers")
# Track if we found any classes
found_classes = False
for class_name, cls in inspect.getmembers(layers_module, inspect.isclass):
# Exclude classes that were imported.
if cls.__module__ != layers_module.__name__:
continue
found_classes = True
try:
# Get docstring, but not from superclasses.
class_docstring = _get_docstring(cls, use_dict_check=True)
except Exception:
print(
f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}",
file=sys.stderr,
)
continue
print(f"\n### Class `{class_name}`")
# Always print class description (helper handles conversion and formatting)
class_docstring_mdx = convert_rst_docstring_to_mdx(
class_docstring, page_info={"package_name": kernel_name}
)
description = _extract_description_before_tags(class_docstring_mdx)
print(f"\n{description}")
# Document methods
print("\n#### Methods")
for method_name, method in inspect.getmembers(cls, inspect.isfunction):
# Note: also skip __init__, since extension layers cannot have a constructor.
if method_name.startswith("_"):
continue
# Skip methods from superclasses.
if method_name not in cls.__dict__:
continue
try:
sig = inspect.signature(method)
method_docstring = _get_docstring(method)
except ValueError:
print(
f"Warning: Could not retrieve signature for {method_name} in {class_name}",
file=sys.stderr,
)
continue
print(f"\n##### Method `{method_name}`")
print(f"\n`{sig}`")
_process_and_print_docstring(
method_docstring,
kernel_name=kernel_name,
context_name=method_name,
header_level=6,
)
if not found_classes:
print("\nNo layers defined.")

489
src/kernels/layer.py Normal file
View File

@ -0,0 +1,489 @@
from __future__ import annotations
import inspect
import os
import warnings
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Flag, auto
from types import MethodType
from typing import (
TYPE_CHECKING,
Dict,
Optional,
Tuple,
Type,
Union,
)
from .utils import get_kernel
if TYPE_CHECKING:
import torch
from torch import nn
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
class Mode(Flag):
"""
Kernelize mode
The `Mode` flag is used by `kernelize` to select kernels for the given
mode. Mappings can be registered for specific modes.
* `INFERENCE`: The kernel is used for inference.
* `TRAINING`: The kernel is used for training.
* `TORCH_COMPILE`: The kernel is used with `torch.compile`.
* `DEFAULT`: In a kernel mapping, this kernel is used when no other mode
matches.
Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE`
should be used for layers that are used for inference *with* `torch.compile`.
"""
_NONE = 0
DEFAULT = auto()
TRAINING = auto()
INFERENCE = auto()
TORCH_COMPILE = auto()
def __or__(self, other: Mode) -> Mode:
union = super().__or__(other)
if Mode.INFERENCE in union and Mode.TRAINING in union:
raise ValueError("Mode.INFERENCE and Mode.TRAINING are mutually exclusive.")
if Mode.DEFAULT in union and union != Mode.DEFAULT:
raise ValueError("Mode.DEFAULT cannot be combined with other modes.")
return union
@dataclass(frozen=True)
class Device:
type: str
# In the future we might add compute capabilities, etc.
def __eq__(self, other):
return isinstance(other, Device) and self.type == other.type
def __hash__(self):
return hash(self.type)
@dataclass
class LayerRepository:
"""
Repository and name of a layer.
"""
layer_name: str = field(
metadata={"help": "The name of the layer in the kernel repository."}
)
repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."})
revision: str = field(
default="main", metadata={"help": "The revision of the layer."}
)
def __eq__(self, other):
return (
isinstance(other, LayerRepository)
and self.layer_name == other.layer_name
and self.repo_id == other.repo_id
and self.revision == other.revision
)
def __hash__(self):
return hash((self.layer_name, self.repo_id, self.revision))
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, Dict[Mode, LayerRepository]]]] = (
ContextVar("_KERNEL_MAPPING", default={})
)
def use_kernel_mapping(
mapping: Dict[
str,
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
],
*,
inherit_mapping: bool = True,
):
"""
Context manager that sets a mapping for a duration of the context.
When `inherit_mapping` is set to `True` the current mapping will be
extended by `mapping` inside the context. If it is `False`, only
`mapping` is used inside the context.
"""
class ContextManager:
def __enter__(self):
# Mappings always stack on previous mappings.
if inherit_mapping:
self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get()))
else:
self.token = _KERNEL_MAPPING.set({})
register_kernel_mapping(mapping)
def __exit__(self, exc_type, exc_value, traceback):
_KERNEL_MAPPING.reset(self.token)
return ContextManager()
def register_kernel_mapping(
mapping: Dict[
str,
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
],
):
"""
Allows one to register a mapping between a layer name and the corresponding
kernel(s) to use, depending on the device. This should be used in conjunction
with `kernelize`.
Example usage:
```python
from kernels import LayerRepository, register_kernel_mapping
kernel_layer_mapping = {
"LlamaRMSNorm": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="RmsNorm",
revision="layers",
),
},
}
register_kernel_mapping(kernel_layer_mapping)
```
"""
# Merge with existing mappings.
for new_kernel, new_device_repos in mapping.items():
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
for new_device, new_repo in new_device_repos.items():
device = (
Device(type=new_device) if isinstance(new_device, str) else new_device
)
if isinstance(new_repo, LayerRepository):
kernel_options = {Mode.DEFAULT: new_repo}
else:
kernel_options = new_repo
device_repo[device] = kernel_options
def replace_kernel_forward_from_hub(
cls,
layer_name: str,
):
"""
Decorator that prepares a layer class to use a kernel from the Hugging Face Hub.
This decorator stores the layer name and original forward method, which will be used
by the kernelize function to replace the forward implementation with the appropriate
kernel from the hub.
Args:
cls: The layer class to decorate
layer_name: The name of the layer to use for kernel lookup
"""
cls.kernel_layer_name = layer_name
def _select_repository(
repositories: Dict[Mode, LayerRepository],
*,
mode: Mode,
) -> Optional[Tuple[LayerRepository, Mode]]:
if mode in repositories:
return (repositories[mode], mode)
elif Mode.DEFAULT in repositories:
return (repositories[Mode.DEFAULT], Mode.DEFAULT)
else:
return None
def kernelize(
model: "nn.Module",
*,
mode: Mode,
device: Optional[Union[str, "torch.device"]] = None,
use_fallback: bool = True,
):
"""
Iterate over all modules in the model and replace the `forward` method of
extensible layers for which kernels are registered using `register_kernel_mapping`
or `use_kernel_mapping`.
Args:
model: The PyTorch model to kernelize
mode: the mode that the kernel is going to be used in (e.g.
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training
and `torch.compile`).
device: The device type to load kernels for. The device type will be inferred
from the parameters of the model when not provided.
use_fallback: Whether to use the original forward method of modules when no
compatible kernel could be found. If set to `False`, an exception will
be raised in such cases.
Returns:
The kernelized model
"""
import torch
if mode == Mode.DEFAULT:
raise ValueError("Mode.DEFAULT can only be used to register kernel mappings.")
# Type check ignored because this causes a false negative on Python < 3.11.
# Looks similar to: https://github.com/python/mypy/issues/9642
# Remove once we start doing typing checks on >= 3.11.
if Mode.INFERENCE not in mode and Mode.TRAINING not in mode: # type: ignore[operator]
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
if device is None:
device_type = _find_device(model)
elif isinstance(device, str):
device_type = Device(type=torch.device(device).type)
else:
device_type = Device(device.type)
assert isinstance(device_type, Device)
for _, module in model.named_modules():
module_class = type(module)
if not hasattr(module_class, "kernel_layer_name"):
continue
layer_name = module_class.kernel_layer_name
if _DISABLE_KERNEL_MAPPING:
_replace_forward(module, module_class)
continue
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
if kernel is None:
warnings.warn(
"\n"
f"No kernel mapping found for layer `{layer_name}`. "
f"Check if the layer name matches one of the kernels in the mapping or add the kernel "
f"you want to use to the mapping. Defaulting to original forward implementation."
)
if not use_fallback:
raise ValueError(f"No layer mapping for `{layer_name}`")
_replace_forward(module, module_class)
continue
# Get kernel options for the device
repos = kernel.get(device_type)
if repos is None:
if not use_fallback:
raise ValueError(
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
)
_replace_forward(module, module_class)
continue
repo_with_mode = _select_repository(
repos,
mode=mode,
)
if repo_with_mode is None:
if not use_fallback:
raise ValueError(
f"No repository for `{layer_name}` for configuration mode={mode}"
)
_replace_forward(module, module_class)
continue
repo, repo_mode = repo_with_mode
layer = _get_layer_memoize(repo, module_class)
# Ideally we would do validation on the mapping where we check that
# e.g. if a repo class is registered for TRAINING | TORCH_COMPILE,
# the actual layer is compatible with that. Unfortunately, this would
# mean that we have to pre-download everything.
_validate_layer_has_mode(
layer_name=layer_name, module=layer, repo=repo, repo_mode=repo_mode
)
_conditionally_replace_forward(
module=module,
layer=layer,
mode=mode,
use_fallback=use_fallback,
)
return model
def use_kernel_forward_from_hub(layer_name: str):
"""
Make a layer extensible using the name `layer_name`.
"""
def decorator(cls):
replace_kernel_forward_from_hub(cls, layer_name)
return cls
return decorator
def _get_kernel_layer(
*, repo_id: str, layer_name: str, revision: str
) -> Type["nn.Module"]:
"""Get a layer from a kernel."""
kernel = get_kernel(repo_id, revision=revision)
if getattr(kernel, "layers", None) is None:
raise ValueError(
f"Kernel `{repo_id}` at revision `{revision}` does not define any layers."
)
layer = getattr(kernel.layers, layer_name, None)
if layer is None:
raise ValueError(f"Layer `{layer_name}` not found in kernel `{repo_id}`.")
return layer
def _validate_layer(*, check_cls, cls):
import torch.nn as nn
# The layer must have at least have the following properties: (1) it
# must be stateless; (2) the forward signature should correspond to
# the signature it is replacing; (3) forward should not call other
# methods.
if not issubclass(cls, nn.Module):
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
# We verify statelessness by checking that the does not have its own
# constructor (since the constructor could add member variables)...
if cls.__init__ is not nn.Module.__init__:
raise TypeError("Layer must not override nn.Module constructor.")
# ... or predefined member variables.
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(cls)}
difference = cls_members - torch_module_members
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
if not difference <= {"can_torch_compile", "has_backward"}:
raise TypeError("Layer must not contain additional members.")
# Check whether the forward signatures are similar.
params = inspect.signature(cls.forward).parameters
ref_params = inspect.signature(check_cls.forward).parameters
if len(params) != len(ref_params):
raise TypeError(
"Forward signature does not match: different number of arguments."
)
for param, ref_param in zip(params.values(), ref_params.values()):
if param.kind != ref_param.kind:
raise TypeError(
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
)
def _find_device(model: "nn.Module") -> Device:
try:
param = next(model.parameters())
except StopIteration:
raise ValueError(
"Cannot determine model device, provide as `device` argument to `kernelize`."
)
return Device(type=param.device.type)
def _conditionally_replace_forward(
*,
module: "nn.Module",
layer: Type["nn.Module"],
mode: Mode,
use_fallback: bool,
):
module_class = type(module)
# Switch to fallback if the mode is not supported by the layer.
# Note that this is useful even after _validate_layer_has_mode because
# layers registered with the DEFAULT mode never get rejected by
# _validate_layer_has_mode. For such layers, we want to fall back in
# case the layer does not support the given mode.
needs_fallback = Mode.TORCH_COMPILE in mode and not getattr(
layer, "can_torch_compile", False
)
needs_fallback |= Mode.TRAINING in mode and not getattr(layer, "has_backward", True)
if needs_fallback:
if use_fallback:
_replace_forward(module, module_class)
else:
raise ValueError(f"Available kernel does not support mode: {mode}")
else:
_replace_forward(module, layer)
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
def _validate_layer_has_mode(
*,
layer_name: str,
module: Type["nn.Module"],
repo: LayerRepository,
repo_mode: Mode,
):
"""
Check that a repository supports the mode that it was registered for.
"""
if Mode.TRAINING in repo_mode and not getattr(module, "has_backward", True):
raise ValueError(
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support backward.\n"
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
)
if Mode.TORCH_COMPILE in repo_mode and not getattr(
module, "can_torch_compile", False
):
raise ValueError(
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support torch.compile.\n"
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
)
return True
def _get_layer_memoize(
repo: LayerRepository, module_class: Type["nn.Module"]
) -> Type["nn.Module"]:
layer = _CACHED_LAYER.get(repo, None)
if layer is not None:
return layer
layer = _get_kernel_layer(
repo_id=repo.repo_id,
layer_name=repo.layer_name,
revision=repo.revision,
)
_validate_layer(check_cls=module_class, cls=layer)
_CACHED_LAYER[repo] = layer
return layer

View File

@ -1,33 +1,37 @@
import hashlib
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Tuple
from huggingface_hub import HfApi
from huggingface_hub.hf_api import GitRefInfo
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version
from hf_kernels.compat import tomllib
from kernels.compat import tomllib
@dataclass
class FileLock:
filename: str
blob_id: str
class VariantLock:
hash: str
hash_type: str = "git_lfs_concat"
@dataclass
class KernelLock:
repo_id: str
sha: str
files: List[FileLock]
variants: Dict[str, VariantLock]
@classmethod
def from_json(cls, o: Dict):
files = [FileLock(**f) for f in o["files"]]
return cls(repo_id=o["repo_id"], sha=o["sha"], files=files)
variants = {
variant: VariantLock(**lock) for variant, lock in o["variants"].items()
}
return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)
def _get_available_versions(repo_id: str):
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
"""Get kernel versions that are available in the repository."""
versions = {}
for tag in HfApi().list_repo_refs(repo_id).tags:
@ -41,7 +45,7 @@ def _get_available_versions(repo_id: str):
return versions
def get_kernel_locks(repo_id: str, version_spec: str):
def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
"""
Get the locks for a kernel with the given version spec.
@ -72,31 +76,55 @@ def get_kernel_locks(repo_id: str, version_spec: str):
f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
)
file_locks = []
variant_files: Dict[str, List[Tuple[bytes, str]]] = {}
for sibling in r.siblings:
if sibling.rfilename.startswith("build/torch"):
if sibling.blob_id is None:
raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")
file_locks.append(
FileLock(filename=sibling.rfilename, blob_id=sibling.blob_id)
)
path = Path(sibling.rfilename)
variant = path.parts[1]
filename = Path(*path.parts[2:])
return KernelLock(repo_id=repo_id, sha=r.sha, files=file_locks)
hash = sibling.lfs.sha256 if sibling.lfs is not None else sibling.blob_id
files = variant_files.setdefault(variant, [])
# Encode as posix for consistent slash handling, then encode
# as utf-8 for byte-wise sorting later.
files.append((filename.as_posix().encode("utf-8"), hash))
variant_locks = {}
for variant, files in variant_files.items():
m = hashlib.sha256()
for filename_bytes, hash in sorted(files):
# Filename as bytes.
m.update(filename_bytes)
# Git blob or LFS file hash as bytes.
m.update(bytes.fromhex(hash))
variant_locks[variant] = VariantLock(hash=f"sha256-{m.hexdigest()}")
return KernelLock(repo_id=repo_id, sha=r.sha, variants=variant_locks)
def write_egg_lockfile(cmd, basename, filename):
import logging
cwd = Path.cwd()
with open(cwd / "pyproject.toml", "rb") as f:
pyproject_path = cwd / "pyproject.toml"
if not pyproject_path.exists():
# Nothing to do if the project doesn't have pyproject.toml.
return
with open(pyproject_path, "rb") as f:
data = tomllib.load(f)
kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)
if kernel_versions is None:
return
lock_path = cwd / "hf-kernels.lock"
lock_path = cwd / "kernels.lock"
if not lock_path.exists():
logging.warning(f"Lock file {lock_path} does not exist")
# Ensure that the file gets deleted in editable installs.

388
src/kernels/utils.py Normal file
View File

@ -0,0 +1,388 @@
import ctypes
import hashlib
import importlib
import importlib.metadata
import inspect
import json
import logging
import os
import platform
import sys
from importlib.metadata import Distribution
from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Tuple
from huggingface_hub import file_exists, snapshot_download
from packaging.version import parse
from kernels.lockfile import KernelLock, VariantLock
def _get_cache_dir() -> Optional[str]:
"""Returns the kernels cache directory."""
cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
if cache_dir is not None:
logging.warning(
"HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead"
)
return cache_dir
return os.environ.get("KERNELS_CACHE", None)
CACHE_DIR: Optional[str] = _get_cache_dir()
def build_variant() -> str:
import torch
if torch.version.cuda is not None:
cuda_version = parse(torch.version.cuda)
compute_framework = f"cu{cuda_version.major}{cuda_version.minor}"
elif torch.version.hip is not None:
rocm_version = parse(torch.version.hip.split("-")[0])
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
elif torch.backends.mps.is_available():
compute_framework = "metal"
else:
raise AssertionError(
"Torch was not compiled with CUDA, Metal, or ROCm enabled."
)
torch_version = parse(torch.__version__)
cpu = platform.machine()
os = platform.system().lower()
if os == "darwin":
cpu = "aarch64" if cpu == "arm64" else cpu
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
def universal_build_variant() -> str:
# Once we support other frameworks, detection goes here.
return "torch-universal"
def import_from_path(module_name: str, file_path: Path) -> ModuleType:
# We cannot use the module name as-is, after adding it to `sys.modules`,
# it would also be used for other imports. So, we make a module name that
# depends on the path for it to be unique using the hex-encoded hash of
# the path.
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value)
module_name = f"{module_name}_{path_hash}"
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
module = importlib.util.module_from_spec(spec)
if module is None:
raise ImportError(f"Cannot load module {module_name} from spec")
sys.modules[module_name] = module
spec.loader.exec_module(module) # type: ignore
return module
def install_kernel(
repo_id: str,
revision: str,
local_files_only: bool = False,
variant_locks: Optional[Dict[str, VariantLock]] = None,
) -> Tuple[str, Path]:
"""
Download a kernel for the current environment to the cache.
The output path is validated againt `hash` when set.
"""
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
universal_variant = universal_build_variant()
repo_path = Path(
snapshot_download(
repo_id,
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
)
)
try:
return _load_kernel_from_path(repo_path, package_name, variant_locks)
except FileNotFoundError:
# Redo with more specific error message.
raise FileNotFoundError(
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
)
def _load_kernel_from_path(
repo_path: Path,
package_name: str,
variant_locks: Optional[Dict[str, VariantLock]] = None,
) -> Tuple[str, Path]:
variant = build_variant()
universal_variant = universal_build_variant()
variant_path = repo_path / "build" / variant
universal_variant_path = repo_path / "build" / universal_variant
if not variant_path.exists() and universal_variant_path.exists():
# Fall back to universal variant.
variant = universal_variant
variant_path = universal_variant_path
if variant_locks is not None:
variant_lock = variant_locks.get(variant)
if variant_lock is None:
raise ValueError(f"No lock found for build variant: {variant}")
validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash)
module_init_path = variant_path / package_name / "__init__.py"
if not os.path.exists(module_init_path):
raise FileNotFoundError(
f"Kernel at path `{repo_path}` does not have build: {variant}"
)
return package_name, variant_path
def install_kernel_all_variants(
repo_id: str,
revision: str,
local_files_only: bool = False,
variant_locks: Optional[Dict[str, VariantLock]] = None,
) -> Path:
repo_path = Path(
snapshot_download(
repo_id,
allow_patterns="build/*",
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
)
)
if variant_locks is not None:
for entry in (repo_path / "build").iterdir():
variant = entry.parts[-1]
variant_lock = variant_locks.get(variant)
if variant_lock is None:
raise ValueError(f"No lock found for build variant: {variant}")
validate_kernel(
repo_path=repo_path, variant=variant, hash=variant_lock.hash
)
return repo_path / "build"
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
"""
Download and import a kernel from the Hugging Face Hub.
The kernel is downloaded from the repository `repo_id` at
branch/commit/tag `revision`.
"""
package_name, package_path = install_kernel(repo_id, revision=revision)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
"""
Import a kernel from a local kernel repository path.
"""
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def has_kernel(repo_id: str, revision: str = "main") -> bool:
"""
Check whether a kernel build exists for the current environment
(Torch version and compute framework).
"""
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
universal_variant = universal_build_variant()
if file_exists(
repo_id,
revision=revision,
filename=f"build/{universal_variant}/{package_name}/__init__.py",
):
return True
return file_exists(
repo_id,
revision=revision,
filename=f"build/{variant}/{package_name}/__init__.py",
)
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
"""
Get a pre-downloaded, locked kernel.
If `lockfile` is not specified, the lockfile will be loaded from the
caller's package metadata.
"""
if lockfile is None:
locked_sha = _get_caller_locked_kernel(repo_id)
else:
with open(lockfile, "r") as f:
locked_sha = _get_locked_kernel(repo_id, f.read())
if locked_sha is None:
raise ValueError(
f"Kernel `{repo_id}` is not locked. Please lock it with `kernels lock <project>` and then reinstall the project."
)
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
universal_variant = universal_build_variant()
repo_path = Path(
snapshot_download(
repo_id,
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
cache_dir=CACHE_DIR,
revision=locked_sha,
local_files_only=True,
)
)
variant_path = repo_path / "build" / variant
universal_variant_path = repo_path / "build" / universal_variant
if not variant_path.exists() and universal_variant_path.exists():
# Fall back to universal variant.
variant = universal_variant
variant_path = universal_variant_path
module_init_path = variant_path / package_name / "__init__.py"
if not os.path.exists(module_init_path):
raise FileNotFoundError(
f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download <project>`"
)
return import_from_path(package_name, variant_path / package_name / "__init__.py")
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
"""Get a kernel using a lock file."""
locked_sha = _get_caller_locked_kernel(repo_id)
if locked_sha is None:
raise ValueError(f"Kernel `{repo_id}` is not locked")
package_name, package_path = install_kernel(
repo_id, locked_sha, local_files_only=local_files_only
)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
for dist in _get_caller_distributions():
lock_json = dist.read_text("kernels.lock")
if lock_json is None:
continue
locked_sha = _get_locked_kernel(repo_id, lock_json)
if locked_sha is not None:
return locked_sha
return None
def _get_locked_kernel(repo_id: str, lock_json: str) -> Optional[str]:
for kernel_lock_json in json.loads(lock_json):
kernel_lock = KernelLock.from_json(kernel_lock_json)
if kernel_lock.repo_id == repo_id:
return kernel_lock.sha
return None
def _get_caller_distributions() -> List[Distribution]:
module = _get_caller_module()
if module is None:
return []
# Look up all possible distributions that this module could be from.
package = module.__name__.split(".")[0]
dist_names = importlib.metadata.packages_distributions().get(package)
if dist_names is None:
return []
return [importlib.metadata.distribution(dist_name) for dist_name in dist_names]
def _get_caller_module() -> Optional[ModuleType]:
stack = inspect.stack()
# Get first module in the stack that is not the current module.
first_module = inspect.getmodule(stack[0][0])
for frame in stack[1:]:
module = inspect.getmodule(frame[0])
if module is not None and module != first_module:
return module
return first_module
def validate_kernel(*, repo_path: Path, variant: str, hash: str):
"""Validate the given build variant of a kernel against a hasht."""
variant_path = repo_path / "build" / variant
# Get the file paths. The first element is a byte-encoded relative path
# used for sorting. The second element is the absolute path.
files: List[Tuple[bytes, Path]] = []
# Ideally we'd use Path.walk, but it's only available in Python 3.12.
for dirpath, _, filenames in os.walk(variant_path):
for filename in filenames:
file_abs = Path(dirpath) / filename
# Python likes to create files when importing modules from the
# cache, only hash files that are symlinked blobs.
if file_abs.is_symlink():
files.append(
(
file_abs.relative_to(variant_path).as_posix().encode("utf-8"),
file_abs,
)
)
m = hashlib.sha256()
for filename_bytes, full_path in sorted(files):
m.update(filename_bytes)
blob_filename = full_path.resolve().name
if len(blob_filename) == 40:
# SHA-1 hashed, so a Git blob.
m.update(git_hash_object(full_path.read_bytes()))
elif len(blob_filename) == 64:
# SHA-256 hashed, so a Git LFS blob.
m.update(hashlib.sha256(full_path.read_bytes()).digest())
else:
raise ValueError(f"Unexpected blob filename length: {len(blob_filename)}")
computedHash = f"sha256-{m.hexdigest()}"
if computedHash != hash:
raise ValueError(
f"Lock file specifies kernel with hash {hash}, but downloaded kernel has hash: {computedHash}"
)
def git_hash_object(data: bytes, object_type: str = "blob"):
"""Calculate git SHA1 of data."""
header = f"{object_type} {len(data)}\0".encode()
m = hashlib.sha1()
m.update(header)
m.update(data)
return m.digest()
def package_name_from_repo_id(repo_id: str) -> str:
return repo_id.split("/")[-1].replace("-", "_")

186
src/kernels/wheel.py Normal file
View File

@ -0,0 +1,186 @@
import email.policy
import os
from dataclasses import dataclass
from email.message import Message
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from typing import Optional
try:
KERNELS_VERSION = version("kernels")
except PackageNotFoundError:
KERNELS_VERSION = "unknown"
@dataclass
class Metadata:
name: str
version: str
cuda_version: Optional[str]
cxx_abi_version: Optional[str]
torch_version: Optional[str]
os: Optional[str]
platform: Optional[str]
@property
def is_universal(self) -> bool:
return self.platform is None
def build_variant_to_wheel(
repo_id: str,
*,
version: str,
variant_path: Path,
wheel_dir: Path,
manylinux_version: str = "2.28",
python_version: str = "3.9",
) -> Path:
"""
Create a wheel file from the variant path.
"""
name = repo_id.split("/")[-1].replace("_", "-")
metadata = extract_metadata(name, version, variant_path)
return build_wheel(
metadata,
variant_path=variant_path,
wheel_dir=wheel_dir,
manylinux_version=manylinux_version,
python_version=python_version,
)
def extract_metadata(name: str, version: str, variant_path: Path) -> Metadata:
"""
Extract metadata from the variant path.
"""
if variant_path.name == "torch-universal":
return Metadata(
name=name,
version=version,
cuda_version=None,
cxx_abi_version=None,
torch_version=None,
os=None,
platform=None,
)
if not variant_path.name.startswith("torch"):
raise ValueError("Currently only conversion of Torch kernels is supported.")
variant_parts = variant_path.name.removeprefix("torch").split("-")
if len(variant_parts) != 5:
raise ValueError(f"Invalid variant name: {variant_path.name}")
torch_version = f"{variant_parts[0][:-1]}.{variant_parts[0][-1:]}"
cpp_abi_version = variant_parts[1].removeprefix("cxx")
cuda_version = variant_parts[2].removeprefix("cu")
platform = variant_parts[3].replace("-", "_")
os = variant_parts[4]
return Metadata(
name=name,
version=version,
cuda_version=cuda_version,
cxx_abi_version=cpp_abi_version,
torch_version=torch_version,
os=os,
platform=platform,
)
def build_wheel(
metadata: Metadata,
*,
variant_path: Path,
wheel_dir: Path,
manylinux_version: str = "2.28",
python_version: str = "3.9",
) -> Path:
"""
Build the wheel file.
"""
try:
from wheel.wheelfile import WheelFile # type: ignore
except ImportError:
raise ImportError(
"The 'wheel' package is required to build wheels. Please install it with: `pip install wheel`"
)
name = metadata.name.replace("-", "_")
python_version_flat = python_version.replace(".", "")
if metadata.is_universal:
python_tag = f"py{python_version_flat}"
abi_tag = "none"
platform_tag = "any"
wheel_filename = (
f"{name}-{metadata.version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
)
dist_info_dir_name = f"{name}-{metadata.version}.dist-info"
root_is_purelib = "true"
requires_dist_torch = "torch"
else:
python_tag = f"cp{python_version_flat}"
abi_tag = "abi3"
if (
metadata.torch_version is None
or metadata.cuda_version is None
or metadata.cxx_abi_version is None
or metadata.os is None
or metadata.platform is None
):
raise ValueError(
"Torch version, CUDA version, C++ ABI version, OS, and platform must be specified for non-universal wheels."
)
local_version = f"torch{metadata.torch_version.replace('.', '')}cu{metadata.cuda_version}cxx{metadata.cxx_abi_version}"
if metadata.os == "linux":
platform_tag = (
f"manylinux_{manylinux_version.replace('.', '_')}_{metadata.platform}"
)
else:
platform_tag = f"{metadata.os}_{metadata.platform.replace('-', '_')}"
wheel_filename = f"{name}-{metadata.version}+{local_version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
dist_info_dir_name = f"{name}-{metadata.version}+{local_version}.dist-info"
root_is_purelib = "false"
requires_dist_torch = f"torch=={metadata.torch_version}.*"
wheel_path = wheel_dir / wheel_filename
wheel_msg = Message(email.policy.compat32)
wheel_msg.add_header("Wheel-Version", "1.0")
wheel_msg.add_header("Generator", f"kernels ({KERNELS_VERSION})")
wheel_msg.add_header("Root-Is-Purelib", root_is_purelib)
wheel_msg.add_header("Tag", f"{python_tag}-{abi_tag}-{platform_tag}")
metadata_msg = Message(email.policy.compat32)
metadata_msg.add_header("Metadata-Version", "2.1")
metadata_msg.add_header("Name", name)
metadata_msg.add_header("Version", metadata.version)
metadata_msg.add_header("Summary", f"{name} kernel")
metadata_msg.add_header("Requires-Python", ">=3.9")
metadata_msg.add_header("Requires-Dist", requires_dist_torch)
source_pkg_dir = variant_path / name
with WheelFile(wheel_path, "w") as wheel_file:
for root, dirnames, filenames in os.walk(source_pkg_dir):
for filename in filenames:
if filename.endswith(".pyc"):
continue
abs_filepath = os.path.join(root, filename)
entry_name = os.path.relpath(abs_filepath, variant_path)
wheel_file.write(abs_filepath, entry_name)
wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL")
wheel_file.writestr(wheel_metadata_path, str(wheel_msg).encode("utf-8"))
metadata_path = os.path.join(dist_info_dir_name, "METADATA")
wheel_file.writestr(metadata_path, str(metadata_msg).encode("utf-8"))
return wheel_path

10
tests/conftest.py Normal file
View File

@ -0,0 +1,10 @@
import sys
import pytest
def pytest_runtest_setup(item):
if "linux_only" in item.keywords and not sys.platform.startswith("linux"):
pytest.skip("skipping Linux-only test on non-Linux platform")
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
pytest.skip("skipping macOS-only test on non-macOS platform")

View File

@ -0,0 +1,94 @@
[
{
"repo_id": "kernels-community/activation",
"sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538",
"variants": {
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-aarch64-linux": {
"hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-aarch64-linux": {
"hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu118-x86_64-linux": {
"hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-aarch64-linux": {
"hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-x86_64-linux": {
"hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-aarch64-linux": {
"hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-x86_64-linux": {
"hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597",
"hash_type": "git_lfs_concat"
}
}
},
{
"repo_id": "kernels-community/triton-scaled-mm",
"sha": "af10d8c1affe8efce93d228c3e6e64ff673d493f",
"variants": {
"torch-universal": {
"hash": "sha256-b843c5f30b52b6c1c56fca28cb0cf453be71d6ce7d308f383dce71a8050f7b52",
"hash_type": "git_lfs_concat"
}
}
}
]

View File

@ -0,0 +1,3 @@
[tool.kernels.dependencies]
"kernels-community/activation" = ">=0.0.2"
"kernels-community/triton-scaled-mm" = ">=0.0.2"

View File

@ -1,6 +1,7 @@
import pytest
import torch
from hf_kernels import get_kernel
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
@pytest.fixture
@ -8,6 +9,24 @@ def kernel():
return get_kernel("kernels-community/activation")
@pytest.fixture
def local_kernel():
package_name, path = install_kernel("kernels-community/activation", "main")
# Path is the build variant path (build/torch-<...>), so the grandparent
# is the kernel repository path.
return get_local_kernel(path.parent.parent, package_name)
@pytest.fixture
def metal_kernel():
return get_kernel("kernels-test/relu-metal")
@pytest.fixture
def universal_kernel():
return get_kernel("kernels-community/triton-scaled-mm")
@pytest.fixture
def device():
if not torch.cuda.is_available():
@ -15,6 +34,7 @@ def device():
return "cuda"
@pytest.mark.linux_only
def test_gelu_fast(kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
@ -28,3 +48,59 @@ def test_gelu_fast(kernel, device):
)
assert torch.allclose(y, expected)
@pytest.mark.linux_only
def test_local_kernel(local_kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
local_kernel.gelu_fast(y, x)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
device=device,
dtype=torch.float16,
)
assert torch.allclose(y, expected)
@pytest.mark.darwin_only
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_relu_metal(metal_kernel, dtype):
x = torch.arange(-10, 10, dtype=dtype, device="mps")
y = metal_kernel.relu(x)
assert torch.allclose(y, torch.relu(x))
@pytest.mark.linux_only
@pytest.mark.parametrize(
"kernel_exists",
[
("kernels-community/activation", "main", True),
("kernels-community/triton-layer-norm", "main", True),
# Repo only contains Torch 2.4 kernels (and we don't
# support/test against this version).
("kernels-test/only-torch-2.4", "main", False),
("google-bert/bert-base-uncased", "87565a309", False),
],
)
def test_has_kernel(kernel_exists):
repo_id, revision, kernel = kernel_exists
assert has_kernel(repo_id, revision=revision) == kernel
@pytest.mark.linux_only
def test_universal_kernel(universal_kernel):
torch.manual_seed(0)
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
B = torch.randint(-10, 10, (128, 96), dtype=torch.int8, device="cuda")
scale_a = torch.tensor(0.4, dtype=torch.float16, device="cuda")
scale_b = torch.tensor(0.6, dtype=torch.float16, device="cuda")
out = universal_kernel.triton_scaled_mm(A, B, scale_a, scale_b, torch.float16)
out_check = (A * scale_a) @ (B * scale_b)
out_check = out_check.to(torch.float16)
torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1)

View File

@ -1,6 +1,7 @@
import pytest
import torch
from hf_kernels import get_kernel
from kernels import get_kernel
@pytest.fixture
@ -15,18 +16,21 @@ def device():
return "cuda"
@pytest.mark.linux_only
def test_gelu_small(kernel, device, benchmark):
x = torch.randn(32, 32, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.linux_only
def test_gelu_medium(kernel, device, benchmark):
x = torch.randn(128, 128, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.linux_only
def test_gelu_large(kernel, device, benchmark):
x = torch.randn(512, 512, dtype=torch.float16, device=device)
y = torch.empty_like(x)

View File

@ -0,0 +1,27 @@
from dataclasses import dataclass
from pathlib import Path
import pytest
from kernels import load_kernel
from kernels.cli import download_kernels
# Mock download arguments class.
@dataclass
class DownloadArgs:
all_variants: bool
project_dir: Path
def test_download_all_hash_validation():
project_dir = Path(__file__).parent / "kernel_locking"
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
@pytest.mark.linux_only
def test_load_locked():
project_dir = Path(__file__).parent / "kernel_locking"
# Also validates that hashing works correctly.
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")

508
tests/test_layer.py Normal file
View File

@ -0,0 +1,508 @@
from contextlib import nullcontext
import pytest
import torch
import torch.nn as nn
from torch.nn import functional as F
from kernels import (
Device,
LayerRepository,
Mode,
kernelize,
register_kernel_mapping,
use_kernel_forward_from_hub,
)
from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping
kernel_layer_mapping = {
"SiluAndMul": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
},
"SiluAndMulNoCompile": {
"cuda": LayerRepository(
repo_id="kernels-test/op-without-fake-test",
layer_name="SiluAndMul",
)
},
"SiluAndMulStringDevice": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
},
}
register_kernel_mapping(kernel_layer_mapping)
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
d = input.shape[-1] // 2
return F.silu(input[..., :d]) * input[..., d:]
@use_kernel_forward_from_hub("SiluAndMulNoCompile")
class SiluAndMulNoCompileKernel(SiluAndMul):
pass
@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMulWithKernel(SiluAndMul):
pass
@use_kernel_forward_from_hub("SiluAndMulStringDevice")
class SiluAndMulStringDevice(SiluAndMul):
pass
@use_kernel_forward_from_hub("Linear")
class TorchLinearWithCounter(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
return super().forward(input)
def test_arg_kinds():
@use_kernel_forward_from_hub("ArgKind")
class ArgKind(nn.Module):
def forward(
self,
arg1,
arg2,
*,
kwarg1,
kwarg2=42,
):
return (arg1, arg2, kwarg1, kwarg2)
arg_kind = ArgKind()
assert arg_kind("foo", "bar", kwarg1="baz") == ("foo", "bar", "baz", 42)
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_hub_forward(cls, device):
torch.random.manual_seed(0)
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
assert silu_and_mul.n_calls == 1
if device == "cuda":
assert silu_and_mul_with_kernel.n_calls == 0
else:
assert silu_and_mul_with_kernel.n_calls == 1
def test_layer_fallback_works():
@use_kernel_forward_from_hub("SiluAndMulNonExisting")
class SiluAndMulWithKernelFallback(SiluAndMul):
pass
# Check that we don't raise an exception for a non-existing kernel.
silu_and_mul = SiluAndMulWithKernelFallback()
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_without_fallback(cls, device):
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), dtype=torch.float32, device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel.eval()
ctx = (
pytest.raises(ValueError, match="does not support mode")
if cls is SiluAndMulNoCompileKernel
else nullcontext()
)
with ctx:
silu_and_mul_with_kernel = kernelize(
silu_and_mul_with_kernel,
device=device,
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
use_fallback=False,
)
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
Y_compiled = silu_and_mul_compiled(X)
torch.testing.assert_close(Y_compiled, Y)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_with_fallback(cls, device):
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), dtype=torch.float32, device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel.eval()
silu_and_mul_with_kernel = kernelize(
silu_and_mul_with_kernel,
device=device,
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
)
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
Y_compiled = silu_and_mul_compiled(X)
torch.testing.assert_close(Y_compiled, Y)
def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
}
extra_mapping1 = {
"TestKernel": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
}
}
with use_kernel_mapping(extra_mapping1):
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
extra_mapping2 = {
"SiluAndMul": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-community/non-existing",
layer_name="SiluAndMul",
revision="layers",
)
}
}
with use_kernel_mapping(extra_mapping2):
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/non-existing"
)
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/activation"
)
with use_kernel_mapping(extra_mapping2, inherit_mapping=False):
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/non-existing"
)
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/activation"
)
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
}
def test_validate_kernel_layer():
class BadLayer(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.foo = 42
with pytest.raises(TypeError, match="not override"):
_validate_layer(cls=BadLayer, check_cls=SiluAndMul)
class BadLayer2(nn.Module):
foo: int = 42
with pytest.raises(TypeError, match="not contain additional members"):
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
class BadLayer3(nn.Module):
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
with pytest.raises(TypeError, match="different number of arguments"):
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
class BadLayer4(nn.Module):
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
with pytest.raises(TypeError, match="different kind of arguments"):
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
def test_invalid_mode_for_mapping_rejected():
linear = TorchLinearWithCounter(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
)
}
}
}
):
with pytest.raises(ValueError, match="does not support backward"):
kernelize(linear, mode=Mode.TRAINING)
def test_kernel_modes():
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: layer without further specification, becomes the
# base layer.
with use_kernel_mapping(
{
"Linear": {
"cuda": LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING)
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
assert linear.n_calls == 0
# Case 2: register a kernel just for training. If no base kernel
# layer is registered, we fall back to the original layer.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Training has a kernel, so fallback.
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# No kernel for training + torch.compile, so fallback.
assert linear.n_calls == 2
# Case 3: register a kernel just for training and one for fallback.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.DEFAULT: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Uses the base kernel.
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Uses the training kernel.
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# Uses the base kernel.
assert linear.n_calls == 2
# Case 4: register a kernel with two preferences.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# No inference kernel, so fallback.
assert linear.n_calls == 3
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# No training kernel, so fallback.
assert linear.n_calls == 4
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# We do have a training + torch.compile kernel.
assert linear.n_calls == 4
@pytest.mark.linux_only
def test_fallback_used_when_training():
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: kernel with explicit backward support should always
# use the kernel.
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
linear.train()
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
# Case 2: kernel with implicit backward support should always
# use the kernel.
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearImplicitBackward",
)
}
}
):
linear.train()
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
def test_invalid_mode_rejected():
with pytest.raises(ValueError, match="mutually exclusive"):
_ = Mode.INFERENCE | Mode.TRAINING
with pytest.raises(ValueError, match="cannot be combined with other modes"):
_ = Mode.DEFAULT | Mode.TORCH_COMPILE
with pytest.raises(
ValueError, match="can only be used to register kernel mappings"
):
kernelize(torch.nn.Linear(32, 32), mode=Mode.DEFAULT)
with pytest.raises(ValueError, match="mode must contain"):
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)