mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-10-25 08:10:39 +08:00 
			
		
		
		
	Compare commits
	
		
			34 Commits
		
	
	
		
			hf-kernels
			...
			v0.4.3
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 6fd2112e22 | |||
| 70f56ff856 | |||
| 7178b0b86c | |||
| 0bbf90a564 | |||
| 27d6ffcb80 | |||
| f7bd21438b | |||
| 6174febb4b | |||
| ff55bc201b | |||
| 3808108d62 | |||
| c4a16ef462 | |||
| 9762794dd2 | |||
| b7d6867c52 | |||
| fbcd0f2ebd | |||
| 5af46eca94 | |||
| 747dd66876 | |||
| 920590a592 | |||
| 5208ac4be5 | |||
| 22eaba2826 | |||
| 9521ba79a0 | |||
| 9861a5bdef | |||
| 1c7c87c960 | |||
| df45cf2795 | |||
| cf0413efe5 | |||
| 851c13f666 | |||
| b6a393612f | |||
| 18ecd0ce69 | |||
| b4ef1d60e5 | |||
| a40756f306 | |||
| 3671158f47 | |||
| 2ddd473cf7 | |||
| 497dffb89e | |||
| f036fd09cb | |||
| 3e4c83c798 | |||
| 4116d6019e | 
							
								
								
									
										10
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,10 @@ | ||||
| name: Lints | ||||
| on: [push, pull_request] | ||||
| jobs: | ||||
|   lint: | ||||
|     name: Run lints | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|       - name: Run ruff | ||||
|         uses: astral-sh/ruff-action@v3 | ||||
							
								
								
									
										16
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										16
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,4 +1,4 @@ | ||||
| name: Test hf-kernels | ||||
| name: Test kernels | ||||
|  | ||||
| on: | ||||
|   push: | ||||
| @ -26,6 +26,9 @@ jobs: | ||||
|         python-version: ["3.10", "3.12"] | ||||
|         torch-version: ["2.5.1", "2.6.0"] | ||||
|  | ||||
|     env: | ||||
|       UV_PYTHON_PREFERENCE: only-managed | ||||
|  | ||||
|     steps: | ||||
|       - name: Checkout code | ||||
|         uses: actions/checkout@v4 | ||||
| @ -41,5 +44,16 @@ jobs: | ||||
|       - name: Install the project | ||||
|         run: uv sync --all-extras --dev | ||||
|  | ||||
|       - name: Install setuptools for Triton-based test | ||||
|         run: uv pip install setuptools | ||||
|  | ||||
|       - name: Check typing | ||||
|         run: uv run mypy src/kernels | ||||
|  | ||||
|       - name: Run tests | ||||
|         run: uv run pytest tests | ||||
|  | ||||
|       - name: Import check without torch | ||||
|         run: | | ||||
|           uv pip uninstall torch | ||||
|           python -c "import kernels" | ||||
|  | ||||
							
								
								
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,201 @@ | ||||
|                                  Apache License | ||||
|                            Version 2.0, January 2004 | ||||
|                         http://www.apache.org/licenses/ | ||||
|  | ||||
|    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||||
|  | ||||
|    1. Definitions. | ||||
|  | ||||
|       "License" shall mean the terms and conditions for use, reproduction, | ||||
|       and distribution as defined by Sections 1 through 9 of this document. | ||||
|  | ||||
|       "Licensor" shall mean the copyright owner or entity authorized by | ||||
|       the copyright owner that is granting the License. | ||||
|  | ||||
|       "Legal Entity" shall mean the union of the acting entity and all | ||||
|       other entities that control, are controlled by, or are under common | ||||
|       control with that entity. For the purposes of this definition, | ||||
|       "control" means (i) the power, direct or indirect, to cause the | ||||
|       direction or management of such entity, whether by contract or | ||||
|       otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||||
|       outstanding shares, or (iii) beneficial ownership of such entity. | ||||
|  | ||||
|       "You" (or "Your") shall mean an individual or Legal Entity | ||||
|       exercising permissions granted by this License. | ||||
|  | ||||
|       "Source" form shall mean the preferred form for making modifications, | ||||
|       including but not limited to software source code, documentation | ||||
|       source, and configuration files. | ||||
|  | ||||
|       "Object" form shall mean any form resulting from mechanical | ||||
|       transformation or translation of a Source form, including but | ||||
|       not limited to compiled object code, generated documentation, | ||||
|       and conversions to other media types. | ||||
|  | ||||
|       "Work" shall mean the work of authorship, whether in Source or | ||||
|       Object form, made available under the License, as indicated by a | ||||
|       copyright notice that is included in or attached to the work | ||||
|       (an example is provided in the Appendix below). | ||||
|  | ||||
|       "Derivative Works" shall mean any work, whether in Source or Object | ||||
|       form, that is based on (or derived from) the Work and for which the | ||||
|       editorial revisions, annotations, elaborations, or other modifications | ||||
|       represent, as a whole, an original work of authorship. For the purposes | ||||
|       of this License, Derivative Works shall not include works that remain | ||||
|       separable from, or merely link (or bind by name) to the interfaces of, | ||||
|       the Work and Derivative Works thereof. | ||||
|  | ||||
|       "Contribution" shall mean any work of authorship, including | ||||
|       the original version of the Work and any modifications or additions | ||||
|       to that Work or Derivative Works thereof, that is intentionally | ||||
|       submitted to Licensor for inclusion in the Work by the copyright owner | ||||
|       or by an individual or Legal Entity authorized to submit on behalf of | ||||
|       the copyright owner. For the purposes of this definition, "submitted" | ||||
|       means any form of electronic, verbal, or written communication sent | ||||
|       to the Licensor or its representatives, including but not limited to | ||||
|       communication on electronic mailing lists, source code control systems, | ||||
|       and issue tracking systems that are managed by, or on behalf of, the | ||||
|       Licensor for the purpose of discussing and improving the Work, but | ||||
|       excluding communication that is conspicuously marked or otherwise | ||||
|       designated in writing by the copyright owner as "Not a Contribution." | ||||
|  | ||||
|       "Contributor" shall mean Licensor and any individual or Legal Entity | ||||
|       on behalf of whom a Contribution has been received by Licensor and | ||||
|       subsequently incorporated within the Work. | ||||
|  | ||||
|    2. Grant of Copyright License. Subject to the terms and conditions of | ||||
|       this License, each Contributor hereby grants to You a perpetual, | ||||
|       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||
|       copyright license to reproduce, prepare Derivative Works of, | ||||
|       publicly display, publicly perform, sublicense, and distribute the | ||||
|       Work and such Derivative Works in Source or Object form. | ||||
|  | ||||
|    3. Grant of Patent License. Subject to the terms and conditions of | ||||
|       this License, each Contributor hereby grants to You a perpetual, | ||||
|       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||
|       (except as stated in this section) patent license to make, have made, | ||||
|       use, offer to sell, sell, import, and otherwise transfer the Work, | ||||
|       where such license applies only to those patent claims licensable | ||||
|       by such Contributor that are necessarily infringed by their | ||||
|       Contribution(s) alone or by combination of their Contribution(s) | ||||
|       with the Work to which such Contribution(s) was submitted. If You | ||||
|       institute patent litigation against any entity (including a | ||||
|       cross-claim or counterclaim in a lawsuit) alleging that the Work | ||||
|       or a Contribution incorporated within the Work constitutes direct | ||||
|       or contributory patent infringement, then any patent licenses | ||||
|       granted to You under this License for that Work shall terminate | ||||
|       as of the date such litigation is filed. | ||||
|  | ||||
|    4. Redistribution. You may reproduce and distribute copies of the | ||||
|       Work or Derivative Works thereof in any medium, with or without | ||||
|       modifications, and in Source or Object form, provided that You | ||||
|       meet the following conditions: | ||||
|  | ||||
|       (a) You must give any other recipients of the Work or | ||||
|           Derivative Works a copy of this License; and | ||||
|  | ||||
|       (b) You must cause any modified files to carry prominent notices | ||||
|           stating that You changed the files; and | ||||
|  | ||||
|       (c) You must retain, in the Source form of any Derivative Works | ||||
|           that You distribute, all copyright, patent, trademark, and | ||||
|           attribution notices from the Source form of the Work, | ||||
|           excluding those notices that do not pertain to any part of | ||||
|           the Derivative Works; and | ||||
|  | ||||
|       (d) If the Work includes a "NOTICE" text file as part of its | ||||
|           distribution, then any Derivative Works that You distribute must | ||||
|           include a readable copy of the attribution notices contained | ||||
|           within such NOTICE file, excluding those notices that do not | ||||
|           pertain to any part of the Derivative Works, in at least one | ||||
|           of the following places: within a NOTICE text file distributed | ||||
|           as part of the Derivative Works; within the Source form or | ||||
|           documentation, if provided along with the Derivative Works; or, | ||||
|           within a display generated by the Derivative Works, if and | ||||
|           wherever such third-party notices normally appear. The contents | ||||
|           of the NOTICE file are for informational purposes only and | ||||
|           do not modify the License. You may add Your own attribution | ||||
|           notices within Derivative Works that You distribute, alongside | ||||
|           or as an addendum to the NOTICE text from the Work, provided | ||||
|           that such additional attribution notices cannot be construed | ||||
|           as modifying the License. | ||||
|  | ||||
|       You may add Your own copyright statement to Your modifications and | ||||
|       may provide additional or different license terms and conditions | ||||
|       for use, reproduction, or distribution of Your modifications, or | ||||
|       for any such Derivative Works as a whole, provided Your use, | ||||
|       reproduction, and distribution of the Work otherwise complies with | ||||
|       the conditions stated in this License. | ||||
|  | ||||
|    5. Submission of Contributions. Unless You explicitly state otherwise, | ||||
|       any Contribution intentionally submitted for inclusion in the Work | ||||
|       by You to the Licensor shall be under the terms and conditions of | ||||
|       this License, without any additional terms or conditions. | ||||
|       Notwithstanding the above, nothing herein shall supersede or modify | ||||
|       the terms of any separate license agreement you may have executed | ||||
|       with Licensor regarding such Contributions. | ||||
|  | ||||
|    6. Trademarks. This License does not grant permission to use the trade | ||||
|       names, trademarks, service marks, or product names of the Licensor, | ||||
|       except as required for reasonable and customary use in describing the | ||||
|       origin of the Work and reproducing the content of the NOTICE file. | ||||
|  | ||||
|    7. Disclaimer of Warranty. Unless required by applicable law or | ||||
|       agreed to in writing, Licensor provides the Work (and each | ||||
|       Contributor provides its Contributions) on an "AS IS" BASIS, | ||||
|       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
|       implied, including, without limitation, any warranties or conditions | ||||
|       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||||
|       PARTICULAR PURPOSE. You are solely responsible for determining the | ||||
|       appropriateness of using or redistributing the Work and assume any | ||||
|       risks associated with Your exercise of permissions under this License. | ||||
|  | ||||
|    8. Limitation of Liability. In no event and under no legal theory, | ||||
|       whether in tort (including negligence), contract, or otherwise, | ||||
|       unless required by applicable law (such as deliberate and grossly | ||||
|       negligent acts) or agreed to in writing, shall any Contributor be | ||||
|       liable to You for damages, including any direct, indirect, special, | ||||
|       incidental, or consequential damages of any character arising as a | ||||
|       result of this License or out of the use or inability to use the | ||||
|       Work (including but not limited to damages for loss of goodwill, | ||||
|       work stoppage, computer failure or malfunction, or any and all | ||||
|       other commercial damages or losses), even if such Contributor | ||||
|       has been advised of the possibility of such damages. | ||||
|  | ||||
|    9. Accepting Warranty or Additional Liability. While redistributing | ||||
|       the Work or Derivative Works thereof, You may choose to offer, | ||||
|       and charge a fee for, acceptance of support, warranty, indemnity, | ||||
|       or other liability obligations and/or rights consistent with this | ||||
|       License. However, in accepting such obligations, You may act only | ||||
|       on Your own behalf and on Your sole responsibility, not on behalf | ||||
|       of any other Contributor, and only if You agree to indemnify, | ||||
|       defend, and hold each Contributor harmless for any liability | ||||
|       incurred by, or claims asserted against, such Contributor by reason | ||||
|       of your accepting any such warranty or additional liability. | ||||
|  | ||||
|    END OF TERMS AND CONDITIONS | ||||
|  | ||||
|    APPENDIX: How to apply the Apache License to your work. | ||||
|  | ||||
|       To apply the Apache License to your work, attach the following | ||||
|       boilerplate notice, with the fields enclosed by brackets "[]" | ||||
|       replaced with your own identifying information. (Don't include | ||||
|       the brackets!)  The text should be enclosed in the appropriate | ||||
|       comment syntax for the file format. We also recommend that a | ||||
|       file or class name and description of purpose be included on the | ||||
|       same "printed page" as the copyright notice for easier | ||||
|       identification within third-party archives. | ||||
|  | ||||
|    Copyright [yyyy] [name of copyright owner] | ||||
|  | ||||
|    Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|    you may not use this file except in compliance with the License. | ||||
|    You may obtain a copy of the License at | ||||
|  | ||||
|        http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
|    Unless required by applicable law or agreed to in writing, software | ||||
|    distributed under the License is distributed on an "AS IS" BASIS, | ||||
|    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|    See the License for the specific language governing permissions and | ||||
|    limitations under the License. | ||||
							
								
								
									
										91
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										91
									
								
								README.md
									
									
									
									
									
								
							| @ -1,4 +1,4 @@ | ||||
