mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-10-31 19:54:28 +08:00 
			
		
		
		
	Compare commits
	
		
			30 Commits
		
	
	
		
			fixup-arg-
			...
			v0.5.0
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| fcb9a80ce6 | |||
| c25bb32e6e | |||
| 2036892762 | |||
| 0f0de049cf | |||
| 59597df03e | |||
| 5e938ede40 | |||
| cf530c283a | |||
| 437f910336 | |||
| 6f1a6067c8 | |||
| 1d14abcef0 | |||
| 6fd2112e22 | |||
| 70f56ff856 | |||
| 7178b0b86c | |||
| 0bbf90a564 | |||
| 27d6ffcb80 | |||
| f7bd21438b | |||
| 6174febb4b | |||
| ff55bc201b | |||
| 3808108d62 | |||
| c4a16ef462 | |||
| 9762794dd2 | |||
| b7d6867c52 | |||
| fbcd0f2ebd | |||
| 5af46eca94 | |||
| 747dd66876 | |||
| 920590a592 | |||
| 5208ac4be5 | |||
| 22eaba2826 | |||
| 9521ba79a0 | |||
| 9861a5bdef | 
							
								
								
									
										120
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,120 @@ | ||||
| name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI | ||||
|  | ||||
| on: push | ||||
|  | ||||
| jobs: | ||||
|   build: | ||||
|     name: Build distribution 📦 | ||||
|     runs-on: ubuntu-latest | ||||
|  | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|         with: | ||||
|           persist-credentials: false | ||||
|       - name: Set up Python | ||||
|         uses: actions/setup-python@v5 | ||||
|         with: | ||||
|           python-version: "3.9" | ||||
|       - name: Install pypa/build | ||||
|         run: >- | ||||
|           python3 -m | ||||
|           pip install | ||||
|           build | ||||
|           --user | ||||
|       - name: Build a binary wheel and a source tarball | ||||
|         run: python3 -m build | ||||
|       - name: Store the distribution packages | ||||
|         uses: actions/upload-artifact@v4 | ||||
|         with: | ||||
|           name: python-package-distributions | ||||
|           path: dist/ | ||||
|  | ||||
|   publish-to-pypi: | ||||
|     name: >- | ||||
|       Publish Python 🐍 distribution 📦 to PyPI | ||||
|     if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes | ||||
|     needs: | ||||
|       - build | ||||
|     runs-on: ubuntu-latest | ||||
|     environment: | ||||
|       name: pypi | ||||
|       url: https://pypi.org/p/<package-name> # Replace <package-name> with your PyPI project name | ||||
|     permissions: | ||||
|       id-token: write # IMPORTANT: mandatory for trusted publishing | ||||
|  | ||||
|     steps: | ||||
|       - name: Download all the dists | ||||
|         uses: actions/download-artifact@v4 | ||||
|         with: | ||||
|           name: python-package-distributions | ||||
|           path: dist/ | ||||
|       - name: Publish distribution 📦 to PyPI | ||||
|         uses: pypa/gh-action-pypi-publish@release/v1 | ||||
|  | ||||
|   github-release: | ||||
|     name: >- | ||||
|       Sign the Python 🐍 distribution 📦 with Sigstore | ||||
|       and upload them to GitHub Release | ||||
|     needs: | ||||
|       - publish-to-pypi | ||||
|     runs-on: ubuntu-latest | ||||
|  | ||||
|     permissions: | ||||
|       contents: write # IMPORTANT: mandatory for making GitHub Releases | ||||
|       id-token: write # IMPORTANT: mandatory for sigstore | ||||
|  | ||||
|     steps: | ||||
|       - name: Download all the dists | ||||
|         uses: actions/download-artifact@v4 | ||||
|         with: | ||||
|           name: python-package-distributions | ||||
|           path: dist/ | ||||
|       - name: Sign the dists with Sigstore | ||||
|         uses: sigstore/gh-action-sigstore-python@v3.0.0 | ||||
|         with: | ||||
|           inputs: >- | ||||
|             ./dist/*.tar.gz | ||||
|             ./dist/*.whl | ||||
|       - name: Create GitHub Release | ||||
|         env: | ||||
|           GITHUB_TOKEN: ${{ github.token }} | ||||
|         run: >- | ||||
|           gh release create | ||||
|           "$GITHUB_REF_NAME" | ||||
|           --repo "$GITHUB_REPOSITORY" | ||||
|           --notes "" | ||||
|       - name: Upload artifact signatures to GitHub Release | ||||
|         env: | ||||
|           GITHUB_TOKEN: ${{ github.token }} | ||||
|         # Upload to GitHub Release using the `gh` CLI. | ||||
|         # `dist/` contains the built packages, and the | ||||
|         # sigstore-produced signatures and certificates. | ||||
|         run: >- | ||||
|           gh release upload | ||||
|           "$GITHUB_REF_NAME" dist/** | ||||
|           --repo "$GITHUB_REPOSITORY" | ||||
|  | ||||
|   publish-to-testpypi: | ||||
|     name: Publish Python 🐍 distribution 📦 to TestPyPI | ||||
|     needs: | ||||
|       - build | ||||
|     runs-on: ubuntu-latest | ||||
|  | ||||
|     environment: | ||||
|       name: testpypi | ||||
|       url: https://test.pypi.org/p/<package-name> | ||||
|  | ||||
|     permissions: | ||||
|       id-token: write # IMPORTANT: mandatory for trusted publishing | ||||
|  | ||||
|     steps: | ||||
|       - name: Download all the dists | ||||
|         uses: actions/download-artifact@v4 | ||||
|         with: | ||||
|           name: python-package-distributions | ||||
|           path: dist/ | ||||
|       - name: Publish distribution 📦 to TestPyPI | ||||
|         uses: pypa/gh-action-pypi-publish@release/v1 | ||||
|         with: | ||||
|           repository-url: https://test.pypi.org/legacy/ | ||||
|           skip-existing: true # Only upload when the version is unique. | ||||
							
								
								
									
										5
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							| @ -52,3 +52,8 @@ jobs: | ||||
|  | ||||
|       - name: Run tests | ||||
|         run: uv run pytest tests | ||||
|  | ||||
|       - name: Import check without torch | ||||
|         run: | | ||||
|           uv pip uninstall torch | ||||
|           python -c "import kernels" | ||||
|  | ||||
							
								
								
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,201 @@ | ||||
|                                  Apache License | ||||
|                            Version 2.0, January 2004 | ||||
|                         http://www.apache.org/licenses/ | ||||
|  | ||||
|    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||||
|  | ||||
|    1. Definitions. | ||||
|  | ||||
|       "License" shall mean the terms and conditions for use, reproduction, | ||||
|       and distribution as defined by Sections 1 through 9 of this document. | ||||
|  | ||||
|       "Licensor" shall mean the copyright owner or entity authorized by | ||||
|       the copyright owner that is granting the License. | ||||
|  | ||||
|       "Legal Entity" shall mean the union of the acting entity and all | ||||
|       other entities that control, are controlled by, or are under common | ||||
|       control with that entity. For the purposes of this definition, | ||||
|       "control" means (i) the power, direct or indirect, to cause the | ||||
|       direction or management of such entity, whether by contract or | ||||
|       otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||||
|       outstanding shares, or (iii) beneficial ownership of such entity. | ||||
|  | ||||
|       "You" (or "Your") shall mean an individual or Legal Entity | ||||
|       exercising permissions granted by this License. | ||||
|  | ||||
|       "Source" form shall mean the preferred form for making modifications, | ||||
|       including but not limited to software source code, documentation | ||||
|       source, and configuration files. | ||||
|  | ||||
|       "Object" form shall mean any form resulting from mechanical | ||||
|       transformation or translation of a Source form, including but | ||||
|       not limited to compiled object code, generated documentation, | ||||
|       and conversions to other media types. | ||||
|  | ||||
|       "Work" shall mean the work of authorship, whether in Source or | ||||
|       Object form, made available under the License, as indicated by a | ||||
|       copyright notice that is included in or attached to the work | ||||
|       (an example is provided in the Appendix below). | ||||
|  | ||||
|       "Derivative Works" shall mean any work, whether in Source or Object | ||||
|       form, that is based on (or derived from) the Work and for which the | ||||
|       editorial revisions, annotations, elaborations, or other modifications | ||||
|       represent, as a whole, an original work of authorship. For the purposes | ||||
|       of this License, Derivative Works shall not include works that remain | ||||
|       separable from, or merely link (or bind by name) to the interfaces of, | ||||
|       the Work and Derivative Works thereof. | ||||
|  | ||||
|       "Contribution" shall mean any work of authorship, including | ||||
|       the original version of the Work and any modifications or additions | ||||
|       to that Work or Derivative Works thereof, that is intentionally | ||||
|       submitted to Licensor for inclusion in the Work by the copyright owner | ||||
|       or by an individual or Legal Entity authorized to submit on behalf of | ||||
|       the copyright owner. For the purposes of this definition, "submitted" | ||||
|       means any form of electronic, verbal, or written communication sent | ||||
|       to the Licensor or its representatives, including but not limited to | ||||
|       communication on electronic mailing lists, source code control systems, | ||||
|       and issue tracking systems that are managed by, or on behalf of, the | ||||
|       Licensor for the purpose of discussing and improving the Work, but | ||||
|       excluding communication that is conspicuously marked or otherwise | ||||
|       designated in writing by the copyright owner as "Not a Contribution." | ||||
|  | ||||
|       "Contributor" shall mean Licensor and any individual or Legal Entity | ||||
|       on behalf of whom a Contribution has been received by Licensor and | ||||
|       subsequently incorporated within the Work. | ||||
|  | ||||
|    2. Grant of Copyright License. Subject to the terms and conditions of | ||||
|       this License, each Contributor hereby grants to You a perpetual, | ||||
|       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||
|       copyright license to reproduce, prepare Derivative Works of, | ||||
|       publicly display, publicly perform, sublicense, and distribute the | ||||
|       Work and such Derivative Works in Source or Object form. | ||||
|  | ||||
|    3. Grant of Patent License. Subject to the terms and conditions of | ||||
|       this License, each Contributor hereby grants to You a perpetual, | ||||
|       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||
|       (except as stated in this section) patent license to make, have made, | ||||
|       use, offer to sell, sell, import, and otherwise transfer the Work, | ||||
|       where such license applies only to those patent claims licensable | ||||
|       by such Contributor that are necessarily infringed by their | ||||
|       Contribution(s) alone or by combination of their Contribution(s) | ||||
|       with the Work to which such Contribution(s) was submitted. If You | ||||
|       institute patent litigation against any entity (including a | ||||
|       cross-claim or counterclaim in a lawsuit) alleging that the Work | ||||
|       or a Contribution incorporated within the Work constitutes direct | ||||
|       or contributory patent infringement, then any patent licenses | ||||
|       granted to You under this License for that Work shall terminate | ||||
|       as of the date such litigation is filed. | ||||
|  | ||||
|    4. Redistribution. You may reproduce and distribute copies of the | ||||
|       Work or Derivative Works thereof in any medium, with or without | ||||
|       modifications, and in Source or Object form, provided that You | ||||
|       meet the following conditions: | ||||
|  | ||||
|       (a) You must give any other recipients of the Work or | ||||
|           Derivative Works a copy of this License; and | ||||
|  | ||||
|       (b) You must cause any modified files to carry prominent notices | ||||
|           stating that You changed the files; and | ||||
|  | ||||
|       (c) You must retain, in the Source form of any Derivative Works | ||||
|           that You distribute, all copyright, patent, trademark, and | ||||
|           attribution notices from the Source form of the Work, | ||||
|           excluding those notices that do not pertain to any part of | ||||
|           the Derivative Works; and | ||||
|  | ||||
|       (d) If the Work includes a "NOTICE" text file as part of its | ||||
|           distribution, then any Derivative Works that You distribute must | ||||
|           include a readable copy of the attribution notices contained | ||||
|           within such NOTICE file, excluding those notices that do not | ||||
|           pertain to any part of the Derivative Works, in at least one | ||||
|           of the following places: within a NOTICE text file distributed | ||||
|           as part of the Derivative Works; within the Source form or | ||||
|           documentation, if provided along with the Derivative Works; or, | ||||
|           within a display generated by the Derivative Works, if and | ||||
|           wherever such third-party notices normally appear. The contents | ||||
|           of the NOTICE file are for informational purposes only and | ||||
|           do not modify the License. You may add Your own attribution | ||||
|           notices within Derivative Works that You distribute, alongside | ||||
|           or as an addendum to the NOTICE text from the Work, provided | ||||
|           that such additional attribution notices cannot be construed | ||||
|           as modifying the License. | ||||
|  | ||||
|       You may add Your own copyright statement to Your modifications and | ||||
|       may provide additional or different license terms and conditions | ||||
|       for use, reproduction, or distribution of Your modifications, or | ||||
|       for any such Derivative Works as a whole, provided Your use, | ||||
|       reproduction, and distribution of the Work otherwise complies with | ||||
|       the conditions stated in this License. | ||||
|  | ||||
|    5. Submission of Contributions. Unless You explicitly state otherwise, | ||||
|       any Contribution intentionally submitted for inclusion in the Work | ||||
|       by You to the Licensor shall be under the terms and conditions of | ||||
|       this License, without any additional terms or conditions. | ||||
|       Notwithstanding the above, nothing herein shall supersede or modify | ||||
|       the terms of any separate license agreement you may have executed | ||||
|       with Licensor regarding such Contributions. | ||||
|  | ||||
|    6. Trademarks. This License does not grant permission to use the trade | ||||
|       names, trademarks, service marks, or product names of the Licensor, | ||||
|       except as required for reasonable and customary use in describing the | ||||
|       origin of the Work and reproducing the content of the NOTICE file. | ||||
|  | ||||
|    7. Disclaimer of Warranty. Unless required by applicable law or | ||||
|       agreed to in writing, Licensor provides the Work (and each | ||||
|       Contributor provides its Contributions) on an "AS IS" BASIS, | ||||
|       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
|       implied, including, without limitation, any warranties or conditions | ||||
|       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||||
|       PARTICULAR PURPOSE. You are solely responsible for determining the | ||||
|       appropriateness of using or redistributing the Work and assume any | ||||
|       risks associated with Your exercise of permissions under this License. | ||||
|  | ||||
|    8. Limitation of Liability. In no event and under no legal theory, | ||||
|       whether in tort (including negligence), contract, or otherwise, | ||||
|       unless required by applicable law (such as deliberate and grossly | ||||
|       negligent acts) or agreed to in writing, shall any Contributor be | ||||
|       liable to You for damages, including any direct, indirect, special, | ||||
|       incidental, or consequential damages of any character arising as a | ||||
|       result of this License or out of the use or inability to use the | ||||
|       Work (including but not limited to damages for loss of goodwill, | ||||
|       work stoppage, computer failure or malfunction, or any and all | ||||
|       other commercial damages or losses), even if such Contributor | ||||
|       has been advised of the possibility of such damages. | ||||
|  | ||||
|    9. Accepting Warranty or Additional Liability. While redistributing | ||||
|       the Work or Derivative Works thereof, You may choose to offer, | ||||
|       and charge a fee for, acceptance of support, warranty, indemnity, | ||||
|       or other liability obligations and/or rights consistent with this | ||||
|       License. However, in accepting such obligations, You may act only | ||||
|       on Your own behalf and on Your sole responsibility, not on behalf | ||||
|       of any other Contributor, and only if You agree to indemnify, | ||||
|       defend, and hold each Contributor harmless for any liability | ||||
|       incurred by, or claims asserted against, such Contributor by reason | ||||
|       of your accepting any such warranty or additional liability. | ||||
|  | ||||
|    END OF TERMS AND CONDITIONS | ||||
|  | ||||
|    APPENDIX: How to apply the Apache License to your work. | ||||
|  | ||||
|       To apply the Apache License to your work, attach the following | ||||
|       boilerplate notice, with the fields enclosed by brackets "[]" | ||||
|       replaced with your own identifying information. (Don't include | ||||
|       the brackets!)  The text should be enclosed in the appropriate | ||||
|       comment syntax for the file format. We also recommend that a | ||||
|       file or class name and description of purpose be included on the | ||||
|       same "printed page" as the copyright notice for easier | ||||
|       identification within third-party archives. | ||||
|  | ||||
|    Copyright [yyyy] [name of copyright owner] | ||||
|  | ||||
|    Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|    you may not use this file except in compliance with the License. | ||||
|    You may obtain a copy of the License at | ||||
|  | ||||
|        http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
|    Unless required by applicable law or agreed to in writing, software | ||||
|    distributed under the License is distributed on an "AS IS" BASIS, | ||||
|    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|    See the License for the specific language governing permissions and | ||||
|    limitations under the License. | ||||
							
								
								
									
										12
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								README.md
									
									
									
									
									
								
							| @ -1,5 +1,16 @@ | ||||
| # kernels | ||||
|  | ||||
| <div align="center"> | ||||
| <img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo"> | ||||
| <p align="center"> | ||||
|     <a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a> | ||||
|     <a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a> | ||||
|     <a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a> | ||||
|    | ||||
| </p> | ||||
| </div> | ||||
| <hr/> | ||||
|  | ||||
| The Kernel Hub allows Python libraries and applications to load compute | ||||
| kernels directly from the [Hub](https://hf.co/). To support this kind | ||||
| of dynamic loading, Hub kernels differ from traditional Python kernel | ||||
| @ -47,6 +58,7 @@ the Hub. | ||||
|  | ||||
| - [Using layers](docs/layers.md) | ||||
| - [Locking kernel versions](docs/locking.md) | ||||
| - [Environment variables](docs/env.md) | ||||
| - [Using kernels in a Docker container](docs/docker.md) | ||||
| - [Kernel requirements](docs/kernel-requirements.md) | ||||
| - [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/) | ||||
|  | ||||
							
								
								
									
										10
									
								
								docs/env.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								docs/env.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,10 @@ | ||||
| # Environment variables | ||||
|  | ||||
| ## `KERNELS_CACHE` | ||||
|  | ||||
| The directory to use as the local kernel cache. If not set, the cache | ||||
| of the `huggingface_hub` package is used. | ||||
|  | ||||
| ## `DISABLE_KERNEL_MAPPING` | ||||
|  | ||||
| Disables kernel mappings for [`layers`](layers.md). | ||||
| @ -1,8 +1,11 @@ | ||||
| # Kernel requirements | ||||
|  | ||||
| Kernels on the Hub must fulfill the requirements outlined on this page. | ||||
| Kernels on the Hub must fulfill the requirements outlined on this page. By | ||||
| ensuring kernels are compliant, they can be used on a wide range of Linux | ||||
| systems and Torch builds. | ||||
|  | ||||
| You can use [kernel-builder](https://github.com/huggingface/kernel-builder/) | ||||
| to build conforming kernels. | ||||
| to build compliant kernels. | ||||
|  | ||||
| ## Directory layout | ||||
|  | ||||
| @ -10,34 +13,27 @@ A kernel repository on the Hub must contain a `build` directory. This | ||||
| directory contains build variants of a kernel in the form of directories | ||||
| following the template | ||||
| `<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`. | ||||
| For example `build/torch26-cxx98-cu118-x86_64-linux`. The currently | ||||
| recommended build variants are: | ||||
| For example `build/torch26-cxx98-cu118-x86_64-linux`. | ||||
|  | ||||
| - `torch25-cxx11-cu118-x86_64-linux` | ||||
| - `torch25-cxx11-cu121-x86_64-linux` | ||||
| - `torch25-cxx11-cu124-x86_64-linux` | ||||
| - `torch25-cxx98-cu118-x86_64-linux` | ||||
| - `torch25-cxx98-cu121-x86_64-linux` | ||||
| - `torch25-cxx98-cu124-x86_64-linux` | ||||
| - `torch26-cxx11-cu118-x86_64-linux` | ||||
| - `torch26-cxx11-cu124-x86_64-linux` | ||||
| - `torch26-cxx11-cu126-x86_64-linux` | ||||
| - `torch26-cxx98-cu118-x86_64-linux` | ||||
| - `torch26-cxx98-cu124-x86_64-linux` | ||||
| - `torch26-cxx98-cu126-x86_64-linux` | ||||
|  | ||||
| This list will be updated as new PyTorch versions are released. Kernels | ||||
| that are in pure Python (e.g. Triton kernels) only need to provide a | ||||
| single build variant: | ||||
|  | ||||
| - `torch-universal` | ||||
|  | ||||
| Each variant directory should contain a single directory with the same name | ||||
| Each variant directory must contain a single directory with the same name | ||||
| as the repository (replacing `-` by `_`). For instance, kernels in the | ||||
| `kernels-community/activation` repository have a directories like | ||||
| `build/<variant>/activation`. This directory | ||||
| must be a Python package with an `__init__.py` file. | ||||
|  | ||||
| ## Build variants | ||||
|  | ||||
| A kernel can be compliant for a specific compute framework (e.g. CUDA) or | ||||
| architecture (e.g. x86_64). For compliance with a compute framework and | ||||
| architecture combination, all the variants from the [build variant list](https://github.com/huggingface/kernel-builder/blob/main/docs/build-variants.md) | ||||
| must be available for that combination. | ||||
|  | ||||
| ## Versioning | ||||
|  | ||||
| Kernels are versioned on the Hub using Git tags. Version tags must be of | ||||
| the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md) | ||||
| to resolve the version constraints. | ||||
|  | ||||
| ## Native Python module | ||||
|  | ||||
| Kernels will typically contain a native Python module with precompiled | ||||
| @ -46,16 +42,31 @@ requirements: | ||||
|  | ||||
| - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | ||||
|   for compatibility with Python 3.9 and later. | ||||
| - Compatible with glibc 2.27 or later. This means that no symbols | ||||
|   from later versions must be used. To archive this, the module should | ||||
|   be built against this glibc version. **Warning:** libgcc must also be | ||||
|   built against glibc 2.27 to avoid leaking symbols. | ||||
| - No dynamic linkage against libstdc++/libc++. Linkage for C++ symbols | ||||
|   must be static. | ||||
| - No dynamic library dependencies outside Torch or CUDA libraries | ||||
|   installed as dependencies of Torch. | ||||
| - Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based). | ||||
|   This means that the extension **must not** use symbols versions higher than: | ||||
|  | ||||
| (These requirements will be updated as new PyTorch versions are released.) | ||||
|   - GLIBC 2.28 | ||||
|   - GLIBCXX 3.4.24 | ||||
|   - CXXABI 1.3.11 | ||||
|   - GCC 7.0.0 | ||||
|  | ||||
|   These requirement can be checked with the ABI checker (see below). | ||||
|  | ||||
| - No dynamic library dependencies outside: | ||||
|  | ||||
|   - Torch; | ||||
|   - CUDA/ROCm libraries installed as dependencies of Torch. | ||||
|  | ||||
| The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with | ||||
| [`kernel-abi-check`](https://crates.io/crates/kernel-abi-check): | ||||
|  | ||||
| ```bash | ||||
|  | ||||
| $ cargo install kernel-abi-check | ||||
| $ kernel-abi-check result/relu/_relu_e87e0ca_dirty.abi3.so | ||||
| 🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9 | ||||
| ✅ No compatibility issues found | ||||
| ``` | ||||
|  | ||||
| ## Torch extension | ||||
|  | ||||
| @ -98,10 +109,20 @@ requirements: | ||||
| - The `forward` method has a signature that is compatible with the | ||||
|   `forward` method that it is extending. | ||||
|  | ||||
| There are two exceptions to the _no class variables rule_: | ||||
|  | ||||
| 1. The `has_backward` variable can be used to indicate whether the layer has | ||||
|    a backward pass implemented (`True` when absent). | ||||
| 2. The `can_torch_compile` variable can be used to indicate whether the layer | ||||
|    supports `torch.compile` (`False` when absent). | ||||
|  | ||||
| This is an example of a pure layer: | ||||
|  | ||||
| ```python | ||||
| class SiluAndMul(nn.Module): | ||||
|     # This layer does not implement backward. | ||||
|     has_backward: bool = False | ||||
|  | ||||
|     def forward(self, x: torch.Tensor): | ||||
|         d = x.shape[-1] // 2 | ||||
|         output_shape = x.shape[:-1] + (d,) | ||||
|  | ||||
| @ -13,7 +13,7 @@ build-backend = "setuptools.build_meta" | ||||
| "kernels-community/activation" = ">=0.0.1" | ||||
| ``` | ||||
|  | ||||
| Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with | ||||
| Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with | ||||
| the locked revisions. The locked revision will be used when loading a kernel with | ||||
| `get_locked_kernel`: | ||||
|  | ||||
| @ -28,7 +28,7 @@ to `kernels` after doing an (editable or regular) installation of your project. | ||||
|  | ||||
| ## Pre-downloading locked kernels | ||||
|  | ||||
| Locked kernels can be pre-downloaded by running `kernel download .` in your | ||||
| Locked kernels can be pre-downloaded by running `kernels download .` in your | ||||
| project directory. This will download the kernels to your local Hugging Face | ||||
| Hub cache. | ||||
|  | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| [project] | ||||
| name = "kernels" | ||||
| version = "0.3.0" | ||||
| version = "0.5.0" | ||||
| description = "Download compute kernels" | ||||
| authors = [ | ||||
|   { name = "OlivierDehaene", email = "olivier@huggingface.co" }, | ||||
| @ -8,13 +8,13 @@ authors = [ | ||||
|   { name = "David Holtz", email = "david@huggingface.co" }, | ||||
|   { name = "Nicolas Patry", email = "nicolas@huggingface.co" }, | ||||
| ] | ||||
| license = { text = "Apache-2.0" } | ||||
| readme = "README.md" | ||||
| requires-python = ">= 3.9" | ||||
| dependencies = [ | ||||
|   "huggingface-hub>=0.26.3", | ||||
|   "packaging>=24.2", | ||||
|   "tomli>=2.0.1; python_version<'3.11'", | ||||
|   "torch>=2.5", | ||||
|   "huggingface_hub>=0.26.0,<1.0", | ||||
|   "packaging>=20.0", | ||||
|   "tomli>=2.0; python_version<'3.11'", | ||||
| ] | ||||
|  | ||||
| [build-system] | ||||
| @ -27,8 +27,12 @@ dev = [ | ||||
|   "pytest >=8", | ||||
|   # Whatever version is compatible with pytest. | ||||
|   "pytest-benchmark", | ||||
|   "torch >=2.5", | ||||
| ] | ||||
|  | ||||
| [project.optional-dependencies] | ||||
| torch = ["torch"] | ||||
|  | ||||
| [project.scripts] | ||||
| kernels = "kernels.cli:main" | ||||
|  | ||||
|  | ||||
| @ -2,11 +2,14 @@ from kernels.layer import ( | ||||
|     Device, | ||||
|     LayerRepository, | ||||
|     register_kernel_mapping, | ||||
|     replace_kernel_forward_from_hub, | ||||
|     use_kernel_forward_from_hub, | ||||
|     use_kernel_mapping, | ||||
| ) | ||||
| from kernels.utils import ( | ||||
|     get_kernel, | ||||
|     get_locked_kernel, | ||||
|     has_kernel, | ||||
|     install_kernel, | ||||
|     load_kernel, | ||||
| ) | ||||
| @ -14,10 +17,13 @@ from kernels.utils import ( | ||||
| __all__ = [ | ||||
|     "get_kernel", | ||||
|     "get_locked_kernel", | ||||
|     "has_kernel", | ||||
|     "load_kernel", | ||||
|     "install_kernel", | ||||
|     "use_kernel_forward_from_hub", | ||||
|     "use_kernel_mapping", | ||||
|     "register_kernel_mapping", | ||||
|     "replace_kernel_forward_from_hub", | ||||
|     "LayerRepository", | ||||
|     "Device", | ||||
| ] | ||||
|  | ||||
| @ -1,14 +1,18 @@ | ||||
| import inspect | ||||
| import os | ||||
| import warnings | ||||
| from contextvars import ContextVar | ||||
| from copy import deepcopy | ||||
| from dataclasses import dataclass, field | ||||
| from typing import TYPE_CHECKING, Callable, Dict, Union | ||||
| from typing import TYPE_CHECKING, Dict, Union | ||||
|  | ||||
| from .utils import get_kernel | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from torch import nn | ||||
|  | ||||
| _DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0"))) | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True) | ||||
| class Device: | ||||
| @ -54,11 +58,26 @@ _KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextV | ||||
| ) | ||||
|  | ||||
|  | ||||
| def use_kernel_mapping(mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]): | ||||
| def use_kernel_mapping( | ||||
|     mapping: Dict[str, Dict[Union[Device, str], LayerRepository]], | ||||
|     *, | ||||
|     inherit_mapping: bool = True, | ||||
| ): | ||||
|     """ | ||||
|     Context manager that sets a mapping for a duration of the context. | ||||
|  | ||||
|     When `inherit_mapping` is set to `True` the current mapping will be | ||||
|     extended by `mapping` inside the context. If it is `False`, only | ||||
|     `mapping` is used inside the context. | ||||
|     """ | ||||
|  | ||||
|     class ContextManager: | ||||
|         def __enter__(self): | ||||
|             # Mappings always stack on previous mappings. | ||||
|             if inherit_mapping: | ||||
|                 self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get())) | ||||
|             else: | ||||
|                 self.token = _KERNEL_MAPPING.set({}) | ||||
|             register_kernel_mapping(mapping) | ||||
|  | ||||
|         def __exit__(self, exc_type, exc_value, traceback): | ||||
| @ -112,18 +131,30 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool | ||||
|  | ||||
|     fallback_forward = cls.forward | ||||
|  | ||||
|     cached_forward: Dict[LayerRepository, Callable] = {} | ||||
|     cached_layer: Dict[LayerRepository, nn.Module] = {} | ||||
|  | ||||
|     def forward(self, x, *args, **kwargs): | ||||
|         if _DISABLE_KERNEL_MAPPING: | ||||
|             return fallback_forward(self, x, *args, **kwargs) | ||||
|  | ||||
|         needs_backward = self.training | ||||
|         is_compiling = _is_torchdynamo_compiling() | ||||
|  | ||||
|     def forward(self, x, **args): | ||||
|         kernel = _KERNEL_MAPPING.get().get(layer_name) | ||||
|         if kernel is None: | ||||
|             warnings.warn( | ||||
|                 "\n" | ||||
|                 f"No kernel mapping found for layer `{layer_name}`. " | ||||
|                 f"Check if the layer name matches one of the kernels in the mapping or add the kernel " | ||||
|                 f"you want to use to the mapping. Defaulting to original forward implementation." | ||||
|             ) | ||||
|             if not use_fallback: | ||||
|                 raise ValueError(f"No layer mapping for `{layer_name}`") | ||||
|             return fallback_forward(self, x, **args) | ||||
|             return fallback_forward(self, x, *args, **kwargs) | ||||
|  | ||||
|         device = getattr(x, "device", None) | ||||
|         if device is None: | ||||
|             return fallback_forward(self, x, **args) | ||||
|             return fallback_forward(self, x, *args, **kwargs) | ||||
|  | ||||
|         repo = kernel.get(Device(type=device.type)) | ||||
|         if repo is None: | ||||
| @ -131,12 +162,21 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool | ||||
|                 raise ValueError( | ||||
|                     f"No layer mapping for `{layer_name}` with device type `{device.type}`" | ||||
|                 ) | ||||
|             return fallback_forward(self, x, **args) | ||||
|             return fallback_forward(self, x, *args, **kwargs) | ||||
|  | ||||
|         # Short-circuit if we already loaded the layer. | ||||
|         layer_forward = cached_forward.get(repo, None) | ||||
|         if layer_forward is not None: | ||||
|             return layer_forward(self, x, **args) | ||||
|         layer = cached_layer.get(repo, None) | ||||
|         if layer is not None: | ||||
|             # Switch to fallback when the layer does not support: | ||||
|             # compilation/compile when needed. | ||||
|             # backward when needed | ||||
|             needs_fallback = needs_backward and not getattr(layer, "has_backward", True) | ||||
|             needs_fallback |= is_compiling and not getattr( | ||||
|                 layer, "can_torch_compile", False | ||||
|             ) | ||||
|             if needs_fallback: | ||||
|                 return fallback_forward(self, x, *args, **kwargs) | ||||
|             return layer.forward(self, x, *args, **kwargs) | ||||
|  | ||||
|         layer = _get_kernel_layer( | ||||
|             repo_id=repo.repo_id, | ||||
| @ -152,10 +192,18 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool | ||||
|         finally: | ||||
|             cls.forward = orig_forward | ||||
|  | ||||
|         layer_forward = layer.forward | ||||
|         cached_forward[repo] = layer_forward | ||||
|         cached_layer[repo] = layer | ||||
|  | ||||
|         return layer_forward(self, x, **args) | ||||
|         # Switch to fallback when the layer does not support | ||||
|         # compilation/compile when needed. | ||||
|         needs_fallback = needs_backward and not getattr(layer, "has_backward", True) | ||||
|         needs_fallback |= is_compiling and not getattr( | ||||
|             layer, "can_torch_compile", False | ||||
|         ) | ||||
|         if needs_fallback: | ||||
|             return fallback_forward(self, x, *args, **kwargs) | ||||
|  | ||||
|         return layer.forward(self, x, *args, **kwargs) | ||||
|  | ||||
|     cls.forward = forward | ||||
|  | ||||
| @ -212,7 +260,9 @@ def _validate_layer(*, check_cls, cls): | ||||
|     # ... or predefined member variables. | ||||
|     torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)} | ||||
|     cls_members = {name for name, _ in inspect.getmembers(cls)} | ||||
|     if cls_members - torch_module_members != set(): | ||||
|     difference = cls_members - torch_module_members | ||||
|     # verify if : difference ⊄ {"can_torch_compile", "has_backward"} | ||||
|     if not difference <= {"can_torch_compile", "has_backward"}: | ||||
|         raise TypeError("Layer must not contain additional members.") | ||||
|  | ||||
|     # Check whether the forward signatures are similar. | ||||
| @ -229,3 +279,19 @@ def _validate_layer(*, check_cls, cls): | ||||
|             raise TypeError( | ||||
|                 f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})" | ||||
|             ) | ||||
|  | ||||
|  | ||||
| def _is_torchdynamo_compiling(): | ||||
|     # Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) | ||||
|     # hence rather relying on `torch.compiler.is_compiling()` when possible (torch>=2.3) | ||||
|     try: | ||||
|         import torch | ||||
|  | ||||
|         return torch.compiler.is_compiling() | ||||
|     except Exception: | ||||
|         try: | ||||
|             import torch._dynamo as dynamo  # noqa: F401 | ||||
|  | ||||
|             return dynamo.is_compiling() | ||||
|         except Exception: | ||||
|             return False | ||||
|  | ||||
| @ -4,6 +4,7 @@ import importlib | ||||
| import importlib.metadata | ||||
| import inspect | ||||
| import json | ||||
| import logging | ||||
| import os | ||||
| import platform | ||||
| import sys | ||||
| @ -12,29 +13,45 @@ from pathlib import Path | ||||
| from types import ModuleType | ||||
| from typing import Dict, List, Optional, Tuple | ||||
|  | ||||
| from huggingface_hub import snapshot_download | ||||
| from huggingface_hub import file_exists, snapshot_download | ||||
| from packaging.version import parse | ||||
|  | ||||
| from kernels.lockfile import KernelLock, VariantLock | ||||
|  | ||||
| CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None) | ||||
|  | ||||
| def _get_cache_dir() -> Optional[str]: | ||||
|     """Returns the kernels cache directory.""" | ||||
|     cache_dir = os.environ.get("HF_KERNELS_CACHE", None) | ||||
|     if cache_dir is not None: | ||||
|         logging.warning( | ||||
|             "HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead" | ||||
|         ) | ||||
|         return cache_dir | ||||
|  | ||||
|     return os.environ.get("KERNELS_CACHE", None) | ||||
|  | ||||
|  | ||||
| CACHE_DIR: Optional[str] = _get_cache_dir() | ||||
|  | ||||
|  | ||||
| def build_variant() -> str: | ||||
|     import torch | ||||
|  | ||||
|     if torch.version.cuda is None: | ||||
|         raise AssertionError( | ||||
|             "This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled." | ||||
|         ) | ||||
|     if torch.version.cuda is not None: | ||||
|         cuda_version = parse(torch.version.cuda) | ||||
|         compute_framework = f"cu{cuda_version.major}{cuda_version.minor}" | ||||
|     elif torch.version.hip is not None: | ||||
|         rocm_version = parse(torch.version.hip.split("-")[0]) | ||||
|         compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}" | ||||
|     else: | ||||
|         raise AssertionError("Torch was not compiled with CUDA or ROCm enabled.") | ||||
|  | ||||
|     torch_version = parse(torch.__version__) | ||||
|     cuda_version = parse(torch.version.cuda) | ||||
|     cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" | ||||
|     cpu = platform.machine() | ||||
|     os = platform.system().lower() | ||||
|  | ||||
|     return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}" | ||||
|     return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}" | ||||
|  | ||||
|  | ||||
| def universal_build_variant() -> str: | ||||
| @ -144,6 +161,29 @@ def get_kernel(repo_id: str, revision: str = "main") -> ModuleType: | ||||
|     return import_from_path(package_name, package_path / package_name / "__init__.py") | ||||
|  | ||||
|  | ||||
| def has_kernel(repo_id: str, revision: str = "main") -> bool: | ||||
|     """ | ||||
|     Check whether a kernel build exists for the current environment | ||||
|     (Torch version and compute framework). | ||||
|     """ | ||||
|     package_name = package_name_from_repo_id(repo_id) | ||||
|     variant = build_variant() | ||||
|     universal_variant = universal_build_variant() | ||||
|  | ||||
|     if file_exists( | ||||
|         repo_id, | ||||
|         revision=revision, | ||||
|         filename=f"build/{universal_variant}/{package_name}/__init__.py", | ||||
|     ): | ||||
|         return True | ||||
|  | ||||
|     return file_exists( | ||||
|         repo_id, | ||||
|         revision=revision, | ||||
|         filename=f"build/{variant}/{package_name}/__init__.py", | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: | ||||
|     """ | ||||
|     Get a pre-downloaded, locked kernel. | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| import pytest | ||||
| import torch | ||||
|  | ||||
| from kernels import get_kernel | ||||
| from kernels import get_kernel, has_kernel | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| @ -36,6 +36,22 @@ def test_gelu_fast(kernel, device): | ||||
|     assert torch.allclose(y, expected) | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "kernel_exists", | ||||
|     [ | ||||
|         ("kernels-community/activation", "main", True), | ||||
|         ("kernels-community/triton-layer-norm", "main", True), | ||||
|         # Repo only contains Torch 2.4 kernels (and we don't | ||||
|         # support/test against this version). | ||||
|         ("kernels-test/only-torch-2.4", "main", False), | ||||
|         ("google-bert/bert-base-uncased", "87565a309", False), | ||||
|     ], | ||||
| ) | ||||
| def test_has_kernel(kernel_exists): | ||||
|     repo_id, revision, kernel = kernel_exists | ||||
|     assert has_kernel(repo_id, revision=revision) == kernel | ||||
|  | ||||
|  | ||||
| def test_universal_kernel(universal_kernel): | ||||
|     torch.manual_seed(0) | ||||
|     A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda") | ||||
|  | ||||
| @ -19,6 +19,12 @@ kernel_layer_mapping = { | ||||
|             revision="layers", | ||||
|         ) | ||||
|     }, | ||||
|     "SiluAndMulNoCompile": { | ||||
|         "cuda": LayerRepository( | ||||
|             repo_id="kernels-test/op-without-fake-test", | ||||
|             layer_name="SiluAndMul", | ||||
|         ) | ||||
|     }, | ||||
|     "SiluAndMulStringDevice": { | ||||
|         "cuda": LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
| @ -43,6 +49,11 @@ class SiluAndMul(nn.Module): | ||||
|         return F.silu(input[..., :d]) * input[..., d:] | ||||
|  | ||||
|  | ||||
| @use_kernel_forward_from_hub("SiluAndMulNoCompile") | ||||
| class SiluAndMulNoCompileKernel(SiluAndMul): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @use_kernel_forward_from_hub("SiluAndMul") | ||||
| class SiluAndMulWithKernel(SiluAndMul): | ||||
|     pass | ||||
| @ -53,6 +64,24 @@ class SiluAndMulStringDevice(SiluAndMul): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| def test_arg_kinds(): | ||||
|     @use_kernel_forward_from_hub("ArgKind") | ||||
|     class ArgKind(nn.Module): | ||||
|         def forward( | ||||
|             self, | ||||
|             arg1, | ||||
|             arg2, | ||||
|             *, | ||||
|             kwarg1, | ||||
|             kwarg2=42, | ||||
|         ): | ||||
|             return (arg1, arg2, kwarg1, kwarg2) | ||||
|  | ||||
|     arg_kind = ArgKind() | ||||
|     assert arg_kind("foo", "bar", kwarg1="baz") == ("foo", "bar", "baz", 42) | ||||
|     assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5) | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice]) | ||||
| @pytest.mark.parametrize("device", ["cuda", "cpu"]) | ||||
| def test_hub_forward(cls, device): | ||||
| @ -83,8 +112,29 @@ def test_layer_fallback_works(): | ||||
|     SiluAndMulWithKernelFallback() | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel]) | ||||
| @pytest.mark.parametrize("device", ["cuda", "cpu"]) | ||||
| def test_torch_compile_layer(cls, device): | ||||
|     silu_and_mul = SiluAndMul() | ||||
|  | ||||
|     X = torch.randn((32, 64), dtype=torch.float32, device=device) | ||||
|     Y = silu_and_mul(X) | ||||
|  | ||||
|     silu_and_mul_with_kernel = cls() | ||||
|     silu_and_mul_with_kernel.eval() | ||||
|     silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel) | ||||
|  | ||||
|     Y_compiled = silu_and_mul_compiled(X) | ||||
|  | ||||
|     torch.testing.assert_close(Y_compiled, Y) | ||||
|  | ||||
|  | ||||
| def test_mapping_contexts(): | ||||
|     assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"} | ||||
|     assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|         "SiluAndMul", | ||||
|         "SiluAndMulStringDevice", | ||||
|         "SiluAndMulNoCompile", | ||||
|     } | ||||
|  | ||||
|     extra_mapping1 = { | ||||
|         "TestKernel": { | ||||
| @ -100,6 +150,7 @@ def test_mapping_contexts(): | ||||
|         assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|             "SiluAndMul", | ||||
|             "SiluAndMulStringDevice", | ||||
|             "SiluAndMulNoCompile", | ||||
|             "TestKernel", | ||||
|         } | ||||
|  | ||||
| @ -117,6 +168,7 @@ def test_mapping_contexts(): | ||||
|             assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|                 "SiluAndMul", | ||||
|                 "SiluAndMulStringDevice", | ||||
|                 "SiluAndMulNoCompile", | ||||
|                 "TestKernel", | ||||
|             } | ||||
|             assert ( | ||||
| @ -127,6 +179,27 @@ def test_mapping_contexts(): | ||||
|         assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|             "SiluAndMul", | ||||
|             "SiluAndMulStringDevice", | ||||
|             "SiluAndMulNoCompile", | ||||
|             "TestKernel", | ||||
|         } | ||||
|         assert ( | ||||
|             _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id | ||||
|             == "kernels-community/activation" | ||||
|         ) | ||||
|  | ||||
|         with use_kernel_mapping(extra_mapping2, inherit_mapping=False): | ||||
|             assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|                 "SiluAndMul", | ||||
|             } | ||||
|             assert ( | ||||
|                 _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id | ||||
|                 == "kernels-community/non-existing" | ||||
|             ) | ||||
|  | ||||
|         assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|             "SiluAndMul", | ||||
|             "SiluAndMulStringDevice", | ||||
|             "SiluAndMulNoCompile", | ||||
|             "TestKernel", | ||||
|         } | ||||
|         assert ( | ||||
| @ -137,6 +210,7 @@ def test_mapping_contexts(): | ||||
|     assert set(_KERNEL_MAPPING.get().keys()) == { | ||||
|         "SiluAndMul", | ||||
|         "SiluAndMulStringDevice", | ||||
|         "SiluAndMulNoCompile", | ||||
|     } | ||||
|  | ||||
|  | ||||
| @ -166,3 +240,75 @@ def test_validate_kernel_layer(): | ||||
|  | ||||
|     with pytest.raises(TypeError, match="different kind of arguments"): | ||||
|         _validate_layer(cls=BadLayer4, check_cls=SiluAndMul) | ||||
|  | ||||
|  | ||||
| def test_fallback_used_when_training(): | ||||
|     @use_kernel_forward_from_hub("Linear") | ||||
|     class TorchLinear(nn.Linear): | ||||
|         def __init__(self, *args, **kwargs): | ||||
|             super().__init__(*args, **kwargs) | ||||
|             # Used to check that we called hub kernel. | ||||
|             self.n_calls = 0 | ||||
|  | ||||
|         def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||
|             self.n_calls += 1 | ||||
|             return super().forward(input) | ||||
|  | ||||
|     linear = TorchLinear(32, 32).to("cuda") | ||||
|  | ||||
|     with use_kernel_mapping( | ||||
|         { | ||||
|             "Linear": { | ||||
|                 Device(type="cuda"): LayerRepository( | ||||
|                     repo_id="kernels-test/backward-marker-test", | ||||
|                     layer_name="LinearImplicitBackward", | ||||
|                 ) | ||||
|             } | ||||
|         } | ||||
|     ): | ||||
|         linear.train() | ||||
|         X = torch.randn(10, 32, device="cuda") | ||||
|         linear(X) | ||||
|         assert linear.n_calls == 0 | ||||
|  | ||||
|         linear.eval() | ||||
|         linear(X) | ||||
|         assert linear.n_calls == 0 | ||||
|  | ||||
|     with use_kernel_mapping( | ||||
|         { | ||||
|             "Linear": { | ||||
|                 Device(type="cuda"): LayerRepository( | ||||
|                     repo_id="kernels-test/backward-marker-test", | ||||
|                     layer_name="LinearBackward", | ||||
|                 ) | ||||
|             } | ||||
|         } | ||||
|     ): | ||||
|         linear.train() | ||||
|         X = torch.randn(10, 32, device="cuda") | ||||
|         linear(X) | ||||
|         assert linear.n_calls == 0 | ||||
|  | ||||
|         linear.eval() | ||||
|         linear(X) | ||||
|         assert linear.n_calls == 0 | ||||
|  | ||||
|     with use_kernel_mapping( | ||||
|         { | ||||
|             "Linear": { | ||||
|                 Device(type="cuda"): LayerRepository( | ||||
|                     repo_id="kernels-test/backward-marker-test", | ||||
|                     layer_name="LinearNoBackward", | ||||
|                 ) | ||||
|             } | ||||
|         } | ||||
|     ): | ||||
|         linear.train() | ||||
|         X = torch.randn(10, 32, device="cuda") | ||||
|         linear(X) | ||||
|         assert linear.n_calls == 1 | ||||
|  | ||||
|         linear.eval() | ||||
|         linear(X) | ||||
|         assert linear.n_calls == 1 | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	