mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-21 13:33:48 +08:00
Compare commits
79 Commits
improve-re
...
v0.8.1
Author | SHA1 | Date | |
---|---|---|---|
0429131630 | |||
967ac581b8 | |||
81088d44e8 | |||
4a04c005e3 | |||
6d3c6daf20 | |||
071900fd69 | |||
2d2c6b14e0 | |||
03edc573b1 | |||
c841a6c90d | |||
c7a343f195 | |||
8d838f947d | |||
b87e6fadbe | |||
fc935d9874 | |||
3622e1f8dd | |||
a7f3b2e8ed | |||
a6ab5d83ba | |||
4f9f1abfb9 | |||
f94b7780a6 | |||
bd28883775 | |||
498429e322 | |||
09c991af4b | |||
bcf8df5875 | |||
239afff6f5 | |||
c5ec6b900a | |||
3a635eaeea | |||
32ec496c5a | |||
848c6db87b | |||
fabb8c52d1 | |||
d66260dd83 | |||
daac8078fc | |||
fcb9a80ce6 | |||
c25bb32e6e | |||
2036892762 | |||
0f0de049cf | |||
59597df03e | |||
5e938ede40 | |||
cf530c283a | |||
437f910336 | |||
6f1a6067c8 | |||
1d14abcef0 | |||
6fd2112e22 | |||
70f56ff856 | |||
7178b0b86c | |||
0bbf90a564 | |||
27d6ffcb80 | |||
f7bd21438b | |||
6174febb4b | |||
ff55bc201b | |||
3808108d62 | |||
c4a16ef462 | |||
9762794dd2 | |||
b7d6867c52 | |||
fbcd0f2ebd | |||
5af46eca94 | |||
747dd66876 | |||
920590a592 | |||
5208ac4be5 | |||
22eaba2826 | |||
9521ba79a0 | |||
9861a5bdef | |||
1c7c87c960 | |||
df45cf2795 | |||
cf0413efe5 | |||
851c13f666 | |||
b6a393612f | |||
18ecd0ce69 | |||
b4ef1d60e5 | |||
a40756f306 | |||
3671158f47 | |||
2ddd473cf7 | |||
497dffb89e | |||
f036fd09cb | |||
3e4c83c798 | |||
4116d6019e | |||
bd166b348a | |||
386c2a104e | |||
c7516b9e50 | |||
a8dcd1f6bc | |||
af7fdf9202 |
10
.github/workflows/lint.yml
vendored
Normal file
10
.github/workflows/lint.yml
vendored
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
name: Lints
|
||||||
|
on: [push, pull_request]
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Run lints
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Run ruff
|
||||||
|
uses: astral-sh/ruff-action@v3
|
120
.github/workflows/publish.yml
vendored
Normal file
120
.github/workflows/publish.yml
vendored
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
|
||||||
|
|
||||||
|
on: push
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: Build distribution 📦
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.9"
|
||||||
|
- name: Install pypa/build
|
||||||
|
run: >-
|
||||||
|
python3 -m
|
||||||
|
pip install
|
||||||
|
build
|
||||||
|
--user
|
||||||
|
- name: Build a binary wheel and a source tarball
|
||||||
|
run: python3 -m build
|
||||||
|
- name: Store the distribution packages
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: python-package-distributions
|
||||||
|
path: dist/
|
||||||
|
|
||||||
|
publish-to-pypi:
|
||||||
|
name: >-
|
||||||
|
Publish Python 🐍 distribution 📦 to PyPI
|
||||||
|
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
|
||||||
|
needs:
|
||||||
|
- build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
environment:
|
||||||
|
name: pypi
|
||||||
|
url: https://pypi.org/p/kernels
|
||||||
|
permissions:
|
||||||
|
id-token: write # IMPORTANT: mandatory for trusted publishing
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Download all the dists
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: python-package-distributions
|
||||||
|
path: dist/
|
||||||
|
- name: Publish distribution 📦 to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
|
||||||
|
github-release:
|
||||||
|
name: >-
|
||||||
|
Sign the Python 🐍 distribution 📦 with Sigstore
|
||||||
|
and upload them to GitHub Release
|
||||||
|
needs:
|
||||||
|
- publish-to-pypi
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write # IMPORTANT: mandatory for making GitHub Releases
|
||||||
|
id-token: write # IMPORTANT: mandatory for sigstore
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Download all the dists
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: python-package-distributions
|
||||||
|
path: dist/
|
||||||
|
- name: Sign the dists with Sigstore
|
||||||
|
uses: sigstore/gh-action-sigstore-python@v3.0.0
|
||||||
|
with:
|
||||||
|
inputs: >-
|
||||||
|
./dist/*.tar.gz
|
||||||
|
./dist/*.whl
|
||||||
|
- name: Create GitHub Release
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ github.token }}
|
||||||
|
run: >-
|
||||||
|
gh release create
|
||||||
|
"$GITHUB_REF_NAME"
|
||||||
|
--repo "$GITHUB_REPOSITORY"
|
||||||
|
--notes ""
|
||||||
|
- name: Upload artifact signatures to GitHub Release
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ github.token }}
|
||||||
|
# Upload to GitHub Release using the `gh` CLI.
|
||||||
|
# `dist/` contains the built packages, and the
|
||||||
|
# sigstore-produced signatures and certificates.
|
||||||
|
run: >-
|
||||||
|
gh release upload
|
||||||
|
"$GITHUB_REF_NAME" dist/**
|
||||||
|
--repo "$GITHUB_REPOSITORY"
|
||||||
|
|
||||||
|
publish-to-testpypi:
|
||||||
|
name: Publish Python 🐍 distribution 📦 to TestPyPI
|
||||||
|
needs:
|
||||||
|
- build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
environment:
|
||||||
|
name: testpypi
|
||||||
|
url: https://test.pypi.org/p/kernels
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
id-token: write # IMPORTANT: mandatory for trusted publishing
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Download all the dists
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: python-package-distributions
|
||||||
|
path: dist/
|
||||||
|
- name: Publish distribution 📦 to TestPyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
repository-url: https://test.pypi.org/legacy/
|
||||||
|
skip-existing: true # Only upload when the version is unique.
|
30
.github/workflows/test.yml
vendored
30
.github/workflows/test.yml
vendored
@ -1,4 +1,4 @@
|
|||||||
name: Test hf-kernels
|
name: Test kernels
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -24,7 +24,10 @@ jobs:
|
|||||||
max-parallel: 4
|
max-parallel: 4
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.10", "3.12"]
|
python-version: ["3.10", "3.12"]
|
||||||
torch-version: ["2.5.1", "2.6.0"]
|
torch-version: ["2.6.0", "2.7.0"]
|
||||||
|
|
||||||
|
env:
|
||||||
|
UV_PYTHON_PREFERENCE: only-managed
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
@ -41,5 +44,28 @@ jobs:
|
|||||||
- name: Install the project
|
- name: Install the project
|
||||||
run: uv sync --all-extras --dev
|
run: uv sync --all-extras --dev
|
||||||
|
|
||||||
|
- name: Install setuptools for Triton-based test
|
||||||
|
run: uv pip install setuptools
|
||||||
|
|
||||||
|
- name: Check typing
|
||||||
|
run: uv run mypy src/kernels
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: uv run pytest tests
|
run: uv run pytest tests
|
||||||
|
|
||||||
|
- name: Check kernel conversion
|
||||||
|
run: |
|
||||||
|
uv pip install wheel
|
||||||
|
uv run kernels to-wheel kernels-community/triton-layer-norm 0.0.1
|
||||||
|
uv pip install triton_layer_norm-0.0.1*.whl
|
||||||
|
uv run python -c "import triton_layer_norm"
|
||||||
|
|
||||||
|
- name: Check README generation
|
||||||
|
# For now, just checks that generation doesn't fail.
|
||||||
|
run: |
|
||||||
|
uv run kernels generate-readme kernels-community/triton-layer-norm
|
||||||
|
|
||||||
|
- name: Import check without torch
|
||||||
|
run: |
|
||||||
|
uv pip uninstall torch
|
||||||
|
python -c "import kernels"
|
||||||
|
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
99
README.md
99
README.md
@ -1,11 +1,42 @@
|
|||||||
# hf-kernels
|
# kernels
|
||||||
|
|
||||||
Make sure you have `torch==2.5.1+cu124` installed.
|
<div align="center">
|
||||||
|
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
|
||||||
|
<a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
|
||||||
|
<a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
|
||||||
|
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<hr/>
|
||||||
|
|
||||||
|
The Kernel Hub allows Python libraries and applications to load compute
|
||||||
|
kernels directly from the [Hub](https://hf.co/). To support this kind
|
||||||
|
of dynamic loading, Hub kernels differ from traditional Python kernel
|
||||||
|
packages in that they are made to be:
|
||||||
|
|
||||||
|
- Portable: a kernel can be loaded from paths outside `PYTHONPATH`.
|
||||||
|
- Unique: multiple versions of the same kernel can be loaded in the
|
||||||
|
same Python process.
|
||||||
|
- Compatible: kernels must support all recent versions of Python and
|
||||||
|
the different PyTorch build configurations (various CUDA versions
|
||||||
|
and C++ ABIs). Furthermore, older C library versions must be supported.
|
||||||
|
|
||||||
|
## 🚀 Quick Start
|
||||||
|
|
||||||
|
Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install kernels
|
||||||
|
```
|
||||||
|
|
||||||
|
Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from hf_kernels import get_kernel
|
from kernels import get_kernel
|
||||||
|
|
||||||
# Download optimized kernels from the Hugging Face hub
|
# Download optimized kernels from the Hugging Face hub
|
||||||
activation = get_kernel("kernels-community/activation")
|
activation = get_kernel("kernels-community/activation")
|
||||||
@ -20,57 +51,15 @@ activation.gelu_fast(y, x)
|
|||||||
print(y)
|
print(y)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Docker Reference
|
You can [search for kernels](https://huggingface.co/models?other=kernel) on
|
||||||
|
the Hub.
|
||||||
|
|
||||||
build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
|
## 📚 Documentation
|
||||||
|
|
||||||
```bash
|
- [Using layers](docs/layers.md)
|
||||||
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
|
- [Locking kernel/layer versions](docs/locking.md)
|
||||||
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
|
- [Environment variables](docs/env.md)
|
||||||
```
|
- [Using kernels in a Docker container](docs/docker.md)
|
||||||
|
- [Kernel requirements](docs/kernel-requirements.md)
|
||||||
## Locking kernel versions
|
- [Frequently Asked Questions](docs/faq.md)
|
||||||
|
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||||
Projects that use `setuptools` can lock the kernel versions that should be
|
|
||||||
used. First specify the accepted versions in `pyproject.toml` and make
|
|
||||||
sure that `hf-kernels` is a build dependency:
|
|
||||||
|
|
||||||
```toml
|
|
||||||
[build-system]
|
|
||||||
requires = ["hf-kernels", "setuptools"]
|
|
||||||
build-backend = "setuptools.build_meta"
|
|
||||||
|
|
||||||
[tool.kernels.dependencies]
|
|
||||||
"kernels-community/activation" = ">=0.0.1"
|
|
||||||
```
|
|
||||||
|
|
||||||
Then run `hf-kernel lock .` in the project directory. This generates a `kernels.lock` file with
|
|
||||||
the locked revisions. The locked revision will be used when loading a kernel with
|
|
||||||
`get_locked_kernel`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from hf_kernels import get_locked_kernel
|
|
||||||
|
|
||||||
activation = get_locked_kernel("kernels-community/activation")
|
|
||||||
```
|
|
||||||
|
|
||||||
**Note:** the lock file is included in the package metadata, so it will only be visible
|
|
||||||
to `hf-kernels` after doing an (editable or regular) installation of your project.
|
|
||||||
|
|
||||||
## Pre-downloading locked kernels
|
|
||||||
|
|
||||||
Locked kernels can be pre-downloaded by running `hf-kernel download .` in your
|
|
||||||
project directory. This will download the kernels to your local Hugging Face
|
|
||||||
Hub cache.
|
|
||||||
|
|
||||||
The pre-downloaded kernels are used by the `get_locked_kernel` function.
|
|
||||||
`get_locked_kernel` will download a kernel when it is not pre-downloaded. If you
|
|
||||||
want kernel loading to error when a kernel is not pre-downloaded, you can use
|
|
||||||
the `load_kernel` function instead:
|
|
||||||
|
|
||||||
````python
|
|
||||||
```python
|
|
||||||
from hf_kernels import load_kernel
|
|
||||||
|
|
||||||
activation = load_kernel("kernels-community/activation")
|
|
||||||
````
|
|
||||||
|
@ -31,13 +31,13 @@ WORKDIR /app/kernel-test
|
|||||||
# install python depdencies
|
# install python depdencies
|
||||||
RUN uv add torch==2.5.0 numpy
|
RUN uv add torch==2.5.0 numpy
|
||||||
|
|
||||||
# copy hf-kernels lib
|
# copy kernels lib
|
||||||
COPY src ./hf-kernels/src
|
COPY src ./kernels/src
|
||||||
COPY pyproject.toml ./hf-kernels/pyproject.toml
|
COPY pyproject.toml ./kernels/pyproject.toml
|
||||||
COPY README.md ./hf-kernels/README.md
|
COPY README.md ./kernels/README.md
|
||||||
|
|
||||||
# install library
|
# install library
|
||||||
RUN uv pip install -e hf-kernels
|
RUN uv pip install -e kernels
|
||||||
|
|
||||||
# copy examples
|
# copy examples
|
||||||
COPY examples ./examples
|
COPY examples ./examples
|
||||||
@ -48,4 +48,4 @@ ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
|||||||
|
|
||||||
# command to run the script
|
# command to run the script
|
||||||
CMD ["uv", "run", "examples/basic.py"]
|
CMD ["uv", "run", "examples/basic.py"]
|
||||||
# CMD ["ls", "hf-kernels"]
|
# CMD ["ls", "kernels"]
|
||||||
|
8
docs/docker.md
Normal file
8
docs/docker.md
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Using kernels in a Docker container
|
||||||
|
|
||||||
|
build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
|
||||||
|
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
|
||||||
|
```
|
10
docs/env.md
Normal file
10
docs/env.md
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Environment variables
|
||||||
|
|
||||||
|
## `KERNELS_CACHE`
|
||||||
|
|
||||||
|
The directory to use as the local kernel cache. If not set, the cache
|
||||||
|
of the `huggingface_hub` package is used.
|
||||||
|
|
||||||
|
## `DISABLE_KERNEL_MAPPING`
|
||||||
|
|
||||||
|
Disables kernel mappings for [`layers`](layers.md).
|
13
docs/faq.md
Normal file
13
docs/faq.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# FAQ
|
||||||
|
|
||||||
|
## Why is the kernelization step needed?
|
||||||
|
|
||||||
|
In earlier versions of `kernels`, a layer's `forward` was replaced by
|
||||||
|
`use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The
|
||||||
|
new `forward` would dispatch to a kernel based on the device type,
|
||||||
|
whether a model was training, etc. However, this approach was
|
||||||
|
fundamentally incompatible with `torch.compile` since it relied
|
||||||
|
on data-dependent branching.
|
||||||
|
|
||||||
|
To avoid branching, we have to make dispatch decisions ahead of time,
|
||||||
|
which is what the `kernelize` function does.
|
210
docs/kernel-requirements.md
Normal file
210
docs/kernel-requirements.md
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
# Kernel requirements
|
||||||
|
|
||||||
|
Kernels on the Hub must fulfill the requirements outlined on this page. By
|
||||||
|
ensuring kernels are compliant, they can be used on a wide range of Linux
|
||||||
|
systems and Torch builds.
|
||||||
|
|
||||||
|
You can use [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||||
|
to build compliant kernels.
|
||||||
|
|
||||||
|
## Directory layout
|
||||||
|
|
||||||
|
A kernel repository on the Hub must contain a `build` directory. This
|
||||||
|
directory contains build variants of a kernel in the form of directories
|
||||||
|
following the template
|
||||||
|
`<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`.
|
||||||
|
For example `build/torch26-cxx98-cu118-x86_64-linux`.
|
||||||
|
|
||||||
|
Each variant directory must contain a single directory with the same name
|
||||||
|
as the repository (replacing `-` by `_`). For instance, kernels in the
|
||||||
|
`kernels-community/activation` repository have a directories like
|
||||||
|
`build/<variant>/activation`. This directory
|
||||||
|
must be a Python package with an `__init__.py` file.
|
||||||
|
|
||||||
|
## Build variants
|
||||||
|
|
||||||
|
A kernel can be compliant for a specific compute framework (e.g. CUDA) or
|
||||||
|
architecture (e.g. x86_64). For compliance with a compute framework and
|
||||||
|
architecture combination, all the variants from the [build variant list](https://github.com/huggingface/kernel-builder/blob/main/docs/build-variants.md)
|
||||||
|
must be available for that combination.
|
||||||
|
|
||||||
|
## Versioning
|
||||||
|
|
||||||
|
Kernels are versioned on the Hub using Git tags. Version tags must be of
|
||||||
|
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
|
||||||
|
to resolve the version constraints.
|
||||||
|
|
||||||
|
## Native Python module
|
||||||
|
|
||||||
|
Kernels will typically contain a native Python module with precompiled
|
||||||
|
compute kernels and bindings. This module must fulfill the requirements
|
||||||
|
outlined in this section. For all operating systems, a kernel must not
|
||||||
|
have dynamic library dependencies outside:
|
||||||
|
|
||||||
|
- Torch;
|
||||||
|
- CUDA/ROCm libraries installed as dependencies of Torch.
|
||||||
|
|
||||||
|
### Linux
|
||||||
|
|
||||||
|
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||||
|
for compatibility with Python 3.9 and later.
|
||||||
|
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
|
||||||
|
This means that the extension **must not** use symbols versions higher than:
|
||||||
|
|
||||||
|
- GLIBC 2.28
|
||||||
|
- GLIBCXX 3.4.24
|
||||||
|
- CXXABI 1.3.11
|
||||||
|
- GCC 7.0.0
|
||||||
|
|
||||||
|
These requirement can be checked with the ABI checker (see below).
|
||||||
|
|
||||||
|
### macOS
|
||||||
|
|
||||||
|
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||||
|
for compatibility with Python 3.9 and later.
|
||||||
|
- macOS deployment target 15.0.
|
||||||
|
- Metal 3.0 (`-std=metal3.0`).
|
||||||
|
|
||||||
|
The ABI3 requirement can be checked with the ABI checker (see below).
|
||||||
|
|
||||||
|
### ABI checker
|
||||||
|
|
||||||
|
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
|
||||||
|
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
|
||||||
|
$ cargo install kernel-abi-check
|
||||||
|
$ kernel-abi-check result/relu/_relu_e87e0ca_dirty.abi3.so
|
||||||
|
🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9
|
||||||
|
✅ No compatibility issues found
|
||||||
|
```
|
||||||
|
|
||||||
|
## Torch extension
|
||||||
|
|
||||||
|
Torch native extension functions must be [registered](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial)
|
||||||
|
in `torch.ops.<namespace>`. Since we allow loading of multiple versions of
|
||||||
|
a module in the same Python process, `namespace` must be unique for each
|
||||||
|
version of a kernel. Failing to do so will create clashes when different
|
||||||
|
versions of the same kernel are loaded. Two suggested ways of doing this
|
||||||
|
are:
|
||||||
|
|
||||||
|
- Appending a truncated SHA-1 hash of the git commit that the kernel was
|
||||||
|
built from to the name of the extension.
|
||||||
|
- Appending random material to the name of the extension.
|
||||||
|
|
||||||
|
**Note:** we recommend against appending a version number or git tag.
|
||||||
|
Version numbers are typically not bumped on each commit, so users
|
||||||
|
might use two different commits that happen to have the same version
|
||||||
|
number. Git tags are not stable, so they do not provide a good way
|
||||||
|
of guaranteeing uniqueness of the namespace.
|
||||||
|
|
||||||
|
## Layers
|
||||||
|
|
||||||
|
A kernel can provide layers in addition to kernel functions. A layer from
|
||||||
|
the Hub can replace the `forward` method of an existing layer for a certain
|
||||||
|
device type. This makes it possible to provide more performant kernels for
|
||||||
|
existing layers. See the [layers documentation](layers.md) for more information
|
||||||
|
on how to use layers.
|
||||||
|
|
||||||
|
### Writing layers
|
||||||
|
|
||||||
|
To make the extension of layers safe, the layers must fulfill the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- The layers are subclasses of `torch.nn.Module`.
|
||||||
|
- The layers are pure, meaning that they do not have their own state. This
|
||||||
|
means that:
|
||||||
|
- The layer must not define its own constructor.
|
||||||
|
- The layer must not use class variables.
|
||||||
|
- No other methods must be defined than `forward`.
|
||||||
|
- The `forward` method has a signature that is compatible with the
|
||||||
|
`forward` method that it is extending.
|
||||||
|
|
||||||
|
There are two exceptions to the _no class variables rule_:
|
||||||
|
|
||||||
|
1. The `has_backward` variable can be used to indicate whether the layer has
|
||||||
|
a backward pass implemented (`True` when absent).
|
||||||
|
2. The `can_torch_compile` variable can be used to indicate whether the layer
|
||||||
|
supports `torch.compile` (`False` when absent).
|
||||||
|
|
||||||
|
This is an example of a pure layer:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SiluAndMul(nn.Module):
|
||||||
|
# This layer does not implement backward.
|
||||||
|
has_backward: bool = False
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
output_shape = x.shape[:-1] + (d,)
|
||||||
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
|
ops.silu_and_mul(out, x)
|
||||||
|
return out
|
||||||
|
```
|
||||||
|
|
||||||
|
For some layers, the `forward` method has to use state from the adopting class.
|
||||||
|
In these cases, we recommend to use type annotations to indicate what member
|
||||||
|
variables are expected. For instance:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class LlamaRMSNorm(nn.Module):
|
||||||
|
weight: torch.Tensor
|
||||||
|
variance_epsilon: float
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
return rms_norm_fn(
|
||||||
|
hidden_states,
|
||||||
|
self.weight,
|
||||||
|
bias=None,
|
||||||
|
residual=None,
|
||||||
|
eps=self.variance_epsilon,
|
||||||
|
dropout_p=0.0,
|
||||||
|
prenorm=False,
|
||||||
|
residual_in_fp32=False,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This layer expects the adopting layer to have `weight` and `variance_epsilon`
|
||||||
|
member variables and uses them in the `forward` method.
|
||||||
|
|
||||||
|
### Exporting layers
|
||||||
|
|
||||||
|
To accommodate portable loading, `layers` must be defined in the main
|
||||||
|
`__init__.py` file. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from . import layers
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# ...
|
||||||
|
"layers"
|
||||||
|
# ...
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Python requirements
|
||||||
|
|
||||||
|
- Python code must be compatible with Python 3.9 and later.
|
||||||
|
- All Python code imports from the kernel itself must be relative. So,
|
||||||
|
for instance if in the example kernel `example`,
|
||||||
|
`module_b` needs a function from `module_a`, import as:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .module_a import foo
|
||||||
|
```
|
||||||
|
|
||||||
|
**Never use:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# DO NOT DO THIS!
|
||||||
|
|
||||||
|
from example.module_a import foo
|
||||||
|
```
|
||||||
|
|
||||||
|
The latter would import from the module `example` that is in Python's
|
||||||
|
global module dict. However, since we allow loading multiple versions
|
||||||
|
of a module, we uniquely name the module.
|
||||||
|
|
||||||
|
- Only modules from the Python standard library, Torch, or the kernel itself
|
||||||
|
can be imported.
|
272
docs/layers.md
Normal file
272
docs/layers.md
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
# Layers
|
||||||
|
|
||||||
|
A kernel can provide layers in addition to kernel functions. A layer from
|
||||||
|
the Hub can replace the `forward` method of an existing layer for a certain
|
||||||
|
device type. This makes it possible to provide more performant kernels for
|
||||||
|
existing layers.
|
||||||
|
|
||||||
|
See [Kernel requirements](kernel-requirements.md) for more information the
|
||||||
|
requirements of Hub layers.
|
||||||
|
|
||||||
|
## Making a layer extensible with kernels from the hub
|
||||||
|
|
||||||
|
### Using a decorator
|
||||||
|
|
||||||
|
A layer can be made extensible with the `use_kernel_forward_from_hub`
|
||||||
|
decorator. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@use_kernel_forward_from_hub("SiluAndMul")
|
||||||
|
class SiluAndMul(nn.Module):
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
d = input.shape[-1] // 2
|
||||||
|
return F.silu(input[..., :d]) * input[..., d:]
|
||||||
|
```
|
||||||
|
|
||||||
|
The decorator does not change the behavior of the class -- it annotates
|
||||||
|
the class with the given name (here `SiluAndMul`). The `kernelize` function
|
||||||
|
described below uses this name to look up kernels for the layer.
|
||||||
|
|
||||||
|
### External layers
|
||||||
|
|
||||||
|
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
|
||||||
|
decorator can be made extensible using the `replace_kernel_forward_from_hub`
|
||||||
|
function:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from somelibrary import SiluAndMul
|
||||||
|
|
||||||
|
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Warning:** we strongly recommend using layers with a decorator, since
|
||||||
|
it signifies that the maintainer intends to keep the `forward` signature
|
||||||
|
compatible with layers from the hub.
|
||||||
|
|
||||||
|
## Kernelizing a model
|
||||||
|
|
||||||
|
A model will not use Hub kernels by default, even if it contains extensible
|
||||||
|
layers. To enable the use of Hub kernels in the model, it needs to be
|
||||||
|
'kernelized' using the `kernelize` function. This function traverses the
|
||||||
|
model graph and replaces the `forward` methods of extensible layers for which
|
||||||
|
Hub kernels are registered. `kernelize` can be used as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = MyModel(...)
|
||||||
|
model = kernelize(model, mode=Mode.INFERENCE)
|
||||||
|
```
|
||||||
|
|
||||||
|
The `kernelize` function modifies the model in-place, the model itself is
|
||||||
|
returned as a convenience. The `mode` specifies that the model will be used
|
||||||
|
in inference. Similarly, you can ask `kernelize` to prepare the model for
|
||||||
|
training:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = MyModel(...)
|
||||||
|
model = kernelize(model, mode=Mode.TRAINING)
|
||||||
|
```
|
||||||
|
|
||||||
|
A model that is kernelized for training can also be used for inference, but
|
||||||
|
not the other way around. If you want to change the mode of the kernelized
|
||||||
|
model, you can just run `kernelize` on the model again with the new mode.
|
||||||
|
|
||||||
|
If you want to compile a model with `torch.compile`, this should be indicated
|
||||||
|
in the mode as well. You can do this by combining `Mode.INFERENCE` or
|
||||||
|
`Mode.TRAINING` with `Mode.TORCH_COMPILE` using the set union (`|`) operator:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = MyModel(...)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
|
```
|
||||||
|
|
||||||
|
When the `mode` argument is not specified,
|
||||||
|
`Mode.TRAINING | Mode.TORCH_COMPILE` is used as the default. This mode
|
||||||
|
aligns most closely with pure PyTorch layers which also support training
|
||||||
|
and `torch.compile`. However, to select the most performant kernels, it
|
||||||
|
is often good to make the mode specific as possible.
|
||||||
|
|
||||||
|
### Kernel device
|
||||||
|
|
||||||
|
Kernels can be registered per device type. For instance, separate `cuda` and
|
||||||
|
`metal` kernels could be registered for the name `SiluAndMul`. By default,
|
||||||
|
`kernelize` will try to infer the device type from the model's parameters.
|
||||||
|
You can pass the device type to `kernelize` if the device type cannot be
|
||||||
|
inferred (e.g. because the model has no parameters):
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = MyModel(...)
|
||||||
|
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fallback `forward`
|
||||||
|
|
||||||
|
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
|
||||||
|
kernel does not support backward passes or `torch.compile` respectively,
|
||||||
|
`kernenize` will fall back to the original, non-kernelized, layer. You
|
||||||
|
can let `kernelize` raise an exception instead by using `use_fallback=False`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = MyModel(...)
|
||||||
|
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
This can be useful if you want to guarantee that Hub kernels are used.
|
||||||
|
|
||||||
|
### Inspecting kernels which kernels are used
|
||||||
|
|
||||||
|
The kernels that are used are logged at the `INFO` level by `kernelize`.
|
||||||
|
See the [Python logging](https://docs.python.org/3/library/logging.html)
|
||||||
|
documentation for information on how to configure logging.
|
||||||
|
|
||||||
|
## Registering a hub kernel for a layer
|
||||||
|
|
||||||
|
`kernelize` relies on kernel mappings to find Hub kernels for layers.
|
||||||
|
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
|
||||||
|
the Hub. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
kernel_layer_mapping = {
|
||||||
|
"SiluAndMul": {
|
||||||
|
"cuda": LayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can register such a mapping using `register_kernel_mapping`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
register_kernel_mapping(kernel_layer_mapping)
|
||||||
|
```
|
||||||
|
|
||||||
|
This will register the kernel mapping in the current context, which is
|
||||||
|
normally global. It is recommended to scope the mapping to where it is
|
||||||
|
used with the `use_kernel_mapping` context manager:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with use_kernel_mapping(kernel_layer_mapping):
|
||||||
|
# Use the layer for which the mapping is applied.
|
||||||
|
model = kernelize(model)
|
||||||
|
```
|
||||||
|
|
||||||
|
This ensures that the mapping is not active anymore outside the
|
||||||
|
`with`-scope.
|
||||||
|
|
||||||
|
### Registering kernels for specific modes
|
||||||
|
|
||||||
|
You might want to register two different kernels for a particular layer,
|
||||||
|
where one kernel is optimized for a specific mode. You can do so by
|
||||||
|
registering layer repositories for specific modes. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
kernel_layer_mapping = {
|
||||||
|
"SiluAndMul": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.INFERENCE: LayerRepository(
|
||||||
|
repo_id="kernels-community/activation-inference-optimized",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
),
|
||||||
|
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
|
||||||
|
repo_id="kernels-community/activation-training-optimized",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The `kernelize` function will attempt to use the following registered
|
||||||
|
kernels for a given mode:
|
||||||
|
|
||||||
|
- `INFERENCE`: `INFERENCE` → `INFERENCE | TORCH_COMPILE` → `TRAINING` →
|
||||||
|
`TRAINING | TORCH_COMPILE` → `FALLBACK`
|
||||||
|
- `INFERENCE | TORCH_COMPILE`: `INFERENCE | TORCH_COMPILE` →
|
||||||
|
`TRAINING | TORCH_COMPILE` → `FALLBACK`
|
||||||
|
- `TRAINING`: `TRAINING` → `TRAINING | TORCH_COMPILE` → `FALLBACK`
|
||||||
|
- `TRAINING | TORCH_COMPILE`: `TRAINING | TORCH_COMPILE` → `FALLBACK`
|
||||||
|
|
||||||
|
`Mode.FALLBACK` is a special mode that is used when no other mode matches. It
|
||||||
|
is also used when a kernel is registered without a mode, as described in the
|
||||||
|
previous section.
|
||||||
|
|
||||||
|
```python
|
||||||
|
kernel_layer_mapping = {
|
||||||
|
"SiluAndMul": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.FALLBACK: LayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
),
|
||||||
|
Mode.INFERENCE: LayerRepository(
|
||||||
|
repo_id="kernels-community/activation-inference-optimized",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
),
|
||||||
|
Mode.TRAINING: LayerRepository(
|
||||||
|
repo_id="kernels-community/activation-training-optimized",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
In this case, both `Mode.INFERENCE | Mode.TORCH_COMPILE` and
|
||||||
|
`Mode.TRAINING | Mode.TORCH_COMPILE` will use the `Mode.FALLBACK` kernel,
|
||||||
|
since the other kernels do not support `torch.compile`.
|
||||||
|
|
||||||
|
### Registering kernels for specific CUDA capabilities
|
||||||
|
|
||||||
|
Some kernels only work with newer CUDA architectures. For instance, some
|
||||||
|
kernels require capability 9.0 for the TMA unit on Hopper GPUs. `kernels`
|
||||||
|
supports registering layers for a range of CUDA capabilities. To do so,
|
||||||
|
you need to register the layer for a `Device` with type `cuda` and
|
||||||
|
set the supported range of CUDA capabilities with using `CUDAProperties`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
kernel_layer_mapping = {
|
||||||
|
"SiluAndMul": {
|
||||||
|
Device(
|
||||||
|
type="cuda",
|
||||||
|
properties=CUDAProperties(
|
||||||
|
min_capability=75, max_capability=89
|
||||||
|
),
|
||||||
|
): LayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
),
|
||||||
|
Device(
|
||||||
|
type="cuda",
|
||||||
|
properties=CUDAProperties(
|
||||||
|
min_capability=90, max_capability=sys.maxsize
|
||||||
|
),
|
||||||
|
): LayerRepository(
|
||||||
|
repo_id="kernels-community/activation-hopper",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Capabilities behave as follows:
|
||||||
|
|
||||||
|
- The minimum and maximum capabilities are inclusive.
|
||||||
|
- When a new kernel is registered with the same min/max capabilities as
|
||||||
|
an existing kernel, the new kernel will replace the old kernel.
|
||||||
|
- When there are multiple kernels that support a capability, the kernel
|
||||||
|
with the smaller capability interval will be used. E.g. given:
|
||||||
|
|
||||||
|
- `KernelA` with `min_capability=80` and `max_capability=89`;
|
||||||
|
- `KernelB` with `min_capability=75` and `max_capability=89`;
|
||||||
|
- `kernelize` runs on a system with capability 8.6.
|
||||||
|
|
||||||
|
Then `KernelA` will be used because the interval 80..89 is smaller
|
||||||
|
than 75..89. The motivation is that kernels with smaller ranges
|
||||||
|
tend to be more optimized for a specific set of GPUs. **This behavior
|
||||||
|
might still change in the future.**
|
62
docs/locking.md
Normal file
62
docs/locking.md
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
# Locking kernel/layer versions
|
||||||
|
|
||||||
|
Projects that use `setuptools` can lock the kernel versions that should be
|
||||||
|
used. First specify the accepted versions in `pyproject.toml` and make
|
||||||
|
sure that `kernels` is a build dependency:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[build-system]
|
||||||
|
requires = ["kernels", "setuptools"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.kernels.dependencies]
|
||||||
|
"kernels-community/activation" = ">=0.0.1"
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with
|
||||||
|
the locked revisions. The locked revision will be used when loading a kernel with
|
||||||
|
`get_locked_kernel`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from kernels import get_locked_kernel
|
||||||
|
|
||||||
|
activation = get_locked_kernel("kernels-community/activation")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** the lock file is included in the package metadata, so it will only be visible
|
||||||
|
to `kernels` after doing an (editable or regular) installation of your project.
|
||||||
|
|
||||||
|
## Locked kernel layers
|
||||||
|
|
||||||
|
Locking is also supported for kernel layers. To use locked layers, register them
|
||||||
|
with the `LockedLayerRepository` class:
|
||||||
|
|
||||||
|
```python
|
||||||
|
kernel_layer_mapping = {
|
||||||
|
"SiluAndMul": {
|
||||||
|
"cuda": LockedLayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
register_kernel_mapping(kernel_layer_mapping)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pre-downloading locked kernels
|
||||||
|
|
||||||
|
Locked kernels can be pre-downloaded by running `kernels download .` in your
|
||||||
|
project directory. This will download the kernels to your local Hugging Face
|
||||||
|
Hub cache.
|
||||||
|
|
||||||
|
The pre-downloaded kernels are used by the `get_locked_kernel` function.
|
||||||
|
`get_locked_kernel` will download a kernel when it is not pre-downloaded. If you
|
||||||
|
want kernel loading to error when a kernel is not pre-downloaded, you can use
|
||||||
|
the `load_kernel` function instead:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from kernels import load_kernel
|
||||||
|
|
||||||
|
activation = load_kernel("kernels-community/activation")
|
||||||
|
```
|
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from hf_kernels import get_kernel
|
from kernels import get_kernel
|
||||||
|
|
||||||
print("Starting examples/basic.py demo")
|
print("Starting examples/basic.py demo")
|
||||||
|
|
||||||
|
133
flake.lock
generated
Normal file
133
flake.lock
generated
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
{
|
||||||
|
"nodes": {
|
||||||
|
"flake-compat": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1733328505,
|
||||||
|
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
||||||
|
"owner": "edolstra",
|
||||||
|
"repo": "flake-compat",
|
||||||
|
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "edolstra",
|
||||||
|
"repo": "flake-compat",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"flake-utils": {
|
||||||
|
"inputs": {
|
||||||
|
"systems": "systems"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1731533236,
|
||||||
|
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"flake-utils_2": {
|
||||||
|
"inputs": {
|
||||||
|
"systems": "systems_2"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1731533236,
|
||||||
|
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"hf-nix": {
|
||||||
|
"inputs": {
|
||||||
|
"flake-compat": "flake-compat",
|
||||||
|
"flake-utils": "flake-utils_2",
|
||||||
|
"nixpkgs": "nixpkgs"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1750775451,
|
||||||
|
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=",
|
||||||
|
"owner": "huggingface",
|
||||||
|
"repo": "hf-nix",
|
||||||
|
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "huggingface",
|
||||||
|
"repo": "hf-nix",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nixpkgs": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1747820358,
|
||||||
|
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
||||||
|
"owner": "danieldk",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "danieldk",
|
||||||
|
"ref": "cudatoolkit-12.9-kernel-builder",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"inputs": {
|
||||||
|
"flake-utils": "flake-utils",
|
||||||
|
"hf-nix": "hf-nix",
|
||||||
|
"nixpkgs": [
|
||||||
|
"hf-nix",
|
||||||
|
"nixpkgs"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"systems": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1681028828,
|
||||||
|
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"systems_2": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1681028828,
|
||||||
|
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": "root",
|
||||||
|
"version": 7
|
||||||
|
}
|
57
flake.nix
Normal file
57
flake.nix
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
{
|
||||||
|
inputs = {
|
||||||
|
hf-nix.url = "github:huggingface/hf-nix";
|
||||||
|
nixpkgs.follows = "hf-nix/nixpkgs";
|
||||||
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
|
};
|
||||||
|
outputs =
|
||||||
|
{
|
||||||
|
self,
|
||||||
|
nixpkgs,
|
||||||
|
flake-utils,
|
||||||
|
hf-nix,
|
||||||
|
}:
|
||||||
|
flake-utils.lib.eachDefaultSystem (
|
||||||
|
system:
|
||||||
|
let
|
||||||
|
pkgs = import nixpkgs {
|
||||||
|
inherit system;
|
||||||
|
config = hf-nix.lib.config system;
|
||||||
|
overlays = [
|
||||||
|
hf-nix.overlays.default
|
||||||
|
];
|
||||||
|
};
|
||||||
|
in
|
||||||
|
{
|
||||||
|
formatter = pkgs.nixfmt-tree;
|
||||||
|
devShells = with pkgs; rec {
|
||||||
|
default = mkShell {
|
||||||
|
buildInputs =
|
||||||
|
[
|
||||||
|
black
|
||||||
|
mypy
|
||||||
|
pyright
|
||||||
|
ruff
|
||||||
|
]
|
||||||
|
++ (with python3.pkgs; [
|
||||||
|
docutils
|
||||||
|
huggingface-hub
|
||||||
|
pytest
|
||||||
|
pytest-benchmark
|
||||||
|
pyyaml
|
||||||
|
torch
|
||||||
|
types-pyyaml
|
||||||
|
venvShellHook
|
||||||
|
]);
|
||||||
|
|
||||||
|
venvDir = "./.venv";
|
||||||
|
|
||||||
|
postVenvCreation = ''
|
||||||
|
unset SOURCE_DATE_EPOCH
|
||||||
|
( python -m pip install --no-build-isolation --no-dependencies -e . )
|
||||||
|
'';
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
@ -1,20 +1,21 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "hf-kernels"
|
name = "kernels"
|
||||||
version = "0.1.6"
|
version = "0.8.1"
|
||||||
description = "Download cuda kernels"
|
description = "Download compute kernels"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||||
{ name = "Daniel de Kok", email = "daniel@huggingface.co" },
|
{ name = "Daniel de Kok", email = "daniel@huggingface.co" },
|
||||||
{ name = "David Holtz", email = "david@huggingface.co" },
|
{ name = "David Holtz", email = "david@huggingface.co" },
|
||||||
{ name = "Nicolas Patry", email = "nicolas@huggingface.co" },
|
{ name = "Nicolas Patry", email = "nicolas@huggingface.co" },
|
||||||
]
|
]
|
||||||
|
license = { text = "Apache-2.0" }
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">= 3.9"
|
requires-python = ">= 3.9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"huggingface-hub>=0.26.3",
|
"huggingface_hub>=0.26.0,<1.0",
|
||||||
"packaging>=24.2",
|
"packaging>=20.0",
|
||||||
"tomli>=2.0.1; python_version<'3.11'",
|
"pyyaml>=6",
|
||||||
"torch>=2.4",
|
"tomli>=2.0; python_version<'3.11'",
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
@ -23,18 +24,47 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
|
"mypy >= 1.15.0",
|
||||||
"pytest >=8",
|
"pytest >=8",
|
||||||
# Whatever version is compatible with pytest.
|
# Whatever version is compatible with pytest.
|
||||||
"pytest-benchmark",
|
"pytest-benchmark",
|
||||||
|
"torch >=2.5",
|
||||||
|
"types-pyyaml"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
torch = ["torch"]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
hf-kernels = "hf_kernels.cli:main"
|
kernels = "kernels.cli:main"
|
||||||
|
|
||||||
[project.entry-points."egg_info.writers"]
|
[project.entry-points."egg_info.writers"]
|
||||||
"hf-kernels.lock" = "hf_kernels.lockfile:write_egg_lockfile"
|
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
|
||||||
|
|
||||||
#[build-system]
|
|
||||||
#requires = ["torch", "huggingface_hub", "numpy", "tomli;python_version<='3.10'"]
|
[tool.ruff]
|
||||||
#build-backend = "hf_kernels.build"
|
exclude = [
|
||||||
#backend-path = ["src"]
|
".eggs",
|
||||||
|
".git",
|
||||||
|
".git-rewrite",
|
||||||
|
".hg",
|
||||||
|
".mypy_cache",
|
||||||
|
".nox",
|
||||||
|
".pants.d",
|
||||||
|
".pytype",
|
||||||
|
".ruff_cache",
|
||||||
|
".svn",
|
||||||
|
".tox",
|
||||||
|
".venv",
|
||||||
|
".venv*",
|
||||||
|
"__pypackages__",
|
||||||
|
"_build",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"venv",
|
||||||
|
]
|
||||||
|
line-length = 119
|
||||||
|
# Ignored rules:
|
||||||
|
# "E501" -> line length violation
|
||||||
|
lint.ignore = ["E501"]
|
||||||
|
lint.select = ["E", "F", "I", "W"]
|
||||||
|
4
pytest.ini
Normal file
4
pytest.ini
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[pytest]
|
||||||
|
markers =
|
||||||
|
darwin_only: marks tests that should only run on macOS
|
||||||
|
linux_only: marks tests that should only run on Linux
|
@ -1,3 +0,0 @@
|
|||||||
from hf_kernels.utils import get_kernel, install_kernel, load_kernel, get_locked_kernel
|
|
||||||
|
|
||||||
__all__ = ["get_kernel", "get_locked_kernel", "load_kernel", "install_kernel"]
|
|
@ -1,144 +0,0 @@
|
|||||||
"""
|
|
||||||
Python shims for the PEP 517 and PEP 660 build backend.
|
|
||||||
|
|
||||||
Major imports in this module are required to be lazy:
|
|
||||||
```
|
|
||||||
$ hyperfine \
|
|
||||||
"/usr/bin/python3 -c \"print('hi')\"" \
|
|
||||||
"/usr/bin/python3 -c \"from subprocess import check_call; print('hi')\""
|
|
||||||
Base: Time (mean ± σ): 11.0 ms ± 1.7 ms [User: 8.5 ms, System: 2.5 ms]
|
|
||||||
With import: Time (mean ± σ): 15.2 ms ± 2.0 ms [User: 12.3 ms, System: 2.9 ms]
|
|
||||||
Base 1.38 ± 0.28 times faster than with import
|
|
||||||
```
|
|
||||||
|
|
||||||
The same thing goes for the typing module, so we use Python 3.10 type annotations that
|
|
||||||
don't require importing typing but then quote them so earlier Python version ignore
|
|
||||||
them while IDEs and type checker can see through the quotes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from hf_kernels.compat import tomllib
|
|
||||||
|
|
||||||
TYPE_CHECKING = False
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Mapping, Sequence # noqa:I001
|
|
||||||
from typing import Any # noqa:I001
|
|
||||||
|
|
||||||
|
|
||||||
def warn_config_settings(config_settings: "Mapping[Any, Any] | None" = None) -> None:
|
|
||||||
import sys
|
|
||||||
|
|
||||||
if config_settings:
|
|
||||||
print("Warning: Config settings are not supported", file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
def call(
|
|
||||||
args: "Sequence[str]", config_settings: "Mapping[Any, Any] | None" = None
|
|
||||||
) -> str:
|
|
||||||
"""Invoke a uv subprocess and return the filename from stdout."""
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
|
|
||||||
warn_config_settings(config_settings)
|
|
||||||
# Unlike `find_uv_bin`, this mechanism must work according to PEP 517
|
|
||||||
import os
|
|
||||||
|
|
||||||
cwd = os.getcwd()
|
|
||||||
filename = os.path.join(cwd, "pyproject.toml")
|
|
||||||
with open(filename, "rb") as f:
|
|
||||||
data = tomllib.load(f)
|
|
||||||
|
|
||||||
for kernel, _ in (
|
|
||||||
data.get("tool", {}).get("hf-kernels", {}).get("dependencies", {}).items()
|
|
||||||
):
|
|
||||||
from hf_kernels.utils import install_kernel
|
|
||||||
|
|
||||||
install_kernel(kernel, revision="main")
|
|
||||||
uv_bin = shutil.which("uv")
|
|
||||||
if uv_bin is None:
|
|
||||||
raise RuntimeError("uv was not properly installed")
|
|
||||||
# Forward stderr, capture stdout for the filename
|
|
||||||
result = subprocess.run([uv_bin, *args], stdout=subprocess.PIPE)
|
|
||||||
if result.returncode != 0:
|
|
||||||
sys.exit(result.returncode)
|
|
||||||
# If there was extra stdout, forward it (there should not be extra stdout)
|
|
||||||
stdout = result.stdout.decode("utf-8").strip().splitlines(keepends=True)
|
|
||||||
sys.stdout.writelines(stdout[:-1])
|
|
||||||
# Fail explicitly instead of an irrelevant stacktrace
|
|
||||||
if not stdout:
|
|
||||||
print("uv subprocess did not return a filename on stdout", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
return stdout[-1].strip()
|
|
||||||
|
|
||||||
|
|
||||||
def build_sdist(
|
|
||||||
sdist_directory: str, config_settings: "Mapping[Any, Any] | None" = None
|
|
||||||
) -> str:
|
|
||||||
"""PEP 517 hook `build_sdist`."""
|
|
||||||
args = ["build-backend", "build-sdist", sdist_directory]
|
|
||||||
return call(args, config_settings)
|
|
||||||
|
|
||||||
|
|
||||||
def build_wheel(
|
|
||||||
wheel_directory: str,
|
|
||||||
config_settings: "Mapping[Any, Any] | None" = None,
|
|
||||||
metadata_directory: "str | None" = None,
|
|
||||||
) -> str:
|
|
||||||
"""PEP 517 hook `build_wheel`."""
|
|
||||||
args = ["build-backend", "build-wheel", wheel_directory]
|
|
||||||
if metadata_directory:
|
|
||||||
args.extend(["--metadata-directory", metadata_directory])
|
|
||||||
return call(args, config_settings)
|
|
||||||
|
|
||||||
|
|
||||||
def get_requires_for_build_sdist(
|
|
||||||
config_settings: "Mapping[Any, Any] | None" = None,
|
|
||||||
) -> "Sequence[str]":
|
|
||||||
"""PEP 517 hook `get_requires_for_build_sdist`."""
|
|
||||||
warn_config_settings(config_settings)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def get_requires_for_build_wheel(
|
|
||||||
config_settings: "Mapping[Any, Any] | None" = None,
|
|
||||||
) -> "Sequence[str]":
|
|
||||||
"""PEP 517 hook `get_requires_for_build_wheel`."""
|
|
||||||
warn_config_settings(config_settings)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_metadata_for_build_wheel(
|
|
||||||
metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None
|
|
||||||
) -> str:
|
|
||||||
"""PEP 517 hook `prepare_metadata_for_build_wheel`."""
|
|
||||||
args = ["build-backend", "prepare-metadata-for-build-wheel", metadata_directory]
|
|
||||||
return call(args, config_settings)
|
|
||||||
|
|
||||||
|
|
||||||
def build_editable(
|
|
||||||
wheel_directory: str,
|
|
||||||
config_settings: "Mapping[Any, Any] | None" = None,
|
|
||||||
metadata_directory: "str | None" = None,
|
|
||||||
) -> str:
|
|
||||||
"""PEP 660 hook `build_editable`."""
|
|
||||||
args = ["build-backend", "build-editable", wheel_directory]
|
|
||||||
|
|
||||||
if metadata_directory:
|
|
||||||
args.extend(["--metadata-directory", metadata_directory])
|
|
||||||
return call(args, config_settings)
|
|
||||||
|
|
||||||
|
|
||||||
def get_requires_for_build_editable(
|
|
||||||
config_settings: "Mapping[Any, Any] | None" = None,
|
|
||||||
) -> "Sequence[str]":
|
|
||||||
"""PEP 660 hook `get_requires_for_build_editable`."""
|
|
||||||
warn_config_settings(config_settings)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_metadata_for_build_editable(
|
|
||||||
metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None
|
|
||||||
) -> str:
|
|
||||||
"""PEP 660 hook `prepare_metadata_for_build_editable`."""
|
|
||||||
args = ["build-backend", "prepare-metadata-for-build-editable", metadata_directory]
|
|
||||||
return call(args, config_settings)
|
|
@ -1,92 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import dataclasses
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from hf_kernels.compat import tomllib
|
|
||||||
from hf_kernels.lockfile import KernelLock, get_kernel_locks
|
|
||||||
from hf_kernels.utils import install_kernel, install_kernel_all_variants
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
prog="hf-kernel", description="Manage compute kernels"
|
|
||||||
)
|
|
||||||
subparsers = parser.add_subparsers(required=True)
|
|
||||||
|
|
||||||
download_parser = subparsers.add_parser("download", help="Download locked kernels")
|
|
||||||
download_parser.add_argument(
|
|
||||||
"project_dir",
|
|
||||||
type=Path,
|
|
||||||
help="The project directory",
|
|
||||||
)
|
|
||||||
download_parser.add_argument(
|
|
||||||
"--all-variants",
|
|
||||||
action="store_true",
|
|
||||||
help="Download all build variants of the kernel",
|
|
||||||
)
|
|
||||||
download_parser.set_defaults(func=download_kernels)
|
|
||||||
|
|
||||||
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
|
|
||||||
lock_parser.add_argument(
|
|
||||||
"project_dir",
|
|
||||||
type=Path,
|
|
||||||
help="The project directory",
|
|
||||||
)
|
|
||||||
lock_parser.set_defaults(func=lock_kernels)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.func(args)
|
|
||||||
|
|
||||||
|
|
||||||
def download_kernels(args):
|
|
||||||
lock_path = args.project_dir / "hf-kernels.lock"
|
|
||||||
|
|
||||||
if not lock_path.exists():
|
|
||||||
print(f"No hf-kernels.lock file found in: {args.project_dir}", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
with open(args.project_dir / "hf-kernels.lock", "r") as f:
|
|
||||||
lock_json = json.load(f)
|
|
||||||
|
|
||||||
all_successful = True
|
|
||||||
|
|
||||||
for kernel_lock_json in lock_json:
|
|
||||||
kernel_lock = KernelLock.from_json(kernel_lock_json)
|
|
||||||
print(
|
|
||||||
f"Downloading `{kernel_lock.repo_id}` at with SHA: {kernel_lock.sha}",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
if args.all_variants:
|
|
||||||
install_kernel_all_variants(kernel_lock.repo_id, kernel_lock.sha)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
install_kernel(kernel_lock.repo_id, kernel_lock.sha)
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
print(e, file=sys.stderr)
|
|
||||||
all_successful = False
|
|
||||||
|
|
||||||
if not all_successful:
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def lock_kernels(args):
|
|
||||||
with open(args.project_dir / "pyproject.toml", "rb") as f:
|
|
||||||
data = tomllib.load(f)
|
|
||||||
|
|
||||||
kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)
|
|
||||||
|
|
||||||
all_locks = []
|
|
||||||
for kernel, version in kernel_versions.items():
|
|
||||||
all_locks.append(get_kernel_locks(kernel, version))
|
|
||||||
|
|
||||||
with open(args.project_dir / "hf-kernels.lock", "w") as f:
|
|
||||||
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
|
|
||||||
|
|
||||||
|
|
||||||
class _JSONEncoder(json.JSONEncoder):
|
|
||||||
def default(self, o):
|
|
||||||
if dataclasses.is_dataclass(o):
|
|
||||||
return dataclasses.asdict(o)
|
|
||||||
return super().default(o)
|
|
@ -1,178 +0,0 @@
|
|||||||
import ctypes
|
|
||||||
import importlib
|
|
||||||
import importlib.metadata
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import platform
|
|
||||||
import sys
|
|
||||||
from importlib.metadata import Distribution
|
|
||||||
from types import ModuleType
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
|
||||||
from packaging.version import parse
|
|
||||||
|
|
||||||
from hf_kernels.compat import tomllib
|
|
||||||
from hf_kernels.lockfile import KernelLock
|
|
||||||
|
|
||||||
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
|
|
||||||
|
|
||||||
|
|
||||||
def build_variant():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if torch.version.cuda is None:
|
|
||||||
raise AssertionError("This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled.")
|
|
||||||
|
|
||||||
torch_version = parse(torch.__version__)
|
|
||||||
cuda_version = parse(torch.version.cuda)
|
|
||||||
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
|
||||||
cpu = platform.machine()
|
|
||||||
os = platform.system().lower()
|
|
||||||
|
|
||||||
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
|
|
||||||
|
|
||||||
|
|
||||||
def import_from_path(module_name: str, file_path):
|
|
||||||
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
|
||||||
# it would also be used for other imports. So, we make a module name that
|
|
||||||
# depends on the path for it to be unique using the hex-encoded hash of
|
|
||||||
# the path.
|
|
||||||
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value)
|
|
||||||
module_name = f"{module_name}_{path_hash}"
|
|
||||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules[module_name] = module
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def install_kernel(
|
|
||||||
repo_id: str, revision: str, local_files_only: bool = False
|
|
||||||
) -> Tuple[str, str]:
|
|
||||||
"""Download a kernel for the current environment to the cache."""
|
|
||||||
package_name = repo_id.split('/')[-1]
|
|
||||||
package_name = package_name.replace('-', '_')
|
|
||||||
repo_path = snapshot_download(
|
|
||||||
repo_id,
|
|
||||||
allow_patterns=f"build/{build_variant()}/*",
|
|
||||||
cache_dir=CACHE_DIR,
|
|
||||||
revision=revision,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
variant_path = f"{repo_path}/build/{build_variant()}"
|
|
||||||
module_init_path = f"{variant_path}/{package_name}/__init__.py"
|
|
||||||
|
|
||||||
if not os.path.exists(module_init_path):
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Kernel `{repo_id}` at revision {revision} does not have build: {build_variant()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return package_name, variant_path
|
|
||||||
|
|
||||||
|
|
||||||
def install_kernel_all_variants(
|
|
||||||
repo_id: str, revision: str, local_files_only: bool = False
|
|
||||||
):
|
|
||||||
snapshot_download(
|
|
||||||
repo_id,
|
|
||||||
allow_patterns="build/*",
|
|
||||||
cache_dir=CACHE_DIR,
|
|
||||||
revision=revision,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_metadata(repo_id: str, revision: str, local_files_only: bool = False):
|
|
||||||
with open(
|
|
||||||
hf_hub_download(
|
|
||||||
repo_id,
|
|
||||||
"build.toml",
|
|
||||||
cache_dir=CACHE_DIR,
|
|
||||||
revision=revision,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
),
|
|
||||||
"rb",
|
|
||||||
) as f:
|
|
||||||
return tomllib.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def get_kernel(repo_id: str, revision: str = "main"):
|
|
||||||
package_name, package_path = install_kernel(repo_id, revision=revision)
|
|
||||||
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
|
|
||||||
|
|
||||||
|
|
||||||
def load_kernel(repo_id: str):
|
|
||||||
"""Get a pre-downloaded, locked kernel."""
|
|
||||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
|
||||||
|
|
||||||
if locked_sha is None:
|
|
||||||
raise ValueError(f"Kernel `{repo_id}` is not locked")
|
|
||||||
|
|
||||||
filename = hf_hub_download(
|
|
||||||
repo_id,
|
|
||||||
"build.toml",
|
|
||||||
cache_dir=CACHE_DIR,
|
|
||||||
local_files_only=True,
|
|
||||||
revision=locked_sha,
|
|
||||||
)
|
|
||||||
with open(filename, "rb") as f:
|
|
||||||
metadata = tomllib.load(f)
|
|
||||||
package_name = metadata["torch"]["name"]
|
|
||||||
|
|
||||||
repo_path = os.path.dirname(filename)
|
|
||||||
package_path = f"{repo_path}/build/{build_variant()}"
|
|
||||||
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
|
|
||||||
|
|
||||||
|
|
||||||
def get_locked_kernel(repo_id: str, local_files_only: bool = False):
|
|
||||||
"""Get a kernel using a lock file."""
|
|
||||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
|
||||||
|
|
||||||
if locked_sha is None:
|
|
||||||
raise ValueError(f"Kernel `{repo_id}` is not locked")
|
|
||||||
|
|
||||||
package_name, package_path = install_kernel(
|
|
||||||
repo_id, locked_sha, local_files_only=local_files_only
|
|
||||||
)
|
|
||||||
|
|
||||||
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
|
|
||||||
for dist in _get_caller_distributions():
|
|
||||||
lock_json = dist.read_text("hf-kernels.lock")
|
|
||||||
if lock_json is not None:
|
|
||||||
for kernel_lock_json in json.loads(lock_json):
|
|
||||||
kernel_lock = KernelLock.from_json(kernel_lock_json)
|
|
||||||
if kernel_lock.repo_id == repo_id:
|
|
||||||
return kernel_lock.sha
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_caller_distributions() -> List[Distribution]:
|
|
||||||
module = _get_caller_module()
|
|
||||||
if module is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Look up all possible distributions that this module could be from.
|
|
||||||
package = module.__name__.split(".")[0]
|
|
||||||
dist_names = importlib.metadata.packages_distributions().get(package)
|
|
||||||
if dist_names is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return [importlib.metadata.distribution(dist_name) for dist_name in dist_names]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_caller_module() -> Optional[ModuleType]:
|
|
||||||
stack = inspect.stack()
|
|
||||||
# Get first module in the stack that is not the current module.
|
|
||||||
first_module = inspect.getmodule(stack[0][0])
|
|
||||||
for frame in stack[1:]:
|
|
||||||
module = inspect.getmodule(frame[0])
|
|
||||||
if module is not None and module != first_module:
|
|
||||||
return module
|
|
||||||
return first_module
|
|
37
src/kernels/__init__.py
Normal file
37
src/kernels/__init__.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from kernels.layer import (
|
||||||
|
CUDAProperties,
|
||||||
|
Device,
|
||||||
|
LayerRepository,
|
||||||
|
Mode,
|
||||||
|
kernelize,
|
||||||
|
register_kernel_mapping,
|
||||||
|
replace_kernel_forward_from_hub,
|
||||||
|
use_kernel_forward_from_hub,
|
||||||
|
use_kernel_mapping,
|
||||||
|
)
|
||||||
|
from kernels.utils import (
|
||||||
|
get_kernel,
|
||||||
|
get_local_kernel,
|
||||||
|
get_locked_kernel,
|
||||||
|
has_kernel,
|
||||||
|
install_kernel,
|
||||||
|
load_kernel,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CUDAProperties",
|
||||||
|
"Device",
|
||||||
|
"LayerRepository",
|
||||||
|
"Mode",
|
||||||
|
"get_kernel",
|
||||||
|
"get_local_kernel",
|
||||||
|
"get_locked_kernel",
|
||||||
|
"has_kernel",
|
||||||
|
"install_kernel",
|
||||||
|
"kernelize",
|
||||||
|
"load_kernel",
|
||||||
|
"register_kernel_mapping",
|
||||||
|
"replace_kernel_forward_from_hub",
|
||||||
|
"use_kernel_forward_from_hub",
|
||||||
|
"use_kernel_mapping",
|
||||||
|
]
|
200
src/kernels/_interval_tree.py
Normal file
200
src/kernels/_interval_tree.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
# AVL-balanced interval trees. We could use the intervaltree
|
||||||
|
# packages, but it seems unmaintained and does not have type
|
||||||
|
# annotations.
|
||||||
|
|
||||||
|
from typing import Generic, List, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class _Node(Generic[T]):
|
||||||
|
"""A node in the interval tree."""
|
||||||
|
|
||||||
|
def __init__(self, start: int, end: int, data: T):
|
||||||
|
self.start: int = start
|
||||||
|
self.end: int = end
|
||||||
|
self.data: T = data
|
||||||
|
self.max_end: int = end
|
||||||
|
self.left: Optional["_Node[T]"] = None
|
||||||
|
self.right: Optional["_Node[T]"] = None
|
||||||
|
self.height: int = 1
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"Node({self.start}, {self.end})"
|
||||||
|
|
||||||
|
|
||||||
|
class IntervalTree(Generic[T]):
|
||||||
|
"""A data structure to hold and query (unique) intervals."""
|
||||||
|
|
||||||
|
root: Optional[_Node[T]]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.root = None
|
||||||
|
|
||||||
|
def insert(self, start: int, end: int, data: T) -> None:
|
||||||
|
"""
|
||||||
|
Inserts a new interval into the tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start: The starting point of the interval.
|
||||||
|
end: The ending point of the interval.
|
||||||
|
data: The data associated with this interval.
|
||||||
|
"""
|
||||||
|
self.root = self._insert(self.root, start, end, data)
|
||||||
|
|
||||||
|
def _get_height(self, node: Optional[_Node[T]]) -> int:
|
||||||
|
if not node:
|
||||||
|
return 0
|
||||||
|
return node.height
|
||||||
|
|
||||||
|
def _get_balance(self, node: Optional[_Node[T]]) -> int:
|
||||||
|
if not node:
|
||||||
|
return 0
|
||||||
|
return self._get_height(node.left) - self._get_height(node.right)
|
||||||
|
|
||||||
|
def _update_node_attributes(self, node: _Node[T]) -> None:
|
||||||
|
node.height = 1 + max(self._get_height(node.left), self._get_height(node.right))
|
||||||
|
node.max_end = node.end
|
||||||
|
if node.left:
|
||||||
|
node.max_end = max(node.max_end, node.left.max_end)
|
||||||
|
if node.right:
|
||||||
|
node.max_end = max(node.max_end, node.right.max_end)
|
||||||
|
|
||||||
|
def _right_rotate(self, y: _Node[T]) -> _Node[T]:
|
||||||
|
"""Performs a right rotation."""
|
||||||
|
x = y.left
|
||||||
|
assert x is not None
|
||||||
|
T2 = x.right
|
||||||
|
|
||||||
|
x.right = y
|
||||||
|
y.left = T2
|
||||||
|
|
||||||
|
self._update_node_attributes(y)
|
||||||
|
self._update_node_attributes(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _left_rotate(self, x: _Node[T]) -> _Node[T]:
|
||||||
|
"""Performs a left rotation."""
|
||||||
|
y = x.right
|
||||||
|
assert y is not None
|
||||||
|
T2 = y.left
|
||||||
|
|
||||||
|
y.left = x
|
||||||
|
x.right = T2
|
||||||
|
|
||||||
|
self._update_node_attributes(x)
|
||||||
|
self._update_node_attributes(y)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
def _insert(
|
||||||
|
self, node: Optional[_Node[T]], start: int, end: int, data: T
|
||||||
|
) -> _Node[T]:
|
||||||
|
"""Recursive helper to insert a new node and balance the tree."""
|
||||||
|
if not node:
|
||||||
|
return _Node(start, end, data)
|
||||||
|
|
||||||
|
# Replace the data if the interval already exists.
|
||||||
|
if start == node.start and end == node.end:
|
||||||
|
node.data = data
|
||||||
|
return node
|
||||||
|
|
||||||
|
if start < node.start:
|
||||||
|
node.left = self._insert(node.left, start, end, data)
|
||||||
|
else:
|
||||||
|
node.right = self._insert(node.right, start, end, data)
|
||||||
|
|
||||||
|
self._update_node_attributes(node)
|
||||||
|
|
||||||
|
balance = self._get_balance(node)
|
||||||
|
|
||||||
|
# Left Left Case
|
||||||
|
if balance > 1 and node.left and start < node.left.start:
|
||||||
|
return self._right_rotate(node)
|
||||||
|
|
||||||
|
# Right Right Case
|
||||||
|
if balance < -1 and node.right and start >= node.right.start:
|
||||||
|
return self._left_rotate(node)
|
||||||
|
|
||||||
|
# Left Right Case
|
||||||
|
if balance > 1 and node.left and start >= node.left.start:
|
||||||
|
node.left = self._left_rotate(node.left)
|
||||||
|
return self._right_rotate(node)
|
||||||
|
|
||||||
|
# Right Left Case
|
||||||
|
if balance < -1 and node.right and start < node.right.start:
|
||||||
|
node.right = self._right_rotate(node.right)
|
||||||
|
return self._left_rotate(node)
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
def search(self, point: int) -> List[T]:
|
||||||
|
"""
|
||||||
|
Searches for all intervals that contain the given point.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
point: The point to search for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of data items from all matching intervals.
|
||||||
|
"""
|
||||||
|
results: List[T] = []
|
||||||
|
self._search(self.root, point, results)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _search(self, node: Optional[_Node[T]], point: int, results: List[T]) -> None:
|
||||||
|
"""Recursive helper to find all overlapping intervals."""
|
||||||
|
if node is None or point > node.max_end:
|
||||||
|
return
|
||||||
|
|
||||||
|
if node.left:
|
||||||
|
self._search(node.left, point, results)
|
||||||
|
|
||||||
|
if node.start <= point <= node.end:
|
||||||
|
results.append(node.data)
|
||||||
|
|
||||||
|
if point >= node.start and node.right:
|
||||||
|
self._search(node.right, point, results)
|
||||||
|
|
||||||
|
def find_smallest_interval(self, point: int) -> Optional[T]:
|
||||||
|
"""
|
||||||
|
Finds the item with the most specific (smallest) range for a given point.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
point: The capability to look up.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The data of the best-matching item, or None if no match is found.
|
||||||
|
"""
|
||||||
|
matches: List[Tuple[int, int, T]] = []
|
||||||
|
self._find_with_intervals(self.root, point, matches)
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Return the smallest interval, sort by memory location when
|
||||||
|
# there are multiple matches with the same interval size. This
|
||||||
|
# is just to ensure that we can compare against a trivial
|
||||||
|
# implementation in tests.
|
||||||
|
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
|
||||||
|
return best_match[2]
|
||||||
|
|
||||||
|
def _find_with_intervals(
|
||||||
|
self,
|
||||||
|
node: Optional[_Node[T]],
|
||||||
|
point: int,
|
||||||
|
results: List[Tuple[int, int, T]],
|
||||||
|
) -> None:
|
||||||
|
"""A modified search that collects interval ranges along with data."""
|
||||||
|
if node is None or point > node.max_end:
|
||||||
|
return
|
||||||
|
|
||||||
|
if node.left:
|
||||||
|
self._find_with_intervals(node.left, point, results)
|
||||||
|
|
||||||
|
if node.start <= point <= node.end:
|
||||||
|
results.append((node.start, node.end, node.data))
|
||||||
|
|
||||||
|
if point >= node.start and node.right:
|
||||||
|
self._find_with_intervals(node.right, point, results)
|
751
src/kernels/_vendored/convert_rst_to_mdx.py
Normal file
751
src/kernels/_vendored/convert_rst_to_mdx.py
Normal file
@ -0,0 +1,751 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Vendored from https://github.com/huggingface/doc-builder/blob/main/src/doc_builder/convert_rst_to_mdx.py
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Re pattern to catch things inside ` ` in :obj:`thing`.
|
||||||
|
_re_obj = re.compile(r":obj:`([^`]+)`")
|
||||||
|
# Re pattern to catch things inside ` ` in :math:`thing`.
|
||||||
|
_re_math = re.compile(r":math:`([^`]+)`")
|
||||||
|
# Re pattern to catch things between single backquotes.
|
||||||
|
_re_single_backquotes = re.compile(r"(^|[^`])`([^`]+)`([^`]|$)")
|
||||||
|
# Re pattern to catch things between double backquotes.
|
||||||
|
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
|
||||||
|
# Re pattern to catch things inside ` ` in :func/class/meth:`thing`.
|
||||||
|
_re_func_class = re.compile(r":(?:func|class|meth):`([^`]+)`")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_rst_formatting(text):
|
||||||
|
"""
|
||||||
|
Convert rst syntax for formatting to markdown in a given text.
|
||||||
|
"""
|
||||||
|
# Remove :class:, :func: and :meth: markers. To code-links and put double backquotes
|
||||||
|
# (to not be caught by the italic conversion).
|
||||||
|
text = _re_func_class.sub(r"[``\1``]", text)
|
||||||
|
# Remove :obj: markers. What's after is in a single backquotes so we put in double backquotes
|
||||||
|
# (to not be caught by the italic conversion).
|
||||||
|
text = _re_obj.sub(r"``\1``", text)
|
||||||
|
# Remove :math: markers.
|
||||||
|
text = _re_math.sub(r"\\\\(\1\\\\)", text)
|
||||||
|
# Convert content in single backquotes to italic.
|
||||||
|
text = _re_single_backquotes.sub(r"\1*\2*\3", text)
|
||||||
|
# Convert content in double backquotes to single backquotes.
|
||||||
|
text = _re_double_backquotes.sub(r"\1`\2`\3", text)
|
||||||
|
# Remove remaining ::
|
||||||
|
text = re.sub(r"::\n", "", text)
|
||||||
|
|
||||||
|
# Remove new lines inside blocks in backsticks as they will be kept.
|
||||||
|
lines = text.split("\n")
|
||||||
|
in_code = False
|
||||||
|
text = None
|
||||||
|
for line in lines:
|
||||||
|
if in_code:
|
||||||
|
splits = line.split("`")
|
||||||
|
in_code = len(splits) > 1 and len(splits) % 2 == 1
|
||||||
|
if len(splits) == 1:
|
||||||
|
# Some forgotten lone backstick
|
||||||
|
text += "\n" + line
|
||||||
|
else:
|
||||||
|
text += " " + line.lstrip()
|
||||||
|
else:
|
||||||
|
if text is not None:
|
||||||
|
text += "\n" + line
|
||||||
|
else:
|
||||||
|
text = line
|
||||||
|
splits = line.split("`")
|
||||||
|
in_code = len(splits) % 2 == 0
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# Re pattern to catch description and url in links of the form `description <url>`_.
|
||||||
|
_re_links = re.compile(r"`([^`]+\S)\s+</*([^/][^>`]*)>`_+")
|
||||||
|
# Re pattern to catch description and url in links of the form :prefix_link:`description <url>`_.
|
||||||
|
_re_prefix_links = re.compile(r":prefix_link:`([^`]+\S)\s+</*([^/][^>`]*)>`")
|
||||||
|
# Re pattern to catch reference in links of the form :doc:`reference`.
|
||||||
|
_re_simple_doc = re.compile(r":doc:`([^`<]*)`")
|
||||||
|
# Re pattern to catch description and reference in links of the form :doc:`description <reference>`.
|
||||||
|
_re_doc_with_description = re.compile(r":doc:`([^`<]+\S)\s+</*([^/][^>`]*)>`")
|
||||||
|
# Re pattern to catch reference in links of the form :ref:`reference`.
|
||||||
|
_re_simple_ref = re.compile(r":ref:`([^`<]*)`")
|
||||||
|
# Re pattern to catch description and reference in links of the form :ref:`description <reference>`.
|
||||||
|
_re_ref_with_description = re.compile(r":ref:`([^`<]+\S)\s+<([^>]*)>`")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_rst_links(text, page_info):
|
||||||
|
"""
|
||||||
|
Convert the rst links in text to markdown.
|
||||||
|
"""
|
||||||
|
if "package_name" not in page_info:
|
||||||
|
raise ValueError("`page_info` must contain at least the package_name.")
|
||||||
|
package_name = page_info["package_name"]
|
||||||
|
version = page_info.get("version", "main")
|
||||||
|
language = page_info.get("language", "en")
|
||||||
|
no_prefix = page_info.get("no_prefix", False)
|
||||||
|
|
||||||
|
prefix = "" if no_prefix else f"/docs/{package_name}/{version}/{language}/"
|
||||||
|
# Links of the form :doc:`page`
|
||||||
|
text = _re_simple_doc.sub(rf"[\1]({prefix}\1)", text)
|
||||||
|
# Links of the form :doc:`text <page>`
|
||||||
|
text = _re_doc_with_description.sub(rf"[\1]({prefix}\2)", text)
|
||||||
|
|
||||||
|
if "page" in page_info and not no_prefix:
|
||||||
|
page = str(page_info["page"])
|
||||||
|
if page.endswith(".html"):
|
||||||
|
page = page[:-5]
|
||||||
|
prefix = f"{prefix}{page}"
|
||||||
|
else:
|
||||||
|
prefix = ""
|
||||||
|
# Refs of the form :ref:`page`
|
||||||
|
text = _re_simple_ref.sub(rf"[\1]({prefix}#\1)", text)
|
||||||
|
# Refs of the form :ref:`text <page>`
|
||||||
|
text = _re_ref_with_description.sub(rf"[\1]({prefix}#\2)", text)
|
||||||
|
|
||||||
|
# Links with a prefix
|
||||||
|
# TODO: when it exists, use the API to deal with prefix links properly.
|
||||||
|
prefix = f"https://github.com/huggingface/{package_name}/tree/main/"
|
||||||
|
text = _re_prefix_links.sub(rf"[\1]({prefix}\2)", text)
|
||||||
|
# Other links
|
||||||
|
text = _re_links.sub(r"[\1](\2)", text)
|
||||||
|
# Relative links or Transformers links need to remove the .html
|
||||||
|
if (
|
||||||
|
"(https://https://huggingface.co/" in text
|
||||||
|
or re.search(r"\(\.+/", text) is not None
|
||||||
|
):
|
||||||
|
text = text.replace(".html", "")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# Re pattern that catches examples blocks of the form `Example::`.
|
||||||
|
_re_example = re.compile(r"^\s*(\S.*)::\s*$")
|
||||||
|
# Re pattern that catches rst blocks of the form `.. block_name::`.
|
||||||
|
_re_block = re.compile(r"^\s*\.\.\s+(\S+)::")
|
||||||
|
# Re pattern that catches what's after the :: in rst blocks of the form `.. block_name:: something`.
|
||||||
|
_re_block_info = re.compile(r"^\s*\.\.\s+\S+::\s*(\S.*)$")
|
||||||
|
|
||||||
|
|
||||||
|
def is_empty_line(line):
|
||||||
|
return len(line) == 0 or line.isspace()
|
||||||
|
|
||||||
|
|
||||||
|
def find_indent(line):
|
||||||
|
"""
|
||||||
|
Returns the number of spaces that start a line indent.
|
||||||
|
"""
|
||||||
|
search = re.search(r"^(\s*)(?:\S|$)", line)
|
||||||
|
if search is None:
|
||||||
|
return 0
|
||||||
|
return len(search.groups()[0])
|
||||||
|
|
||||||
|
|
||||||
|
_re_rst_option = re.compile(r"^\s*:(\S+):(.*)$")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_special_chars(text):
|
||||||
|
"""
|
||||||
|
Converts { and < that have special meanings in MDX.
|
||||||
|
"""
|
||||||
|
text = text.replace("{", "&lcub;")
|
||||||
|
# We don't want to replace those by the HTML code, so we temporarily set them at LTHTML
|
||||||
|
text = re.sub(
|
||||||
|
r"<(img|br|hr|Youtube)", r"LTHTML\1", text
|
||||||
|
) # html void elements with no closing counterpart
|
||||||
|
_re_lt_html = re.compile(r"<(\S+)([^>]*>)(((?!</\1>).)*)<(/\1>)", re.DOTALL)
|
||||||
|
while _re_lt_html.search(text):
|
||||||
|
text = _re_lt_html.sub(r"LTHTML\1\2\3LTHTML\5", text)
|
||||||
|
text = re.sub(r"(^|[^<])<([^<]|$)", r"\1&lt;\2", text)
|
||||||
|
text = text.replace("LTHTML", "<")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def parse_options(block_content):
|
||||||
|
"""
|
||||||
|
Parses the option in some rst block content.
|
||||||
|
"""
|
||||||
|
block_lines = block_content.split("\n")
|
||||||
|
block_indent = find_indent(block_lines[0])
|
||||||
|
current_option = None
|
||||||
|
result = {}
|
||||||
|
for line in block_lines:
|
||||||
|
if _re_rst_option.search(line) is not None:
|
||||||
|
current_option, value = _re_rst_option.search(line).groups()
|
||||||
|
result[current_option] = value.lstrip()
|
||||||
|
elif find_indent(line) > block_indent:
|
||||||
|
result[current_option] += " " + line.lstrip()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def apply_min_indent(text, min_indent):
|
||||||
|
"""
|
||||||
|
Make sure all lines in a text are have a minimum indentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (`str`): The text to treat.
|
||||||
|
min_indent (`int`): The minimal indentation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: The processed text.
|
||||||
|
"""
|
||||||
|
lines = text.split("\n")
|
||||||
|
idx = 0
|
||||||
|
while idx < len(lines):
|
||||||
|
if is_empty_line(lines[idx]):
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
indent = find_indent(lines[idx])
|
||||||
|
if indent < min_indent:
|
||||||
|
while idx < len(lines) and (
|
||||||
|
find_indent(lines[idx]) >= indent or is_empty_line(lines[idx])
|
||||||
|
):
|
||||||
|
if not is_empty_line(lines[idx]):
|
||||||
|
lines[idx] = " " * (min_indent - indent) + lines[idx]
|
||||||
|
idx += 1
|
||||||
|
else:
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_rst_blocks(text, page_info):
|
||||||
|
"""
|
||||||
|
Converts rst special blocks (examples, notes) into MDX.
|
||||||
|
"""
|
||||||
|
if "package_name" not in page_info:
|
||||||
|
raise ValueError("`page_info` must contain at least the package_name.")
|
||||||
|
package_name = page_info["package_name"]
|
||||||
|
version = page_info.get("version", "main")
|
||||||
|
language = page_info.get("language", "en")
|
||||||
|
|
||||||
|
lines = text.split("\n")
|
||||||
|
idx = 0
|
||||||
|
new_lines = []
|
||||||
|
while idx < len(lines):
|
||||||
|
block_type = None
|
||||||
|
block_info = None
|
||||||
|
if _re_block.search(lines[idx]) is not None:
|
||||||
|
block_type = _re_block.search(lines[idx]).groups()[0]
|
||||||
|
if _re_block_info.search(lines[idx]) is not None:
|
||||||
|
block_info = _re_block_info.search(lines[idx]).groups()[0]
|
||||||
|
elif _re_example.search(lines[idx]) is not None:
|
||||||
|
block_type = "code-block-example"
|
||||||
|
block_info = "python"
|
||||||
|
example_name = _re_example.search(lines[idx]).groups()[0]
|
||||||
|
new_lines.append(f"<exampletitle>{example_name}:</exampletitle>\n")
|
||||||
|
elif lines[idx].strip() == "..":
|
||||||
|
block_type = "comment"
|
||||||
|
elif lines[idx].strip() == "::":
|
||||||
|
block_type = "code-block"
|
||||||
|
|
||||||
|
if block_type is not None:
|
||||||
|
block_indent = find_indent(lines[idx])
|
||||||
|
# Find the next nonempty line
|
||||||
|
idx += 1
|
||||||
|
while idx < len(lines) and is_empty_line(lines[idx]):
|
||||||
|
idx += 1
|
||||||
|
# Grab the indent of the return line, this block will stop when we unindent under it (or has already)
|
||||||
|
example_indent = (
|
||||||
|
find_indent(lines[idx]) if idx < len(lines) else block_indent
|
||||||
|
)
|
||||||
|
|
||||||
|
if example_indent == block_indent:
|
||||||
|
block_content = ""
|
||||||
|
else:
|
||||||
|
block_lines = []
|
||||||
|
while idx < len(lines) and (
|
||||||
|
is_empty_line(lines[idx])
|
||||||
|
or find_indent(lines[idx]) >= example_indent
|
||||||
|
):
|
||||||
|
block_lines.append(lines[idx][example_indent:])
|
||||||
|
idx += 1
|
||||||
|
block_content = "\n".join(block_lines)
|
||||||
|
|
||||||
|
if block_type in ["code", "code-block"]:
|
||||||
|
prefix = "```" if block_info is None else f"```{block_info}"
|
||||||
|
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n")
|
||||||
|
elif block_type == "code-block-example":
|
||||||
|
prefix = f"<example>```{block_info}"
|
||||||
|
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n</example>")
|
||||||
|
elif block_type == "note":
|
||||||
|
new_lines.append(
|
||||||
|
apply_min_indent(
|
||||||
|
f"<Tip>\n\n{block_content.strip()}\n\n</Tip>\n", block_indent
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif block_type == "warning":
|
||||||
|
new_lines.append(
|
||||||
|
apply_min_indent(
|
||||||
|
"<Tip warning={true}>\n\n"
|
||||||
|
+ f"{block_content.strip()}\n\n</Tip>\n",
|
||||||
|
block_indent,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif block_type == "raw":
|
||||||
|
new_lines.append(block_content.strip() + "\n")
|
||||||
|
elif block_type == "math":
|
||||||
|
new_lines.append(f"$${block_content.strip()}$$\n")
|
||||||
|
elif block_type == "comment":
|
||||||
|
new_lines.append(f"<!--{block_content.strip()}\n-->\n")
|
||||||
|
elif block_type == "autofunction":
|
||||||
|
if block_info is not None:
|
||||||
|
new_lines.append(f"[[autodoc]] {block_info}\n")
|
||||||
|
elif block_type == "autoclass":
|
||||||
|
if block_info is not None:
|
||||||
|
block = f"[[autodoc]] {block_info}\n"
|
||||||
|
options = parse_options(block_content)
|
||||||
|
if "special-members" in options:
|
||||||
|
special_members = options["special-members"].split(", ")
|
||||||
|
for special_member in special_members:
|
||||||
|
block += f" - {special_member}\n"
|
||||||
|
if "members" in options:
|
||||||
|
members = options["members"]
|
||||||
|
if len(members) == 0:
|
||||||
|
block += " - all\n"
|
||||||
|
else:
|
||||||
|
for member in members.split(", "):
|
||||||
|
block += f" - {member}\n"
|
||||||
|
new_lines.append(block)
|
||||||
|
elif block_type == "image":
|
||||||
|
options = parse_options(block_content)
|
||||||
|
target = options.pop("target", None)
|
||||||
|
if block_info is not None:
|
||||||
|
options["src"] = block_info
|
||||||
|
else:
|
||||||
|
if target is None:
|
||||||
|
raise ValueError("Image source not defined.")
|
||||||
|
options["src"] = target
|
||||||
|
# Adapt path
|
||||||
|
options["src"] = options["src"].replace(
|
||||||
|
"/imgs/", f"/docs/{package_name}/{version}/{language}/imgs/"
|
||||||
|
)
|
||||||
|
html_code = " ".join(
|
||||||
|
[f'{key}="{value}"' for key, value in options.items()]
|
||||||
|
)
|
||||||
|
new_lines.append(f"<img {html_code}/>\n")
|
||||||
|
|
||||||
|
else:
|
||||||
|
new_lines.append(
|
||||||
|
f"{block_type},{block_info}\n{block_content.rstrip()}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
new_lines.append(lines[idx])
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
return "\n".join(new_lines)
|
||||||
|
|
||||||
|
|
||||||
|
# Re pattern that catches rst args blocks of the form `Parameters:`.
|
||||||
|
_re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\s*$")
|
||||||
|
# Re pattern that catches return blocks of the form `Return:`.
|
||||||
|
_re_returns = re.compile(r"^\s*(Return|Yield|Raise)s?:\s*$")
|
||||||
|
|
||||||
|
|
||||||
|
def split_return_line(line):
|
||||||
|
"""
|
||||||
|
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
|
||||||
|
"""
|
||||||
|
splits_on_colon = line.split(":")
|
||||||
|
idx = 1
|
||||||
|
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
|
||||||
|
idx += 2
|
||||||
|
if idx >= len(splits_on_colon):
|
||||||
|
if len(splits_on_colon) % 2 == 1 and re.search(r"`\w+`$", line.rstrip()):
|
||||||
|
return line, ""
|
||||||
|
return None, line
|
||||||
|
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
|
||||||
|
|
||||||
|
|
||||||
|
def split_raise_line(line):
|
||||||
|
"""
|
||||||
|
Split the raise line with format `SomeError some doc`.
|
||||||
|
"""
|
||||||
|
splits_on_colon = line.strip().split(" ")
|
||||||
|
error_type, doc = splits_on_colon[0], " ".join(splits_on_colon[1:])
|
||||||
|
if error_type and error_type[-1] == ":":
|
||||||
|
error_type = error_type[:-1]
|
||||||
|
return error_type, doc
|
||||||
|
|
||||||
|
|
||||||
|
def split_arg_line(line):
|
||||||
|
"""
|
||||||
|
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
|
||||||
|
"""
|
||||||
|
splits_on_colon = line.split(":")
|
||||||
|
idx = 1
|
||||||
|
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
|
||||||
|
idx += 2
|
||||||
|
if idx >= len(splits_on_colon):
|
||||||
|
return line, ""
|
||||||
|
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidRstDocstringError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_re_parameters = re.compile(
|
||||||
|
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
|
||||||
|
)
|
||||||
|
_re_md_link = re.compile(r"\[(.+)\]\(.+\)", re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_rst_docstring(docstring):
|
||||||
|
"""
|
||||||
|
Parses a docstring written in rst, in particular the list of arguments and the return type.
|
||||||
|
"""
|
||||||
|
lines = docstring.split("\n")
|
||||||
|
idx = 0
|
||||||
|
while idx < len(lines):
|
||||||
|
# Parameters section
|
||||||
|
if _re_args.search(lines[idx]) is not None:
|
||||||
|
# Title of the section.
|
||||||
|
lines[idx] = "<parameters>\n"
|
||||||
|
# Find the next nonempty line
|
||||||
|
idx += 1
|
||||||
|
while is_empty_line(lines[idx]):
|
||||||
|
idx += 1
|
||||||
|
# Grab the indent of the list of parameters, this block will stop when we unindent under it or we see the
|
||||||
|
# Returns or Raises block.
|
||||||
|
param_indent = find_indent(lines[idx])
|
||||||
|
while (
|
||||||
|
idx < len(lines)
|
||||||
|
and find_indent(lines[idx]) == param_indent
|
||||||
|
and _re_returns.search(lines[idx]) is None
|
||||||
|
):
|
||||||
|
intro, doc = split_arg_line(lines[idx])
|
||||||
|
# Line starting with a > after indent indicate a "section title" in the parameters.
|
||||||
|
if intro.lstrip().startswith(">"):
|
||||||
|
lines[idx] = intro.lstrip()
|
||||||
|
else:
|
||||||
|
lines[idx] = (
|
||||||
|
re.sub(r"^\s*(\S+)(\s)?", r"- **\1**\2", intro) + " --" + doc
|
||||||
|
)
|
||||||
|
idx += 1
|
||||||
|
while idx < len(lines) and (
|
||||||
|
is_empty_line(lines[idx]) or find_indent(lines[idx]) > param_indent
|
||||||
|
):
|
||||||
|
idx += 1
|
||||||
|
lines.insert(idx, "</parameters>\n")
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
# Returns section
|
||||||
|
elif _re_returns.search(lines[idx]) is not None:
|
||||||
|
# tag is either `return` or `yield`
|
||||||
|
tag = _re_returns.match(lines[idx]).group(1).lower()
|
||||||
|
# Title of the section.
|
||||||
|
lines[idx] = f"<{tag}s>\n"
|
||||||
|
# Find the next nonempty line
|
||||||
|
idx += 1
|
||||||
|
while is_empty_line(lines[idx]):
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
# Grab the indent of the return line, this block will stop when we unindent under it.
|
||||||
|
return_indent = find_indent(lines[idx])
|
||||||
|
raised_errors = []
|
||||||
|
# The line may contain the return type.
|
||||||
|
if tag in ["return", "yield"]:
|
||||||
|
return_type, return_description = split_return_line(lines[idx])
|
||||||
|
lines[idx] = return_description
|
||||||
|
idx += 1
|
||||||
|
while idx < len(lines) and (
|
||||||
|
is_empty_line(lines[idx])
|
||||||
|
or find_indent(lines[idx]) >= return_indent
|
||||||
|
):
|
||||||
|
idx += 1
|
||||||
|
else:
|
||||||
|
while idx < len(lines) and find_indent(lines[idx]) == return_indent:
|
||||||
|
return_type, return_description = split_raise_line(lines[idx])
|
||||||
|
raised_error = re.sub(r"^\s*`?([\w\.]*)`?$", r"``\1``", return_type)
|
||||||
|
lines[idx] = "- " + raised_error + " -- " + return_description
|
||||||
|
md_link = _re_md_link.match(raised_error)
|
||||||
|
if md_link:
|
||||||
|
raised_error = md_link[1]
|
||||||
|
raised_error = re.sub(
|
||||||
|
r"^\s*`?([\w\.]*)`?$", r"``\1``", raised_error
|
||||||
|
)
|
||||||
|
if raised_error not in raised_errors:
|
||||||
|
raised_errors.append(raised_error)
|
||||||
|
idx += 1
|
||||||
|
while idx < len(lines) and (
|
||||||
|
is_empty_line(lines[idx])
|
||||||
|
or find_indent(lines[idx]) > return_indent
|
||||||
|
):
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
lines.insert(idx, f"</{tag}s>\n")
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
# Return block finished, we insert the return type if one was specified
|
||||||
|
if tag in ["return", "yield"] and return_type is not None:
|
||||||
|
lines[idx - 1] += f"\n<{tag}type>{return_type}</{tag}type>\n"
|
||||||
|
elif len(raised_errors) > 0:
|
||||||
|
# raised errors
|
||||||
|
lines[
|
||||||
|
idx - 1
|
||||||
|
] += f"\n<raisederrors>{' or '.join(raised_errors)}</raisederrors>\n"
|
||||||
|
|
||||||
|
else:
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
result = "\n".join(lines)
|
||||||
|
|
||||||
|
# combine multiple <parameters> blocks into one block
|
||||||
|
if result.count("<parameters>") > 1:
|
||||||
|
parameters_blocks = _re_parameters.findall(result)
|
||||||
|
parameters_blocks = [pb[0].strip() for pb in parameters_blocks]
|
||||||
|
parameters_str = "\n".join(parameters_blocks)
|
||||||
|
result = _re_parameters.sub("", result)
|
||||||
|
result += f"\n<parameters>{parameters_str}</parameters>\n"
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
_re_list = re.compile(r"^\s*(-|\*|\d+\.)\s")
|
||||||
|
_re_autodoc = re.compile(r"^\s*\[\[autodoc\]\]\s+(\S+)\s*$")
|
||||||
|
|
||||||
|
|
||||||
|
def remove_indent(text):
|
||||||
|
"""
|
||||||
|
Remove indents in text, except the one linked to lists (or sublists).
|
||||||
|
"""
|
||||||
|
lines = text.split("\n")
|
||||||
|
# List of indents to remember for nested lists
|
||||||
|
current_indents = []
|
||||||
|
# List of new indents to remember for nested lists
|
||||||
|
new_indents = []
|
||||||
|
is_inside_code = False
|
||||||
|
code_indent = 0
|
||||||
|
for idx, line in enumerate(lines):
|
||||||
|
# Line is an item in a list.
|
||||||
|
if _re_list.search(line) is not None:
|
||||||
|
indent = find_indent(line)
|
||||||
|
# Is it a new list / new level of nestedness?
|
||||||
|
if len(current_indents) == 0 or indent > current_indents[-1]:
|
||||||
|
current_indents.append(indent)
|
||||||
|
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
|
||||||
|
lines[idx] = " " * new_indent + line[indent:]
|
||||||
|
new_indent += len(_re_list.search(line).groups()[0]) + 1
|
||||||
|
new_indents.append(new_indent)
|
||||||
|
# Otherwise it's an existing level of list (current one, or previous one)
|
||||||
|
else:
|
||||||
|
# Let's find the proper level of indentation
|
||||||
|
level = len(current_indents) - 1
|
||||||
|
while level >= 0 and current_indents[level] != indent:
|
||||||
|
level -= 1
|
||||||
|
current_indents = current_indents[: level + 1]
|
||||||
|
new_indents = new_indents[:level]
|
||||||
|
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
|
||||||
|
lines[idx] = " " * new_indent + line[indent:]
|
||||||
|
new_indent += len(_re_list.search(line).groups()[0]) + 1
|
||||||
|
new_indents.append(new_indent)
|
||||||
|
|
||||||
|
# Line is an autodoc, we keep the indent for the list just after if there is one.
|
||||||
|
elif _re_autodoc.search(line) is not None:
|
||||||
|
indent = find_indent(line)
|
||||||
|
current_indents = [indent]
|
||||||
|
new_indents = [4]
|
||||||
|
lines[idx] = line.strip()
|
||||||
|
|
||||||
|
# Deal with empty lines separately
|
||||||
|
elif is_empty_line(line):
|
||||||
|
lines[idx] = ""
|
||||||
|
|
||||||
|
# Code blocks
|
||||||
|
elif line.lstrip().startswith("```"):
|
||||||
|
is_inside_code = not is_inside_code
|
||||||
|
if is_inside_code:
|
||||||
|
code_indent = find_indent(line)
|
||||||
|
lines[idx] = line[code_indent:]
|
||||||
|
elif is_inside_code:
|
||||||
|
lines[idx] = line[code_indent:]
|
||||||
|
|
||||||
|
else:
|
||||||
|
indent = find_indent(line)
|
||||||
|
if len(current_indents) > 0 and indent > current_indents[-1]:
|
||||||
|
lines[idx] = " " * new_indents[-1] + line[indent:]
|
||||||
|
elif len(current_indents) > 0:
|
||||||
|
# Let's find the proper level of indentation
|
||||||
|
level = len(current_indents) - 1
|
||||||
|
while level >= 0 and current_indents[level] > indent:
|
||||||
|
level -= 1
|
||||||
|
current_indents = current_indents[: level + 1]
|
||||||
|
if level >= 0:
|
||||||
|
if current_indents[level] < indent:
|
||||||
|
new_indents = new_indents[: level + 1]
|
||||||
|
else:
|
||||||
|
new_indents = new_indents[:level]
|
||||||
|
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
|
||||||
|
lines[idx] = " " * new_indent + line[indent:]
|
||||||
|
new_indents.append(new_indent)
|
||||||
|
else:
|
||||||
|
new_indents = []
|
||||||
|
lines[idx] = line[indent:]
|
||||||
|
else:
|
||||||
|
lines[idx] = line[indent:]
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def base_rst_to_mdx(text, page_info, unindent=True):
|
||||||
|
"""
|
||||||
|
Convert a text from rst to mdx, with the base operations necessary for both docstrings and rst docs.
|
||||||
|
"""
|
||||||
|
text = convert_rst_links(text, page_info)
|
||||||
|
text = convert_special_chars(text)
|
||||||
|
text = convert_rst_blocks(text, page_info)
|
||||||
|
# Convert * in lists to - to avoid the formatting conversion treat them as bold.
|
||||||
|
text = re.sub(r"^(\s*)\*(\s)", r"\1-\2", text, flags=re.MULTILINE)
|
||||||
|
text = convert_rst_formatting(text)
|
||||||
|
return remove_indent(text) if unindent else text
|
||||||
|
|
||||||
|
|
||||||
|
def convert_rst_docstring_to_mdx(docstring, page_info):
|
||||||
|
"""
|
||||||
|
Convert a docstring written in rst to mdx.
|
||||||
|
"""
|
||||||
|
text = parse_rst_docstring(docstring)
|
||||||
|
return base_rst_to_mdx(text, page_info)
|
||||||
|
|
||||||
|
|
||||||
|
def process_titles(lines):
|
||||||
|
"""Converts rst titles to markdown titles."""
|
||||||
|
title_chars = """= - ` : ' " ~ ^ _ * + # < >""".split(" ")
|
||||||
|
title_levels = {}
|
||||||
|
new_lines = []
|
||||||
|
for line in lines:
|
||||||
|
if (
|
||||||
|
len(new_lines) > 0
|
||||||
|
and len(line) >= len(new_lines[-1])
|
||||||
|
and len(set(line)) == 1
|
||||||
|
and line[0] in title_chars
|
||||||
|
and line != "::"
|
||||||
|
):
|
||||||
|
char = line[0]
|
||||||
|
level = title_levels.get(char, len(title_levels) + 1)
|
||||||
|
if level not in title_levels:
|
||||||
|
title_levels[char] = level
|
||||||
|
new_lines[-1] = f"{'#' * level} {new_lines[-1]}"
|
||||||
|
else:
|
||||||
|
new_lines.append(line)
|
||||||
|
return new_lines
|
||||||
|
|
||||||
|
|
||||||
|
# Matches lines with a pattern of a table new line in rst.
|
||||||
|
_re_ignore_line_table = re.compile(r"^(\+[\-\s]+)+\+\s*$")
|
||||||
|
# Matches lines with a pattern of a table new line in rst, with a first column empty.
|
||||||
|
_re_ignore_line_table1 = re.compile(r"^\|\s+(\+[\-\s]+)+\+\s*$")
|
||||||
|
# Matches lines with a pattern of a first table line in rst.
|
||||||
|
_re_sep_line_table = re.compile(r"^(\+[=\s]+)+\+\s*$")
|
||||||
|
# Re pattern that catches anchors of the type .. reference:
|
||||||
|
_re_anchor_section = re.compile(r"^\.\.\s+_(\S+):")
|
||||||
|
|
||||||
|
|
||||||
|
def split_pt_tf_code_blocks(text):
|
||||||
|
"""
|
||||||
|
Split PyTorch and TensorFlow specific block codes.
|
||||||
|
"""
|
||||||
|
lines = text.split("\n")
|
||||||
|
new_lines = []
|
||||||
|
idx = 0
|
||||||
|
while idx < len(lines):
|
||||||
|
if lines[idx].startswith("```"):
|
||||||
|
code_lines = {"common": [lines[idx]], "pytorch": [], "tensorflow": []}
|
||||||
|
is_pytorch = False
|
||||||
|
is_tensorflow = False
|
||||||
|
idx += 1
|
||||||
|
while idx < len(lines) and lines[idx].strip() != "```":
|
||||||
|
if "## PYTORCH CODE" in lines[idx]:
|
||||||
|
is_pytorch = True
|
||||||
|
is_tensorflow = False
|
||||||
|
elif "## TENSORFLOW CODE" in lines[idx]:
|
||||||
|
is_tensorflow = True
|
||||||
|
is_pytorch = False
|
||||||
|
elif is_pytorch:
|
||||||
|
code_lines["pytorch"].append(lines[idx])
|
||||||
|
elif is_tensorflow:
|
||||||
|
code_lines["tensorflow"].append(lines[idx])
|
||||||
|
else:
|
||||||
|
code_lines["common"].append(lines[idx])
|
||||||
|
idx += 1
|
||||||
|
if len(code_lines["pytorch"]) > 0 or len(code_lines["tensorflow"]) > 0:
|
||||||
|
block_lines = ["<frameworkcontent>", "<pt>"]
|
||||||
|
block_lines.extend(code_lines["common"].copy() + code_lines["pytorch"])
|
||||||
|
block_lines.extend(["```", "</pt>", "<tf>"])
|
||||||
|
block_lines.extend(
|
||||||
|
code_lines["common"].copy() + code_lines["tensorflow"]
|
||||||
|
)
|
||||||
|
block_lines.extend(["```", "</tf>", "</frameworkcontent>"])
|
||||||
|
new_lines.extend(block_lines)
|
||||||
|
else:
|
||||||
|
block_lines = code_lines["common"] + ["```"]
|
||||||
|
new_lines.extend(block_lines)
|
||||||
|
idx += 1
|
||||||
|
else:
|
||||||
|
new_lines.append(lines[idx])
|
||||||
|
idx += 1
|
||||||
|
return "\n".join(new_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_rst_to_mdx(rst_text, page_info, add_imports=True):
|
||||||
|
"""
|
||||||
|
Convert a document written in rst to mdx.
|
||||||
|
"""
|
||||||
|
lines = rst_text.split("\n")
|
||||||
|
lines = process_titles(lines)
|
||||||
|
if add_imports:
|
||||||
|
new_lines = [
|
||||||
|
'<script lang="ts">',
|
||||||
|
' import Tip from "$lib/Tip.svelte";',
|
||||||
|
' import Youtube from "$lib/Youtube.svelte";',
|
||||||
|
' import Docstring from "$lib/Docstring.svelte";',
|
||||||
|
' import CodeBlock from "$lib/CodeBlock.svelte";',
|
||||||
|
' import CodeBlockFw from "$lib/CodeBlockFw.svelte";',
|
||||||
|
' import DocNotebookDropdown from "$lib/DocNotebookDropdown.svelte";',
|
||||||
|
' import CourseFloatingBanner from "$lib/CourseFloatingBanner.svelte";',
|
||||||
|
' import IconCopyLink from "$lib/IconCopyLink.svelte";',
|
||||||
|
' import FrameworkContent from "$lib/FrameworkContent.svelte";',
|
||||||
|
' import Markdown from "$lib/Markdown.svelte";',
|
||||||
|
' import ExampleCodeBlock from "$lib/ExampleCodeBlock.svelte";',
|
||||||
|
' import Added from "$lib/Added.svelte";',
|
||||||
|
' import Changed from "$lib/Changed.svelte";',
|
||||||
|
' import Deprecated from "$lib/Deprecated.svelte";',
|
||||||
|
' import PipelineIcon from "$lib/PipelineIcon.svelte";',
|
||||||
|
' import PipelineTag from "$lib/PipelineTag.svelte";',
|
||||||
|
" ",
|
||||||
|
' export let fw: "pt" | "tf"',
|
||||||
|
"</script>",
|
||||||
|
"<svelte:head>",
|
||||||
|
'<meta name="hf:doc:metadata" content={JSON.stringify(metadata)} >',
|
||||||
|
"</svelte:head>",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
new_lines = []
|
||||||
|
for line in lines:
|
||||||
|
if _re_ignore_line_table.search(line) is not None:
|
||||||
|
continue
|
||||||
|
elif _re_ignore_line_table1.search(line) is not None:
|
||||||
|
continue
|
||||||
|
elif _re_sep_line_table.search(line) is not None:
|
||||||
|
line = line.replace("=", "-").replace("+", "|")
|
||||||
|
elif _re_anchor_section.search(line) is not None:
|
||||||
|
anchor_name = _re_anchor_section.search(line).groups()[0]
|
||||||
|
line = f"<a id='{anchor_name}'></a>"
|
||||||
|
new_lines.append(line)
|
||||||
|
text = "\n".join(new_lines)
|
||||||
|
|
||||||
|
return split_pt_tf_code_blocks(base_rst_to_mdx(text, page_info))
|
52
src/kernels/_versions.py
Normal file
52
src/kernels/_versions.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from huggingface_hub.hf_api import GitRefInfo
|
||||||
|
from packaging.specifiers import SpecifierSet
|
||||||
|
from packaging.version import InvalidVersion, Version
|
||||||
|
|
||||||
|
|
||||||
|
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
|
||||||
|
"""Get kernel versions that are available in the repository."""
|
||||||
|
versions = {}
|
||||||
|
for tag in HfApi().list_repo_refs(repo_id).tags:
|
||||||
|
if not tag.name.startswith("v"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
versions[Version(tag.name[1:])] = tag
|
||||||
|
except InvalidVersion:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return versions
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo:
|
||||||
|
"""
|
||||||
|
Get the locks for a kernel with the given version spec.
|
||||||
|
|
||||||
|
The version specifier can be any valid Python version specifier:
|
||||||
|
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
|
||||||
|
"""
|
||||||
|
versions = _get_available_versions(repo_id)
|
||||||
|
requirement = SpecifierSet(version_spec)
|
||||||
|
accepted_versions = sorted(requirement.filter(versions.keys()))
|
||||||
|
|
||||||
|
if len(accepted_versions) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return versions[accepted_versions[-1]]
|
||||||
|
|
||||||
|
|
||||||
|
def select_revision_or_version(
|
||||||
|
repo_id: str, revision: Optional[str], version: Optional[str]
|
||||||
|
) -> str:
|
||||||
|
if revision is not None and version is not None:
|
||||||
|
raise ValueError("Either a revision or a version must be specified, not both.")
|
||||||
|
elif revision is None and version is None:
|
||||||
|
revision = "main"
|
||||||
|
elif version is not None:
|
||||||
|
revision = resolve_version_spec_as_ref(repo_id, version).target_commit
|
||||||
|
assert revision is not None
|
||||||
|
return revision
|
160
src/kernels/cli.py
Normal file
160
src/kernels/cli.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from kernels.compat import tomllib
|
||||||
|
from kernels.lockfile import KernelLock, get_kernel_locks
|
||||||
|
from kernels.utils import install_kernel, install_kernel_all_variants
|
||||||
|
|
||||||
|
from .doc import generate_readme_for_kernel
|
||||||
|
from .wheel import build_variant_to_wheel
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="kernel", description="Manage compute kernels"
|
||||||
|
)
|
||||||
|
subparsers = parser.add_subparsers(required=True)
|
||||||
|
|
||||||
|
download_parser = subparsers.add_parser("download", help="Download locked kernels")
|
||||||
|
download_parser.add_argument(
|
||||||
|
"project_dir",
|
||||||
|
type=Path,
|
||||||
|
help="The project directory",
|
||||||
|
)
|
||||||
|
download_parser.add_argument(
|
||||||
|
"--all-variants",
|
||||||
|
action="store_true",
|
||||||
|
help="Download all build variants of the kernel",
|
||||||
|
)
|
||||||
|
download_parser.set_defaults(func=download_kernels)
|
||||||
|
|
||||||
|
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
|
||||||
|
lock_parser.add_argument(
|
||||||
|
"project_dir",
|
||||||
|
type=Path,
|
||||||
|
help="The project directory",
|
||||||
|
)
|
||||||
|
lock_parser.set_defaults(func=lock_kernels)
|
||||||
|
|
||||||
|
to_wheel_parser = subparsers.add_parser(
|
||||||
|
"to-wheel", help="Convert a kernel to a wheel file"
|
||||||
|
)
|
||||||
|
to_wheel_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
|
||||||
|
to_wheel_parser.add_argument("version", type=str, help="The kernel version")
|
||||||
|
to_wheel_parser.add_argument(
|
||||||
|
"--python-version",
|
||||||
|
type=str,
|
||||||
|
default="3.9",
|
||||||
|
help="The minimum Python version. Must match the Python version that the kernel was compiled for.",
|
||||||
|
)
|
||||||
|
to_wheel_parser.add_argument(
|
||||||
|
"--manylinux-version",
|
||||||
|
type=str,
|
||||||
|
default="2.28",
|
||||||
|
help="The manylinux version. Must match the manylinux version that the kernel was compiled for.",
|
||||||
|
)
|
||||||
|
to_wheel_parser.set_defaults(func=kernels_to_wheel)
|
||||||
|
|
||||||
|
# Add generate-readme subcommand parser
|
||||||
|
generate_readme_parser = subparsers.add_parser(
|
||||||
|
"generate-readme",
|
||||||
|
help="Generate README snippets for a kernel's public functions",
|
||||||
|
)
|
||||||
|
generate_readme_parser.add_argument(
|
||||||
|
"repo_id",
|
||||||
|
type=str,
|
||||||
|
help="The kernel repo ID (e.g., kernels-community/activation)",
|
||||||
|
)
|
||||||
|
generate_readme_parser.add_argument(
|
||||||
|
"--revision",
|
||||||
|
type=str,
|
||||||
|
default="main",
|
||||||
|
help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
|
||||||
|
)
|
||||||
|
generate_readme_parser.set_defaults(
|
||||||
|
func=lambda args: generate_readme_for_kernel(
|
||||||
|
repo_id=args.repo_id, revision=args.revision
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.func(args)
|
||||||
|
|
||||||
|
|
||||||
|
def download_kernels(args):
|
||||||
|
lock_path = args.project_dir / "kernels.lock"
|
||||||
|
|
||||||
|
if not lock_path.exists():
|
||||||
|
print(f"No kernels.lock file found in: {args.project_dir}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
with open(args.project_dir / "kernels.lock", "r") as f:
|
||||||
|
lock_json = json.load(f)
|
||||||
|
|
||||||
|
all_successful = True
|
||||||
|
|
||||||
|
for kernel_lock_json in lock_json:
|
||||||
|
kernel_lock = KernelLock.from_json(kernel_lock_json)
|
||||||
|
print(
|
||||||
|
f"Downloading `{kernel_lock.repo_id}` at with SHA: {kernel_lock.sha}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
if args.all_variants:
|
||||||
|
install_kernel_all_variants(
|
||||||
|
kernel_lock.repo_id, kernel_lock.sha, variant_locks=kernel_lock.variants
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
install_kernel(
|
||||||
|
kernel_lock.repo_id,
|
||||||
|
kernel_lock.sha,
|
||||||
|
variant_locks=kernel_lock.variants,
|
||||||
|
)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
print(e, file=sys.stderr)
|
||||||
|
all_successful = False
|
||||||
|
|
||||||
|
if not all_successful:
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def kernels_to_wheel(args):
|
||||||
|
variants_path = install_kernel_all_variants(
|
||||||
|
repo_id=args.repo_id, revision=f"v{args.version}"
|
||||||
|
)
|
||||||
|
for variant_path in variants_path.iterdir():
|
||||||
|
if not variant_path.is_dir():
|
||||||
|
continue
|
||||||
|
wheel_path = build_variant_to_wheel(
|
||||||
|
manylinux_version=args.manylinux_version,
|
||||||
|
python_version=args.python_version,
|
||||||
|
repo_id=args.repo_id,
|
||||||
|
version=args.version,
|
||||||
|
variant_path=variant_path,
|
||||||
|
wheel_dir=Path("."),
|
||||||
|
)
|
||||||
|
print(f"☸️ {wheel_path.name}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def lock_kernels(args):
|
||||||
|
with open(args.project_dir / "pyproject.toml", "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
|
||||||
|
kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)
|
||||||
|
|
||||||
|
all_locks = []
|
||||||
|
for kernel, version in kernel_versions.items():
|
||||||
|
all_locks.append(get_kernel_locks(kernel, version))
|
||||||
|
|
||||||
|
with open(args.project_dir / "kernels.lock", "w") as f:
|
||||||
|
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
class _JSONEncoder(json.JSONEncoder):
|
||||||
|
def default(self, o):
|
||||||
|
if dataclasses.is_dataclass(o):
|
||||||
|
return dataclasses.asdict(o)
|
||||||
|
return super().default(o)
|
242
src/kernels/doc.py
Normal file
242
src/kernels/doc.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from ._vendored.convert_rst_to_mdx import convert_rst_docstring_to_mdx
|
||||||
|
from .utils import get_kernel
|
||||||
|
|
||||||
|
_RE_PARAMETERS = re.compile(
|
||||||
|
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
|
||||||
|
)
|
||||||
|
_RE_RETURNS = re.compile(r"<returns>(((?!<returns>).)*)</returns>", re.DOTALL)
|
||||||
|
_RE_RETURNTYPE = re.compile(
|
||||||
|
r"<returntype>(((?!<returntype>).)*)</returntype>", re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_description_before_tags(docstring_mdx: str) -> str:
|
||||||
|
"""Extract the description part of a docstring before any tags."""
|
||||||
|
params_pos = docstring_mdx.find("<parameters>")
|
||||||
|
returns_pos = docstring_mdx.find("<returns>")
|
||||||
|
returntype_pos = docstring_mdx.find("<returntype>")
|
||||||
|
positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1]
|
||||||
|
|
||||||
|
if positions:
|
||||||
|
first_tag_pos = min(positions)
|
||||||
|
return docstring_mdx[:first_tag_pos].strip()
|
||||||
|
else:
|
||||||
|
return docstring_mdx.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None:
|
||||||
|
"""Print the parameters section from a docstring."""
|
||||||
|
matches = _RE_PARAMETERS.findall(docstring_mdx)
|
||||||
|
if matches:
|
||||||
|
header = "#" * header_level
|
||||||
|
print(f"\n{header} Parameters")
|
||||||
|
for match in matches:
|
||||||
|
print(f"\n{match[0].strip()}")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_returns_section(
|
||||||
|
docstring_mdx: str, *, context_name: str, header_level: int
|
||||||
|
) -> None:
|
||||||
|
"""Print the returns section from a docstring."""
|
||||||
|
return_matches = _RE_RETURNS.findall(docstring_mdx)
|
||||||
|
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
|
||||||
|
|
||||||
|
if return_matches or returntype_matches:
|
||||||
|
header = "#" * header_level
|
||||||
|
print(f"\n{header} Returns")
|
||||||
|
|
||||||
|
if returntype_matches:
|
||||||
|
if len(returntype_matches) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"More than one <returntype> tag found in docstring for {context_name}"
|
||||||
|
)
|
||||||
|
print(f"\n**Type**: {returntype_matches[0][0].strip()}")
|
||||||
|
|
||||||
|
if return_matches:
|
||||||
|
for match in return_matches:
|
||||||
|
print(f"\n{match[0].strip()}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_docstring(obj, use_dict_check: bool = False) -> str:
|
||||||
|
"""Get docstring from an object, with fallback to default message."""
|
||||||
|
# Check whether the class/method itself has docs and not just
|
||||||
|
# the superclass.
|
||||||
|
if use_dict_check:
|
||||||
|
has_doc = obj.__dict__.get("__doc__", None) is not None
|
||||||
|
else:
|
||||||
|
has_doc = getattr(obj, "__doc__", None) is not None
|
||||||
|
|
||||||
|
# We use inspect.getdoc because it does normalization.
|
||||||
|
doc = inspect.getdoc(obj)
|
||||||
|
|
||||||
|
return doc if has_doc and doc is not None else "No documentation available."
|
||||||
|
|
||||||
|
|
||||||
|
def _process_and_print_docstring(
|
||||||
|
docstring: str, *, kernel_name: str, context_name: str, header_level: int
|
||||||
|
) -> None:
|
||||||
|
"""Convert docstring to MDX and print description, parameters, and returns sections."""
|
||||||
|
docstring_mdx = convert_rst_docstring_to_mdx(
|
||||||
|
docstring, page_info={"package_name": kernel_name}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print the description
|
||||||
|
description = _extract_description_before_tags(docstring_mdx)
|
||||||
|
print(f"\n{description}")
|
||||||
|
|
||||||
|
# Print parameters and returns sections
|
||||||
|
_print_parameters_section(docstring_mdx, header_level=header_level)
|
||||||
|
_print_returns_section(
|
||||||
|
docstring_mdx, context_name=context_name, header_level=header_level
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
|
||||||
|
kernel_module = get_kernel(repo_id=repo_id, revision=revision)
|
||||||
|
kernel_name = repo_id.split("/")[-1].replace("-", "_")
|
||||||
|
|
||||||
|
generate_metadata(kernel_module)
|
||||||
|
generate_kernel_doc(kernel_module, kernel_name)
|
||||||
|
generate_function_doc(kernel_module, kernel_name)
|
||||||
|
generate_layers_doc(kernel_module, kernel_name)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_metadata(module: ModuleType) -> None:
|
||||||
|
metadata = getattr(module, "__kernel_metadata__", {})
|
||||||
|
if "tags" not in metadata:
|
||||||
|
metadata["tags"] = ["kernel"]
|
||||||
|
else:
|
||||||
|
if "kernel" not in metadata["tags"]:
|
||||||
|
metadata["tags"].append("kernel")
|
||||||
|
|
||||||
|
print("---")
|
||||||
|
print(yaml.dump(metadata), end="")
|
||||||
|
print("---")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None:
|
||||||
|
docstring = module.__doc__.strip() if module.__doc__ is not None else None
|
||||||
|
if docstring:
|
||||||
|
title, rest = docstring.split("\n", 1)
|
||||||
|
print(f"# {title.strip()}")
|
||||||
|
print(
|
||||||
|
f"\n{convert_rst_docstring_to_mdx(rest.strip(), page_info={'package_name': kernel_name})}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None:
|
||||||
|
print("\n## Functions")
|
||||||
|
|
||||||
|
# Track if we found any functions
|
||||||
|
found_functions = False
|
||||||
|
|
||||||
|
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
|
||||||
|
# Do not include imported functions.
|
||||||
|
if func.__module__ != kernel_module.__name__:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Exclude private functions.
|
||||||
|
if name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
found_functions = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
docstring = _get_docstring(func)
|
||||||
|
except ValueError:
|
||||||
|
print(
|
||||||
|
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\n### Function `{name}`")
|
||||||
|
print(f"\n`{sig}`")
|
||||||
|
|
||||||
|
_process_and_print_docstring(
|
||||||
|
docstring, kernel_name=kernel_name, context_name=name, header_level=3
|
||||||
|
)
|
||||||
|
|
||||||
|
if not found_functions:
|
||||||
|
print("\nNo public top-level functions.")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None:
|
||||||
|
# Check if layers module is available
|
||||||
|
layers_module = getattr(kernel_module, "layers", None)
|
||||||
|
if layers_module is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n## Layers")
|
||||||
|
|
||||||
|
# Track if we found any classes
|
||||||
|
found_classes = False
|
||||||
|
|
||||||
|
for class_name, cls in inspect.getmembers(layers_module, inspect.isclass):
|
||||||
|
# Exclude classes that were imported.
|
||||||
|
if cls.__module__ != layers_module.__name__:
|
||||||
|
continue
|
||||||
|
|
||||||
|
found_classes = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get docstring, but not from superclasses.
|
||||||
|
class_docstring = _get_docstring(cls, use_dict_check=True)
|
||||||
|
except Exception:
|
||||||
|
print(
|
||||||
|
f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\n### Class `{class_name}`")
|
||||||
|
|
||||||
|
# Always print class description (helper handles conversion and formatting)
|
||||||
|
class_docstring_mdx = convert_rst_docstring_to_mdx(
|
||||||
|
class_docstring, page_info={"package_name": kernel_name}
|
||||||
|
)
|
||||||
|
description = _extract_description_before_tags(class_docstring_mdx)
|
||||||
|
print(f"\n{description}")
|
||||||
|
|
||||||
|
# Document methods
|
||||||
|
print("\n#### Methods")
|
||||||
|
|
||||||
|
for method_name, method in inspect.getmembers(cls, inspect.isfunction):
|
||||||
|
# Note: also skip __init__, since extension layers cannot have a constructor.
|
||||||
|
if method_name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip methods from superclasses.
|
||||||
|
if method_name not in cls.__dict__:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
sig = inspect.signature(method)
|
||||||
|
method_docstring = _get_docstring(method)
|
||||||
|
except ValueError:
|
||||||
|
print(
|
||||||
|
f"Warning: Could not retrieve signature for {method_name} in {class_name}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\n##### Method `{method_name}`")
|
||||||
|
print(f"\n`{sig}`")
|
||||||
|
|
||||||
|
_process_and_print_docstring(
|
||||||
|
method_docstring,
|
||||||
|
kernel_name=kernel_name,
|
||||||
|
context_name=method_name,
|
||||||
|
header_level=6,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not found_classes:
|
||||||
|
print("\nNo layers defined.")
|
758
src/kernels/layer.py
Normal file
758
src/kernels/layer.py
Normal file
@ -0,0 +1,758 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Flag, auto
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from types import MethodType
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ._interval_tree import IntervalTree
|
||||||
|
from ._versions import select_revision_or_version
|
||||||
|
from .utils import _get_caller_locked_kernel, _get_locked_kernel, get_kernel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
|
||||||
|
|
||||||
|
|
||||||
|
class Mode(Flag):
|
||||||
|
"""
|
||||||
|
Kernelize mode
|
||||||
|
|
||||||
|
The `Mode` flag is used by `kernelize` to select kernels for the given
|
||||||
|
mode. Mappings can be registered for specific modes.
|
||||||
|
|
||||||
|
* `INFERENCE`: The kernel is used for inference.
|
||||||
|
* `TRAINING`: The kernel is used for training.
|
||||||
|
* `TORCH_COMPILE`: The kernel is used with `torch.compile`.
|
||||||
|
* `FALLBACK`: In a kernel mapping, this kernel is used when no other mode
|
||||||
|
matches.
|
||||||
|
|
||||||
|
Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE`
|
||||||
|
should be used for layers that are used for inference *with* `torch.compile`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_NONE = 0
|
||||||
|
FALLBACK = auto()
|
||||||
|
TRAINING = auto()
|
||||||
|
INFERENCE = auto()
|
||||||
|
TORCH_COMPILE = auto()
|
||||||
|
|
||||||
|
def __or__(self, other: Mode) -> Mode:
|
||||||
|
union = super().__or__(other)
|
||||||
|
|
||||||
|
if Mode.INFERENCE in union and Mode.TRAINING in union:
|
||||||
|
raise ValueError("Mode.INFERENCE and Mode.TRAINING are mutually exclusive.")
|
||||||
|
|
||||||
|
if Mode.FALLBACK in union and union != Mode.FALLBACK:
|
||||||
|
raise ValueError("Mode.FALLBACK cannot be combined with other modes.")
|
||||||
|
|
||||||
|
return union
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Device:
|
||||||
|
type: str
|
||||||
|
properties: Optional[CUDAProperties] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.properties is not None and isinstance(self.properties, CUDAProperties):
|
||||||
|
if self.type != "cuda":
|
||||||
|
raise ValueError("CUDAProperties is only supported for 'cuda' devices.")
|
||||||
|
|
||||||
|
def create_repo(self) -> _DeviceRepos:
|
||||||
|
"""Create an appropriate repository set for this device type."""
|
||||||
|
if self.type == "cuda":
|
||||||
|
return _CUDARepos()
|
||||||
|
elif self.type == "mps":
|
||||||
|
return _MPSRepos()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown device type: {self.type}")
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, Device):
|
||||||
|
return NotImplemented
|
||||||
|
return self.type == other.type and self.properties == other.properties
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.type, self.properties))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CUDAProperties:
|
||||||
|
min_capability: int
|
||||||
|
max_capability: int
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, CUDAProperties):
|
||||||
|
return NotImplemented
|
||||||
|
return (
|
||||||
|
self.min_capability == other.min_capability
|
||||||
|
and self.max_capability == other.max_capability
|
||||||
|
)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.min_capability, self.max_capability))
|
||||||
|
|
||||||
|
|
||||||
|
class LayerRepositoryProtocol(Protocol):
|
||||||
|
@property
|
||||||
|
def layer_name(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repo_id(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def revision(self) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
|
class LayerRepository:
|
||||||
|
"""
|
||||||
|
Repository and name of a layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_id: str,
|
||||||
|
*,
|
||||||
|
layer_name: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
version: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Construct a layer repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`): The Hub repository containing the layer.
|
||||||
|
revision (`str`, *optional*, defaults to `"main"`): The specific
|
||||||
|
revision (branch, tag, or commit) to download.
|
||||||
|
Cannot be used together with `version`.
|
||||||
|
version (`str`, *optional*): The kernel version to download. This
|
||||||
|
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||||
|
Cannot be used together with `revision`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if revision is not None and version is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Either a revision or a version must be specified, not both."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.repo_id = repo_id
|
||||||
|
self.layer_name = layer_name
|
||||||
|
|
||||||
|
# We are going to resolve these lazily, since we do not want
|
||||||
|
# to do a network request for every registered LayerRepository.
|
||||||
|
self._revision = revision
|
||||||
|
self._version = version
|
||||||
|
|
||||||
|
@property
|
||||||
|
@functools.lru_cache()
|
||||||
|
def revision(self) -> str:
|
||||||
|
return select_revision_or_version(
|
||||||
|
repo_id=self.repo_id, revision=self._revision, version=self._version
|
||||||
|
)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (
|
||||||
|
isinstance(other, LayerRepository)
|
||||||
|
and self.layer_name == other.layer_name
|
||||||
|
and self.repo_id == other.repo_id
|
||||||
|
and self._revision == other._revision
|
||||||
|
and self._version == other._version
|
||||||
|
)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.layer_name, self.repo_id, self._revision, self._version))
|
||||||
|
|
||||||
|
|
||||||
|
class LockedLayerRepository:
|
||||||
|
"""
|
||||||
|
Repository and name of a layer.
|
||||||
|
|
||||||
|
In contrast to `LayerRepository`, this class uses repositories that
|
||||||
|
are locked inside a project.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_id: str,
|
||||||
|
*,
|
||||||
|
lockfile: Optional[Path] = None,
|
||||||
|
layer_name: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Construct a layer repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`): The Hub repository containing the layer.
|
||||||
|
"""
|
||||||
|
self.repo_id = repo_id
|
||||||
|
self.lockfile = lockfile
|
||||||
|
self.layer_name = layer_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
@functools.lru_cache()
|
||||||
|
def revision(self) -> str:
|
||||||
|
if self.lockfile is None:
|
||||||
|
locked_sha = _get_caller_locked_kernel(self.repo_id)
|
||||||
|
else:
|
||||||
|
with open(self.lockfile, "r") as f:
|
||||||
|
locked_sha = _get_locked_kernel(self.repo_id, f.read())
|
||||||
|
|
||||||
|
if locked_sha is None:
|
||||||
|
raise ValueError(f"Kernel `{self.repo_id}` is not locked")
|
||||||
|
|
||||||
|
return locked_sha
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (
|
||||||
|
isinstance(other, LockedLayerRepository)
|
||||||
|
and self.layer_name == other.layer_name
|
||||||
|
and self.repo_id == other.repo_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.layer_name, self.repo_id))
|
||||||
|
|
||||||
|
|
||||||
|
_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class _DeviceRepos(ABC):
|
||||||
|
"""
|
||||||
|
Device-specific kernel layer repositories.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def repos(
|
||||||
|
self,
|
||||||
|
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||||
|
"""
|
||||||
|
Insert a repository for a specific device and mode.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class _MPSRepos(_DeviceRepos):
|
||||||
|
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._repos = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repos(
|
||||||
|
self,
|
||||||
|
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||||
|
return self._repos
|
||||||
|
|
||||||
|
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||||
|
if device.type != "mps":
|
||||||
|
raise ValueError(f"Device type must be 'mps', got {device.type}")
|
||||||
|
|
||||||
|
self._repos = repos
|
||||||
|
|
||||||
|
|
||||||
|
class _CUDARepos(_DeviceRepos):
|
||||||
|
_repos: IntervalTree[Dict[Mode, LayerRepositoryProtocol]]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.repos_by_capability = IntervalTree()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repos(
|
||||||
|
self,
|
||||||
|
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||||
|
capability = _find_capability()
|
||||||
|
return self.repos_by_capability.find_smallest_interval(capability)
|
||||||
|
|
||||||
|
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||||
|
assert device.properties is None or isinstance(
|
||||||
|
device.properties, CUDAProperties
|
||||||
|
)
|
||||||
|
|
||||||
|
min_capability = (
|
||||||
|
0 if device.properties is None else device.properties.min_capability
|
||||||
|
)
|
||||||
|
max_capability = (
|
||||||
|
sys.maxsize
|
||||||
|
if device.properties is None
|
||||||
|
else device.properties.max_capability
|
||||||
|
)
|
||||||
|
|
||||||
|
self.repos_by_capability.insert(min_capability, max_capability, repos)
|
||||||
|
|
||||||
|
|
||||||
|
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[str, _DeviceRepos]]] = ContextVar(
|
||||||
|
"_KERNEL_MAPPING", default={}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def use_kernel_mapping(
|
||||||
|
mapping: Dict[
|
||||||
|
str,
|
||||||
|
Dict[
|
||||||
|
Union[Device, str],
|
||||||
|
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
*,
|
||||||
|
inherit_mapping: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Context manager that sets a mapping for a duration of the context.
|
||||||
|
|
||||||
|
When `inherit_mapping` is set to `True` the current mapping will be
|
||||||
|
extended by `mapping` inside the context. If it is `False`, only
|
||||||
|
`mapping` is used inside the context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class ContextManager:
|
||||||
|
def __enter__(self):
|
||||||
|
# Mappings always stack on previous mappings.
|
||||||
|
if inherit_mapping:
|
||||||
|
self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get()))
|
||||||
|
else:
|
||||||
|
self.token = _KERNEL_MAPPING.set({})
|
||||||
|
register_kernel_mapping(mapping)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
_KERNEL_MAPPING.reset(self.token)
|
||||||
|
|
||||||
|
return ContextManager()
|
||||||
|
|
||||||
|
|
||||||
|
def register_kernel_mapping(
|
||||||
|
mapping: Dict[
|
||||||
|
str,
|
||||||
|
Dict[
|
||||||
|
Union[Device, str],
|
||||||
|
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Allows one to register a mapping between a layer name and the corresponding
|
||||||
|
kernel(s) to use, depending on the device. This should be used in conjunction
|
||||||
|
with `kernelize`.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from kernels import LayerRepository, register_kernel_mapping
|
||||||
|
|
||||||
|
kernel_layer_mapping = {
|
||||||
|
"LlamaRMSNorm": {
|
||||||
|
"cuda": LayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="RmsNorm",
|
||||||
|
revision="layers",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
register_kernel_mapping(kernel_layer_mapping)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# Merge with existing mappings.
|
||||||
|
for new_kernel, new_device_repos in mapping.items():
|
||||||
|
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
|
||||||
|
for new_device, new_repo in new_device_repos.items():
|
||||||
|
device = (
|
||||||
|
Device(type=new_device) if isinstance(new_device, str) else new_device
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(new_repo, dict):
|
||||||
|
kernel_options = new_repo
|
||||||
|
else:
|
||||||
|
kernel_options = {Mode.FALLBACK: new_repo}
|
||||||
|
|
||||||
|
feature_repos = device_repo.setdefault(device.type, device.create_repo())
|
||||||
|
feature_repos.insert(device, kernel_options)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_kernel_forward_from_hub(
|
||||||
|
cls,
|
||||||
|
layer_name: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decorator that prepares a layer class to use a kernel from the Hugging Face Hub.
|
||||||
|
|
||||||
|
This decorator stores the layer name and original forward method, which will be used
|
||||||
|
by the kernelize function to replace the forward implementation with the appropriate
|
||||||
|
kernel from the hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: The layer class to decorate
|
||||||
|
layer_name: The name of the layer to use for kernel lookup
|
||||||
|
"""
|
||||||
|
cls.kernel_layer_name = layer_name
|
||||||
|
|
||||||
|
|
||||||
|
_MODE_FALLBACK_PRIORITY = {
|
||||||
|
Mode.INFERENCE: [
|
||||||
|
Mode.INFERENCE,
|
||||||
|
Mode.INFERENCE | Mode.TORCH_COMPILE,
|
||||||
|
Mode.TRAINING,
|
||||||
|
Mode.TRAINING | Mode.TORCH_COMPILE,
|
||||||
|
Mode.FALLBACK,
|
||||||
|
],
|
||||||
|
Mode.TRAINING: [
|
||||||
|
Mode.TRAINING,
|
||||||
|
Mode.TRAINING | Mode.TORCH_COMPILE,
|
||||||
|
Mode.FALLBACK,
|
||||||
|
],
|
||||||
|
Mode.INFERENCE
|
||||||
|
| Mode.TORCH_COMPILE: [
|
||||||
|
Mode.INFERENCE | Mode.TORCH_COMPILE,
|
||||||
|
Mode.TRAINING | Mode.TORCH_COMPILE,
|
||||||
|
Mode.FALLBACK,
|
||||||
|
],
|
||||||
|
Mode.TRAINING
|
||||||
|
| Mode.TORCH_COMPILE: [
|
||||||
|
Mode.TRAINING | Mode.TORCH_COMPILE,
|
||||||
|
Mode.FALLBACK,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _select_repository(
|
||||||
|
repositories: Dict[Mode, LayerRepositoryProtocol],
|
||||||
|
*,
|
||||||
|
mode: Mode,
|
||||||
|
) -> Optional[Tuple[LayerRepositoryProtocol, Mode]]:
|
||||||
|
# Get the fallback priority list for the requested mode
|
||||||
|
if mode not in _MODE_FALLBACK_PRIORITY:
|
||||||
|
raise ValueError(f"Unsupported mode: {mode}")
|
||||||
|
|
||||||
|
fallback_modes = _MODE_FALLBACK_PRIORITY[mode]
|
||||||
|
|
||||||
|
# Try each mode in priority order
|
||||||
|
for fallback_mode in fallback_modes:
|
||||||
|
if fallback_mode in repositories:
|
||||||
|
return (repositories[fallback_mode], fallback_mode)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def kernelize(
|
||||||
|
model: "nn.Module",
|
||||||
|
*,
|
||||||
|
mode: Mode = Mode.TRAINING | Mode.TORCH_COMPILE,
|
||||||
|
device: Optional[Union[str, "torch.device"]] = None,
|
||||||
|
use_fallback: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Iterate over all modules in the model and replace the `forward` method of
|
||||||
|
extensible layers for which kernels are registered using `register_kernel_mapping`
|
||||||
|
or `use_kernel_mapping`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The PyTorch model to kernelize
|
||||||
|
mode: the mode that the kernel is going to be used in (e.g.
|
||||||
|
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training
|
||||||
|
and `torch.compile`).
|
||||||
|
device: The device type to load kernels for. The device type will be inferred
|
||||||
|
from the parameters of the model when not provided.
|
||||||
|
use_fallback: Whether to use the original forward method of modules when no
|
||||||
|
compatible kernel could be found. If set to `False`, an exception will
|
||||||
|
be raised in such cases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The kernelized model
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if mode == Mode.FALLBACK:
|
||||||
|
raise ValueError("Mode.FALLBACK can only be used to register kernel mappings.")
|
||||||
|
|
||||||
|
# Type check ignored because this causes a false negative on Python < 3.11.
|
||||||
|
# Looks similar to: https://github.com/python/mypy/issues/9642
|
||||||
|
# Remove once we start doing typing checks on >= 3.11.
|
||||||
|
if Mode.INFERENCE not in mode and Mode.TRAINING not in mode: # type: ignore[operator]
|
||||||
|
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device_type = _find_device(model)
|
||||||
|
elif isinstance(device, str):
|
||||||
|
device_type = Device(type=torch.device(device).type)
|
||||||
|
else:
|
||||||
|
device_type = Device(device.type)
|
||||||
|
|
||||||
|
assert isinstance(device_type, Device)
|
||||||
|
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
module_class = type(module)
|
||||||
|
if not hasattr(module_class, "kernel_layer_name"):
|
||||||
|
continue
|
||||||
|
layer_name = module_class.kernel_layer_name
|
||||||
|
|
||||||
|
if _DISABLE_KERNEL_MAPPING:
|
||||||
|
_replace_forward(module, module_class)
|
||||||
|
continue
|
||||||
|
|
||||||
|
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
|
||||||
|
|
||||||
|
if kernel is None:
|
||||||
|
warnings.warn(
|
||||||
|
"\n"
|
||||||
|
f"No kernel mapping found for layer `{layer_name}`. "
|
||||||
|
f"Check if the layer name matches one of the kernels in the mapping or add the kernel "
|
||||||
|
f"you want to use to the mapping. Defaulting to original forward implementation."
|
||||||
|
)
|
||||||
|
if not use_fallback:
|
||||||
|
raise ValueError(f"No layer mapping for `{layer_name}`")
|
||||||
|
_replace_forward(module, module_class)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get kernel options for the device
|
||||||
|
property_repos = kernel.get(device_type.type)
|
||||||
|
|
||||||
|
if property_repos is None:
|
||||||
|
if not use_fallback:
|
||||||
|
raise ValueError(
|
||||||
|
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
|
||||||
|
)
|
||||||
|
_replace_forward(module, module_class)
|
||||||
|
continue
|
||||||
|
|
||||||
|
repos = property_repos.repos
|
||||||
|
|
||||||
|
if repos is None:
|
||||||
|
if not use_fallback:
|
||||||
|
raise ValueError(
|
||||||
|
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
|
||||||
|
)
|
||||||
|
_replace_forward(module, module_class)
|
||||||
|
continue
|
||||||
|
|
||||||
|
repo_with_mode = _select_repository(
|
||||||
|
repos,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
if repo_with_mode is None:
|
||||||
|
if not use_fallback:
|
||||||
|
raise ValueError(
|
||||||
|
f"No repository for `{layer_name}` for configuration mode={mode}"
|
||||||
|
)
|
||||||
|
_replace_forward(module, module_class)
|
||||||
|
continue
|
||||||
|
|
||||||
|
repo, repo_mode = repo_with_mode
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Using layer `{repo.layer_name}` from repo `{repo.repo_id}` (revision: {repo.revision}) for layer `{layer_name}`"
|
||||||
|
)
|
||||||
|
logging.debug(f"kernelize mode: {mode}, repo mode: {repo_mode}")
|
||||||
|
|
||||||
|
layer = _get_layer_memoize(repo, module_class)
|
||||||
|
|
||||||
|
# Ideally we would do validation on the mapping where we check that
|
||||||
|
# e.g. if a repo class is registered for TRAINING | TORCH_COMPILE,
|
||||||
|
# the actual layer is compatible with that. Unfortunately, this would
|
||||||
|
# mean that we have to pre-download everything.
|
||||||
|
_validate_layer_has_mode(
|
||||||
|
layer_name=layer_name, module=layer, repo=repo, repo_mode=repo_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
_conditionally_replace_forward(
|
||||||
|
module=module,
|
||||||
|
layer=layer,
|
||||||
|
mode=mode,
|
||||||
|
use_fallback=use_fallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def use_kernel_forward_from_hub(layer_name: str):
|
||||||
|
"""
|
||||||
|
Make a layer extensible using the name `layer_name`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(cls):
|
||||||
|
replace_kernel_forward_from_hub(cls, layer_name)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _get_kernel_layer(
|
||||||
|
*, repo_id: str, layer_name: str, revision: str
|
||||||
|
) -> Type["nn.Module"]:
|
||||||
|
"""Get a layer from a kernel."""
|
||||||
|
|
||||||
|
kernel = get_kernel(repo_id, revision=revision)
|
||||||
|
|
||||||
|
if getattr(kernel, "layers", None) is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Kernel `{repo_id}` at revision `{revision}` does not define any layers."
|
||||||
|
)
|
||||||
|
|
||||||
|
layer = getattr(kernel.layers, layer_name, None)
|
||||||
|
if layer is None:
|
||||||
|
raise ValueError(f"Layer `{layer_name}` not found in kernel `{repo_id}`.")
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_layer(*, check_cls, cls):
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# The layer must have at least have the following properties: (1) it
|
||||||
|
# must be stateless; (2) the forward signature should correspond to
|
||||||
|
# the signature it is replacing; (3) forward should not call other
|
||||||
|
# methods.
|
||||||
|
|
||||||
|
if not issubclass(cls, nn.Module):
|
||||||
|
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
|
||||||
|
|
||||||
|
# We verify statelessness by checking that the does not have its own
|
||||||
|
# constructor (since the constructor could add member variables)...
|
||||||
|
if cls.__init__ is not nn.Module.__init__:
|
||||||
|
raise TypeError("Layer must not override nn.Module constructor.")
|
||||||
|
|
||||||
|
# ... or predefined member variables.
|
||||||
|
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
||||||
|
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
||||||
|
difference = cls_members - torch_module_members
|
||||||
|
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
|
||||||
|
if not difference <= {"can_torch_compile", "has_backward"}:
|
||||||
|
raise TypeError("Layer must not contain additional members.")
|
||||||
|
|
||||||
|
# Check whether the forward signatures are similar.
|
||||||
|
params = inspect.signature(cls.forward).parameters
|
||||||
|
ref_params = inspect.signature(check_cls.forward).parameters
|
||||||
|
|
||||||
|
if len(params) != len(ref_params):
|
||||||
|
raise TypeError(
|
||||||
|
"Forward signature does not match: different number of arguments."
|
||||||
|
)
|
||||||
|
|
||||||
|
for param, ref_param in zip(params.values(), ref_params.values()):
|
||||||
|
if param.kind != ref_param.kind:
|
||||||
|
raise TypeError(
|
||||||
|
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_device(model: "nn.Module") -> Device:
|
||||||
|
try:
|
||||||
|
param = next(model.parameters())
|
||||||
|
except StopIteration:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot determine model device, provide as `device` argument to `kernelize`."
|
||||||
|
)
|
||||||
|
|
||||||
|
return Device(type=param.device.type)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _find_capability() -> int:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
major, minor = torch.cuda.get_device_capability(device=None)
|
||||||
|
return major * 10 + minor
|
||||||
|
|
||||||
|
|
||||||
|
def _conditionally_replace_forward(
|
||||||
|
*,
|
||||||
|
module: "nn.Module",
|
||||||
|
layer: Type["nn.Module"],
|
||||||
|
mode: Mode,
|
||||||
|
use_fallback: bool,
|
||||||
|
):
|
||||||
|
module_class = type(module)
|
||||||
|
|
||||||
|
# Switch to fallback if the mode is not supported by the layer.
|
||||||
|
# Note that this is useful even after _validate_layer_has_mode because
|
||||||
|
# layers registered with the FALLBACK mode never get rejected by
|
||||||
|
# _validate_layer_has_mode. For such layers, we want to fall back in
|
||||||
|
# case the layer does not support the given mode.
|
||||||
|
needs_fallback = Mode.TORCH_COMPILE in mode and not getattr(
|
||||||
|
layer, "can_torch_compile", False
|
||||||
|
)
|
||||||
|
needs_fallback |= Mode.TRAINING in mode and not getattr(layer, "has_backward", True)
|
||||||
|
|
||||||
|
if needs_fallback:
|
||||||
|
if use_fallback:
|
||||||
|
_replace_forward(module, module_class)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Available kernel does not support mode: {mode}")
|
||||||
|
else:
|
||||||
|
_replace_forward(module, layer)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
|
||||||
|
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_layer_has_mode(
|
||||||
|
*,
|
||||||
|
layer_name: str,
|
||||||
|
module: Type["nn.Module"],
|
||||||
|
repo: LayerRepositoryProtocol,
|
||||||
|
repo_mode: Mode,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check that a repository supports the mode that it was registered for.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if Mode.TRAINING in repo_mode and not getattr(module, "has_backward", True):
|
||||||
|
raise ValueError(
|
||||||
|
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support backward.\n"
|
||||||
|
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if Mode.TORCH_COMPILE in repo_mode and not getattr(
|
||||||
|
module, "can_torch_compile", False
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support torch.compile.\n"
|
||||||
|
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _get_layer_memoize(
|
||||||
|
repo: LayerRepositoryProtocol, module_class: Type["nn.Module"]
|
||||||
|
) -> Type["nn.Module"]:
|
||||||
|
layer = _CACHED_LAYER.get(repo, None)
|
||||||
|
if layer is not None:
|
||||||
|
return layer
|
||||||
|
|
||||||
|
layer = _get_kernel_layer(
|
||||||
|
repo_id=repo.repo_id,
|
||||||
|
layer_name=repo.layer_name,
|
||||||
|
revision=repo.revision,
|
||||||
|
)
|
||||||
|
_validate_layer(check_cls=module_class, cls=layer)
|
||||||
|
_CACHED_LAYER[repo] = layer
|
||||||
|
|
||||||
|
return layer
|
@ -1,63 +1,42 @@
|
|||||||
|
import hashlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from packaging.specifiers import SpecifierSet
|
|
||||||
from packaging.version import InvalidVersion, Version
|
|
||||||
|
|
||||||
from hf_kernels.compat import tomllib
|
from kernels._versions import resolve_version_spec_as_ref
|
||||||
|
from kernels.compat import tomllib
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FileLock:
|
class VariantLock:
|
||||||
filename: str
|
hash: str
|
||||||
blob_id: str
|
hash_type: str = "git_lfs_concat"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KernelLock:
|
class KernelLock:
|
||||||
repo_id: str
|
repo_id: str
|
||||||
sha: str
|
sha: str
|
||||||
files: List[FileLock]
|
variants: Dict[str, VariantLock]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json(cls, o: Dict):
|
def from_json(cls, o: Dict):
|
||||||
files = [FileLock(**f) for f in o["files"]]
|
variants = {
|
||||||
return cls(repo_id=o["repo_id"], sha=o["sha"], files=files)
|
variant: VariantLock(**lock) for variant, lock in o["variants"].items()
|
||||||
|
}
|
||||||
|
return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)
|
||||||
|
|
||||||
|
|
||||||
def _get_available_versions(repo_id: str):
|
def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
|
||||||
"""Get kernel versions that are available in the repository."""
|
|
||||||
versions = {}
|
|
||||||
for tag in HfApi().list_repo_refs(repo_id).tags:
|
|
||||||
if not tag.name.startswith("v"):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
versions[Version(tag.name[1:])] = tag
|
|
||||||
except InvalidVersion:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return versions
|
|
||||||
|
|
||||||
|
|
||||||
def get_kernel_locks(repo_id: str, version_spec: str):
|
|
||||||
"""
|
"""
|
||||||
Get the locks for a kernel with the given version spec.
|
Get the locks for a kernel with the given version spec.
|
||||||
|
|
||||||
The version specifier can be any valid Python version specifier:
|
The version specifier can be any valid Python version specifier:
|
||||||
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
|
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
|
||||||
"""
|
"""
|
||||||
versions = _get_available_versions(repo_id)
|
tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec)
|
||||||
requirement = SpecifierSet(version_spec)
|
|
||||||
accepted_versions = sorted(requirement.filter(versions.keys()))
|
|
||||||
|
|
||||||
if len(accepted_versions) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
|
|
||||||
)
|
|
||||||
|
|
||||||
tag_for_newest = versions[accepted_versions[-1]]
|
|
||||||
|
|
||||||
r = HfApi().repo_info(
|
r = HfApi().repo_info(
|
||||||
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
|
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
|
||||||
@ -72,17 +51,36 @@ def get_kernel_locks(repo_id: str, version_spec: str):
|
|||||||
f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
|
f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
file_locks = []
|
variant_files: Dict[str, List[Tuple[bytes, str]]] = {}
|
||||||
for sibling in r.siblings:
|
for sibling in r.siblings:
|
||||||
if sibling.rfilename.startswith("build/torch"):
|
if sibling.rfilename.startswith("build/torch"):
|
||||||
if sibling.blob_id is None:
|
if sibling.blob_id is None:
|
||||||
raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")
|
raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")
|
||||||
|
|
||||||
file_locks.append(
|
path = Path(sibling.rfilename)
|
||||||
FileLock(filename=sibling.rfilename, blob_id=sibling.blob_id)
|
variant = path.parts[1]
|
||||||
)
|
filename = Path(*path.parts[2:])
|
||||||
|
|
||||||
return KernelLock(repo_id=repo_id, sha=r.sha, files=file_locks)
|
hash = sibling.lfs.sha256 if sibling.lfs is not None else sibling.blob_id
|
||||||
|
|
||||||
|
files = variant_files.setdefault(variant, [])
|
||||||
|
|
||||||
|
# Encode as posix for consistent slash handling, then encode
|
||||||
|
# as utf-8 for byte-wise sorting later.
|
||||||
|
files.append((filename.as_posix().encode("utf-8"), hash))
|
||||||
|
|
||||||
|
variant_locks = {}
|
||||||
|
for variant, files in variant_files.items():
|
||||||
|
m = hashlib.sha256()
|
||||||
|
for filename_bytes, hash in sorted(files):
|
||||||
|
# Filename as bytes.
|
||||||
|
m.update(filename_bytes)
|
||||||
|
# Git blob or LFS file hash as bytes.
|
||||||
|
m.update(bytes.fromhex(hash))
|
||||||
|
|
||||||
|
variant_locks[variant] = VariantLock(hash=f"sha256-{m.hexdigest()}")
|
||||||
|
|
||||||
|
return KernelLock(repo_id=repo_id, sha=r.sha, variants=variant_locks)
|
||||||
|
|
||||||
|
|
||||||
def write_egg_lockfile(cmd, basename, filename):
|
def write_egg_lockfile(cmd, basename, filename):
|
||||||
@ -101,7 +99,7 @@ def write_egg_lockfile(cmd, basename, filename):
|
|||||||
if kernel_versions is None:
|
if kernel_versions is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
lock_path = cwd / "hf-kernels.lock"
|
lock_path = cwd / "kernels.lock"
|
||||||
if not lock_path.exists():
|
if not lock_path.exists():
|
||||||
logging.warning(f"Lock file {lock_path} does not exist")
|
logging.warning(f"Lock file {lock_path} does not exist")
|
||||||
# Ensure that the file gets deleted in editable installs.
|
# Ensure that the file gets deleted in editable installs.
|
424
src/kernels/utils.py
Normal file
424
src/kernels/utils.py
Normal file
@ -0,0 +1,424 @@
|
|||||||
|
import ctypes
|
||||||
|
import hashlib
|
||||||
|
import importlib
|
||||||
|
import importlib.metadata
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import sys
|
||||||
|
from importlib.metadata import Distribution
|
||||||
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from huggingface_hub import file_exists, snapshot_download
|
||||||
|
from packaging.version import parse
|
||||||
|
|
||||||
|
from kernels._versions import select_revision_or_version
|
||||||
|
from kernels.lockfile import KernelLock, VariantLock
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cache_dir() -> Optional[str]:
|
||||||
|
"""Returns the kernels cache directory."""
|
||||||
|
cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
|
||||||
|
if cache_dir is not None:
|
||||||
|
logging.warning(
|
||||||
|
"HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead"
|
||||||
|
)
|
||||||
|
return cache_dir
|
||||||
|
|
||||||
|
return os.environ.get("KERNELS_CACHE", None)
|
||||||
|
|
||||||
|
|
||||||
|
CACHE_DIR: Optional[str] = _get_cache_dir()
|
||||||
|
|
||||||
|
|
||||||
|
def build_variant() -> str:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch.version.cuda is not None:
|
||||||
|
cuda_version = parse(torch.version.cuda)
|
||||||
|
compute_framework = f"cu{cuda_version.major}{cuda_version.minor}"
|
||||||
|
elif torch.version.hip is not None:
|
||||||
|
rocm_version = parse(torch.version.hip.split("-")[0])
|
||||||
|
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
compute_framework = "metal"
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
compute_framework = "xpu"
|
||||||
|
else:
|
||||||
|
raise AssertionError(
|
||||||
|
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_version = parse(torch.__version__)
|
||||||
|
cpu = platform.machine()
|
||||||
|
os = platform.system().lower()
|
||||||
|
|
||||||
|
if os == "darwin":
|
||||||
|
cpu = "aarch64" if cpu == "arm64" else cpu
|
||||||
|
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
|
||||||
|
|
||||||
|
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
||||||
|
|
||||||
|
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
|
||||||
|
|
||||||
|
|
||||||
|
def universal_build_variant() -> str:
|
||||||
|
# Once we support other frameworks, detection goes here.
|
||||||
|
return "torch-universal"
|
||||||
|
|
||||||
|
|
||||||
|
def import_from_path(module_name: str, file_path: Path) -> ModuleType:
|
||||||
|
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
||||||
|
# it would also be used for other imports. So, we make a module name that
|
||||||
|
# depends on the path for it to be unique using the hex-encoded hash of
|
||||||
|
# the path.
|
||||||
|
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value)
|
||||||
|
module_name = f"{module_name}_{path_hash}"
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||||
|
if spec is None:
|
||||||
|
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
if module is None:
|
||||||
|
raise ImportError(f"Cannot load module {module_name} from spec")
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
spec.loader.exec_module(module) # type: ignore
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def install_kernel(
|
||||||
|
repo_id: str,
|
||||||
|
revision: str,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
||||||
|
) -> Tuple[str, Path]:
|
||||||
|
"""
|
||||||
|
Download a kernel for the current environment to the cache.
|
||||||
|
|
||||||
|
The output path is validated againt `hash` when set.
|
||||||
|
"""
|
||||||
|
package_name = package_name_from_repo_id(repo_id)
|
||||||
|
variant = build_variant()
|
||||||
|
universal_variant = universal_build_variant()
|
||||||
|
repo_path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id,
|
||||||
|
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
|
||||||
|
cache_dir=CACHE_DIR,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _load_kernel_from_path(repo_path, package_name, variant_locks)
|
||||||
|
except FileNotFoundError:
|
||||||
|
# Redo with more specific error message.
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_kernel_from_path(
|
||||||
|
repo_path: Path,
|
||||||
|
package_name: str,
|
||||||
|
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
||||||
|
) -> Tuple[str, Path]:
|
||||||
|
variant = build_variant()
|
||||||
|
universal_variant = universal_build_variant()
|
||||||
|
|
||||||
|
variant_path = repo_path / "build" / variant
|
||||||
|
universal_variant_path = repo_path / "build" / universal_variant
|
||||||
|
|
||||||
|
if not variant_path.exists() and universal_variant_path.exists():
|
||||||
|
# Fall back to universal variant.
|
||||||
|
variant = universal_variant
|
||||||
|
variant_path = universal_variant_path
|
||||||
|
|
||||||
|
if variant_locks is not None:
|
||||||
|
variant_lock = variant_locks.get(variant)
|
||||||
|
if variant_lock is None:
|
||||||
|
raise ValueError(f"No lock found for build variant: {variant}")
|
||||||
|
validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash)
|
||||||
|
|
||||||
|
module_init_path = variant_path / package_name / "__init__.py"
|
||||||
|
|
||||||
|
if not os.path.exists(module_init_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Kernel at path `{repo_path}` does not have build: {variant}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return package_name, variant_path
|
||||||
|
|
||||||
|
|
||||||
|
def install_kernel_all_variants(
|
||||||
|
repo_id: str,
|
||||||
|
revision: str,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
||||||
|
) -> Path:
|
||||||
|
repo_path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id,
|
||||||
|
allow_patterns="build/*",
|
||||||
|
cache_dir=CACHE_DIR,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if variant_locks is not None:
|
||||||
|
for entry in (repo_path / "build").iterdir():
|
||||||
|
variant = entry.parts[-1]
|
||||||
|
|
||||||
|
variant_lock = variant_locks.get(variant)
|
||||||
|
if variant_lock is None:
|
||||||
|
raise ValueError(f"No lock found for build variant: {variant}")
|
||||||
|
|
||||||
|
validate_kernel(
|
||||||
|
repo_path=repo_path, variant=variant, hash=variant_lock.hash
|
||||||
|
)
|
||||||
|
|
||||||
|
return repo_path / "build"
|
||||||
|
|
||||||
|
|
||||||
|
def get_kernel(
|
||||||
|
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
|
||||||
|
) -> ModuleType:
|
||||||
|
"""
|
||||||
|
Load a kernel from the kernel hub.
|
||||||
|
This function downloads a kernel to the local Hugging Face Hub cache
|
||||||
|
directory (if it was not downloaded before) and then loads the kernel.
|
||||||
|
Args:
|
||||||
|
repo_id (`str`): The Hub repository containing the kernel.
|
||||||
|
revision (`str`, *optional*, defaults to `"main"`): The specific
|
||||||
|
revision (branch, tag, or commit) to download.
|
||||||
|
Cannot be used together with `version`.
|
||||||
|
version (`str`, *optional*): The kernel version to download. This
|
||||||
|
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||||
|
Cannot be used together with `revision`.
|
||||||
|
Returns:
|
||||||
|
`ModuleType`: The imported kernel module.
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from kernels import get_kernel
|
||||||
|
kernel = get_kernel("username/my-kernel")
|
||||||
|
result = kernel.kernel_function(input_data)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
revision = select_revision_or_version(repo_id, revision, version)
|
||||||
|
package_name, package_path = install_kernel(repo_id, revision=revision)
|
||||||
|
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
|
||||||
|
"""
|
||||||
|
Import a kernel from a local kernel repository path.
|
||||||
|
"""
|
||||||
|
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
|
||||||
|
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||||
|
|
||||||
|
|
||||||
|
def has_kernel(
|
||||||
|
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether a kernel build exists for the current environment
|
||||||
|
(Torch version and compute framework).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`): The Hub repository containing the kernel.
|
||||||
|
revision (`str`, *optional*, defaults to `"main"`): The specific
|
||||||
|
revision (branch, tag, or commit) to download.
|
||||||
|
Cannot be used together with `version`.
|
||||||
|
version (`str`, *optional*): The kernel version to download. This
|
||||||
|
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||||
|
Cannot be used together with `revision`.
|
||||||
|
Returns:
|
||||||
|
`bool`: `true` if a kernel is avaialble for the current environment.
|
||||||
|
"""
|
||||||
|
revision = select_revision_or_version(repo_id, revision, version)
|
||||||
|
|
||||||
|
package_name = package_name_from_repo_id(repo_id)
|
||||||
|
variant = build_variant()
|
||||||
|
universal_variant = universal_build_variant()
|
||||||
|
|
||||||
|
if file_exists(
|
||||||
|
repo_id,
|
||||||
|
revision=revision,
|
||||||
|
filename=f"build/{universal_variant}/{package_name}/__init__.py",
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return file_exists(
|
||||||
|
repo_id,
|
||||||
|
revision=revision,
|
||||||
|
filename=f"build/{variant}/{package_name}/__init__.py",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
|
||||||
|
"""
|
||||||
|
Get a pre-downloaded, locked kernel.
|
||||||
|
|
||||||
|
If `lockfile` is not specified, the lockfile will be loaded from the
|
||||||
|
caller's package metadata.
|
||||||
|
"""
|
||||||
|
if lockfile is None:
|
||||||
|
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||||
|
else:
|
||||||
|
with open(lockfile, "r") as f:
|
||||||
|
locked_sha = _get_locked_kernel(repo_id, f.read())
|
||||||
|
|
||||||
|
if locked_sha is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Kernel `{repo_id}` is not locked. Please lock it with `kernels lock <project>` and then reinstall the project."
|
||||||
|
)
|
||||||
|
|
||||||
|
package_name = package_name_from_repo_id(repo_id)
|
||||||
|
|
||||||
|
variant = build_variant()
|
||||||
|
universal_variant = universal_build_variant()
|
||||||
|
|
||||||
|
repo_path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id,
|
||||||
|
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
|
||||||
|
cache_dir=CACHE_DIR,
|
||||||
|
revision=locked_sha,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
variant_path = repo_path / "build" / variant
|
||||||
|
universal_variant_path = repo_path / "build" / universal_variant
|
||||||
|
if not variant_path.exists() and universal_variant_path.exists():
|
||||||
|
# Fall back to universal variant.
|
||||||
|
variant = universal_variant
|
||||||
|
variant_path = universal_variant_path
|
||||||
|
|
||||||
|
module_init_path = variant_path / package_name / "__init__.py"
|
||||||
|
if not os.path.exists(module_init_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download <project>`"
|
||||||
|
)
|
||||||
|
|
||||||
|
return import_from_path(package_name, variant_path / package_name / "__init__.py")
|
||||||
|
|
||||||
|
|
||||||
|
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
|
||||||
|
"""Get a kernel using a lock file."""
|
||||||
|
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||||
|
|
||||||
|
if locked_sha is None:
|
||||||
|
raise ValueError(f"Kernel `{repo_id}` is not locked")
|
||||||
|
|
||||||
|
package_name, package_path = install_kernel(
|
||||||
|
repo_id, locked_sha, local_files_only=local_files_only
|
||||||
|
)
|
||||||
|
|
||||||
|
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
|
||||||
|
for dist in _get_caller_distributions():
|
||||||
|
lock_json = dist.read_text("kernels.lock")
|
||||||
|
if lock_json is None:
|
||||||
|
continue
|
||||||
|
locked_sha = _get_locked_kernel(repo_id, lock_json)
|
||||||
|
if locked_sha is not None:
|
||||||
|
return locked_sha
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_locked_kernel(repo_id: str, lock_json: str) -> Optional[str]:
|
||||||
|
for kernel_lock_json in json.loads(lock_json):
|
||||||
|
kernel_lock = KernelLock.from_json(kernel_lock_json)
|
||||||
|
if kernel_lock.repo_id == repo_id:
|
||||||
|
return kernel_lock.sha
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_caller_distributions() -> List[Distribution]:
|
||||||
|
module = _get_caller_module()
|
||||||
|
if module is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Look up all possible distributions that this module could be from.
|
||||||
|
package = module.__name__.split(".")[0]
|
||||||
|
dist_names = importlib.metadata.packages_distributions().get(package)
|
||||||
|
if dist_names is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [importlib.metadata.distribution(dist_name) for dist_name in dist_names]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_caller_module() -> Optional[ModuleType]:
|
||||||
|
stack = inspect.stack()
|
||||||
|
# Get first module in the stack that is not the current module.
|
||||||
|
first_module = inspect.getmodule(stack[0][0])
|
||||||
|
for frame in stack[1:]:
|
||||||
|
module = inspect.getmodule(frame[0])
|
||||||
|
if module is not None and module != first_module:
|
||||||
|
return module
|
||||||
|
return first_module
|
||||||
|
|
||||||
|
|
||||||
|
def validate_kernel(*, repo_path: Path, variant: str, hash: str):
|
||||||
|
"""Validate the given build variant of a kernel against a hasht."""
|
||||||
|
variant_path = repo_path / "build" / variant
|
||||||
|
|
||||||
|
# Get the file paths. The first element is a byte-encoded relative path
|
||||||
|
# used for sorting. The second element is the absolute path.
|
||||||
|
files: List[Tuple[bytes, Path]] = []
|
||||||
|
# Ideally we'd use Path.walk, but it's only available in Python 3.12.
|
||||||
|
for dirpath, _, filenames in os.walk(variant_path):
|
||||||
|
for filename in filenames:
|
||||||
|
file_abs = Path(dirpath) / filename
|
||||||
|
|
||||||
|
# Python likes to create files when importing modules from the
|
||||||
|
# cache, only hash files that are symlinked blobs.
|
||||||
|
if file_abs.is_symlink():
|
||||||
|
files.append(
|
||||||
|
(
|
||||||
|
file_abs.relative_to(variant_path).as_posix().encode("utf-8"),
|
||||||
|
file_abs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
m = hashlib.sha256()
|
||||||
|
|
||||||
|
for filename_bytes, full_path in sorted(files):
|
||||||
|
m.update(filename_bytes)
|
||||||
|
|
||||||
|
blob_filename = full_path.resolve().name
|
||||||
|
if len(blob_filename) == 40:
|
||||||
|
# SHA-1 hashed, so a Git blob.
|
||||||
|
m.update(git_hash_object(full_path.read_bytes()))
|
||||||
|
elif len(blob_filename) == 64:
|
||||||
|
# SHA-256 hashed, so a Git LFS blob.
|
||||||
|
m.update(hashlib.sha256(full_path.read_bytes()).digest())
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected blob filename length: {len(blob_filename)}")
|
||||||
|
|
||||||
|
computedHash = f"sha256-{m.hexdigest()}"
|
||||||
|
if computedHash != hash:
|
||||||
|
raise ValueError(
|
||||||
|
f"Lock file specifies kernel with hash {hash}, but downloaded kernel has hash: {computedHash}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def git_hash_object(data: bytes, object_type: str = "blob"):
|
||||||
|
"""Calculate git SHA1 of data."""
|
||||||
|
header = f"{object_type} {len(data)}\0".encode()
|
||||||
|
m = hashlib.sha1()
|
||||||
|
m.update(header)
|
||||||
|
m.update(data)
|
||||||
|
return m.digest()
|
||||||
|
|
||||||
|
|
||||||
|
def package_name_from_repo_id(repo_id: str) -> str:
|
||||||
|
return repo_id.split("/")[-1].replace("-", "_")
|
186
src/kernels/wheel.py
Normal file
186
src/kernels/wheel.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import email.policy
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from email.message import Message
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
KERNELS_VERSION = version("kernels")
|
||||||
|
except PackageNotFoundError:
|
||||||
|
KERNELS_VERSION = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Metadata:
|
||||||
|
name: str
|
||||||
|
version: str
|
||||||
|
cuda_version: Optional[str]
|
||||||
|
cxx_abi_version: Optional[str]
|
||||||
|
torch_version: Optional[str]
|
||||||
|
os: Optional[str]
|
||||||
|
platform: Optional[str]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_universal(self) -> bool:
|
||||||
|
return self.platform is None
|
||||||
|
|
||||||
|
|
||||||
|
def build_variant_to_wheel(
|
||||||
|
repo_id: str,
|
||||||
|
*,
|
||||||
|
version: str,
|
||||||
|
variant_path: Path,
|
||||||
|
wheel_dir: Path,
|
||||||
|
manylinux_version: str = "2.28",
|
||||||
|
python_version: str = "3.9",
|
||||||
|
) -> Path:
|
||||||
|
"""
|
||||||
|
Create a wheel file from the variant path.
|
||||||
|
"""
|
||||||
|
name = repo_id.split("/")[-1].replace("_", "-")
|
||||||
|
metadata = extract_metadata(name, version, variant_path)
|
||||||
|
return build_wheel(
|
||||||
|
metadata,
|
||||||
|
variant_path=variant_path,
|
||||||
|
wheel_dir=wheel_dir,
|
||||||
|
manylinux_version=manylinux_version,
|
||||||
|
python_version=python_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_metadata(name: str, version: str, variant_path: Path) -> Metadata:
|
||||||
|
"""
|
||||||
|
Extract metadata from the variant path.
|
||||||
|
"""
|
||||||
|
if variant_path.name == "torch-universal":
|
||||||
|
return Metadata(
|
||||||
|
name=name,
|
||||||
|
version=version,
|
||||||
|
cuda_version=None,
|
||||||
|
cxx_abi_version=None,
|
||||||
|
torch_version=None,
|
||||||
|
os=None,
|
||||||
|
platform=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not variant_path.name.startswith("torch"):
|
||||||
|
raise ValueError("Currently only conversion of Torch kernels is supported.")
|
||||||
|
|
||||||
|
variant_parts = variant_path.name.removeprefix("torch").split("-")
|
||||||
|
if len(variant_parts) != 5:
|
||||||
|
raise ValueError(f"Invalid variant name: {variant_path.name}")
|
||||||
|
|
||||||
|
torch_version = f"{variant_parts[0][:-1]}.{variant_parts[0][-1:]}"
|
||||||
|
cpp_abi_version = variant_parts[1].removeprefix("cxx")
|
||||||
|
cuda_version = variant_parts[2].removeprefix("cu")
|
||||||
|
platform = variant_parts[3].replace("-", "_")
|
||||||
|
os = variant_parts[4]
|
||||||
|
|
||||||
|
return Metadata(
|
||||||
|
name=name,
|
||||||
|
version=version,
|
||||||
|
cuda_version=cuda_version,
|
||||||
|
cxx_abi_version=cpp_abi_version,
|
||||||
|
torch_version=torch_version,
|
||||||
|
os=os,
|
||||||
|
platform=platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_wheel(
|
||||||
|
metadata: Metadata,
|
||||||
|
*,
|
||||||
|
variant_path: Path,
|
||||||
|
wheel_dir: Path,
|
||||||
|
manylinux_version: str = "2.28",
|
||||||
|
python_version: str = "3.9",
|
||||||
|
) -> Path:
|
||||||
|
"""
|
||||||
|
Build the wheel file.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from wheel.wheelfile import WheelFile # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'wheel' package is required to build wheels. Please install it with: `pip install wheel`"
|
||||||
|
)
|
||||||
|
|
||||||
|
name = metadata.name.replace("-", "_")
|
||||||
|
python_version_flat = python_version.replace(".", "")
|
||||||
|
|
||||||
|
if metadata.is_universal:
|
||||||
|
python_tag = f"py{python_version_flat}"
|
||||||
|
abi_tag = "none"
|
||||||
|
platform_tag = "any"
|
||||||
|
wheel_filename = (
|
||||||
|
f"{name}-{metadata.version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
|
||||||
|
)
|
||||||
|
dist_info_dir_name = f"{name}-{metadata.version}.dist-info"
|
||||||
|
root_is_purelib = "true"
|
||||||
|
requires_dist_torch = "torch"
|
||||||
|
else:
|
||||||
|
python_tag = f"cp{python_version_flat}"
|
||||||
|
abi_tag = "abi3"
|
||||||
|
|
||||||
|
if (
|
||||||
|
metadata.torch_version is None
|
||||||
|
or metadata.cuda_version is None
|
||||||
|
or metadata.cxx_abi_version is None
|
||||||
|
or metadata.os is None
|
||||||
|
or metadata.platform is None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Torch version, CUDA version, C++ ABI version, OS, and platform must be specified for non-universal wheels."
|
||||||
|
)
|
||||||
|
|
||||||
|
local_version = f"torch{metadata.torch_version.replace('.', '')}cu{metadata.cuda_version}cxx{metadata.cxx_abi_version}"
|
||||||
|
|
||||||
|
if metadata.os == "linux":
|
||||||
|
platform_tag = (
|
||||||
|
f"manylinux_{manylinux_version.replace('.', '_')}_{metadata.platform}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
platform_tag = f"{metadata.os}_{metadata.platform.replace('-', '_')}"
|
||||||
|
|
||||||
|
wheel_filename = f"{name}-{metadata.version}+{local_version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
|
||||||
|
dist_info_dir_name = f"{name}-{metadata.version}+{local_version}.dist-info"
|
||||||
|
root_is_purelib = "false"
|
||||||
|
requires_dist_torch = f"torch=={metadata.torch_version}.*"
|
||||||
|
|
||||||
|
wheel_path = wheel_dir / wheel_filename
|
||||||
|
|
||||||
|
wheel_msg = Message(email.policy.compat32)
|
||||||
|
wheel_msg.add_header("Wheel-Version", "1.0")
|
||||||
|
wheel_msg.add_header("Generator", f"kernels ({KERNELS_VERSION})")
|
||||||
|
wheel_msg.add_header("Root-Is-Purelib", root_is_purelib)
|
||||||
|
wheel_msg.add_header("Tag", f"{python_tag}-{abi_tag}-{platform_tag}")
|
||||||
|
|
||||||
|
metadata_msg = Message(email.policy.compat32)
|
||||||
|
metadata_msg.add_header("Metadata-Version", "2.1")
|
||||||
|
metadata_msg.add_header("Name", name)
|
||||||
|
metadata_msg.add_header("Version", metadata.version)
|
||||||
|
metadata_msg.add_header("Summary", f"{name} kernel")
|
||||||
|
metadata_msg.add_header("Requires-Python", ">=3.9")
|
||||||
|
metadata_msg.add_header("Requires-Dist", requires_dist_torch)
|
||||||
|
|
||||||
|
source_pkg_dir = variant_path / name
|
||||||
|
|
||||||
|
with WheelFile(wheel_path, "w") as wheel_file:
|
||||||
|
for root, dirnames, filenames in os.walk(source_pkg_dir):
|
||||||
|
for filename in filenames:
|
||||||
|
if filename.endswith(".pyc"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
abs_filepath = os.path.join(root, filename)
|
||||||
|
entry_name = os.path.relpath(abs_filepath, variant_path)
|
||||||
|
wheel_file.write(abs_filepath, entry_name)
|
||||||
|
|
||||||
|
wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL")
|
||||||
|
wheel_file.writestr(wheel_metadata_path, str(wheel_msg).encode("utf-8"))
|
||||||
|
|
||||||
|
metadata_path = os.path.join(dist_info_dir_name, "METADATA")
|
||||||
|
wheel_file.writestr(metadata_path, str(metadata_msg).encode("utf-8"))
|
||||||
|
|
||||||
|
return wheel_path
|
10
tests/conftest.py
Normal file
10
tests/conftest.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_runtest_setup(item):
|
||||||
|
if "linux_only" in item.keywords and not sys.platform.startswith("linux"):
|
||||||
|
pytest.skip("skipping Linux-only test on non-Linux platform")
|
||||||
|
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
||||||
|
pytest.skip("skipping macOS-only test on non-macOS platform")
|
94
tests/kernel_locking/kernels.lock
Normal file
94
tests/kernel_locking/kernels.lock
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"repo_id": "kernels-community/activation",
|
||||||
|
"sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538",
|
||||||
|
"variants": {
|
||||||
|
"torch25-cxx11-cu118-x86_64-linux": {
|
||||||
|
"hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch25-cxx11-cu121-x86_64-linux": {
|
||||||
|
"hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch25-cxx11-cu124-x86_64-linux": {
|
||||||
|
"hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch25-cxx98-cu118-x86_64-linux": {
|
||||||
|
"hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch25-cxx98-cu121-x86_64-linux": {
|
||||||
|
"hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch25-cxx98-cu124-x86_64-linux": {
|
||||||
|
"hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx11-cu118-x86_64-linux": {
|
||||||
|
"hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx11-cu124-x86_64-linux": {
|
||||||
|
"hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx11-cu126-aarch64-linux": {
|
||||||
|
"hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx11-cu126-x86_64-linux": {
|
||||||
|
"hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx98-cu118-x86_64-linux": {
|
||||||
|
"hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx98-cu124-x86_64-linux": {
|
||||||
|
"hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx98-cu126-aarch64-linux": {
|
||||||
|
"hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx98-cu126-x86_64-linux": {
|
||||||
|
"hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch27-cxx11-cu118-x86_64-linux": {
|
||||||
|
"hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch27-cxx11-cu126-aarch64-linux": {
|
||||||
|
"hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch27-cxx11-cu126-x86_64-linux": {
|
||||||
|
"hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch27-cxx11-cu128-aarch64-linux": {
|
||||||
|
"hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch27-cxx11-cu128-x86_64-linux": {
|
||||||
|
"hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"repo_id": "kernels-community/triton-scaled-mm",
|
||||||
|
"sha": "af10d8c1affe8efce93d228c3e6e64ff673d493f",
|
||||||
|
"variants": {
|
||||||
|
"torch-universal": {
|
||||||
|
"hash": "sha256-b843c5f30b52b6c1c56fca28cb0cf453be71d6ce7d308f383dce71a8050f7b52",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
3
tests/kernel_locking/pyproject.toml
Normal file
3
tests/kernel_locking/pyproject.toml
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[tool.kernels.dependencies]
|
||||||
|
"kernels-community/activation" = ">=0.0.2"
|
||||||
|
"kernels-community/triton-scaled-mm" = ">=0.0.2"
|
12
tests/layer_locking/kernels.lock
Normal file
12
tests/layer_locking/kernels.lock
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"repo_id": "kernels-test/versions",
|
||||||
|
"sha": "dc142fd6c9920c993d32be6358b78957c58681c3",
|
||||||
|
"variants": {
|
||||||
|
"torch-universal": {
|
||||||
|
"hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
2
tests/layer_locking/pyproject.toml
Normal file
2
tests/layer_locking/pyproject.toml
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
[tool.kernels.dependencies]
|
||||||
|
"kernels-test/versions" = ">=0.1.0,<0.2.0"
|
@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from hf_kernels import get_kernel
|
|
||||||
|
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -8,6 +9,24 @@ def kernel():
|
|||||||
return get_kernel("kernels-community/activation")
|
return get_kernel("kernels-community/activation")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def local_kernel():
|
||||||
|
package_name, path = install_kernel("kernels-community/activation", "main")
|
||||||
|
# Path is the build variant path (build/torch-<...>), so the grandparent
|
||||||
|
# is the kernel repository path.
|
||||||
|
return get_local_kernel(path.parent.parent, package_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def metal_kernel():
|
||||||
|
return get_kernel("kernels-test/relu-metal")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def universal_kernel():
|
||||||
|
return get_kernel("kernels-community/triton-scaled-mm")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def device():
|
def device():
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
@ -15,6 +34,7 @@ def device():
|
|||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
def test_gelu_fast(kernel, device):
|
def test_gelu_fast(kernel, device):
|
||||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
@ -28,3 +48,78 @@ def test_gelu_fast(kernel, device):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert torch.allclose(y, expected)
|
assert torch.allclose(y, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
def test_local_kernel(local_kernel, device):
|
||||||
|
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||||
|
y = torch.empty_like(x)
|
||||||
|
|
||||||
|
local_kernel.gelu_fast(y, x)
|
||||||
|
|
||||||
|
expected = torch.tensor(
|
||||||
|
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(y, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.darwin_only
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||||
|
def test_relu_metal(metal_kernel, dtype):
|
||||||
|
x = torch.arange(-10, 10, dtype=dtype, device="mps")
|
||||||
|
y = metal_kernel.relu(x)
|
||||||
|
assert torch.allclose(y, torch.relu(x))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"kernel_exists",
|
||||||
|
[
|
||||||
|
("kernels-community/activation", "main", True),
|
||||||
|
("kernels-community/triton-layer-norm", "main", True),
|
||||||
|
# Repo only contains Torch 2.4 kernels (and we don't
|
||||||
|
# support/test against this version).
|
||||||
|
("kernels-test/only-torch-2.4", "main", False),
|
||||||
|
("google-bert/bert-base-uncased", "87565a309", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_has_kernel(kernel_exists):
|
||||||
|
repo_id, revision, kernel = kernel_exists
|
||||||
|
assert has_kernel(repo_id, revision=revision) == kernel
|
||||||
|
|
||||||
|
|
||||||
|
def test_version():
|
||||||
|
kernel = get_kernel("kernels-test/versions")
|
||||||
|
assert kernel.version() == "0.2.0"
|
||||||
|
kernel = get_kernel("kernels-test/versions", version="<1.0.0")
|
||||||
|
assert kernel.version() == "0.2.0"
|
||||||
|
kernel = get_kernel("kernels-test/versions", version="<0.2.0")
|
||||||
|
assert kernel.version() == "0.1.1"
|
||||||
|
kernel = get_kernel("kernels-test/versions", version=">0.1.0,<0.2.0")
|
||||||
|
assert kernel.version() == "0.1.1"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
||||||
|
get_kernel("kernels-test/versions", version=">0.2.0")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
||||||
|
kernel = get_kernel(
|
||||||
|
"kernels-test/versions", revision="v0.1.0", version="<1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
def test_universal_kernel(universal_kernel):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
||||||
|
B = torch.randint(-10, 10, (128, 96), dtype=torch.int8, device="cuda")
|
||||||
|
scale_a = torch.tensor(0.4, dtype=torch.float16, device="cuda")
|
||||||
|
scale_b = torch.tensor(0.6, dtype=torch.float16, device="cuda")
|
||||||
|
|
||||||
|
out = universal_kernel.triton_scaled_mm(A, B, scale_a, scale_b, torch.float16)
|
||||||
|
out_check = (A * scale_a) @ (B * scale_b)
|
||||||
|
out_check = out_check.to(torch.float16)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from hf_kernels import get_kernel
|
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -15,18 +16,21 @@ def device():
|
|||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
def test_gelu_small(kernel, device, benchmark):
|
def test_gelu_small(kernel, device, benchmark):
|
||||||
x = torch.randn(32, 32, dtype=torch.float16, device=device)
|
x = torch.randn(32, 32, dtype=torch.float16, device=device)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
benchmark(kernel.gelu_fast, y, x)
|
benchmark(kernel.gelu_fast, y, x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
def test_gelu_medium(kernel, device, benchmark):
|
def test_gelu_medium(kernel, device, benchmark):
|
||||||
x = torch.randn(128, 128, dtype=torch.float16, device=device)
|
x = torch.randn(128, 128, dtype=torch.float16, device=device)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
benchmark(kernel.gelu_fast, y, x)
|
benchmark(kernel.gelu_fast, y, x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
def test_gelu_large(kernel, device, benchmark):
|
def test_gelu_large(kernel, device, benchmark):
|
||||||
x = torch.randn(512, 512, dtype=torch.float16, device=device)
|
x = torch.randn(512, 512, dtype=torch.float16, device=device)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
|
230
tests/test_interval_tree.py
Normal file
230
tests/test_interval_tree.py
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
import random
|
||||||
|
from typing import Generic, List, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from kernels._interval_tree import IntervalTree, _Node
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleIntervalStore(Generic[T]):
|
||||||
|
"""A simple O(n) implementation that stores intervals in a list."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.intervals: List[Tuple[int, int, T]] = []
|
||||||
|
|
||||||
|
def insert(self, start: int, end: int, data: T) -> None:
|
||||||
|
"""Insert an interval into the store."""
|
||||||
|
# Replace data if the interval already exists.
|
||||||
|
for i, (existing_start, existing_end, existing_data) in enumerate(
|
||||||
|
self.intervals
|
||||||
|
):
|
||||||
|
if existing_start == start and existing_end == end:
|
||||||
|
self.intervals[i] = (start, end, data)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.intervals.append((start, end, data))
|
||||||
|
|
||||||
|
def find_smallest_interval(self, point: int) -> Optional[T]:
|
||||||
|
"""Find the best match using linear search."""
|
||||||
|
matches = []
|
||||||
|
for start, end, data in self.intervals:
|
||||||
|
if start <= point <= end:
|
||||||
|
matches.append((start, end, data))
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Return the smallest interval, sort by memory location when
|
||||||
|
# there are multiple matches with the same interval size. This
|
||||||
|
# mirrors the ordering in the intervan tree.
|
||||||
|
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
|
||||||
|
return best_match[2]
|
||||||
|
|
||||||
|
|
||||||
|
def is_balanced(tree: IntervalTree[T]) -> bool:
|
||||||
|
"""Check if the AVL tree is properly balanced."""
|
||||||
|
|
||||||
|
def check_balance(node: Optional[_Node[T]]) -> Tuple[bool, int]:
|
||||||
|
if node is None:
|
||||||
|
return True, 0
|
||||||
|
|
||||||
|
# Left and right subtrees should be balanced.
|
||||||
|
left_balanced, left_height = check_balance(node.left)
|
||||||
|
if not left_balanced:
|
||||||
|
return False, -1
|
||||||
|
|
||||||
|
right_balanced, right_height = check_balance(node.right)
|
||||||
|
if not right_balanced:
|
||||||
|
return False, -1
|
||||||
|
|
||||||
|
# The difference in height should not exceed 1.
|
||||||
|
if abs(left_height - right_height) > 1:
|
||||||
|
return False, -1
|
||||||
|
|
||||||
|
# Check if the height is correct.
|
||||||
|
expected_height = 1 + max(left_height, right_height)
|
||||||
|
if node.height != expected_height:
|
||||||
|
return False, -1
|
||||||
|
|
||||||
|
return True, expected_height
|
||||||
|
|
||||||
|
balanced, _ = check_balance(tree.root)
|
||||||
|
return balanced
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def populated_tree() -> IntervalTree[str]:
|
||||||
|
"""Provides a pre-populated IntervalTree for testing."""
|
||||||
|
tree = IntervalTree[str]()
|
||||||
|
kernels = [
|
||||||
|
(80, 89, "Kernel_A_General_80_89"),
|
||||||
|
(86, 89, "Kernel_B_Ampere_86_89"),
|
||||||
|
(80, 86, "Kernel_C_Older_Ampere_80_86"),
|
||||||
|
(70, 75, "Kernel_D_Volta_70_75"),
|
||||||
|
(86, 87, "Kernel_E_Specific_86_87"),
|
||||||
|
]
|
||||||
|
for start, end, name in kernels:
|
||||||
|
tree.insert(start, end, name)
|
||||||
|
return tree
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_smallest_interval_match_with_multiple_overlaps(populated_tree):
|
||||||
|
# Check that the smallest inteval is selected when there are
|
||||||
|
# multiple matching intervals.
|
||||||
|
assert populated_tree.find_smallest_interval(86) == "Kernel_E_Specific_86_87"
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_single_match(populated_tree):
|
||||||
|
assert populated_tree.find_smallest_interval(72) == "Kernel_D_Volta_70_75"
|
||||||
|
assert populated_tree.find_smallest_interval(75) == "Kernel_D_Volta_70_75"
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_match_outside_all_ranges(populated_tree):
|
||||||
|
# Check that no interval is found when the value is out of range
|
||||||
|
# (too small/too large).
|
||||||
|
assert populated_tree.find_smallest_interval(65) is None
|
||||||
|
assert populated_tree.find_smallest_interval(95) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_match_in_gap_between_ranges(populated_tree):
|
||||||
|
# Check that no interval is found when the value is between two
|
||||||
|
# intervals.
|
||||||
|
assert populated_tree.find_smallest_interval(78) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_boundary_conditions_start_and_end(populated_tree):
|
||||||
|
# Test exact upper/lower bounds of intervals.
|
||||||
|
assert populated_tree.find_smallest_interval(80) == "Kernel_C_Older_Ampere_80_86"
|
||||||
|
assert populated_tree.find_smallest_interval(89) == "Kernel_B_Ampere_86_89"
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_tree():
|
||||||
|
# Searching in an empty tree should return None.
|
||||||
|
empty_tree = IntervalTree[str]()
|
||||||
|
assert empty_tree.find_smallest_interval(100) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_equally_specific_matches():
|
||||||
|
# Check that we pick the match in a stable way when there is are
|
||||||
|
# multiple matching intervals with the same size.
|
||||||
|
tree = IntervalTree[str]()
|
||||||
|
str1 = "First_Narrow_Kernel"
|
||||||
|
str2 = "Second_Narrow_Kernel"
|
||||||
|
tree.insert(10, 20, "Wide_Kernel")
|
||||||
|
tree.insert(12, 17, str1)
|
||||||
|
tree.insert(14, 19, str2)
|
||||||
|
|
||||||
|
if id(str1) < id(str2):
|
||||||
|
assert tree.find_smallest_interval(15) == str1
|
||||||
|
else:
|
||||||
|
assert tree.find_smallest_interval(15) == str2
|
||||||
|
|
||||||
|
|
||||||
|
def test_property_based_interval_tree():
|
||||||
|
# Quick-check property-based testing:
|
||||||
|
#
|
||||||
|
# - Verify that the tree is balanced after each insertion.
|
||||||
|
# - Verify the query against a simple list-based implementation.
|
||||||
|
|
||||||
|
random.seed(42) # For reproducible tests
|
||||||
|
|
||||||
|
test_points = list(range(0, 101))
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
tree = IntervalTree[str]()
|
||||||
|
simple = SimpleIntervalStore[str]()
|
||||||
|
|
||||||
|
intervals = []
|
||||||
|
for i in range(100):
|
||||||
|
start = random.randint(0, 90)
|
||||||
|
end = random.randint(start, 100)
|
||||||
|
data = f"interval_{i}_s{start}_e{end}"
|
||||||
|
intervals.append((start, end, data))
|
||||||
|
|
||||||
|
for i, (start, end, data) in enumerate(intervals):
|
||||||
|
tree.insert(start, end, data)
|
||||||
|
simple.insert(start, end, data)
|
||||||
|
|
||||||
|
# Check that tree is still balanced
|
||||||
|
assert is_balanced(
|
||||||
|
tree
|
||||||
|
), f"Tree became unbalanced after inserting interval {i}: ({start}, {end})"
|
||||||
|
|
||||||
|
for point in test_points:
|
||||||
|
tree_result = tree.find_smallest_interval(point)
|
||||||
|
simple_result = simple.find_smallest_interval(point)
|
||||||
|
|
||||||
|
assert tree_result == simple_result, (
|
||||||
|
f"Mismatch for point {point} after inserting {i+1} intervals. "
|
||||||
|
f"Tree: {tree_result}, Simple: {simple_result}. "
|
||||||
|
f"Last inserted: ({start}, {end})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_property_based_edge_cases():
|
||||||
|
random.seed(123)
|
||||||
|
|
||||||
|
tree = IntervalTree[str]()
|
||||||
|
simple = SimpleIntervalStore[str]()
|
||||||
|
|
||||||
|
# Single-point intervals.
|
||||||
|
for i in range(10):
|
||||||
|
point = random.randint(0, 100)
|
||||||
|
data = f"single_point_{i}_{point}"
|
||||||
|
tree.insert(point, point, data)
|
||||||
|
simple.insert(point, point, data)
|
||||||
|
|
||||||
|
assert is_balanced(
|
||||||
|
tree
|
||||||
|
), f"Tree unbalanced after inserting single point {point}"
|
||||||
|
|
||||||
|
# Test the exact point and neighbors
|
||||||
|
for test_point in [point - 1, point, point + 1]:
|
||||||
|
if 0 <= test_point <= 100:
|
||||||
|
tree_result = tree.find_smallest_interval(test_point)
|
||||||
|
simple_result = simple.find_smallest_interval(test_point)
|
||||||
|
assert tree_result == simple_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_unique_intervals_override():
|
||||||
|
"""Test that inserting an interval with the same start/end overrides the previous value."""
|
||||||
|
tree = IntervalTree[str]()
|
||||||
|
|
||||||
|
tree.insert(10, 20, "original_value")
|
||||||
|
assert tree.find_smallest_interval(15) == "original_value"
|
||||||
|
|
||||||
|
tree.insert(10, 20, "new_value")
|
||||||
|
assert tree.find_smallest_interval(15) == "new_value"
|
||||||
|
|
||||||
|
tree.insert(10, 25, "different_interval")
|
||||||
|
results = tree.search(15)
|
||||||
|
assert "new_value" in results
|
||||||
|
assert "different_interval" in results
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
tree.insert(10, 20, "final_value")
|
||||||
|
assert tree.find_smallest_interval(15) == "final_value"
|
||||||
|
|
||||||
|
assert is_balanced(tree)
|
60
tests/test_kernel_locking.py
Normal file
60
tests/test_kernel_locking.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from kernels import load_kernel
|
||||||
|
from kernels.cli import download_kernels
|
||||||
|
from kernels.layer import (
|
||||||
|
LockedLayerRepository,
|
||||||
|
Mode,
|
||||||
|
kernelize,
|
||||||
|
use_kernel_forward_from_hub,
|
||||||
|
use_kernel_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Mock download arguments class.
|
||||||
|
@dataclass
|
||||||
|
class DownloadArgs:
|
||||||
|
all_variants: bool
|
||||||
|
project_dir: Path
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_all_hash_validation():
|
||||||
|
project_dir = Path(__file__).parent / "kernel_locking"
|
||||||
|
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
def test_load_locked():
|
||||||
|
project_dir = Path(__file__).parent / "kernel_locking"
|
||||||
|
# Also validates that hashing works correctly.
|
||||||
|
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
|
||||||
|
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer_locked():
|
||||||
|
project_dir = Path(__file__).parent / "layer_locking"
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("Version")
|
||||||
|
class Version(nn.Module):
|
||||||
|
def forward(self) -> str:
|
||||||
|
return "0.0.0"
|
||||||
|
|
||||||
|
version = Version()
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Version": {
|
||||||
|
"cuda": LockedLayerRepository(
|
||||||
|
repo_id="kernels-test/versions",
|
||||||
|
layer_name="Version",
|
||||||
|
lockfile=project_dir / "kernels.lock",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
):
|
||||||
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
assert version() == "0.1.1"
|
953
tests/test_layer.py
Normal file
953
tests/test_layer.py
Normal file
@ -0,0 +1,953 @@
|
|||||||
|
import sys
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from kernels import (
|
||||||
|
Device,
|
||||||
|
LayerRepository,
|
||||||
|
Mode,
|
||||||
|
kernelize,
|
||||||
|
register_kernel_mapping,
|
||||||
|
use_kernel_forward_from_hub,
|
||||||
|
)
|
||||||
|
from kernels.layer import (
|
||||||
|
_KERNEL_MAPPING,
|
||||||
|
CUDAProperties,
|
||||||
|
_validate_layer,
|
||||||
|
use_kernel_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_layer_mapping = {
|
||||||
|
"SiluAndMul": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"SiluAndMulNoCompile": {
|
||||||
|
"cuda": LayerRepository(
|
||||||
|
repo_id="kernels-test/op-without-fake-test",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"SiluAndMulStringDevice": {
|
||||||
|
"cuda": LayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
register_kernel_mapping(kernel_layer_mapping)
|
||||||
|
|
||||||
|
|
||||||
|
class SiluAndMul(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Used to check that we called hub kernel.
|
||||||
|
self.n_calls = 0
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
self.n_calls += 1
|
||||||
|
d = input.shape[-1] // 2
|
||||||
|
return F.silu(input[..., :d]) * input[..., d:]
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("SiluAndMulNoCompile")
|
||||||
|
class SiluAndMulNoCompileKernel(SiluAndMul):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("SiluAndMul")
|
||||||
|
class SiluAndMulWithKernel(SiluAndMul):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("SiluAndMulStringDevice")
|
||||||
|
class SiluAndMulStringDevice(SiluAndMul):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("Linear")
|
||||||
|
class TorchLinearWithCounter(nn.Linear):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# Used to check that we called hub kernel.
|
||||||
|
self.n_calls = 0
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
self.n_calls += 1
|
||||||
|
return super().forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
def test_arg_kinds():
|
||||||
|
@use_kernel_forward_from_hub("ArgKind")
|
||||||
|
class ArgKind(nn.Module):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
arg1,
|
||||||
|
arg2,
|
||||||
|
*,
|
||||||
|
kwarg1,
|
||||||
|
kwarg2=42,
|
||||||
|
):
|
||||||
|
return (arg1, arg2, kwarg1, kwarg2)
|
||||||
|
|
||||||
|
arg_kind = ArgKind()
|
||||||
|
assert arg_kind("foo", "bar", kwarg1="baz") == ("foo", "bar", "baz", 42)
|
||||||
|
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
||||||
|
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||||
|
def test_hub_forward(cls, device):
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
|
||||||
|
silu_and_mul = SiluAndMul()
|
||||||
|
X = torch.randn((32, 64), device=device)
|
||||||
|
Y = silu_and_mul(X)
|
||||||
|
|
||||||
|
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
|
||||||
|
Y_kernel = silu_and_mul_with_kernel(X)
|
||||||
|
|
||||||
|
torch.testing.assert_close(Y_kernel, Y)
|
||||||
|
|
||||||
|
assert silu_and_mul.n_calls == 1
|
||||||
|
if device == "cuda":
|
||||||
|
assert silu_and_mul_with_kernel.n_calls == 0
|
||||||
|
else:
|
||||||
|
assert silu_and_mul_with_kernel.n_calls == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
def test_capability():
|
||||||
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
Device(
|
||||||
|
type="cuda",
|
||||||
|
properties=CUDAProperties(
|
||||||
|
min_capability=75, max_capability=sys.maxsize
|
||||||
|
),
|
||||||
|
): LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
|
||||||
|
# Check that we called out to the kernel.
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
Device(
|
||||||
|
type="cuda",
|
||||||
|
properties=CUDAProperties(
|
||||||
|
min_capability=sys.maxsize, max_capability=sys.maxsize
|
||||||
|
),
|
||||||
|
): LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
|
||||||
|
# Check that we didn't call out to the kernel because there is
|
||||||
|
# is no kernel with a matching capability..
|
||||||
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer_fallback_works():
|
||||||
|
@use_kernel_forward_from_hub("SiluAndMulNonExisting")
|
||||||
|
class SiluAndMulWithKernelFallback(SiluAndMul):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check that we don't raise an exception for a non-existing kernel.
|
||||||
|
silu_and_mul = SiluAndMulWithKernelFallback()
|
||||||
|
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
||||||
|
@pytest.mark.parametrize("device", ["cuda"])
|
||||||
|
def test_torch_compile_layer_without_fallback(cls, device):
|
||||||
|
silu_and_mul = SiluAndMul()
|
||||||
|
|
||||||
|
X = torch.randn((32, 64), dtype=torch.float32, device=device)
|
||||||
|
Y = silu_and_mul(X)
|
||||||
|
|
||||||
|
silu_and_mul_with_kernel = cls()
|
||||||
|
silu_and_mul_with_kernel.eval()
|
||||||
|
|
||||||
|
ctx = (
|
||||||
|
pytest.raises(ValueError, match="does not support mode")
|
||||||
|
if cls is SiluAndMulNoCompileKernel
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
with ctx:
|
||||||
|
silu_and_mul_with_kernel = kernelize(
|
||||||
|
silu_and_mul_with_kernel,
|
||||||
|
device=device,
|
||||||
|
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
|
||||||
|
use_fallback=False,
|
||||||
|
)
|
||||||
|
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
|
||||||
|
|
||||||
|
Y_compiled = silu_and_mul_compiled(X)
|
||||||
|
|
||||||
|
torch.testing.assert_close(Y_compiled, Y)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
||||||
|
@pytest.mark.parametrize("device", ["cuda"])
|
||||||
|
def test_torch_compile_layer_with_fallback(cls, device):
|
||||||
|
silu_and_mul = SiluAndMul()
|
||||||
|
|
||||||
|
X = torch.randn((32, 64), dtype=torch.float32, device=device)
|
||||||
|
Y = silu_and_mul(X)
|
||||||
|
|
||||||
|
silu_and_mul_with_kernel = cls()
|
||||||
|
silu_and_mul_with_kernel.eval()
|
||||||
|
silu_and_mul_with_kernel = kernelize(
|
||||||
|
silu_and_mul_with_kernel,
|
||||||
|
device=device,
|
||||||
|
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
|
||||||
|
)
|
||||||
|
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
|
||||||
|
|
||||||
|
Y_compiled = silu_and_mul_compiled(X)
|
||||||
|
|
||||||
|
torch.testing.assert_close(Y_compiled, Y)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
def test_mapping_contexts():
|
||||||
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
|
"SiluAndMul",
|
||||||
|
"SiluAndMulStringDevice",
|
||||||
|
"SiluAndMulNoCompile",
|
||||||
|
}
|
||||||
|
|
||||||
|
extra_mapping1 = {
|
||||||
|
"TestKernel": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
revision="layers",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with use_kernel_mapping(extra_mapping1):
|
||||||
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
|
"SiluAndMul",
|
||||||
|
"SiluAndMulStringDevice",
|
||||||
|
"SiluAndMulNoCompile",
|
||||||
|
"TestKernel",
|
||||||
|
}
|
||||||
|
|
||||||
|
extra_mapping2 = {
|
||||||
|
"SiluAndMul": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-community/non-existing",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
revision="layers",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with use_kernel_mapping(extra_mapping2):
|
||||||
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
|
"SiluAndMul",
|
||||||
|
"SiluAndMulStringDevice",
|
||||||
|
"SiluAndMulNoCompile",
|
||||||
|
"TestKernel",
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
||||||
|
== "kernels-community/non-existing"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
|
"SiluAndMul",
|
||||||
|
"SiluAndMulStringDevice",
|
||||||
|
"SiluAndMulNoCompile",
|
||||||
|
"TestKernel",
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
||||||
|
== "kernels-community/activation"
|
||||||
|
)
|
||||||
|
|
||||||
|
with use_kernel_mapping(extra_mapping2, inherit_mapping=False):
|
||||||
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
|
"SiluAndMul",
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
||||||
|
== "kernels-community/non-existing"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
|
"SiluAndMul",
|
||||||
|
"SiluAndMulStringDevice",
|
||||||
|
"SiluAndMulNoCompile",
|
||||||
|
"TestKernel",
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
||||||
|
== "kernels-community/activation"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
|
"SiluAndMul",
|
||||||
|
"SiluAndMulStringDevice",
|
||||||
|
"SiluAndMulNoCompile",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_kernel_layer():
|
||||||
|
class BadLayer(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.foo = 42
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match="not override"):
|
||||||
|
_validate_layer(cls=BadLayer, check_cls=SiluAndMul)
|
||||||
|
|
||||||
|
class BadLayer2(nn.Module):
|
||||||
|
foo: int = 42
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match="not contain additional members"):
|
||||||
|
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
|
||||||
|
|
||||||
|
class BadLayer3(nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match="different number of arguments"):
|
||||||
|
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
|
||||||
|
|
||||||
|
class BadLayer4(nn.Module):
|
||||||
|
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match="different kind of arguments"):
|
||||||
|
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
def test_invalid_mode_for_mapping_rejected():
|
||||||
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.TRAINING: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearNoBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="does not support backward"):
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
def test_kernel_modes():
|
||||||
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
|
# Case 1: layer without further specification, becomes the
|
||||||
|
# base layer.
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||||
|
kernelize(linear)
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
# Case 2: register a kernel just for training. If no base kernel
|
||||||
|
# layer is registered, we fall back to the original layer.
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.TRAINING: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
linear(X)
|
||||||
|
# Training has a kernel, so fallback.
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
|
||||||
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
|
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||||
|
kernelize(linear)
|
||||||
|
linear(X)
|
||||||
|
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
# Case 3: register a kernel just for training and one for fallback.
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.FALLBACK: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
),
|
||||||
|
Mode.TRAINING: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
# Falls back to TRAINING.
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
linear(X)
|
||||||
|
# Falls back to the TRAINING kernel.
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||||
|
kernelize(linear)
|
||||||
|
linear(X)
|
||||||
|
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
# Case 4: register a kernel with two preferences.
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.TRAINING
|
||||||
|
| Mode.TORCH_COMPILE: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
# Falls back to the TRAINING | TORCH_COMPILE kernel.
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
linear(X)
|
||||||
|
# TRAINING can fall back to TRAINING | TORCH_COMPILE kernel.
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
# Uses TRAINING | TORCH_COMPILE kernel.
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
kernelize(linear)
|
||||||
|
linear(X)
|
||||||
|
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
def test_fallback_used_when_training():
|
||||||
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
|
# Case 1: kernel with explicit backward support should always
|
||||||
|
# use the kernel.
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
linear.train()
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
linear.eval()
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
# Case 2: kernel with implicit backward support should always
|
||||||
|
# use the kernel.
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearImplicitBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
linear.train()
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
linear.eval()
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_mode_rejected():
|
||||||
|
with pytest.raises(ValueError, match="mutually exclusive"):
|
||||||
|
_ = Mode.INFERENCE | Mode.TRAINING
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="cannot be combined with other modes"):
|
||||||
|
_ = Mode.FALLBACK | Mode.TORCH_COMPILE
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="can only be used to register kernel mappings"
|
||||||
|
):
|
||||||
|
kernelize(torch.nn.Linear(32, 32), mode=Mode.FALLBACK)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="mode must contain"):
|
||||||
|
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
def test_kernel_modes_inference():
|
||||||
|
"""Test inference-specific fallback scenarios."""
|
||||||
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
|
# Case 1: register a kernel just for inference
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.INFERENCE: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
# INFERENCE | TORCH_COMPILE cannot fall back to INFERENCE kernel, so uses original
|
||||||
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
linear(X)
|
||||||
|
# No training kernel, so fallback to original
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
# Case 2: register a kernel just for inference + torch.compile
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.INFERENCE
|
||||||
|
| Mode.TORCH_COMPILE: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
linear(X)
|
||||||
|
# INFERENCE falls back to INFERENCE | TORCH_COMPILE kernel
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
linear(X)
|
||||||
|
# No training kernel, so fallback to original
|
||||||
|
assert linear.n_calls == 3
|
||||||
|
|
||||||
|
# Case 3: register both inference kernels
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.INFERENCE: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
),
|
||||||
|
Mode.INFERENCE
|
||||||
|
| Mode.TORCH_COMPILE: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
# Uses exact INFERENCE kernel
|
||||||
|
assert linear.n_calls == 3
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
# Uses exact INFERENCE | TORCH_COMPILE kernel
|
||||||
|
assert linear.n_calls == 3
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
linear(X)
|
||||||
|
# No training kernel, so fallback to original
|
||||||
|
assert linear.n_calls == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
def test_kernel_modes_mixed():
|
||||||
|
"""Test mixed training and inference kernel scenarios."""
|
||||||
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
|
# Case 1: register both base inference and training kernels
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.INFERENCE: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
),
|
||||||
|
Mode.TRAINING: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
# INFERENCE | TORCH_COMPILE cannot fall back to INFERENCE kernel, so uses original
|
||||||
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
# Case 2: register all four kernel modes
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.INFERENCE: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
),
|
||||||
|
Mode.TRAINING: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
),
|
||||||
|
Mode.INFERENCE
|
||||||
|
| Mode.TORCH_COMPILE: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
),
|
||||||
|
Mode.TRAINING
|
||||||
|
| Mode.TORCH_COMPILE: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
# Uses exact INFERENCE kernel
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
linear(X)
|
||||||
|
# Uses exact TRAINING kernel
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
# Uses exact INFERENCE | TORCH_COMPILE kernel
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
# Uses exact TRAINING | TORCH_COMPILE kernel
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.linux_only
|
||||||
|
def test_kernel_modes_cross_fallback():
|
||||||
|
"""Test cross-mode fallback scenarios from inference to training modes."""
|
||||||
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
|
# Case 1: Only training kernel registered - inference should fall back to training
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.TRAINING: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
# INFERENCE falls back to TRAINING kernel
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
linear(X)
|
||||||
|
# TRAINING uses the kernel directly
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
# Case 2: Only training + torch.compile kernel registered
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.TRAINING
|
||||||
|
| Mode.TORCH_COMPILE: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
# INFERENCE falls back to TRAINING | TORCH_COMPILE kernel
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
# INFERENCE | TORCH_COMPILE falls back to TRAINING | TORCH_COMPILE kernel
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
linear(X)
|
||||||
|
# TRAINING falls back to TRAINING | TORCH_COMPILE kernel
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
# TRAINING | TORCH_COMPILE uses the kernel directly
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
# Case 3: Test that training modes don't fall back to inference modes
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
"cuda": {
|
||||||
|
Mode.INFERENCE: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
),
|
||||||
|
Mode.INFERENCE
|
||||||
|
| Mode.TORCH_COMPILE: LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
# TRAINING should NOT fall back to inference kernels, use original
|
||||||
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
|
linear(X)
|
||||||
|
# TRAINING | TORCH_COMPILE should NOT fall back to inference kernels, use original
|
||||||
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer_versions():
|
||||||
|
@use_kernel_forward_from_hub("Version")
|
||||||
|
class Version(nn.Module):
|
||||||
|
def forward(self) -> str:
|
||||||
|
return "0.0.0"
|
||||||
|
|
||||||
|
version = Version()
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Version": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/versions",
|
||||||
|
layer_name="Version",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
assert version() == "0.2.0"
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Version": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/versions",
|
||||||
|
layer_name="Version",
|
||||||
|
version="<1.0.0",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
assert version() == "0.2.0"
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Version": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/versions",
|
||||||
|
layer_name="Version",
|
||||||
|
version="<0.2.0",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
assert version() == "0.1.1"
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Version": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/versions",
|
||||||
|
layer_name="Version",
|
||||||
|
version=">0.1.0,<0.2.0",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
assert version() == "0.1.1"
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Version": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/versions",
|
||||||
|
layer_name="Version",
|
||||||
|
version=">0.2.0",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
||||||
|
kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
||||||
|
use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Version": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/versions",
|
||||||
|
layer_name="Version",
|
||||||
|
revision="v0.1.0",
|
||||||
|
version="<1.0.0",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
Reference in New Issue
Block a user