| # hf-kernels | ||||
| # kernels | ||||
|  | ||||
| The Kernel Hub allows Python libraries and applications to load compute | ||||
| kernels directly from the [Hub](https://hf.co/). To support this kind | ||||
| @ -12,16 +12,20 @@ packages in that they are made to be: | ||||
|   the different PyTorch build configurations (various CUDA versions | ||||
|   and C++ ABIs). Furthermore, older C library versions must be supported. | ||||
|  | ||||
| ## Usage | ||||
| ## 🚀 Quick Start | ||||
|  | ||||
| Kernels depends on `torch>=2.5` and CUDA for now.  | ||||
| Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA): | ||||
|  | ||||
| ```bash | ||||
| pip install kernels | ||||
| ``` | ||||
|  | ||||
| Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub: | ||||
|  | ||||
| ```python | ||||
| import torch | ||||
|  | ||||
| from hf_kernels import get_kernel | ||||
| from kernels import get_kernel | ||||
|  | ||||
| # Download optimized kernels from the Hugging Face hub | ||||
| activation = get_kernel("kernels-community/activation") | ||||
| @ -36,75 +40,14 @@ activation.gelu_fast(y, x) | ||||
| print(y) | ||||
| ``` | ||||
|  | ||||
| These kernels can be built from the [kernel-builder library](https://github.com/huggingface/kernel-builder).  | ||||
| You can [search for kernels](https://huggingface.co/models?other=kernel) on | ||||
| the Hub. | ||||
|  | ||||
| If you're looking to better understand how these kernels are structured, or looking to build your own kernels,  | ||||
| please take a look at the following guide:  | ||||
| [writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md). | ||||
| ## 📚 Documentation | ||||
|  | ||||
| ## Installation | ||||
|  | ||||
| To install `hf-kernels`, we recommend installing from the pypi package: | ||||
|  | ||||
| ```bash | ||||
| pip install hf-kernels | ||||
| ``` | ||||
|  | ||||
| You should then be able to run the script above (also in [examples/basic.py](examples/basic.py)): | ||||
| ```bash | ||||
| python examples/basic.py | ||||
| ``` | ||||
|  | ||||
| ## Docker Reference | ||||
|  | ||||
| build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands: | ||||
|  | ||||
| ```bash | ||||
| docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference . | ||||
| docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference | ||||
| ``` | ||||
|  | ||||
| ## Locking kernel versions | ||||
|  | ||||
| Projects that use `setuptools` can lock the kernel versions that should be | ||||
| used. First specify the accepted versions in `pyproject.toml` and make | ||||
| sure that `hf-kernels` is a build dependency: | ||||
|  | ||||
| ```toml | ||||
| [build-system] | ||||
| requires = ["hf-kernels", "setuptools"] | ||||
| build-backend = "setuptools.build_meta" | ||||
|  | ||||
| [tool.kernels.dependencies] | ||||
| "kernels-community/activation" = ">=0.0.1" | ||||
| ``` | ||||
|  | ||||
| Then run `hf-kernel lock .` in the project directory. This generates a `kernels.lock` file with | ||||
| the locked revisions. The locked revision will be used when loading a kernel with | ||||
| `get_locked_kernel`: | ||||
|  | ||||
| ```python | ||||
| from hf_kernels import get_locked_kernel | ||||
|  | ||||
| activation = get_locked_kernel("kernels-community/activation") | ||||
| ``` | ||||
|  | ||||
| **Note:** the lock file is included in the package metadata, so it will only be visible | ||||
| to `hf-kernels` after doing an (editable or regular) installation of your project. | ||||
|  | ||||
| ## Pre-downloading locked kernels | ||||
|  | ||||
| Locked kernels can be pre-downloaded by running `hf-kernel download .` in your | ||||
| project directory. This will download the kernels to your local Hugging Face | ||||
| Hub cache. | ||||
|  | ||||
| The pre-downloaded kernels are used by the `get_locked_kernel` function. | ||||
| `get_locked_kernel` will download a kernel when it is not pre-downloaded. If you | ||||
| want kernel loading to error when a kernel is not pre-downloaded, you can use | ||||
| the `load_kernel` function instead: | ||||
|  | ||||
| ```python | ||||
| from hf_kernels import load_kernel | ||||
|  | ||||
| activation = load_kernel("kernels-community/activation") | ||||
| ``` | ||||
| - [Using layers](docs/layers.md) | ||||
| - [Locking kernel versions](docs/locking.md) | ||||
| - [Environment variables](docs/env.md) | ||||
| - [Using kernels in a Docker container](docs/docker.md) | ||||
| - [Kernel requirements](docs/kernel-requirements.md) | ||||
| - [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/) | ||||
|  | ||||
| @ -31,13 +31,13 @@ WORKDIR /app/kernel-test | ||||
| # install python depdencies | ||||
| RUN uv add torch==2.5.0 numpy | ||||
|  | ||||
| # copy hf-kernels lib | ||||
| COPY src ./hf-kernels/src | ||||
| COPY pyproject.toml ./hf-kernels/pyproject.toml | ||||
| COPY README.md ./hf-kernels/README.md | ||||
| # copy kernels lib | ||||
| COPY src ./kernels/src | ||||
| COPY pyproject.toml ./kernels/pyproject.toml | ||||
| COPY README.md ./kernels/README.md | ||||
|  | ||||
| # install library | ||||
| RUN uv pip install -e hf-kernels | ||||
| RUN uv pip install -e kernels | ||||
|  | ||||
| # copy examples | ||||
| COPY examples ./examples | ||||
| @ -48,4 +48,4 @@ ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility | ||||
|  | ||||
| # command to run the script | ||||
| CMD ["uv", "run", "examples/basic.py"] | ||||
| # CMD ["ls", "hf-kernels"] | ||||
| # CMD ["ls", "kernels"] | ||||
|  | ||||
							
								
								
									
										8
									
								
								docs/docker.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								docs/docker.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,8 @@ | ||||
| # Using kernels in a Docker container | ||||
|  | ||||
| build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands: | ||||
|  | ||||
| ```bash | ||||
| docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference . | ||||
| docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference | ||||
| ``` | ||||
							
								
								
									
										10
									
								
								docs/env.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								docs/env.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,10 @@ | ||||
| # Environment variables | ||||
|  | ||||
| ## `KERNELS_CACHE` | ||||
|  | ||||
| The directory to use as the local kernel cache. If not set, the cache | ||||
| of the `huggingface_hub` package is used. | ||||
|  | ||||
| ## `DISABLE_KERNEL_MAPPING` | ||||
|  | ||||
| Disables kernel mappings for [`layers`](layers.md). | ||||
| @ -26,13 +26,24 @@ recommended build variants are: | ||||
| - `torch26-cxx98-cu124-x86_64-linux` | ||||
| - `torch26-cxx98-cu126-x86_64-linux` | ||||
|  | ||||
| This list will be updated as new PyTorch versions are released. Each | ||||
| variant directory should contain a single directory with the same name | ||||
| This list will be updated as new PyTorch versions are released. Kernels | ||||
| that are in pure Python (e.g. Triton kernels) only need to provide a | ||||
| single build variant: | ||||
|  | ||||
| - `torch-universal` | ||||
|  | ||||
| Each variant directory should contain a single directory with the same name | ||||
| as the repository (replacing `-` by `_`). For instance, kernels in the | ||||
| `kernels-community/activation` repository have a directories like | ||||
| `build/<variant>/activation`. This directory | ||||
| must be a Python package with an `__init__.py` file. | ||||
|  | ||||
| ## Versioning | ||||
|  | ||||
| Kernels are versioned on the Hub using Git tags. Version tags must be of | ||||
| the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md) | ||||
| to resolve the version constraints. | ||||
|  | ||||
| ## Native Python module | ||||
|  | ||||
| Kernels will typically contain a native Python module with precompiled | ||||
| @ -41,16 +52,31 @@ requirements: | ||||
|  | ||||
| - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | ||||
|   for compatibility with Python 3.9 and later. | ||||
| - Compatible with glibc 2.27 or later. This means that no symbols | ||||
|   from later versions must be used. To archive this, the module should | ||||
|   be built against this glibc version. **Warning:** libgcc must also be | ||||
|   built against glibc 2.27 to avoid leaking symbols. | ||||
| - No dynamic linkage against libstdc++/libc++. Linkage for C++ symbols | ||||
|   must be static. | ||||
| - No dynamic library dependencies outside Torch or CUDA libraries | ||||
|   installed as dependencies of Torch. | ||||
| - Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based). | ||||
|   This means that the extension **must not** use symbols versions higher than: | ||||
|  | ||||
| (These requirements will be updated as new PyTorch versions are released.) | ||||
|   - GLIBC 2.28 | ||||
|   - GLIBCXX 3.4.24 | ||||
|   - CXXABI 1.3.11 | ||||
|   - GCC 7.0.0 | ||||
|  | ||||
|   These requirement can be checked with the ABI checker (see below). | ||||
|  | ||||
| - No dynamic library dependencies outside: | ||||
|  | ||||
|   - Torch; | ||||
|   - CUDA/ROCm libraries installed as dependencies of Torch. | ||||
|  | ||||
| The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with | ||||
| [`kernel-abi-check`](https://crates.io/crates/kernel-abi-check): | ||||
|  | ||||
| ```bash | ||||
|  | ||||
| $ cargo install kernel-abi-check | ||||
| $ kernel-abi-check result/relu/_relu_e87e0ca_dirty.abi3.so | ||||
| 🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9 | ||||
| ✅ No compatibility issues found | ||||
| ``` | ||||
|  | ||||
| ## Torch extension | ||||
|  | ||||
| @ -71,6 +97,80 @@ might use two different commits that happen to have the same version | ||||
| number. Git tags are not stable, so they do not provide a good way | ||||
| of guaranteeing uniqueness of the namespace. | ||||
|  | ||||
| ## Layers | ||||
|  | ||||
| A kernel can provide layers in addition to kernel functions. A layer from | ||||
| the Hub can replace the `forward` method of an existing layer for a certain | ||||
| device type. This makes it possible to provide more performant kernels for | ||||
| existing layers. See the [layers documentation](layers.md) for more information | ||||
| on how to use layers. | ||||
|  | ||||
| ### Writing layers | ||||
|  | ||||
| To make the extension of layers safe, the layers must fulfill the following | ||||
| requirements: | ||||
|  | ||||
| - The layers are subclasses of `torch.nn.Module`. | ||||
| - The layers are pure, meaning that they do not have their own state. This | ||||
|   means that: | ||||
|   - The layer must not define its own constructor. | ||||
|   - The layer must not use class variables. | ||||
| - No other methods must be defined than `forward`. | ||||
| - The `forward` method has a signature that is compatible with the | ||||
|   `forward` method that it is extending. | ||||
|  | ||||
| This is an example of a pure layer: | ||||
|  | ||||
| ```python | ||||
| class SiluAndMul(nn.Module): | ||||
|     def forward(self, x: torch.Tensor): | ||||
|         d = x.shape[-1] // 2 | ||||
|         output_shape = x.shape[:-1] + (d,) | ||||
|         out = torch.empty(output_shape, dtype=x.dtype, device=x.device) | ||||
|         ops.silu_and_mul(out, x) | ||||
|         return out | ||||
| ``` | ||||
|  | ||||
| For some layers, the `forward` method has to use state from the adopting class. | ||||
| In these cases, we recommend to use type annotations to indicate what member | ||||
| variables are expected. For instance: | ||||
|  | ||||
| ```python | ||||
| class LlamaRMSNorm(nn.Module): | ||||
|     weight: torch.Tensor | ||||
|     variance_epsilon: float | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||
|         return rms_norm_fn( | ||||
|             hidden_states, | ||||
|             self.weight, | ||||
|             bias=None, | ||||
|             residual=None, | ||||
|             eps=self.variance_epsilon, | ||||
|             dropout_p=0.0, | ||||
|             prenorm=False, | ||||
|             residual_in_fp32=False, | ||||
|         ) | ||||
| ``` | ||||
|  | ||||
| This layer expects the adopting layer to have `weight` and `variance_epsilon` | ||||
| member variables and uses them in the `forward` method. | ||||
|  | ||||
| ### Exporting layers | ||||
|  | ||||
| To accommodate portable loading, `layers` must be defined in the main | ||||
| `__init__.py` file. For example: | ||||
|  | ||||
| ```python | ||||
| from . import layers | ||||
|  | ||||
| __all__ = [ | ||||
|   # ... | ||||
|   "layers" | ||||
|   # ... | ||||
| ] | ||||
| ``` | ||||
|  | ||||
| ## Python requirements | ||||
|  | ||||
| - Python code must be compatible with Python 3.9 and later. | ||||
|  | ||||
							
								
								
									
										79
									
								
								docs/layers.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								docs/layers.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,79 @@ | ||||
| # Layers | ||||
|  | ||||
| A kernel can provide layers in addition to kernel functions. A layer from | ||||
| the Hub can replace the `forward` method of an existing layer for a certain | ||||
| device type. This makes it possible to provide more performant kernels for | ||||
| existing layers. | ||||
|  | ||||
| See [Kernel requirements](kernel-requirements.md) for more information the | ||||
| requirements of Hub layers. | ||||
|  | ||||
| ## Making a layer extensible with kernels from the hub | ||||
|  | ||||
| ### Using a decorator | ||||
|  | ||||
| A layer can be made extensible with the `use_kernel_forward_from_hub` | ||||
| decorator. For example: | ||||
|  | ||||
| ```python | ||||
| @use_kernel_forward_from_hub("SiluAndMul") | ||||
| class SiluAndMul(nn.Module): | ||||
|     def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         d = input.shape[-1] // 2 | ||||
|         return F.silu(input[..., :d]) * input[..., d:] | ||||
| ``` | ||||
|  | ||||
| The decorator changes the layer, so that other implementations of the `forward` | ||||
| method can be registered using the name `SiluAndMul`. | ||||
|  | ||||
| ### External layers | ||||
|  | ||||
| An existing layer that does not (yet) have the `use_kernel_forward_from_hub` | ||||
| decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function. | ||||
|  | ||||
| ```python | ||||
| from somelibrary import SiluAndMul | ||||
|  | ||||
| replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul") | ||||
| register_kernel_mapping(kernel_layer_mapping) | ||||
| ``` | ||||
|  | ||||
| The `register_kernel_mapping` call maps the name `SiluAndMul` to actual | ||||
| hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer) | ||||
| section for more information. | ||||
|  | ||||
| **Warning:** we strongly recommend using layers with a decorator, since | ||||
| it signifies that the maintainer intends to keep the `forward` signature | ||||
| compatible with layers from the hub. | ||||
|  | ||||
| ## Registering a hub kernel for a layer | ||||
|  | ||||
| Once a layer is made extensible, users can register hub kernels for it | ||||
| by name using the `register_kernel_mapping` function. For example: | ||||
|  | ||||
| ```python | ||||
| kernel_layer_mapping = { | ||||
|     "SiluAndMul": { | ||||
|         "cuda": LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|             revision="layers", | ||||
|         ) | ||||
|     } | ||||
| } | ||||
|  | ||||
| register_kernel_mapping(kernel_layer_mapping) | ||||
| ``` | ||||
|  | ||||
| This will register the kernel mapping in the current context, which is | ||||
| normally global. It is recommended to scope the mapping to where it is | ||||
| used with the `use_kernel_mapping` context manager: | ||||
|  | ||||
| ```python | ||||
| with use_kernel_mapping(kernel_layer_mapping): | ||||
|     # Use the layer for which the mapping is applied. | ||||
|     ... | ||||
| ``` | ||||
|  | ||||
| This ensures that the mapping is not active anymore outside the | ||||
| `with`-scope. | ||||
							
								
								
									
										44
									
								
								docs/locking.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								docs/locking.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,44 @@ | ||||
