Compare commits

...

51 Commits

Author SHA1 Message Date
b7b5f40143 Set version to 0.7.0 2025-07-07 13:09:01 +00:00
b87e6fadbe Set version to 0.7.0.dev0 (#104) 2025-07-07 14:56:43 +02:00
fc935d9874 Support registering inference/training-specific layers (#103)
* Support registering inference/training-specific layers

This change makes it possible to register kernels specialized for
inference, training, and/or `torch.compile`. To do so, the mapping
notation is extended to support registering specialized kernels
for a specific 'mode'. For instance, the following mapping,

```python
kernel_layer_mapping = {
    "SiluAndMul": {
        "cuda": {
          Mode.DEFAULT: LayerRepository(
              repo_id="kernels-community/activation",
              layer_name="SiluAndMul",
          ),
          Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
              repo_id="kernels-community/activation-training-optimized",
              layer_name="SiluAndMul",
          ),
      }
    }
}
```

uses `kernels-community/activation` by default, but will switch to
using `kernels-community/activation-training-optimized` if a model
is kernelized for training and `torch.compile`.

To make it easier to add more modes in the future and to unify the
`register_kernel_mapping` and `kernelize` signatures, the `training`
and `needs_torch_compile` arguments of `kernelize` are replaced by
a single `mode` argument:

```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```

* Documentation fixes

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Add note on when the fallback is used

* Tighten up some Mode checks

* Fix ruff check

* Attempt to fix mypy errors

* More typing fixes

* Ignore Python < 3.11 type check SNAFU

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-07-04 19:57:14 +02:00
3622e1f8dd Add get_local_kernel function (#102)
This function loads a kernel from a local repository (e.g. the output
of kernel-builder), which can be handy for testing.
2025-07-01 13:58:47 +02:00
a7f3b2e8ed Set version to 0.6.2.dev0 (#100) 2025-06-25 09:48:09 +02:00
a6ab5d83ba Make the flake work on Darwin (#98) 2025-06-24 20:35:21 +02:00
4f9f1abfb9 darwin: fix variant CPU for aarch64 (#97) 2025-06-24 20:35:07 +02:00
f94b7780a6 CI: main triton-layer-norm has docs, branch is gone (#99) 2025-06-24 16:40:36 +02:00
bd28883775 Set version to 0.6.1.dev1 (#96) 2025-06-20 11:43:26 +02:00
498429e322 Add README generation for layers (#94) 2025-06-20 10:16:50 +02:00
09c991af4b Add macOS requirements (#95) 2025-06-16 17:20:47 +02:00
bcf8df5875 Bump version to 0.6.0.dev0 (#93) 2025-06-04 13:59:32 +02:00
239afff6f5 Update Nix flake dependencies (#92)
* Update Nix flake dependencies

To ensure that we can test with Torch 2.7 kernels in the development
environment.

* Update nix fmt to use nixfmt-tree
2025-06-04 12:13:19 +02:00
c5ec6b900a Hotfix: add FAQ (#91) 2025-06-04 09:52:39 +02:00
3a635eaeea Automatic fallback for kernels that don't support training (#90)
For kernels that do not support backward, fall back to the original
implementation if `model.train(True)` is called. This removes the
need for the `needs_backward` argument of `kernelize`.
2025-06-03 19:13:57 +02:00
32ec496c5a Make the forward pass torch.compile compatible (#87)
* first commit

* style

* update

* fix

* different approach

* Polish kernelize

- Process comment from the PR.
- Replacement should be on instances, not the class.
- Remove torch compile checks (not relevant during kernelize). We
  might add it back in a different way in another commit: add an
  option to `kernelize`.

* Fixup tests

* Fix `torch.compile` support

* Remove some unused code

* Sync the docs

* CI: update Torch versions

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-06-03 15:06:02 +02:00
848c6db87b Add support for Metal builds (#89)
* Add support for Metal builds

* Add Metal test, gate tests by OS where necessary
2025-05-30 15:54:28 +02:00
fabb8c52d1 Add generate-readme subcommand for generating a README (#88)
* Add `generate-readme` subcommand for generating a README

This README includes all the top-level functions with docs (if
docstrings are available).

* CI: attempt README generation

* Add PyYAML dependencies

* Typing fixes
2025-05-21 15:43:53 +02:00
d66260dd83 kernels: add the to-wheel subcommand (#84)
* kernels: add the `to-wheel` subcommand

This subcommand accepts a kernel repo and version as arguments:

    kernels to-wheel kernels-community/activation 0.0.3

Wheels will then be generated for every build variant.

* CI: check kernel -> wheel conversion

* No typing for wheel.wheelfile
2025-05-08 17:30:06 +02:00
daac8078fc CI: fix some stubs (#83) 2025-05-07 14:43:57 +02:00
fcb9a80ce6 Set version to 0.5.0 (#82) 2025-05-06 11:45:26 +02:00
c25bb32e6e Add publishing workflow (#81) 2025-05-06 09:29:08 +00:00
2036892762 Allow layers to opt in to torch.compile (#79)
* Allow layers to opt in to `torch.compile`

This change allows a layer to set the `can_torch_compile` class
variable to indicate that the layer is compatible with `torch.compile`.
When enabled, the layer does not fall back to the original
implementation when `torch.compile` is used.

* Comment fixes

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-05-06 09:36:33 +02:00
0f0de049cf docs: link to autogenerated build variant list (#77) 2025-04-16 17:25:51 +02:00
59597df03e Specify required aarch64 build variants (#76) 2025-04-15 16:09:44 +02:00
5e938ede40 locking docs: fix command name (kernel -> kernels) (#74) 2025-04-14 16:02:00 +02:00
cf530c283a Set version to 0.4.4 (#73) 2025-04-11 10:23:26 +02:00
437f910336 Add has_kernel function (#69)
* Add `has_kernel` function

This function checks whether a kernel build exists for the current
environment (Torch version and compute framework).

* Test kernel repo that only contains Torch 2.4
2025-04-11 10:12:37 +02:00
6f1a6067c8 feat: add logo and shields (#72) 2025-04-11 10:07:24 +02:00
1d14abcef0 Do not use kernels without backward when training (#68)
* Do not use kernels without backward when training

* Update repo for backwards marker test
2025-04-11 10:05:57 +02:00
6fd2112e22 Set version to 0.4.3 (#71) 2025-04-10 11:57:15 +02:00
70f56ff856 Support DISABLE_KERNEL_MAPPING env var for completely disabling kernel mappings (#70)
* Disable kernel mappings with `DISABLE_KERNEL_MAPPING=1`

* Rename HF_KERNELS_CACHE to KERNELS_CACHE

But still recognize the old variant for compatibility.

* Add documentation for environment variables
2025-04-10 11:37:54 +02:00
7178b0b86c Add Apache License version 2.0 (#66)
Fixes #64
2025-04-04 20:35:29 +02:00
0bbf90a564 Update ABI requirement to manylinux_2_28 (#65) 2025-04-04 19:38:15 +02:00
27d6ffcb80 Add more details about the ABI requirements (#63) 2025-03-31 14:29:30 +02:00
f7bd21438b Set version to 0.4.2 (#62) 2025-03-27 16:57:28 +01:00
6174febb4b Add warning when layer_name not present in _KERNEL_MAPPING (#61)
* add warning

* fix import order
2025-03-27 16:22:58 +01:00
ff55bc201b Add support for fetching ROCm kernels (#59) 2025-03-25 15:11:03 +01:00
3808108d62 doc: add versioning (#58) 2025-03-24 16:48:20 +01:00
c4a16ef462 Actually export use_kernel_mapping at the top-level (#57)
* Actually export `use_kernel_mapping` at the top-level

* Set version to 0.4.1
2025-03-24 12:44:00 +01:00
9762794dd2 Set version to 0.4.0 (#56) 2025-03-21 20:49:01 +01:00
b7d6867c52 use_kernel_mapping: add inherit_mapping option (#55)
`inherit_mapping` is the default and extends the existing mapping
with the given mapping. If `inherit_mapping` is `False`, existing
mappings are not inherited.
2025-03-21 17:28:45 +01:00
fbcd0f2ebd Set version to 0.3.3 (#54) 2025-03-20 16:09:11 +01:00
5af46eca94 Align dependency versions with transformers (#53) 2025-03-20 15:13:45 +01:00
747dd66876 Set version to 0.3.2 (#51) 2025-03-20 11:46:36 +01:00
920590a592 Also export replace_kernel_forward_from_hub (#52) 2025-03-20 11:46:18 +01:00
5208ac4be5 Make torch an extra/dev dependency (#50)
To support use of this package when Torch is optional.
2025-03-20 10:18:19 +01:00
22eaba2826 Set version to 0.3.1 (#49) 2025-03-19 16:35:10 +01:00
9521ba79a0 Fix forward positional argument handling (#48) 2025-03-19 15:54:51 +01:00
9861a5bdef Fix forward positional argument handling (#48) 2025-03-19 15:34:35 +01:00
1c7c87c960 Set version to 0.3.0 (#47) 2025-03-19 12:02:02 +01:00
26 changed files with 2757 additions and 189 deletions

120
.github/workflows/publish.yml vendored Normal file
View File

@ -0,0 +1,120 @@
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
on: push
jobs:
build:
name: Build distribution 📦
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.9"
- name: Install pypa/build
run: >-
python3 -m
pip install
build
--user
- name: Build a binary wheel and a source tarball
run: python3 -m build
- name: Store the distribution packages
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: dist/
publish-to-pypi:
name: >-
Publish Python 🐍 distribution 📦 to PyPI
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
needs:
- build
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/kernels
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
github-release:
name: >-
Sign the Python 🐍 distribution 📦 with Sigstore
and upload them to GitHub Release
needs:
- publish-to-pypi
runs-on: ubuntu-latest
permissions:
contents: write # IMPORTANT: mandatory for making GitHub Releases
id-token: write # IMPORTANT: mandatory for sigstore
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Sign the dists with Sigstore
uses: sigstore/gh-action-sigstore-python@v3.0.0
with:
inputs: >-
./dist/*.tar.gz
./dist/*.whl
- name: Create GitHub Release
env:
GITHUB_TOKEN: ${{ github.token }}
run: >-
gh release create
"$GITHUB_REF_NAME"
--repo "$GITHUB_REPOSITORY"
--notes ""
- name: Upload artifact signatures to GitHub Release
env:
GITHUB_TOKEN: ${{ github.token }}
# Upload to GitHub Release using the `gh` CLI.
# `dist/` contains the built packages, and the
# sigstore-produced signatures and certificates.
run: >-
gh release upload
"$GITHUB_REF_NAME" dist/**
--repo "$GITHUB_REPOSITORY"
publish-to-testpypi:
name: Publish Python 🐍 distribution 📦 to TestPyPI
needs:
- build
runs-on: ubuntu-latest
environment:
name: testpypi
url: https://test.pypi.org/p/kernels
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
skip-existing: true # Only upload when the version is unique.

View File

@ -24,7 +24,7 @@ jobs:
max-parallel: 4
matrix:
python-version: ["3.10", "3.12"]
torch-version: ["2.5.1", "2.6.0"]
torch-version: ["2.6.0", "2.7.0"]
env:
UV_PYTHON_PREFERENCE: only-managed
@ -52,3 +52,20 @@ jobs:
- name: Run tests
run: uv run pytest tests
- name: Check kernel conversion
run: |
uv pip install wheel
uv run kernels to-wheel kernels-community/triton-layer-norm 0.0.1
uv pip install triton_layer_norm-0.0.1*.whl
uv run python -c "import triton_layer_norm"
- name: Check README generation
# For now, just checks that generation doesn't fail.
run: |
uv run kernels generate-readme kernels-community/triton-layer-norm
- name: Import check without torch
run: |
uv pip uninstall torch
python -c "import kernels"

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -1,5 +1,16 @@
# kernels
<div align="center">
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
<p align="center">
<a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
<a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
<a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
</p>
</div>
<hr/>
The Kernel Hub allows Python libraries and applications to load compute
kernels directly from the [Hub](https://hf.co/). To support this kind
of dynamic loading, Hub kernels differ from traditional Python kernel
@ -47,6 +58,8 @@ the Hub.
- [Using layers](docs/layers.md)
- [Locking kernel versions](docs/locking.md)
- [Environment variables](docs/env.md)
- [Using kernels in a Docker container](docs/docker.md)
- [Kernel requirements](docs/kernel-requirements.md)
- [Frequently Asked Questions](docs/faq.md)
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)

10
docs/env.md Normal file
View File

@ -0,0 +1,10 @@
# Environment variables
## `KERNELS_CACHE`
The directory to use as the local kernel cache. If not set, the cache
of the `huggingface_hub` package is used.
## `DISABLE_KERNEL_MAPPING`
Disables kernel mappings for [`layers`](layers.md).

13
docs/faq.md Normal file
View File

@ -0,0 +1,13 @@
# FAQ
## Why is the kernelization step needed?
In earlier versions of `kernels`, a layer's `forward` was replaced by
`use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The
new `forward` would dispatch to a kernel based on the device type,
whether a model was training, etc. However, this approach was
fundamentally incompatible with `torch.compile` since it relied
on data-dependent branching.
To avoid branching, we have to make dispatch decisions ahead of time,
which is what the `kernelize` function does.

View File

@ -1,8 +1,11 @@
# Kernel requirements
Kernels on the Hub must fulfill the requirements outlined on this page.
Kernels on the Hub must fulfill the requirements outlined on this page. By
ensuring kernels are compliant, they can be used on a wide range of Linux
systems and Torch builds.
You can use [kernel-builder](https://github.com/huggingface/kernel-builder/)
to build conforming kernels.
to build compliant kernels.
## Directory layout
@ -10,52 +13,72 @@ A kernel repository on the Hub must contain a `build` directory. This
directory contains build variants of a kernel in the form of directories
following the template
`<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`.
For example `build/torch26-cxx98-cu118-x86_64-linux`. The currently
recommended build variants are:
For example `build/torch26-cxx98-cu118-x86_64-linux`.
- `torch25-cxx11-cu118-x86_64-linux`
- `torch25-cxx11-cu121-x86_64-linux`
- `torch25-cxx11-cu124-x86_64-linux`
- `torch25-cxx98-cu118-x86_64-linux`
- `torch25-cxx98-cu121-x86_64-linux`
- `torch25-cxx98-cu124-x86_64-linux`
- `torch26-cxx11-cu118-x86_64-linux`
- `torch26-cxx11-cu124-x86_64-linux`
- `torch26-cxx11-cu126-x86_64-linux`
- `torch26-cxx98-cu118-x86_64-linux`
- `torch26-cxx98-cu124-x86_64-linux`
- `torch26-cxx98-cu126-x86_64-linux`
This list will be updated as new PyTorch versions are released. Kernels
that are in pure Python (e.g. Triton kernels) only need to provide a
single build variant:
- `torch-universal`
Each variant directory should contain a single directory with the same name
Each variant directory must contain a single directory with the same name
as the repository (replacing `-` by `_`). For instance, kernels in the
`kernels-community/activation` repository have a directories like
`build/<variant>/activation`. This directory
must be a Python package with an `__init__.py` file.
## Build variants
A kernel can be compliant for a specific compute framework (e.g. CUDA) or
architecture (e.g. x86_64). For compliance with a compute framework and
architecture combination, all the variants from the [build variant list](https://github.com/huggingface/kernel-builder/blob/main/docs/build-variants.md)
must be available for that combination.
## Versioning
Kernels are versioned on the Hub using Git tags. Version tags must be of
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
to resolve the version constraints.
## Native Python module
Kernels will typically contain a native Python module with precompiled
compute kernels and bindings. This module must fulfill the following
requirements:
compute kernels and bindings. This module must fulfill the requirements
outlined in this section. For all operating systems, a kernel must not
have dynamic library dependencies outside:
- Torch;
- CUDA/ROCm libraries installed as dependencies of Torch.
### Linux
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
for compatibility with Python 3.9 and later.
- Compatible with glibc 2.27 or later. This means that no symbols
from later versions must be used. To archive this, the module should
be built against this glibc version. **Warning:** libgcc must also be
built against glibc 2.27 to avoid leaking symbols.
- No dynamic linkage against libstdc++/libc++. Linkage for C++ symbols
must be static.
- No dynamic library dependencies outside Torch or CUDA libraries
installed as dependencies of Torch.
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
This means that the extension **must not** use symbols versions higher than:
(These requirements will be updated as new PyTorch versions are released.)
- GLIBC 2.28
- GLIBCXX 3.4.24
- CXXABI 1.3.11
- GCC 7.0.0
These requirement can be checked with the ABI checker (see below).
### macOS
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
for compatibility with Python 3.9 and later.
- macOS deployment target 15.0.
- Metal 3.0 (`-std=metal3.0`).
The ABI3 requirement can be checked with the ABI checker (see below).
### ABI checker
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
```bash
$ cargo install kernel-abi-check
$ kernel-abi-check result/relu/_relu_e87e0ca_dirty.abi3.so
🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9
✅ No compatibility issues found
```
## Torch extension
@ -98,10 +121,20 @@ requirements:
- The `forward` method has a signature that is compatible with the
`forward` method that it is extending.
There are two exceptions to the _no class variables rule_:
1. The `has_backward` variable can be used to indicate whether the layer has
a backward pass implemented (`True` when absent).
2. The `can_torch_compile` variable can be used to indicate whether the layer
supports `torch.compile` (`False` when absent).
This is an example of a pure layer:
```python
class SiluAndMul(nn.Module):
# This layer does not implement backward.
has_backward: bool = False
def forward(self, x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)

View File

@ -23,33 +23,93 @@ class SiluAndMul(nn.Module):
return F.silu(input[..., :d]) * input[..., d:]
```
The decorator changes the layer, so that other implementations of the `forward`
method can be registered using the name `SiluAndMul`.
The decorator does not change the behavior of the class -- it annotates
the class with the given name (here `SiluAndMul`). The `kernelize` function
described below uses this name to look up kernels for the layer.
### External layers
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function.
decorator can be made extensible using the `replace_kernel_forward_from_hub`
function:
```python
from somelibrary import SiluAndMul
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
register_kernel_mapping(kernel_layer_mapping)
```
The `register_kernel_mapping` call maps the name `SiluAndMul` to actual
hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer)
section for more information.
**Warning:** we strongly recommend using layers with a decorator, since
it signifies that the maintainer intends to keep the `forward` signature
compatible with layers from the hub.
## Kernelizing a model
A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. Kernelize can be used as follows:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE)
```
The `mode` specifies that the model will be used in inference. Similarly,
you can ask `kernelize` to prepare the model for training:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING)
```
**Note:** the `kernelize` function modifies the model in-place, the model
itself is returned as a convenience.
### Kernel device
Kernels can be registered per device type. For instance, separate `cuda` and
`metal` kernels could be registered for the name `SiluAndMul`. By default,
`kernelize` will try to infer the device type from the model's parameters.
You can pass the device type to `kernelize` if the device type cannot be
inferred (e.g. because the model has no parameters):
```python
model = MyModel(...)
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
```
### `torch.compile`
Not all Hub kernels support `torch.compile`. If you want to compile a model
after kernelizing it, you need to add this to the mode. You can use the
set union (`|`) operator to add `TORCH_COMPILE` to the mode:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
```
### Fallback `forward`
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
kernel does not support backward passes or `torch.compile` respectively,
`kernenize` will fall back to the original, non-kernelized, layer. You
can let `kernelize` raise an exception instead by using `use_fallback=False`:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
```
This can be useful if you want to guarantee that Hub kernels are used.
## Registering a hub kernel for a layer
Once a layer is made extensible, users can register hub kernels for it
by name using the `register_kernel_mapping` function. For example:
`kernelize` relies on kernel mappings to find Hub kernels for layers.
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
the Hub. For example:
```python
kernel_layer_mapping = {
@ -57,11 +117,14 @@ kernel_layer_mapping = {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
}
}
```
You can register such a mapping using `register_kernel_mapping`:
```python
register_kernel_mapping(kernel_layer_mapping)
```
@ -72,8 +135,63 @@ used with the `use_kernel_mapping` context manager:
```python
with use_kernel_mapping(kernel_layer_mapping):
# Use the layer for which the mapping is applied.
...
model = kernelize(model)
```
This ensures that the mapping is not active anymore outside the
`with`-scope.
### Registering kernels for specific modes
You might want to register two different kernels for a particular layer,
where one kernel is optimized for a specific mode. You can do so by
registering layer repositories for specific modes. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
The kernels will match exactly on the mode. So, for instance in the example above, no kernel
layer is used when the `mode` passed to `kernelize` is
`Mode.INFERENCE | Mode.TORCH_COMPILE` or `Mode.TRAINING`. However, if you want to
register a kernel to be used when the mode does not match any of the
modes in the mapping, you can use the special `Mode.DEFAULT` mode to do
so. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.DEFAULT: LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
In this case, modes other than `Mode.INFERENCE` and
`Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using
`kernels-community/activation`.

View File

@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
"kernels-community/activation" = ">=0.0.1"
```
Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with
Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with
the locked revisions. The locked revision will be used when loading a kernel with
`get_locked_kernel`:
@ -28,7 +28,7 @@ to `kernels` after doing an (editable or regular) installation of your project.
## Pre-downloading locked kernels
Locked kernels can be pre-downloaded by running `kernel download .` in your
Locked kernels can be pre-downloaded by running `kernels download .` in your
project directory. This will download the kernels to your local Hugging Face
Hub cache.

55
flake.lock generated
View File

@ -51,18 +51,38 @@
"type": "github"
}
},
"hf-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1750775451,
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=",
"owner": "huggingface",
"repo": "hf-nix",
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7",
"type": "github"
},
"original": {
"owner": "huggingface",
"repo": "hf-nix",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1737453259,
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
"lastModified": 1747820358,
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
"owner": "danieldk",
"repo": "nixpkgs",
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
"rev": "d3c1681180717528068082103bf323147de6ab0b",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "outlines-v0.1.4-tgi",
"ref": "cudatoolkit-12.9-kernel-builder",
"repo": "nixpkgs",
"type": "github"
}
@ -70,11 +90,11 @@
"root": {
"inputs": {
"flake-utils": "flake-utils",
"hf-nix": "hf-nix",
"nixpkgs": [
"tgi-nix",
"hf-nix",
"nixpkgs"
],
"tgi-nix": "tgi-nix"
]
}
},
"systems": {
@ -106,27 +126,6 @@
"repo": "default",
"type": "github"
}
},
"tgi-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1741617161,
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "kernels-0.2.0",
"repo": "text-generation-inference-nix",
"type": "github"
}
}
},
"root": "root",

View File

@ -1,7 +1,7 @@
{
inputs = {
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
hf-nix.url = "github:huggingface/hf-nix";
nixpkgs.follows = "hf-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
};
outputs =
@ -9,21 +9,21 @@
self,
nixpkgs,
flake-utils,
tgi-nix,
hf-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = import nixpkgs {
inherit system;
inherit (tgi-nix.lib) config;
config = hf-nix.lib.config system;
overlays = [
tgi-nix.overlays.default
hf-nix.overlays.default
];
};
in
{
formatter = pkgs.nixfmt-rfc-style;
formatter = pkgs.nixfmt-tree;
devShells = with pkgs; rec {
default = mkShell {
buildInputs =
@ -34,10 +34,13 @@
ruff
]
++ (with python3.pkgs; [
docutils
huggingface-hub
pytest
pytest-benchmark
pyyaml
torch
types-pyyaml
venvShellHook
]);

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.2.1"
version = "0.7.0"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
@ -8,13 +8,14 @@ authors = [
{ name = "David Holtz", email = "david@huggingface.co" },
{ name = "Nicolas Patry", email = "nicolas@huggingface.co" },
]
license = { text = "Apache-2.0" }
readme = "README.md"
requires-python = ">= 3.9"
dependencies = [
"huggingface-hub>=0.26.3",
"packaging>=24.2",
"tomli>=2.0.1; python_version<'3.11'",
"torch>=2.5",
"huggingface_hub>=0.26.0,<1.0",
"packaging>=20.0",
"pyyaml>=6",
"tomli>=2.0; python_version<'3.11'",
]
[build-system]
@ -23,12 +24,17 @@ build-backend = "setuptools.build_meta"
[dependency-groups]
dev = [
"mypy == 1.14.1",
"mypy >= 1.15.0",
"pytest >=8",
# Whatever version is compatible with pytest.
"pytest-benchmark",
"torch >=2.5",
"types-pyyaml"
]
[project.optional-dependencies]
torch = ["torch"]
[project.scripts]
kernels = "kernels.cli:main"

4
pytest.ini Normal file
View File

@ -0,0 +1,4 @@
[pytest]
markers =
darwin_only: marks tests that should only run on macOS
linux_only: marks tests that should only run on Linux

View File

@ -1,23 +1,35 @@
from kernels.layer import (
Device,
LayerRepository,
Mode,
kernelize,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
use_kernel_mapping,
)
from kernels.utils import (
get_kernel,
get_local_kernel,
get_locked_kernel,
has_kernel,
install_kernel,
load_kernel,
)
__all__ = [
"get_kernel",
"get_locked_kernel",
"load_kernel",
"install_kernel",
"use_kernel_forward_from_hub",
"register_kernel_mapping",
"LayerRepository",
"Device",
"LayerRepository",
"Mode",
"get_kernel",
"get_local_kernel",
"get_locked_kernel",
"has_kernel",
"install_kernel",
"kernelize",
"load_kernel",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"use_kernel_forward_from_hub",
"use_kernel_mapping",
]

View File

@ -0,0 +1,751 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Vendored from https://github.com/huggingface/doc-builder/blob/main/src/doc_builder/convert_rst_to_mdx.py
import re
# Re pattern to catch things inside ` ` in :obj:`thing`.
_re_obj = re.compile(r":obj:`([^`]+)`")
# Re pattern to catch things inside ` ` in :math:`thing`.
_re_math = re.compile(r":math:`([^`]+)`")
# Re pattern to catch things between single backquotes.
_re_single_backquotes = re.compile(r"(^|[^`])`([^`]+)`([^`]|$)")
# Re pattern to catch things between double backquotes.
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
# Re pattern to catch things inside ` ` in :func/class/meth:`thing`.
_re_func_class = re.compile(r":(?:func|class|meth):`([^`]+)`")
def convert_rst_formatting(text):
"""
Convert rst syntax for formatting to markdown in a given text.
"""
# Remove :class:, :func: and :meth: markers. To code-links and put double backquotes
# (to not be caught by the italic conversion).
text = _re_func_class.sub(r"[``\1``]", text)
# Remove :obj: markers. What's after is in a single backquotes so we put in double backquotes
# (to not be caught by the italic conversion).
text = _re_obj.sub(r"``\1``", text)
# Remove :math: markers.
text = _re_math.sub(r"\\\\(\1\\\\)", text)
# Convert content in single backquotes to italic.
text = _re_single_backquotes.sub(r"\1*\2*\3", text)
# Convert content in double backquotes to single backquotes.
text = _re_double_backquotes.sub(r"\1`\2`\3", text)
# Remove remaining ::
text = re.sub(r"::\n", "", text)
# Remove new lines inside blocks in backsticks as they will be kept.
lines = text.split("\n")
in_code = False
text = None
for line in lines:
if in_code:
splits = line.split("`")
in_code = len(splits) > 1 and len(splits) % 2 == 1
if len(splits) == 1:
# Some forgotten lone backstick
text += "\n" + line
else:
text += " " + line.lstrip()
else:
if text is not None:
text += "\n" + line
else:
text = line
splits = line.split("`")
in_code = len(splits) % 2 == 0
return text
# Re pattern to catch description and url in links of the form `description <url>`_.
_re_links = re.compile(r"`([^`]+\S)\s+</*([^/][^>`]*)>`_+")
# Re pattern to catch description and url in links of the form :prefix_link:`description <url>`_.
_re_prefix_links = re.compile(r":prefix_link:`([^`]+\S)\s+</*([^/][^>`]*)>`")
# Re pattern to catch reference in links of the form :doc:`reference`.
_re_simple_doc = re.compile(r":doc:`([^`<]*)`")
# Re pattern to catch description and reference in links of the form :doc:`description <reference>`.
_re_doc_with_description = re.compile(r":doc:`([^`<]+\S)\s+</*([^/][^>`]*)>`")
# Re pattern to catch reference in links of the form :ref:`reference`.
_re_simple_ref = re.compile(r":ref:`([^`<]*)`")
# Re pattern to catch description and reference in links of the form :ref:`description <reference>`.
_re_ref_with_description = re.compile(r":ref:`([^`<]+\S)\s+<([^>]*)>`")
def convert_rst_links(text, page_info):
"""
Convert the rst links in text to markdown.
"""
if "package_name" not in page_info:
raise ValueError("`page_info` must contain at least the package_name.")
package_name = page_info["package_name"]
version = page_info.get("version", "main")
language = page_info.get("language", "en")
no_prefix = page_info.get("no_prefix", False)
prefix = "" if no_prefix else f"/docs/{package_name}/{version}/{language}/"
# Links of the form :doc:`page`
text = _re_simple_doc.sub(rf"[\1]({prefix}\1)", text)
# Links of the form :doc:`text <page>`
text = _re_doc_with_description.sub(rf"[\1]({prefix}\2)", text)
if "page" in page_info and not no_prefix:
page = str(page_info["page"])
if page.endswith(".html"):
page = page[:-5]
prefix = f"{prefix}{page}"
else:
prefix = ""
# Refs of the form :ref:`page`
text = _re_simple_ref.sub(rf"[\1]({prefix}#\1)", text)
# Refs of the form :ref:`text <page>`
text = _re_ref_with_description.sub(rf"[\1]({prefix}#\2)", text)
# Links with a prefix
# TODO: when it exists, use the API to deal with prefix links properly.
prefix = f"https://github.com/huggingface/{package_name}/tree/main/"
text = _re_prefix_links.sub(rf"[\1]({prefix}\2)", text)
# Other links
text = _re_links.sub(r"[\1](\2)", text)
# Relative links or Transformers links need to remove the .html
if (
"(https://https://huggingface.co/" in text
or re.search(r"\(\.+/", text) is not None
):
text = text.replace(".html", "")
return text
# Re pattern that catches examples blocks of the form `Example::`.
_re_example = re.compile(r"^\s*(\S.*)::\s*$")
# Re pattern that catches rst blocks of the form `.. block_name::`.
_re_block = re.compile(r"^\s*\.\.\s+(\S+)::")
# Re pattern that catches what's after the :: in rst blocks of the form `.. block_name:: something`.
_re_block_info = re.compile(r"^\s*\.\.\s+\S+::\s*(\S.*)$")
def is_empty_line(line):
return len(line) == 0 or line.isspace()
def find_indent(line):
"""
Returns the number of spaces that start a line indent.
"""
search = re.search(r"^(\s*)(?:\S|$)", line)
if search is None:
return 0
return len(search.groups()[0])
_re_rst_option = re.compile(r"^\s*:(\S+):(.*)$")
def convert_special_chars(text):
"""
Converts { and < that have special meanings in MDX.
"""
text = text.replace("{", "&amp;lcub;")
# We don't want to replace those by the HTML code, so we temporarily set them at LTHTML
text = re.sub(
r"<(img|br|hr|Youtube)", r"LTHTML\1", text
) # html void elements with no closing counterpart
_re_lt_html = re.compile(r"<(\S+)([^>]*>)(((?!</\1>).)*)<(/\1>)", re.DOTALL)
while _re_lt_html.search(text):
text = _re_lt_html.sub(r"LTHTML\1\2\3LTHTML\5", text)
text = re.sub(r"(^|[^<])<([^<]|$)", r"\1&amp;lt;\2", text)
text = text.replace("LTHTML", "<")
return text
def parse_options(block_content):
"""
Parses the option in some rst block content.
"""
block_lines = block_content.split("\n")
block_indent = find_indent(block_lines[0])
current_option = None
result = {}
for line in block_lines:
if _re_rst_option.search(line) is not None:
current_option, value = _re_rst_option.search(line).groups()
result[current_option] = value.lstrip()
elif find_indent(line) > block_indent:
result[current_option] += " " + line.lstrip()
return result
def apply_min_indent(text, min_indent):
"""
Make sure all lines in a text are have a minimum indentation.
Args:
text (`str`): The text to treat.
min_indent (`int`): The minimal indentation.
Returns:
`str`: The processed text.
"""
lines = text.split("\n")
idx = 0
while idx < len(lines):
if is_empty_line(lines[idx]):
idx += 1
continue
indent = find_indent(lines[idx])
if indent < min_indent:
while idx < len(lines) and (
find_indent(lines[idx]) >= indent or is_empty_line(lines[idx])
):
if not is_empty_line(lines[idx]):
lines[idx] = " " * (min_indent - indent) + lines[idx]
idx += 1
else:
idx += 1
return "\n".join(lines)
def convert_rst_blocks(text, page_info):
"""
Converts rst special blocks (examples, notes) into MDX.
"""
if "package_name" not in page_info:
raise ValueError("`page_info` must contain at least the package_name.")
package_name = page_info["package_name"]
version = page_info.get("version", "main")
language = page_info.get("language", "en")
lines = text.split("\n")
idx = 0
new_lines = []
while idx < len(lines):
block_type = None
block_info = None
if _re_block.search(lines[idx]) is not None:
block_type = _re_block.search(lines[idx]).groups()[0]
if _re_block_info.search(lines[idx]) is not None:
block_info = _re_block_info.search(lines[idx]).groups()[0]
elif _re_example.search(lines[idx]) is not None:
block_type = "code-block-example"
block_info = "python"
example_name = _re_example.search(lines[idx]).groups()[0]
new_lines.append(f"<exampletitle>{example_name}:</exampletitle>\n")
elif lines[idx].strip() == "..":
block_type = "comment"
elif lines[idx].strip() == "::":
block_type = "code-block"
if block_type is not None:
block_indent = find_indent(lines[idx])
# Find the next nonempty line
idx += 1
while idx < len(lines) and is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the return line, this block will stop when we unindent under it (or has already)
example_indent = (
find_indent(lines[idx]) if idx < len(lines) else block_indent
)
if example_indent == block_indent:
block_content = ""
else:
block_lines = []
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) >= example_indent
):
block_lines.append(lines[idx][example_indent:])
idx += 1
block_content = "\n".join(block_lines)
if block_type in ["code", "code-block"]:
prefix = "```" if block_info is None else f"```{block_info}"
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n")
elif block_type == "code-block-example":
prefix = f"<example>```{block_info}"
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n</example>")
elif block_type == "note":
new_lines.append(
apply_min_indent(
f"<Tip>\n\n{block_content.strip()}\n\n</Tip>\n", block_indent
)
)
elif block_type == "warning":
new_lines.append(
apply_min_indent(
"<Tip warning={true}>\n\n"
+ f"{block_content.strip()}\n\n</Tip>\n",
block_indent,
)
)
elif block_type == "raw":
new_lines.append(block_content.strip() + "\n")
elif block_type == "math":
new_lines.append(f"$${block_content.strip()}$$\n")
elif block_type == "comment":
new_lines.append(f"<!--{block_content.strip()}\n-->\n")
elif block_type == "autofunction":
if block_info is not None:
new_lines.append(f"[[autodoc]] {block_info}\n")
elif block_type == "autoclass":
if block_info is not None:
block = f"[[autodoc]] {block_info}\n"
options = parse_options(block_content)
if "special-members" in options:
special_members = options["special-members"].split(", ")
for special_member in special_members:
block += f" - {special_member}\n"
if "members" in options:
members = options["members"]
if len(members) == 0:
block += " - all\n"
else:
for member in members.split(", "):
block += f" - {member}\n"
new_lines.append(block)
elif block_type == "image":
options = parse_options(block_content)
target = options.pop("target", None)
if block_info is not None:
options["src"] = block_info
else:
if target is None:
raise ValueError("Image source not defined.")
options["src"] = target
# Adapt path
options["src"] = options["src"].replace(
"/imgs/", f"/docs/{package_name}/{version}/{language}/imgs/"
)
html_code = " ".join(
[f'{key}="{value}"' for key, value in options.items()]
)
new_lines.append(f"<img {html_code}/>\n")
else:
new_lines.append(
f"{block_type},{block_info}\n{block_content.rstrip()}\n"
)
else:
new_lines.append(lines[idx])
idx += 1
return "\n".join(new_lines)
# Re pattern that catches rst args blocks of the form `Parameters:`.
_re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\s*$")
# Re pattern that catches return blocks of the form `Return:`.
_re_returns = re.compile(r"^\s*(Return|Yield|Raise)s?:\s*$")
def split_return_line(line):
"""
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
"""
splits_on_colon = line.split(":")
idx = 1
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
idx += 2
if idx >= len(splits_on_colon):
if len(splits_on_colon) % 2 == 1 and re.search(r"`\w+`$", line.rstrip()):
return line, ""
return None, line
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
def split_raise_line(line):
"""
Split the raise line with format `SomeError some doc`.
"""
splits_on_colon = line.strip().split(" ")
error_type, doc = splits_on_colon[0], " ".join(splits_on_colon[1:])
if error_type and error_type[-1] == ":":
error_type = error_type[:-1]
return error_type, doc
def split_arg_line(line):
"""
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
"""
splits_on_colon = line.split(":")
idx = 1
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
idx += 2
if idx >= len(splits_on_colon):
return line, ""
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
class InvalidRstDocstringError(ValueError):
pass
_re_parameters = re.compile(
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
)
_re_md_link = re.compile(r"\[(.+)\]\(.+\)", re.DOTALL)
def parse_rst_docstring(docstring):
"""
Parses a docstring written in rst, in particular the list of arguments and the return type.
"""
lines = docstring.split("\n")
idx = 0
while idx < len(lines):
# Parameters section
if _re_args.search(lines[idx]) is not None:
# Title of the section.
lines[idx] = "<parameters>\n"
# Find the next nonempty line
idx += 1
while is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the list of parameters, this block will stop when we unindent under it or we see the
# Returns or Raises block.
param_indent = find_indent(lines[idx])
while (
idx < len(lines)
and find_indent(lines[idx]) == param_indent
and _re_returns.search(lines[idx]) is None
):
intro, doc = split_arg_line(lines[idx])
# Line starting with a > after indent indicate a "section title" in the parameters.
if intro.lstrip().startswith(">"):
lines[idx] = intro.lstrip()
else:
lines[idx] = (
re.sub(r"^\s*(\S+)(\s)?", r"- **\1**\2", intro) + " --" + doc
)
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx]) or find_indent(lines[idx]) > param_indent
):
idx += 1
lines.insert(idx, "</parameters>\n")
idx += 1
# Returns section
elif _re_returns.search(lines[idx]) is not None:
# tag is either `return` or `yield`
tag = _re_returns.match(lines[idx]).group(1).lower()
# Title of the section.
lines[idx] = f"<{tag}s>\n"
# Find the next nonempty line
idx += 1
while is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the return line, this block will stop when we unindent under it.
return_indent = find_indent(lines[idx])
raised_errors = []
# The line may contain the return type.
if tag in ["return", "yield"]:
return_type, return_description = split_return_line(lines[idx])
lines[idx] = return_description
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) >= return_indent
):
idx += 1
else:
while idx < len(lines) and find_indent(lines[idx]) == return_indent:
return_type, return_description = split_raise_line(lines[idx])
raised_error = re.sub(r"^\s*`?([\w\.]*)`?$", r"``\1``", return_type)
lines[idx] = "- " + raised_error + " -- " + return_description
md_link = _re_md_link.match(raised_error)
if md_link:
raised_error = md_link[1]
raised_error = re.sub(
r"^\s*`?([\w\.]*)`?$", r"``\1``", raised_error
)
if raised_error not in raised_errors:
raised_errors.append(raised_error)
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) > return_indent
):
idx += 1
lines.insert(idx, f"</{tag}s>\n")
idx += 1
# Return block finished, we insert the return type if one was specified
if tag in ["return", "yield"] and return_type is not None:
lines[idx - 1] += f"\n<{tag}type>{return_type}</{tag}type>\n"
elif len(raised_errors) > 0:
# raised errors
lines[
idx - 1
] += f"\n<raisederrors>{' or '.join(raised_errors)}</raisederrors>\n"
else:
idx += 1
result = "\n".join(lines)
# combine multiple <parameters> blocks into one block
if result.count("<parameters>") > 1:
parameters_blocks = _re_parameters.findall(result)
parameters_blocks = [pb[0].strip() for pb in parameters_blocks]
parameters_str = "\n".join(parameters_blocks)
result = _re_parameters.sub("", result)
result += f"\n<parameters>{parameters_str}</parameters>\n"
return result
_re_list = re.compile(r"^\s*(-|\*|\d+\.)\s")
_re_autodoc = re.compile(r"^\s*\[\[autodoc\]\]\s+(\S+)\s*$")
def remove_indent(text):
"""
Remove indents in text, except the one linked to lists (or sublists).
"""
lines = text.split("\n")
# List of indents to remember for nested lists
current_indents = []
# List of new indents to remember for nested lists
new_indents = []
is_inside_code = False
code_indent = 0
for idx, line in enumerate(lines):
# Line is an item in a list.
if _re_list.search(line) is not None:
indent = find_indent(line)
# Is it a new list / new level of nestedness?
if len(current_indents) == 0 or indent > current_indents[-1]:
current_indents.append(indent)
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indent += len(_re_list.search(line).groups()[0]) + 1
new_indents.append(new_indent)
# Otherwise it's an existing level of list (current one, or previous one)
else:
# Let's find the proper level of indentation
level = len(current_indents) - 1
while level >= 0 and current_indents[level] != indent:
level -= 1
current_indents = current_indents[: level + 1]
new_indents = new_indents[:level]
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indent += len(_re_list.search(line).groups()[0]) + 1
new_indents.append(new_indent)
# Line is an autodoc, we keep the indent for the list just after if there is one.
elif _re_autodoc.search(line) is not None:
indent = find_indent(line)
current_indents = [indent]
new_indents = [4]
lines[idx] = line.strip()
# Deal with empty lines separately
elif is_empty_line(line):
lines[idx] = ""
# Code blocks
elif line.lstrip().startswith("```"):
is_inside_code = not is_inside_code
if is_inside_code:
code_indent = find_indent(line)
lines[idx] = line[code_indent:]
elif is_inside_code:
lines[idx] = line[code_indent:]
else:
indent = find_indent(line)
if len(current_indents) > 0 and indent > current_indents[-1]:
lines[idx] = " " * new_indents[-1] + line[indent:]
elif len(current_indents) > 0:
# Let's find the proper level of indentation
level = len(current_indents) - 1
while level >= 0 and current_indents[level] > indent:
level -= 1
current_indents = current_indents[: level + 1]
if level >= 0:
if current_indents[level] < indent:
new_indents = new_indents[: level + 1]
else:
new_indents = new_indents[:level]
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indents.append(new_indent)
else:
new_indents = []
lines[idx] = line[indent:]
else:
lines[idx] = line[indent:]
return "\n".join(lines)
def base_rst_to_mdx(text, page_info, unindent=True):
"""
Convert a text from rst to mdx, with the base operations necessary for both docstrings and rst docs.
"""
text = convert_rst_links(text, page_info)
text = convert_special_chars(text)
text = convert_rst_blocks(text, page_info)
# Convert * in lists to - to avoid the formatting conversion treat them as bold.
text = re.sub(r"^(\s*)\*(\s)", r"\1-\2", text, flags=re.MULTILINE)
text = convert_rst_formatting(text)
return remove_indent(text) if unindent else text
def convert_rst_docstring_to_mdx(docstring, page_info):
"""
Convert a docstring written in rst to mdx.
"""
text = parse_rst_docstring(docstring)
return base_rst_to_mdx(text, page_info)
def process_titles(lines):
"""Converts rst titles to markdown titles."""
title_chars = """= - ` : ' " ~ ^ _ * + # < >""".split(" ")
title_levels = {}
new_lines = []
for line in lines:
if (
len(new_lines) > 0
and len(line) >= len(new_lines[-1])
and len(set(line)) == 1
and line[0] in title_chars
and line != "::"
):
char = line[0]
level = title_levels.get(char, len(title_levels) + 1)
if level not in title_levels:
title_levels[char] = level
new_lines[-1] = f"{'#' * level} {new_lines[-1]}"
else:
new_lines.append(line)
return new_lines
# Matches lines with a pattern of a table new line in rst.
_re_ignore_line_table = re.compile(r"^(\+[\-\s]+)+\+\s*$")
# Matches lines with a pattern of a table new line in rst, with a first column empty.
_re_ignore_line_table1 = re.compile(r"^\|\s+(\+[\-\s]+)+\+\s*$")
# Matches lines with a pattern of a first table line in rst.
_re_sep_line_table = re.compile(r"^(\+[=\s]+)+\+\s*$")
# Re pattern that catches anchors of the type .. reference:
_re_anchor_section = re.compile(r"^\.\.\s+_(\S+):")
def split_pt_tf_code_blocks(text):
"""
Split PyTorch and TensorFlow specific block codes.
"""
lines = text.split("\n")
new_lines = []
idx = 0
while idx < len(lines):
if lines[idx].startswith("```"):
code_lines = {"common": [lines[idx]], "pytorch": [], "tensorflow": []}
is_pytorch = False
is_tensorflow = False
idx += 1
while idx < len(lines) and lines[idx].strip() != "```":
if "## PYTORCH CODE" in lines[idx]:
is_pytorch = True
is_tensorflow = False
elif "## TENSORFLOW CODE" in lines[idx]:
is_tensorflow = True
is_pytorch = False
elif is_pytorch:
code_lines["pytorch"].append(lines[idx])
elif is_tensorflow:
code_lines["tensorflow"].append(lines[idx])
else:
code_lines["common"].append(lines[idx])
idx += 1
if len(code_lines["pytorch"]) > 0 or len(code_lines["tensorflow"]) > 0:
block_lines = ["<frameworkcontent>", "<pt>"]
block_lines.extend(code_lines["common"].copy() + code_lines["pytorch"])
block_lines.extend(["```", "</pt>", "<tf>"])
block_lines.extend(
code_lines["common"].copy() + code_lines["tensorflow"]
)
block_lines.extend(["```", "</tf>", "</frameworkcontent>"])
new_lines.extend(block_lines)
else:
block_lines = code_lines["common"] + ["```"]
new_lines.extend(block_lines)
idx += 1
else:
new_lines.append(lines[idx])
idx += 1
return "\n".join(new_lines)
def convert_rst_to_mdx(rst_text, page_info, add_imports=True):
"""
Convert a document written in rst to mdx.
"""
lines = rst_text.split("\n")
lines = process_titles(lines)
if add_imports:
new_lines = [
'<script lang="ts">',
' import Tip from "$lib/Tip.svelte";',
' import Youtube from "$lib/Youtube.svelte";',
' import Docstring from "$lib/Docstring.svelte";',
' import CodeBlock from "$lib/CodeBlock.svelte";',
' import CodeBlockFw from "$lib/CodeBlockFw.svelte";',
' import DocNotebookDropdown from "$lib/DocNotebookDropdown.svelte";',
' import CourseFloatingBanner from "$lib/CourseFloatingBanner.svelte";',
' import IconCopyLink from "$lib/IconCopyLink.svelte";',
' import FrameworkContent from "$lib/FrameworkContent.svelte";',
' import Markdown from "$lib/Markdown.svelte";',
' import ExampleCodeBlock from "$lib/ExampleCodeBlock.svelte";',
' import Added from "$lib/Added.svelte";',
' import Changed from "$lib/Changed.svelte";',
' import Deprecated from "$lib/Deprecated.svelte";',
' import PipelineIcon from "$lib/PipelineIcon.svelte";',
' import PipelineTag from "$lib/PipelineTag.svelte";',
" ",
' export let fw: "pt" | "tf"',
"</script>",
"<svelte:head>",
'<meta name="hf:doc:metadata" content={JSON.stringify(metadata)} >',
"</svelte:head>",
"",
]
else:
new_lines = []
for line in lines:
if _re_ignore_line_table.search(line) is not None:
continue
elif _re_ignore_line_table1.search(line) is not None:
continue
elif _re_sep_line_table.search(line) is not None:
line = line.replace("=", "-").replace("+", "|")
elif _re_anchor_section.search(line) is not None:
anchor_name = _re_anchor_section.search(line).groups()[0]
line = f"<a id='{anchor_name}'></a>"
new_lines.append(line)
text = "\n".join(new_lines)
return split_pt_tf_code_blocks(base_rst_to_mdx(text, page_info))

View File

@ -8,6 +8,9 @@ from kernels.compat import tomllib
from kernels.lockfile import KernelLock, get_kernel_locks
from kernels.utils import install_kernel, install_kernel_all_variants
from .doc import generate_readme_for_kernel
from .wheel import build_variant_to_wheel
def main():
parser = argparse.ArgumentParser(
@ -36,6 +39,47 @@ def main():
)
lock_parser.set_defaults(func=lock_kernels)
to_wheel_parser = subparsers.add_parser(
"to-wheel", help="Convert a kernel to a wheel file"
)
to_wheel_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
to_wheel_parser.add_argument("version", type=str, help="The kernel version")
to_wheel_parser.add_argument(
"--python-version",
type=str,
default="3.9",
help="The minimum Python version. Must match the Python version that the kernel was compiled for.",
)
to_wheel_parser.add_argument(
"--manylinux-version",
type=str,
default="2.28",
help="The manylinux version. Must match the manylinux version that the kernel was compiled for.",
)
to_wheel_parser.set_defaults(func=kernels_to_wheel)
# Add generate-readme subcommand parser
generate_readme_parser = subparsers.add_parser(
"generate-readme",
help="Generate README snippets for a kernel's public functions",
)
generate_readme_parser.add_argument(
"repo_id",
type=str,
help="The kernel repo ID (e.g., kernels-community/activation)",
)
generate_readme_parser.add_argument(
"--revision",
type=str,
default="main",
help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
)
generate_readme_parser.set_defaults(
func=lambda args: generate_readme_for_kernel(
repo_id=args.repo_id, revision=args.revision
)
)
args = parser.parse_args()
args.func(args)
@ -77,6 +121,24 @@ def download_kernels(args):
sys.exit(1)
def kernels_to_wheel(args):
variants_path = install_kernel_all_variants(
repo_id=args.repo_id, revision=f"v{args.version}"
)
for variant_path in variants_path.iterdir():
if not variant_path.is_dir():
continue
wheel_path = build_variant_to_wheel(
manylinux_version=args.manylinux_version,
python_version=args.python_version,
repo_id=args.repo_id,
version=args.version,
variant_path=variant_path,
wheel_dir=Path("."),
)
print(f"☸️ {wheel_path.name}", file=sys.stderr)
def lock_kernels(args):
with open(args.project_dir / "pyproject.toml", "rb") as f:
data = tomllib.load(f)

242
src/kernels/doc.py Normal file
View File

@ -0,0 +1,242 @@
import inspect
import re
import sys
from types import ModuleType
import yaml
from ._vendored.convert_rst_to_mdx import convert_rst_docstring_to_mdx
from .utils import get_kernel
_RE_PARAMETERS = re.compile(
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
)
_RE_RETURNS = re.compile(r"<returns>(((?!<returns>).)*)</returns>", re.DOTALL)
_RE_RETURNTYPE = re.compile(
r"<returntype>(((?!<returntype>).)*)</returntype>", re.DOTALL
)
def _extract_description_before_tags(docstring_mdx: str) -> str:
"""Extract the description part of a docstring before any tags."""
params_pos = docstring_mdx.find("<parameters>")
returns_pos = docstring_mdx.find("<returns>")
returntype_pos = docstring_mdx.find("<returntype>")
positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1]
if positions:
first_tag_pos = min(positions)
return docstring_mdx[:first_tag_pos].strip()
else:
return docstring_mdx.strip()
def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None:
"""Print the parameters section from a docstring."""
matches = _RE_PARAMETERS.findall(docstring_mdx)
if matches:
header = "#" * header_level
print(f"\n{header} Parameters")
for match in matches:
print(f"\n{match[0].strip()}")
def _print_returns_section(
docstring_mdx: str, *, context_name: str, header_level: int
) -> None:
"""Print the returns section from a docstring."""
return_matches = _RE_RETURNS.findall(docstring_mdx)
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
if return_matches or returntype_matches:
header = "#" * header_level
print(f"\n{header} Returns")
if returntype_matches:
if len(returntype_matches) > 1:
raise ValueError(
f"More than one <returntype> tag found in docstring for {context_name}"
)
print(f"\n**Type**: {returntype_matches[0][0].strip()}")
if return_matches:
for match in return_matches:
print(f"\n{match[0].strip()}")
def _get_docstring(obj, use_dict_check: bool = False) -> str:
"""Get docstring from an object, with fallback to default message."""
# Check whether the class/method itself has docs and not just
# the superclass.
if use_dict_check:
has_doc = obj.__dict__.get("__doc__", None) is not None
else:
has_doc = getattr(obj, "__doc__", None) is not None
# We use inspect.getdoc because it does normalization.
doc = inspect.getdoc(obj)
return doc if has_doc and doc is not None else "No documentation available."
def _process_and_print_docstring(
docstring: str, *, kernel_name: str, context_name: str, header_level: int
) -> None:
"""Convert docstring to MDX and print description, parameters, and returns sections."""
docstring_mdx = convert_rst_docstring_to_mdx(
docstring, page_info={"package_name": kernel_name}
)
# Print the description
description = _extract_description_before_tags(docstring_mdx)
print(f"\n{description}")
# Print parameters and returns sections
_print_parameters_section(docstring_mdx, header_level=header_level)
_print_returns_section(
docstring_mdx, context_name=context_name, header_level=header_level
)
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
kernel_module = get_kernel(repo_id=repo_id, revision=revision)
kernel_name = repo_id.split("/")[-1].replace("-", "_")
generate_metadata(kernel_module)
generate_kernel_doc(kernel_module, kernel_name)
generate_function_doc(kernel_module, kernel_name)
generate_layers_doc(kernel_module, kernel_name)
def generate_metadata(module: ModuleType) -> None:
metadata = getattr(module, "__kernel_metadata__", {})
if "tags" not in metadata:
metadata["tags"] = ["kernel"]
else:
if "kernel" not in metadata["tags"]:
metadata["tags"].append("kernel")
print("---")
print(yaml.dump(metadata), end="")
print("---")
def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None:
docstring = module.__doc__.strip() if module.__doc__ is not None else None
if docstring:
title, rest = docstring.split("\n", 1)
print(f"# {title.strip()}")
print(
f"\n{convert_rst_docstring_to_mdx(rest.strip(), page_info={'package_name': kernel_name})}"
)
def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None:
print("\n## Functions")
# Track if we found any functions
found_functions = False
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
# Do not include imported functions.
if func.__module__ != kernel_module.__name__:
continue
# Exclude private functions.
if name.startswith("_"):
continue
found_functions = True
try:
sig = inspect.signature(func)
docstring = _get_docstring(func)
except ValueError:
print(
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
file=sys.stderr,
)
continue
print(f"\n### Function `{name}`")
print(f"\n`{sig}`")
_process_and_print_docstring(
docstring, kernel_name=kernel_name, context_name=name, header_level=3
)
if not found_functions:
print("\nNo public top-level functions.")
def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None:
# Check if layers module is available
layers_module = getattr(kernel_module, "layers", None)
if layers_module is None:
return
print("\n## Layers")
# Track if we found any classes
found_classes = False
for class_name, cls in inspect.getmembers(layers_module, inspect.isclass):
# Exclude classes that were imported.
if cls.__module__ != layers_module.__name__:
continue
found_classes = True
try:
# Get docstring, but not from superclasses.
class_docstring = _get_docstring(cls, use_dict_check=True)
except Exception:
print(
f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}",
file=sys.stderr,
)
continue
print(f"\n### Class `{class_name}`")
# Always print class description (helper handles conversion and formatting)
class_docstring_mdx = convert_rst_docstring_to_mdx(
class_docstring, page_info={"package_name": kernel_name}
)
description = _extract_description_before_tags(class_docstring_mdx)
print(f"\n{description}")
# Document methods
print("\n#### Methods")
for method_name, method in inspect.getmembers(cls, inspect.isfunction):
# Note: also skip __init__, since extension layers cannot have a constructor.
if method_name.startswith("_"):
continue
# Skip methods from superclasses.
if method_name not in cls.__dict__:
continue
try:
sig = inspect.signature(method)
method_docstring = _get_docstring(method)
except ValueError:
print(
f"Warning: Could not retrieve signature for {method_name} in {class_name}",
file=sys.stderr,
)
continue
print(f"\n##### Method `{method_name}`")
print(f"\n`{sig}`")
_process_and_print_docstring(
method_docstring,
kernel_name=kernel_name,
context_name=method_name,
header_level=6,
)
if not found_classes:
print("\nNo layers defined.")

View File

@ -1,15 +1,67 @@
from __future__ import annotations
import inspect
import os
import warnings
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Dict, Union
from enum import Flag, auto
from types import MethodType
from typing import (
TYPE_CHECKING,
Dict,
Optional,
Tuple,
Type,
Union,
)
from .utils import get_kernel
if TYPE_CHECKING:
import torch
from torch import nn
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
class Mode(Flag):
"""
Kernelize mode
The `Mode` flag is used by `kernelize` to select kernels for the given
mode. Mappings can be registered for specific modes.
* `INFERENCE`: The kernel is used for inference.
* `TRAINING`: The kernel is used for training.
* `TORCH_COMPILE`: The kernel is used with `torch.compile`.
* `DEFAULT`: In a kernel mapping, this kernel is used when no other mode
matches.
Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE`
should be used for layers that are used for inference *with* `torch.compile`.
"""
_NONE = 0
DEFAULT = auto()
TRAINING = auto()
INFERENCE = auto()
TORCH_COMPILE = auto()
def __or__(self, other: Mode) -> Mode:
union = super().__or__(other)
if Mode.INFERENCE in union and Mode.TRAINING in union:
raise ValueError("Mode.INFERENCE and Mode.TRAINING are mutually exclusive.")
if Mode.DEFAULT in union and union != Mode.DEFAULT:
raise ValueError("Mode.DEFAULT cannot be combined with other modes.")
return union
@dataclass(frozen=True)
class Device:
type: str
@ -49,16 +101,37 @@ class LayerRepository:
return hash((self.layer_name, self.repo_id, self.revision))
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
"_KERNEL_MAPPING", default={}
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, Dict[Mode, LayerRepository]]]] = (
ContextVar("_KERNEL_MAPPING", default={})
)
def use_kernel_mapping(mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]):
def use_kernel_mapping(
mapping: Dict[
str,
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
],
*,
inherit_mapping: bool = True,
):
"""
Context manager that sets a mapping for a duration of the context.
When `inherit_mapping` is set to `True` the current mapping will be
extended by `mapping` inside the context. If it is `False`, only
`mapping` is used inside the context.
"""
class ContextManager:
def __enter__(self):
# Mappings always stack on previous mappings.
self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get()))
if inherit_mapping:
self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get()))
else:
self.token = _KERNEL_MAPPING.set({})
register_kernel_mapping(mapping)
def __exit__(self, exc_type, exc_value, traceback):
@ -68,12 +141,17 @@ def use_kernel_mapping(mapping: Dict[str, Dict[Union[Device, str], LayerReposito
def register_kernel_mapping(
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]
mapping: Dict[
str,
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
],
):
"""
Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device.
This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname.
Exemple usage:
Allows one to register a mapping between a layer name and the corresponding
kernel(s) to use, depending on the device. This should be used in conjunction
with `kernelize`.
Example usage:
```python
from kernels import LayerRepository, register_kernel_mapping
@ -94,90 +172,179 @@ def register_kernel_mapping(
for new_kernel, new_device_repos in mapping.items():
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
for new_device, new_repo in new_device_repos.items():
if isinstance(new_device, str):
device_repo[Device(type=new_device)] = new_repo
device = (
Device(type=new_device) if isinstance(new_device, str) else new_device
)
if isinstance(new_repo, LayerRepository):
kernel_options = {Mode.DEFAULT: new_repo}
else:
device_repo[new_device] = new_repo
kernel_options = new_repo
device_repo[device] = kernel_options
def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True):
def replace_kernel_forward_from_hub(
cls,
layer_name: str,
):
"""
Replace the forward function of a layer using a layer from the kernel hub.
This function monkeypatches a layer, replacing the `forward` method
of the layer with that of a layer from the hub. The replacement is done
when a layer matching `layer_name` and device type is registered through
`register_layer_mapping`. The device type is inferred from the first
argument to `forward`.
Decorator that prepares a layer class to use a kernel from the Hugging Face Hub.
This decorator stores the layer name and original forward method, which will be used
by the kernelize function to replace the forward implementation with the appropriate
kernel from the hub.
Args:
cls: The layer class to decorate
layer_name: The name of the layer to use for kernel lookup
"""
cls.kernel_layer_name = layer_name
fallback_forward = cls.forward
cached_forward: Dict[LayerRepository, Callable] = {}
def _select_repository(
repositories: Dict[Mode, LayerRepository],
*,
mode: Mode,
) -> Optional[Tuple[LayerRepository, Mode]]:
if mode in repositories:
return (repositories[mode], mode)
elif Mode.DEFAULT in repositories:
return (repositories[Mode.DEFAULT], Mode.DEFAULT)
else:
return None
def kernelize(
model: "nn.Module",
*,
mode: Mode,
device: Optional[Union[str, "torch.device"]] = None,
use_fallback: bool = True,
):
"""
Iterate over all modules in the model and replace the `forward` method of
extensible layers for which kernels are registered using `register_kernel_mapping`
or `use_kernel_mapping`.
Args:
model: The PyTorch model to kernelize
mode: the mode that the kernel is going to be used in (e.g.
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training
and `torch.compile`).
device: The device type to load kernels for. The device type will be inferred
from the parameters of the model when not provided.
use_fallback: Whether to use the original forward method of modules when no
compatible kernel could be found. If set to `False`, an exception will
be raised in such cases.
Returns:
The kernelized model
"""
import torch
if mode == Mode.DEFAULT:
raise ValueError("Mode.DEFAULT can only be used to register kernel mappings.")
# Type check ignored because this causes a false negative on Python < 3.11.
# Looks similar to: https://github.com/python/mypy/issues/9642
# Remove once we start doing typing checks on >= 3.11.
if Mode.INFERENCE not in mode and Mode.TRAINING not in mode: # type: ignore[operator]
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
if device is None:
device_type = _find_device(model)
elif isinstance(device, str):
device_type = Device(type=torch.device(device).type)
else:
device_type = Device(device.type)
assert isinstance(device_type, Device)
for _, module in model.named_modules():
module_class = type(module)
if not hasattr(module_class, "kernel_layer_name"):
continue
layer_name = module_class.kernel_layer_name
if _DISABLE_KERNEL_MAPPING:
_replace_forward(module, module_class)
continue
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
def forward(self, x, **args):
kernel = _KERNEL_MAPPING.get().get(layer_name)
if kernel is None:
warnings.warn(
"\n"
f"No kernel mapping found for layer `{layer_name}`. "
f"Check if the layer name matches one of the kernels in the mapping or add the kernel "
f"you want to use to the mapping. Defaulting to original forward implementation."
)
if not use_fallback:
raise ValueError(f"No layer mapping for `{layer_name}`")
return fallback_forward(self, x, **args)
_replace_forward(module, module_class)
continue
device = getattr(x, "device", None)
if device is None:
return fallback_forward(self, x, **args)
# Get kernel options for the device
repos = kernel.get(device_type)
repo = kernel.get(Device(type=device.type))
if repo is None:
if repos is None:
if not use_fallback:
raise ValueError(
f"No layer mapping for `{layer_name}` with device type `{device.type}`"
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
)
return fallback_forward(self, x, **args)
_replace_forward(module, module_class)
continue
# Short-circuit if we already loaded the layer.
layer_forward = cached_forward.get(repo, None)
if layer_forward is not None:
return layer_forward(self, x, **args)
layer = _get_kernel_layer(
repo_id=repo.repo_id,
layer_name=repo.layer_name,
revision=repo.revision,
repo_with_mode = _select_repository(
repos,
mode=mode,
)
# We have to validate against the original signature.
orig_forward = cls.forward
try:
cls.forward = fallback_forward
_validate_layer(check_cls=cls, cls=layer)
finally:
cls.forward = orig_forward
if repo_with_mode is None:
if not use_fallback:
raise ValueError(
f"No repository for `{layer_name}` for configuration mode={mode}"
)
_replace_forward(module, module_class)
continue
layer_forward = layer.forward
cached_forward[repo] = layer_forward
repo, repo_mode = repo_with_mode
return layer_forward(self, x, **args)
layer = _get_layer_memoize(repo, module_class)
cls.forward = forward
# Ideally we would do validation on the mapping where we check that
# e.g. if a repo class is registered for TRAINING | TORCH_COMPILE,
# the actual layer is compatible with that. Unfortunately, this would
# mean that we have to pre-download everything.
_validate_layer_has_mode(
layer_name=layer_name, module=layer, repo=repo, repo_mode=repo_mode
)
_conditionally_replace_forward(
module=module,
layer=layer,
mode=mode,
use_fallback=use_fallback,
)
return model
def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True):
def use_kernel_forward_from_hub(layer_name: str):
"""
Replace the forward function of a layer using a layer from the kernel hub.
This decorator can be applied to a layer and replaces the forward method
of the layer with that of a layer from the hub. The replacement is done
when a layer matching `layer_name` and device type is registered through
`register_layer_mapping`. The device type is inferred from the first
argument to `forward`.
Make a layer extensible using the name `layer_name`.
"""
def decorator(cls):
replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback)
replace_kernel_forward_from_hub(cls, layer_name)
return cls
return decorator
def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
def _get_kernel_layer(
*, repo_id: str, layer_name: str, revision: str
) -> Type["nn.Module"]:
"""Get a layer from a kernel."""
kernel = get_kernel(repo_id, revision=revision)
@ -194,13 +361,13 @@ def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Mo
def _validate_layer(*, check_cls, cls):
import torch.nn as nn
# The layer must have at least have the following properties: (1) it
# must be stateless; (2) the forward signature should correspond to
# the signature it is replacing; (3) forward should not call other
# methods.
from torch import nn
if not issubclass(cls, nn.Module):
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
@ -212,7 +379,9 @@ def _validate_layer(*, check_cls, cls):
# ... or predefined member variables.
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(cls)}
if cls_members - torch_module_members != set():
difference = cls_members - torch_module_members
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
if not difference <= {"can_torch_compile", "has_backward"}:
raise TypeError("Layer must not contain additional members.")
# Check whether the forward signatures are similar.
@ -229,3 +398,92 @@ def _validate_layer(*, check_cls, cls):
raise TypeError(
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
)
def _find_device(model: "nn.Module") -> Device:
try:
param = next(model.parameters())
except StopIteration:
raise ValueError(
"Cannot determine model device, provide as `device` argument to `kernelize`."
)
return Device(type=param.device.type)
def _conditionally_replace_forward(
*,
module: "nn.Module",
layer: Type["nn.Module"],
mode: Mode,
use_fallback: bool,
):
module_class = type(module)
# Switch to fallback if the mode is not supported by the layer.
# Note that this is useful even after _validate_layer_has_mode because
# layers registered with the DEFAULT mode never get rejected by
# _validate_layer_has_mode. For such layers, we want to fall back in
# case the layer does not support the given mode.
needs_fallback = Mode.TORCH_COMPILE in mode and not getattr(
layer, "can_torch_compile", False
)
needs_fallback |= Mode.TRAINING in mode and not getattr(layer, "has_backward", True)
if needs_fallback:
if use_fallback:
_replace_forward(module, module_class)
else:
raise ValueError(f"Available kernel does not support mode: {mode}")
else:
_replace_forward(module, layer)
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
def _validate_layer_has_mode(
*,
layer_name: str,
module: Type["nn.Module"],
repo: LayerRepository,
repo_mode: Mode,
):
"""
Check that a repository supports the mode that it was registered for.
"""
if Mode.TRAINING in repo_mode and not getattr(module, "has_backward", True):
raise ValueError(
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support backward.\n"
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
)
if Mode.TORCH_COMPILE in repo_mode and not getattr(
module, "can_torch_compile", False
):
raise ValueError(
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support torch.compile.\n"
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
)
return True
def _get_layer_memoize(
repo: LayerRepository, module_class: Type["nn.Module"]
) -> Type["nn.Module"]:
layer = _CACHED_LAYER.get(repo, None)
if layer is not None:
return layer
layer = _get_kernel_layer(
repo_id=repo.repo_id,
layer_name=repo.layer_name,
revision=repo.revision,
)
_validate_layer(check_cls=module_class, cls=layer)
_CACHED_LAYER[repo] = layer
return layer

View File

@ -4,6 +4,7 @@ import importlib
import importlib.metadata
import inspect
import json
import logging
import os
import platform
import sys
@ -12,29 +13,54 @@ from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Tuple
from huggingface_hub import snapshot_download
from huggingface_hub import file_exists, snapshot_download
from packaging.version import parse
from kernels.lockfile import KernelLock, VariantLock
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
def _get_cache_dir() -> Optional[str]:
"""Returns the kernels cache directory."""
cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
if cache_dir is not None:
logging.warning(
"HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead"
)
return cache_dir
return os.environ.get("KERNELS_CACHE", None)
CACHE_DIR: Optional[str] = _get_cache_dir()
def build_variant() -> str:
import torch
if torch.version.cuda is None:
if torch.version.cuda is not None:
cuda_version = parse(torch.version.cuda)
compute_framework = f"cu{cuda_version.major}{cuda_version.minor}"
elif torch.version.hip is not None:
rocm_version = parse(torch.version.hip.split("-")[0])
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
elif torch.backends.mps.is_available():
compute_framework = "metal"
else:
raise AssertionError(
"This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled."
"Torch was not compiled with CUDA, Metal, or ROCm enabled."
)
torch_version = parse(torch.__version__)
cuda_version = parse(torch.version.cuda)
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
cpu = platform.machine()
os = platform.system().lower()
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
if os == "darwin":
cpu = "aarch64" if cpu == "arm64" else cpu
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
def universal_build_variant() -> str:
@ -84,6 +110,23 @@ def install_kernel(
)
)
try:
return _load_kernel_from_path(repo_path, package_name, variant_locks)
except FileNotFoundError:
# Redo with more specific error message.
raise FileNotFoundError(
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
)
def _load_kernel_from_path(
repo_path: Path,
package_name: str,
variant_locks: Optional[Dict[str, VariantLock]] = None,
) -> Tuple[str, Path]:
variant = build_variant()
universal_variant = universal_build_variant()
variant_path = repo_path / "build" / variant
universal_variant_path = repo_path / "build" / universal_variant
@ -102,7 +145,7 @@ def install_kernel(
if not os.path.exists(module_init_path):
raise FileNotFoundError(
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
f"Kernel at path `{repo_path}` does not have build: {variant}"
)
return package_name, variant_path
@ -140,10 +183,47 @@ def install_kernel_all_variants(
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
"""
Download and import a kernel from the Hugging Face Hub.
The kernel is downloaded from the repository `repo_id` at
branch/commit/tag `revision`.
"""
package_name, package_path = install_kernel(repo_id, revision=revision)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
"""
Import a kernel from a local kernel repository path.
"""
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def has_kernel(repo_id: str, revision: str = "main") -> bool:
"""
Check whether a kernel build exists for the current environment
(Torch version and compute framework).
"""
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
universal_variant = universal_build_variant()
if file_exists(
repo_id,
revision=revision,
filename=f"build/{universal_variant}/{package_name}/__init__.py",
):
return True
return file_exists(
repo_id,
revision=revision,
filename=f"build/{variant}/{package_name}/__init__.py",
)
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
"""
Get a pre-downloaded, locked kernel.

186
src/kernels/wheel.py Normal file
View File

@ -0,0 +1,186 @@
import email.policy
import os
from dataclasses import dataclass
from email.message import Message
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from typing import Optional
try:
KERNELS_VERSION = version("kernels")
except PackageNotFoundError:
KERNELS_VERSION = "unknown"
@dataclass
class Metadata:
name: str
version: str
cuda_version: Optional[str]
cxx_abi_version: Optional[str]
torch_version: Optional[str]
os: Optional[str]
platform: Optional[str]
@property
def is_universal(self) -> bool:
return self.platform is None
def build_variant_to_wheel(
repo_id: str,
*,
version: str,
variant_path: Path,
wheel_dir: Path,
manylinux_version: str = "2.28",
python_version: str = "3.9",
) -> Path:
"""
Create a wheel file from the variant path.
"""
name = repo_id.split("/")[-1].replace("_", "-")
metadata = extract_metadata(name, version, variant_path)
return build_wheel(
metadata,
variant_path=variant_path,
wheel_dir=wheel_dir,
manylinux_version=manylinux_version,
python_version=python_version,
)
def extract_metadata(name: str, version: str, variant_path: Path) -> Metadata:
"""
Extract metadata from the variant path.
"""
if variant_path.name == "torch-universal":
return Metadata(
name=name,
version=version,
cuda_version=None,
cxx_abi_version=None,
torch_version=None,
os=None,
platform=None,
)
if not variant_path.name.startswith("torch"):
raise ValueError("Currently only conversion of Torch kernels is supported.")
variant_parts = variant_path.name.removeprefix("torch").split("-")
if len(variant_parts) != 5:
raise ValueError(f"Invalid variant name: {variant_path.name}")
torch_version = f"{variant_parts[0][:-1]}.{variant_parts[0][-1:]}"
cpp_abi_version = variant_parts[1].removeprefix("cxx")
cuda_version = variant_parts[2].removeprefix("cu")
platform = variant_parts[3].replace("-", "_")
os = variant_parts[4]
return Metadata(
name=name,
version=version,
cuda_version=cuda_version,
cxx_abi_version=cpp_abi_version,
torch_version=torch_version,
os=os,
platform=platform,
)
def build_wheel(
metadata: Metadata,
*,
variant_path: Path,
wheel_dir: Path,
manylinux_version: str = "2.28",
python_version: str = "3.9",
) -> Path:
"""
Build the wheel file.
"""
try:
from wheel.wheelfile import WheelFile # type: ignore
except ImportError:
raise ImportError(
"The 'wheel' package is required to build wheels. Please install it with: `pip install wheel`"
)
name = metadata.name.replace("-", "_")
python_version_flat = python_version.replace(".", "")
if metadata.is_universal:
python_tag = f"py{python_version_flat}"
abi_tag = "none"
platform_tag = "any"
wheel_filename = (
f"{name}-{metadata.version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
)
dist_info_dir_name = f"{name}-{metadata.version}.dist-info"
root_is_purelib = "true"
requires_dist_torch = "torch"
else:
python_tag = f"cp{python_version_flat}"
abi_tag = "abi3"
if (
metadata.torch_version is None
or metadata.cuda_version is None
or metadata.cxx_abi_version is None
or metadata.os is None
or metadata.platform is None
):
raise ValueError(
"Torch version, CUDA version, C++ ABI version, OS, and platform must be specified for non-universal wheels."
)
local_version = f"torch{metadata.torch_version.replace('.', '')}cu{metadata.cuda_version}cxx{metadata.cxx_abi_version}"
if metadata.os == "linux":
platform_tag = (
f"manylinux_{manylinux_version.replace('.', '_')}_{metadata.platform}"
)
else:
platform_tag = f"{metadata.os}_{metadata.platform.replace('-', '_')}"
wheel_filename = f"{name}-{metadata.version}+{local_version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
dist_info_dir_name = f"{name}-{metadata.version}+{local_version}.dist-info"
root_is_purelib = "false"
requires_dist_torch = f"torch=={metadata.torch_version}.*"
wheel_path = wheel_dir / wheel_filename
wheel_msg = Message(email.policy.compat32)
wheel_msg.add_header("Wheel-Version", "1.0")
wheel_msg.add_header("Generator", f"kernels ({KERNELS_VERSION})")
wheel_msg.add_header("Root-Is-Purelib", root_is_purelib)
wheel_msg.add_header("Tag", f"{python_tag}-{abi_tag}-{platform_tag}")
metadata_msg = Message(email.policy.compat32)
metadata_msg.add_header("Metadata-Version", "2.1")
metadata_msg.add_header("Name", name)
metadata_msg.add_header("Version", metadata.version)
metadata_msg.add_header("Summary", f"{name} kernel")
metadata_msg.add_header("Requires-Python", ">=3.9")
metadata_msg.add_header("Requires-Dist", requires_dist_torch)
source_pkg_dir = variant_path / name
with WheelFile(wheel_path, "w") as wheel_file:
for root, dirnames, filenames in os.walk(source_pkg_dir):
for filename in filenames:
if filename.endswith(".pyc"):
continue
abs_filepath = os.path.join(root, filename)
entry_name = os.path.relpath(abs_filepath, variant_path)
wheel_file.write(abs_filepath, entry_name)
wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL")
wheel_file.writestr(wheel_metadata_path, str(wheel_msg).encode("utf-8"))
metadata_path = os.path.join(dist_info_dir_name, "METADATA")
wheel_file.writestr(metadata_path, str(metadata_msg).encode("utf-8"))
return wheel_path

10
tests/conftest.py Normal file
View File

@ -0,0 +1,10 @@
import sys
import pytest
def pytest_runtest_setup(item):
if "linux_only" in item.keywords and not sys.platform.startswith("linux"):
pytest.skip("skipping Linux-only test on non-Linux platform")
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
pytest.skip("skipping macOS-only test on non-macOS platform")

View File

@ -1,54 +1,82 @@
[
{
"repo_id": "kernels-community/activation",
"sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
"sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538",
"variants": {
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d",
"hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58",
"hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a",
"hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0",
"hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8",
"hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc",
"hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b",
"hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005",
"hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-aarch64-linux": {
"hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc",
"hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f",
"hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf",
"hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-aarch64-linux": {
"hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a",
"hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu118-x86_64-linux": {
"hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-aarch64-linux": {
"hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-x86_64-linux": {
"hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-aarch64-linux": {
"hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-x86_64-linux": {
"hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597",
"hash_type": "git_lfs_concat"
}
}

View File

@ -1,7 +1,7 @@
import pytest
import torch
from kernels import get_kernel
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
@pytest.fixture
@ -9,6 +9,19 @@ def kernel():
return get_kernel("kernels-community/activation")
@pytest.fixture
def local_kernel():
package_name, path = install_kernel("kernels-community/activation", "main")
# Path is the build variant path (build/torch-<...>), so the grandparent
# is the kernel repository path.
return get_local_kernel(path.parent.parent, package_name)
@pytest.fixture
def metal_kernel():
return get_kernel("kernels-test/relu-metal")
@pytest.fixture
def universal_kernel():
return get_kernel("kernels-community/triton-scaled-mm")
@ -21,6 +34,7 @@ def device():
return "cuda"
@pytest.mark.linux_only
def test_gelu_fast(kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
@ -36,6 +50,48 @@ def test_gelu_fast(kernel, device):
assert torch.allclose(y, expected)
@pytest.mark.linux_only
def test_local_kernel(local_kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
local_kernel.gelu_fast(y, x)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
device=device,
dtype=torch.float16,
)
assert torch.allclose(y, expected)
@pytest.mark.darwin_only
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_relu_metal(metal_kernel, dtype):
x = torch.arange(-10, 10, dtype=dtype, device="mps")
y = metal_kernel.relu(x)
assert torch.allclose(y, torch.relu(x))
@pytest.mark.linux_only
@pytest.mark.parametrize(
"kernel_exists",
[
("kernels-community/activation", "main", True),
("kernels-community/triton-layer-norm", "main", True),
# Repo only contains Torch 2.4 kernels (and we don't
# support/test against this version).
("kernels-test/only-torch-2.4", "main", False),
("google-bert/bert-base-uncased", "87565a309", False),
],
)
def test_has_kernel(kernel_exists):
repo_id, revision, kernel = kernel_exists
assert has_kernel(repo_id, revision=revision) == kernel
@pytest.mark.linux_only
def test_universal_kernel(universal_kernel):
torch.manual_seed(0)
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")

View File

@ -16,18 +16,21 @@ def device():
return "cuda"
@pytest.mark.linux_only
def test_gelu_small(kernel, device, benchmark):
x = torch.randn(32, 32, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.linux_only
def test_gelu_medium(kernel, device, benchmark):
x = torch.randn(128, 128, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.linux_only
def test_gelu_large(kernel, device, benchmark):
x = torch.randn(512, 512, dtype=torch.float16, device=device)
y = torch.empty_like(x)

View File

@ -1,6 +1,8 @@
from dataclasses import dataclass
from pathlib import Path
import pytest
from kernels import load_kernel
from kernels.cli import download_kernels
@ -17,6 +19,7 @@ def test_download_all_hash_validation():
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
@pytest.mark.linux_only
def test_load_locked():
project_dir = Path(__file__).parent / "kernel_locking"
# Also validates that hashing works correctly.

View File

@ -1,3 +1,5 @@
from contextlib import nullcontext
import pytest
import torch
import torch.nn as nn
@ -6,6 +8,8 @@ from torch.nn import functional as F
from kernels import (
Device,
LayerRepository,
Mode,
kernelize,
register_kernel_mapping,
use_kernel_forward_from_hub,
)
@ -16,14 +20,18 @@ kernel_layer_mapping = {
Device(type="cuda"): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
},
"SiluAndMulNoCompile": {
"cuda": LayerRepository(
repo_id="kernels-test/op-without-fake-test",
layer_name="SiluAndMul",
)
},
"SiluAndMulStringDevice": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
},
}
@ -43,6 +51,11 @@ class SiluAndMul(nn.Module):
return F.silu(input[..., :d]) * input[..., d:]
@use_kernel_forward_from_hub("SiluAndMulNoCompile")
class SiluAndMulNoCompileKernel(SiluAndMul):
pass
@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMulWithKernel(SiluAndMul):
pass
@ -53,6 +66,37 @@ class SiluAndMulStringDevice(SiluAndMul):
pass
@use_kernel_forward_from_hub("Linear")
class TorchLinearWithCounter(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
return super().forward(input)
def test_arg_kinds():
@use_kernel_forward_from_hub("ArgKind")
class ArgKind(nn.Module):
def forward(
self,
arg1,
arg2,
*,
kwarg1,
kwarg2=42,
):
return (arg1, arg2, kwarg1, kwarg2)
arg_kind = ArgKind()
assert arg_kind("foo", "bar", kwarg1="baz") == ("foo", "bar", "baz", 42)
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_hub_forward(cls, device):
@ -62,7 +106,7 @@ def test_hub_forward(cls, device):
X = torch.randn((32, 64), device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
@ -80,11 +124,70 @@ def test_layer_fallback_works():
pass
# Check that we don't raise an exception for a non-existing kernel.
SiluAndMulWithKernelFallback()
silu_and_mul = SiluAndMulWithKernelFallback()
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_without_fallback(cls, device):
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), dtype=torch.float32, device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel.eval()
ctx = (
pytest.raises(ValueError, match="does not support mode")
if cls is SiluAndMulNoCompileKernel
else nullcontext()
)
with ctx:
silu_and_mul_with_kernel = kernelize(
silu_and_mul_with_kernel,
device=device,
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
use_fallback=False,
)
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
Y_compiled = silu_and_mul_compiled(X)
torch.testing.assert_close(Y_compiled, Y)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_with_fallback(cls, device):
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), dtype=torch.float32, device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel.eval()
silu_and_mul_with_kernel = kernelize(
silu_and_mul_with_kernel,
device=device,
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
)
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
Y_compiled = silu_and_mul_compiled(X)
torch.testing.assert_close(Y_compiled, Y)
def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"}
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
}
extra_mapping1 = {
"TestKernel": {
@ -100,6 +203,7 @@ def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
@ -117,26 +221,57 @@ def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/non-existing"
)
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/activation"
)
with use_kernel_mapping(extra_mapping2, inherit_mapping=False):
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/non-existing"
)
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/activation"
)
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
}
@ -166,3 +301,208 @@ def test_validate_kernel_layer():
with pytest.raises(TypeError, match="different kind of arguments"):
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
def test_invalid_mode_for_mapping_rejected():
linear = TorchLinearWithCounter(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
)
}
}
}
):
with pytest.raises(ValueError, match="does not support backward"):
kernelize(linear, mode=Mode.TRAINING)
def test_kernel_modes():
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: layer without further specification, becomes the
# base layer.
with use_kernel_mapping(
{
"Linear": {
"cuda": LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING)
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
assert linear.n_calls == 0
# Case 2: register a kernel just for training. If no base kernel
# layer is registered, we fall back to the original layer.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Training has a kernel, so fallback.
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# No kernel for training + torch.compile, so fallback.
assert linear.n_calls == 2
# Case 3: register a kernel just for training and one for fallback.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.DEFAULT: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Uses the base kernel.
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Uses the training kernel.
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# Uses the base kernel.
assert linear.n_calls == 2
# Case 4: register a kernel with two preferences.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# No inference kernel, so fallback.
assert linear.n_calls == 3
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# No training kernel, so fallback.
assert linear.n_calls == 4
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# We do have a training + torch.compile kernel.
assert linear.n_calls == 4
@pytest.mark.linux_only
def test_fallback_used_when_training():
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: kernel with explicit backward support should always
# use the kernel.
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
linear.train()
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
# Case 2: kernel with implicit backward support should always
# use the kernel.
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearImplicitBackward",
)
}
}
):
linear.train()
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
def test_invalid_mode_rejected():
with pytest.raises(ValueError, match="mutually exclusive"):
_ = Mode.INFERENCE | Mode.TRAINING
with pytest.raises(ValueError, match="cannot be combined with other modes"):
_ = Mode.DEFAULT | Mode.TORCH_COMPILE
with pytest.raises(
ValueError, match="can only be used to register kernel mappings"
):
kernelize(torch.nn.Linear(32, 32), mode=Mode.DEFAULT)
with pytest.raises(ValueError, match="mode must contain"):
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)