| # Locking kernel versions | ||||
|  | ||||
| Projects that use `setuptools` can lock the kernel versions that should be | ||||
| used. First specify the accepted versions in `pyproject.toml` and make | ||||
| sure that `kernels` is a build dependency: | ||||
|  | ||||
| ```toml | ||||
| [build-system] | ||||
| requires = ["kernels", "setuptools"] | ||||
| build-backend = "setuptools.build_meta" | ||||
|  | ||||
| [tool.kernels.dependencies] | ||||
| "kernels-community/activation" = ">=0.0.1" | ||||
| ``` | ||||
|  | ||||
| Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with | ||||
| the locked revisions. The locked revision will be used when loading a kernel with | ||||
| `get_locked_kernel`: | ||||
|  | ||||
| ```python | ||||
| from kernels import get_locked_kernel | ||||
|  | ||||
| activation = get_locked_kernel("kernels-community/activation") | ||||
| ``` | ||||
|  | ||||
| **Note:** the lock file is included in the package metadata, so it will only be visible | ||||
| to `kernels` after doing an (editable or regular) installation of your project. | ||||
|  | ||||
| ## Pre-downloading locked kernels | ||||
|  | ||||
| Locked kernels can be pre-downloaded by running `kernel download .` in your | ||||
| project directory. This will download the kernels to your local Hugging Face | ||||
| Hub cache. | ||||
|  | ||||
| The pre-downloaded kernels are used by the `get_locked_kernel` function. | ||||
| `get_locked_kernel` will download a kernel when it is not pre-downloaded. If you | ||||
| want kernel loading to error when a kernel is not pre-downloaded, you can use | ||||
| the `load_kernel` function instead: | ||||
|  | ||||
| ```python | ||||
| from kernels import load_kernel | ||||
|  | ||||
| activation = load_kernel("kernels-community/activation") | ||||
| ``` | ||||
| @ -1,6 +1,6 @@ | ||||
| import torch | ||||
|  | ||||
| from hf_kernels import get_kernel | ||||
| from kernels import get_kernel | ||||
|  | ||||
| print("Starting examples/basic.py demo") | ||||
|  | ||||
|  | ||||
							
								
								
									
										134
									
								
								flake.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								flake.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							| @ -0,0 +1,134 @@ | ||||
| { | ||||
|   "nodes": { | ||||
|     "flake-compat": { | ||||
|       "locked": { | ||||
|         "lastModified": 1733328505, | ||||
|         "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", | ||||
|         "owner": "edolstra", | ||||
|         "repo": "flake-compat", | ||||
|         "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", | ||||
|         "type": "github" | ||||
|       }, | ||||
|       "original": { | ||||
|         "owner": "edolstra", | ||||
|         "repo": "flake-compat", | ||||
|         "type": "github" | ||||
|       } | ||||
|     }, | ||||
|     "flake-utils": { | ||||
|       "inputs": { | ||||
|         "systems": "systems" | ||||
|       }, | ||||
|       "locked": { | ||||
|         "lastModified": 1731533236, | ||||
|         "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", | ||||
|         "owner": "numtide", | ||||
|         "repo": "flake-utils", | ||||
|         "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", | ||||
|         "type": "github" | ||||
|       }, | ||||
|       "original": { | ||||
|         "owner": "numtide", | ||||
|         "repo": "flake-utils", | ||||
|         "type": "github" | ||||
|       } | ||||
|     }, | ||||
|     "flake-utils_2": { | ||||
|       "inputs": { | ||||
|         "systems": "systems_2" | ||||
|       }, | ||||
|       "locked": { | ||||
|         "lastModified": 1731533236, | ||||
|         "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", | ||||
|         "owner": "numtide", | ||||
|         "repo": "flake-utils", | ||||
|         "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", | ||||
|         "type": "github" | ||||
|       }, | ||||
|       "original": { | ||||
|         "owner": "numtide", | ||||
|         "repo": "flake-utils", | ||||
|         "type": "github" | ||||
|       } | ||||
|     }, | ||||
|     "nixpkgs": { | ||||
|       "locked": { | ||||
|         "lastModified": 1737453259, | ||||
|         "narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=", | ||||
|         "owner": "danieldk", | ||||
|         "repo": "nixpkgs", | ||||
|         "rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e", | ||||
|         "type": "github" | ||||
|       }, | ||||
|       "original": { | ||||
|         "owner": "danieldk", | ||||
|         "ref": "outlines-v0.1.4-tgi", | ||||
|         "repo": "nixpkgs", | ||||
|         "type": "github" | ||||
|       } | ||||
|     }, | ||||
|     "root": { | ||||
|       "inputs": { | ||||
|         "flake-utils": "flake-utils", | ||||
|         "nixpkgs": [ | ||||
|           "tgi-nix", | ||||
|           "nixpkgs" | ||||
|         ], | ||||
|         "tgi-nix": "tgi-nix" | ||||
|       } | ||||
|     }, | ||||
|     "systems": { | ||||
|       "locked": { | ||||
|         "lastModified": 1681028828, | ||||
|         "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", | ||||
|         "owner": "nix-systems", | ||||
|         "repo": "default", | ||||
|         "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", | ||||
|         "type": "github" | ||||
|       }, | ||||
|       "original": { | ||||
|         "owner": "nix-systems", | ||||
|         "repo": "default", | ||||
|         "type": "github" | ||||
|       } | ||||
|     }, | ||||
|     "systems_2": { | ||||
|       "locked": { | ||||
|         "lastModified": 1681028828, | ||||
|         "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", | ||||
|         "owner": "nix-systems", | ||||
|         "repo": "default", | ||||
|         "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", | ||||
|         "type": "github" | ||||
|       }, | ||||
|       "original": { | ||||
|         "owner": "nix-systems", | ||||
|         "repo": "default", | ||||
|         "type": "github" | ||||
|       } | ||||
|     }, | ||||
|     "tgi-nix": { | ||||
|       "inputs": { | ||||
|         "flake-compat": "flake-compat", | ||||
|         "flake-utils": "flake-utils_2", | ||||
|         "nixpkgs": "nixpkgs" | ||||
|       }, | ||||
|       "locked": { | ||||
|         "lastModified": 1741617161, | ||||
|         "narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=", | ||||
|         "owner": "huggingface", | ||||
|         "repo": "text-generation-inference-nix", | ||||
|         "rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925", | ||||
|         "type": "github" | ||||
|       }, | ||||
|       "original": { | ||||
|         "owner": "huggingface", | ||||
|         "ref": "kernels-0.2.0", | ||||
|         "repo": "text-generation-inference-nix", | ||||
|         "type": "github" | ||||
|       } | ||||
|     } | ||||
|   }, | ||||
|   "root": "root", | ||||
|   "version": 7 | ||||
| } | ||||
							
								
								
									
										54
									
								
								flake.nix
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								flake.nix
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,54 @@ | ||||
| { | ||||
|   inputs = { | ||||
|     tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0"; | ||||
|     nixpkgs.follows = "tgi-nix/nixpkgs"; | ||||
|     flake-utils.url = "github:numtide/flake-utils"; | ||||
|   }; | ||||
|   outputs = | ||||
|     { | ||||
|       self, | ||||
|       nixpkgs, | ||||
|       flake-utils, | ||||
|       tgi-nix, | ||||
|     }: | ||||
|     flake-utils.lib.eachDefaultSystem ( | ||||
|       system: | ||||
|       let | ||||
|         pkgs = import nixpkgs { | ||||
|           inherit system; | ||||
|           inherit (tgi-nix.lib) config; | ||||
|           overlays = [ | ||||
|             tgi-nix.overlays.default | ||||
|           ]; | ||||
|         }; | ||||
|       in | ||||
|       { | ||||
|         formatter = pkgs.nixfmt-rfc-style; | ||||
|         devShells = with pkgs; rec { | ||||
|           default = mkShell { | ||||
|             buildInputs = | ||||
|               [ | ||||
|                 black | ||||
|                 mypy | ||||
|                 pyright | ||||
|                 ruff | ||||
|               ] | ||||
|               ++ (with python3.pkgs; [ | ||||
|                 huggingface-hub | ||||
|                 pytest | ||||
|                 pytest-benchmark | ||||
|                 torch | ||||
|                 venvShellHook | ||||
|               ]); | ||||
|  | ||||
|             venvDir = "./.venv"; | ||||
|  | ||||
|             postVenvCreation = '' | ||||
|               unset SOURCE_DATE_EPOCH | ||||
|               ( python -m pip install --no-build-isolation --no-dependencies -e . ) | ||||
|             ''; | ||||
|           }; | ||||
|         }; | ||||
|       } | ||||
|     ); | ||||
| } | ||||
| @ -1,20 +1,20 @@ | ||||
| [project] | ||||
| name = "hf-kernels" | ||||
| version = "0.1.6" | ||||
| description = "Download cuda kernels" | ||||
| name = "kernels" | ||||
| version = "0.4.3" | ||||
| description = "Download compute kernels" | ||||
| authors = [ | ||||
|   { name = "OlivierDehaene", email = "olivier@huggingface.co" }, | ||||
|   { name = "Daniel de Kok", email = "daniel@huggingface.co" }, | ||||
|   { name = "David Holtz", email = "david@huggingface.co" }, | ||||
|   { name = "Nicolas Patry", email = "nicolas@huggingface.co" }, | ||||
| ] | ||||
| license = { text = "Apache-2.0" } | ||||
| readme = "README.md" | ||||
| requires-python = ">= 3.9" | ||||
| dependencies = [ | ||||
|   "huggingface-hub>=0.26.3", | ||||
|   "packaging>=24.2", | ||||
|   "tomli>=2.0.1; python_version<'3.11'", | ||||
|   "torch>=2.4", | ||||
|   "huggingface_hub>=0.26.0,<1.0", | ||||
|   "packaging>=20.0", | ||||
|   "tomli>=2.0; python_version<'3.11'", | ||||
| ] | ||||
|  | ||||
| [build-system] | ||||
| @ -23,18 +23,46 @@ build-backend = "setuptools.build_meta" | ||||
|  | ||||
| [dependency-groups] | ||||
| dev = [ | ||||
|   "mypy == 1.14.1", | ||||
|   "pytest >=8", | ||||
|   # Whatever version is compatible with pytest. | ||||
|   "pytest-benchmark", | ||||
|   "torch >=2.5", | ||||
| ] | ||||
|  | ||||
| [project.optional-dependencies] | ||||
| torch = ["torch"] | ||||
|  | ||||
| [project.scripts] | ||||
| hf-kernels = "hf_kernels.cli:main" | ||||
| kernels = "kernels.cli:main" | ||||
|  | ||||
| [project.entry-points."egg_info.writers"] | ||||
| "hf-kernels.lock" = "hf_kernels.lockfile:write_egg_lockfile" | ||||
| "kernels.lock" = "kernels.lockfile:write_egg_lockfile" | ||||
|  | ||||
| #[build-system] | ||||
| #requires = ["torch", "huggingface_hub", "numpy", "tomli;python_version<='3.10'"] | ||||
| #build-backend = "hf_kernels.build" | ||||
| #backend-path = ["src"] | ||||
|  | ||||
| [tool.ruff] | ||||
| exclude = [ | ||||
|   ".eggs", | ||||
|   ".git", | ||||
|   ".git-rewrite", | ||||
|   ".hg", | ||||
|   ".mypy_cache", | ||||
|   ".nox", | ||||
|   ".pants.d", | ||||
|   ".pytype", | ||||
|   ".ruff_cache", | ||||
|   ".svn", | ||||
|   ".tox", | ||||
|   ".venv", | ||||
|   ".venv*", | ||||
|   "__pypackages__", | ||||
|   "_build", | ||||
|   "build", | ||||
|   "dist", | ||||
|   "venv", | ||||
| ] | ||||
| line-length = 119 | ||||
| # Ignored rules: | ||||
| # "E501" -> line length violation | ||||
| lint.ignore = ["E501"] | ||||
| lint.select = ["E", "F", "I", "W"] | ||||
|  | ||||
| @ -1,3 +0,0 @@ | ||||
| from hf_kernels.utils import get_kernel, install_kernel, load_kernel, get_locked_kernel | ||||
|  | ||||
| __all__ = ["get_kernel", "get_locked_kernel", "load_kernel", "install_kernel"] | ||||
| @ -1,144 +0,0 @@ | ||||
| """ | ||||
| Python shims for the PEP 517 and PEP 660 build backend. | ||||
|  | ||||
| Major imports in this module are required to be lazy: | ||||
| ``` | ||||
| $ hyperfine \ | ||||
|      "/usr/bin/python3 -c \"print('hi')\"" \ | ||||
|      "/usr/bin/python3 -c \"from subprocess import check_call; print('hi')\"" | ||||
| Base: Time (mean ± σ):      11.0 ms ±   1.7 ms    [User: 8.5 ms, System: 2.5 ms] | ||||
| With import: Time (mean ± σ):      15.2 ms ±   2.0 ms    [User: 12.3 ms, System: 2.9 ms] | ||||
| Base 1.38 ± 0.28 times faster than with import | ||||
| ``` | ||||
|  | ||||
| The same thing goes for the typing module, so we use Python 3.10 type annotations that | ||||
| don't require importing typing but then quote them so earlier Python version ignore | ||||
| them while IDEs and type checker can see through the quotes. | ||||
| """ | ||||
|  | ||||
| from hf_kernels.compat import tomllib | ||||
|  | ||||
| TYPE_CHECKING = False | ||||
| if TYPE_CHECKING: | ||||
|     from collections.abc import Mapping, Sequence  # noqa:I001 | ||||
|     from typing import Any  # noqa:I001 | ||||
|  | ||||
|  | ||||
| def warn_config_settings(config_settings: "Mapping[Any, Any] | None" = None) -> None: | ||||
|     import sys | ||||
|  | ||||
|     if config_settings: | ||||
|         print("Warning: Config settings are not supported", file=sys.stderr) | ||||
|  | ||||
|  | ||||
| def call( | ||||
|     args: "Sequence[str]", config_settings: "Mapping[Any, Any] | None" = None | ||||
| ) -> str: | ||||
|     """Invoke a uv subprocess and return the filename from stdout.""" | ||||
|     import shutil | ||||
|     import subprocess | ||||
|     import sys | ||||
|  | ||||
|     warn_config_settings(config_settings) | ||||
|     # Unlike `find_uv_bin`, this mechanism must work according to PEP 517 | ||||
|     import os | ||||
|  | ||||
|     cwd = os.getcwd() | ||||
|     filename = os.path.join(cwd, "pyproject.toml") | ||||
|     with open(filename, "rb") as f: | ||||
|         data = tomllib.load(f) | ||||
|  | ||||
|     for kernel, _ in ( | ||||
|         data.get("tool", {}).get("hf-kernels", {}).get("dependencies", {}).items() | ||||
|     ): | ||||
|         from hf_kernels.utils import install_kernel | ||||
|  | ||||
|         install_kernel(kernel, revision="main") | ||||
|     uv_bin = shutil.which("uv") | ||||
|     if uv_bin is None: | ||||
|         raise RuntimeError("uv was not properly installed") | ||||
|     # Forward stderr, capture stdout for the filename | ||||
|     result = subprocess.run([uv_bin, *args], stdout=subprocess.PIPE) | ||||
|     if result.returncode != 0: | ||||
|         sys.exit(result.returncode) | ||||
|     # If there was extra stdout, forward it (there should not be extra stdout) | ||||
|     stdout = result.stdout.decode("utf-8").strip().splitlines(keepends=True) | ||||
|     sys.stdout.writelines(stdout[:-1]) | ||||
|     # Fail explicitly instead of an irrelevant stacktrace | ||||
|     if not stdout: | ||||
|         print("uv subprocess did not return a filename on stdout", file=sys.stderr) | ||||
|         sys.exit(1) | ||||
|     return stdout[-1].strip() | ||||
|  | ||||
|  | ||||
| def build_sdist( | ||||
|     sdist_directory: str, config_settings: "Mapping[Any, Any] | None" = None | ||||
| ) -> str: | ||||
|     """PEP 517 hook `build_sdist`.""" | ||||
|     args = ["build-backend", "build-sdist", sdist_directory] | ||||
|     return call(args, config_settings) | ||||
|  | ||||
|  | ||||
| def build_wheel( | ||||
|     wheel_directory: str, | ||||
|     config_settings: "Mapping[Any, Any] | None" = None, | ||||
|     metadata_directory: "str | None" = None, | ||||
| ) -> str: | ||||
|     """PEP 517 hook `build_wheel`.""" | ||||
|     args = ["build-backend", "build-wheel", wheel_directory] | ||||
|     if metadata_directory: | ||||
|         args.extend(["--metadata-directory", metadata_directory]) | ||||
|     return call(args, config_settings) | ||||
|  | ||||
|  | ||||
| def get_requires_for_build_sdist( | ||||
|     config_settings: "Mapping[Any, Any] | None" = None, | ||||
| ) -> "Sequence[str]": | ||||
|     """PEP 517 hook `get_requires_for_build_sdist`.""" | ||||
|     warn_config_settings(config_settings) | ||||
|     return [] | ||||
|  | ||||
|  | ||||
| def get_requires_for_build_wheel( | ||||
|     config_settings: "Mapping[Any, Any] | None" = None, | ||||
| ) -> "Sequence[str]": | ||||
|     """PEP 517 hook `get_requires_for_build_wheel`.""" | ||||
|     warn_config_settings(config_settings) | ||||
|     return [] | ||||
|  | ||||
|  | ||||
| def prepare_metadata_for_build_wheel( | ||||
|     metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None | ||||
| ) -> str: | ||||
|     """PEP 517 hook `prepare_metadata_for_build_wheel`.""" | ||||
|     args = ["build-backend", "prepare-metadata-for-build-wheel", metadata_directory] | ||||
|     return call(args, config_settings) | ||||
|  | ||||
|  | ||||
| def build_editable( | ||||
|     wheel_directory: str, | ||||
|     config_settings: "Mapping[Any, Any] | None" = None, | ||||
|     metadata_directory: "str | None" = None, | ||||
| ) -> str: | ||||
|     """PEP 660 hook `build_editable`.""" | ||||
|     args = ["build-backend", "build-editable", wheel_directory] | ||||
|  | ||||
|     if metadata_directory: | ||||
|         args.extend(["--metadata-directory", metadata_directory]) | ||||
|     return call(args, config_settings) | ||||
|  | ||||
|  | ||||
| def get_requires_for_build_editable( | ||||
|     config_settings: "Mapping[Any, Any] | None" = None, | ||||
| ) -> "Sequence[str]": | ||||
|     """PEP 660 hook `get_requires_for_build_editable`.""" | ||||
|     warn_config_settings(config_settings) | ||||
|     return [] | ||||
|  | ||||
|  | ||||
| def prepare_metadata_for_build_editable( | ||||
|     metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None | ||||
| ) -> str: | ||||
|     """PEP 660 hook `prepare_metadata_for_build_editable`.""" | ||||
|     args = ["build-backend", "prepare-metadata-for-build-editable", metadata_directory] | ||||
|     return call(args, config_settings) | ||||
							
								
								
									
										27
									
								
								src/kernels/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								src/kernels/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | ||||
| from kernels.layer import ( | ||||
|     Device, | ||||
|     LayerRepository, | ||||
|     register_kernel_mapping, | ||||
|     replace_kernel_forward_from_hub, | ||||
|     use_kernel_forward_from_hub, | ||||
|     use_kernel_mapping, | ||||
| ) | ||||
| from kernels.utils import ( | ||||
|     get_kernel, | ||||
|     get_locked_kernel, | ||||
|     install_kernel, | ||||
|     load_kernel, | ||||
| ) | ||||
|  | ||||
| __all__ = [ | ||||
|     "get_kernel", | ||||
|     "get_locked_kernel", | ||||
|     "load_kernel", | ||||
|     "install_kernel", | ||||
|     "use_kernel_forward_from_hub", | ||||
|     "use_kernel_mapping", | ||||
|     "register_kernel_mapping", | ||||
|     "replace_kernel_forward_from_hub", | ||||
|     "LayerRepository", | ||||
|     "Device", | ||||
| ] | ||||
| @ -4,14 +4,14 @@ import json | ||||
| import sys | ||||
| from pathlib import Path | ||||
| 
 | ||||
| from hf_kernels.compat import tomllib | ||||
| from hf_kernels.lockfile import KernelLock, get_kernel_locks | ||||
| from hf_kernels.utils import build_variant, install_kernel, install_kernel_all_variants | ||||
| from kernels.compat import tomllib | ||||
| from kernels.lockfile import KernelLock, get_kernel_locks | ||||
| from kernels.utils import install_kernel, install_kernel_all_variants | ||||
| 
 | ||||
| 
 | ||||
| def main(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         prog="hf-kernel", description="Manage compute kernels" | ||||
|         prog="kernel", description="Manage compute kernels" | ||||
|     ) | ||||
|     subparsers = parser.add_subparsers(required=True) | ||||
| 
 | ||||
| @ -41,13 +41,13 @@ def main(): | ||||
| 
 | ||||
| 
 | ||||
| def download_kernels(args): | ||||
|     lock_path = args.project_dir / "hf-kernels.lock" | ||||
|     lock_path = args.project_dir / "kernels.lock" | ||||
| 
 | ||||
|     if not lock_path.exists(): | ||||
|         print(f"No hf-kernels.lock file found in: {args.project_dir}", file=sys.stderr) | ||||
|         print(f"No kernels.lock file found in: {args.project_dir}", file=sys.stderr) | ||||
|         sys.exit(1) | ||||
| 
 | ||||
|     with open(args.project_dir / "hf-kernels.lock", "r") as f: | ||||
|     with open(args.project_dir / "kernels.lock", "r") as f: | ||||
|         lock_json = json.load(f) | ||||
| 
 | ||||
|     all_successful = True | ||||
| @ -67,7 +67,7 @@ def download_kernels(args): | ||||
|                 install_kernel( | ||||
|                     kernel_lock.repo_id, | ||||
|                     kernel_lock.sha, | ||||
|                     variant_lock=kernel_lock.variants[build_variant()], | ||||
|                     variant_locks=kernel_lock.variants, | ||||
|                 ) | ||||
|             except FileNotFoundError as e: | ||||
|                 print(e, file=sys.stderr) | ||||
| @ -87,7 +87,7 @@ def lock_kernels(args): | ||||
|     for kernel, version in kernel_versions.items(): | ||||
|         all_locks.append(get_kernel_locks(kernel, version)) | ||||
| 
 | ||||
|     with open(args.project_dir / "hf-kernels.lock", "w") as f: | ||||
|     with open(args.project_dir / "kernels.lock", "w") as f: | ||||
|         json.dump(all_locks, f, cls=_JSONEncoder, indent=2) | ||||
| 
 | ||||
| 
 | ||||
							
								
								
									
										259
									
								
								src/kernels/layer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										259
									
								
								src/kernels/layer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,259 @@ | ||||
| import inspect | ||||
| import os | ||||
| import warnings | ||||
| from contextvars import ContextVar | ||||
| from copy import deepcopy | ||||
| from dataclasses import dataclass, field | ||||
| from typing import TYPE_CHECKING, Callable, Dict, Union | ||||
|  | ||||
| from .utils import get_kernel | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from torch import nn | ||||
|  | ||||
| _DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0"))) | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True) | ||||
| class Device: | ||||
|     type: str | ||||
|  | ||||
|     # In the future we might add compute capabilities, etc. | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         return isinstance(other, Device) and self.type == other.type | ||||
|  | ||||
|     def __hash__(self): | ||||
|         return hash(self.type) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class LayerRepository: | ||||
|     """ | ||||
|     Repository and name of a layer. | ||||
|     """ | ||||
|  | ||||
|     layer_name: str = field( | ||||
|         metadata={"help": "The name of the layer in the kernel repository."} | ||||
|     ) | ||||
|     repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."}) | ||||
|     revision: str = field( | ||||
|         default="main", metadata={"help": "The revision of the layer."} | ||||
|     ) | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         return ( | ||||
|             isinstance(other, LayerRepository) | ||||
|             and self.layer_name == other.layer_name | ||||
|             and self.repo_id == other.repo_id | ||||
|             and self.revision == other.revision | ||||
|         ) | ||||
|  | ||||
|     def __hash__(self): | ||||
|         return hash((self.layer_name, self.repo_id, self.revision)) | ||||
|  | ||||
|  | ||||
| _KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar( | ||||
|     "_KERNEL_MAPPING", default={} | ||||
| ) | ||||
|  | ||||
|  | ||||
| def use_kernel_mapping( | ||||
|     mapping: Dict[str, Dict[Union[Device, str], LayerRepository]], | ||||
|     *, | ||||
|     inherit_mapping: bool = True, | ||||
| ): | ||||
|     """ | ||||
|     Context manager that sets a mapping for a duration of the context. | ||||
|  | ||||
|     When `inherit_mapping` is set to `True` the current mapping will be | ||||
|     extended by `mapping` inside the context. If it is `False`, only | ||||
|     `mapping` is used inside the context. | ||||
|     """ | ||||
|  | ||||
|     class ContextManager: | ||||
|         def __enter__(self): | ||||
|             # Mappings always stack on previous mappings. | ||||
|             if inherit_mapping: | ||||
|                 self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get())) | ||||
|             else: | ||||
|                 self.token = _KERNEL_MAPPING.set({}) | ||||
|             register_kernel_mapping(mapping) | ||||
|  | ||||
|         def __exit__(self, exc_type, exc_value, traceback): | ||||
|             _KERNEL_MAPPING.reset(self.token) | ||||
|  | ||||
|     return ContextManager() | ||||
|  | ||||
|  | ||||
| def register_kernel_mapping( | ||||
|     mapping: Dict[str, Dict[Union[Device, str], LayerRepository]] | ||||
| ): | ||||
|     """ | ||||
|     Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device. | ||||
|     This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname. | ||||
|     Exemple usage: | ||||
|  | ||||
|     ```python | ||||
|     from kernels import LayerRepository, register_kernel_mapping | ||||
|  | ||||
|     kernel_layer_mapping = { | ||||
|       "LlamaRMSNorm": { | ||||
|           "cuda": LayerRepository( | ||||
|               repo_id="kernels-community/activation", | ||||
|               layer_name="RmsNorm", | ||||
|               revision="layers", | ||||
|           ), | ||||
|       }, | ||||
|     } | ||||
|     register_kernel_mapping(kernel_layer_mapping) | ||||
|     ``` | ||||
|     """ | ||||
|     # Merge with existing mappings. | ||||
|     for new_kernel, new_device_repos in mapping.items(): | ||||
|         device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {}) | ||||
|         for new_device, new_repo in new_device_repos.items(): | ||||
|             if isinstance(new_device, str): | ||||
|                 device_repo[Device(type=new_device)] = new_repo | ||||
|             else: | ||||
|                 device_repo[new_device] = new_repo | ||||
|  | ||||
|  | ||||
| def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True): | ||||
|     """ | ||||
|     Replace the forward function of a layer using a layer from the kernel hub. | ||||
|     This function monkeypatches a layer, replacing the `forward` method | ||||
|     of the layer with that of a layer from the hub. The replacement is done | ||||
|     when a layer matching `layer_name` and device type is registered through | ||||
|     `register_layer_mapping`. The device type is inferred from the first | ||||
|     argument to `forward`. | ||||
|     """ | ||||
|  | ||||
|     fallback_forward = cls.forward | ||||
|  | ||||
|     cached_forward: Dict[LayerRepository, Callable] = {} | ||||
|  | ||||
|     def forward(self, x, *args, **kwargs): | ||||
|         if _DISABLE_KERNEL_MAPPING: | ||||
|             return fallback_forward(self, x, *args, **kwargs) | ||||
|  | ||||
|         kernel = _KERNEL_MAPPING.get().get(layer_name) | ||||
|         if kernel is None: | ||||
|             warnings.warn( | ||||
|                 "\n" | ||||
|                 f"No kernel mapping found for layer `{layer_name}`. " | ||||
|                 f"Check if the layer name matches one of the kernels in the mapping or add the kernel " | ||||
|                 f"you want to use to the mapping. Defaulting to original forward implementation." | ||||
|             ) | ||||
|             if not use_fallback: | ||||
|                 raise ValueError(f"No layer mapping for `{layer_name}`") | ||||
|             return fallback_forward(self, x, *args, **kwargs) | ||||
|  | ||||
|         device = getattr(x, "device", None) | ||||
|         if device is None: | ||||
|             return fallback_forward(self, x, *args, **kwargs) | ||||
|  | ||||
|         repo = kernel.get(Device(type=device.type)) | ||||
|         if repo is None: | ||||
|             if not use_fallback: | ||||
|                 raise ValueError( | ||||
|                     f"No layer mapping for `{layer_name}` with device type `{device.type}`" | ||||
|                 ) | ||||
|             return fallback_forward(self, x, *args, **kwargs) | ||||
|  | ||||
|         # Short-circuit if we already loaded the layer. | ||||
|         layer_forward = cached_forward.get(repo, None) | ||||
|         if layer_forward is not None: | ||||
|             return layer_forward(self, x, *args, **kwargs) | ||||
|  | ||||
|         layer = _get_kernel_layer( | ||||
|             repo_id=repo.repo_id, | ||||
|             layer_name=repo.layer_name, | ||||
|             revision=repo.revision, | ||||
|         ) | ||||
|  | ||||
|         # We have to validate against the original signature. | ||||
|         orig_forward = cls.forward | ||||
|         try: | ||||
|             cls.forward = fallback_forward | ||||
|             _validate_layer(check_cls=cls, cls=layer) | ||||
|         finally: | ||||
|             cls.forward = orig_forward | ||||
|  | ||||
|         layer_forward = layer.forward | ||||
|         cached_forward[repo] = layer_forward | ||||
|  | ||||
|         return layer_forward(self, x, *args, **kwargs) | ||||
|  | ||||
|     cls.forward = forward | ||||
|  | ||||
|  | ||||
| def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True): | ||||
|     """ | ||||
|     Replace the forward function of a layer using a layer from the kernel hub. | ||||
|     This decorator can be applied to a layer and replaces the forward method | ||||
|     of the layer with that of a layer from the hub. The replacement is done | ||||
|     when a layer matching `layer_name` and device type is registered through | ||||
|     `register_layer_mapping`. The device type is inferred from the first | ||||
|     argument to `forward`. | ||||
|     """ | ||||
|  | ||||
|     def decorator(cls): | ||||
|         replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback) | ||||
|         return cls | ||||
|  | ||||
|     return decorator | ||||
|  | ||||
|  | ||||
| def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module": | ||||
|     """Get a layer from a kernel.""" | ||||
|  | ||||
|     kernel = get_kernel(repo_id, revision=revision) | ||||
|  | ||||
|     if getattr(kernel, "layers", None) is None: | ||||
|         raise ValueError( | ||||
|             f"Kernel `{repo_id}` at revision `{revision}` does not define any layers." | ||||
|         ) | ||||
|  | ||||
|     layer = getattr(kernel.layers, layer_name, None) | ||||
|     if layer is None: | ||||
|         raise ValueError(f"Layer `{layer_name}` not found in kernel `{repo_id}`.") | ||||
|     return layer | ||||
|  | ||||
|  | ||||
| def _validate_layer(*, check_cls, cls): | ||||
|     # The layer must have at least have the following properties: (1) it | ||||
|     # must be stateless; (2) the forward signature should correspond to | ||||
|     # the signature it is replacing; (3) forward should not call other | ||||
|     # methods. | ||||
|  | ||||
|     from torch import nn | ||||
|  | ||||
|     if not issubclass(cls, nn.Module): | ||||
|         raise TypeError(f"Layer `{cls}` is not a Torch layer.") | ||||
|  | ||||
|     # We verify statelessness by checking that the does not have its own | ||||
|     # constructor (since the constructor could add member variables)... | ||||
|     if cls.__init__ is not nn.Module.__init__: | ||||
|         raise TypeError("Layer must not override nn.Module constructor.") | ||||
|  | ||||
|     # ... or predefined member variables. | ||||
|     torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)} | ||||
|     cls_members = {name for name, _ in inspect.getmembers(cls)} | ||||
|     if cls_members - torch_module_members != set(): | ||||
|         raise TypeError("Layer must not contain additional members.") | ||||
|  | ||||
|     # Check whether the forward signatures are similar. | ||||
|     params = inspect.signature(cls.forward).parameters | ||||
|     ref_params = inspect.signature(check_cls.forward).parameters | ||||
|  | ||||
|     if len(params) != len(ref_params): | ||||
|         raise TypeError( | ||||
|             "Forward signature does not match: different number of arguments." | ||||
|         ) | ||||
|  | ||||
|     for param, ref_param in zip(params.values(), ref_params.values()): | ||||
|         if param.kind != ref_param.kind: | ||||
|             raise TypeError( | ||||
|                 f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})" | ||||
|             ) | ||||
| @ -1,13 +1,14 @@ | ||||
| from dataclasses import dataclass | ||||
| import hashlib | ||||
| from dataclasses import dataclass | ||||
| from pathlib import Path | ||||
| from typing import Dict | ||||
| from typing import Dict, List, Tuple | ||||
| 
 | ||||
| from huggingface_hub import HfApi | ||||
| from huggingface_hub.hf_api import GitRefInfo | ||||
| from packaging.specifiers import SpecifierSet | ||||
| from packaging.version import InvalidVersion, Version | ||||
| 
 | ||||
| from hf_kernels.compat import tomllib | ||||
| from kernels.compat import tomllib | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| @ -30,7 +31,7 @@ class KernelLock: | ||||
|         return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants) | ||||
| 
 | ||||
| 
 | ||||
| def _get_available_versions(repo_id: str): | ||||
| def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]: | ||||
|     """Get kernel versions that are available in the repository.""" | ||||
|     versions = {} | ||||
|     for tag in HfApi().list_repo_refs(repo_id).tags: | ||||
| @ -44,7 +45,7 @@ def _get_available_versions(repo_id: str): | ||||
|     return versions | ||||
| 
 | ||||
| 
 | ||||
| def get_kernel_locks(repo_id: str, version_spec: str): | ||||
| def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock: | ||||
|     """ | ||||
|     Get the locks for a kernel with the given version spec. | ||||
| 
 | ||||
| @ -75,7 +76,7 @@ def get_kernel_locks(repo_id: str, version_spec: str): | ||||
|             f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}" | ||||
|         ) | ||||
| 
 | ||||
|     variant_files = {} | ||||
|     variant_files: Dict[str, List[Tuple[bytes, str]]] = {} | ||||
|     for sibling in r.siblings: | ||||
|         if sibling.rfilename.startswith("build/torch"): | ||||
|             if sibling.blob_id is None: | ||||
| @ -96,9 +97,9 @@ def get_kernel_locks(repo_id: str, version_spec: str): | ||||
|     variant_locks = {} | ||||
|     for variant, files in variant_files.items(): | ||||
|         m = hashlib.sha256() | ||||
|         for filename, hash in sorted(files): | ||||
|         for filename_bytes, hash in sorted(files): | ||||
|             # Filename as bytes. | ||||
|             m.update(filename) | ||||
|             m.update(filename_bytes) | ||||
|             # Git blob or LFS file hash as bytes. | ||||
|             m.update(bytes.fromhex(hash)) | ||||
| 
 | ||||
| @ -123,7 +124,7 @@ def write_egg_lockfile(cmd, basename, filename): | ||||
|     if kernel_versions is None: | ||||
|         return | ||||
| 
 | ||||
|     lock_path = cwd / "hf-kernels.lock" | ||||
|     lock_path = cwd / "kernels.lock" | ||||
|     if not lock_path.exists(): | ||||
|         logging.warning(f"Lock file {lock_path} does not exist") | ||||
|         # Ensure that the file gets deleted in editable installs. | ||||
| @ -4,41 +4,62 @@ import importlib | ||||
| import importlib.metadata | ||||
| import inspect | ||||
| import json | ||||
| import logging | ||||
| import os | ||||
| from pathlib import Path | ||||
| import platform | ||||
| import sys | ||||
| from importlib.metadata import Distribution | ||||
| from pathlib import Path | ||||
| from types import ModuleType | ||||
| from typing import Dict, List, Optional, Tuple | ||||
| 
 | ||||
| from huggingface_hub import hf_hub_download, snapshot_download | ||||
| from huggingface_hub import snapshot_download | ||||
| from packaging.version import parse | ||||
| 
 | ||||
| from hf_kernels.compat import tomllib | ||||
| from hf_kernels.lockfile import KernelLock, VariantLock | ||||
| 
 | ||||
| CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None) | ||||
| from kernels.lockfile import KernelLock, VariantLock | ||||
| 
 | ||||
| 
 | ||||
| def build_variant(): | ||||
| def _get_cache_dir() -> Optional[str]: | ||||
|     """Returns the kernels cache directory.""" | ||||
|     cache_dir = os.environ.get("HF_KERNELS_CACHE", None) | ||||
|     if cache_dir is not None: | ||||
|         logging.warning( | ||||
|             "HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead" | ||||
|         ) | ||||
|         return cache_dir | ||||
| 
 | ||||
|     return os.environ.get("KERNELS_CACHE", None) | ||||
| 
 | ||||
| 
 | ||||
| CACHE_DIR: Optional[str] = _get_cache_dir() | ||||
| 
 | ||||
| 
 | ||||
| def build_variant() -> str: | ||||
|     import torch | ||||
| 
 | ||||
|     if torch.version.cuda is None: | ||||
|         raise AssertionError( | ||||
|             "This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled." | ||||
|         ) | ||||
|     if torch.version.cuda is not None: | ||||
|         cuda_version = parse(torch.version.cuda) | ||||
|         compute_framework = f"cu{cuda_version.major}{cuda_version.minor}" | ||||
|     elif torch.version.hip is not None: | ||||
|         rocm_version = parse(torch.version.hip.split("-")[0]) | ||||
|         compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}" | ||||
|     else: | ||||
|         raise AssertionError("Torch was not compiled with CUDA or ROCm enabled.") | ||||
| 
 | ||||
|     torch_version = parse(torch.__version__) | ||||
|     cuda_version = parse(torch.version.cuda) | ||||
|     cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" | ||||
|     cpu = platform.machine() | ||||
|     os = platform.system().lower() | ||||
| 
 | ||||
|     return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}" | ||||
|     return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}" | ||||
| 
 | ||||
| 
 | ||||
| def import_from_path(module_name: str, file_path): | ||||
| def universal_build_variant() -> str: | ||||
|     # Once we support other frameworks, detection goes here. | ||||
|     return "torch-universal" | ||||
| 
 | ||||
| 
 | ||||
| def import_from_path(module_name: str, file_path: Path) -> ModuleType: | ||||
|     # We cannot use the module name as-is, after adding it to `sys.modules`, | ||||
|     # it would also be used for other imports. So, we make a module name that | ||||
|     # depends on the path for it to be unique using the hex-encoded hash of | ||||
| @ -46,9 +67,13 @@ def import_from_path(module_name: str, file_path): | ||||
|     path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value) | ||||
|     module_name = f"{module_name}_{path_hash}" | ||||
|     spec = importlib.util.spec_from_file_location(module_name, file_path) | ||||
|     if spec is None: | ||||
|         raise ImportError(f"Cannot load spec for {module_name} from {file_path}") | ||||
|     module = importlib.util.module_from_spec(spec) | ||||
|     if module is None: | ||||
|         raise ImportError(f"Cannot load module {module_name} from spec") | ||||
|     sys.modules[module_name] = module | ||||
|     spec.loader.exec_module(module) | ||||
|     spec.loader.exec_module(module)  # type: ignore | ||||
|     return module | ||||
| 
 | ||||
| 
 | ||||
| @ -56,29 +81,41 @@ def install_kernel( | ||||
|     repo_id: str, | ||||
|     revision: str, | ||||
|     local_files_only: bool = False, | ||||
|     variant_lock: Optional[VariantLock] = None, | ||||
| ) -> Tuple[str, str]: | ||||
|     variant_locks: Optional[Dict[str, VariantLock]] = None, | ||||
| ) -> Tuple[str, Path]: | ||||
|     """ | ||||
|     Download a kernel for the current environment to the cache. | ||||
| 
 | ||||
|     The output path is validated againt `hash` when set. | ||||
|     """ | ||||
|     package_name = repo_id.split("/")[-1] | ||||
|     package_name = package_name.replace("-", "_") | ||||
|     package_name = package_name_from_repo_id(repo_id) | ||||
|     variant = build_variant() | ||||
|     repo_path = snapshot_download( | ||||
|     universal_variant = universal_build_variant() | ||||
|     repo_path = Path( | ||||
|         snapshot_download( | ||||
|             repo_id, | ||||
|         allow_patterns=f"build/{variant}/*", | ||||
|             allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"], | ||||
|             cache_dir=CACHE_DIR, | ||||
|             revision=revision, | ||||
|             local_files_only=local_files_only, | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
|     if variant_lock is not None: | ||||
|     variant_path = repo_path / "build" / variant | ||||
|     universal_variant_path = repo_path / "build" / universal_variant | ||||
| 
 | ||||
|     if not variant_path.exists() and universal_variant_path.exists(): | ||||
|         # Fall back to universal variant. | ||||
|         variant = universal_variant | ||||
|         variant_path = universal_variant_path | ||||
| 
 | ||||
|     if variant_locks is not None: | ||||
|         variant_lock = variant_locks.get(variant) | ||||
|         if variant_lock is None: | ||||
|             raise ValueError(f"No lock found for build variant: {variant}") | ||||
|         validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash) | ||||
| 
 | ||||
|     variant_path = f"{repo_path}/build/{variant}" | ||||
|     module_init_path = f"{variant_path}/{package_name}/__init__.py" | ||||
|     module_init_path = variant_path / package_name / "__init__.py" | ||||
| 
 | ||||
|     if not os.path.exists(module_init_path): | ||||
|         raise FileNotFoundError( | ||||
| @ -93,7 +130,7 @@ def install_kernel_all_variants( | ||||
|     revision: str, | ||||
|     local_files_only: bool = False, | ||||
|     variant_locks: Optional[Dict[str, VariantLock]] = None, | ||||
| ) -> str: | ||||
| ) -> Path: | ||||
|     repo_path = Path( | ||||
|         snapshot_download( | ||||
|             repo_id, | ||||
| @ -116,52 +153,64 @@ def install_kernel_all_variants( | ||||
|                 repo_path=repo_path, variant=variant, hash=variant_lock.hash | ||||
|             ) | ||||
| 
 | ||||
|     return f"{repo_path}/build" | ||||
|     return repo_path / "build" | ||||
| 
 | ||||
| 
 | ||||
| def get_metadata(repo_id: str, revision: str, local_files_only: bool = False): | ||||
|     with open( | ||||
|         hf_hub_download( | ||||
|             repo_id, | ||||
|             "build.toml", | ||||
|             cache_dir=CACHE_DIR, | ||||
|             revision=revision, | ||||
|             local_files_only=local_files_only, | ||||
|         ), | ||||
|         "rb", | ||||
|     ) as f: | ||||
|         return tomllib.load(f) | ||||
| 
 | ||||
| 
 | ||||
| def get_kernel(repo_id: str, revision: str = "main"): | ||||
| def get_kernel(repo_id: str, revision: str = "main") -> ModuleType: | ||||
|     package_name, package_path = install_kernel(repo_id, revision=revision) | ||||
|     return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py") | ||||
|     return import_from_path(package_name, package_path / package_name / "__init__.py") | ||||
| 
 | ||||
| 
 | ||||
| def load_kernel(repo_id: str): | ||||
|     """Get a pre-downloaded, locked kernel.""" | ||||
| def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: | ||||
|     """ | ||||
|     Get a pre-downloaded, locked kernel. | ||||
| 
 | ||||
|     If `lockfile` is not specified, the lockfile will be loaded from the | ||||
|     caller's package metadata. | ||||
|     """ | ||||
|     if lockfile is None: | ||||
|         locked_sha = _get_caller_locked_kernel(repo_id) | ||||
|     else: | ||||
|         with open(lockfile, "r") as f: | ||||
|             locked_sha = _get_locked_kernel(repo_id, f.read()) | ||||
| 
 | ||||
|     if locked_sha is None: | ||||
|         raise ValueError(f"Kernel `{repo_id}` is not locked") | ||||
| 
 | ||||
|     filename = hf_hub_download( | ||||
|         repo_id, | ||||
|         "build.toml", | ||||
|         cache_dir=CACHE_DIR, | ||||
|         local_files_only=True, | ||||
|         revision=locked_sha, | ||||
|         raise ValueError( | ||||
|             f"Kernel `{repo_id}` is not locked. Please lock it with `kernels lock <project>` and then reinstall the project." | ||||
|         ) | ||||
|     with open(filename, "rb") as f: | ||||
|         metadata = tomllib.load(f) | ||||
|     package_name = metadata["torch"]["name"] | ||||
| 
 | ||||
|     repo_path = os.path.dirname(filename) | ||||
|     package_path = f"{repo_path}/build/{build_variant()}" | ||||
|     return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py") | ||||
|     package_name = package_name_from_repo_id(repo_id) | ||||
| 
 | ||||
|     variant = build_variant() | ||||
|     universal_variant = universal_build_variant() | ||||
| 
 | ||||
|     repo_path = Path( | ||||
|         snapshot_download( | ||||
|             repo_id, | ||||
|             allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"], | ||||
|             cache_dir=CACHE_DIR, | ||||
|             revision=locked_sha, | ||||
|             local_files_only=True, | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
|     variant_path = repo_path / "build" / variant | ||||
|     universal_variant_path = repo_path / "build" / universal_variant | ||||
|     if not variant_path.exists() and universal_variant_path.exists(): | ||||
|         # Fall back to universal variant. | ||||
|         variant = universal_variant | ||||
|         variant_path = universal_variant_path | ||||
| 
 | ||||
|     module_init_path = variant_path / package_name / "__init__.py" | ||||
|     if not os.path.exists(module_init_path): | ||||
|         raise FileNotFoundError( | ||||
|             f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download <project>`" | ||||
|         ) | ||||
| 
 | ||||
|     return import_from_path(package_name, variant_path / package_name / "__init__.py") | ||||
| 
 | ||||
| 
 | ||||
| def get_locked_kernel(repo_id: str, local_files_only: bool = False): | ||||
| def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType: | ||||
|     """Get a kernel using a lock file.""" | ||||
|     locked_sha = _get_caller_locked_kernel(repo_id) | ||||
| 
 | ||||
| @ -172,13 +221,21 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False): | ||||
|         repo_id, locked_sha, local_files_only=local_files_only | ||||
|     ) | ||||
| 
 | ||||
|     return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py") | ||||
|     return import_from_path(package_name, package_path / package_name / "__init__.py") | ||||
| 
 | ||||
| 
 | ||||
| def _get_caller_locked_kernel(repo_id: str) -> Optional[str]: | ||||
|     for dist in _get_caller_distributions(): | ||||
|         lock_json = dist.read_text("hf-kernels.lock") | ||||
|         if lock_json is not None: | ||||
|         lock_json = dist.read_text("kernels.lock") | ||||
|         if lock_json is None: | ||||
|             continue | ||||
|         locked_sha = _get_locked_kernel(repo_id, lock_json) | ||||
|         if locked_sha is not None: | ||||
|             return locked_sha | ||||
|     return None | ||||
| 
 | ||||
| 
 | ||||
| def _get_locked_kernel(repo_id: str, lock_json: str) -> Optional[str]: | ||||
|     for kernel_lock_json in json.loads(lock_json): | ||||
|         kernel_lock = KernelLock.from_json(kernel_lock_json) | ||||
|         if kernel_lock.repo_id == repo_id: | ||||
| @ -211,9 +268,9 @@ def _get_caller_module() -> Optional[ModuleType]: | ||||
|     return first_module | ||||
| 
 | ||||
| 
 | ||||
| def validate_kernel(*, repo_path: str, variant: str, hash: str): | ||||
| def validate_kernel(*, repo_path: Path, variant: str, hash: str): | ||||
|     """Validate the given build variant of a kernel against a hasht.""" | ||||
|     variant_path = Path(repo_path) / "build" / variant | ||||
|     variant_path = repo_path / "build" / variant | ||||
| 
 | ||||
|     # Get the file paths. The first element is a byte-encoded relative path | ||||
|     # used for sorting. The second element is the absolute path. | ||||
| @ -235,8 +292,8 @@ def validate_kernel(*, repo_path: str, variant: str, hash: str): | ||||
| 
 | ||||
|     m = hashlib.sha256() | ||||
| 
 | ||||
|     for filename, full_path in sorted(files): | ||||
|         m.update(filename) | ||||
|     for filename_bytes, full_path in sorted(files): | ||||
|         m.update(filename_bytes) | ||||
| 
 | ||||
|         blob_filename = full_path.resolve().name | ||||
|         if len(blob_filename) == 40: | ||||
| @ -262,3 +319,7 @@ def git_hash_object(data: bytes, object_type: str = "blob"): | ||||
|     m.update(header) | ||||
|     m.update(data) | ||||
|     return m.digest() | ||||
| 
 | ||||
| 
 | ||||
| def package_name_from_repo_id(repo_id: str) -> str: | ||||
|     return repo_id.split("/")[-1].replace("-", "_") | ||||
| @ -52,5 +52,15 @@ | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       } | ||||
|     } | ||||
|   }, | ||||
|   { | ||||
|     "repo_id": "kernels-community/triton-scaled-mm", | ||||
|     "sha": "af10d8c1affe8efce93d228c3e6e64ff673d493f", | ||||
|     "variants": { | ||||
|       "torch-universal": { | ||||
|         "hash": "sha256-b843c5f30b52b6c1c56fca28cb0cf453be71d6ce7d308f383dce71a8050f7b52", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| ] | ||||
| @ -1,2 +1,3 @@ | ||||
| [tool.kernels.dependencies] | ||||
| "kernels-community/activation" = ">=0.0.2" | ||||
| "kernels-community/triton-scaled-mm" = ">=0.0.2" | ||||
| @ -1,6 +1,7 @@ | ||||
| import pytest | ||||
| import torch | ||||
| from hf_kernels import get_kernel | ||||
|  | ||||
| from kernels import get_kernel | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| @ -8,6 +9,11 @@ def kernel(): | ||||
|     return get_kernel("kernels-community/activation") | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def universal_kernel(): | ||||
|     return get_kernel("kernels-community/triton-scaled-mm") | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def device(): | ||||
|     if not torch.cuda.is_available(): | ||||
| @ -28,3 +34,17 @@ def test_gelu_fast(kernel, device): | ||||
|     ) | ||||
|  | ||||
|     assert torch.allclose(y, expected) | ||||
|  | ||||
|  | ||||
| def test_universal_kernel(universal_kernel): | ||||
|     torch.manual_seed(0) | ||||
|     A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda") | ||||
|     B = torch.randint(-10, 10, (128, 96), dtype=torch.int8, device="cuda") | ||||
|     scale_a = torch.tensor(0.4, dtype=torch.float16, device="cuda") | ||||
|     scale_b = torch.tensor(0.6, dtype=torch.float16, device="cuda") | ||||
|  | ||||
|     out = universal_kernel.triton_scaled_mm(A, B, scale_a, scale_b, torch.float16) | ||||
|     out_check = (A * scale_a) @ (B * scale_b) | ||||
|     out_check = out_check.to(torch.float16) | ||||
|  | ||||
|     torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1) | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| import pytest | ||||
| import torch | ||||
| from hf_kernels import get_kernel | ||||
|  | ||||
| from kernels import get_kernel | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
|  | ||||
| @ -1,7 +1,8 @@ | ||||
| from dataclasses import dataclass | ||||
| from pathlib import Path | ||||
| 
 | ||||
| from hf_kernels.cli import download_kernels | ||||
| from kernels import load_kernel | ||||
| from kernels.cli import download_kernels | ||||
| 
 | ||||
| 
 | ||||
| # Mock download arguments class. | ||||
| @ -11,11 +12,13 @@ class DownloadArgs: | ||||
|     project_dir: Path | ||||
| 
 | ||||
| 
 | ||||
| def test_download_hash_validation(): | ||||
|     project_dir = Path(__file__).parent / "hash_validation" | ||||
|     download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir)) | ||||
| 
 | ||||
| 
 | ||||
| def test_download_all_hash_validation(): | ||||
|     project_dir = Path(__file__).parent / "hash_validation" | ||||
|     project_dir = Path(__file__).parent / "kernel_locking" | ||||
|     download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir)) | ||||
| 
 | ||||
| 
 | ||||
| def test_load_locked(): | ||||
|     project_dir = Path(__file__).parent / "kernel_locking" | ||||
|     # Also validates that hashing works correctly. | ||||
|     download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir)) | ||||
|     load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock") | ||||
							
								
								
									
										205
									
								
								tests/test_layer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										205
									
								
								tests/test_layer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,205 @@ | ||||
| import pytest | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.nn import functional as F | ||||
|  | ||||
| from kernels import ( | ||||
|     Device, | ||||
|     LayerRepository, | ||||
|     register_kernel_mapping, | ||||
|     use_kernel_forward_from_hub, | ||||
| ) | ||||
| from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping | ||||
|  | ||||
| kernel_layer_mapping = { | ||||
|     "SiluAndMul": { | ||||
|         Device(type="cuda"): LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|             revision="layers", | ||||
|         ) | ||||
|     }, | ||||
|     "SiluAndMulStringDevice": { | ||||
|         "cuda": LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|             revision="layers", | ||||
|         ) | ||||
|     }, | ||||
| } | ||||
|  | ||||
| register_kernel_mapping(kernel_layer_mapping) | ||||
|  | ||||
|  | ||||
| class SiluAndMul(nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         # Used to check that we called hub kernel. | ||||
|         self.n_calls = 0 | ||||
|  | ||||
|     def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         self.n_calls += 1 | ||||
|         d = input.shape[-1] // 2 | ||||
|         return F.silu(input[..., :d]) * input[..., d:] | ||||
|  | ||||
|  | ||||
| @use_kernel_forward_from_hub("SiluAndMul") | ||||
| class SiluAndMulWithKernel(SiluAndMul): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @use_kernel_forward_from_hub("SiluAndMulStringDevice") | ||||
| class SiluAndMulStringDevice(SiluAndMul): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| def test_arg_kinds(): | ||||
|     @use_kernel_forward_from_hub("ArgKind") | ||||
|     class ArgKind(nn.Module): | ||||
|         def forward( | ||||
|             self, | ||||
|             arg1, | ||||
|             arg2, | ||||
|             *, | ||||
|             kwarg1, | ||||
|             kwarg2=42, | ||||
|         ): | ||||
|             return (arg1, arg2, kwarg1, kwarg2) | ||||
|  | ||||
|     arg_kind = ArgKind() | ||||
|     assert arg_kind("foo", "bar", kwarg1="baz") == ("foo", "bar", "baz", 42) | ||||
|     assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5) | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice]) | ||||
| @pytest.mark.parametrize("device", ["cuda", "cpu"]) | ||||
| def test_hub_forward(cls, device): | ||||
|     torch.random.manual_seed(0) | ||||
|  | ||||
|     silu_and_mul = SiluAndMul() | ||||
|     X = torch.randn((32, 64), device=device) | ||||
|     Y = silu_and_mul(X) | ||||
|  | ||||
|     silu_and_mul_with_kernel = cls() | ||||
|     Y_kernel = silu_and_mul_with_kernel(X) | ||||
|  | ||||
|     torch.testing.assert_close(Y_kernel, Y) | ||||
|  | ||||
|     assert silu_and_mul.n_calls == 1 | ||||
|     if device == "cuda": | ||||
|         assert silu_and_mul_with_kernel.n_calls == 0 | ||||
|     else: | ||||
|         assert silu_and_mul_with_kernel.n_calls == 1 | ||||
|  | ||||
|  | ||||
| def test_layer_fallback_works(): | ||||
|     @use_kernel_forward_from_hub("SiluAndMulNonExisting") | ||||
|     class SiluAndMulWithKernelFallback(SiluAndMul): | ||||
|         pass | ||||
|  | ||||
|     # Check that we don't raise an exception for a non-existing kernel. | ||||
|     SiluAndMulWithKernelFallback() | ||||
|  | ||||
|  | ||||
| def test_mapping_contexts(): | ||||
|     assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"} | ||||
|  | ||||
|     extra_mapping1 = { | ||||
|         "TestKernel": { | ||||
|             Device(type="cuda"): LayerRepository( | ||||
|                 repo_id="kernels-community/activation", | ||||
|                 layer_name="SiluAndMul", | ||||
|                 revision="layers", | ||||
|             ) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     with use_kernel_mapping(extra_mapping1): | ||||
|         assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|             "SiluAndMul", | ||||
|             "SiluAndMulStringDevice", | ||||
|             "TestKernel", | ||||
|         } | ||||
|  | ||||
|         extra_mapping2 = { | ||||
|             "SiluAndMul": { | ||||
|                 Device(type="cuda"): LayerRepository( | ||||
|                     repo_id="kernels-community/non-existing", | ||||
|                     layer_name="SiluAndMul", | ||||
|                     revision="layers", | ||||
|                 ) | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         with use_kernel_mapping(extra_mapping2): | ||||
|             assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|                 "SiluAndMul", | ||||
|                 "SiluAndMulStringDevice", | ||||
|                 "TestKernel", | ||||
|             } | ||||
|             assert ( | ||||
|                 _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id | ||||
|                 == "kernels-community/non-existing" | ||||
|             ) | ||||
|  | ||||
|         assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|             "SiluAndMul", | ||||
|             "SiluAndMulStringDevice", | ||||
|             "TestKernel", | ||||
|         } | ||||
|         assert ( | ||||
|             _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id | ||||
|             == "kernels-community/activation" | ||||
|         ) | ||||
|  | ||||
|         with use_kernel_mapping(extra_mapping2, inherit_mapping=False): | ||||
|             assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|                 "SiluAndMul", | ||||
|             } | ||||
|             assert ( | ||||
|                 _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id | ||||
|                 == "kernels-community/non-existing" | ||||
|             ) | ||||
|  | ||||
|         assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|             "SiluAndMul", | ||||
|             "SiluAndMulStringDevice", | ||||
|             "TestKernel", | ||||
|         } | ||||
|         assert ( | ||||
|             _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id | ||||
|             == "kernels-community/activation" | ||||
|         ) | ||||
|  | ||||
|     assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|         "SiluAndMul", | ||||
|         "SiluAndMulStringDevice", | ||||
|     } | ||||
|  | ||||
|  | ||||
| def test_validate_kernel_layer(): | ||||
|     class BadLayer(nn.Module): | ||||
|         def __init__(self, *args, **kwargs): | ||||
|             super().__init__(*args, **kwargs) | ||||
|             self.foo = 42 | ||||
|  | ||||
|     with pytest.raises(TypeError, match="not override"): | ||||
|         _validate_layer(cls=BadLayer, check_cls=SiluAndMul) | ||||
|  | ||||
|     class BadLayer2(nn.Module): | ||||
|         foo: int = 42 | ||||
|  | ||||
|     with pytest.raises(TypeError, match="not contain additional members"): | ||||
|         _validate_layer(cls=BadLayer2, check_cls=SiluAndMul) | ||||
|  | ||||
|     class BadLayer3(nn.Module): | ||||
|         def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ... | ||||
|  | ||||
|     with pytest.raises(TypeError, match="different number of arguments"): | ||||
|         _validate_layer(cls=BadLayer3, check_cls=SiluAndMul) | ||||
|  | ||||
|     class BadLayer4(nn.Module): | ||||
|         def forward(self, *, x: torch.Tensor) -> torch.Tensor: ... | ||||
|  | ||||
|     with pytest.raises(TypeError, match="different kind of arguments"): | ||||
|         _validate_layer(cls=BadLayer4, check_cls=SiluAndMul) | ||||
		Reference in New Issue
	
	Block a user
	